rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/anysat.py +35 -33
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clip.py +5 -2
- rslearn/models/component.py +12 -0
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +206 -51
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +43 -17
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +96 -28
- rslearn/train/dataset.py +102 -53
- rslearn/train/model_context.py +35 -1
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import math
|
|
4
4
|
import tempfile
|
|
5
5
|
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
from enum import StrEnum
|
|
7
8
|
from typing import cast
|
|
8
9
|
|
|
@@ -411,6 +412,23 @@ class GalileoModel(FeatureExtractor):
|
|
|
411
412
|
months=months,
|
|
412
413
|
)
|
|
413
414
|
|
|
415
|
+
@staticmethod
|
|
416
|
+
def time_ranges_to_timestamps(
|
|
417
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
418
|
+
device: torch.device,
|
|
419
|
+
) -> torch.Tensor:
|
|
420
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
|
|
421
|
+
|
|
422
|
+
Galileo only uses the month associated with each timestamp, so we take the midpoint
|
|
423
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
424
|
+
time so that start_time == end_time == mid_time.
|
|
425
|
+
"""
|
|
426
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
427
|
+
# months are indexed 0-11
|
|
428
|
+
return torch.tensor(
|
|
429
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
430
|
+
)
|
|
431
|
+
|
|
414
432
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
415
433
|
"""Compute feature maps from the Galileo backbone.
|
|
416
434
|
|
|
@@ -418,16 +436,16 @@ class GalileoModel(FeatureExtractor):
|
|
|
418
436
|
context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
|
|
419
437
|
(also documented below) and values are tensors of the following shapes,
|
|
420
438
|
per input key:
|
|
421
|
-
"s1": B
|
|
422
|
-
"s2": B
|
|
423
|
-
"era5": B
|
|
424
|
-
"tc": B
|
|
425
|
-
"viirs": B
|
|
426
|
-
"srtm": B C H W (SRTM has no temporal dimension)
|
|
427
|
-
"dw": : B C H W (Dynamic World should be averaged over time)
|
|
428
|
-
"wc": B C H W (WorldCereal has no temporal dimension)
|
|
429
|
-
"landscan": B C H W (we will average over the H, W dimensions)
|
|
430
|
-
"latlon": B C H W (we will average over the H, W dimensions)
|
|
439
|
+
"s1": B C T H W
|
|
440
|
+
"s2": B C T H W
|
|
441
|
+
"era5": B C T H W (we will average over the H, W dimensions)
|
|
442
|
+
"tc": B C T H W (we will average over the H, W dimensions)
|
|
443
|
+
"viirs": B C T H W (we will average over the H, W dimensions)
|
|
444
|
+
"srtm": B C 1 H W (SRTM has no temporal dimension)
|
|
445
|
+
"dw": : B C 1 H W (Dynamic World should be averaged over time)
|
|
446
|
+
"wc": B C 1 H W (WorldCereal has no temporal dimension)
|
|
447
|
+
"landscan": B C 1 H W (we will average over the H, W dimensions)
|
|
448
|
+
"latlon": B C 1 H W (we will average over the H, W dimensions)
|
|
431
449
|
|
|
432
450
|
The output will be an embedding representing the pooled tokens. If there is
|
|
433
451
|
only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
|
|
@@ -436,15 +454,35 @@ class GalileoModel(FeatureExtractor):
|
|
|
436
454
|
If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
|
|
437
455
|
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
438
456
|
"""
|
|
457
|
+
space_time_modalities = ["s1", "s2"]
|
|
458
|
+
time_modalities = ["era5", "tc", "viirs"]
|
|
439
459
|
stacked_inputs = {}
|
|
460
|
+
months: torch.Tensor | None = None
|
|
440
461
|
for key in context.inputs[0].keys():
|
|
441
462
|
# assume all the keys in an input are consistent
|
|
442
463
|
if key in self.input_keys:
|
|
443
464
|
stacked_inputs[key] = torch.stack(
|
|
444
|
-
[inp[key] for inp in context.inputs], dim=0
|
|
465
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
445
466
|
)
|
|
467
|
+
if key in space_time_modalities + time_modalities:
|
|
468
|
+
if months is None:
|
|
469
|
+
if context.inputs[0][key].timestamps is not None:
|
|
470
|
+
months = torch.stack(
|
|
471
|
+
[
|
|
472
|
+
self.time_ranges_to_timestamps(
|
|
473
|
+
inp[key].timestamps, # type: ignore
|
|
474
|
+
device=stacked_inputs[key].device,
|
|
475
|
+
)
|
|
476
|
+
for inp in context.inputs
|
|
477
|
+
],
|
|
478
|
+
dim=0,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if months is not None:
|
|
482
|
+
stacked_inputs["months"] = months
|
|
483
|
+
|
|
446
484
|
s_t_channels = []
|
|
447
|
-
for space_time_modality in
|
|
485
|
+
for space_time_modality in space_time_modalities:
|
|
448
486
|
if space_time_modality not in stacked_inputs:
|
|
449
487
|
continue
|
|
450
488
|
if space_time_modality == "s1":
|
|
@@ -452,36 +490,27 @@ class GalileoModel(FeatureExtractor):
|
|
|
452
490
|
else:
|
|
453
491
|
s_t_channels += self.s_t_channels_s2
|
|
454
492
|
cur = stacked_inputs[space_time_modality]
|
|
455
|
-
|
|
456
|
-
num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
|
|
457
|
-
num_timesteps = cur.shape[1] // num_bands
|
|
458
|
-
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
493
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
459
494
|
stacked_inputs[space_time_modality] = cur
|
|
460
495
|
|
|
461
496
|
for space_modality in ["srtm", "dw", "wc"]:
|
|
462
497
|
if space_modality not in stacked_inputs:
|
|
463
498
|
continue
|
|
499
|
+
# take the first (and assumed only) timestep
|
|
500
|
+
stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
|
|
464
501
|
stacked_inputs[space_modality] = rearrange(
|
|
465
502
|
stacked_inputs[space_modality], "b c h w -> b h w c"
|
|
466
503
|
)
|
|
467
504
|
|
|
468
|
-
for time_modality in
|
|
505
|
+
for time_modality in time_modalities:
|
|
469
506
|
if time_modality not in stacked_inputs:
|
|
470
507
|
continue
|
|
471
508
|
cur = stacked_inputs[time_modality]
|
|
472
|
-
# Check if it's single or multitemporal, and reshape accordingly
|
|
473
|
-
num_bands = {
|
|
474
|
-
"era5": len(ERA5_BANDS),
|
|
475
|
-
"tc": len(TC_BANDS),
|
|
476
|
-
"viirs": len(VIIRS_BANDS),
|
|
477
|
-
}[time_modality]
|
|
478
|
-
num_timesteps = cur.shape[1] // num_bands
|
|
479
509
|
# take the average over the h, w bands since Galileo
|
|
480
510
|
# treats it as a pixel-timeseries
|
|
481
511
|
cur = rearrange(
|
|
482
|
-
torch.nanmean(
|
|
483
|
-
"b
|
|
484
|
-
t=num_timesteps,
|
|
512
|
+
torch.nanmean(cur, dim=(-1, -2)),
|
|
513
|
+
"b c t -> b t c",
|
|
485
514
|
)
|
|
486
515
|
stacked_inputs[time_modality] = cur
|
|
487
516
|
|
|
@@ -489,9 +518,8 @@ class GalileoModel(FeatureExtractor):
|
|
|
489
518
|
if static_modality not in stacked_inputs:
|
|
490
519
|
continue
|
|
491
520
|
cur = stacked_inputs[static_modality]
|
|
492
|
-
stacked_inputs[static_modality] = torch.nanmean(
|
|
493
|
-
|
|
494
|
-
)
|
|
521
|
+
stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
|
|
522
|
+
|
|
495
523
|
galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
|
|
496
524
|
h = galileo_input.s_t_x.shape[1]
|
|
497
525
|
if h < self.patch_size:
|
|
@@ -511,7 +539,6 @@ class GalileoModel(FeatureExtractor):
|
|
|
511
539
|
torch_context = torch.amp.autocast(
|
|
512
540
|
device_type=device.type, dtype=self.autocast_dtype
|
|
513
541
|
)
|
|
514
|
-
|
|
515
542
|
with torch_context:
|
|
516
543
|
outputs = self.model(
|
|
517
544
|
s_t_x=galileo_input.s_t_x,
|
rslearn/models/module_wrapper.py
CHANGED
|
@@ -53,7 +53,12 @@ class EncoderModuleWrapper(FeatureExtractor):
|
|
|
53
53
|
Returns:
|
|
54
54
|
the output from the last wrapped module.
|
|
55
55
|
"""
|
|
56
|
-
|
|
56
|
+
# take the first and only timestep. Currently no intermediate
|
|
57
|
+
# components support multi temporal inputs, so if the input is
|
|
58
|
+
# multitemporal it should be wrapped in a simple time series wrapper.
|
|
59
|
+
images = torch.stack(
|
|
60
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
61
|
+
)
|
|
57
62
|
cur: Any = FeatureMaps([images])
|
|
58
63
|
for m in self.encoder_modules:
|
|
59
64
|
cur = m(cur, context)
|
rslearn/models/molmo.py
CHANGED
|
@@ -47,11 +47,13 @@ class Molmo(FeatureExtractor):
|
|
|
47
47
|
a FeatureMaps. Molmo produces features at one scale, so it will contain one
|
|
48
48
|
feature map that is a Bx24x24x2048 tensor.
|
|
49
49
|
"""
|
|
50
|
-
device = context.inputs[0]["image"].device
|
|
50
|
+
device = context.inputs[0]["image"].image.device
|
|
51
51
|
molmo_inputs_list = []
|
|
52
52
|
# Process each one so we can isolate just the full image without any crops.
|
|
53
53
|
for inp in context.inputs:
|
|
54
|
-
image =
|
|
54
|
+
image = (
|
|
55
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
56
|
+
)
|
|
55
57
|
processed = self.processor.process(
|
|
56
58
|
images=[image],
|
|
57
59
|
text="",
|
|
@@ -1,26 +1,27 @@
|
|
|
1
1
|
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import warnings
|
|
4
5
|
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
5
7
|
from typing import Any
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
10
|
from einops import rearrange
|
|
9
|
-
from
|
|
10
|
-
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.config import Config, require_olmo_core
|
|
11
12
|
from olmoearth_pretrain.data.constants import Modality
|
|
13
|
+
from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
|
|
12
14
|
from olmoearth_pretrain.model_loader import (
|
|
13
15
|
ModelID,
|
|
14
16
|
load_model_from_id,
|
|
15
17
|
load_model_from_path,
|
|
16
18
|
)
|
|
17
19
|
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
18
|
-
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
19
20
|
from upath import UPath
|
|
20
21
|
|
|
21
22
|
from rslearn.log_utils import get_logger
|
|
22
|
-
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
23
|
-
from rslearn.train.model_context import ModelContext
|
|
23
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
|
|
24
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
24
25
|
|
|
25
26
|
logger = get_logger(__name__)
|
|
26
27
|
|
|
@@ -60,6 +61,8 @@ class OlmoEarth(FeatureExtractor):
|
|
|
60
61
|
random_initialization: bool = False,
|
|
61
62
|
embedding_size: int | None = None,
|
|
62
63
|
autocast_dtype: str | None = "bfloat16",
|
|
64
|
+
token_pooling: bool = True,
|
|
65
|
+
use_legacy_timestamps: bool = True,
|
|
63
66
|
):
|
|
64
67
|
"""Create a new OlmoEarth model.
|
|
65
68
|
|
|
@@ -83,7 +86,18 @@ class OlmoEarth(FeatureExtractor):
|
|
|
83
86
|
embedding_size: optional embedding size to report via
|
|
84
87
|
get_backbone_channels (if model_id is not set).
|
|
85
88
|
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
89
|
+
token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
|
|
90
|
+
there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
|
|
91
|
+
dimensions.
|
|
92
|
+
use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
|
|
93
|
+
from 0 (instead of the actual timestamps of the input). The option to do this is preserved
|
|
94
|
+
for backwards compatability with finetuned models which were trained against this implementation.
|
|
86
95
|
"""
|
|
96
|
+
if use_legacy_timestamps:
|
|
97
|
+
warnings.warn(
|
|
98
|
+
"For new projects, don't use legacy timesteps.", DeprecationWarning
|
|
99
|
+
)
|
|
100
|
+
|
|
87
101
|
if (
|
|
88
102
|
sum(
|
|
89
103
|
[
|
|
@@ -133,6 +147,8 @@ class OlmoEarth(FeatureExtractor):
|
|
|
133
147
|
else:
|
|
134
148
|
model = model[part]
|
|
135
149
|
self.model = model
|
|
150
|
+
self.token_pooling = token_pooling
|
|
151
|
+
self.use_legacy_timestamps = use_legacy_timestamps
|
|
136
152
|
|
|
137
153
|
def _load_model_from_checkpoint(
|
|
138
154
|
self, checkpoint_upath: UPath, random_initialization: bool
|
|
@@ -143,9 +159,12 @@ class OlmoEarth(FeatureExtractor):
|
|
|
143
159
|
that contains the distributed checkpoint. This is the format produced by
|
|
144
160
|
pre-training runs in olmoearth_pretrain.
|
|
145
161
|
"""
|
|
146
|
-
# Load the model config and initialize it.
|
|
147
162
|
# We avoid loading the train module here because it depends on running within
|
|
148
163
|
# olmo_core.
|
|
164
|
+
# Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
|
|
165
|
+
require_olmo_core("_load_model_from_checkpoint")
|
|
166
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
167
|
+
|
|
149
168
|
with (checkpoint_upath / "config.json").open() as f:
|
|
150
169
|
config_dict = json.load(f)
|
|
151
170
|
model_config = Config.from_dict(config_dict["model"])
|
|
@@ -160,58 +179,161 @@ class OlmoEarth(FeatureExtractor):
|
|
|
160
179
|
|
|
161
180
|
return model
|
|
162
181
|
|
|
163
|
-
|
|
164
|
-
|
|
182
|
+
@staticmethod
|
|
183
|
+
def time_ranges_to_timestamps(
|
|
184
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
185
|
+
max_timestamps: int,
|
|
186
|
+
device: torch.device,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
|
|
189
|
+
|
|
190
|
+
OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
|
|
191
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
192
|
+
time so that start_time == end_time == mid_time.
|
|
193
|
+
"""
|
|
194
|
+
timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
|
|
195
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
196
|
+
timestamps[: len(time_ranges), 0] = torch.tensor(
|
|
197
|
+
[d.day for d in mid_ranges], dtype=torch.int32
|
|
198
|
+
)
|
|
199
|
+
# months are indexed 0-11
|
|
200
|
+
timestamps[: len(time_ranges), 1] = torch.tensor(
|
|
201
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32
|
|
202
|
+
)
|
|
203
|
+
timestamps[: len(time_ranges), 2] = torch.tensor(
|
|
204
|
+
[d.year for d in mid_ranges], dtype=torch.int32
|
|
205
|
+
)
|
|
206
|
+
return timestamps
|
|
207
|
+
|
|
208
|
+
def _prepare_modality_inputs(
|
|
209
|
+
self, context: ModelContext
|
|
210
|
+
) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
|
|
211
|
+
"""Prepare modality tensors and masks for the OlmoEarth model.
|
|
212
|
+
|
|
213
|
+
Uses a two-pass approach to ensure all modalities have consistent timestep
|
|
214
|
+
dimensions for position encoding.
|
|
165
215
|
|
|
166
216
|
Args:
|
|
167
|
-
context: the model context
|
|
168
|
-
to the modalities that should be passed to the OlmoEarth model.
|
|
217
|
+
context: the model context with input tensors.
|
|
169
218
|
|
|
170
219
|
Returns:
|
|
171
|
-
|
|
172
|
-
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
220
|
+
tuple of (sample, present_modalities, device)
|
|
173
221
|
"""
|
|
174
222
|
kwargs = {}
|
|
175
223
|
present_modalities = []
|
|
176
224
|
device = None
|
|
177
|
-
|
|
178
|
-
#
|
|
225
|
+
|
|
226
|
+
# First pass: find global max_timesteps across all modalities and samples
|
|
227
|
+
# TODO: currently we assume all modalities have the same number of timesteps,
|
|
228
|
+
# which is not true for all cases, and time series time steps are assumed to
|
|
229
|
+
# be 1-month apart. It also assumes continuity between available timesteps.
|
|
230
|
+
# We'll have to fix all that.
|
|
179
231
|
max_timesteps = 1
|
|
232
|
+
modality_data = {}
|
|
233
|
+
# we will just store the longest time range
|
|
234
|
+
# per instance in the batch. This means it may not be
|
|
235
|
+
# aligned per modality
|
|
236
|
+
timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
|
|
237
|
+
context.inputs
|
|
238
|
+
)
|
|
180
239
|
for modality in MODALITY_NAMES:
|
|
181
240
|
if modality not in context.inputs[0]:
|
|
182
241
|
continue
|
|
183
242
|
present_modalities.append(modality)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
243
|
+
tensors = []
|
|
244
|
+
for idx, inp in enumerate(context.inputs):
|
|
245
|
+
assert isinstance(inp, RasterImage)
|
|
246
|
+
tensors.append(inp[modality].image)
|
|
247
|
+
cur_timestamps = inp[modality].timestamps
|
|
248
|
+
if cur_timestamps is not None and len(cur_timestamps) > len(
|
|
249
|
+
timestamps_per_instance[idx]
|
|
250
|
+
):
|
|
251
|
+
timestamps_per_instance[idx] = cur_timestamps
|
|
252
|
+
tensors = [inp[modality].image for inp in context.inputs]
|
|
253
|
+
device = tensors[0].device
|
|
254
|
+
max_t = max(t.shape[1] for t in tensors)
|
|
255
|
+
max_timesteps = max(max_timesteps, max_t)
|
|
256
|
+
modality_data[modality] = (
|
|
257
|
+
tensors,
|
|
258
|
+
len(Modality.get(modality).band_sets),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Second pass: pad and process each modality with global max_timesteps
|
|
262
|
+
for modality in present_modalities:
|
|
263
|
+
tensors, num_band_sets = modality_data[modality]
|
|
264
|
+
|
|
265
|
+
# Pad tensors to target_ch and track original timesteps for masking
|
|
266
|
+
padded = []
|
|
267
|
+
original_timesteps = []
|
|
268
|
+
for t in tensors:
|
|
269
|
+
orig_t = t.shape[1]
|
|
270
|
+
original_timesteps.append(orig_t)
|
|
271
|
+
if orig_t < max_timesteps:
|
|
272
|
+
pad = torch.zeros(
|
|
273
|
+
t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
|
|
274
|
+
dtype=t.dtype,
|
|
275
|
+
device=device,
|
|
276
|
+
)
|
|
277
|
+
t = torch.cat([t, pad], dim=1)
|
|
278
|
+
padded.append(t)
|
|
279
|
+
|
|
280
|
+
cur = torch.stack(padded, dim=0)
|
|
281
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
191
282
|
kwargs[modality] = cur
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
mask = (
|
|
196
|
-
|
|
197
|
-
|
|
283
|
+
|
|
284
|
+
# Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
|
|
285
|
+
b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
|
|
286
|
+
mask = torch.full(
|
|
287
|
+
(b, h, w, max_timesteps, num_band_sets),
|
|
288
|
+
fill_value=MaskValue.ONLINE_ENCODER.value,
|
|
289
|
+
dtype=torch.int32,
|
|
290
|
+
device=device,
|
|
198
291
|
)
|
|
292
|
+
for sample_idx, orig_t in enumerate(original_timesteps):
|
|
293
|
+
if orig_t < max_timesteps:
|
|
294
|
+
mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
|
|
199
295
|
kwargs[f"{modality}_mask"] = mask
|
|
200
296
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
297
|
+
if self.use_legacy_timestamps:
|
|
298
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
299
|
+
timestamps = torch.zeros(
|
|
300
|
+
(len(context.inputs), max_timesteps, 3),
|
|
301
|
+
dtype=torch.int32,
|
|
302
|
+
device=device,
|
|
303
|
+
)
|
|
304
|
+
timestamps[:, :, 0] = 1 # day
|
|
305
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
306
|
+
None, :
|
|
307
|
+
] # month
|
|
308
|
+
timestamps[:, :, 2] = 2024 # year
|
|
309
|
+
kwargs["timestamps"] = timestamps
|
|
310
|
+
else:
|
|
311
|
+
if max([len(t) for t in timestamps_per_instance]) == 0:
|
|
312
|
+
# Timestamps is required.
|
|
313
|
+
raise ValueError("No inputs had timestamps.")
|
|
314
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
315
|
+
kwargs["timestamps"] = torch.stack(
|
|
316
|
+
[
|
|
317
|
+
self.time_ranges_to_timestamps(time_range, max_timesteps, device)
|
|
318
|
+
for time_range in timestamps_per_instance
|
|
319
|
+
],
|
|
320
|
+
dim=0,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
return MaskedOlmoEarthSample(**kwargs), present_modalities, device
|
|
324
|
+
|
|
325
|
+
def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
|
|
326
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
context: the model context. Input dicts should include keys corresponding
|
|
330
|
+
to the modalities that should be passed to the OlmoEarth model.
|
|
213
331
|
|
|
214
|
-
|
|
332
|
+
Returns:
|
|
333
|
+
a FeatureMaps consisting of one feature map, at 1/patch_size of the input
|
|
334
|
+
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
335
|
+
"""
|
|
336
|
+
sample, present_modalities, device = self._prepare_modality_inputs(context)
|
|
215
337
|
|
|
216
338
|
# Decide context based on self.autocast_dtype.
|
|
217
339
|
if self.autocast_dtype is None:
|
|
@@ -222,6 +344,14 @@ class OlmoEarth(FeatureExtractor):
|
|
|
222
344
|
device_type=device.type, dtype=self.autocast_dtype
|
|
223
345
|
)
|
|
224
346
|
|
|
347
|
+
# Check if we can bypass masks (fast_pass=True)
|
|
348
|
+
missing_tokens = False
|
|
349
|
+
for modality in present_modalities:
|
|
350
|
+
modality_mask = getattr(sample, f"{modality}_mask")
|
|
351
|
+
if torch.any(modality_mask == MaskValue.MISSING.value):
|
|
352
|
+
missing_tokens = True
|
|
353
|
+
break
|
|
354
|
+
|
|
225
355
|
with torch_context:
|
|
226
356
|
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
227
357
|
tokens_and_masks: TokensAndMasks
|
|
@@ -229,7 +359,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
229
359
|
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
230
360
|
tokens_and_masks = self.model(
|
|
231
361
|
sample,
|
|
232
|
-
fast_pass=
|
|
362
|
+
fast_pass=not missing_tokens,
|
|
233
363
|
patch_size=self.patch_size,
|
|
234
364
|
**self.forward_kwargs,
|
|
235
365
|
)["tokens_and_masks"]
|
|
@@ -241,16 +371,41 @@ class OlmoEarth(FeatureExtractor):
|
|
|
241
371
|
|
|
242
372
|
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
243
373
|
features = []
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
374
|
+
if self.token_pooling:
|
|
375
|
+
for modality in present_modalities:
|
|
376
|
+
modality_features = getattr(tokens_and_masks, modality) # BHWTSC
|
|
377
|
+
# If fast_pass is False, we need to mask the missing tokens before pooling.
|
|
378
|
+
if missing_tokens:
|
|
379
|
+
modality_masks = getattr(
|
|
380
|
+
tokens_and_masks, f"{modality}_mask"
|
|
381
|
+
) # BHWTS
|
|
382
|
+
modality_masks_bool = (
|
|
383
|
+
modality_masks != MaskValue.MISSING.value
|
|
384
|
+
).unsqueeze(-1)
|
|
385
|
+
count = modality_masks_bool.sum(dim=[3, 4])
|
|
386
|
+
# Masked average over band sets and timesteps (BHWTSC -> BHWC).
|
|
387
|
+
pooled = (modality_features * modality_masks_bool).sum(
|
|
388
|
+
dim=[3, 4]
|
|
389
|
+
) / count.clamp(min=1)
|
|
390
|
+
else:
|
|
391
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
392
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
393
|
+
# We want BHWC -> BCHW.
|
|
394
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
395
|
+
features.append(pooled)
|
|
396
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
397
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
398
|
+
return FeatureMaps([pooled])
|
|
399
|
+
else:
|
|
400
|
+
for modality in present_modalities:
|
|
401
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
402
|
+
# Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
|
|
403
|
+
modality_features = rearrange(
|
|
404
|
+
modality_features, "b h w t s c -> b c h w (t s)"
|
|
405
|
+
)
|
|
406
|
+
features.append(modality_features)
|
|
407
|
+
pooled = torch.cat(features, dim=-1)
|
|
408
|
+
return TokenFeatureMaps([pooled])
|
|
254
409
|
|
|
255
410
|
def get_backbone_channels(self) -> list:
|
|
256
411
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -64,8 +64,8 @@ class OlmoEarthNormalize(Transform):
|
|
|
64
64
|
band_norms = self.norm_config[modality_name]
|
|
65
65
|
image = input_dict[modality_name]
|
|
66
66
|
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
-
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
-
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
67
|
+
needed_band_indices = set(range(image.image.shape[0]))
|
|
68
|
+
num_timesteps = image.image.shape[0] // len(cur_band_names)
|
|
69
69
|
|
|
70
70
|
for band, norm_dict in band_norms.items():
|
|
71
71
|
# If multitemporal, normalize each timestep separately.
|
|
@@ -73,7 +73,9 @@ class OlmoEarthNormalize(Transform):
|
|
|
73
73
|
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
74
|
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
75
|
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
-
image[band_idx] = (image[band_idx] - min_val) / (
|
|
76
|
+
image.image[band_idx] = (image.image[band_idx] - min_val) / (
|
|
77
|
+
max_val - min_val
|
|
78
|
+
)
|
|
77
79
|
needed_band_indices.remove(band_idx)
|
|
78
80
|
|
|
79
81
|
if len(needed_band_indices) > 0:
|
rslearn/models/panopticon.py
CHANGED
|
@@ -142,7 +142,9 @@ class Panopticon(FeatureExtractor):
|
|
|
142
142
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
143
143
|
"""Forward pass through the panopticon model."""
|
|
144
144
|
batch_inputs = {
|
|
145
|
-
key: torch.stack(
|
|
145
|
+
key: torch.stack(
|
|
146
|
+
[inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
147
|
+
)
|
|
146
148
|
for key in context.inputs[0].keys()
|
|
147
149
|
}
|
|
148
150
|
panopticon_inputs = self.prepare_input(batch_inputs)
|