rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import warnings
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from olmoearth_pretrain.config import Config, require_olmo_core
|
|
12
|
+
from olmoearth_pretrain.data.constants import Modality
|
|
13
|
+
from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
|
|
14
|
+
from olmoearth_pretrain.model_loader import (
|
|
15
|
+
ModelID,
|
|
16
|
+
load_model_from_id,
|
|
17
|
+
load_model_from_path,
|
|
18
|
+
)
|
|
19
|
+
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
20
|
+
from upath import UPath
|
|
21
|
+
|
|
22
|
+
from rslearn.log_utils import get_logger
|
|
23
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
|
|
24
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
MODALITY_NAMES = [
|
|
29
|
+
"sentinel2_l2a",
|
|
30
|
+
"sentinel1",
|
|
31
|
+
"worldcover",
|
|
32
|
+
"openstreetmap_raster",
|
|
33
|
+
"landsat",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
AUTOCAST_DTYPE_MAP = {
|
|
37
|
+
"bfloat16": torch.bfloat16,
|
|
38
|
+
"float16": torch.float16,
|
|
39
|
+
"float32": torch.float32,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
EMBEDDING_SIZES = {
|
|
43
|
+
ModelID.OLMOEARTH_V1_NANO: 128,
|
|
44
|
+
ModelID.OLMOEARTH_V1_TINY: 192,
|
|
45
|
+
ModelID.OLMOEARTH_V1_BASE: 768,
|
|
46
|
+
ModelID.OLMOEARTH_V1_LARGE: 1024,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OlmoEarth(FeatureExtractor):
|
|
51
|
+
"""A wrapper to support the OlmoEarth model."""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
patch_size: int,
|
|
56
|
+
model_id: ModelID | None = None,
|
|
57
|
+
model_path: str | None = None,
|
|
58
|
+
checkpoint_path: str | None = None,
|
|
59
|
+
selector: list[str | int] = ["encoder"],
|
|
60
|
+
forward_kwargs: dict[str, Any] = {},
|
|
61
|
+
random_initialization: bool = False,
|
|
62
|
+
embedding_size: int | None = None,
|
|
63
|
+
autocast_dtype: str | None = "bfloat16",
|
|
64
|
+
token_pooling: bool = True,
|
|
65
|
+
use_legacy_timestamps: bool = True,
|
|
66
|
+
):
|
|
67
|
+
"""Create a new OlmoEarth model.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
patch_size: token spatial patch size to use.
|
|
71
|
+
model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
|
|
72
|
+
set.
|
|
73
|
+
model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
|
|
74
|
+
set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
|
|
75
|
+
checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
|
|
76
|
+
set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
|
|
77
|
+
folder.
|
|
78
|
+
selector: an optional sequence of attribute names or list indices to select
|
|
79
|
+
the sub-module that should be applied on the input images. Defaults to
|
|
80
|
+
["encoder"] to select only the transformer encoder.
|
|
81
|
+
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
82
|
+
MaskedOlmoEarthSample.
|
|
83
|
+
random_initialization: whether to skip loading the checkpoint so the
|
|
84
|
+
weights are randomly initialized. In this case, the checkpoint is only
|
|
85
|
+
used to define the model architecture.
|
|
86
|
+
embedding_size: optional embedding size to report via
|
|
87
|
+
get_backbone_channels (if model_id is not set).
|
|
88
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
89
|
+
token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
|
|
90
|
+
there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
|
|
91
|
+
dimensions.
|
|
92
|
+
use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
|
|
93
|
+
from 0 (instead of the actual timestamps of the input). The option to do this is preserved
|
|
94
|
+
for backwards compatability with finetuned models which were trained against this implementation.
|
|
95
|
+
"""
|
|
96
|
+
if use_legacy_timestamps:
|
|
97
|
+
warnings.warn(
|
|
98
|
+
"For new projects, don't use legacy timesteps.", DeprecationWarning
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if (
|
|
102
|
+
sum(
|
|
103
|
+
[
|
|
104
|
+
model_id is not None,
|
|
105
|
+
model_path is not None,
|
|
106
|
+
checkpoint_path is not None,
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
!= 1
|
|
110
|
+
):
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"exactly one of model_id, model_path, or checkpoint_path must be set"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.patch_size = patch_size
|
|
117
|
+
self.forward_kwargs = forward_kwargs
|
|
118
|
+
self.embedding_size = embedding_size
|
|
119
|
+
|
|
120
|
+
if autocast_dtype is not None:
|
|
121
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
122
|
+
else:
|
|
123
|
+
self.autocast_dtype = None
|
|
124
|
+
|
|
125
|
+
if model_id is not None:
|
|
126
|
+
# Load from Hugging Face.
|
|
127
|
+
model = load_model_from_id(model_id, load_weights=not random_initialization)
|
|
128
|
+
if self.embedding_size is None and model_id in EMBEDDING_SIZES:
|
|
129
|
+
self.embedding_size = EMBEDDING_SIZES[model_id]
|
|
130
|
+
|
|
131
|
+
elif model_path is not None:
|
|
132
|
+
# Load from path.
|
|
133
|
+
model = load_model_from_path(
|
|
134
|
+
UPath(model_path), load_weights=not random_initialization
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
else:
|
|
138
|
+
# Load the distributed model checkpoint by path through Olmo Core
|
|
139
|
+
model = self._load_model_from_checkpoint(
|
|
140
|
+
UPath(checkpoint_path), random_initialization
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Select just the portion of the model that we actually want to use.
|
|
144
|
+
for part in selector:
|
|
145
|
+
if isinstance(part, str):
|
|
146
|
+
model = getattr(model, part)
|
|
147
|
+
else:
|
|
148
|
+
model = model[part]
|
|
149
|
+
self.model = model
|
|
150
|
+
self.token_pooling = token_pooling
|
|
151
|
+
self.use_legacy_timestamps = use_legacy_timestamps
|
|
152
|
+
|
|
153
|
+
def _load_model_from_checkpoint(
|
|
154
|
+
self, checkpoint_upath: UPath, random_initialization: bool
|
|
155
|
+
) -> torch.nn.Module:
|
|
156
|
+
"""Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
|
|
157
|
+
|
|
158
|
+
The folder should contain config.json as well as the model_and_optim folder
|
|
159
|
+
that contains the distributed checkpoint. This is the format produced by
|
|
160
|
+
pre-training runs in olmoearth_pretrain.
|
|
161
|
+
"""
|
|
162
|
+
with (checkpoint_upath / "config.json").open() as f:
|
|
163
|
+
config_dict = json.load(f)
|
|
164
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
165
|
+
|
|
166
|
+
model = model_config.build()
|
|
167
|
+
|
|
168
|
+
# Load the checkpoint (requires olmo_core for distributed checkpoint loading).
|
|
169
|
+
if not random_initialization:
|
|
170
|
+
require_olmo_core(
|
|
171
|
+
"_load_model_from_checkpoint with random_initialization=False"
|
|
172
|
+
)
|
|
173
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
174
|
+
|
|
175
|
+
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
176
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
177
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
178
|
+
|
|
179
|
+
return model
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def time_ranges_to_timestamps(
|
|
183
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
184
|
+
max_timestamps: int,
|
|
185
|
+
device: torch.device,
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
|
|
188
|
+
|
|
189
|
+
OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
|
|
190
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
191
|
+
time so that start_time == end_time == mid_time.
|
|
192
|
+
"""
|
|
193
|
+
timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
|
|
194
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
195
|
+
timestamps[: len(time_ranges), 0] = torch.tensor(
|
|
196
|
+
[d.day for d in mid_ranges], dtype=torch.int32
|
|
197
|
+
)
|
|
198
|
+
# months are indexed 0-11
|
|
199
|
+
timestamps[: len(time_ranges), 1] = torch.tensor(
|
|
200
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32
|
|
201
|
+
)
|
|
202
|
+
timestamps[: len(time_ranges), 2] = torch.tensor(
|
|
203
|
+
[d.year for d in mid_ranges], dtype=torch.int32
|
|
204
|
+
)
|
|
205
|
+
return timestamps
|
|
206
|
+
|
|
207
|
+
def _prepare_modality_inputs(
|
|
208
|
+
self, context: ModelContext
|
|
209
|
+
) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
|
|
210
|
+
"""Prepare modality tensors and masks for the OlmoEarth model.
|
|
211
|
+
|
|
212
|
+
Uses a two-pass approach to ensure all modalities have consistent timestep
|
|
213
|
+
dimensions for position encoding.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
context: the model context with input tensors.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
tuple of (sample, present_modalities, device)
|
|
220
|
+
"""
|
|
221
|
+
kwargs = {}
|
|
222
|
+
present_modalities = []
|
|
223
|
+
device = None
|
|
224
|
+
|
|
225
|
+
# First pass: find global max_timesteps across all modalities and samples
|
|
226
|
+
# TODO: currently we assume all modalities have the same number of timesteps,
|
|
227
|
+
# which is not true for all cases, and time series time steps are assumed to
|
|
228
|
+
# be 1-month apart. It also assumes continuity between available timesteps.
|
|
229
|
+
# We'll have to fix all that.
|
|
230
|
+
max_timesteps = 1
|
|
231
|
+
modality_data = {}
|
|
232
|
+
# we will just store the longest time range
|
|
233
|
+
# per instance in the batch. This means it may not be
|
|
234
|
+
# aligned per modality
|
|
235
|
+
timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
|
|
236
|
+
context.inputs
|
|
237
|
+
)
|
|
238
|
+
for modality in MODALITY_NAMES:
|
|
239
|
+
if modality not in context.inputs[0]:
|
|
240
|
+
continue
|
|
241
|
+
present_modalities.append(modality)
|
|
242
|
+
tensors = []
|
|
243
|
+
for idx, inp in enumerate(context.inputs):
|
|
244
|
+
assert isinstance(inp[modality], RasterImage)
|
|
245
|
+
tensors.append(inp[modality].image)
|
|
246
|
+
cur_timestamps = inp[modality].timestamps
|
|
247
|
+
if cur_timestamps is not None and len(cur_timestamps) > len(
|
|
248
|
+
timestamps_per_instance[idx]
|
|
249
|
+
):
|
|
250
|
+
timestamps_per_instance[idx] = cur_timestamps
|
|
251
|
+
tensors = [inp[modality].image for inp in context.inputs]
|
|
252
|
+
device = tensors[0].device
|
|
253
|
+
max_t = max(t.shape[1] for t in tensors)
|
|
254
|
+
max_timesteps = max(max_timesteps, max_t)
|
|
255
|
+
modality_data[modality] = (
|
|
256
|
+
tensors,
|
|
257
|
+
len(Modality.get(modality).band_sets),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Second pass: pad and process each modality with global max_timesteps
|
|
261
|
+
for modality in present_modalities:
|
|
262
|
+
tensors, num_band_sets = modality_data[modality]
|
|
263
|
+
|
|
264
|
+
# Pad tensors to target_ch and track original timesteps for masking
|
|
265
|
+
padded = []
|
|
266
|
+
original_timesteps = []
|
|
267
|
+
for t in tensors:
|
|
268
|
+
orig_t = t.shape[1]
|
|
269
|
+
original_timesteps.append(orig_t)
|
|
270
|
+
if orig_t < max_timesteps:
|
|
271
|
+
pad = torch.zeros(
|
|
272
|
+
t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
|
|
273
|
+
dtype=t.dtype,
|
|
274
|
+
device=device,
|
|
275
|
+
)
|
|
276
|
+
t = torch.cat([t, pad], dim=1)
|
|
277
|
+
padded.append(t)
|
|
278
|
+
|
|
279
|
+
cur = torch.stack(padded, dim=0)
|
|
280
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
281
|
+
kwargs[modality] = cur
|
|
282
|
+
|
|
283
|
+
# Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
|
|
284
|
+
b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
|
|
285
|
+
mask = torch.full(
|
|
286
|
+
(b, h, w, max_timesteps, num_band_sets),
|
|
287
|
+
fill_value=MaskValue.ONLINE_ENCODER.value,
|
|
288
|
+
dtype=torch.int32,
|
|
289
|
+
device=device,
|
|
290
|
+
)
|
|
291
|
+
for sample_idx, orig_t in enumerate(original_timesteps):
|
|
292
|
+
if orig_t < max_timesteps:
|
|
293
|
+
mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
|
|
294
|
+
kwargs[f"{modality}_mask"] = mask
|
|
295
|
+
|
|
296
|
+
if self.use_legacy_timestamps:
|
|
297
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
298
|
+
timestamps = torch.zeros(
|
|
299
|
+
(len(context.inputs), max_timesteps, 3),
|
|
300
|
+
dtype=torch.int32,
|
|
301
|
+
device=device,
|
|
302
|
+
)
|
|
303
|
+
timestamps[:, :, 0] = 1 # day
|
|
304
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
305
|
+
None, :
|
|
306
|
+
] # month
|
|
307
|
+
timestamps[:, :, 2] = 2024 # year
|
|
308
|
+
kwargs["timestamps"] = timestamps
|
|
309
|
+
else:
|
|
310
|
+
if max([len(t) for t in timestamps_per_instance]) == 0:
|
|
311
|
+
# Timestamps is required.
|
|
312
|
+
raise ValueError("No inputs had timestamps.")
|
|
313
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
314
|
+
kwargs["timestamps"] = torch.stack(
|
|
315
|
+
[
|
|
316
|
+
self.time_ranges_to_timestamps(time_range, max_timesteps, device)
|
|
317
|
+
for time_range in timestamps_per_instance
|
|
318
|
+
],
|
|
319
|
+
dim=0,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return MaskedOlmoEarthSample(**kwargs), present_modalities, device
|
|
323
|
+
|
|
324
|
+
def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
|
|
325
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
context: the model context. Input dicts should include keys corresponding
|
|
329
|
+
to the modalities that should be passed to the OlmoEarth model.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
a FeatureMaps consisting of one feature map, at 1/patch_size of the input
|
|
333
|
+
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
334
|
+
"""
|
|
335
|
+
sample, present_modalities, device = self._prepare_modality_inputs(context)
|
|
336
|
+
|
|
337
|
+
# Decide context based on self.autocast_dtype.
|
|
338
|
+
if self.autocast_dtype is None:
|
|
339
|
+
torch_context = nullcontext()
|
|
340
|
+
else:
|
|
341
|
+
assert device is not None
|
|
342
|
+
torch_context = torch.amp.autocast(
|
|
343
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Check if we can bypass masks (fast_pass=True)
|
|
347
|
+
missing_tokens = False
|
|
348
|
+
for modality in present_modalities:
|
|
349
|
+
modality_mask = getattr(sample, f"{modality}_mask")
|
|
350
|
+
if torch.any(modality_mask == MaskValue.MISSING.value):
|
|
351
|
+
missing_tokens = True
|
|
352
|
+
break
|
|
353
|
+
|
|
354
|
+
with torch_context:
|
|
355
|
+
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
356
|
+
tokens_and_masks: TokensAndMasks
|
|
357
|
+
if isinstance(self.model, Encoder):
|
|
358
|
+
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
359
|
+
tokens_and_masks = self.model(
|
|
360
|
+
sample,
|
|
361
|
+
fast_pass=not missing_tokens,
|
|
362
|
+
patch_size=self.patch_size,
|
|
363
|
+
**self.forward_kwargs,
|
|
364
|
+
)["tokens_and_masks"]
|
|
365
|
+
else:
|
|
366
|
+
# Other models like STEncoder do not have this option supported.
|
|
367
|
+
tokens_and_masks = self.model(
|
|
368
|
+
sample, patch_size=self.patch_size, **self.forward_kwargs
|
|
369
|
+
)["tokens_and_masks"]
|
|
370
|
+
|
|
371
|
+
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
372
|
+
features = []
|
|
373
|
+
if self.token_pooling:
|
|
374
|
+
for modality in present_modalities:
|
|
375
|
+
modality_features = getattr(tokens_and_masks, modality) # BHWTSC
|
|
376
|
+
# If fast_pass is False, we need to mask the missing tokens before pooling.
|
|
377
|
+
if missing_tokens:
|
|
378
|
+
modality_masks = getattr(
|
|
379
|
+
tokens_and_masks, f"{modality}_mask"
|
|
380
|
+
) # BHWTS
|
|
381
|
+
modality_masks_bool = (
|
|
382
|
+
modality_masks != MaskValue.MISSING.value
|
|
383
|
+
).unsqueeze(-1)
|
|
384
|
+
count = modality_masks_bool.sum(dim=[3, 4])
|
|
385
|
+
# Masked average over band sets and timesteps (BHWTSC -> BHWC).
|
|
386
|
+
pooled = (modality_features * modality_masks_bool).sum(
|
|
387
|
+
dim=[3, 4]
|
|
388
|
+
) / count.clamp(min=1)
|
|
389
|
+
else:
|
|
390
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
391
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
392
|
+
# We want BHWC -> BCHW.
|
|
393
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
394
|
+
features.append(pooled)
|
|
395
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
396
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
397
|
+
return FeatureMaps([pooled])
|
|
398
|
+
else:
|
|
399
|
+
for modality in present_modalities:
|
|
400
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
401
|
+
# Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
|
|
402
|
+
modality_features = rearrange(
|
|
403
|
+
modality_features, "b h w t s c -> b c h w (t s)"
|
|
404
|
+
)
|
|
405
|
+
features.append(modality_features)
|
|
406
|
+
pooled = torch.cat(features, dim=-1)
|
|
407
|
+
return TokenFeatureMaps([pooled])
|
|
408
|
+
|
|
409
|
+
def get_backbone_channels(self) -> list:
|
|
410
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
411
|
+
|
|
412
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
413
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
414
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
415
|
+
has 32 channels.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
419
|
+
tuples.
|
|
420
|
+
"""
|
|
421
|
+
return [(self.patch_size, self.embedding_size)]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Normalization transforms."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from olmoearth_pretrain.data.normalize import load_computed_config
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.train.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__file__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OlmoEarthNormalize(Transform):
|
|
15
|
+
"""Normalize using OlmoEarth JSON config.
|
|
16
|
+
|
|
17
|
+
For Sentinel-1 data, the values should be converted to decibels before being passed
|
|
18
|
+
to this transform.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
band_names: dict[str, list[str]],
|
|
24
|
+
std_multiplier: float | None = 2,
|
|
25
|
+
config_fname: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize a new OlmoEarthNormalize.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band_names: map from modality name to the list of bands in that modality in
|
|
31
|
+
the order they are being loaded. Note that this order must match the
|
|
32
|
+
expected order for the OlmoEarth model.
|
|
33
|
+
std_multiplier: the std multiplier matching the one used for the model
|
|
34
|
+
training in OlmoEarth.
|
|
35
|
+
config_fname: load the normalization configuration from this file, instead
|
|
36
|
+
of getting it from OlmoEarth.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.band_names = band_names
|
|
40
|
+
self.std_multiplier = std_multiplier
|
|
41
|
+
|
|
42
|
+
if config_fname is None:
|
|
43
|
+
self.norm_config = load_computed_config()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
|
|
47
|
+
)
|
|
48
|
+
with open(config_fname) as f:
|
|
49
|
+
self.norm_config = json.load(f)
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Apply normalization over the inputs and targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
input_dict: the input
|
|
58
|
+
target_dict: the target
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
normalized (input_dicts, target_dicts) tuple
|
|
62
|
+
"""
|
|
63
|
+
for modality_name, cur_band_names in self.band_names.items():
|
|
64
|
+
band_norms = self.norm_config[modality_name]
|
|
65
|
+
image = input_dict[modality_name]
|
|
66
|
+
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
+
needed_band_indices = set(range(image.image.shape[0]))
|
|
68
|
+
num_timesteps = image.image.shape[0] // len(cur_band_names)
|
|
69
|
+
|
|
70
|
+
for band, norm_dict in band_norms.items():
|
|
71
|
+
# If multitemporal, normalize each timestep separately.
|
|
72
|
+
for t in range(num_timesteps):
|
|
73
|
+
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
|
+
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
|
+
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
+
image.image[band_idx] = (image.image[band_idx] - min_val) / (
|
|
77
|
+
max_val - min_val
|
|
78
|
+
)
|
|
79
|
+
needed_band_indices.remove(band_idx)
|
|
80
|
+
|
|
81
|
+
if len(needed_band_indices) > 0:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return input_dict, target_dict
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""Wrapper for the Panopticon model."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from importlib import resources
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
import yaml
|
|
10
|
+
from einops import rearrange, repeat
|
|
11
|
+
|
|
12
|
+
from rslearn.log_utils import get_logger
|
|
13
|
+
from rslearn.train.model_context import ModelContext
|
|
14
|
+
|
|
15
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PanopticonModalities(StrEnum):
|
|
21
|
+
"""Modalities supported by Panopticon.
|
|
22
|
+
|
|
23
|
+
These are the keys needed to load the yaml file from panopticon_data/sensors
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
SENTINEL2 = "sentinel2"
|
|
27
|
+
LANDSAT8 = "landsat8"
|
|
28
|
+
SENTINEL1 = "sentinel1"
|
|
29
|
+
# Add more modalities as needed
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Panopticon(FeatureExtractor):
|
|
33
|
+
"""Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
|
|
34
|
+
|
|
35
|
+
patch_size: int = 14
|
|
36
|
+
base_image_size: int = 224
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
band_order: dict[str, list[str]],
|
|
41
|
+
torchhub_id: str = "panopticon_vitb14",
|
|
42
|
+
):
|
|
43
|
+
"""Initialize the Panopticon wrapper.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
band_order: The band order for the panopticon model, must match the specified order in the data config
|
|
47
|
+
torchhub_id: The torch hub model ID for panopticon
|
|
48
|
+
"""
|
|
49
|
+
super().__init__()
|
|
50
|
+
# Load the panopticon model
|
|
51
|
+
self._load_model(torchhub_id)
|
|
52
|
+
self.output_dim = self.model.embed_dim
|
|
53
|
+
self.band_order = band_order
|
|
54
|
+
self.supported_modalities = list(band_order.keys())
|
|
55
|
+
|
|
56
|
+
def _load_model(self, torchhub_id: str) -> None:
|
|
57
|
+
"""Load the panopticon model from torch hub."""
|
|
58
|
+
import time
|
|
59
|
+
|
|
60
|
+
# Hack to get around https://discuss.pytorch.org/t/torch-hub-load-gives-httperror-rate-limit-exceeded/124769
|
|
61
|
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
|
62
|
+
for attempt in range(2):
|
|
63
|
+
try:
|
|
64
|
+
self.model = torch.hub.load( # nosec B614
|
|
65
|
+
"panopticon-FM/panopticon",
|
|
66
|
+
torchhub_id,
|
|
67
|
+
)
|
|
68
|
+
break
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.warning(
|
|
71
|
+
f"Error loading panopticon model: {e}. Retrying in 5 seconds..."
|
|
72
|
+
)
|
|
73
|
+
time.sleep(5)
|
|
74
|
+
else:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
f"Failed to load panopticon model {torchhub_id} after retrying."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _process_modality_data(self, data: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
"""Process individual modality data.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
data: Input tensor of shape [B, C, H, W]
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Processed tensor of shape [B, C, H, W]
|
|
87
|
+
"""
|
|
88
|
+
original_height = data.shape[2]
|
|
89
|
+
new_height = self.patch_size if original_height == 1 else self.base_image_size
|
|
90
|
+
|
|
91
|
+
data = F.interpolate(
|
|
92
|
+
data,
|
|
93
|
+
size=(new_height, new_height),
|
|
94
|
+
mode="bilinear",
|
|
95
|
+
align_corners=False,
|
|
96
|
+
)
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
def _create_channel_ids(
|
|
100
|
+
self, modality: str, batch_size: int, device: torch.device
|
|
101
|
+
) -> torch.Tensor:
|
|
102
|
+
"""Create channel IDs for the panopticon model."""
|
|
103
|
+
with resources.open_text(
|
|
104
|
+
"rslearn.models.panopticon_data.sensors", f"{modality}.yaml"
|
|
105
|
+
) as f:
|
|
106
|
+
sensor_config = yaml.safe_load(f)
|
|
107
|
+
|
|
108
|
+
band_order = self.band_order[modality]
|
|
109
|
+
chn_ids = [
|
|
110
|
+
sensor_config["bands"][band.upper()]["gaussian"]["mu"]
|
|
111
|
+
for band in band_order
|
|
112
|
+
]
|
|
113
|
+
chn_ids = torch.tensor(chn_ids, dtype=torch.float32, device=device)
|
|
114
|
+
chn_ids = repeat(chn_ids, "c -> b c", b=batch_size)
|
|
115
|
+
return chn_ids
|
|
116
|
+
|
|
117
|
+
def prepare_input(
|
|
118
|
+
self, input_data: dict[str, torch.Tensor]
|
|
119
|
+
) -> dict[str, torch.Tensor]:
|
|
120
|
+
"""Prepare input for the panopticon model from MaskedHeliosSample."""
|
|
121
|
+
channel_ids_list: list[torch.Tensor] = []
|
|
122
|
+
processed_data_list: list[torch.Tensor] = []
|
|
123
|
+
for modality in self.supported_modalities:
|
|
124
|
+
if modality not in input_data.keys():
|
|
125
|
+
logger.debug(f"Modality {modality} not found in input data")
|
|
126
|
+
continue
|
|
127
|
+
data = input_data[modality]
|
|
128
|
+
device = data.device
|
|
129
|
+
processed_data = self._process_modality_data(data)
|
|
130
|
+
processed_data_list.append(processed_data)
|
|
131
|
+
batch_size = processed_data.shape[0]
|
|
132
|
+
chn_ids = self._create_channel_ids(modality, batch_size, device)
|
|
133
|
+
channel_ids_list.append(chn_ids)
|
|
134
|
+
|
|
135
|
+
processed_data = torch.cat(processed_data_list, dim=1)
|
|
136
|
+
chn_ids = torch.cat(channel_ids_list, dim=1)
|
|
137
|
+
return {
|
|
138
|
+
"imgs": processed_data,
|
|
139
|
+
"chn_ids": chn_ids,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
143
|
+
"""Forward pass through the panopticon model."""
|
|
144
|
+
batch_inputs = {
|
|
145
|
+
key: torch.stack(
|
|
146
|
+
[inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
147
|
+
)
|
|
148
|
+
for key in context.inputs[0].keys()
|
|
149
|
+
}
|
|
150
|
+
panopticon_inputs = self.prepare_input(batch_inputs)
|
|
151
|
+
output_features = self.model.forward_features(panopticon_inputs)[
|
|
152
|
+
"x_norm_patchtokens"
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
num_tokens = output_features.shape[1]
|
|
156
|
+
height = int(math.sqrt(num_tokens))
|
|
157
|
+
output_features = rearrange(
|
|
158
|
+
output_features, "b (h w) d -> b d h w", h=height, w=height
|
|
159
|
+
)
|
|
160
|
+
return FeatureMaps([output_features])
|
|
161
|
+
|
|
162
|
+
def get_backbone_channels(self) -> list:
|
|
163
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
164
|
+
|
|
165
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
166
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
167
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
168
|
+
has 32 channels.
|
|
169
|
+
"""
|
|
170
|
+
return [(self.patch_size, self.output_dim)]
|