rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/dataset.py +15 -16
  3. rslearn/dataset/dataset.py +28 -22
  4. rslearn/lightning_cli.py +22 -11
  5. rslearn/main.py +1 -1
  6. rslearn/models/anysat.py +35 -33
  7. rslearn/models/attention_pooling.py +177 -0
  8. rslearn/models/clip.py +5 -2
  9. rslearn/models/component.py +12 -0
  10. rslearn/models/croma.py +11 -3
  11. rslearn/models/dinov3.py +2 -1
  12. rslearn/models/faster_rcnn.py +2 -1
  13. rslearn/models/galileo/galileo.py +58 -31
  14. rslearn/models/module_wrapper.py +6 -1
  15. rslearn/models/molmo.py +4 -2
  16. rslearn/models/olmoearth_pretrain/model.py +206 -51
  17. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  18. rslearn/models/panopticon.py +3 -1
  19. rslearn/models/presto/presto.py +45 -15
  20. rslearn/models/prithvi.py +9 -7
  21. rslearn/models/sam2_enc.py +3 -1
  22. rslearn/models/satlaspretrain.py +4 -1
  23. rslearn/models/simple_time_series.py +43 -17
  24. rslearn/models/ssl4eo_s12.py +19 -14
  25. rslearn/models/swin.py +3 -1
  26. rslearn/models/terramind.py +5 -4
  27. rslearn/train/all_patches_dataset.py +96 -28
  28. rslearn/train/dataset.py +102 -53
  29. rslearn/train/model_context.py +35 -1
  30. rslearn/train/scheduler.py +15 -0
  31. rslearn/train/tasks/classification.py +8 -2
  32. rslearn/train/tasks/detection.py +3 -2
  33. rslearn/train/tasks/multi_task.py +2 -3
  34. rslearn/train/tasks/per_pixel_regression.py +14 -5
  35. rslearn/train/tasks/regression.py +8 -2
  36. rslearn/train/tasks/segmentation.py +13 -4
  37. rslearn/train/tasks/task.py +2 -2
  38. rslearn/train/transforms/concatenate.py +45 -5
  39. rslearn/train/transforms/crop.py +22 -8
  40. rslearn/train/transforms/flip.py +13 -5
  41. rslearn/train/transforms/mask.py +11 -2
  42. rslearn/train/transforms/normalize.py +46 -15
  43. rslearn/train/transforms/pad.py +15 -3
  44. rslearn/train/transforms/resize.py +83 -0
  45. rslearn/train/transforms/select_bands.py +11 -2
  46. rslearn/train/transforms/sentinel1.py +18 -3
  47. rslearn/utils/geometry.py +73 -0
  48. rslearn/utils/jsonargparse.py +66 -0
  49. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  50. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
  51. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  52. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  53. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  54. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  55. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  import math
4
4
  import tempfile
5
5
  from contextlib import nullcontext
6
+ from datetime import datetime
6
7
  from enum import StrEnum
7
8
  from typing import cast
8
9
 
@@ -411,6 +412,23 @@ class GalileoModel(FeatureExtractor):
411
412
  months=months,
412
413
  )
413
414
 
415
+ @staticmethod
416
+ def time_ranges_to_timestamps(
417
+ time_ranges: list[tuple[datetime, datetime]],
418
+ device: torch.device,
419
+ ) -> torch.Tensor:
420
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
421
+
422
+ Galileo only uses the month associated with each timestamp, so we take the midpoint
423
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
424
+ time so that start_time == end_time == mid_time.
425
+ """
426
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
427
+ # months are indexed 0-11
428
+ return torch.tensor(
429
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
430
+ )
431
+
414
432
  def forward(self, context: ModelContext) -> FeatureMaps:
415
433
  """Compute feature maps from the Galileo backbone.
416
434
 
@@ -418,16 +436,16 @@ class GalileoModel(FeatureExtractor):
418
436
  context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
419
437
  (also documented below) and values are tensors of the following shapes,
420
438
  per input key:
421
- "s1": B (T * C) H W
422
- "s2": B (T * C) H W
423
- "era5": B (T * C) H W (we will average over the H, W dimensions)
424
- "tc": B (T * C) H W (we will average over the H, W dimensions)
425
- "viirs": B (T * C) H W (we will average over the H, W dimensions)
426
- "srtm": B C H W (SRTM has no temporal dimension)
427
- "dw": : B C H W (Dynamic World should be averaged over time)
428
- "wc": B C H W (WorldCereal has no temporal dimension)
429
- "landscan": B C H W (we will average over the H, W dimensions)
430
- "latlon": B C H W (we will average over the H, W dimensions)
439
+ "s1": B C T H W
440
+ "s2": B C T H W
441
+ "era5": B C T H W (we will average over the H, W dimensions)
442
+ "tc": B C T H W (we will average over the H, W dimensions)
443
+ "viirs": B C T H W (we will average over the H, W dimensions)
444
+ "srtm": B C 1 H W (SRTM has no temporal dimension)
445
+ "dw": : B C 1 H W (Dynamic World should be averaged over time)
446
+ "wc": B C 1 H W (WorldCereal has no temporal dimension)
447
+ "landscan": B C 1 H W (we will average over the H, W dimensions)
448
+ "latlon": B C 1 H W (we will average over the H, W dimensions)
431
449
 
432
450
  The output will be an embedding representing the pooled tokens. If there is
433
451
  only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
@@ -436,15 +454,35 @@ class GalileoModel(FeatureExtractor):
436
454
  If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
437
455
  take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
438
456
  """
457
+ space_time_modalities = ["s1", "s2"]
458
+ time_modalities = ["era5", "tc", "viirs"]
439
459
  stacked_inputs = {}
460
+ months: torch.Tensor | None = None
440
461
  for key in context.inputs[0].keys():
441
462
  # assume all the keys in an input are consistent
442
463
  if key in self.input_keys:
443
464
  stacked_inputs[key] = torch.stack(
444
- [inp[key] for inp in context.inputs], dim=0
465
+ [inp[key].image for inp in context.inputs], dim=0
445
466
  )
467
+ if key in space_time_modalities + time_modalities:
468
+ if months is None:
469
+ if context.inputs[0][key].timestamps is not None:
470
+ months = torch.stack(
471
+ [
472
+ self.time_ranges_to_timestamps(
473
+ inp[key].timestamps, # type: ignore
474
+ device=stacked_inputs[key].device,
475
+ )
476
+ for inp in context.inputs
477
+ ],
478
+ dim=0,
479
+ )
480
+
481
+ if months is not None:
482
+ stacked_inputs["months"] = months
483
+
446
484
  s_t_channels = []
447
- for space_time_modality in ["s1", "s2"]:
485
+ for space_time_modality in space_time_modalities:
448
486
  if space_time_modality not in stacked_inputs:
449
487
  continue
450
488
  if space_time_modality == "s1":
@@ -452,36 +490,27 @@ class GalileoModel(FeatureExtractor):
452
490
  else:
453
491
  s_t_channels += self.s_t_channels_s2
454
492
  cur = stacked_inputs[space_time_modality]
455
- # Check if it's single or multitemporal, and reshape accordingly
456
- num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
457
- num_timesteps = cur.shape[1] // num_bands
458
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
493
+ cur = rearrange(cur, "b c t h w -> b h w t c")
459
494
  stacked_inputs[space_time_modality] = cur
460
495
 
461
496
  for space_modality in ["srtm", "dw", "wc"]:
462
497
  if space_modality not in stacked_inputs:
463
498
  continue
499
+ # take the first (and assumed only) timestep
500
+ stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
464
501
  stacked_inputs[space_modality] = rearrange(
465
502
  stacked_inputs[space_modality], "b c h w -> b h w c"
466
503
  )
467
504
 
468
- for time_modality in ["era5", "tc", "viirs"]:
505
+ for time_modality in time_modalities:
469
506
  if time_modality not in stacked_inputs:
470
507
  continue
471
508
  cur = stacked_inputs[time_modality]
472
- # Check if it's single or multitemporal, and reshape accordingly
473
- num_bands = {
474
- "era5": len(ERA5_BANDS),
475
- "tc": len(TC_BANDS),
476
- "viirs": len(VIIRS_BANDS),
477
- }[time_modality]
478
- num_timesteps = cur.shape[1] // num_bands
479
509
  # take the average over the h, w bands since Galileo
480
510
  # treats it as a pixel-timeseries
481
511
  cur = rearrange(
482
- torch.nanmean(torch.nanmean(cur, dim=-1), dim=-1),
483
- "b (t c) -> b t c",
484
- t=num_timesteps,
512
+ torch.nanmean(cur, dim=(-1, -2)),
513
+ "b c t -> b t c",
485
514
  )
486
515
  stacked_inputs[time_modality] = cur
487
516
 
@@ -489,9 +518,8 @@ class GalileoModel(FeatureExtractor):
489
518
  if static_modality not in stacked_inputs:
490
519
  continue
491
520
  cur = stacked_inputs[static_modality]
492
- stacked_inputs[static_modality] = torch.nanmean(
493
- torch.nanmean(cur, dim=-1), dim=-1
494
- )
521
+ stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
522
+
495
523
  galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
496
524
  h = galileo_input.s_t_x.shape[1]
497
525
  if h < self.patch_size:
@@ -511,7 +539,6 @@ class GalileoModel(FeatureExtractor):
511
539
  torch_context = torch.amp.autocast(
512
540
  device_type=device.type, dtype=self.autocast_dtype
513
541
  )
514
-
515
542
  with torch_context:
516
543
  outputs = self.model(
517
544
  s_t_x=galileo_input.s_t_x,
@@ -53,7 +53,12 @@ class EncoderModuleWrapper(FeatureExtractor):
53
53
  Returns:
54
54
  the output from the last wrapped module.
55
55
  """
56
- images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
56
+ # take the first and only timestep. Currently no intermediate
57
+ # components support multi temporal inputs, so if the input is
58
+ # multitemporal it should be wrapped in a simple time series wrapper.
59
+ images = torch.stack(
60
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
61
+ )
57
62
  cur: Any = FeatureMaps([images])
58
63
  for m in self.encoder_modules:
59
64
  cur = m(cur, context)
rslearn/models/molmo.py CHANGED
@@ -47,11 +47,13 @@ class Molmo(FeatureExtractor):
47
47
  a FeatureMaps. Molmo produces features at one scale, so it will contain one
48
48
  feature map that is a Bx24x24x2048 tensor.
49
49
  """
50
- device = context.inputs[0]["image"].device
50
+ device = context.inputs[0]["image"].image.device
51
51
  molmo_inputs_list = []
52
52
  # Process each one so we can isolate just the full image without any crops.
53
53
  for inp in context.inputs:
54
- image = inp["image"].cpu().numpy().transpose(1, 2, 0)
54
+ image = (
55
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
56
+ )
55
57
  processed = self.processor.process(
56
58
  images=[image],
57
59
  text="",
@@ -1,26 +1,27 @@
1
1
  """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
2
 
3
3
  import json
4
+ import warnings
4
5
  from contextlib import nullcontext
6
+ from datetime import datetime
5
7
  from typing import Any
6
8
 
7
9
  import torch
8
10
  from einops import rearrange
9
- from olmo_core.config import Config
10
- from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.config import Config, require_olmo_core
11
12
  from olmoearth_pretrain.data.constants import Modality
13
+ from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
12
14
  from olmoearth_pretrain.model_loader import (
13
15
  ModelID,
14
16
  load_model_from_id,
15
17
  load_model_from_path,
16
18
  )
17
19
  from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
18
- from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
20
  from upath import UPath
20
21
 
21
22
  from rslearn.log_utils import get_logger
22
- from rslearn.models.component import FeatureExtractor, FeatureMaps
23
- from rslearn.train.model_context import ModelContext
23
+ from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
24
+ from rslearn.train.model_context import ModelContext, RasterImage
24
25
 
25
26
  logger = get_logger(__name__)
26
27
 
@@ -60,6 +61,8 @@ class OlmoEarth(FeatureExtractor):
60
61
  random_initialization: bool = False,
61
62
  embedding_size: int | None = None,
62
63
  autocast_dtype: str | None = "bfloat16",
64
+ token_pooling: bool = True,
65
+ use_legacy_timestamps: bool = True,
63
66
  ):
64
67
  """Create a new OlmoEarth model.
65
68
 
@@ -83,7 +86,18 @@ class OlmoEarth(FeatureExtractor):
83
86
  embedding_size: optional embedding size to report via
84
87
  get_backbone_channels (if model_id is not set).
85
88
  autocast_dtype: which dtype to use for autocasting, or set None to disable.
89
+ token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
90
+ there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
91
+ dimensions.
92
+ use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
93
+ from 0 (instead of the actual timestamps of the input). The option to do this is preserved
94
+ for backwards compatability with finetuned models which were trained against this implementation.
86
95
  """
96
+ if use_legacy_timestamps:
97
+ warnings.warn(
98
+ "For new projects, don't use legacy timesteps.", DeprecationWarning
99
+ )
100
+
87
101
  if (
88
102
  sum(
89
103
  [
@@ -133,6 +147,8 @@ class OlmoEarth(FeatureExtractor):
133
147
  else:
134
148
  model = model[part]
135
149
  self.model = model
150
+ self.token_pooling = token_pooling
151
+ self.use_legacy_timestamps = use_legacy_timestamps
136
152
 
137
153
  def _load_model_from_checkpoint(
138
154
  self, checkpoint_upath: UPath, random_initialization: bool
@@ -143,9 +159,12 @@ class OlmoEarth(FeatureExtractor):
143
159
  that contains the distributed checkpoint. This is the format produced by
144
160
  pre-training runs in olmoearth_pretrain.
145
161
  """
146
- # Load the model config and initialize it.
147
162
  # We avoid loading the train module here because it depends on running within
148
163
  # olmo_core.
164
+ # Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
165
+ require_olmo_core("_load_model_from_checkpoint")
166
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
167
+
149
168
  with (checkpoint_upath / "config.json").open() as f:
150
169
  config_dict = json.load(f)
151
170
  model_config = Config.from_dict(config_dict["model"])
@@ -160,58 +179,161 @@ class OlmoEarth(FeatureExtractor):
160
179
 
161
180
  return model
162
181
 
163
- def forward(self, context: ModelContext) -> FeatureMaps:
164
- """Compute feature maps from the OlmoEarth backbone.
182
+ @staticmethod
183
+ def time_ranges_to_timestamps(
184
+ time_ranges: list[tuple[datetime, datetime]],
185
+ max_timestamps: int,
186
+ device: torch.device,
187
+ ) -> torch.Tensor:
188
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
189
+
190
+ OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
191
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
192
+ time so that start_time == end_time == mid_time.
193
+ """
194
+ timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
195
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
196
+ timestamps[: len(time_ranges), 0] = torch.tensor(
197
+ [d.day for d in mid_ranges], dtype=torch.int32
198
+ )
199
+ # months are indexed 0-11
200
+ timestamps[: len(time_ranges), 1] = torch.tensor(
201
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32
202
+ )
203
+ timestamps[: len(time_ranges), 2] = torch.tensor(
204
+ [d.year for d in mid_ranges], dtype=torch.int32
205
+ )
206
+ return timestamps
207
+
208
+ def _prepare_modality_inputs(
209
+ self, context: ModelContext
210
+ ) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
211
+ """Prepare modality tensors and masks for the OlmoEarth model.
212
+
213
+ Uses a two-pass approach to ensure all modalities have consistent timestep
214
+ dimensions for position encoding.
165
215
 
166
216
  Args:
167
- context: the model context. Input dicts should include keys corresponding
168
- to the modalities that should be passed to the OlmoEarth model.
217
+ context: the model context with input tensors.
169
218
 
170
219
  Returns:
171
- a FeatureMaps consisting of one feature map, at 1/patch_size of the input
172
- resolution. Embeddings will be pooled across modalities and timesteps.
220
+ tuple of (sample, present_modalities, device)
173
221
  """
174
222
  kwargs = {}
175
223
  present_modalities = []
176
224
  device = None
177
- # Handle the case where some modalities are multitemporal and some are not.
178
- # We assume all multitemporal modalities have the same number of timesteps.
225
+
226
+ # First pass: find global max_timesteps across all modalities and samples
227
+ # TODO: currently we assume all modalities have the same number of timesteps,
228
+ # which is not true for all cases, and time series time steps are assumed to
229
+ # be 1-month apart. It also assumes continuity between available timesteps.
230
+ # We'll have to fix all that.
179
231
  max_timesteps = 1
232
+ modality_data = {}
233
+ # we will just store the longest time range
234
+ # per instance in the batch. This means it may not be
235
+ # aligned per modality
236
+ timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
237
+ context.inputs
238
+ )
180
239
  for modality in MODALITY_NAMES:
181
240
  if modality not in context.inputs[0]:
182
241
  continue
183
242
  present_modalities.append(modality)
184
- cur = torch.stack([inp[modality] for inp in context.inputs], dim=0)
185
- device = cur.device
186
- # Check if it's single or multitemporal, and reshape accordingly
187
- num_bands = Modality.get(modality).num_bands
188
- num_timesteps = cur.shape[1] // num_bands
189
- max_timesteps = max(max_timesteps, num_timesteps)
190
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
243
+ tensors = []
244
+ for idx, inp in enumerate(context.inputs):
245
+ assert isinstance(inp, RasterImage)
246
+ tensors.append(inp[modality].image)
247
+ cur_timestamps = inp[modality].timestamps
248
+ if cur_timestamps is not None and len(cur_timestamps) > len(
249
+ timestamps_per_instance[idx]
250
+ ):
251
+ timestamps_per_instance[idx] = cur_timestamps
252
+ tensors = [inp[modality].image for inp in context.inputs]
253
+ device = tensors[0].device
254
+ max_t = max(t.shape[1] for t in tensors)
255
+ max_timesteps = max(max_timesteps, max_t)
256
+ modality_data[modality] = (
257
+ tensors,
258
+ len(Modality.get(modality).band_sets),
259
+ )
260
+
261
+ # Second pass: pad and process each modality with global max_timesteps
262
+ for modality in present_modalities:
263
+ tensors, num_band_sets = modality_data[modality]
264
+
265
+ # Pad tensors to target_ch and track original timesteps for masking
266
+ padded = []
267
+ original_timesteps = []
268
+ for t in tensors:
269
+ orig_t = t.shape[1]
270
+ original_timesteps.append(orig_t)
271
+ if orig_t < max_timesteps:
272
+ pad = torch.zeros(
273
+ t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
274
+ dtype=t.dtype,
275
+ device=device,
276
+ )
277
+ t = torch.cat([t, pad], dim=1)
278
+ padded.append(t)
279
+
280
+ cur = torch.stack(padded, dim=0)
281
+ cur = rearrange(cur, "b c t h w -> b h w t c")
191
282
  kwargs[modality] = cur
192
- # Create mask array which is BHWTS (without channels but with band sets).
193
- num_band_sets = len(Modality.get(modality).band_sets)
194
- mask_shape = cur.shape[0:4] + (num_band_sets,)
195
- mask = (
196
- torch.ones(mask_shape, dtype=torch.int32, device=device)
197
- * MaskValue.ONLINE_ENCODER.value
283
+
284
+ # Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
285
+ b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
286
+ mask = torch.full(
287
+ (b, h, w, max_timesteps, num_band_sets),
288
+ fill_value=MaskValue.ONLINE_ENCODER.value,
289
+ dtype=torch.int32,
290
+ device=device,
198
291
  )
292
+ for sample_idx, orig_t in enumerate(original_timesteps):
293
+ if orig_t < max_timesteps:
294
+ mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
199
295
  kwargs[f"{modality}_mask"] = mask
200
296
 
201
- # Timestamps is required.
202
- # Note that only months (0 to 11) are used in OlmoEarth position encoding.
203
- # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
204
- timestamps = torch.zeros(
205
- (len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
206
- )
207
- timestamps[:, :, 0] = 1 # day
208
- timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
209
- None, :
210
- ] # month
211
- timestamps[:, :, 2] = 2024 # year
212
- kwargs["timestamps"] = timestamps
297
+ if self.use_legacy_timestamps:
298
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
299
+ timestamps = torch.zeros(
300
+ (len(context.inputs), max_timesteps, 3),
301
+ dtype=torch.int32,
302
+ device=device,
303
+ )
304
+ timestamps[:, :, 0] = 1 # day
305
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
306
+ None, :
307
+ ] # month
308
+ timestamps[:, :, 2] = 2024 # year
309
+ kwargs["timestamps"] = timestamps
310
+ else:
311
+ if max([len(t) for t in timestamps_per_instance]) == 0:
312
+ # Timestamps is required.
313
+ raise ValueError("No inputs had timestamps.")
314
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
315
+ kwargs["timestamps"] = torch.stack(
316
+ [
317
+ self.time_ranges_to_timestamps(time_range, max_timesteps, device)
318
+ for time_range in timestamps_per_instance
319
+ ],
320
+ dim=0,
321
+ )
322
+
323
+ return MaskedOlmoEarthSample(**kwargs), present_modalities, device
324
+
325
+ def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
326
+ """Compute feature maps from the OlmoEarth backbone.
327
+
328
+ Args:
329
+ context: the model context. Input dicts should include keys corresponding
330
+ to the modalities that should be passed to the OlmoEarth model.
213
331
 
214
- sample = MaskedOlmoEarthSample(**kwargs)
332
+ Returns:
333
+ a FeatureMaps consisting of one feature map, at 1/patch_size of the input
334
+ resolution. Embeddings will be pooled across modalities and timesteps.
335
+ """
336
+ sample, present_modalities, device = self._prepare_modality_inputs(context)
215
337
 
216
338
  # Decide context based on self.autocast_dtype.
217
339
  if self.autocast_dtype is None:
@@ -222,6 +344,14 @@ class OlmoEarth(FeatureExtractor):
222
344
  device_type=device.type, dtype=self.autocast_dtype
223
345
  )
224
346
 
347
+ # Check if we can bypass masks (fast_pass=True)
348
+ missing_tokens = False
349
+ for modality in present_modalities:
350
+ modality_mask = getattr(sample, f"{modality}_mask")
351
+ if torch.any(modality_mask == MaskValue.MISSING.value):
352
+ missing_tokens = True
353
+ break
354
+
225
355
  with torch_context:
226
356
  # Currently we assume the provided model always returns a TokensAndMasks object.
227
357
  tokens_and_masks: TokensAndMasks
@@ -229,7 +359,7 @@ class OlmoEarth(FeatureExtractor):
229
359
  # Encoder has a fast_pass argument to indicate mask is not needed.
230
360
  tokens_and_masks = self.model(
231
361
  sample,
232
- fast_pass=True,
362
+ fast_pass=not missing_tokens,
233
363
  patch_size=self.patch_size,
234
364
  **self.forward_kwargs,
235
365
  )["tokens_and_masks"]
@@ -241,16 +371,41 @@ class OlmoEarth(FeatureExtractor):
241
371
 
242
372
  # Apply temporal/modality pooling so we just have one feature per patch.
243
373
  features = []
244
- for modality in present_modalities:
245
- modality_features = getattr(tokens_and_masks, modality)
246
- # Pool over band sets and timesteps (BHWTSC -> BHWC).
247
- pooled = modality_features.mean(dim=[3, 4])
248
- # We want BHWC -> BCHW.
249
- pooled = rearrange(pooled, "b h w c -> b c h w")
250
- features.append(pooled)
251
- # Pool over the modalities, so we get one BCHW feature map.
252
- pooled = torch.stack(features, dim=0).mean(dim=0)
253
- return FeatureMaps([pooled])
374
+ if self.token_pooling:
375
+ for modality in present_modalities:
376
+ modality_features = getattr(tokens_and_masks, modality) # BHWTSC
377
+ # If fast_pass is False, we need to mask the missing tokens before pooling.
378
+ if missing_tokens:
379
+ modality_masks = getattr(
380
+ tokens_and_masks, f"{modality}_mask"
381
+ ) # BHWTS
382
+ modality_masks_bool = (
383
+ modality_masks != MaskValue.MISSING.value
384
+ ).unsqueeze(-1)
385
+ count = modality_masks_bool.sum(dim=[3, 4])
386
+ # Masked average over band sets and timesteps (BHWTSC -> BHWC).
387
+ pooled = (modality_features * modality_masks_bool).sum(
388
+ dim=[3, 4]
389
+ ) / count.clamp(min=1)
390
+ else:
391
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
392
+ pooled = modality_features.mean(dim=[3, 4])
393
+ # We want BHWC -> BCHW.
394
+ pooled = rearrange(pooled, "b h w c -> b c h w")
395
+ features.append(pooled)
396
+ # Pool over the modalities, so we get one BCHW feature map.
397
+ pooled = torch.stack(features, dim=0).mean(dim=0)
398
+ return FeatureMaps([pooled])
399
+ else:
400
+ for modality in present_modalities:
401
+ modality_features = getattr(tokens_and_masks, modality)
402
+ # Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
403
+ modality_features = rearrange(
404
+ modality_features, "b h w t s c -> b c h w (t s)"
405
+ )
406
+ features.append(modality_features)
407
+ pooled = torch.cat(features, dim=-1)
408
+ return TokenFeatureMaps([pooled])
254
409
 
255
410
  def get_backbone_channels(self) -> list:
256
411
  """Returns the output channels of this model when used as a backbone.
@@ -64,8 +64,8 @@ class OlmoEarthNormalize(Transform):
64
64
  band_norms = self.norm_config[modality_name]
65
65
  image = input_dict[modality_name]
66
66
  # Keep a set of indices to make sure that we normalize all of them.
67
- needed_band_indices = set(range(image.shape[0]))
68
- num_timesteps = image.shape[0] // len(cur_band_names)
67
+ needed_band_indices = set(range(image.image.shape[0]))
68
+ num_timesteps = image.image.shape[0] // len(cur_band_names)
69
69
 
70
70
  for band, norm_dict in band_norms.items():
71
71
  # If multitemporal, normalize each timestep separately.
@@ -73,7 +73,9 @@ class OlmoEarthNormalize(Transform):
73
73
  band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
74
  min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
75
  max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
- image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
76
+ image.image[band_idx] = (image.image[band_idx] - min_val) / (
77
+ max_val - min_val
78
+ )
77
79
  needed_band_indices.remove(band_idx)
78
80
 
79
81
  if len(needed_band_indices) > 0:
@@ -142,7 +142,9 @@ class Panopticon(FeatureExtractor):
142
142
  def forward(self, context: ModelContext) -> FeatureMaps:
143
143
  """Forward pass through the panopticon model."""
144
144
  batch_inputs = {
145
- key: torch.stack([inp[key] for inp in context.inputs], dim=0)
145
+ key: torch.stack(
146
+ [inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
147
+ )
146
148
  for key in context.inputs[0].keys()
147
149
  }
148
150
  panopticon_inputs = self.prepare_input(batch_inputs)