rslearn 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +5 -1
- rslearn/models/dinov3.py +6 -1
- rslearn/models/feature_center_crop.py +50 -0
- rslearn/models/olmoearth_pretrain/model.py +87 -27
- rslearn/models/prithvi.py +9 -1
- rslearn/train/lightning_module.py +0 -3
- rslearn/train/tasks/classification.py +2 -2
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +5 -4
- rslearn/train/tasks/regression.py +5 -5
- rslearn/train/transforms/pad.py +3 -3
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/METADATA +2 -1
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/RECORD +18 -25
- rslearn-0.0.12.dist-info/licenses/NOTICE +115 -0
- rslearn/models/copernicusfm.py +0 -228
- rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/WHEEL +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.12.dist-info}/top_level.txt +0 -0
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
# type: ignore
|
|
2
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
-
# All rights reserved.
|
|
4
|
-
|
|
5
|
-
# This source code is licensed under the license found in the
|
|
6
|
-
# LICENSE file in the root directory of this source tree.
|
|
7
|
-
# --------------------------------------------------------
|
|
8
|
-
# Position embedding utils
|
|
9
|
-
# --------------------------------------------------------
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
import torch
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
# --------------------------------------------------------
|
|
16
|
-
# 2D sine-cosine position embedding
|
|
17
|
-
# References:
|
|
18
|
-
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
|
19
|
-
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
|
20
|
-
# --------------------------------------------------------
|
|
21
|
-
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
|
22
|
-
"""grid_size: int of the grid height and width
|
|
23
|
-
return:
|
|
24
|
-
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
25
|
-
"""
|
|
26
|
-
grid_h = np.arange(grid_size, dtype=np.float32)
|
|
27
|
-
grid_w = np.arange(grid_size, dtype=np.float32)
|
|
28
|
-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
29
|
-
grid = np.stack(grid, axis=0)
|
|
30
|
-
|
|
31
|
-
grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
32
|
-
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
33
|
-
if cls_token:
|
|
34
|
-
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
35
|
-
return pos_embed
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
39
|
-
assert embed_dim % 2 == 0
|
|
40
|
-
|
|
41
|
-
# use half of dimensions to encode grid_h
|
|
42
|
-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
43
|
-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
44
|
-
|
|
45
|
-
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
46
|
-
return emb
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
50
|
-
"""embed_dim: output dimension for each position
|
|
51
|
-
pos: a list of positions to be encoded: size (M,)
|
|
52
|
-
out: (M, D)
|
|
53
|
-
"""
|
|
54
|
-
assert embed_dim % 2 == 0
|
|
55
|
-
# omega = np.arange(embed_dim // 2, dtype=np.float) # numpy deprecated in 1.20
|
|
56
|
-
omega = np.arange(embed_dim // 2, dtype=float)
|
|
57
|
-
|
|
58
|
-
omega /= embed_dim / 2.0
|
|
59
|
-
omega = 1.0 / 10000**omega # (D/2,)
|
|
60
|
-
|
|
61
|
-
pos = pos.reshape(-1) # (M,)
|
|
62
|
-
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
63
|
-
|
|
64
|
-
emb_sin = np.sin(out) # (M, D/2)
|
|
65
|
-
emb_cos = np.cos(out) # (M, D/2)
|
|
66
|
-
|
|
67
|
-
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
68
|
-
return emb
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# --------------------------------------------------------
|
|
72
|
-
# Interpolate position embeddings for high-resolution
|
|
73
|
-
# References:
|
|
74
|
-
# DeiT: https://github.com/facebookresearch/deit
|
|
75
|
-
# --------------------------------------------------------
|
|
76
|
-
def interpolate_pos_embed(model, checkpoint_model):
|
|
77
|
-
if "pos_embed" in checkpoint_model:
|
|
78
|
-
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
|
79
|
-
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
80
|
-
num_patches = model.patch_embed.num_patches
|
|
81
|
-
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
|
82
|
-
# height (== width) for the checkpoint position embedding
|
|
83
|
-
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
84
|
-
# height (== width) for the new position embedding
|
|
85
|
-
new_size = int(num_patches**0.5)
|
|
86
|
-
# class_token and dist_token are kept unchanged
|
|
87
|
-
if orig_size != new_size:
|
|
88
|
-
print(
|
|
89
|
-
"Position interpolate from %dx%d to %dx%d"
|
|
90
|
-
% (orig_size, orig_size, new_size, new_size)
|
|
91
|
-
)
|
|
92
|
-
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
93
|
-
# only the position tokens are interpolated
|
|
94
|
-
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
95
|
-
pos_tokens = pos_tokens.reshape(
|
|
96
|
-
-1, orig_size, orig_size, embedding_size
|
|
97
|
-
).permute(0, 3, 1, 2)
|
|
98
|
-
pos_tokens = torch.nn.functional.interpolate(
|
|
99
|
-
pos_tokens,
|
|
100
|
-
size=(new_size, new_size),
|
|
101
|
-
mode="bicubic",
|
|
102
|
-
align_corners=False,
|
|
103
|
-
)
|
|
104
|
-
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
105
|
-
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
106
|
-
checkpoint_model["pos_embed"] = new_pos_embed
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def interpolate_pos_embed_ofa(model, checkpoint_model):
|
|
110
|
-
if "pos_embed" in checkpoint_model:
|
|
111
|
-
pos_embed_dict = checkpoint_model["pos_embed"]
|
|
112
|
-
|
|
113
|
-
for key, pos_embed in pos_embed_dict.items():
|
|
114
|
-
pos_embed_checkpoint = pos_embed
|
|
115
|
-
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
116
|
-
num_patches = model.patch_embed[key].num_patches
|
|
117
|
-
num_extra_tokens = model.pos_embed[key].shape[-2] - num_patches
|
|
118
|
-
# height (== width) for the checkpoint position embedding
|
|
119
|
-
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
120
|
-
# height (== width) for the new position embedding
|
|
121
|
-
new_size = int(num_patches**0.5)
|
|
122
|
-
# class_token and dist_token are kept unchanged
|
|
123
|
-
if orig_size != new_size:
|
|
124
|
-
print(
|
|
125
|
-
"Position interpolate from %dx%d to %dx%d"
|
|
126
|
-
% (orig_size, orig_size, new_size, new_size)
|
|
127
|
-
)
|
|
128
|
-
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
129
|
-
# only the position tokens are interpolated
|
|
130
|
-
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
131
|
-
pos_tokens = pos_tokens.reshape(
|
|
132
|
-
-1, orig_size, orig_size, embedding_size
|
|
133
|
-
).permute(0, 3, 1, 2)
|
|
134
|
-
pos_tokens = torch.nn.functional.interpolate(
|
|
135
|
-
pos_tokens,
|
|
136
|
-
size=(new_size, new_size),
|
|
137
|
-
mode="bicubic",
|
|
138
|
-
align_corners=False,
|
|
139
|
-
)
|
|
140
|
-
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
141
|
-
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
142
|
-
checkpoint_model["pos_embed"][key] = new_pos_embed
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def get_2d_sincos_pos_embed_with_resolution(
|
|
146
|
-
embed_dim, grid_size, res, cls_token=False, device="cpu"
|
|
147
|
-
):
|
|
148
|
-
"""grid_size: int of the grid height and width
|
|
149
|
-
res: array of size n, representing the resolution of a pixel (say, in meters),
|
|
150
|
-
|
|
151
|
-
Return:
|
|
152
|
-
pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
153
|
-
"""
|
|
154
|
-
# res = torch.FloatTensor(res).to(device)
|
|
155
|
-
res = res.to(device)
|
|
156
|
-
grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
|
|
157
|
-
grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
|
|
158
|
-
grid = torch.meshgrid(
|
|
159
|
-
grid_w, grid_h, indexing="xy"
|
|
160
|
-
) # here h goes first,direction reversed for numpy
|
|
161
|
-
grid = torch.stack(grid, dim=0) # 2 x h x w
|
|
162
|
-
|
|
163
|
-
# grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
164
|
-
grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
|
|
165
|
-
_, n, h, w = grid.shape
|
|
166
|
-
pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
|
|
167
|
-
embed_dim, grid
|
|
168
|
-
) # # (nxH*W, D/2)
|
|
169
|
-
pos_embed = pos_embed.reshape(n, h * w, embed_dim)
|
|
170
|
-
if cls_token:
|
|
171
|
-
pos_embed = torch.cat(
|
|
172
|
-
[
|
|
173
|
-
torch.zeros(
|
|
174
|
-
[n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device
|
|
175
|
-
),
|
|
176
|
-
pos_embed,
|
|
177
|
-
],
|
|
178
|
-
dim=1,
|
|
179
|
-
)
|
|
180
|
-
return pos_embed
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
|
|
184
|
-
assert embed_dim % 2 == 0
|
|
185
|
-
|
|
186
|
-
# use half of dimensions to encode grid_h
|
|
187
|
-
emb_h = get_1d_sincos_pos_embed_from_grid_torch(
|
|
188
|
-
embed_dim // 2, grid[0]
|
|
189
|
-
) # (H*W, D/2)
|
|
190
|
-
emb_w = get_1d_sincos_pos_embed_from_grid_torch(
|
|
191
|
-
embed_dim // 2, grid[1]
|
|
192
|
-
) # (H*W, D/2)
|
|
193
|
-
|
|
194
|
-
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
|
|
195
|
-
return emb
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
|
|
199
|
-
"""embed_dim: output dimension for each position
|
|
200
|
-
pos: a list of positions to be encoded: size (M,)
|
|
201
|
-
out: (M, D)
|
|
202
|
-
"""
|
|
203
|
-
assert embed_dim % 2 == 0
|
|
204
|
-
old_shape = pos
|
|
205
|
-
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
|
206
|
-
omega /= embed_dim / 2.0
|
|
207
|
-
omega = 1.0 / 10000**omega # (D/2,)
|
|
208
|
-
|
|
209
|
-
pos = pos.reshape(-1) # (M,)
|
|
210
|
-
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
211
|
-
|
|
212
|
-
emb_sin = torch.sin(out) # (M, D/2)
|
|
213
|
-
emb_cos = torch.cos(out) # (M, D/2)
|
|
214
|
-
|
|
215
|
-
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
216
|
-
return emb
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|