rslearn 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/arg_parser.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Custom Lightning ArgumentParser with environment variable substitution support."""
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  from jsonargparse import Namespace
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
21
20
  def parse_string(
22
21
  self,
23
22
  cfg_str: str,
24
- cfg_path: str | os.PathLike = "",
25
- ext_vars: dict | None = None,
26
- env: bool | None = None,
27
- defaults: bool = True,
28
- with_meta: bool | None = None,
23
+ *args: Any,
29
24
  **kwargs: Any,
30
25
  ) -> Namespace:
31
26
  """Pre-processes string for environment variable substitution before parsing."""
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
33
28
  substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
34
29
 
35
30
  # Call the parent method with the substituted config
36
- return super().parse_string(
37
- substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
38
- )
31
+ return super().parse_string(substituted_cfg_str, *args, **kwargs)
rslearn/config/dataset.py CHANGED
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
25
25
  from upath import UPath
26
26
 
27
27
  from rslearn.log_utils import get_logger
28
- from rslearn.utils import PixelBounds, Projection
28
+ from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
29
29
  from rslearn.utils.raster_format import RasterFormat
30
30
  from rslearn.utils.vector_format import VectorFormat
31
31
 
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
215
215
  Returns:
216
216
  tuple of updated projection and bounds with zoom offset applied
217
217
  """
218
- if self.zoom_offset == 0:
219
- return projection, bounds
220
- projection = Projection(
221
- projection.crs,
222
- projection.x_resolution / (2**self.zoom_offset),
223
- projection.y_resolution / (2**self.zoom_offset),
224
- )
225
- if self.zoom_offset > 0:
226
- zoom_factor = 2**self.zoom_offset
227
- bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
218
+ if self.zoom_offset >= 0:
219
+ factor = ResolutionFactor(numerator=2**self.zoom_offset)
228
220
  else:
229
- bounds = tuple(
230
- x // (2 ** (-self.zoom_offset))
231
- for x in bounds # type: ignore
232
- )
233
- return projection, bounds
221
+ factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
222
+
223
+ return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
234
224
 
235
225
  @field_validator("format", mode="before")
236
226
  @classmethod
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
645
635
  default_factory=lambda: StorageConfig(),
646
636
  description="jsonargparse configuration for the WindowStorageFactory.",
647
637
  )
638
+
639
+ @field_validator("layers", mode="after")
640
+ @classmethod
641
+ def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642
+ """Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643
+ for layer_name in v.keys():
644
+ if "." in layer_name:
645
+ raise ValueError(f"layer names must not contain periods: {layer_name}")
646
+ return v
@@ -23,7 +23,7 @@ class Dataset:
23
23
  .. code-block:: none
24
24
 
25
25
  dataset/
26
- config.json
26
+ config.json # optional, if config provided as runtime object
27
27
  windows/
28
28
  group1/
29
29
  epsg:3857_10_623565_1528020/
@@ -40,37 +40,43 @@ class Dataset:
40
40
  materialize.
41
41
  """
42
42
 
43
- def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
43
+ def __init__(
44
+ self,
45
+ path: UPath,
46
+ disabled_layers: list[str] = [],
47
+ dataset_config: DatasetConfig | None = None,
48
+ ) -> None:
44
49
  """Initializes a new Dataset.
45
50
 
46
51
  Args:
47
52
  path: the root directory of the dataset
48
53
  disabled_layers: list of layers to disable
54
+ dataset_config: optional dataset configuration to use instead of loading from the dataset directory
49
55
  """
50
56
  self.path = path
51
57
 
52
- # Load dataset configuration.
53
- with (self.path / "config.json").open("r") as f:
54
- config_content = f.read()
55
- config_content = substitute_env_vars_in_string(config_content)
56
- config = DatasetConfig.model_validate(json.loads(config_content))
57
-
58
- self.layers = {}
59
- for layer_name, layer_config in config.layers.items():
60
- # Layer names must not contain period, since we use period to
61
- # distinguish different materialized groups within a layer.
62
- assert "." not in layer_name, "layer names must not contain periods"
63
- if layer_name in disabled_layers:
64
- logger.warning(f"Layer {layer_name} is disabled")
65
- continue
66
- self.layers[layer_name] = layer_config
67
-
68
- self.tile_store_config = config.tile_store
69
- self.storage = (
70
- config.storage.instantiate_window_storage_factory().get_storage(
71
- self.path
58
+ if dataset_config is None:
59
+ # Load dataset configuration from the dataset directory.
60
+ with (self.path / "config.json").open("r") as f:
61
+ config_content = f.read()
62
+ config_content = substitute_env_vars_in_string(config_content)
63
+ dataset_config = DatasetConfig.model_validate(
64
+ json.loads(config_content)
72
65
  )
66
+
67
+ self.layers = {}
68
+ for layer_name, layer_config in dataset_config.layers.items():
69
+ if layer_name in disabled_layers:
70
+ logger.warning(f"Layer {layer_name} is disabled")
71
+ continue
72
+ self.layers[layer_name] = layer_config
73
+
74
+ self.tile_store_config = dataset_config.tile_store
75
+ self.storage = (
76
+ dataset_config.storage.instantiate_window_storage_factory().get_storage(
77
+ self.path
73
78
  )
79
+ )
74
80
 
75
81
  def load_windows(
76
82
  self,
rslearn/lightning_cli.py CHANGED
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
21
21
  from rslearn.train.data_module import RslearnDataModule
22
22
  from rslearn.train.lightning_module import RslearnLightningModule
23
23
  from rslearn.utils.fsspec import open_atomic
24
+ from rslearn.utils.jsonargparse import init_jsonargparse
24
25
 
25
26
  WANDB_ID_FNAME = "wandb_id"
26
27
 
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
390
391
 
391
392
  Sets the dataset path for any configured RslearnPredictionWriter callbacks.
392
393
  """
393
- subcommand = self.config.subcommand
394
- c = self.config[subcommand]
394
+ if not hasattr(self.config, "subcommand"):
395
+ logger.warning(
396
+ "Config does not have subcommand attribute, assuming we are in run=False mode"
397
+ )
398
+ subcommand = None
399
+ c = self.config
400
+ else:
401
+ subcommand = self.config.subcommand
402
+ c = self.config[subcommand]
395
403
 
396
404
  # If there is a RslearnPredictionWriter, set its path.
397
405
  prediction_writer_callback = None
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
415
423
  if subcommand == "predict":
416
424
  c.return_predictions = False
417
425
 
418
- # For now we use DDP strategy with find_unused_parameters=True.
426
+ # Default to DDP with find_unused_parameters. Likely won't get called with unified config
419
427
  if subcommand == "fit":
420
- c.trainer.strategy = jsonargparse.Namespace(
421
- {
422
- "class_path": "lightning.pytorch.strategies.DDPStrategy",
423
- "init_args": jsonargparse.Namespace(
424
- {"find_unused_parameters": True}
425
- ),
426
- }
427
- )
428
+ if not c.trainer.strategy:
429
+ c.trainer.strategy = jsonargparse.Namespace(
430
+ {
431
+ "class_path": "lightning.pytorch.strategies.DDPStrategy",
432
+ "init_args": jsonargparse.Namespace(
433
+ {"find_unused_parameters": True}
434
+ ),
435
+ }
436
+ )
428
437
 
429
438
  if c.management_dir:
430
439
  self.enable_project_management(c.management_dir)
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
432
441
 
433
442
  def model_handler() -> None:
434
443
  """Handler for any rslearn model X commands."""
444
+ init_jsonargparse()
445
+
435
446
  RslearnLightningCLI(
436
447
  model_class=RslearnLightningModule,
437
448
  datamodule_class=RslearnDataModule,
rslearn/main.py CHANGED
@@ -380,7 +380,7 @@ def apply_on_windows(
380
380
 
381
381
  def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
382
382
  """Call apply_on_windows with arguments passed via command-line interface."""
383
- dataset = Dataset(UPath(args.root), args.disabled_layers)
383
+ dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
384
384
  apply_on_windows(
385
385
  f=f,
386
386
  dataset=dataset,
@@ -0,0 +1,177 @@
1
+ """An attention pooling layer."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from rslearn.models.component import (
12
+ FeatureMaps,
13
+ IntermediateComponent,
14
+ TokenFeatureMaps,
15
+ )
16
+ from rslearn.train.model_context import ModelContext
17
+
18
+
19
+ class SimpleAttentionPool(IntermediateComponent):
20
+ """Simple Attention Pooling.
21
+
22
+ Given a token feature map of shape BCHWN,
23
+ learn an attention layer which aggregates over
24
+ the N dimension.
25
+
26
+ This is done simply by learning a mapping D->1 which is the weight
27
+ which should be assigned to each token during averaging:
28
+
29
+ output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
30
+ """
31
+
32
+ def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
33
+ """Initialize the simple attention pooling layer.
34
+
35
+ Args:
36
+ in_dim: the encoding dimension D
37
+ hidden_linear: whether to apply an additional linear transformation D -> D
38
+ to the feat tokens. If this is True, a ReLU activation is applied
39
+ after the first linear transformation.
40
+ """
41
+ super().__init__()
42
+ if hidden_linear:
43
+ self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
44
+ else:
45
+ self.hidden_linear = None
46
+ self.linear = nn.Linear(in_features=in_dim, out_features=1)
47
+
48
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
49
+ """Attention pooling for a single feature map (BCHWN tensor)."""
50
+ B, D, H, W, N = feat_tokens.shape
51
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
52
+ if self.hidden_linear is not None:
53
+ feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
54
+ attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
55
+ feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
56
+ return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
57
+
58
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
59
+ """Forward pass for attention pooling linear probe.
60
+
61
+ Args:
62
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
63
+ We pool over the final dimension in the TokenFeatureMaps. If multiple maps
64
+ are passed, we apply the same linear layers to all of them.
65
+ context: the model context.
66
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
67
+
68
+ Returns:
69
+ torch.Tensor:
70
+ - output, attentioned pool over the last dimension (B, C, H, W)
71
+ """
72
+ if not isinstance(intermediates, TokenFeatureMaps):
73
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
74
+
75
+ features = []
76
+ for feat in intermediates.feature_maps:
77
+ features.append(self.forward_for_map(feat))
78
+ return FeatureMaps(features)
79
+
80
+
81
+ class AttentionPool(IntermediateComponent):
82
+ """Attention Pooling.
83
+
84
+ Given a feature map of shape BCHWN,
85
+ learn an attention layer which aggregates over
86
+ the N dimension.
87
+
88
+ We do this by learning a query token, and applying a standard
89
+ attention mechanism against this learned query token.
90
+ """
91
+
92
+ def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
93
+ """Initialize the attention pooling layer.
94
+
95
+ Args:
96
+ in_dim: the encoding dimension D
97
+ num_heads: the number of heads to use
98
+ linear_on_kv: Whether to apply a linear layer on the input tokens
99
+ to create the key and value tokens.
100
+ """
101
+ super().__init__()
102
+ self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
103
+ if linear_on_kv:
104
+ self.k_linear = nn.Linear(in_dim, in_dim)
105
+ self.v_linear = nn.Linear(in_dim, in_dim)
106
+ else:
107
+ self.k_linear = None
108
+ self.v_linear = None
109
+ if in_dim % num_heads != 0:
110
+ raise ValueError(
111
+ f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
112
+ )
113
+ self.num_heads = num_heads
114
+ self.init_weights()
115
+
116
+ def init_weights(self) -> None:
117
+ """Initialize weights for the probe."""
118
+ nn.init.trunc_normal_(self.query_token, std=0.02)
119
+
120
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
121
+ """Attention pooling for a single feature map (BCHWN tensor)."""
122
+ B, D, H, W, N = feat_tokens.shape
123
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
124
+ collapsed_dim = B * H * W
125
+ q = self.query_token.expand(collapsed_dim, 1, -1)
126
+ q = q.reshape(
127
+ collapsed_dim, 1, self.num_heads, D // self.num_heads
128
+ ) # [B, 1, head, D_head]
129
+ q = rearrange(q, "b h n d -> b n h d")
130
+ if self.k_linear is not None:
131
+ assert self.v_linear is not None
132
+ k = self.k_linear(feat_tokens).reshape(
133
+ collapsed_dim, N, self.num_heads, D // self.num_heads
134
+ )
135
+ v = self.v_linear(feat_tokens).reshape(
136
+ collapsed_dim, N, self.num_heads, D // self.num_heads
137
+ )
138
+ else:
139
+ k = feat_tokens.reshape(
140
+ collapsed_dim, N, self.num_heads, D // self.num_heads
141
+ )
142
+ v = feat_tokens.reshape(
143
+ collapsed_dim, N, self.num_heads, D // self.num_heads
144
+ )
145
+ k = rearrange(k, "b n h d -> b h n d")
146
+ v = rearrange(v, "b n h d -> b h n d")
147
+
148
+ # Compute attention scores
149
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
150
+ D // self.num_heads
151
+ )
152
+ attn_weights = F.softmax(attn_scores, dim=-1)
153
+ x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
+ return x.reshape(B, D, H, W)
155
+
156
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
+ """Forward pass for attention pooling linear probe.
158
+
159
+ Args:
160
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
161
+ We pool over the final dimension in the TokenFeatureMaps. If multiple feature
162
+ maps are passed, we apply the same attention weights (query token and linear k, v layers)
163
+ to all the maps.
164
+ context: the model context.
165
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
166
+
167
+ Returns:
168
+ torch.Tensor:
169
+ - output, attentioned pool over the last dimension (B, C, H, W)
170
+ """
171
+ if not isinstance(intermediates, TokenFeatureMaps):
172
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
173
+
174
+ features = []
175
+ for feat in intermediates.feature_maps:
176
+ features.append(self.forward_for_map(feat))
177
+ return FeatureMaps(features)
@@ -91,6 +91,18 @@ class FeatureMaps:
91
91
  feature_maps: list[torch.Tensor]
92
92
 
93
93
 
94
+ @dataclass
95
+ class TokenFeatureMaps:
96
+ """An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
97
+
98
+ Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
99
+ """
100
+
101
+ # List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
102
+ # (most fine-grained) to lowest resolution (coarsest).
103
+ feature_maps: list[torch.Tensor]
104
+
105
+
94
106
  @dataclass
95
107
  class FeatureVector:
96
108
  """An intermediate output type for a flat feature vector."""
@@ -19,7 +19,7 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
19
  from upath import UPath
20
20
 
21
21
  from rslearn.log_utils import get_logger
22
- from rslearn.models.component import FeatureExtractor, FeatureMaps
22
+ from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
23
23
  from rslearn.train.model_context import ModelContext
24
24
 
25
25
  logger = get_logger(__name__)
@@ -60,6 +60,7 @@ class OlmoEarth(FeatureExtractor):
60
60
  random_initialization: bool = False,
61
61
  embedding_size: int | None = None,
62
62
  autocast_dtype: str | None = "bfloat16",
63
+ token_pooling: bool = True,
63
64
  ):
64
65
  """Create a new OlmoEarth model.
65
66
 
@@ -83,6 +84,9 @@ class OlmoEarth(FeatureExtractor):
83
84
  embedding_size: optional embedding size to report via
84
85
  get_backbone_channels (if model_id is not set).
85
86
  autocast_dtype: which dtype to use for autocasting, or set None to disable.
87
+ token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
88
+ there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
89
+ dimensions.
86
90
  """
87
91
  if (
88
92
  sum(
@@ -133,6 +137,7 @@ class OlmoEarth(FeatureExtractor):
133
137
  else:
134
138
  model = model[part]
135
139
  self.model = model
140
+ self.token_pooling = token_pooling
136
141
 
137
142
  def _load_model_from_checkpoint(
138
143
  self, checkpoint_upath: UPath, random_initialization: bool
@@ -160,47 +165,87 @@ class OlmoEarth(FeatureExtractor):
160
165
 
161
166
  return model
162
167
 
163
- def forward(self, context: ModelContext) -> FeatureMaps:
164
- """Compute feature maps from the OlmoEarth backbone.
168
+ def _prepare_modality_inputs(
169
+ self, context: ModelContext
170
+ ) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
171
+ """Prepare modality tensors and masks for the OlmoEarth model.
172
+
173
+ Uses a two-pass approach to ensure all modalities have consistent timestep
174
+ dimensions for position encoding.
165
175
 
166
176
  Args:
167
- context: the model context. Input dicts should include keys corresponding
168
- to the modalities that should be passed to the OlmoEarth model.
177
+ context: the model context with input tensors.
169
178
 
170
179
  Returns:
171
- a FeatureMaps consisting of one feature map, at 1/patch_size of the input
172
- resolution. Embeddings will be pooled across modalities and timesteps.
180
+ tuple of (sample, present_modalities, device)
173
181
  """
174
182
  kwargs = {}
175
183
  present_modalities = []
176
184
  device = None
177
- # Handle the case where some modalities are multitemporal and some are not.
178
- # We assume all multitemporal modalities have the same number of timesteps.
185
+
186
+ # First pass: find global max_timesteps across all modalities and samples
187
+ # TODO: currently we assume all modalities have the same number of timesteps,
188
+ # which is not true for all cases, and time series time steps are assumed to
189
+ # be 1-month apart. It also assumes continuity between available timesteps.
190
+ # We'll have to fix all that.
179
191
  max_timesteps = 1
192
+ modality_data = {}
180
193
  for modality in MODALITY_NAMES:
181
194
  if modality not in context.inputs[0]:
182
195
  continue
183
196
  present_modalities.append(modality)
184
- cur = torch.stack([inp[modality] for inp in context.inputs], dim=0)
185
- device = cur.device
186
- # Check if it's single or multitemporal, and reshape accordingly
197
+ tensors = [inp[modality] for inp in context.inputs]
198
+ device = tensors[0].device
187
199
  num_bands = Modality.get(modality).num_bands
188
- num_timesteps = cur.shape[1] // num_bands
189
- max_timesteps = max(max_timesteps, num_timesteps)
190
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
200
+ max_t = max(t.shape[0] for t in tensors) // num_bands
201
+ max_timesteps = max(max_timesteps, max_t)
202
+ modality_data[modality] = (
203
+ tensors,
204
+ num_bands,
205
+ len(Modality.get(modality).band_sets),
206
+ )
207
+
208
+ # Second pass: pad and process each modality with global max_timesteps
209
+ for modality in present_modalities:
210
+ tensors, num_bands, num_band_sets = modality_data[modality]
211
+ target_ch = max_timesteps * num_bands
212
+
213
+ # Pad tensors to target_ch and track original timesteps for masking
214
+ padded = []
215
+ original_timesteps = []
216
+ for t in tensors:
217
+ orig_t = t.shape[0] // num_bands
218
+ original_timesteps.append(orig_t)
219
+ if t.shape[0] < target_ch:
220
+ pad = torch.zeros(
221
+ (target_ch - t.shape[0],) + t.shape[1:],
222
+ dtype=t.dtype,
223
+ device=device,
224
+ )
225
+ t = torch.cat([t, pad], dim=0)
226
+ padded.append(t)
227
+
228
+ cur = torch.stack(padded, dim=0)
229
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=max_timesteps)
191
230
  kwargs[modality] = cur
192
- # Create mask array which is BHWTS (without channels but with band sets).
193
- num_band_sets = len(Modality.get(modality).band_sets)
194
- mask_shape = cur.shape[0:4] + (num_band_sets,)
195
- mask = (
196
- torch.ones(mask_shape, dtype=torch.int32, device=device)
197
- * MaskValue.ONLINE_ENCODER.value
231
+
232
+ # Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
233
+ b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
234
+ mask = torch.full(
235
+ (b, h, w, max_timesteps, num_band_sets),
236
+ fill_value=MaskValue.ONLINE_ENCODER.value,
237
+ dtype=torch.int32,
238
+ device=device,
198
239
  )
240
+ for sample_idx, orig_t in enumerate(original_timesteps):
241
+ if orig_t < max_timesteps:
242
+ mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
199
243
  kwargs[f"{modality}_mask"] = mask
200
244
 
201
245
  # Timestamps is required.
202
246
  # Note that only months (0 to 11) are used in OlmoEarth position encoding.
203
- # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
247
+ # For now, we assign same timestamps to all inputs, but later we should
248
+ # handle varying timestamps per input.
204
249
  timestamps = torch.zeros(
205
250
  (len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
206
251
  )
@@ -211,7 +256,20 @@ class OlmoEarth(FeatureExtractor):
211
256
  timestamps[:, :, 2] = 2024 # year
212
257
  kwargs["timestamps"] = timestamps
213
258
 
214
- sample = MaskedOlmoEarthSample(**kwargs)
259
+ return MaskedOlmoEarthSample(**kwargs), present_modalities, device
260
+
261
+ def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
262
+ """Compute feature maps from the OlmoEarth backbone.
263
+
264
+ Args:
265
+ context: the model context. Input dicts should include keys corresponding
266
+ to the modalities that should be passed to the OlmoEarth model.
267
+
268
+ Returns:
269
+ a FeatureMaps consisting of one feature map, at 1/patch_size of the input
270
+ resolution. Embeddings will be pooled across modalities and timesteps.
271
+ """
272
+ sample, present_modalities, device = self._prepare_modality_inputs(context)
215
273
 
216
274
  # Decide context based on self.autocast_dtype.
217
275
  if self.autocast_dtype is None:
@@ -222,6 +280,14 @@ class OlmoEarth(FeatureExtractor):
222
280
  device_type=device.type, dtype=self.autocast_dtype
223
281
  )
224
282
 
283
+ # Check if we can bypass masks (fast_pass=True)
284
+ missing_tokens = False
285
+ for modality in present_modalities:
286
+ modality_mask = getattr(sample, f"{modality}_mask")
287
+ if torch.any(modality_mask == MaskValue.MISSING.value):
288
+ missing_tokens = True
289
+ break
290
+
225
291
  with torch_context:
226
292
  # Currently we assume the provided model always returns a TokensAndMasks object.
227
293
  tokens_and_masks: TokensAndMasks
@@ -229,7 +295,7 @@ class OlmoEarth(FeatureExtractor):
229
295
  # Encoder has a fast_pass argument to indicate mask is not needed.
230
296
  tokens_and_masks = self.model(
231
297
  sample,
232
- fast_pass=True,
298
+ fast_pass=not missing_tokens,
233
299
  patch_size=self.patch_size,
234
300
  **self.forward_kwargs,
235
301
  )["tokens_and_masks"]
@@ -241,16 +307,41 @@ class OlmoEarth(FeatureExtractor):
241
307
 
242
308
  # Apply temporal/modality pooling so we just have one feature per patch.
243
309
  features = []
244
- for modality in present_modalities:
245
- modality_features = getattr(tokens_and_masks, modality)
246
- # Pool over band sets and timesteps (BHWTSC -> BHWC).
247
- pooled = modality_features.mean(dim=[3, 4])
248
- # We want BHWC -> BCHW.
249
- pooled = rearrange(pooled, "b h w c -> b c h w")
250
- features.append(pooled)
251
- # Pool over the modalities, so we get one BCHW feature map.
252
- pooled = torch.stack(features, dim=0).mean(dim=0)
253
- return FeatureMaps([pooled])
310
+ if self.token_pooling:
311
+ for modality in present_modalities:
312
+ modality_features = getattr(tokens_and_masks, modality) # BHWTSC
313
+ # If fast_pass is False, we need to mask the missing tokens before pooling.
314
+ if missing_tokens:
315
+ modality_masks = getattr(
316
+ tokens_and_masks, f"{modality}_mask"
317
+ ) # BHWTS
318
+ modality_masks_bool = (
319
+ modality_masks != MaskValue.MISSING.value
320
+ ).unsqueeze(-1)
321
+ count = modality_masks_bool.sum(dim=[3, 4])
322
+ # Masked average over band sets and timesteps (BHWTSC -> BHWC).
323
+ pooled = (modality_features * modality_masks_bool).sum(
324
+ dim=[3, 4]
325
+ ) / count.clamp(min=1)
326
+ else:
327
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
328
+ pooled = modality_features.mean(dim=[3, 4])
329
+ # We want BHWC -> BCHW.
330
+ pooled = rearrange(pooled, "b h w c -> b c h w")
331
+ features.append(pooled)
332
+ # Pool over the modalities, so we get one BCHW feature map.
333
+ pooled = torch.stack(features, dim=0).mean(dim=0)
334
+ return FeatureMaps([pooled])
335
+ else:
336
+ for modality in present_modalities:
337
+ modality_features = getattr(tokens_and_masks, modality)
338
+ # Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
339
+ modality_features = rearrange(
340
+ modality_features, "b h w t s c -> b c h w (t s)"
341
+ )
342
+ features.append(modality_features)
343
+ pooled = torch.cat(features, dim=-1)
344
+ return TokenFeatureMaps([pooled])
254
345
 
255
346
  def get_backbone_channels(self) -> list:
256
347
  """Returns the output channels of this model when used as a backbone.