olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,519 @@
1
+ """Constants shared across the OlmoEarth Pretrain package.
2
+
3
+ Warning: this is only developed for raster data currently.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+
9
+ # The highest resolution that we are working at.
10
+ # Everything else is a factor (which is a power of 2) coarser than this resolution.
11
+ BASE_RESOLUTION = 0.625
12
+
13
+ # The default image tile size.
14
+ # Some images may be smaller if they are stored at a coarser resolution compared to the
15
+ # resolution that the grid is based on.
16
+ IMAGE_TILE_SIZE = 256
17
+
18
+ PROJECTION_CRS = "EPSG:4326"
19
+
20
+ # Default missing value for raster data.
21
+ MISSING_VALUE = -99999
22
+
23
+ # Default maximum sequence length.
24
+ MAX_SEQUENCE_LENGTH = 12
25
+
26
+ # Resolution of the input data in meters
27
+ BASE_GSD = 10
28
+ # Default nodata value for Sentinel-1 data.
29
+ SENTINEL1_NODATA = -32768
30
+
31
+ # Number of timesteps for YEAR data.
32
+ YEAR_NUM_TIMESTEPS = 12
33
+
34
+
35
+ def get_resolution(resolution_factor: int) -> float | int:
36
+ """Compute the resolution.
37
+
38
+ If it is an integer, then we cast it to int so that it works with the raw OlmoEarth Pretrain
39
+ dataset, where some files are named based on the integer. We may want to change
40
+ this in the future to avoid the extra code here.
41
+ """
42
+ resolution = BASE_RESOLUTION * resolution_factor
43
+ if float(int(resolution)) == resolution:
44
+ return int(resolution)
45
+ return resolution
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class BandSet:
50
+ """A group of bands that is stored at the same resolution.
51
+
52
+ Many modalities only have one band set, but some have different bands at different
53
+ resolutions.
54
+ """
55
+
56
+ # List of band names.
57
+ bands: list[str]
58
+
59
+ # Resolution is BASE_RESOLUTION * resolution_factor.
60
+ # If resolution == 0, this means the data
61
+ # does not vary in space (e.g. latlons)
62
+ resolution_factor: int
63
+
64
+ def __hash__(self) -> int:
65
+ """Hash this BandSet."""
66
+ return hash((tuple(self.bands), self.resolution_factor))
67
+
68
+ def get_resolution(self) -> float:
69
+ """Compute the resolution."""
70
+ return get_resolution(self.resolution_factor)
71
+
72
+ def get_expected_image_size(self, modality_resolution_factor: int) -> int:
73
+ """Get the expected size of images containing these bands.
74
+
75
+ Args:
76
+ modality_resolution_factor: the resolution factor of the modality.
77
+
78
+ Returns:
79
+ the expected image size.
80
+ """
81
+ return IMAGE_TILE_SIZE // (self.resolution_factor // modality_resolution_factor)
82
+
83
+
84
+ class TimeSpan(str, Enum):
85
+ """Enum to distinguish data that is valid for different time ranges."""
86
+
87
+ # Only one data point (not time series).
88
+ STATIC = "static"
89
+
90
+ # Monthly over one year.
91
+ YEAR = "year"
92
+
93
+ # Every data point in a two-week period.
94
+ TWO_WEEK = "two_week"
95
+
96
+ def get_suffix(self) -> str:
97
+ """Returns the suffix used for this timespan in raw OlmoEarth Pretrain dataset."""
98
+ if self == TimeSpan.STATIC:
99
+ return ""
100
+ if self == TimeSpan.YEAR:
101
+ return "_monthly"
102
+ if self == TimeSpan.TWO_WEEK:
103
+ return "_freq"
104
+ raise ValueError("invalid TimeSpan")
105
+
106
+
107
+ @dataclass(frozen=True)
108
+ class ModalitySpec:
109
+ """Modality specification.
110
+
111
+ Args:
112
+ name: the name of the modality.
113
+ tile_resolution_factor: the factor of how much more ground area is covered by the tile compared with a tile
114
+ of IMAGE_TILE_SIZE x IMAGE_TILE_SIZE pixels at the base resolution.
115
+ band_sets: the band sets of the modality, ie the units of tokenization.
116
+ is_multitemporal: whether the modality is multitemporal.
117
+ ignore_when_parsing: whether to ignore the modality when parsing the data form the csv file.
118
+ image_tile_size_factor: the factor of how much bigger the dimensions of the image tile are compared with the base tile size.
119
+ """
120
+
121
+ name: str
122
+ tile_resolution_factor: int
123
+ band_sets: list[BandSet]
124
+ is_multitemporal: bool
125
+ ignore_when_parsing: bool # If true this modality is not parsed from the csv file and not loaded form a file
126
+ image_tile_size_factor: int = 1
127
+
128
+ def __hash__(self) -> int:
129
+ """Hash this Modality."""
130
+ return hash(self.name)
131
+
132
+ def get_tile_resolution(self) -> float:
133
+ """Compute the tile resolution."""
134
+ return get_resolution(self.tile_resolution_factor)
135
+
136
+ def bandsets_as_indices(self) -> list[list[int]]:
137
+ """Return band sets as indices."""
138
+ indices = []
139
+ offset = 0
140
+ for band_set in self.band_sets:
141
+ num_bands = len(band_set.bands)
142
+ indices.append(list(range(offset, offset + num_bands)))
143
+ offset += num_bands
144
+ return indices
145
+
146
+ @property
147
+ def band_order(self) -> list[str]:
148
+ """Get all bands."""
149
+ return sum((list(band_set.bands) for band_set in self.band_sets), [])
150
+
151
+ @property
152
+ def num_band_sets(self) -> int:
153
+ """Get the number of band sets."""
154
+ return len(self.band_sets)
155
+
156
+ @property
157
+ def num_bands(self) -> int:
158
+ """Get the number of channels.
159
+
160
+ The number of channels is the sum of the number of bands in all the band sets.
161
+ """
162
+ return sum(len(band_set.bands) for band_set in self.band_sets)
163
+
164
+ def get_expected_tile_size(self) -> int:
165
+ """Get the expected size of the tile."""
166
+ if self.image_tile_size_factor < 0:
167
+ return IMAGE_TILE_SIZE // abs(self.image_tile_size_factor)
168
+ else:
169
+ return IMAGE_TILE_SIZE * self.image_tile_size_factor
170
+
171
+ @property
172
+ def is_spatial(self) -> bool:
173
+ """Does the modality have spatial data."""
174
+ # Tile size must be greater than 1 to have spatial varying data.
175
+ return self.get_tile_resolution() > 0 and self.get_expected_tile_size() > 1
176
+
177
+ @property
178
+ def is_spacetime_varying(self) -> bool:
179
+ """Does the modality vary in space and time."""
180
+ return self.is_spatial and self.is_multitemporal
181
+
182
+ @property
183
+ def is_space_only_varying(self) -> bool:
184
+ """Does the modality vary in space and not time."""
185
+ return self.is_spatial and not self.is_multitemporal
186
+
187
+ @property
188
+ def is_time_only_varying(self) -> bool:
189
+ """Does the modality vary in time and not space."""
190
+ return not self.is_spatial and self.is_multitemporal
191
+
192
+ @property
193
+ def is_static_in_space_and_time(self) -> bool:
194
+ """Does the modality vary in neither space or space."""
195
+ return not self.is_spatial and not self.is_multitemporal
196
+
197
+
198
+ class Modality:
199
+ """Enum-like access to ModalitySpecs."""
200
+
201
+ NAIP = ModalitySpec(
202
+ name="naip",
203
+ tile_resolution_factor=1,
204
+ band_sets=[BandSet(["R", "G", "B", "IR"], 1)],
205
+ is_multitemporal=False,
206
+ ignore_when_parsing=False,
207
+ )
208
+
209
+ # NAIP_10 is the NAIP data that covers the same extent as a IMAGE_TILE_SIZE x IMAGE_TILE_SIZE tile
210
+ # at 10 m/pixel resolution but is still stored at NAIP resolution.
211
+ NAIP_10 = ModalitySpec(
212
+ name="naip_10",
213
+ tile_resolution_factor=16,
214
+ band_sets=[BandSet(["R", "G", "B", "IR"], 1)],
215
+ is_multitemporal=False,
216
+ ignore_when_parsing=False,
217
+ # Currently this is set to 4x (2.5 m/pixel) so that it is more feasible to
218
+ # train with NAIP_10. This way we end up with 512x512 NAIP images in the
219
+ # 128x128 H5 files instead of 2048x2048, which slows down data loading.
220
+ image_tile_size_factor=4,
221
+ )
222
+
223
+ SENTINEL1 = ModalitySpec(
224
+ name="sentinel1",
225
+ tile_resolution_factor=16,
226
+ band_sets=[BandSet(["vv", "vh"], 16)],
227
+ is_multitemporal=True,
228
+ ignore_when_parsing=False,
229
+ )
230
+
231
+ SENTINEL2 = ModalitySpec(
232
+ name="sentinel2",
233
+ tile_resolution_factor=16,
234
+ band_sets=[
235
+ # 10 m/pixel bands.
236
+ BandSet(["B02", "B03", "B04", "B08"], 16),
237
+ # 20 m/pixel bands.
238
+ BandSet(["B05", "B06", "B07", "B8A", "B11", "B12"], 32),
239
+ # 60 m/pixel bands that we store at 40 m/pixel.
240
+ BandSet(["B01", "B09", "B10"], 64),
241
+ ],
242
+ is_multitemporal=True,
243
+ ignore_when_parsing=False,
244
+ )
245
+
246
+ SENTINEL2_L2A = ModalitySpec(
247
+ name="sentinel2_l2a",
248
+ tile_resolution_factor=16,
249
+ band_sets=[
250
+ # 10 m/pixel bands.
251
+ BandSet(["B02", "B03", "B04", "B08"], 16),
252
+ # 20 m/pixel bands.
253
+ BandSet(["B05", "B06", "B07", "B8A", "B11", "B12"], 32),
254
+ # 60 m/pixel bands that we store at 40 m/pixel.
255
+ BandSet(["B01", "B09"], 64),
256
+ ],
257
+ is_multitemporal=True,
258
+ ignore_when_parsing=False,
259
+ )
260
+
261
+ LANDSAT = ModalitySpec(
262
+ name="landsat",
263
+ tile_resolution_factor=16,
264
+ band_sets=[
265
+ # 15 m/pixel bands that we store at 10 m/pixel.
266
+ BandSet(["B8"], 16),
267
+ # 30 m/pixel bands that we store at 20 m/pixel.
268
+ BandSet(["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], 32),
269
+ ],
270
+ is_multitemporal=True,
271
+ ignore_when_parsing=False,
272
+ )
273
+
274
+ WORLDCOVER = ModalitySpec(
275
+ name="worldcover",
276
+ tile_resolution_factor=16,
277
+ band_sets=[BandSet(["B1"], 16)],
278
+ is_multitemporal=False,
279
+ ignore_when_parsing=False,
280
+ )
281
+
282
+ WORLDCEREAL = ModalitySpec(
283
+ name="worldcereal",
284
+ tile_resolution_factor=16,
285
+ band_sets=[
286
+ BandSet(
287
+ [
288
+ "tc-annual-temporarycrops-classification",
289
+ "tc-maize-main-irrigation-classification",
290
+ "tc-maize-main-maize-classification",
291
+ "tc-maize-second-irrigation-classification",
292
+ "tc-maize-second-maize-classification",
293
+ "tc-springcereals-springcereals-classification",
294
+ "tc-wintercereals-irrigation-classification",
295
+ "tc-wintercereals-wintercereals-classification",
296
+ ],
297
+ 16,
298
+ )
299
+ ],
300
+ is_multitemporal=False,
301
+ ignore_when_parsing=False,
302
+ )
303
+
304
+ SRTM = ModalitySpec(
305
+ name="srtm",
306
+ tile_resolution_factor=16,
307
+ band_sets=[BandSet(["srtm"], 16)],
308
+ is_multitemporal=False,
309
+ ignore_when_parsing=False,
310
+ )
311
+
312
+ OPENSTREETMAP = ModalitySpec(
313
+ name="openstreetmap",
314
+ tile_resolution_factor=16,
315
+ band_sets=[
316
+ BandSet(
317
+ [
318
+ "aerialway_pylon",
319
+ "aerodrome",
320
+ "airstrip",
321
+ "amenity_fuel",
322
+ "building",
323
+ "chimney",
324
+ "communications_tower",
325
+ "crane",
326
+ "flagpole",
327
+ "fountain",
328
+ "generator_wind",
329
+ "helipad",
330
+ "highway",
331
+ "leisure",
332
+ "lighthouse",
333
+ "obelisk",
334
+ "observatory",
335
+ "parking",
336
+ "petroleum_well",
337
+ "power_plant",
338
+ "power_substation",
339
+ "power_tower",
340
+ "river",
341
+ "runway",
342
+ "satellite_dish",
343
+ "silo",
344
+ "storage_tank",
345
+ "taxiway",
346
+ "water_tower",
347
+ "works",
348
+ ],
349
+ 1,
350
+ )
351
+ ],
352
+ is_multitemporal=False,
353
+ ignore_when_parsing=True,
354
+ )
355
+
356
+ OPENSTREETMAP_RASTER = ModalitySpec(
357
+ name="openstreetmap_raster",
358
+ tile_resolution_factor=16,
359
+ band_sets=[
360
+ BandSet(
361
+ [
362
+ "aerialway_pylon",
363
+ "aerodrome",
364
+ "airstrip",
365
+ "amenity_fuel",
366
+ "building",
367
+ "chimney",
368
+ "communications_tower",
369
+ "crane",
370
+ "flagpole",
371
+ "fountain",
372
+ "generator_wind",
373
+ "helipad",
374
+ "highway",
375
+ "leisure",
376
+ "lighthouse",
377
+ "obelisk",
378
+ "observatory",
379
+ "parking",
380
+ "petroleum_well",
381
+ "power_plant",
382
+ "power_substation",
383
+ "power_tower",
384
+ "river",
385
+ "runway",
386
+ "satellite_dish",
387
+ "silo",
388
+ "storage_tank",
389
+ "taxiway",
390
+ "water_tower",
391
+ "works",
392
+ ],
393
+ 4,
394
+ )
395
+ ],
396
+ is_multitemporal=False,
397
+ ignore_when_parsing=False,
398
+ )
399
+
400
+ ERA5 = ModalitySpec(
401
+ name="era5",
402
+ # 9 km/pixel bands that we store at 150 m/pixel.
403
+ tile_resolution_factor=256,
404
+ band_sets=[
405
+ BandSet(
406
+ [
407
+ "2m-temperature",
408
+ "2m-dewpoint-temperature",
409
+ "surface-pressure",
410
+ "10m-u-component-of-wind",
411
+ "10m-v-component-of-wind",
412
+ "total-precipitation",
413
+ ],
414
+ 256,
415
+ ),
416
+ ],
417
+ is_multitemporal=True,
418
+ ignore_when_parsing=True,
419
+ )
420
+
421
+ ERA5_10 = ModalitySpec(
422
+ name="era5_10",
423
+ # 9 km/pixel bands that we store at 2.56 km/pixel.
424
+ tile_resolution_factor=16,
425
+ band_sets=[
426
+ BandSet(
427
+ [
428
+ "2m-temperature",
429
+ "2m-dewpoint-temperature",
430
+ "surface-pressure",
431
+ "10m-u-component-of-wind",
432
+ "10m-v-component-of-wind",
433
+ "total-precipitation",
434
+ ],
435
+ 4096,
436
+ ),
437
+ ],
438
+ is_multitemporal=True,
439
+ ignore_when_parsing=False,
440
+ image_tile_size_factor=-256,
441
+ )
442
+
443
+ LATLON = ModalitySpec(
444
+ name="latlon",
445
+ tile_resolution_factor=0,
446
+ band_sets=[BandSet(["lat", "lon"], 0)],
447
+ is_multitemporal=False,
448
+ ignore_when_parsing=True,
449
+ )
450
+
451
+ GSE = ModalitySpec(
452
+ name="gse",
453
+ tile_resolution_factor=16,
454
+ band_sets=[
455
+ BandSet(
456
+ [f"A{idx:02d}" for idx in range(64)],
457
+ 16,
458
+ ),
459
+ ],
460
+ is_multitemporal=False,
461
+ ignore_when_parsing=False,
462
+ )
463
+
464
+ CDL = ModalitySpec(
465
+ name="cdl",
466
+ tile_resolution_factor=16,
467
+ band_sets=[BandSet(["cdl"], 16)],
468
+ is_multitemporal=False,
469
+ ignore_when_parsing=False,
470
+ )
471
+
472
+ WORLDPOP = ModalitySpec(
473
+ name="worldpop",
474
+ tile_resolution_factor=16,
475
+ band_sets=[BandSet(["B1"], 16)],
476
+ is_multitemporal=False,
477
+ ignore_when_parsing=False,
478
+ )
479
+
480
+ WRI_CANOPY_HEIGHT_MAP = ModalitySpec(
481
+ name="wri_canopy_height_map",
482
+ tile_resolution_factor=16,
483
+ band_sets=[BandSet(["B1"], 16)],
484
+ is_multitemporal=False,
485
+ ignore_when_parsing=False,
486
+ )
487
+
488
+ @classmethod
489
+ def get(self, name: str) -> ModalitySpec:
490
+ """Get the ModalitySpec with the specified name."""
491
+ modality = getattr(Modality, name.upper())
492
+ assert modality.name == name
493
+ return modality
494
+
495
+ @classmethod
496
+ def values(self) -> list[ModalitySpec]:
497
+ """Get all of the ModalitySpecs."""
498
+ modalities = []
499
+ for k in dir(Modality):
500
+ modality = getattr(Modality, k)
501
+ if not isinstance(modality, ModalitySpec):
502
+ continue
503
+ modalities.append(modality)
504
+ return modalities
505
+
506
+ @classmethod
507
+ def names(self) -> list[str]:
508
+ """Get all of the modality names."""
509
+ return [modality.name for modality in self.values()]
510
+
511
+
512
+ # Latlon and timestamps
513
+ LATLON = ["lat", "lon"]
514
+ TIMESTAMPS = ["day", "month", "year"]
515
+
516
+
517
+ def get_modality_specs_from_names(names: list[str]) -> list[ModalitySpec]:
518
+ """Get the modality specs from the names."""
519
+ return [Modality.get(name) for name in names]
@@ -0,0 +1,165 @@
1
+ """Data structures for OlmoEarth Pretrain."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+ from typing import TYPE_CHECKING, Any, NamedTuple
7
+
8
+ import torch
9
+
10
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.types import ArrayTensor
11
+
12
+
13
+ class MaskValue(Enum):
14
+ """Masks can take 4 possible values.
15
+
16
+ ONLINE_ENCODER: The token is seen by the online encoder
17
+ TARGET_ENCODER_ONLY: The token is seen by the target encoder only
18
+ DECODER: The token is seen by the decoder only
19
+ MISSING: The token is missing
20
+ """
21
+
22
+ ONLINE_ENCODER = 0
23
+ TARGET_ENCODER_ONLY = 1
24
+ DECODER = 2
25
+ MISSING = 3
26
+
27
+
28
+ class MaskedOlmoEarthSample(NamedTuple):
29
+ """A masked sample of the data from the OlmoEarth Pretrain dataset.
30
+
31
+ We always require sentinel2 data.
32
+ This is a namedtuple that contains the data for a single sample from the OlmoEarth Pretrain dataset.
33
+ latlon and timestamps are the same for all modalities.
34
+ For each modality. we have an ArrayTensor named by modality, and a mask for each modality named by modality_mask.
35
+ we also have a mask for the latlon called latlon_mask
36
+ """
37
+
38
+ timestamps: (
39
+ ArrayTensor # [B, T, D=3], where D=[day, month, year] (months are zero indexed)
40
+ )
41
+ sentinel2_l2a: ArrayTensor | None = None
42
+ sentinel2_l2a_mask: ArrayTensor | None = None
43
+ sentinel1: ArrayTensor | None = None
44
+ sentinel1_mask: ArrayTensor | None = None
45
+ worldcover: ArrayTensor | None = None
46
+ worldcover_mask: ArrayTensor | None = None
47
+ latlon: ArrayTensor | None = None # [B, 2]
48
+ latlon_mask: ArrayTensor | None = None
49
+ openstreetmap_raster: ArrayTensor | None = None
50
+ openstreetmap_raster_mask: ArrayTensor | None = None
51
+ srtm: ArrayTensor | None = None
52
+ srtm_mask: ArrayTensor | None = None
53
+ landsat: ArrayTensor | None = None
54
+ landsat_mask: ArrayTensor | None = None
55
+ naip: ArrayTensor | None = None
56
+ naip_mask: ArrayTensor | None = None
57
+ naip_10: ArrayTensor | None = None
58
+ naip_10_mask: ArrayTensor | None = None
59
+ gse: ArrayTensor | None = None
60
+ gse_mask: ArrayTensor | None = None
61
+ cdl: ArrayTensor | None = None
62
+ cdl_mask: ArrayTensor | None = None
63
+ worldpop: ArrayTensor | None = None
64
+ worldpop_mask: ArrayTensor | None = None
65
+ worldcereal: ArrayTensor | None = None
66
+ worldcereal_mask: ArrayTensor | None = None
67
+ wri_canopy_height_map: ArrayTensor | None = None
68
+ wri_canopy_height_map_mask: ArrayTensor | None = None
69
+ era5_10: ArrayTensor | None = None
70
+ era5_10_mask: ArrayTensor | None = None
71
+
72
+ def as_dict(self, return_none: bool = True) -> dict[str, Any]:
73
+ """Convert the namedtuple to a dictionary.
74
+
75
+ Returns:
76
+ Dictionary representation of the namedtuple.
77
+ """
78
+ return_dict = {}
79
+ for field in self._fields:
80
+ val = getattr(self, field)
81
+ if return_none:
82
+ return_dict[field] = val
83
+ else:
84
+ if val is not None:
85
+ return_dict[field] = val
86
+ return return_dict
87
+
88
+ def unmask(self) -> MaskedOlmoEarthSample:
89
+ """Return an unmasked MaskedOlmoEarthSample.
90
+
91
+ All mask values are MaskValue.ONLINE_ENCODER except for MaskValue.MISSING,
92
+ which remain MISSING.
93
+ """
94
+ return_dict: dict[str, ArrayTensor] = {}
95
+ for key, val in self.as_dict().items():
96
+ if val is None:
97
+ continue
98
+ if key.endswith("mask"):
99
+ # 1s where it is missing, 0 elsewhere
100
+ all_but_missing = val == MaskValue.MISSING.value
101
+ return_dict[key] = val * all_but_missing
102
+ else:
103
+ return_dict[key] = val
104
+ return MaskedOlmoEarthSample(**return_dict)
105
+
106
+ @property
107
+ def modalities(self) -> list[str]:
108
+ """Get the present modalities in this instance of MaskedOlmoEarthSample."""
109
+ return [
110
+ field
111
+ for field in self._fields
112
+ if not field.endswith("_mask")
113
+ and field != "timestamps"
114
+ and getattr(self, field) is not None
115
+ ]
116
+
117
+ @staticmethod
118
+ def get_masked_modality_name(modality: str) -> str:
119
+ """Get the masked modality name."""
120
+ return f"{modality}_mask"
121
+
122
+ @staticmethod
123
+ def get_unmasked_modality_name(modality_mask_name: str) -> str:
124
+ """Get the unmasked modality name."""
125
+ return modality_mask_name.replace("_mask", "")
126
+
127
+ @classmethod
128
+ def from_olmoearthsample(
129
+ cls,
130
+ sample: Any, # OlmoEarthSample - not available in minimal repo
131
+ ) -> MaskedOlmoEarthSample:
132
+ """Transforms a OlmoEarthSample into a MaskedOlmoEarthSample.
133
+
134
+ This function assumes modalities are uniformly missing.
135
+ """
136
+ masked_sample_dict = {}
137
+ for key, t in sample.as_dict(ignore_nones=False).items():
138
+ if key == "timestamps":
139
+ # lets assume timestamps is not None
140
+ masked_sample_dict[key] = t
141
+ else:
142
+ if t is None:
143
+ masked_sample_dict[key] = None
144
+ masked_sample_dict[
145
+ MaskedOlmoEarthSample.get_masked_modality_name(key)
146
+ ] = None
147
+ else:
148
+ masked_sample_dict[key] = t
149
+ masked_sample_dict[
150
+ MaskedOlmoEarthSample.get_masked_modality_name(key)
151
+ ] = (
152
+ torch.ones(sample.shape(key, mask=False))
153
+ * MaskValue.ONLINE_ENCODER.value
154
+ )
155
+
156
+ return MaskedOlmoEarthSample(**masked_sample_dict)
157
+
158
+ @classmethod
159
+ def from_dict(cls, dict: dict[str, Any]) -> MaskedOlmoEarthSample:
160
+ """Create a MaskedOlmoEarthSample from a dictionary, creating empty tensors for missing modalities.
161
+
162
+ Args:
163
+ dict: Dictionary representation of the MaskedOlmoEarthSample.
164
+ """
165
+ return cls(**dict)