olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
1
+ """Flexible patch embedding Module.
2
+
3
+ Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
4
+ by https://github.com/bwconrad/flexivit/
5
+ """
6
+
7
+ import logging
8
+ from collections.abc import Iterable
9
+ from typing import Any
10
+
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+ from torch import Tensor
15
+
16
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import ModalitySpec
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class FlexiPatchEmbed(nn.Module):
22
+ """Flexible patch embedding nn.Module."""
23
+
24
+ def __init__(
25
+ self,
26
+ modality_spec: ModalitySpec,
27
+ patch_size_at_16: int | tuple[int, int],
28
+ in_chans: int = 3,
29
+ embedding_size: int = 128,
30
+ norm_layer: nn.Module | None = None,
31
+ bias: bool = True,
32
+ interpolation: str = "bicubic",
33
+ antialias: bool = True,
34
+ ) -> None:
35
+ """2D image to patch embedding w/ flexible patch sizes.
36
+
37
+ Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
38
+ by https://github.com/bwconrad/flexivit/
39
+
40
+ Args:
41
+ modality_spec: The modality spec for this modality
42
+ patch_size_at_16: Base patch size. i.e the size of the parameter buffer at a resolution of 16
43
+ in_chans: Number of input image channels
44
+ embedding_size: Network embedding dimension size
45
+ norm_layer: Optional normalization layer
46
+ bias: Whether to use bias in convolution
47
+ interpolation: Resize interpolation type
48
+ antialias: Whether to apply antialiasing resizing (TODO: Add a link or more info)
49
+ """
50
+ super().__init__()
51
+
52
+ self.embedding_size = embedding_size
53
+
54
+ self.modality_spec = modality_spec
55
+ self.patch_size = self.to_2tuple(
56
+ patch_size_at_16 * modality_spec.image_tile_size_factor
57
+ )
58
+
59
+ self.proj = nn.Conv2d(
60
+ in_chans,
61
+ embedding_size,
62
+ kernel_size=self.patch_size,
63
+ stride=self.patch_size,
64
+ bias=bias,
65
+ )
66
+ self.norm = norm_layer(embedding_size) if norm_layer else nn.Identity()
67
+
68
+ # Flexi specific attributes
69
+ self.interpolation = interpolation
70
+ self.antialias = antialias
71
+
72
+ @staticmethod
73
+ def to_2tuple(x: Any) -> Any:
74
+ """Convert a value to a 2-tuple by either converting an iterable or repeating a scalar.
75
+
76
+ This is used to handle patch sizes that can be specified either as:
77
+ - A single integer (e.g. 16) which gets converted to (16, 16) for square patches
78
+ - A tuple/list of 2 integers (e.g. (16, 32)) for rectangular patches
79
+
80
+ Args:
81
+ x: Value to convert to a 2-tuple. Can be an iterable (list/tuple) of 2 elements,
82
+ or a single value to repeat twice.
83
+
84
+ Returns:
85
+ A 2-tuple containing either the original iterable values or the input repeated twice.
86
+ """
87
+ if isinstance(x, Iterable) and not isinstance(x, str):
88
+ assert len(list(x)) == 2, "x must be a 2-tuple"
89
+ return tuple(x)
90
+ return (x, x)
91
+
92
+ def forward(
93
+ self,
94
+ x: Tensor,
95
+ patch_size: int | tuple[int, int] | None = None,
96
+ ) -> Tensor | tuple[Tensor, tuple[int, int]]:
97
+ """Forward pass for the FlexiPatchEmbed module.
98
+
99
+ Args:
100
+ x: Input tensor with shape [b, h, w, (t), c]
101
+ patch_size: Patch size to use for the embedding. If None, the base patch size
102
+ will be used, at an image_tile_size_factor of 16
103
+ """
104
+ # x has input shape [b, h, w, (t), c]
105
+ batch_size = x.shape[0]
106
+ has_time_dimension = False
107
+ num_timesteps = 0 # ignored if has_time_dimension is False
108
+
109
+ if len(x.shape) == 5:
110
+ has_time_dimension = True
111
+ num_timesteps = x.shape[3]
112
+ x = rearrange(x, "b h w t c -> (b t) c h w")
113
+ else:
114
+ x = rearrange(x, "b h w c -> b c h w")
115
+
116
+ if not patch_size:
117
+ # During evaluation use base patch size if not specified
118
+ patch_size = self.patch_size
119
+ else:
120
+ if isinstance(patch_size, tuple):
121
+ patch_size = (
122
+ patch_size[0] * self.modality_spec.image_tile_size_factor,
123
+ patch_size[1] * self.modality_spec.image_tile_size_factor,
124
+ )
125
+ else:
126
+ patch_size = patch_size * self.modality_spec.image_tile_size_factor
127
+ patch_size = self.to_2tuple(patch_size)
128
+ assert isinstance(patch_size, tuple) and len(patch_size) == 2, (
129
+ "patch_size must be a 2-tuple"
130
+ )
131
+ # Resize input
132
+ if patch_size != self.patch_size:
133
+ shape = x.shape[-2:]
134
+ new_shape = (
135
+ shape[0] // patch_size[0] * self.patch_size[0],
136
+ shape[1] // patch_size[1] * self.patch_size[1],
137
+ )
138
+ x = F.interpolate(
139
+ x,
140
+ size=new_shape,
141
+ mode=self.interpolation,
142
+ antialias=self.antialias,
143
+ )
144
+ # Apply conv with resized weights
145
+ x = self.proj(x)
146
+ # At this point x has embedding dim sized channel dimension
147
+ if has_time_dimension:
148
+ _, d, h, w = x.shape
149
+ x = rearrange(
150
+ x,
151
+ "(b t) d h w -> b h w t d",
152
+ b=batch_size,
153
+ t=num_timesteps,
154
+ d=d,
155
+ h=h,
156
+ w=w,
157
+ )
158
+ else:
159
+ x = rearrange(x, "b d h w -> b h w d")
160
+
161
+ x = self.norm(x)
162
+
163
+ return x
164
+
165
+
166
+ class FlexiPatchReconstruction(nn.Module):
167
+ """Flexible patch reconstruction nn.Module."""
168
+
169
+ def __init__(
170
+ self,
171
+ max_patch_size: int | tuple[int, int],
172
+ out_chans: int = 3,
173
+ embedding_size: int = 128,
174
+ norm_layer: nn.Module | None = None,
175
+ bias: bool = True,
176
+ interpolation: str = "bicubic",
177
+ antialias: bool = True,
178
+ ) -> None:
179
+ """Patch embeding to 2d image reconstruction w/ flexible patch sizes.
180
+
181
+ Args:
182
+ max_patch_size: Base patch size. i.e the size of the parameter buffer
183
+ out_chans: Number of out image channels
184
+ embedding_size: Network embedding dimension size
185
+ norm_layer: Optional normalization layer
186
+ bias: Whether to use bias in convolution
187
+ interpolation: Resize interpolation type
188
+ antialias: Whether to apply antialiasing resizing
189
+ """
190
+ super().__init__()
191
+
192
+ self.embedding_size = embedding_size
193
+
194
+ self.max_patch_size = self.to_2tuple(max_patch_size)
195
+
196
+ self.proj = nn.ConvTranspose2d(
197
+ embedding_size,
198
+ out_chans,
199
+ kernel_size=max_patch_size,
200
+ stride=max_patch_size,
201
+ bias=bias,
202
+ )
203
+ self.norm = norm_layer(embedding_size) if norm_layer else nn.Identity()
204
+
205
+ # Flexi specific attributes
206
+ self.interpolation = interpolation
207
+ self.antialias = antialias
208
+
209
+ @staticmethod
210
+ def to_2tuple(x: Any) -> Any:
211
+ """Convert a value to a 2-tuple by either converting an iterable or repeating a scalar.
212
+
213
+ This is used to handle patch sizes that can be specified either as:
214
+ - A single integer (e.g. 16) which gets converted to (16, 16) for square patches
215
+ - A tuple/list of 2 integers (e.g. (16, 32)) for rectangular patches
216
+
217
+ Args:
218
+ x: Value to convert to a 2-tuple. Can be an iterable (list/tuple) of 2 elements,
219
+ or a single value to repeat twice.
220
+
221
+ Returns:
222
+ A 2-tuple containing either the original iterable values or the input repeated twice.
223
+ """
224
+ if isinstance(x, Iterable) and not isinstance(x, str):
225
+ assert len(list(x)) == 2, "x must be a 2-tuple"
226
+ return tuple(x)
227
+ return (x, x)
228
+
229
+ def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
230
+ """Resize the input tensor to the target shape.
231
+
232
+ Args:
233
+ x: Input tensor
234
+ shape: Target shape
235
+
236
+ Returns:
237
+ Resized tensor
238
+ """
239
+ x_resized = F.interpolate(
240
+ x[None, None, ...],
241
+ shape,
242
+ mode=self.interpolation,
243
+ antialias=self.antialias,
244
+ )
245
+ return x_resized[0, 0, ...]
246
+
247
+ def forward(
248
+ self,
249
+ x: Tensor,
250
+ patch_size: int | tuple[int, int] | None = None,
251
+ ) -> Tensor | tuple[Tensor, tuple[int, int]]:
252
+ """Forward pass for the FlexiPatchReconstruction module.
253
+
254
+ Args:
255
+ x: Input tensor with shape [b, h, w, (t), d]
256
+ patch_size: Patch size to use for the reconstruction. If None, the base patch size
257
+ will be used.
258
+ """
259
+ # x has input shape [b, h, w, (t), d]
260
+ if len(x.shape) == 4:
261
+ has_time_dimension = False
262
+ b, h, w, d = x.shape
263
+ t = 1
264
+ else:
265
+ has_time_dimension = True
266
+ b, h, w, t, d = x.shape
267
+
268
+ if not patch_size:
269
+ # During evaluation use base patch size if not specified
270
+ patch_size = self.max_patch_size
271
+
272
+ patch_size = self.to_2tuple(patch_size)
273
+
274
+ if has_time_dimension:
275
+ x = rearrange(x, "b h w t d -> (b t) d h w", b=b, t=t)
276
+ else:
277
+ x = rearrange(x, "b h w d -> b d h w")
278
+
279
+ x = self.proj(x)
280
+
281
+ if patch_size != self.max_patch_size:
282
+ x = rearrange(
283
+ x,
284
+ "b c (h p_h) (w p_w) -> b h w c p_h p_w",
285
+ p_h=self.max_patch_size[0],
286
+ p_w=self.max_patch_size[1],
287
+ )
288
+ bl, hl, wl, cl = x.shape[:4]
289
+ x = rearrange(x, "b h w c p_h p_w -> (b h w) c p_h p_w")
290
+ x = F.interpolate(
291
+ x, patch_size, mode=self.interpolation, antialias=self.antialias
292
+ )
293
+ x = rearrange(
294
+ x, "(b h w) c p_h p_w -> b c (h p_h) (w p_w)", b=bl, h=hl, w=wl
295
+ )
296
+
297
+ if has_time_dimension:
298
+ x = rearrange(x, "(b t) c h w -> b h w t c", b=b, t=t)
299
+ else:
300
+ x = rearrange(x, "b c h w -> b h w c")
301
+
302
+ x = self.norm(x)
303
+
304
+ return x