olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""Constants shared across the OlmoEarth Pretrain package.
|
|
2
|
+
|
|
3
|
+
Warning: this is only developed for raster data currently.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
# The highest resolution that we are working at.
|
|
10
|
+
# Everything else is a factor (which is a power of 2) coarser than this resolution.
|
|
11
|
+
BASE_RESOLUTION = 0.625
|
|
12
|
+
|
|
13
|
+
# The default image tile size.
|
|
14
|
+
# Some images may be smaller if they are stored at a coarser resolution compared to the
|
|
15
|
+
# resolution that the grid is based on.
|
|
16
|
+
IMAGE_TILE_SIZE = 256
|
|
17
|
+
|
|
18
|
+
PROJECTION_CRS = "EPSG:4326"
|
|
19
|
+
|
|
20
|
+
# Default missing value for raster data.
|
|
21
|
+
MISSING_VALUE = -99999
|
|
22
|
+
|
|
23
|
+
# Default maximum sequence length.
|
|
24
|
+
MAX_SEQUENCE_LENGTH = 12
|
|
25
|
+
|
|
26
|
+
# Resolution of the input data in meters
|
|
27
|
+
BASE_GSD = 10
|
|
28
|
+
# Default nodata value for Sentinel-1 data.
|
|
29
|
+
SENTINEL1_NODATA = -32768
|
|
30
|
+
|
|
31
|
+
# Number of timesteps for YEAR data.
|
|
32
|
+
YEAR_NUM_TIMESTEPS = 12
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_resolution(resolution_factor: int) -> float | int:
|
|
36
|
+
"""Compute the resolution.
|
|
37
|
+
|
|
38
|
+
If it is an integer, then we cast it to int so that it works with the raw OlmoEarth Pretrain
|
|
39
|
+
dataset, where some files are named based on the integer. We may want to change
|
|
40
|
+
this in the future to avoid the extra code here.
|
|
41
|
+
"""
|
|
42
|
+
resolution = BASE_RESOLUTION * resolution_factor
|
|
43
|
+
if float(int(resolution)) == resolution:
|
|
44
|
+
return int(resolution)
|
|
45
|
+
return resolution
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class BandSet:
|
|
50
|
+
"""A group of bands that is stored at the same resolution.
|
|
51
|
+
|
|
52
|
+
Many modalities only have one band set, but some have different bands at different
|
|
53
|
+
resolutions.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# List of band names.
|
|
57
|
+
bands: list[str]
|
|
58
|
+
|
|
59
|
+
# Resolution is BASE_RESOLUTION * resolution_factor.
|
|
60
|
+
# If resolution == 0, this means the data
|
|
61
|
+
# does not vary in space (e.g. latlons)
|
|
62
|
+
resolution_factor: int
|
|
63
|
+
|
|
64
|
+
def __hash__(self) -> int:
|
|
65
|
+
"""Hash this BandSet."""
|
|
66
|
+
return hash((tuple(self.bands), self.resolution_factor))
|
|
67
|
+
|
|
68
|
+
def get_resolution(self) -> float:
|
|
69
|
+
"""Compute the resolution."""
|
|
70
|
+
return get_resolution(self.resolution_factor)
|
|
71
|
+
|
|
72
|
+
def get_expected_image_size(self, modality_resolution_factor: int) -> int:
|
|
73
|
+
"""Get the expected size of images containing these bands.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
modality_resolution_factor: the resolution factor of the modality.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
the expected image size.
|
|
80
|
+
"""
|
|
81
|
+
return IMAGE_TILE_SIZE // (self.resolution_factor // modality_resolution_factor)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class TimeSpan(str, Enum):
|
|
85
|
+
"""Enum to distinguish data that is valid for different time ranges."""
|
|
86
|
+
|
|
87
|
+
# Only one data point (not time series).
|
|
88
|
+
STATIC = "static"
|
|
89
|
+
|
|
90
|
+
# Monthly over one year.
|
|
91
|
+
YEAR = "year"
|
|
92
|
+
|
|
93
|
+
# Every data point in a two-week period.
|
|
94
|
+
TWO_WEEK = "two_week"
|
|
95
|
+
|
|
96
|
+
def get_suffix(self) -> str:
|
|
97
|
+
"""Returns the suffix used for this timespan in raw OlmoEarth Pretrain dataset."""
|
|
98
|
+
if self == TimeSpan.STATIC:
|
|
99
|
+
return ""
|
|
100
|
+
if self == TimeSpan.YEAR:
|
|
101
|
+
return "_monthly"
|
|
102
|
+
if self == TimeSpan.TWO_WEEK:
|
|
103
|
+
return "_freq"
|
|
104
|
+
raise ValueError("invalid TimeSpan")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass(frozen=True)
|
|
108
|
+
class ModalitySpec:
|
|
109
|
+
"""Modality specification.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
name: the name of the modality.
|
|
113
|
+
tile_resolution_factor: the factor of how much more ground area is covered by the tile compared with a tile
|
|
114
|
+
of IMAGE_TILE_SIZE x IMAGE_TILE_SIZE pixels at the base resolution.
|
|
115
|
+
band_sets: the band sets of the modality, ie the units of tokenization.
|
|
116
|
+
is_multitemporal: whether the modality is multitemporal.
|
|
117
|
+
ignore_when_parsing: whether to ignore the modality when parsing the data form the csv file.
|
|
118
|
+
image_tile_size_factor: the factor of how much bigger the dimensions of the image tile are compared with the base tile size.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
name: str
|
|
122
|
+
tile_resolution_factor: int
|
|
123
|
+
band_sets: list[BandSet]
|
|
124
|
+
is_multitemporal: bool
|
|
125
|
+
ignore_when_parsing: bool # If true this modality is not parsed from the csv file and not loaded form a file
|
|
126
|
+
image_tile_size_factor: int = 1
|
|
127
|
+
|
|
128
|
+
def __hash__(self) -> int:
|
|
129
|
+
"""Hash this Modality."""
|
|
130
|
+
return hash(self.name)
|
|
131
|
+
|
|
132
|
+
def get_tile_resolution(self) -> float:
|
|
133
|
+
"""Compute the tile resolution."""
|
|
134
|
+
return get_resolution(self.tile_resolution_factor)
|
|
135
|
+
|
|
136
|
+
def bandsets_as_indices(self) -> list[list[int]]:
|
|
137
|
+
"""Return band sets as indices."""
|
|
138
|
+
indices = []
|
|
139
|
+
offset = 0
|
|
140
|
+
for band_set in self.band_sets:
|
|
141
|
+
num_bands = len(band_set.bands)
|
|
142
|
+
indices.append(list(range(offset, offset + num_bands)))
|
|
143
|
+
offset += num_bands
|
|
144
|
+
return indices
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def band_order(self) -> list[str]:
|
|
148
|
+
"""Get all bands."""
|
|
149
|
+
return sum((list(band_set.bands) for band_set in self.band_sets), [])
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def num_band_sets(self) -> int:
|
|
153
|
+
"""Get the number of band sets."""
|
|
154
|
+
return len(self.band_sets)
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def num_bands(self) -> int:
|
|
158
|
+
"""Get the number of channels.
|
|
159
|
+
|
|
160
|
+
The number of channels is the sum of the number of bands in all the band sets.
|
|
161
|
+
"""
|
|
162
|
+
return sum(len(band_set.bands) for band_set in self.band_sets)
|
|
163
|
+
|
|
164
|
+
def get_expected_tile_size(self) -> int:
|
|
165
|
+
"""Get the expected size of the tile."""
|
|
166
|
+
if self.image_tile_size_factor < 0:
|
|
167
|
+
return IMAGE_TILE_SIZE // abs(self.image_tile_size_factor)
|
|
168
|
+
else:
|
|
169
|
+
return IMAGE_TILE_SIZE * self.image_tile_size_factor
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def is_spatial(self) -> bool:
|
|
173
|
+
"""Does the modality have spatial data."""
|
|
174
|
+
# Tile size must be greater than 1 to have spatial varying data.
|
|
175
|
+
return self.get_tile_resolution() > 0 and self.get_expected_tile_size() > 1
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def is_spacetime_varying(self) -> bool:
|
|
179
|
+
"""Does the modality vary in space and time."""
|
|
180
|
+
return self.is_spatial and self.is_multitemporal
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def is_space_only_varying(self) -> bool:
|
|
184
|
+
"""Does the modality vary in space and not time."""
|
|
185
|
+
return self.is_spatial and not self.is_multitemporal
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def is_time_only_varying(self) -> bool:
|
|
189
|
+
"""Does the modality vary in time and not space."""
|
|
190
|
+
return not self.is_spatial and self.is_multitemporal
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def is_static_in_space_and_time(self) -> bool:
|
|
194
|
+
"""Does the modality vary in neither space or space."""
|
|
195
|
+
return not self.is_spatial and not self.is_multitemporal
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class Modality:
|
|
199
|
+
"""Enum-like access to ModalitySpecs."""
|
|
200
|
+
|
|
201
|
+
NAIP = ModalitySpec(
|
|
202
|
+
name="naip",
|
|
203
|
+
tile_resolution_factor=1,
|
|
204
|
+
band_sets=[BandSet(["R", "G", "B", "IR"], 1)],
|
|
205
|
+
is_multitemporal=False,
|
|
206
|
+
ignore_when_parsing=False,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# NAIP_10 is the NAIP data that covers the same extent as a IMAGE_TILE_SIZE x IMAGE_TILE_SIZE tile
|
|
210
|
+
# at 10 m/pixel resolution but is still stored at NAIP resolution.
|
|
211
|
+
NAIP_10 = ModalitySpec(
|
|
212
|
+
name="naip_10",
|
|
213
|
+
tile_resolution_factor=16,
|
|
214
|
+
band_sets=[BandSet(["R", "G", "B", "IR"], 1)],
|
|
215
|
+
is_multitemporal=False,
|
|
216
|
+
ignore_when_parsing=False,
|
|
217
|
+
# Currently this is set to 4x (2.5 m/pixel) so that it is more feasible to
|
|
218
|
+
# train with NAIP_10. This way we end up with 512x512 NAIP images in the
|
|
219
|
+
# 128x128 H5 files instead of 2048x2048, which slows down data loading.
|
|
220
|
+
image_tile_size_factor=4,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
SENTINEL1 = ModalitySpec(
|
|
224
|
+
name="sentinel1",
|
|
225
|
+
tile_resolution_factor=16,
|
|
226
|
+
band_sets=[BandSet(["vv", "vh"], 16)],
|
|
227
|
+
is_multitemporal=True,
|
|
228
|
+
ignore_when_parsing=False,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
SENTINEL2 = ModalitySpec(
|
|
232
|
+
name="sentinel2",
|
|
233
|
+
tile_resolution_factor=16,
|
|
234
|
+
band_sets=[
|
|
235
|
+
# 10 m/pixel bands.
|
|
236
|
+
BandSet(["B02", "B03", "B04", "B08"], 16),
|
|
237
|
+
# 20 m/pixel bands.
|
|
238
|
+
BandSet(["B05", "B06", "B07", "B8A", "B11", "B12"], 32),
|
|
239
|
+
# 60 m/pixel bands that we store at 40 m/pixel.
|
|
240
|
+
BandSet(["B01", "B09", "B10"], 64),
|
|
241
|
+
],
|
|
242
|
+
is_multitemporal=True,
|
|
243
|
+
ignore_when_parsing=False,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
SENTINEL2_L2A = ModalitySpec(
|
|
247
|
+
name="sentinel2_l2a",
|
|
248
|
+
tile_resolution_factor=16,
|
|
249
|
+
band_sets=[
|
|
250
|
+
# 10 m/pixel bands.
|
|
251
|
+
BandSet(["B02", "B03", "B04", "B08"], 16),
|
|
252
|
+
# 20 m/pixel bands.
|
|
253
|
+
BandSet(["B05", "B06", "B07", "B8A", "B11", "B12"], 32),
|
|
254
|
+
# 60 m/pixel bands that we store at 40 m/pixel.
|
|
255
|
+
BandSet(["B01", "B09"], 64),
|
|
256
|
+
],
|
|
257
|
+
is_multitemporal=True,
|
|
258
|
+
ignore_when_parsing=False,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
LANDSAT = ModalitySpec(
|
|
262
|
+
name="landsat",
|
|
263
|
+
tile_resolution_factor=16,
|
|
264
|
+
band_sets=[
|
|
265
|
+
# 15 m/pixel bands that we store at 10 m/pixel.
|
|
266
|
+
BandSet(["B8"], 16),
|
|
267
|
+
# 30 m/pixel bands that we store at 20 m/pixel.
|
|
268
|
+
BandSet(["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"], 32),
|
|
269
|
+
],
|
|
270
|
+
is_multitemporal=True,
|
|
271
|
+
ignore_when_parsing=False,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
WORLDCOVER = ModalitySpec(
|
|
275
|
+
name="worldcover",
|
|
276
|
+
tile_resolution_factor=16,
|
|
277
|
+
band_sets=[BandSet(["B1"], 16)],
|
|
278
|
+
is_multitemporal=False,
|
|
279
|
+
ignore_when_parsing=False,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
WORLDCEREAL = ModalitySpec(
|
|
283
|
+
name="worldcereal",
|
|
284
|
+
tile_resolution_factor=16,
|
|
285
|
+
band_sets=[
|
|
286
|
+
BandSet(
|
|
287
|
+
[
|
|
288
|
+
"tc-annual-temporarycrops-classification",
|
|
289
|
+
"tc-maize-main-irrigation-classification",
|
|
290
|
+
"tc-maize-main-maize-classification",
|
|
291
|
+
"tc-maize-second-irrigation-classification",
|
|
292
|
+
"tc-maize-second-maize-classification",
|
|
293
|
+
"tc-springcereals-springcereals-classification",
|
|
294
|
+
"tc-wintercereals-irrigation-classification",
|
|
295
|
+
"tc-wintercereals-wintercereals-classification",
|
|
296
|
+
],
|
|
297
|
+
16,
|
|
298
|
+
)
|
|
299
|
+
],
|
|
300
|
+
is_multitemporal=False,
|
|
301
|
+
ignore_when_parsing=False,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
SRTM = ModalitySpec(
|
|
305
|
+
name="srtm",
|
|
306
|
+
tile_resolution_factor=16,
|
|
307
|
+
band_sets=[BandSet(["srtm"], 16)],
|
|
308
|
+
is_multitemporal=False,
|
|
309
|
+
ignore_when_parsing=False,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
OPENSTREETMAP = ModalitySpec(
|
|
313
|
+
name="openstreetmap",
|
|
314
|
+
tile_resolution_factor=16,
|
|
315
|
+
band_sets=[
|
|
316
|
+
BandSet(
|
|
317
|
+
[
|
|
318
|
+
"aerialway_pylon",
|
|
319
|
+
"aerodrome",
|
|
320
|
+
"airstrip",
|
|
321
|
+
"amenity_fuel",
|
|
322
|
+
"building",
|
|
323
|
+
"chimney",
|
|
324
|
+
"communications_tower",
|
|
325
|
+
"crane",
|
|
326
|
+
"flagpole",
|
|
327
|
+
"fountain",
|
|
328
|
+
"generator_wind",
|
|
329
|
+
"helipad",
|
|
330
|
+
"highway",
|
|
331
|
+
"leisure",
|
|
332
|
+
"lighthouse",
|
|
333
|
+
"obelisk",
|
|
334
|
+
"observatory",
|
|
335
|
+
"parking",
|
|
336
|
+
"petroleum_well",
|
|
337
|
+
"power_plant",
|
|
338
|
+
"power_substation",
|
|
339
|
+
"power_tower",
|
|
340
|
+
"river",
|
|
341
|
+
"runway",
|
|
342
|
+
"satellite_dish",
|
|
343
|
+
"silo",
|
|
344
|
+
"storage_tank",
|
|
345
|
+
"taxiway",
|
|
346
|
+
"water_tower",
|
|
347
|
+
"works",
|
|
348
|
+
],
|
|
349
|
+
1,
|
|
350
|
+
)
|
|
351
|
+
],
|
|
352
|
+
is_multitemporal=False,
|
|
353
|
+
ignore_when_parsing=True,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
OPENSTREETMAP_RASTER = ModalitySpec(
|
|
357
|
+
name="openstreetmap_raster",
|
|
358
|
+
tile_resolution_factor=16,
|
|
359
|
+
band_sets=[
|
|
360
|
+
BandSet(
|
|
361
|
+
[
|
|
362
|
+
"aerialway_pylon",
|
|
363
|
+
"aerodrome",
|
|
364
|
+
"airstrip",
|
|
365
|
+
"amenity_fuel",
|
|
366
|
+
"building",
|
|
367
|
+
"chimney",
|
|
368
|
+
"communications_tower",
|
|
369
|
+
"crane",
|
|
370
|
+
"flagpole",
|
|
371
|
+
"fountain",
|
|
372
|
+
"generator_wind",
|
|
373
|
+
"helipad",
|
|
374
|
+
"highway",
|
|
375
|
+
"leisure",
|
|
376
|
+
"lighthouse",
|
|
377
|
+
"obelisk",
|
|
378
|
+
"observatory",
|
|
379
|
+
"parking",
|
|
380
|
+
"petroleum_well",
|
|
381
|
+
"power_plant",
|
|
382
|
+
"power_substation",
|
|
383
|
+
"power_tower",
|
|
384
|
+
"river",
|
|
385
|
+
"runway",
|
|
386
|
+
"satellite_dish",
|
|
387
|
+
"silo",
|
|
388
|
+
"storage_tank",
|
|
389
|
+
"taxiway",
|
|
390
|
+
"water_tower",
|
|
391
|
+
"works",
|
|
392
|
+
],
|
|
393
|
+
4,
|
|
394
|
+
)
|
|
395
|
+
],
|
|
396
|
+
is_multitemporal=False,
|
|
397
|
+
ignore_when_parsing=False,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
ERA5 = ModalitySpec(
|
|
401
|
+
name="era5",
|
|
402
|
+
# 9 km/pixel bands that we store at 150 m/pixel.
|
|
403
|
+
tile_resolution_factor=256,
|
|
404
|
+
band_sets=[
|
|
405
|
+
BandSet(
|
|
406
|
+
[
|
|
407
|
+
"2m-temperature",
|
|
408
|
+
"2m-dewpoint-temperature",
|
|
409
|
+
"surface-pressure",
|
|
410
|
+
"10m-u-component-of-wind",
|
|
411
|
+
"10m-v-component-of-wind",
|
|
412
|
+
"total-precipitation",
|
|
413
|
+
],
|
|
414
|
+
256,
|
|
415
|
+
),
|
|
416
|
+
],
|
|
417
|
+
is_multitemporal=True,
|
|
418
|
+
ignore_when_parsing=True,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
ERA5_10 = ModalitySpec(
|
|
422
|
+
name="era5_10",
|
|
423
|
+
# 9 km/pixel bands that we store at 2.56 km/pixel.
|
|
424
|
+
tile_resolution_factor=16,
|
|
425
|
+
band_sets=[
|
|
426
|
+
BandSet(
|
|
427
|
+
[
|
|
428
|
+
"2m-temperature",
|
|
429
|
+
"2m-dewpoint-temperature",
|
|
430
|
+
"surface-pressure",
|
|
431
|
+
"10m-u-component-of-wind",
|
|
432
|
+
"10m-v-component-of-wind",
|
|
433
|
+
"total-precipitation",
|
|
434
|
+
],
|
|
435
|
+
4096,
|
|
436
|
+
),
|
|
437
|
+
],
|
|
438
|
+
is_multitemporal=True,
|
|
439
|
+
ignore_when_parsing=False,
|
|
440
|
+
image_tile_size_factor=-256,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
LATLON = ModalitySpec(
|
|
444
|
+
name="latlon",
|
|
445
|
+
tile_resolution_factor=0,
|
|
446
|
+
band_sets=[BandSet(["lat", "lon"], 0)],
|
|
447
|
+
is_multitemporal=False,
|
|
448
|
+
ignore_when_parsing=True,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
GSE = ModalitySpec(
|
|
452
|
+
name="gse",
|
|
453
|
+
tile_resolution_factor=16,
|
|
454
|
+
band_sets=[
|
|
455
|
+
BandSet(
|
|
456
|
+
[f"A{idx:02d}" for idx in range(64)],
|
|
457
|
+
16,
|
|
458
|
+
),
|
|
459
|
+
],
|
|
460
|
+
is_multitemporal=False,
|
|
461
|
+
ignore_when_parsing=False,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
CDL = ModalitySpec(
|
|
465
|
+
name="cdl",
|
|
466
|
+
tile_resolution_factor=16,
|
|
467
|
+
band_sets=[BandSet(["cdl"], 16)],
|
|
468
|
+
is_multitemporal=False,
|
|
469
|
+
ignore_when_parsing=False,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
WORLDPOP = ModalitySpec(
|
|
473
|
+
name="worldpop",
|
|
474
|
+
tile_resolution_factor=16,
|
|
475
|
+
band_sets=[BandSet(["B1"], 16)],
|
|
476
|
+
is_multitemporal=False,
|
|
477
|
+
ignore_when_parsing=False,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
WRI_CANOPY_HEIGHT_MAP = ModalitySpec(
|
|
481
|
+
name="wri_canopy_height_map",
|
|
482
|
+
tile_resolution_factor=16,
|
|
483
|
+
band_sets=[BandSet(["B1"], 16)],
|
|
484
|
+
is_multitemporal=False,
|
|
485
|
+
ignore_when_parsing=False,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
@classmethod
|
|
489
|
+
def get(self, name: str) -> ModalitySpec:
|
|
490
|
+
"""Get the ModalitySpec with the specified name."""
|
|
491
|
+
modality = getattr(Modality, name.upper())
|
|
492
|
+
assert modality.name == name
|
|
493
|
+
return modality
|
|
494
|
+
|
|
495
|
+
@classmethod
|
|
496
|
+
def values(self) -> list[ModalitySpec]:
|
|
497
|
+
"""Get all of the ModalitySpecs."""
|
|
498
|
+
modalities = []
|
|
499
|
+
for k in dir(Modality):
|
|
500
|
+
modality = getattr(Modality, k)
|
|
501
|
+
if not isinstance(modality, ModalitySpec):
|
|
502
|
+
continue
|
|
503
|
+
modalities.append(modality)
|
|
504
|
+
return modalities
|
|
505
|
+
|
|
506
|
+
@classmethod
|
|
507
|
+
def names(self) -> list[str]:
|
|
508
|
+
"""Get all of the modality names."""
|
|
509
|
+
return [modality.name for modality in self.values()]
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
# Latlon and timestamps
|
|
513
|
+
LATLON = ["lat", "lon"]
|
|
514
|
+
TIMESTAMPS = ["day", "month", "year"]
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def get_modality_specs_from_names(names: list[str]) -> list[ModalitySpec]:
|
|
518
|
+
"""Get the modality specs from the names."""
|
|
519
|
+
return [Modality.get(name) for name in names]
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Data structures for OlmoEarth Pretrain."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import TYPE_CHECKING, Any, NamedTuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.types import ArrayTensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MaskValue(Enum):
|
|
14
|
+
"""Masks can take 4 possible values.
|
|
15
|
+
|
|
16
|
+
ONLINE_ENCODER: The token is seen by the online encoder
|
|
17
|
+
TARGET_ENCODER_ONLY: The token is seen by the target encoder only
|
|
18
|
+
DECODER: The token is seen by the decoder only
|
|
19
|
+
MISSING: The token is missing
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
ONLINE_ENCODER = 0
|
|
23
|
+
TARGET_ENCODER_ONLY = 1
|
|
24
|
+
DECODER = 2
|
|
25
|
+
MISSING = 3
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MaskedOlmoEarthSample(NamedTuple):
|
|
29
|
+
"""A masked sample of the data from the OlmoEarth Pretrain dataset.
|
|
30
|
+
|
|
31
|
+
We always require sentinel2 data.
|
|
32
|
+
This is a namedtuple that contains the data for a single sample from the OlmoEarth Pretrain dataset.
|
|
33
|
+
latlon and timestamps are the same for all modalities.
|
|
34
|
+
For each modality. we have an ArrayTensor named by modality, and a mask for each modality named by modality_mask.
|
|
35
|
+
we also have a mask for the latlon called latlon_mask
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
timestamps: (
|
|
39
|
+
ArrayTensor # [B, T, D=3], where D=[day, month, year] (months are zero indexed)
|
|
40
|
+
)
|
|
41
|
+
sentinel2_l2a: ArrayTensor | None = None
|
|
42
|
+
sentinel2_l2a_mask: ArrayTensor | None = None
|
|
43
|
+
sentinel1: ArrayTensor | None = None
|
|
44
|
+
sentinel1_mask: ArrayTensor | None = None
|
|
45
|
+
worldcover: ArrayTensor | None = None
|
|
46
|
+
worldcover_mask: ArrayTensor | None = None
|
|
47
|
+
latlon: ArrayTensor | None = None # [B, 2]
|
|
48
|
+
latlon_mask: ArrayTensor | None = None
|
|
49
|
+
openstreetmap_raster: ArrayTensor | None = None
|
|
50
|
+
openstreetmap_raster_mask: ArrayTensor | None = None
|
|
51
|
+
srtm: ArrayTensor | None = None
|
|
52
|
+
srtm_mask: ArrayTensor | None = None
|
|
53
|
+
landsat: ArrayTensor | None = None
|
|
54
|
+
landsat_mask: ArrayTensor | None = None
|
|
55
|
+
naip: ArrayTensor | None = None
|
|
56
|
+
naip_mask: ArrayTensor | None = None
|
|
57
|
+
naip_10: ArrayTensor | None = None
|
|
58
|
+
naip_10_mask: ArrayTensor | None = None
|
|
59
|
+
gse: ArrayTensor | None = None
|
|
60
|
+
gse_mask: ArrayTensor | None = None
|
|
61
|
+
cdl: ArrayTensor | None = None
|
|
62
|
+
cdl_mask: ArrayTensor | None = None
|
|
63
|
+
worldpop: ArrayTensor | None = None
|
|
64
|
+
worldpop_mask: ArrayTensor | None = None
|
|
65
|
+
worldcereal: ArrayTensor | None = None
|
|
66
|
+
worldcereal_mask: ArrayTensor | None = None
|
|
67
|
+
wri_canopy_height_map: ArrayTensor | None = None
|
|
68
|
+
wri_canopy_height_map_mask: ArrayTensor | None = None
|
|
69
|
+
era5_10: ArrayTensor | None = None
|
|
70
|
+
era5_10_mask: ArrayTensor | None = None
|
|
71
|
+
|
|
72
|
+
def as_dict(self, return_none: bool = True) -> dict[str, Any]:
|
|
73
|
+
"""Convert the namedtuple to a dictionary.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Dictionary representation of the namedtuple.
|
|
77
|
+
"""
|
|
78
|
+
return_dict = {}
|
|
79
|
+
for field in self._fields:
|
|
80
|
+
val = getattr(self, field)
|
|
81
|
+
if return_none:
|
|
82
|
+
return_dict[field] = val
|
|
83
|
+
else:
|
|
84
|
+
if val is not None:
|
|
85
|
+
return_dict[field] = val
|
|
86
|
+
return return_dict
|
|
87
|
+
|
|
88
|
+
def unmask(self) -> MaskedOlmoEarthSample:
|
|
89
|
+
"""Return an unmasked MaskedOlmoEarthSample.
|
|
90
|
+
|
|
91
|
+
All mask values are MaskValue.ONLINE_ENCODER except for MaskValue.MISSING,
|
|
92
|
+
which remain MISSING.
|
|
93
|
+
"""
|
|
94
|
+
return_dict: dict[str, ArrayTensor] = {}
|
|
95
|
+
for key, val in self.as_dict().items():
|
|
96
|
+
if val is None:
|
|
97
|
+
continue
|
|
98
|
+
if key.endswith("mask"):
|
|
99
|
+
# 1s where it is missing, 0 elsewhere
|
|
100
|
+
all_but_missing = val == MaskValue.MISSING.value
|
|
101
|
+
return_dict[key] = val * all_but_missing
|
|
102
|
+
else:
|
|
103
|
+
return_dict[key] = val
|
|
104
|
+
return MaskedOlmoEarthSample(**return_dict)
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def modalities(self) -> list[str]:
|
|
108
|
+
"""Get the present modalities in this instance of MaskedOlmoEarthSample."""
|
|
109
|
+
return [
|
|
110
|
+
field
|
|
111
|
+
for field in self._fields
|
|
112
|
+
if not field.endswith("_mask")
|
|
113
|
+
and field != "timestamps"
|
|
114
|
+
and getattr(self, field) is not None
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def get_masked_modality_name(modality: str) -> str:
|
|
119
|
+
"""Get the masked modality name."""
|
|
120
|
+
return f"{modality}_mask"
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def get_unmasked_modality_name(modality_mask_name: str) -> str:
|
|
124
|
+
"""Get the unmasked modality name."""
|
|
125
|
+
return modality_mask_name.replace("_mask", "")
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def from_olmoearthsample(
|
|
129
|
+
cls,
|
|
130
|
+
sample: Any, # OlmoEarthSample - not available in minimal repo
|
|
131
|
+
) -> MaskedOlmoEarthSample:
|
|
132
|
+
"""Transforms a OlmoEarthSample into a MaskedOlmoEarthSample.
|
|
133
|
+
|
|
134
|
+
This function assumes modalities are uniformly missing.
|
|
135
|
+
"""
|
|
136
|
+
masked_sample_dict = {}
|
|
137
|
+
for key, t in sample.as_dict(ignore_nones=False).items():
|
|
138
|
+
if key == "timestamps":
|
|
139
|
+
# lets assume timestamps is not None
|
|
140
|
+
masked_sample_dict[key] = t
|
|
141
|
+
else:
|
|
142
|
+
if t is None:
|
|
143
|
+
masked_sample_dict[key] = None
|
|
144
|
+
masked_sample_dict[
|
|
145
|
+
MaskedOlmoEarthSample.get_masked_modality_name(key)
|
|
146
|
+
] = None
|
|
147
|
+
else:
|
|
148
|
+
masked_sample_dict[key] = t
|
|
149
|
+
masked_sample_dict[
|
|
150
|
+
MaskedOlmoEarthSample.get_masked_modality_name(key)
|
|
151
|
+
] = (
|
|
152
|
+
torch.ones(sample.shape(key, mask=False))
|
|
153
|
+
* MaskValue.ONLINE_ENCODER.value
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return MaskedOlmoEarthSample(**masked_sample_dict)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def from_dict(cls, dict: dict[str, Any]) -> MaskedOlmoEarthSample:
|
|
160
|
+
"""Create a MaskedOlmoEarthSample from a dictionary, creating empty tensors for missing modalities.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
dict: Dictionary representation of the MaskedOlmoEarthSample.
|
|
164
|
+
"""
|
|
165
|
+
return cls(**dict)
|