rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1147 @@
|
|
|
1
|
+
"""Prithvi V2.
|
|
2
|
+
|
|
3
|
+
This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
4
|
+
|
|
5
|
+
The code is released under:
|
|
6
|
+
|
|
7
|
+
MIT License
|
|
8
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import tempfile
|
|
14
|
+
import warnings
|
|
15
|
+
from enum import StrEnum
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
from einops import rearrange
|
|
23
|
+
from huggingface_hub import hf_hub_download
|
|
24
|
+
from timm.layers import to_2tuple
|
|
25
|
+
from timm.models.vision_transformer import Block
|
|
26
|
+
from torch.nn import functional as F
|
|
27
|
+
|
|
28
|
+
from rslearn.train.model_context import ModelContext
|
|
29
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
30
|
+
from rslearn.train.transforms.transform import Transform
|
|
31
|
+
|
|
32
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class PrithviV2Models(StrEnum):
|
|
38
|
+
"""Names for different Prithvi models on torch hub."""
|
|
39
|
+
|
|
40
|
+
VIT_300 = "VIT_300"
|
|
41
|
+
VIT_600 = "VIT_600"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
MODEL_TO_HF_INFO = {
|
|
45
|
+
PrithviV2Models.VIT_300: {
|
|
46
|
+
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
|
|
47
|
+
"weights": "Prithvi_EO_V2_300M.pt",
|
|
48
|
+
"revision": "b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
|
|
49
|
+
},
|
|
50
|
+
PrithviV2Models.VIT_600: {
|
|
51
|
+
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M",
|
|
52
|
+
"weights": "Prithvi_EO_V2_600M.pt",
|
|
53
|
+
"revision": "87f15784813828dc37aa3197a143cd4689e4d080",
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
HF_HUB_CONFIG_FNAME = "config.json"
|
|
59
|
+
DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), "rslearn_cache", "prithvi_v2")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[str, Any]:
|
|
63
|
+
"""Get the JSON config dict.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
cache_dir: the directory to cache the config.json file, which will be
|
|
67
|
+
downloaded from HF Hub.
|
|
68
|
+
hf_hub_id: the HF Hub ID from which to download the config.
|
|
69
|
+
hf_hub_revision: The revision (commit) to download the config from.
|
|
70
|
+
"""
|
|
71
|
+
cache_fname = cache_dir / HF_HUB_CONFIG_FNAME
|
|
72
|
+
if not cache_fname.exists():
|
|
73
|
+
_ = hf_hub_download(
|
|
74
|
+
local_dir=cache_dir,
|
|
75
|
+
repo_id=hf_hub_id,
|
|
76
|
+
filename=HF_HUB_CONFIG_FNAME,
|
|
77
|
+
revision=hf_hub_revision,
|
|
78
|
+
) # nosec
|
|
79
|
+
with cache_fname.open() as f:
|
|
80
|
+
return json.load(f)["pretrained_cfg"]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class PrithviV2(FeatureExtractor):
|
|
84
|
+
"""An Rslearn wrapper for Prithvi 2.0."""
|
|
85
|
+
|
|
86
|
+
INPUT_KEY = "image"
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
cache_dir: str | Path | None = None,
|
|
91
|
+
size: PrithviV2Models = PrithviV2Models.VIT_300,
|
|
92
|
+
num_frames: int = 1,
|
|
93
|
+
):
|
|
94
|
+
"""Create a new PrithviV2.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
cache_dir: The local folder in which to download the prithvi config and
|
|
98
|
+
weights. If None, it downloads to a temporary folder.
|
|
99
|
+
size: the model size, see class for various models.
|
|
100
|
+
num_frames: The number of input frames (timesteps). The model was trained on 3,
|
|
101
|
+
but if there is just one timestamp examples use 1 (e.g.
|
|
102
|
+
https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
|
|
103
|
+
example_landslide4sense.ipynb)
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
super().__init__()
|
|
107
|
+
if cache_dir is None:
|
|
108
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
109
|
+
cache_dir = Path(cache_dir)
|
|
110
|
+
|
|
111
|
+
hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
|
|
112
|
+
revision = MODEL_TO_HF_INFO[size]["revision"]
|
|
113
|
+
checkpoint_fname = MODEL_TO_HF_INFO[size]["weights"]
|
|
114
|
+
|
|
115
|
+
config = get_config(cache_dir, hub_id, revision)
|
|
116
|
+
config["num_frames"] = num_frames
|
|
117
|
+
self.model = PrithviMAE(**config)
|
|
118
|
+
|
|
119
|
+
if not (cache_dir / checkpoint_fname).exists():
|
|
120
|
+
_ = hf_hub_download(
|
|
121
|
+
local_dir=cache_dir,
|
|
122
|
+
repo_id=hub_id,
|
|
123
|
+
filename=checkpoint_fname,
|
|
124
|
+
revision=revision,
|
|
125
|
+
) # nosec
|
|
126
|
+
|
|
127
|
+
state_dict = torch.load(
|
|
128
|
+
cache_dir / checkpoint_fname,
|
|
129
|
+
map_location="cpu",
|
|
130
|
+
weights_only=True,
|
|
131
|
+
)
|
|
132
|
+
# discard fixed pos_embedding weight, following
|
|
133
|
+
# https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362
|
|
134
|
+
for k in list(state_dict.keys()):
|
|
135
|
+
if "pos_embed" in k:
|
|
136
|
+
del state_dict[k]
|
|
137
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
138
|
+
self.image_resolution = config["img_size"]
|
|
139
|
+
self.bands = config["bands"]
|
|
140
|
+
# patch size is a list [t, h, w], where h == w
|
|
141
|
+
self.patch_size = config["patch_size"][-1]
|
|
142
|
+
|
|
143
|
+
def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""Process individual modality data.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
data: Input tensor of shape [B, C, T, H, W]
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
list of tensors of shape [B, C, T, H, W]
|
|
151
|
+
"""
|
|
152
|
+
# Get original dimensions
|
|
153
|
+
B, C, T, H, W = data.shape
|
|
154
|
+
data = rearrange(data, "b c t h w -> b (c t) h w")
|
|
155
|
+
original_height = H
|
|
156
|
+
new_height = self.patch_size if original_height == 1 else self.image_resolution
|
|
157
|
+
data = F.interpolate(
|
|
158
|
+
data,
|
|
159
|
+
size=(new_height, new_height),
|
|
160
|
+
mode="bilinear",
|
|
161
|
+
align_corners=False,
|
|
162
|
+
)
|
|
163
|
+
data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
|
|
164
|
+
return data
|
|
165
|
+
|
|
166
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
167
|
+
"""Compute feature maps from the Prithvi V2 backbone.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
context: the model context. Input dicts must include "image" key containing
|
|
171
|
+
HLS (Harmonized Landsat-Sentinel) data.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
|
|
175
|
+
feature maps across the 11 transformer blocks.
|
|
176
|
+
"""
|
|
177
|
+
# x has shape BCTHW
|
|
178
|
+
x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
|
|
179
|
+
x = self._resize_data(x)
|
|
180
|
+
features = self.model.encoder.forward_features(x)
|
|
181
|
+
# prepare_features_for_image_model was slightly modified since we already
|
|
182
|
+
# know the number of timesteps and don't need to recompute it.
|
|
183
|
+
# in addition we average along the time dimension (instead of concatenating)
|
|
184
|
+
# to keep the embeddings reasonably sized.
|
|
185
|
+
result = self.model.encoder.prepare_features_for_image_model(
|
|
186
|
+
features, x.shape[2]
|
|
187
|
+
)
|
|
188
|
+
return FeatureMaps([torch.cat(result, dim=1)])
|
|
189
|
+
|
|
190
|
+
def get_backbone_channels(self) -> list:
|
|
191
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
192
|
+
|
|
193
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
194
|
+
to the feature maps that the backbone returns.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
198
|
+
"""
|
|
199
|
+
return [(1, 1024)]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class PrithviNormalize(Transform):
|
|
203
|
+
"""Normalize inputs using Prithvi normalization.
|
|
204
|
+
|
|
205
|
+
Similar to the model, the input should be an image time series under the key
|
|
206
|
+
"image".
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
cache_dir: str | Path | None = None,
|
|
212
|
+
size: PrithviV2Models = PrithviV2Models.VIT_300,
|
|
213
|
+
) -> None:
|
|
214
|
+
"""Initialize a new PrithviNormalize.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
cache_dir: the local directory to cache the config.json which contains the
|
|
218
|
+
means and standard deviations used in the normalization.
|
|
219
|
+
size: the model size, see class for various models. In this case (and
|
|
220
|
+
for the current hf revision), the config values (mean and std) are the
|
|
221
|
+
same for both the 300M and 600M model, so its safe to not set this.
|
|
222
|
+
"""
|
|
223
|
+
super().__init__()
|
|
224
|
+
hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
|
|
225
|
+
revision = MODEL_TO_HF_INFO[size]["revision"]
|
|
226
|
+
if cache_dir is None:
|
|
227
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
228
|
+
cache_dir = Path(cache_dir)
|
|
229
|
+
config = get_config(cache_dir, hub_id, revision)
|
|
230
|
+
self.normalizer = Normalize(
|
|
231
|
+
mean=config["mean"],
|
|
232
|
+
std=config["std"],
|
|
233
|
+
num_bands=len(config["mean"]),
|
|
234
|
+
selectors=[PrithviV2.INPUT_KEY],
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def forward(
|
|
238
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
239
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
240
|
+
"""Apply Prithvi normalization on the image.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
input_dict: the input, which must contain the "image" key.
|
|
244
|
+
target_dict: the target
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
normalized (input_dicts, target_dicts) tuple
|
|
248
|
+
"""
|
|
249
|
+
return self.normalizer(input_dict, target_dict)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
|
253
|
+
#
|
|
254
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
255
|
+
# you may not use this file except in compliance with the License.
|
|
256
|
+
# You may obtain a copy of the License at
|
|
257
|
+
#
|
|
258
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
259
|
+
#
|
|
260
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
261
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
262
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
263
|
+
# See the License for the specific language governing permissions and
|
|
264
|
+
# limitations under the License.
|
|
265
|
+
# --------------------------------------------------------
|
|
266
|
+
# References:
|
|
267
|
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
|
268
|
+
# transformers: https://github.com/huggingface/transformers
|
|
269
|
+
# --------------------------------------------------------
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_3d_sincos_pos_embed(
|
|
273
|
+
embed_dim: int,
|
|
274
|
+
grid_size: tuple[int, int, int] | list[int],
|
|
275
|
+
add_cls_token: bool = False,
|
|
276
|
+
) -> torch.Tensor:
|
|
277
|
+
"""Create 3D sin/cos positional embeddings.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
embed_dim (int):
|
|
281
|
+
Embedding dimension.
|
|
282
|
+
grid_size (tuple[int, int, int] | list[int]):
|
|
283
|
+
The grid depth, height and width.
|
|
284
|
+
add_cls_token (bool, *optional*, defaults to False):
|
|
285
|
+
Whether or not to add a classification (CLS) token.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
|
289
|
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
|
290
|
+
"""
|
|
291
|
+
assert embed_dim % 16 == 0
|
|
292
|
+
|
|
293
|
+
t_size, h_size, w_size = grid_size
|
|
294
|
+
|
|
295
|
+
w_embed_dim = embed_dim // 16 * 6
|
|
296
|
+
h_embed_dim = embed_dim // 16 * 6
|
|
297
|
+
t_embed_dim = embed_dim // 16 * 4
|
|
298
|
+
|
|
299
|
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
|
300
|
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
|
301
|
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
|
302
|
+
|
|
303
|
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
|
304
|
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
|
305
|
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
|
306
|
+
|
|
307
|
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
|
308
|
+
|
|
309
|
+
if add_cls_token:
|
|
310
|
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
311
|
+
return pos_embed
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def get_1d_sincos_pos_embed_from_grid(
|
|
315
|
+
embed_dim: int, pos: torch.Tensor
|
|
316
|
+
) -> torch.Tensor:
|
|
317
|
+
"""embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)."""
|
|
318
|
+
if embed_dim % 2 != 0:
|
|
319
|
+
raise ValueError("embed_dim must be even")
|
|
320
|
+
|
|
321
|
+
omega = np.arange(embed_dim // 2, dtype=float)
|
|
322
|
+
omega /= embed_dim / 2.0
|
|
323
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
324
|
+
|
|
325
|
+
pos = pos.reshape(-1) # (M,)
|
|
326
|
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
327
|
+
|
|
328
|
+
emb_sin = np.sin(out) # (M, D/2)
|
|
329
|
+
emb_cos = np.cos(out) # (M, D/2)
|
|
330
|
+
|
|
331
|
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
332
|
+
return emb
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _get_1d_sincos_embed_from_grid_torch(
|
|
336
|
+
embed_dim: int, pos: torch.Tensor
|
|
337
|
+
) -> torch.Tensor:
|
|
338
|
+
"""Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
|
|
339
|
+
|
|
340
|
+
embed_dim: output dimension for each position
|
|
341
|
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
|
342
|
+
out: (M, D)
|
|
343
|
+
"""
|
|
344
|
+
assert embed_dim % 2 == 0
|
|
345
|
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
346
|
+
|
|
347
|
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
|
348
|
+
omega /= embed_dim / 2.0
|
|
349
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
350
|
+
|
|
351
|
+
pos = pos.reshape(-1) # (M,)
|
|
352
|
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
353
|
+
|
|
354
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
|
355
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
|
356
|
+
|
|
357
|
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
358
|
+
|
|
359
|
+
return emb
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _init_weights(module: nn.Module) -> None:
|
|
363
|
+
"""Initialize the weights."""
|
|
364
|
+
if isinstance(module, nn.Linear):
|
|
365
|
+
nn.init.xavier_uniform_(module.weight)
|
|
366
|
+
if module.bias is not None:
|
|
367
|
+
module.bias.data.zero_()
|
|
368
|
+
elif isinstance(module, nn.LayerNorm):
|
|
369
|
+
module.bias.data.zero_()
|
|
370
|
+
module.weight.data.fill_(1.0)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _interpolate_pos_encoding(
|
|
374
|
+
pos_embed: torch.Tensor,
|
|
375
|
+
grid_size: tuple[int, int, int] | list[int],
|
|
376
|
+
patch_size: tuple[int, int, int] | list[int],
|
|
377
|
+
shape: tuple[int, int, int] | list[int],
|
|
378
|
+
embed_dim: int,
|
|
379
|
+
) -> torch.Tensor:
|
|
380
|
+
"""_interpolate_pos_encoding.
|
|
381
|
+
|
|
382
|
+
Adapted from:
|
|
383
|
+
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
|
|
384
|
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
|
|
385
|
+
"""
|
|
386
|
+
t, h, w = shape
|
|
387
|
+
t_patches = t // patch_size[0]
|
|
388
|
+
h_patches = h // patch_size[1]
|
|
389
|
+
w_patches = w // patch_size[2]
|
|
390
|
+
|
|
391
|
+
if [t_patches, h_patches, w_patches] == grid_size:
|
|
392
|
+
# No interpolation needed
|
|
393
|
+
return pos_embed
|
|
394
|
+
if t_patches != grid_size[0]:
|
|
395
|
+
# Re-compute pos embedding to handle changed num_frames
|
|
396
|
+
new_grid_size = (t_patches, *grid_size[1:])
|
|
397
|
+
new_pos_embed = get_3d_sincos_pos_embed(
|
|
398
|
+
pos_embed.shape[-1], new_grid_size, add_cls_token=True
|
|
399
|
+
)
|
|
400
|
+
new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
|
|
401
|
+
else:
|
|
402
|
+
new_grid_size = grid_size # type: ignore
|
|
403
|
+
new_pos_embed = pos_embed
|
|
404
|
+
|
|
405
|
+
class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
|
|
406
|
+
|
|
407
|
+
patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(
|
|
408
|
+
0, 3, 1, 2
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
patch_pos_embed = nn.functional.interpolate(
|
|
412
|
+
patch_pos_embed,
|
|
413
|
+
size=(h_patches, w_patches),
|
|
414
|
+
mode="bicubic",
|
|
415
|
+
align_corners=True,
|
|
416
|
+
)
|
|
417
|
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
|
|
418
|
+
|
|
419
|
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class PatchEmbed(nn.Module):
|
|
423
|
+
"""3D version of timm.models.vision_transformer.PatchEmbed."""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
input_size: tuple[int, int, int] = (1, 224, 224),
|
|
428
|
+
patch_size: tuple[int, int, int] = (1, 16, 16),
|
|
429
|
+
in_chans: int = 3,
|
|
430
|
+
embed_dim: int = 768,
|
|
431
|
+
norm_layer: nn.Module | None = None,
|
|
432
|
+
flatten: bool = True,
|
|
433
|
+
bias: bool = True,
|
|
434
|
+
) -> None:
|
|
435
|
+
"""Init."""
|
|
436
|
+
super().__init__()
|
|
437
|
+
self.input_size = input_size
|
|
438
|
+
self.patch_size = patch_size
|
|
439
|
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
|
440
|
+
assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
|
|
441
|
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
|
442
|
+
self.flatten = flatten
|
|
443
|
+
|
|
444
|
+
self.proj = nn.Conv3d(
|
|
445
|
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
|
446
|
+
)
|
|
447
|
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
448
|
+
|
|
449
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
450
|
+
"""Forward."""
|
|
451
|
+
B, C, T, H, W = x.shape
|
|
452
|
+
|
|
453
|
+
if (
|
|
454
|
+
T / self.patch_size[0] % 1
|
|
455
|
+
or H / self.patch_size[1] % 1
|
|
456
|
+
or W / self.patch_size[2] % 1
|
|
457
|
+
):
|
|
458
|
+
warnings.warn(
|
|
459
|
+
f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
|
460
|
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks."
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
x = self.proj(x)
|
|
464
|
+
if self.flatten:
|
|
465
|
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
|
466
|
+
x = self.norm(x)
|
|
467
|
+
return x
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class TemporalEncoder(nn.Module):
|
|
471
|
+
"""TemporalEncoder."""
|
|
472
|
+
|
|
473
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
474
|
+
"""Init."""
|
|
475
|
+
super().__init__()
|
|
476
|
+
self.embed_dim = embed_dim
|
|
477
|
+
self.year_embed_dim = embed_dim // 2
|
|
478
|
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
|
479
|
+
|
|
480
|
+
# If trainable, initialize scale with small number
|
|
481
|
+
if trainable_scale:
|
|
482
|
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
|
483
|
+
else:
|
|
484
|
+
self.register_buffer("scale", torch.ones(1))
|
|
485
|
+
|
|
486
|
+
def forward(
|
|
487
|
+
self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
|
|
488
|
+
) -> torch.Tensor:
|
|
489
|
+
"""Forward.
|
|
490
|
+
|
|
491
|
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
|
492
|
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
|
493
|
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
|
494
|
+
"""
|
|
495
|
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
|
496
|
+
|
|
497
|
+
year = _get_1d_sincos_embed_from_grid_torch(
|
|
498
|
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()
|
|
499
|
+
).reshape(shape)
|
|
500
|
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
|
501
|
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
|
|
502
|
+
).reshape(shape)
|
|
503
|
+
|
|
504
|
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
|
505
|
+
|
|
506
|
+
if tokens_per_frame is not None:
|
|
507
|
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
|
508
|
+
|
|
509
|
+
return embedding # B, T*tokens_per_frame, embed_dim
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class LocationEncoder(nn.Module):
|
|
513
|
+
"""LocationEncoder."""
|
|
514
|
+
|
|
515
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
516
|
+
"""Init."""
|
|
517
|
+
super().__init__()
|
|
518
|
+
self.embed_dim = embed_dim
|
|
519
|
+
self.lat_embed_dim = embed_dim // 2
|
|
520
|
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
|
521
|
+
|
|
522
|
+
# If trainable, initialize scale with small number
|
|
523
|
+
if trainable_scale:
|
|
524
|
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
|
525
|
+
else:
|
|
526
|
+
self.register_buffer("scale", torch.ones(1))
|
|
527
|
+
|
|
528
|
+
def forward(self, location_coords: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
"""location_coords: lat and lon info with shape (B, 2)."""
|
|
530
|
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
|
531
|
+
|
|
532
|
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
|
533
|
+
self.lat_embed_dim, location_coords[:, 0].flatten()
|
|
534
|
+
).reshape(shape)
|
|
535
|
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
|
536
|
+
self.lon_embed_dim, location_coords[:, 1].flatten()
|
|
537
|
+
).reshape(shape)
|
|
538
|
+
|
|
539
|
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
|
540
|
+
|
|
541
|
+
return embedding # B, 1, embed_dim
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class PrithviViT(nn.Module):
|
|
545
|
+
"""Prithvi ViT Encoder."""
|
|
546
|
+
|
|
547
|
+
def __init__(
|
|
548
|
+
self,
|
|
549
|
+
img_size: int | tuple[int, int] = 224,
|
|
550
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
551
|
+
num_frames: int = 1,
|
|
552
|
+
in_chans: int = 3,
|
|
553
|
+
embed_dim: int = 1024,
|
|
554
|
+
depth: int = 24,
|
|
555
|
+
num_heads: int = 16,
|
|
556
|
+
mlp_ratio: float = 4.0,
|
|
557
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
558
|
+
coords_encoding: list[str] | None = None,
|
|
559
|
+
coords_scale_learn: bool = False,
|
|
560
|
+
drop_path: float = 0.0,
|
|
561
|
+
**kwargs: Any,
|
|
562
|
+
) -> None:
|
|
563
|
+
"""Init."""
|
|
564
|
+
super().__init__()
|
|
565
|
+
|
|
566
|
+
self.in_chans = in_chans
|
|
567
|
+
self.num_frames = num_frames
|
|
568
|
+
self.embed_dim = embed_dim
|
|
569
|
+
self.img_size = to_2tuple(img_size)
|
|
570
|
+
if isinstance(patch_size, int):
|
|
571
|
+
patch_size = (1, patch_size, patch_size)
|
|
572
|
+
|
|
573
|
+
# 3D patch embedding
|
|
574
|
+
self.patch_embed = PatchEmbed(
|
|
575
|
+
input_size=(num_frames,) + self.img_size,
|
|
576
|
+
patch_size=patch_size,
|
|
577
|
+
in_chans=in_chans,
|
|
578
|
+
embed_dim=embed_dim,
|
|
579
|
+
)
|
|
580
|
+
self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
|
|
581
|
+
|
|
582
|
+
# Optional temporal and location embedding
|
|
583
|
+
coords_encoding = coords_encoding or []
|
|
584
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
585
|
+
self.location_encoding = "location" in coords_encoding
|
|
586
|
+
if self.temporal_encoding:
|
|
587
|
+
assert patch_size[0] == 1, (
|
|
588
|
+
f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
|
589
|
+
)
|
|
590
|
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
|
591
|
+
if self.location_encoding:
|
|
592
|
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
|
593
|
+
|
|
594
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
595
|
+
self.register_buffer(
|
|
596
|
+
"pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Transformer layers
|
|
600
|
+
self.blocks = []
|
|
601
|
+
for i in range(depth):
|
|
602
|
+
self.blocks.append(
|
|
603
|
+
Block(
|
|
604
|
+
embed_dim,
|
|
605
|
+
num_heads,
|
|
606
|
+
mlp_ratio,
|
|
607
|
+
qkv_bias=True,
|
|
608
|
+
norm_layer=norm_layer,
|
|
609
|
+
drop_path=drop_path,
|
|
610
|
+
)
|
|
611
|
+
)
|
|
612
|
+
self.blocks = nn.ModuleList(self.blocks)
|
|
613
|
+
|
|
614
|
+
self.norm = norm_layer(embed_dim)
|
|
615
|
+
|
|
616
|
+
self.initialize_weights()
|
|
617
|
+
|
|
618
|
+
def initialize_weights(self) -> None:
|
|
619
|
+
"""initialize_weights."""
|
|
620
|
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
|
621
|
+
pos_embed = get_3d_sincos_pos_embed(
|
|
622
|
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
|
623
|
+
)
|
|
624
|
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
|
625
|
+
|
|
626
|
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
|
627
|
+
w = self.patch_embed.proj.weight.data
|
|
628
|
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
629
|
+
|
|
630
|
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
|
631
|
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
|
632
|
+
self.apply(_init_weights)
|
|
633
|
+
|
|
634
|
+
def random_masking(
|
|
635
|
+
self,
|
|
636
|
+
sequence: torch.Tensor,
|
|
637
|
+
mask_ratio: float,
|
|
638
|
+
noise: None | torch.Tensor = None,
|
|
639
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
640
|
+
"""Perform per-sample random masking by per-sample shuffling.
|
|
641
|
+
|
|
642
|
+
Per-sample shuffling is done by argsort random
|
|
643
|
+
noise.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
sequence: (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
|
647
|
+
mask_ratio: (float): mask ratio to use.
|
|
648
|
+
noise: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
|
649
|
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
|
650
|
+
"""
|
|
651
|
+
batch_size, seq_length, dim = sequence.shape
|
|
652
|
+
len_keep = int(seq_length * (1 - mask_ratio))
|
|
653
|
+
|
|
654
|
+
if noise is None:
|
|
655
|
+
noise = torch.rand(
|
|
656
|
+
batch_size, seq_length, device=sequence.device
|
|
657
|
+
) # noise in [0, 1]
|
|
658
|
+
|
|
659
|
+
# sort noise for each sample
|
|
660
|
+
ids_shuffle = torch.argsort(noise, dim=1).to(
|
|
661
|
+
sequence.device
|
|
662
|
+
) # ascend: small is keep, large is remove
|
|
663
|
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
|
664
|
+
|
|
665
|
+
# keep the first subset
|
|
666
|
+
ids_keep = ids_shuffle[:, :len_keep]
|
|
667
|
+
sequence_unmasked = torch.gather(
|
|
668
|
+
sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# generate the binary mask: 0 is keep, 1 is remove
|
|
672
|
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
|
673
|
+
mask[:, :len_keep] = 0
|
|
674
|
+
# unshuffle to get the binary mask
|
|
675
|
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
|
676
|
+
|
|
677
|
+
return sequence_unmasked, mask, ids_restore
|
|
678
|
+
|
|
679
|
+
def interpolate_pos_encoding(
|
|
680
|
+
self, sample_shape: tuple[int, int, int] | list[int]
|
|
681
|
+
) -> torch.Tensor:
|
|
682
|
+
"""interpolate_pos_encoding."""
|
|
683
|
+
pos_embed = _interpolate_pos_encoding(
|
|
684
|
+
pos_embed=self.pos_embed,
|
|
685
|
+
grid_size=self.patch_embed.grid_size,
|
|
686
|
+
patch_size=self.patch_embed.patch_size,
|
|
687
|
+
shape=sample_shape,
|
|
688
|
+
embed_dim=self.embed_dim,
|
|
689
|
+
)
|
|
690
|
+
return pos_embed
|
|
691
|
+
|
|
692
|
+
def forward(
|
|
693
|
+
self,
|
|
694
|
+
x: torch.Tensor,
|
|
695
|
+
temporal_coords: None | torch.Tensor = None,
|
|
696
|
+
location_coords: None | torch.Tensor = None,
|
|
697
|
+
mask_ratio: float = 0.75,
|
|
698
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
699
|
+
"""Forward."""
|
|
700
|
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
|
701
|
+
# add time dim
|
|
702
|
+
x = x.unsqueeze(2)
|
|
703
|
+
sample_shape = x.shape[-3:]
|
|
704
|
+
|
|
705
|
+
# embed patches
|
|
706
|
+
x = self.patch_embed(x)
|
|
707
|
+
|
|
708
|
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
|
709
|
+
# add pos embed w/o cls token
|
|
710
|
+
x = x + pos_embed[:, 1:, :]
|
|
711
|
+
|
|
712
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
713
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
714
|
+
temporal_encoding = self.temporal_embed_enc(
|
|
715
|
+
temporal_coords, num_tokens_per_frame
|
|
716
|
+
)
|
|
717
|
+
x = x + temporal_encoding
|
|
718
|
+
if self.location_encoding and location_coords is not None:
|
|
719
|
+
location_encoding = self.location_embed_enc(location_coords)
|
|
720
|
+
x = x + location_encoding
|
|
721
|
+
|
|
722
|
+
# masking: length -> length * mask_ratio
|
|
723
|
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
|
724
|
+
|
|
725
|
+
# append cls token
|
|
726
|
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
|
727
|
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
728
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
729
|
+
|
|
730
|
+
# apply Transformer blocks
|
|
731
|
+
for block in self.blocks:
|
|
732
|
+
x = block(x)
|
|
733
|
+
x = self.norm(x)
|
|
734
|
+
|
|
735
|
+
return x, mask, ids_restore
|
|
736
|
+
|
|
737
|
+
def forward_features(
|
|
738
|
+
self,
|
|
739
|
+
x: torch.Tensor,
|
|
740
|
+
temporal_coords: None | torch.Tensor = None,
|
|
741
|
+
location_coords: None | torch.Tensor = None,
|
|
742
|
+
) -> list[torch.Tensor]:
|
|
743
|
+
"""forward_features."""
|
|
744
|
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
|
745
|
+
# add time dim
|
|
746
|
+
x = x.unsqueeze(2)
|
|
747
|
+
sample_shape = x.shape[-3:]
|
|
748
|
+
|
|
749
|
+
# embed patches
|
|
750
|
+
x = self.patch_embed(x)
|
|
751
|
+
|
|
752
|
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
|
753
|
+
# add pos embed w/o cls token
|
|
754
|
+
x = x + pos_embed[:, 1:, :]
|
|
755
|
+
|
|
756
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
757
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
758
|
+
temporal_encoding = self.temporal_embed_enc(
|
|
759
|
+
temporal_coords, num_tokens_per_frame
|
|
760
|
+
)
|
|
761
|
+
x = x + temporal_encoding
|
|
762
|
+
if self.location_encoding and location_coords is not None:
|
|
763
|
+
location_encoding = self.location_embed_enc(location_coords)
|
|
764
|
+
x = x + location_encoding
|
|
765
|
+
|
|
766
|
+
# append cls token
|
|
767
|
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
|
768
|
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
769
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
770
|
+
|
|
771
|
+
# apply Transformer blocks
|
|
772
|
+
out = []
|
|
773
|
+
for block in self.blocks:
|
|
774
|
+
x = block(x)
|
|
775
|
+
out.append(x.clone())
|
|
776
|
+
|
|
777
|
+
x = self.norm(x)
|
|
778
|
+
out[-1] = x
|
|
779
|
+
return out
|
|
780
|
+
|
|
781
|
+
def prepare_features_for_image_model(
|
|
782
|
+
self, features: list[torch.Tensor], t: int
|
|
783
|
+
) -> list[torch.Tensor]:
|
|
784
|
+
"""prepare_features_for_image_model."""
|
|
785
|
+
out = []
|
|
786
|
+
for x in features:
|
|
787
|
+
x_no_token = x[:, 1:, :]
|
|
788
|
+
number_of_tokens = x_no_token.shape[1]
|
|
789
|
+
tokens_per_timestep = number_of_tokens // t
|
|
790
|
+
h = int(np.sqrt(tokens_per_timestep))
|
|
791
|
+
encoded = rearrange(
|
|
792
|
+
x_no_token,
|
|
793
|
+
"batch (t h w) e -> batch t e h w",
|
|
794
|
+
e=self.embed_dim,
|
|
795
|
+
t=t,
|
|
796
|
+
h=h,
|
|
797
|
+
)
|
|
798
|
+
# mean along the time dimension
|
|
799
|
+
out.append(encoded.mean(dim=1))
|
|
800
|
+
return out
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class MAEDecoder(nn.Module):
|
|
804
|
+
"""Transformer Decoder used in the Prithvi MAE."""
|
|
805
|
+
|
|
806
|
+
def __init__(
|
|
807
|
+
self,
|
|
808
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
809
|
+
grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
|
|
810
|
+
in_chans: int = 3,
|
|
811
|
+
encoder_embed_dim: int = 1024,
|
|
812
|
+
decoder_embed_dim: int = 512,
|
|
813
|
+
depth: int = 8,
|
|
814
|
+
num_heads: int = 16,
|
|
815
|
+
mlp_ratio: float = 4.0,
|
|
816
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
817
|
+
coords_encoding: list[str] | None = None,
|
|
818
|
+
coords_scale_learn: bool = False,
|
|
819
|
+
) -> None:
|
|
820
|
+
"""Init."""
|
|
821
|
+
super().__init__()
|
|
822
|
+
|
|
823
|
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
|
824
|
+
self.decoder_embed_dim = decoder_embed_dim
|
|
825
|
+
self.grid_size = grid_size
|
|
826
|
+
if isinstance(patch_size, int):
|
|
827
|
+
patch_size = (1, patch_size, patch_size)
|
|
828
|
+
self.patch_size = patch_size
|
|
829
|
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
|
830
|
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
|
831
|
+
|
|
832
|
+
# Optional temporal and location embedding
|
|
833
|
+
coords_encoding = coords_encoding or []
|
|
834
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
835
|
+
self.location_encoding = "location" in coords_encoding
|
|
836
|
+
if self.temporal_encoding:
|
|
837
|
+
self.temporal_embed_dec = TemporalEncoder(
|
|
838
|
+
decoder_embed_dim, coords_scale_learn
|
|
839
|
+
)
|
|
840
|
+
if self.location_encoding:
|
|
841
|
+
self.location_embed_dec = LocationEncoder(
|
|
842
|
+
decoder_embed_dim, coords_scale_learn
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
|
846
|
+
|
|
847
|
+
self.register_buffer(
|
|
848
|
+
"decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
self.decoder_blocks = nn.ModuleList(
|
|
852
|
+
[
|
|
853
|
+
Block(
|
|
854
|
+
decoder_embed_dim,
|
|
855
|
+
num_heads,
|
|
856
|
+
mlp_ratio,
|
|
857
|
+
qkv_bias=True,
|
|
858
|
+
norm_layer=norm_layer,
|
|
859
|
+
)
|
|
860
|
+
for _ in range(depth)
|
|
861
|
+
]
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
|
865
|
+
self.decoder_pred = nn.Linear(
|
|
866
|
+
decoder_embed_dim,
|
|
867
|
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
|
868
|
+
bias=True,
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
self.initialize_weights()
|
|
872
|
+
|
|
873
|
+
def initialize_weights(self) -> None:
|
|
874
|
+
"""initialize_weights."""
|
|
875
|
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
|
876
|
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
|
877
|
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
|
878
|
+
)
|
|
879
|
+
self.decoder_pos_embed.data.copy_(
|
|
880
|
+
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
|
884
|
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
|
885
|
+
self.apply(_init_weights)
|
|
886
|
+
|
|
887
|
+
def interpolate_pos_encoding(
|
|
888
|
+
self, sample_shape: tuple[int, int, int]
|
|
889
|
+
) -> torch.Tensor:
|
|
890
|
+
"""interpolate_pos_encoding."""
|
|
891
|
+
pos_embed = _interpolate_pos_encoding(
|
|
892
|
+
pos_embed=self.decoder_pos_embed,
|
|
893
|
+
grid_size=self.grid_size,
|
|
894
|
+
patch_size=self.patch_size,
|
|
895
|
+
shape=sample_shape,
|
|
896
|
+
embed_dim=self.decoder_embed_dim,
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
return pos_embed
|
|
900
|
+
|
|
901
|
+
def forward(
|
|
902
|
+
self,
|
|
903
|
+
hidden_states: torch.Tensor,
|
|
904
|
+
ids_restore: torch.Tensor,
|
|
905
|
+
temporal_coords: None | torch.Tensor = None,
|
|
906
|
+
location_coords: None | torch.Tensor = None,
|
|
907
|
+
input_size: list[int] | None = None,
|
|
908
|
+
) -> torch.Tensor:
|
|
909
|
+
"""Forward."""
|
|
910
|
+
# embed tokens
|
|
911
|
+
x = self.decoder_embed(hidden_states)
|
|
912
|
+
cls_token = x[:, :1, :]
|
|
913
|
+
|
|
914
|
+
# append mask tokens to sequence
|
|
915
|
+
mask_tokens = self.mask_token.repeat(
|
|
916
|
+
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
|
|
917
|
+
)
|
|
918
|
+
x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
|
919
|
+
# unshuffle
|
|
920
|
+
x = torch.gather(
|
|
921
|
+
x,
|
|
922
|
+
dim=1,
|
|
923
|
+
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device),
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
# add pos embed
|
|
927
|
+
decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) # type: ignore
|
|
928
|
+
cls_token = cls_token + decoder_pos_embed[:, :1, :]
|
|
929
|
+
x = x + decoder_pos_embed[:, 1:, :]
|
|
930
|
+
|
|
931
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
932
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
933
|
+
temporal_encoding = self.temporal_embed_dec(
|
|
934
|
+
temporal_coords, num_tokens_per_frame
|
|
935
|
+
)
|
|
936
|
+
# Add temporal encoding w/o cls token
|
|
937
|
+
x = x + temporal_encoding
|
|
938
|
+
if self.location_encoding and location_coords is not None:
|
|
939
|
+
location_encoding = self.location_embed_dec(location_coords)
|
|
940
|
+
# Add location encoding w/o cls token
|
|
941
|
+
x = x + location_encoding
|
|
942
|
+
|
|
943
|
+
# append cls token
|
|
944
|
+
x = torch.cat([cls_token, x], dim=1)
|
|
945
|
+
|
|
946
|
+
# apply Transformer layers (blocks)
|
|
947
|
+
for block in self.decoder_blocks:
|
|
948
|
+
x = block(x)
|
|
949
|
+
x = self.decoder_norm(x)
|
|
950
|
+
|
|
951
|
+
# predictor projection
|
|
952
|
+
pred = self.decoder_pred(x)
|
|
953
|
+
|
|
954
|
+
# remove cls token
|
|
955
|
+
pred = pred[:, 1:, :]
|
|
956
|
+
|
|
957
|
+
return pred
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
class PrithviMAE(nn.Module):
|
|
961
|
+
"""Prithvi Masked Autoencoder."""
|
|
962
|
+
|
|
963
|
+
def __init__(
|
|
964
|
+
self,
|
|
965
|
+
img_size: int | tuple[int, int] = 224,
|
|
966
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
967
|
+
num_frames: int = 4,
|
|
968
|
+
in_chans: int = 6,
|
|
969
|
+
embed_dim: int = 768,
|
|
970
|
+
depth: int = 12,
|
|
971
|
+
num_heads: int = 12,
|
|
972
|
+
decoder_embed_dim: int = 512,
|
|
973
|
+
decoder_depth: int = 8,
|
|
974
|
+
decoder_num_heads: int = 16,
|
|
975
|
+
mlp_ratio: float = 4.0,
|
|
976
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
977
|
+
norm_pix_loss: bool = False,
|
|
978
|
+
coords_encoding: list[str] | None = None,
|
|
979
|
+
coords_scale_learn: bool = False,
|
|
980
|
+
drop_path: float = 0.0,
|
|
981
|
+
mask_ratio: float = 0.75,
|
|
982
|
+
**kwargs: Any,
|
|
983
|
+
):
|
|
984
|
+
"""Init."""
|
|
985
|
+
super().__init__()
|
|
986
|
+
|
|
987
|
+
self.encoder = PrithviViT(
|
|
988
|
+
img_size=img_size,
|
|
989
|
+
num_frames=num_frames,
|
|
990
|
+
patch_size=patch_size,
|
|
991
|
+
in_chans=in_chans,
|
|
992
|
+
embed_dim=embed_dim,
|
|
993
|
+
depth=depth,
|
|
994
|
+
num_heads=num_heads,
|
|
995
|
+
mlp_ratio=mlp_ratio,
|
|
996
|
+
norm_layer=norm_layer,
|
|
997
|
+
coords_encoding=coords_encoding,
|
|
998
|
+
coords_scale_learn=coords_scale_learn,
|
|
999
|
+
drop_path=drop_path,
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
self.decoder = MAEDecoder(
|
|
1003
|
+
patch_size=patch_size,
|
|
1004
|
+
grid_size=self.encoder.patch_embed.grid_size,
|
|
1005
|
+
in_chans=in_chans,
|
|
1006
|
+
encoder_embed_dim=embed_dim,
|
|
1007
|
+
decoder_embed_dim=decoder_embed_dim,
|
|
1008
|
+
depth=decoder_depth,
|
|
1009
|
+
num_heads=decoder_num_heads,
|
|
1010
|
+
mlp_ratio=mlp_ratio,
|
|
1011
|
+
norm_layer=norm_layer,
|
|
1012
|
+
coords_encoding=coords_encoding,
|
|
1013
|
+
coords_scale_learn=coords_scale_learn,
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
self.mask_ratio = mask_ratio
|
|
1017
|
+
self.norm_pix_loss = norm_pix_loss
|
|
1018
|
+
self.out_channels = self.encoder.out_channels
|
|
1019
|
+
|
|
1020
|
+
def patchify(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
1021
|
+
"""Patchify.
|
|
1022
|
+
|
|
1023
|
+
Args:
|
|
1024
|
+
pixel_values: (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
|
1025
|
+
Pixel values.
|
|
1026
|
+
|
|
1027
|
+
Returns:
|
|
1028
|
+
torch.FloatTensor of shape
|
|
1029
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
|
1030
|
+
Patchified pixel values.
|
|
1031
|
+
"""
|
|
1032
|
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
|
1033
|
+
num_channels = self.encoder.in_chans
|
|
1034
|
+
|
|
1035
|
+
# patchify
|
|
1036
|
+
patchified_pixel_values = rearrange(
|
|
1037
|
+
pixel_values,
|
|
1038
|
+
"b c (t s) (h p) (w q) -> b (t h w) (s p q c)",
|
|
1039
|
+
c=num_channels,
|
|
1040
|
+
s=patch_size_t,
|
|
1041
|
+
p=patch_size_h,
|
|
1042
|
+
q=patch_size_w,
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
return patchified_pixel_values
|
|
1046
|
+
|
|
1047
|
+
def unpatchify(
|
|
1048
|
+
self,
|
|
1049
|
+
patchified_pixel_values: torch.Tensor,
|
|
1050
|
+
image_size: tuple[int, int] | None = None,
|
|
1051
|
+
) -> torch.Tensor:
|
|
1052
|
+
"""Unpatchify.
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
patchified_pixel_values: (`torch.FloatTensor` of shape
|
|
1056
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
|
|
1057
|
+
Patchified pixel values.
|
|
1058
|
+
image_size: (`tuple[int, int]`, *optional*):
|
|
1059
|
+
Original image size.
|
|
1060
|
+
|
|
1061
|
+
Returns:
|
|
1062
|
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
|
1063
|
+
Pixel values.
|
|
1064
|
+
"""
|
|
1065
|
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
|
1066
|
+
image_size = (
|
|
1067
|
+
to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
|
1068
|
+
)
|
|
1069
|
+
original_height, original_width = image_size
|
|
1070
|
+
num_patches_h = original_height // patch_size_h
|
|
1071
|
+
num_patches_w = original_width // patch_size_w
|
|
1072
|
+
num_channels = self.encoder.in_chans
|
|
1073
|
+
|
|
1074
|
+
pixel_values = rearrange(
|
|
1075
|
+
patchified_pixel_values,
|
|
1076
|
+
"b (t h w) (s p q c) -> b c (t s) (h p) (w q)",
|
|
1077
|
+
c=num_channels,
|
|
1078
|
+
h=num_patches_h,
|
|
1079
|
+
w=num_patches_w,
|
|
1080
|
+
s=patch_size_t,
|
|
1081
|
+
p=patch_size_h,
|
|
1082
|
+
q=patch_size_w,
|
|
1083
|
+
)
|
|
1084
|
+
return pixel_values
|
|
1085
|
+
|
|
1086
|
+
def forward_loss(
|
|
1087
|
+
self, pixel_values: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
|
|
1088
|
+
) -> torch.Tensor:
|
|
1089
|
+
"""forward_loss.
|
|
1090
|
+
|
|
1091
|
+
Args:
|
|
1092
|
+
pixel_values: (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
|
1093
|
+
Pixel values.
|
|
1094
|
+
pred: (`torch.FloatTensor` of shape
|
|
1095
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
|
1096
|
+
Predicted pixel values.
|
|
1097
|
+
mask: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
1098
|
+
Tensor indicating which patches are masked (1) and which are not (0).
|
|
1099
|
+
|
|
1100
|
+
Returns:
|
|
1101
|
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
|
1102
|
+
"""
|
|
1103
|
+
target = self.patchify(pixel_values)
|
|
1104
|
+
if self.norm_pix_loss:
|
|
1105
|
+
mean = target.mean(dim=-1, keepdim=True)
|
|
1106
|
+
var = target.var(dim=-1, keepdim=True)
|
|
1107
|
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
|
1108
|
+
|
|
1109
|
+
loss = (pred - target) ** 2
|
|
1110
|
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
|
1111
|
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
|
1112
|
+
return loss
|
|
1113
|
+
|
|
1114
|
+
def forward(
|
|
1115
|
+
self,
|
|
1116
|
+
pixel_values: torch.Tensor,
|
|
1117
|
+
temporal_coords: None | torch.Tensor = None,
|
|
1118
|
+
location_coords: None | torch.Tensor = None,
|
|
1119
|
+
mask_ratio: float | None = None,
|
|
1120
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1121
|
+
"""Forward."""
|
|
1122
|
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
|
1123
|
+
# add time dim
|
|
1124
|
+
pixel_values = pixel_values.unsqueeze(2)
|
|
1125
|
+
|
|
1126
|
+
mask_ratio = mask_ratio or self.mask_ratio
|
|
1127
|
+
latent, mask, ids_restore = self.encoder(
|
|
1128
|
+
pixel_values, temporal_coords, location_coords, mask_ratio
|
|
1129
|
+
)
|
|
1130
|
+
pred = self.decoder(
|
|
1131
|
+
latent,
|
|
1132
|
+
ids_restore,
|
|
1133
|
+
temporal_coords,
|
|
1134
|
+
location_coords,
|
|
1135
|
+
input_size=pixel_values.shape,
|
|
1136
|
+
)
|
|
1137
|
+
loss = self.forward_loss(pixel_values, pred, mask)
|
|
1138
|
+
return loss, pred, mask
|
|
1139
|
+
|
|
1140
|
+
def forward_features(
|
|
1141
|
+
self,
|
|
1142
|
+
x: torch.Tensor,
|
|
1143
|
+
temporal_coords: None | torch.Tensor = None,
|
|
1144
|
+
location_coords: None | torch.Tensor = None,
|
|
1145
|
+
) -> list[torch.Tensor]:
|
|
1146
|
+
"""forward_features."""
|
|
1147
|
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|