rslearn 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/component.py +12 -0
- rslearn/models/olmoearth_pretrain/model.py +125 -34
- rslearn/models/simple_time_series.py +7 -1
- rslearn/train/all_patches_dataset.py +67 -19
- rslearn/train/dataset.py +36 -43
- rslearn/train/scheduler.py +15 -0
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/RECORD +22 -20
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
from jsonargparse import Namespace
|
|
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
21
20
|
def parse_string(
|
|
22
21
|
self,
|
|
23
22
|
cfg_str: str,
|
|
24
|
-
|
|
25
|
-
ext_vars: dict | None = None,
|
|
26
|
-
env: bool | None = None,
|
|
27
|
-
defaults: bool = True,
|
|
28
|
-
with_meta: bool | None = None,
|
|
23
|
+
*args: Any,
|
|
29
24
|
**kwargs: Any,
|
|
30
25
|
) -> Namespace:
|
|
31
26
|
"""Pre-processes string for environment variable substitution before parsing."""
|
|
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
33
28
|
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
34
29
|
|
|
35
30
|
# Call the parent method with the substituted config
|
|
36
|
-
return super().parse_string(
|
|
37
|
-
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
38
|
-
)
|
|
31
|
+
return super().parse_string(substituted_cfg_str, *args, **kwargs)
|
rslearn/config/dataset.py
CHANGED
|
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
|
|
|
25
25
|
from upath import UPath
|
|
26
26
|
|
|
27
27
|
from rslearn.log_utils import get_logger
|
|
28
|
-
from rslearn.utils import PixelBounds, Projection
|
|
28
|
+
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
|
|
29
29
|
from rslearn.utils.raster_format import RasterFormat
|
|
30
30
|
from rslearn.utils.vector_format import VectorFormat
|
|
31
31
|
|
|
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
|
|
|
215
215
|
Returns:
|
|
216
216
|
tuple of updated projection and bounds with zoom offset applied
|
|
217
217
|
"""
|
|
218
|
-
if self.zoom_offset
|
|
219
|
-
|
|
220
|
-
projection = Projection(
|
|
221
|
-
projection.crs,
|
|
222
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
223
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
224
|
-
)
|
|
225
|
-
if self.zoom_offset > 0:
|
|
226
|
-
zoom_factor = 2**self.zoom_offset
|
|
227
|
-
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
218
|
+
if self.zoom_offset >= 0:
|
|
219
|
+
factor = ResolutionFactor(numerator=2**self.zoom_offset)
|
|
228
220
|
else:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
)
|
|
233
|
-
return projection, bounds
|
|
221
|
+
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
|
|
222
|
+
|
|
223
|
+
return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
|
|
234
224
|
|
|
235
225
|
@field_validator("format", mode="before")
|
|
236
226
|
@classmethod
|
|
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
|
|
|
645
635
|
default_factory=lambda: StorageConfig(),
|
|
646
636
|
description="jsonargparse configuration for the WindowStorageFactory.",
|
|
647
637
|
)
|
|
638
|
+
|
|
639
|
+
@field_validator("layers", mode="after")
|
|
640
|
+
@classmethod
|
|
641
|
+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
|
|
642
|
+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
|
|
643
|
+
for layer_name in v.keys():
|
|
644
|
+
if "." in layer_name:
|
|
645
|
+
raise ValueError(f"layer names must not contain periods: {layer_name}")
|
|
646
|
+
return v
|
rslearn/dataset/dataset.py
CHANGED
|
@@ -23,7 +23,7 @@ class Dataset:
|
|
|
23
23
|
.. code-block:: none
|
|
24
24
|
|
|
25
25
|
dataset/
|
|
26
|
-
config.json
|
|
26
|
+
config.json # optional, if config provided as runtime object
|
|
27
27
|
windows/
|
|
28
28
|
group1/
|
|
29
29
|
epsg:3857_10_623565_1528020/
|
|
@@ -40,37 +40,43 @@ class Dataset:
|
|
|
40
40
|
materialize.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
path: UPath,
|
|
46
|
+
disabled_layers: list[str] = [],
|
|
47
|
+
dataset_config: DatasetConfig | None = None,
|
|
48
|
+
) -> None:
|
|
44
49
|
"""Initializes a new Dataset.
|
|
45
50
|
|
|
46
51
|
Args:
|
|
47
52
|
path: the root directory of the dataset
|
|
48
53
|
disabled_layers: list of layers to disable
|
|
54
|
+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
|
|
49
55
|
"""
|
|
50
56
|
self.path = path
|
|
51
57
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for layer_name, layer_config in config.layers.items():
|
|
60
|
-
# Layer names must not contain period, since we use period to
|
|
61
|
-
# distinguish different materialized groups within a layer.
|
|
62
|
-
assert "." not in layer_name, "layer names must not contain periods"
|
|
63
|
-
if layer_name in disabled_layers:
|
|
64
|
-
logger.warning(f"Layer {layer_name} is disabled")
|
|
65
|
-
continue
|
|
66
|
-
self.layers[layer_name] = layer_config
|
|
67
|
-
|
|
68
|
-
self.tile_store_config = config.tile_store
|
|
69
|
-
self.storage = (
|
|
70
|
-
config.storage.instantiate_window_storage_factory().get_storage(
|
|
71
|
-
self.path
|
|
58
|
+
if dataset_config is None:
|
|
59
|
+
# Load dataset configuration from the dataset directory.
|
|
60
|
+
with (self.path / "config.json").open("r") as f:
|
|
61
|
+
config_content = f.read()
|
|
62
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
63
|
+
dataset_config = DatasetConfig.model_validate(
|
|
64
|
+
json.loads(config_content)
|
|
72
65
|
)
|
|
66
|
+
|
|
67
|
+
self.layers = {}
|
|
68
|
+
for layer_name, layer_config in dataset_config.layers.items():
|
|
69
|
+
if layer_name in disabled_layers:
|
|
70
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
71
|
+
continue
|
|
72
|
+
self.layers[layer_name] = layer_config
|
|
73
|
+
|
|
74
|
+
self.tile_store_config = dataset_config.tile_store
|
|
75
|
+
self.storage = (
|
|
76
|
+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
77
|
+
self.path
|
|
73
78
|
)
|
|
79
|
+
)
|
|
74
80
|
|
|
75
81
|
def load_windows(
|
|
76
82
|
self,
|
rslearn/lightning_cli.py
CHANGED
|
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
|
|
|
21
21
|
from rslearn.train.data_module import RslearnDataModule
|
|
22
22
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
23
23
|
from rslearn.utils.fsspec import open_atomic
|
|
24
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
24
25
|
|
|
25
26
|
WANDB_ID_FNAME = "wandb_id"
|
|
26
27
|
|
|
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
390
391
|
|
|
391
392
|
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
392
393
|
"""
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
if not hasattr(self.config, "subcommand"):
|
|
395
|
+
logger.warning(
|
|
396
|
+
"Config does not have subcommand attribute, assuming we are in run=False mode"
|
|
397
|
+
)
|
|
398
|
+
subcommand = None
|
|
399
|
+
c = self.config
|
|
400
|
+
else:
|
|
401
|
+
subcommand = self.config.subcommand
|
|
402
|
+
c = self.config[subcommand]
|
|
395
403
|
|
|
396
404
|
# If there is a RslearnPredictionWriter, set its path.
|
|
397
405
|
prediction_writer_callback = None
|
|
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
415
423
|
if subcommand == "predict":
|
|
416
424
|
c.return_predictions = False
|
|
417
425
|
|
|
418
|
-
#
|
|
426
|
+
# Default to DDP with find_unused_parameters. Likely won't get called with unified config
|
|
419
427
|
if subcommand == "fit":
|
|
420
|
-
c.trainer.strategy
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
+
if not c.trainer.strategy:
|
|
429
|
+
c.trainer.strategy = jsonargparse.Namespace(
|
|
430
|
+
{
|
|
431
|
+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
|
|
432
|
+
"init_args": jsonargparse.Namespace(
|
|
433
|
+
{"find_unused_parameters": True}
|
|
434
|
+
),
|
|
435
|
+
}
|
|
436
|
+
)
|
|
428
437
|
|
|
429
438
|
if c.management_dir:
|
|
430
439
|
self.enable_project_management(c.management_dir)
|
|
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
432
441
|
|
|
433
442
|
def model_handler() -> None:
|
|
434
443
|
"""Handler for any rslearn model X commands."""
|
|
444
|
+
init_jsonargparse()
|
|
445
|
+
|
|
435
446
|
RslearnLightningCLI(
|
|
436
447
|
model_class=RslearnLightningModule,
|
|
437
448
|
datamodule_class=RslearnDataModule,
|
rslearn/main.py
CHANGED
|
@@ -380,7 +380,7 @@ def apply_on_windows(
|
|
|
380
380
|
|
|
381
381
|
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
382
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
383
|
-
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
383
|
+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
|
|
384
384
|
apply_on_windows(
|
|
385
385
|
f=f,
|
|
386
386
|
dataset=dataset,
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""An attention pooling layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import (
|
|
12
|
+
FeatureMaps,
|
|
13
|
+
IntermediateComponent,
|
|
14
|
+
TokenFeatureMaps,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleAttentionPool(IntermediateComponent):
|
|
20
|
+
"""Simple Attention Pooling.
|
|
21
|
+
|
|
22
|
+
Given a token feature map of shape BCHWN,
|
|
23
|
+
learn an attention layer which aggregates over
|
|
24
|
+
the N dimension.
|
|
25
|
+
|
|
26
|
+
This is done simply by learning a mapping D->1 which is the weight
|
|
27
|
+
which should be assigned to each token during averaging:
|
|
28
|
+
|
|
29
|
+
output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
|
|
33
|
+
"""Initialize the simple attention pooling layer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
in_dim: the encoding dimension D
|
|
37
|
+
hidden_linear: whether to apply an additional linear transformation D -> D
|
|
38
|
+
to the feat tokens. If this is True, a ReLU activation is applied
|
|
39
|
+
after the first linear transformation.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
if hidden_linear:
|
|
43
|
+
self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
|
|
44
|
+
else:
|
|
45
|
+
self.hidden_linear = None
|
|
46
|
+
self.linear = nn.Linear(in_features=in_dim, out_features=1)
|
|
47
|
+
|
|
48
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
50
|
+
B, D, H, W, N = feat_tokens.shape
|
|
51
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
52
|
+
if self.hidden_linear is not None:
|
|
53
|
+
feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
|
|
54
|
+
attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
|
|
55
|
+
feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
|
|
56
|
+
return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
|
|
57
|
+
|
|
58
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
59
|
+
"""Forward pass for attention pooling linear probe.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
63
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple maps
|
|
64
|
+
are passed, we apply the same linear layers to all of them.
|
|
65
|
+
context: the model context.
|
|
66
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor:
|
|
70
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
71
|
+
"""
|
|
72
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
73
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
74
|
+
|
|
75
|
+
features = []
|
|
76
|
+
for feat in intermediates.feature_maps:
|
|
77
|
+
features.append(self.forward_for_map(feat))
|
|
78
|
+
return FeatureMaps(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AttentionPool(IntermediateComponent):
|
|
82
|
+
"""Attention Pooling.
|
|
83
|
+
|
|
84
|
+
Given a feature map of shape BCHWN,
|
|
85
|
+
learn an attention layer which aggregates over
|
|
86
|
+
the N dimension.
|
|
87
|
+
|
|
88
|
+
We do this by learning a query token, and applying a standard
|
|
89
|
+
attention mechanism against this learned query token.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
|
|
93
|
+
"""Initialize the attention pooling layer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
in_dim: the encoding dimension D
|
|
97
|
+
num_heads: the number of heads to use
|
|
98
|
+
linear_on_kv: Whether to apply a linear layer on the input tokens
|
|
99
|
+
to create the key and value tokens.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
|
|
103
|
+
if linear_on_kv:
|
|
104
|
+
self.k_linear = nn.Linear(in_dim, in_dim)
|
|
105
|
+
self.v_linear = nn.Linear(in_dim, in_dim)
|
|
106
|
+
else:
|
|
107
|
+
self.k_linear = None
|
|
108
|
+
self.v_linear = None
|
|
109
|
+
if in_dim % num_heads != 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
|
|
112
|
+
)
|
|
113
|
+
self.num_heads = num_heads
|
|
114
|
+
self.init_weights()
|
|
115
|
+
|
|
116
|
+
def init_weights(self) -> None:
|
|
117
|
+
"""Initialize weights for the probe."""
|
|
118
|
+
nn.init.trunc_normal_(self.query_token, std=0.02)
|
|
119
|
+
|
|
120
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
122
|
+
B, D, H, W, N = feat_tokens.shape
|
|
123
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
124
|
+
collapsed_dim = B * H * W
|
|
125
|
+
q = self.query_token.expand(collapsed_dim, 1, -1)
|
|
126
|
+
q = q.reshape(
|
|
127
|
+
collapsed_dim, 1, self.num_heads, D // self.num_heads
|
|
128
|
+
) # [B, 1, head, D_head]
|
|
129
|
+
q = rearrange(q, "b h n d -> b n h d")
|
|
130
|
+
if self.k_linear is not None:
|
|
131
|
+
assert self.v_linear is not None
|
|
132
|
+
k = self.k_linear(feat_tokens).reshape(
|
|
133
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
134
|
+
)
|
|
135
|
+
v = self.v_linear(feat_tokens).reshape(
|
|
136
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
k = feat_tokens.reshape(
|
|
140
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
141
|
+
)
|
|
142
|
+
v = feat_tokens.reshape(
|
|
143
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
144
|
+
)
|
|
145
|
+
k = rearrange(k, "b n h d -> b h n d")
|
|
146
|
+
v = rearrange(v, "b n h d -> b h n d")
|
|
147
|
+
|
|
148
|
+
# Compute attention scores
|
|
149
|
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
|
150
|
+
D // self.num_heads
|
|
151
|
+
)
|
|
152
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
|
|
154
|
+
return x.reshape(B, D, H, W)
|
|
155
|
+
|
|
156
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
|
+
"""Forward pass for attention pooling linear probe.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
161
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple feature
|
|
162
|
+
maps are passed, we apply the same attention weights (query token and linear k, v layers)
|
|
163
|
+
to all the maps.
|
|
164
|
+
context: the model context.
|
|
165
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor:
|
|
169
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
170
|
+
"""
|
|
171
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
172
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
173
|
+
|
|
174
|
+
features = []
|
|
175
|
+
for feat in intermediates.feature_maps:
|
|
176
|
+
features.append(self.forward_for_map(feat))
|
|
177
|
+
return FeatureMaps(features)
|
rslearn/models/component.py
CHANGED
|
@@ -91,6 +91,18 @@ class FeatureMaps:
|
|
|
91
91
|
feature_maps: list[torch.Tensor]
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
@dataclass
|
|
95
|
+
class TokenFeatureMaps:
|
|
96
|
+
"""An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
|
|
97
|
+
|
|
98
|
+
Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
|
|
102
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
103
|
+
feature_maps: list[torch.Tensor]
|
|
104
|
+
|
|
105
|
+
|
|
94
106
|
@dataclass
|
|
95
107
|
class FeatureVector:
|
|
96
108
|
"""An intermediate output type for a flat feature vector."""
|
|
@@ -19,7 +19,7 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
21
|
from rslearn.log_utils import get_logger
|
|
22
|
-
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
22
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
|
|
23
23
|
from rslearn.train.model_context import ModelContext
|
|
24
24
|
|
|
25
25
|
logger = get_logger(__name__)
|
|
@@ -60,6 +60,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
60
60
|
random_initialization: bool = False,
|
|
61
61
|
embedding_size: int | None = None,
|
|
62
62
|
autocast_dtype: str | None = "bfloat16",
|
|
63
|
+
token_pooling: bool = True,
|
|
63
64
|
):
|
|
64
65
|
"""Create a new OlmoEarth model.
|
|
65
66
|
|
|
@@ -83,6 +84,9 @@ class OlmoEarth(FeatureExtractor):
|
|
|
83
84
|
embedding_size: optional embedding size to report via
|
|
84
85
|
get_backbone_channels (if model_id is not set).
|
|
85
86
|
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
87
|
+
token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
|
|
88
|
+
there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
|
|
89
|
+
dimensions.
|
|
86
90
|
"""
|
|
87
91
|
if (
|
|
88
92
|
sum(
|
|
@@ -133,6 +137,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
133
137
|
else:
|
|
134
138
|
model = model[part]
|
|
135
139
|
self.model = model
|
|
140
|
+
self.token_pooling = token_pooling
|
|
136
141
|
|
|
137
142
|
def _load_model_from_checkpoint(
|
|
138
143
|
self, checkpoint_upath: UPath, random_initialization: bool
|
|
@@ -160,47 +165,87 @@ class OlmoEarth(FeatureExtractor):
|
|
|
160
165
|
|
|
161
166
|
return model
|
|
162
167
|
|
|
163
|
-
def
|
|
164
|
-
|
|
168
|
+
def _prepare_modality_inputs(
|
|
169
|
+
self, context: ModelContext
|
|
170
|
+
) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
|
|
171
|
+
"""Prepare modality tensors and masks for the OlmoEarth model.
|
|
172
|
+
|
|
173
|
+
Uses a two-pass approach to ensure all modalities have consistent timestep
|
|
174
|
+
dimensions for position encoding.
|
|
165
175
|
|
|
166
176
|
Args:
|
|
167
|
-
context: the model context
|
|
168
|
-
to the modalities that should be passed to the OlmoEarth model.
|
|
177
|
+
context: the model context with input tensors.
|
|
169
178
|
|
|
170
179
|
Returns:
|
|
171
|
-
|
|
172
|
-
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
180
|
+
tuple of (sample, present_modalities, device)
|
|
173
181
|
"""
|
|
174
182
|
kwargs = {}
|
|
175
183
|
present_modalities = []
|
|
176
184
|
device = None
|
|
177
|
-
|
|
178
|
-
#
|
|
185
|
+
|
|
186
|
+
# First pass: find global max_timesteps across all modalities and samples
|
|
187
|
+
# TODO: currently we assume all modalities have the same number of timesteps,
|
|
188
|
+
# which is not true for all cases, and time series time steps are assumed to
|
|
189
|
+
# be 1-month apart. It also assumes continuity between available timesteps.
|
|
190
|
+
# We'll have to fix all that.
|
|
179
191
|
max_timesteps = 1
|
|
192
|
+
modality_data = {}
|
|
180
193
|
for modality in MODALITY_NAMES:
|
|
181
194
|
if modality not in context.inputs[0]:
|
|
182
195
|
continue
|
|
183
196
|
present_modalities.append(modality)
|
|
184
|
-
|
|
185
|
-
device =
|
|
186
|
-
# Check if it's single or multitemporal, and reshape accordingly
|
|
197
|
+
tensors = [inp[modality] for inp in context.inputs]
|
|
198
|
+
device = tensors[0].device
|
|
187
199
|
num_bands = Modality.get(modality).num_bands
|
|
188
|
-
|
|
189
|
-
max_timesteps = max(max_timesteps,
|
|
190
|
-
|
|
200
|
+
max_t = max(t.shape[0] for t in tensors) // num_bands
|
|
201
|
+
max_timesteps = max(max_timesteps, max_t)
|
|
202
|
+
modality_data[modality] = (
|
|
203
|
+
tensors,
|
|
204
|
+
num_bands,
|
|
205
|
+
len(Modality.get(modality).band_sets),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Second pass: pad and process each modality with global max_timesteps
|
|
209
|
+
for modality in present_modalities:
|
|
210
|
+
tensors, num_bands, num_band_sets = modality_data[modality]
|
|
211
|
+
target_ch = max_timesteps * num_bands
|
|
212
|
+
|
|
213
|
+
# Pad tensors to target_ch and track original timesteps for masking
|
|
214
|
+
padded = []
|
|
215
|
+
original_timesteps = []
|
|
216
|
+
for t in tensors:
|
|
217
|
+
orig_t = t.shape[0] // num_bands
|
|
218
|
+
original_timesteps.append(orig_t)
|
|
219
|
+
if t.shape[0] < target_ch:
|
|
220
|
+
pad = torch.zeros(
|
|
221
|
+
(target_ch - t.shape[0],) + t.shape[1:],
|
|
222
|
+
dtype=t.dtype,
|
|
223
|
+
device=device,
|
|
224
|
+
)
|
|
225
|
+
t = torch.cat([t, pad], dim=0)
|
|
226
|
+
padded.append(t)
|
|
227
|
+
|
|
228
|
+
cur = torch.stack(padded, dim=0)
|
|
229
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=max_timesteps)
|
|
191
230
|
kwargs[modality] = cur
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
mask = (
|
|
196
|
-
|
|
197
|
-
|
|
231
|
+
|
|
232
|
+
# Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
|
|
233
|
+
b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
|
|
234
|
+
mask = torch.full(
|
|
235
|
+
(b, h, w, max_timesteps, num_band_sets),
|
|
236
|
+
fill_value=MaskValue.ONLINE_ENCODER.value,
|
|
237
|
+
dtype=torch.int32,
|
|
238
|
+
device=device,
|
|
198
239
|
)
|
|
240
|
+
for sample_idx, orig_t in enumerate(original_timesteps):
|
|
241
|
+
if orig_t < max_timesteps:
|
|
242
|
+
mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
|
|
199
243
|
kwargs[f"{modality}_mask"] = mask
|
|
200
244
|
|
|
201
245
|
# Timestamps is required.
|
|
202
246
|
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
203
|
-
# For now, we assign same timestamps to all inputs, but later we should
|
|
247
|
+
# For now, we assign same timestamps to all inputs, but later we should
|
|
248
|
+
# handle varying timestamps per input.
|
|
204
249
|
timestamps = torch.zeros(
|
|
205
250
|
(len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
206
251
|
)
|
|
@@ -211,7 +256,20 @@ class OlmoEarth(FeatureExtractor):
|
|
|
211
256
|
timestamps[:, :, 2] = 2024 # year
|
|
212
257
|
kwargs["timestamps"] = timestamps
|
|
213
258
|
|
|
214
|
-
|
|
259
|
+
return MaskedOlmoEarthSample(**kwargs), present_modalities, device
|
|
260
|
+
|
|
261
|
+
def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
|
|
262
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
context: the model context. Input dicts should include keys corresponding
|
|
266
|
+
to the modalities that should be passed to the OlmoEarth model.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
a FeatureMaps consisting of one feature map, at 1/patch_size of the input
|
|
270
|
+
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
271
|
+
"""
|
|
272
|
+
sample, present_modalities, device = self._prepare_modality_inputs(context)
|
|
215
273
|
|
|
216
274
|
# Decide context based on self.autocast_dtype.
|
|
217
275
|
if self.autocast_dtype is None:
|
|
@@ -222,6 +280,14 @@ class OlmoEarth(FeatureExtractor):
|
|
|
222
280
|
device_type=device.type, dtype=self.autocast_dtype
|
|
223
281
|
)
|
|
224
282
|
|
|
283
|
+
# Check if we can bypass masks (fast_pass=True)
|
|
284
|
+
missing_tokens = False
|
|
285
|
+
for modality in present_modalities:
|
|
286
|
+
modality_mask = getattr(sample, f"{modality}_mask")
|
|
287
|
+
if torch.any(modality_mask == MaskValue.MISSING.value):
|
|
288
|
+
missing_tokens = True
|
|
289
|
+
break
|
|
290
|
+
|
|
225
291
|
with torch_context:
|
|
226
292
|
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
227
293
|
tokens_and_masks: TokensAndMasks
|
|
@@ -229,7 +295,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
229
295
|
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
230
296
|
tokens_and_masks = self.model(
|
|
231
297
|
sample,
|
|
232
|
-
fast_pass=
|
|
298
|
+
fast_pass=not missing_tokens,
|
|
233
299
|
patch_size=self.patch_size,
|
|
234
300
|
**self.forward_kwargs,
|
|
235
301
|
)["tokens_and_masks"]
|
|
@@ -241,16 +307,41 @@ class OlmoEarth(FeatureExtractor):
|
|
|
241
307
|
|
|
242
308
|
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
243
309
|
features = []
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
310
|
+
if self.token_pooling:
|
|
311
|
+
for modality in present_modalities:
|
|
312
|
+
modality_features = getattr(tokens_and_masks, modality) # BHWTSC
|
|
313
|
+
# If fast_pass is False, we need to mask the missing tokens before pooling.
|
|
314
|
+
if missing_tokens:
|
|
315
|
+
modality_masks = getattr(
|
|
316
|
+
tokens_and_masks, f"{modality}_mask"
|
|
317
|
+
) # BHWTS
|
|
318
|
+
modality_masks_bool = (
|
|
319
|
+
modality_masks != MaskValue.MISSING.value
|
|
320
|
+
).unsqueeze(-1)
|
|
321
|
+
count = modality_masks_bool.sum(dim=[3, 4])
|
|
322
|
+
# Masked average over band sets and timesteps (BHWTSC -> BHWC).
|
|
323
|
+
pooled = (modality_features * modality_masks_bool).sum(
|
|
324
|
+
dim=[3, 4]
|
|
325
|
+
) / count.clamp(min=1)
|
|
326
|
+
else:
|
|
327
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
328
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
329
|
+
# We want BHWC -> BCHW.
|
|
330
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
331
|
+
features.append(pooled)
|
|
332
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
333
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
334
|
+
return FeatureMaps([pooled])
|
|
335
|
+
else:
|
|
336
|
+
for modality in present_modalities:
|
|
337
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
338
|
+
# Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
|
|
339
|
+
modality_features = rearrange(
|
|
340
|
+
modality_features, "b h w t s c -> b c h w (t s)"
|
|
341
|
+
)
|
|
342
|
+
features.append(modality_features)
|
|
343
|
+
pooled = torch.cat(features, dim=-1)
|
|
344
|
+
return TokenFeatureMaps([pooled])
|
|
254
345
|
|
|
255
346
|
def get_backbone_channels(self) -> list:
|
|
256
347
|
"""Returns the output channels of this model when used as a backbone.
|