rslearn 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,216 +0,0 @@
1
- # type: ignore
2
- # Copyright (c) Meta Platforms, Inc. and affiliates.
3
- # All rights reserved.
4
-
5
- # This source code is licensed under the license found in the
6
- # LICENSE file in the root directory of this source tree.
7
- # --------------------------------------------------------
8
- # Position embedding utils
9
- # --------------------------------------------------------
10
-
11
- import numpy as np
12
- import torch
13
-
14
-
15
- # --------------------------------------------------------
16
- # 2D sine-cosine position embedding
17
- # References:
18
- # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19
- # MoCo v3: https://github.com/facebookresearch/moco-v3
20
- # --------------------------------------------------------
21
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22
- """grid_size: int of the grid height and width
23
- return:
24
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
- """
26
- grid_h = np.arange(grid_size, dtype=np.float32)
27
- grid_w = np.arange(grid_size, dtype=np.float32)
28
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
- grid = np.stack(grid, axis=0)
30
-
31
- grid = grid.reshape([2, 1, grid_size, grid_size])
32
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
- if cls_token:
34
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
- return pos_embed
36
-
37
-
38
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
- assert embed_dim % 2 == 0
40
-
41
- # use half of dimensions to encode grid_h
42
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
-
45
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
- return emb
47
-
48
-
49
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
- """embed_dim: output dimension for each position
51
- pos: a list of positions to be encoded: size (M,)
52
- out: (M, D)
53
- """
54
- assert embed_dim % 2 == 0
55
- # omega = np.arange(embed_dim // 2, dtype=np.float) # numpy deprecated in 1.20
56
- omega = np.arange(embed_dim // 2, dtype=float)
57
-
58
- omega /= embed_dim / 2.0
59
- omega = 1.0 / 10000**omega # (D/2,)
60
-
61
- pos = pos.reshape(-1) # (M,)
62
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
63
-
64
- emb_sin = np.sin(out) # (M, D/2)
65
- emb_cos = np.cos(out) # (M, D/2)
66
-
67
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68
- return emb
69
-
70
-
71
- # --------------------------------------------------------
72
- # Interpolate position embeddings for high-resolution
73
- # References:
74
- # DeiT: https://github.com/facebookresearch/deit
75
- # --------------------------------------------------------
76
- def interpolate_pos_embed(model, checkpoint_model):
77
- if "pos_embed" in checkpoint_model:
78
- pos_embed_checkpoint = checkpoint_model["pos_embed"]
79
- embedding_size = pos_embed_checkpoint.shape[-1]
80
- num_patches = model.patch_embed.num_patches
81
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
82
- # height (== width) for the checkpoint position embedding
83
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
84
- # height (== width) for the new position embedding
85
- new_size = int(num_patches**0.5)
86
- # class_token and dist_token are kept unchanged
87
- if orig_size != new_size:
88
- print(
89
- "Position interpolate from %dx%d to %dx%d"
90
- % (orig_size, orig_size, new_size, new_size)
91
- )
92
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
93
- # only the position tokens are interpolated
94
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
95
- pos_tokens = pos_tokens.reshape(
96
- -1, orig_size, orig_size, embedding_size
97
- ).permute(0, 3, 1, 2)
98
- pos_tokens = torch.nn.functional.interpolate(
99
- pos_tokens,
100
- size=(new_size, new_size),
101
- mode="bicubic",
102
- align_corners=False,
103
- )
104
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
105
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
106
- checkpoint_model["pos_embed"] = new_pos_embed
107
-
108
-
109
- def interpolate_pos_embed_ofa(model, checkpoint_model):
110
- if "pos_embed" in checkpoint_model:
111
- pos_embed_dict = checkpoint_model["pos_embed"]
112
-
113
- for key, pos_embed in pos_embed_dict.items():
114
- pos_embed_checkpoint = pos_embed
115
- embedding_size = pos_embed_checkpoint.shape[-1]
116
- num_patches = model.patch_embed[key].num_patches
117
- num_extra_tokens = model.pos_embed[key].shape[-2] - num_patches
118
- # height (== width) for the checkpoint position embedding
119
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
120
- # height (== width) for the new position embedding
121
- new_size = int(num_patches**0.5)
122
- # class_token and dist_token are kept unchanged
123
- if orig_size != new_size:
124
- print(
125
- "Position interpolate from %dx%d to %dx%d"
126
- % (orig_size, orig_size, new_size, new_size)
127
- )
128
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
129
- # only the position tokens are interpolated
130
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
131
- pos_tokens = pos_tokens.reshape(
132
- -1, orig_size, orig_size, embedding_size
133
- ).permute(0, 3, 1, 2)
134
- pos_tokens = torch.nn.functional.interpolate(
135
- pos_tokens,
136
- size=(new_size, new_size),
137
- mode="bicubic",
138
- align_corners=False,
139
- )
140
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
141
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
142
- checkpoint_model["pos_embed"][key] = new_pos_embed
143
-
144
-
145
- def get_2d_sincos_pos_embed_with_resolution(
146
- embed_dim, grid_size, res, cls_token=False, device="cpu"
147
- ):
148
- """grid_size: int of the grid height and width
149
- res: array of size n, representing the resolution of a pixel (say, in meters),
150
-
151
- Return:
152
- pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
153
- """
154
- # res = torch.FloatTensor(res).to(device)
155
- res = res.to(device)
156
- grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
157
- grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
158
- grid = torch.meshgrid(
159
- grid_w, grid_h, indexing="xy"
160
- ) # here h goes first,direction reversed for numpy
161
- grid = torch.stack(grid, dim=0) # 2 x h x w
162
-
163
- # grid = grid.reshape([2, 1, grid_size, grid_size])
164
- grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
165
- _, n, h, w = grid.shape
166
- pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
167
- embed_dim, grid
168
- ) # # (nxH*W, D/2)
169
- pos_embed = pos_embed.reshape(n, h * w, embed_dim)
170
- if cls_token:
171
- pos_embed = torch.cat(
172
- [
173
- torch.zeros(
174
- [n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device
175
- ),
176
- pos_embed,
177
- ],
178
- dim=1,
179
- )
180
- return pos_embed
181
-
182
-
183
- def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
184
- assert embed_dim % 2 == 0
185
-
186
- # use half of dimensions to encode grid_h
187
- emb_h = get_1d_sincos_pos_embed_from_grid_torch(
188
- embed_dim // 2, grid[0]
189
- ) # (H*W, D/2)
190
- emb_w = get_1d_sincos_pos_embed_from_grid_torch(
191
- embed_dim // 2, grid[1]
192
- ) # (H*W, D/2)
193
-
194
- emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
195
- return emb
196
-
197
-
198
- def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
199
- """embed_dim: output dimension for each position
200
- pos: a list of positions to be encoded: size (M,)
201
- out: (M, D)
202
- """
203
- assert embed_dim % 2 == 0
204
- old_shape = pos
205
- omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
206
- omega /= embed_dim / 2.0
207
- omega = 1.0 / 10000**omega # (D/2,)
208
-
209
- pos = pos.reshape(-1) # (M,)
210
- out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
211
-
212
- emb_sin = torch.sin(out) # (M, D/2)
213
- emb_cos = torch.cos(out) # (M, D/2)
214
-
215
- emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
216
- return emb