olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""Flexible patch embedding Module.
|
|
2
|
+
|
|
3
|
+
Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
|
|
4
|
+
by https://github.com/bwconrad/flexivit/
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import Iterable
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from einops import rearrange
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import ModalitySpec
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FlexiPatchEmbed(nn.Module):
|
|
22
|
+
"""Flexible patch embedding nn.Module."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
modality_spec: ModalitySpec,
|
|
27
|
+
patch_size_at_16: int | tuple[int, int],
|
|
28
|
+
in_chans: int = 3,
|
|
29
|
+
embedding_size: int = 128,
|
|
30
|
+
norm_layer: nn.Module | None = None,
|
|
31
|
+
bias: bool = True,
|
|
32
|
+
interpolation: str = "bicubic",
|
|
33
|
+
antialias: bool = True,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""2D image to patch embedding w/ flexible patch sizes.
|
|
36
|
+
|
|
37
|
+
Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
|
|
38
|
+
by https://github.com/bwconrad/flexivit/
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
modality_spec: The modality spec for this modality
|
|
42
|
+
patch_size_at_16: Base patch size. i.e the size of the parameter buffer at a resolution of 16
|
|
43
|
+
in_chans: Number of input image channels
|
|
44
|
+
embedding_size: Network embedding dimension size
|
|
45
|
+
norm_layer: Optional normalization layer
|
|
46
|
+
bias: Whether to use bias in convolution
|
|
47
|
+
interpolation: Resize interpolation type
|
|
48
|
+
antialias: Whether to apply antialiasing resizing (TODO: Add a link or more info)
|
|
49
|
+
"""
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.embedding_size = embedding_size
|
|
53
|
+
|
|
54
|
+
self.modality_spec = modality_spec
|
|
55
|
+
self.patch_size = self.to_2tuple(
|
|
56
|
+
patch_size_at_16 * modality_spec.image_tile_size_factor
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.proj = nn.Conv2d(
|
|
60
|
+
in_chans,
|
|
61
|
+
embedding_size,
|
|
62
|
+
kernel_size=self.patch_size,
|
|
63
|
+
stride=self.patch_size,
|
|
64
|
+
bias=bias,
|
|
65
|
+
)
|
|
66
|
+
self.norm = norm_layer(embedding_size) if norm_layer else nn.Identity()
|
|
67
|
+
|
|
68
|
+
# Flexi specific attributes
|
|
69
|
+
self.interpolation = interpolation
|
|
70
|
+
self.antialias = antialias
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def to_2tuple(x: Any) -> Any:
|
|
74
|
+
"""Convert a value to a 2-tuple by either converting an iterable or repeating a scalar.
|
|
75
|
+
|
|
76
|
+
This is used to handle patch sizes that can be specified either as:
|
|
77
|
+
- A single integer (e.g. 16) which gets converted to (16, 16) for square patches
|
|
78
|
+
- A tuple/list of 2 integers (e.g. (16, 32)) for rectangular patches
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
x: Value to convert to a 2-tuple. Can be an iterable (list/tuple) of 2 elements,
|
|
82
|
+
or a single value to repeat twice.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
A 2-tuple containing either the original iterable values or the input repeated twice.
|
|
86
|
+
"""
|
|
87
|
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
|
88
|
+
assert len(list(x)) == 2, "x must be a 2-tuple"
|
|
89
|
+
return tuple(x)
|
|
90
|
+
return (x, x)
|
|
91
|
+
|
|
92
|
+
def forward(
|
|
93
|
+
self,
|
|
94
|
+
x: Tensor,
|
|
95
|
+
patch_size: int | tuple[int, int] | None = None,
|
|
96
|
+
) -> Tensor | tuple[Tensor, tuple[int, int]]:
|
|
97
|
+
"""Forward pass for the FlexiPatchEmbed module.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
x: Input tensor with shape [b, h, w, (t), c]
|
|
101
|
+
patch_size: Patch size to use for the embedding. If None, the base patch size
|
|
102
|
+
will be used, at an image_tile_size_factor of 16
|
|
103
|
+
"""
|
|
104
|
+
# x has input shape [b, h, w, (t), c]
|
|
105
|
+
batch_size = x.shape[0]
|
|
106
|
+
has_time_dimension = False
|
|
107
|
+
num_timesteps = 0 # ignored if has_time_dimension is False
|
|
108
|
+
|
|
109
|
+
if len(x.shape) == 5:
|
|
110
|
+
has_time_dimension = True
|
|
111
|
+
num_timesteps = x.shape[3]
|
|
112
|
+
x = rearrange(x, "b h w t c -> (b t) c h w")
|
|
113
|
+
else:
|
|
114
|
+
x = rearrange(x, "b h w c -> b c h w")
|
|
115
|
+
|
|
116
|
+
if not patch_size:
|
|
117
|
+
# During evaluation use base patch size if not specified
|
|
118
|
+
patch_size = self.patch_size
|
|
119
|
+
else:
|
|
120
|
+
if isinstance(patch_size, tuple):
|
|
121
|
+
patch_size = (
|
|
122
|
+
patch_size[0] * self.modality_spec.image_tile_size_factor,
|
|
123
|
+
patch_size[1] * self.modality_spec.image_tile_size_factor,
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
patch_size = patch_size * self.modality_spec.image_tile_size_factor
|
|
127
|
+
patch_size = self.to_2tuple(patch_size)
|
|
128
|
+
assert isinstance(patch_size, tuple) and len(patch_size) == 2, (
|
|
129
|
+
"patch_size must be a 2-tuple"
|
|
130
|
+
)
|
|
131
|
+
# Resize input
|
|
132
|
+
if patch_size != self.patch_size:
|
|
133
|
+
shape = x.shape[-2:]
|
|
134
|
+
new_shape = (
|
|
135
|
+
shape[0] // patch_size[0] * self.patch_size[0],
|
|
136
|
+
shape[1] // patch_size[1] * self.patch_size[1],
|
|
137
|
+
)
|
|
138
|
+
x = F.interpolate(
|
|
139
|
+
x,
|
|
140
|
+
size=new_shape,
|
|
141
|
+
mode=self.interpolation,
|
|
142
|
+
antialias=self.antialias,
|
|
143
|
+
)
|
|
144
|
+
# Apply conv with resized weights
|
|
145
|
+
x = self.proj(x)
|
|
146
|
+
# At this point x has embedding dim sized channel dimension
|
|
147
|
+
if has_time_dimension:
|
|
148
|
+
_, d, h, w = x.shape
|
|
149
|
+
x = rearrange(
|
|
150
|
+
x,
|
|
151
|
+
"(b t) d h w -> b h w t d",
|
|
152
|
+
b=batch_size,
|
|
153
|
+
t=num_timesteps,
|
|
154
|
+
d=d,
|
|
155
|
+
h=h,
|
|
156
|
+
w=w,
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
x = rearrange(x, "b d h w -> b h w d")
|
|
160
|
+
|
|
161
|
+
x = self.norm(x)
|
|
162
|
+
|
|
163
|
+
return x
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class FlexiPatchReconstruction(nn.Module):
|
|
167
|
+
"""Flexible patch reconstruction nn.Module."""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
max_patch_size: int | tuple[int, int],
|
|
172
|
+
out_chans: int = 3,
|
|
173
|
+
embedding_size: int = 128,
|
|
174
|
+
norm_layer: nn.Module | None = None,
|
|
175
|
+
bias: bool = True,
|
|
176
|
+
interpolation: str = "bicubic",
|
|
177
|
+
antialias: bool = True,
|
|
178
|
+
) -> None:
|
|
179
|
+
"""Patch embeding to 2d image reconstruction w/ flexible patch sizes.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
max_patch_size: Base patch size. i.e the size of the parameter buffer
|
|
183
|
+
out_chans: Number of out image channels
|
|
184
|
+
embedding_size: Network embedding dimension size
|
|
185
|
+
norm_layer: Optional normalization layer
|
|
186
|
+
bias: Whether to use bias in convolution
|
|
187
|
+
interpolation: Resize interpolation type
|
|
188
|
+
antialias: Whether to apply antialiasing resizing
|
|
189
|
+
"""
|
|
190
|
+
super().__init__()
|
|
191
|
+
|
|
192
|
+
self.embedding_size = embedding_size
|
|
193
|
+
|
|
194
|
+
self.max_patch_size = self.to_2tuple(max_patch_size)
|
|
195
|
+
|
|
196
|
+
self.proj = nn.ConvTranspose2d(
|
|
197
|
+
embedding_size,
|
|
198
|
+
out_chans,
|
|
199
|
+
kernel_size=max_patch_size,
|
|
200
|
+
stride=max_patch_size,
|
|
201
|
+
bias=bias,
|
|
202
|
+
)
|
|
203
|
+
self.norm = norm_layer(embedding_size) if norm_layer else nn.Identity()
|
|
204
|
+
|
|
205
|
+
# Flexi specific attributes
|
|
206
|
+
self.interpolation = interpolation
|
|
207
|
+
self.antialias = antialias
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def to_2tuple(x: Any) -> Any:
|
|
211
|
+
"""Convert a value to a 2-tuple by either converting an iterable or repeating a scalar.
|
|
212
|
+
|
|
213
|
+
This is used to handle patch sizes that can be specified either as:
|
|
214
|
+
- A single integer (e.g. 16) which gets converted to (16, 16) for square patches
|
|
215
|
+
- A tuple/list of 2 integers (e.g. (16, 32)) for rectangular patches
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
x: Value to convert to a 2-tuple. Can be an iterable (list/tuple) of 2 elements,
|
|
219
|
+
or a single value to repeat twice.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
A 2-tuple containing either the original iterable values or the input repeated twice.
|
|
223
|
+
"""
|
|
224
|
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
|
225
|
+
assert len(list(x)) == 2, "x must be a 2-tuple"
|
|
226
|
+
return tuple(x)
|
|
227
|
+
return (x, x)
|
|
228
|
+
|
|
229
|
+
def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
|
|
230
|
+
"""Resize the input tensor to the target shape.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
x: Input tensor
|
|
234
|
+
shape: Target shape
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Resized tensor
|
|
238
|
+
"""
|
|
239
|
+
x_resized = F.interpolate(
|
|
240
|
+
x[None, None, ...],
|
|
241
|
+
shape,
|
|
242
|
+
mode=self.interpolation,
|
|
243
|
+
antialias=self.antialias,
|
|
244
|
+
)
|
|
245
|
+
return x_resized[0, 0, ...]
|
|
246
|
+
|
|
247
|
+
def forward(
|
|
248
|
+
self,
|
|
249
|
+
x: Tensor,
|
|
250
|
+
patch_size: int | tuple[int, int] | None = None,
|
|
251
|
+
) -> Tensor | tuple[Tensor, tuple[int, int]]:
|
|
252
|
+
"""Forward pass for the FlexiPatchReconstruction module.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
x: Input tensor with shape [b, h, w, (t), d]
|
|
256
|
+
patch_size: Patch size to use for the reconstruction. If None, the base patch size
|
|
257
|
+
will be used.
|
|
258
|
+
"""
|
|
259
|
+
# x has input shape [b, h, w, (t), d]
|
|
260
|
+
if len(x.shape) == 4:
|
|
261
|
+
has_time_dimension = False
|
|
262
|
+
b, h, w, d = x.shape
|
|
263
|
+
t = 1
|
|
264
|
+
else:
|
|
265
|
+
has_time_dimension = True
|
|
266
|
+
b, h, w, t, d = x.shape
|
|
267
|
+
|
|
268
|
+
if not patch_size:
|
|
269
|
+
# During evaluation use base patch size if not specified
|
|
270
|
+
patch_size = self.max_patch_size
|
|
271
|
+
|
|
272
|
+
patch_size = self.to_2tuple(patch_size)
|
|
273
|
+
|
|
274
|
+
if has_time_dimension:
|
|
275
|
+
x = rearrange(x, "b h w t d -> (b t) d h w", b=b, t=t)
|
|
276
|
+
else:
|
|
277
|
+
x = rearrange(x, "b h w d -> b d h w")
|
|
278
|
+
|
|
279
|
+
x = self.proj(x)
|
|
280
|
+
|
|
281
|
+
if patch_size != self.max_patch_size:
|
|
282
|
+
x = rearrange(
|
|
283
|
+
x,
|
|
284
|
+
"b c (h p_h) (w p_w) -> b h w c p_h p_w",
|
|
285
|
+
p_h=self.max_patch_size[0],
|
|
286
|
+
p_w=self.max_patch_size[1],
|
|
287
|
+
)
|
|
288
|
+
bl, hl, wl, cl = x.shape[:4]
|
|
289
|
+
x = rearrange(x, "b h w c p_h p_w -> (b h w) c p_h p_w")
|
|
290
|
+
x = F.interpolate(
|
|
291
|
+
x, patch_size, mode=self.interpolation, antialias=self.antialias
|
|
292
|
+
)
|
|
293
|
+
x = rearrange(
|
|
294
|
+
x, "(b h w) c p_h p_w -> b c (h p_h) (w p_w)", b=bl, h=hl, w=wl
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if has_time_dimension:
|
|
298
|
+
x = rearrange(x, "(b t) c h w -> b h w t c", b=b, t=t)
|
|
299
|
+
else:
|
|
300
|
+
x = rearrange(x, "b c h w -> b h w c")
|
|
301
|
+
|
|
302
|
+
x = self.norm(x)
|
|
303
|
+
|
|
304
|
+
return x
|