rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Presto wrapper to ingest Masked Helios Samples."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import tempfile
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange, repeat
|
|
9
|
+
from huggingface_hub import hf_hub_download
|
|
10
|
+
from upath import UPath
|
|
11
|
+
|
|
12
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
13
|
+
from rslearn.models.presto.single_file_presto import (
|
|
14
|
+
ERA5_BANDS,
|
|
15
|
+
NUM_DYNAMIC_WORLD_CLASSES,
|
|
16
|
+
PRESTO_ADD_BY,
|
|
17
|
+
PRESTO_BANDS,
|
|
18
|
+
PRESTO_DIV_BY,
|
|
19
|
+
PRESTO_S1_BANDS,
|
|
20
|
+
PRESTO_S2_BANDS,
|
|
21
|
+
SRTM_BANDS,
|
|
22
|
+
)
|
|
23
|
+
from rslearn.models.presto.single_file_presto import Presto as SFPresto
|
|
24
|
+
from rslearn.train.model_context import ModelContext
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B09"]
|
|
29
|
+
INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B09"]
|
|
30
|
+
|
|
31
|
+
PRESTO_S1_SUBTRACT_VALUE = -25.0
|
|
32
|
+
PRESTO_S1_DIV_VALUE = 25.0
|
|
33
|
+
PRESTO_S2_SUBTRACT_VALUE = 0.0
|
|
34
|
+
PRESTO_S2_DIV_VALUE = 1e4
|
|
35
|
+
|
|
36
|
+
HF_HUB_ID = "nasaharvest/presto"
|
|
37
|
+
MODEL_FILENAME = "default_model.pt"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Presto(FeatureExtractor):
|
|
41
|
+
"""Presto."""
|
|
42
|
+
|
|
43
|
+
input_keys = [
|
|
44
|
+
"s1",
|
|
45
|
+
"s2",
|
|
46
|
+
"era5",
|
|
47
|
+
"srtm",
|
|
48
|
+
"dynamic_world",
|
|
49
|
+
"latlon",
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
pretrained_path: str | UPath | None = None,
|
|
55
|
+
pixel_batch_size: int = 128,
|
|
56
|
+
):
|
|
57
|
+
"""Initialize the Presto wrapper.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
pretrained_path: The directory to load from
|
|
61
|
+
pixel_batch_size: If the input has a h,w dimension >1, this is
|
|
62
|
+
flattened into a batch dimension (b h w) before being passed
|
|
63
|
+
to the model (since Presto is designed for pixel timeseries).
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
if pretrained_path is None:
|
|
68
|
+
pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "presto")
|
|
69
|
+
if not (UPath(pretrained_path) / MODEL_FILENAME).exists():
|
|
70
|
+
_ = hf_hub_download(
|
|
71
|
+
local_dir=UPath(pretrained_path),
|
|
72
|
+
repo_id=HF_HUB_ID,
|
|
73
|
+
filename=MODEL_FILENAME,
|
|
74
|
+
# pin the model to a specific hugging face commit
|
|
75
|
+
revision="1b97f885969da4e2d5834ca8c92707c737911464",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
model = SFPresto.construct()
|
|
79
|
+
model.load_state_dict(
|
|
80
|
+
torch.load(
|
|
81
|
+
UPath(pretrained_path) / MODEL_FILENAME,
|
|
82
|
+
map_location="cpu",
|
|
83
|
+
weights_only=True,
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
self.pixel_batch_size = pixel_batch_size
|
|
87
|
+
self.model = model.encoder
|
|
88
|
+
self.month = 6 # default month
|
|
89
|
+
|
|
90
|
+
def construct_presto_input(
|
|
91
|
+
self,
|
|
92
|
+
s1: torch.Tensor | None = None,
|
|
93
|
+
s1_bands: torch.Tensor | None = None,
|
|
94
|
+
s2: torch.Tensor | None = None,
|
|
95
|
+
s2_bands: torch.Tensor | None = None,
|
|
96
|
+
era5: torch.Tensor | None = None,
|
|
97
|
+
era5_bands: torch.Tensor | None = None,
|
|
98
|
+
srtm: torch.Tensor | None = None,
|
|
99
|
+
srtm_bands: torch.Tensor | None = None,
|
|
100
|
+
dynamic_world: torch.Tensor | None = None,
|
|
101
|
+
months: torch.Tensor | None = None,
|
|
102
|
+
normalize: bool = True,
|
|
103
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
104
|
+
"""Inputs are paired into a tensor input <X> and a list <X>_bands, which describes <X>.
|
|
105
|
+
|
|
106
|
+
<X> should have shape (b, num_timesteps, h, w len(<X>_bands)), with the following bands for
|
|
107
|
+
each input:
|
|
108
|
+
|
|
109
|
+
s1: ["VV", "VH"]
|
|
110
|
+
s2: ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
|
|
111
|
+
era5: ["temperature_2m", "total_precipitation"]
|
|
112
|
+
"temperature_2m": Temperature of air at 2m above the surface of land,
|
|
113
|
+
sea or in-land waters in Kelvin (K)
|
|
114
|
+
"total_precipitation": Accumulated liquid and frozen water, including rain and snow,
|
|
115
|
+
that falls to the Earth's surface. Measured in metres (m)
|
|
116
|
+
srtm: ["elevation", "slope"]
|
|
117
|
+
|
|
118
|
+
dynamic_world is a 1d input of shape (num_timesteps,) representing the dynamic world classes
|
|
119
|
+
of each timestep for that pixel
|
|
120
|
+
"""
|
|
121
|
+
bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
|
|
122
|
+
ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
|
+
hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
|
|
124
|
+
ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
|
|
125
|
+
devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
|
|
126
|
+
|
|
127
|
+
assert len(set(bs)) == 1
|
|
128
|
+
assert len(set(hs)) == 1
|
|
129
|
+
assert len(set(ws)) == 1
|
|
130
|
+
assert len(set(devices)) == 1
|
|
131
|
+
assert len(set(ts)) == 1
|
|
132
|
+
b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
|
|
133
|
+
# these values will be initialized as
|
|
134
|
+
# we iterate through the data
|
|
135
|
+
x: torch.Tensor | None = None
|
|
136
|
+
mask: torch.Tensor | None = None
|
|
137
|
+
|
|
138
|
+
for band_group in [
|
|
139
|
+
(s1, s1_bands),
|
|
140
|
+
(s2, s2_bands),
|
|
141
|
+
(era5, era5_bands),
|
|
142
|
+
(srtm, srtm_bands),
|
|
143
|
+
]:
|
|
144
|
+
data, input_bands = band_group
|
|
145
|
+
if data is not None:
|
|
146
|
+
assert input_bands is not None
|
|
147
|
+
else:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
data = rearrange(data, "b c t h w -> b t h w c")
|
|
151
|
+
if x is None:
|
|
152
|
+
x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
|
|
153
|
+
if mask is None:
|
|
154
|
+
mask = torch.ones(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
|
|
155
|
+
|
|
156
|
+
# construct a mapping from the input bands to the presto input bands
|
|
157
|
+
input_to_output_mapping = [
|
|
158
|
+
INPUT_PRESTO_BANDS.index(val) for val in input_bands
|
|
159
|
+
]
|
|
160
|
+
x[:, :, :, :, input_to_output_mapping] = data
|
|
161
|
+
mask[:, :, :, :, input_to_output_mapping] = 0
|
|
162
|
+
|
|
163
|
+
assert x is not None
|
|
164
|
+
assert mask is not None
|
|
165
|
+
assert t is not None
|
|
166
|
+
|
|
167
|
+
if dynamic_world is None:
|
|
168
|
+
dynamic_world = (
|
|
169
|
+
torch.ones(b, t, h, w, device=device) * NUM_DYNAMIC_WORLD_CLASSES
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if months is None:
|
|
173
|
+
months = torch.ones((b, t), device=device) * self.month
|
|
174
|
+
else:
|
|
175
|
+
assert months.shape[-1] == t
|
|
176
|
+
|
|
177
|
+
if normalize:
|
|
178
|
+
x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
|
|
179
|
+
return x, mask, dynamic_world.long(), months.long()
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def time_ranges_to_timestamps(
|
|
183
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
184
|
+
device: torch.device,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
|
|
187
|
+
|
|
188
|
+
Presto only uses the month associated with each timestamp, so we take the midpoint
|
|
189
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
190
|
+
time so that start_time == end_time == mid_time.
|
|
191
|
+
"""
|
|
192
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
193
|
+
# months are indexed 0-11
|
|
194
|
+
return torch.tensor(
|
|
195
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
199
|
+
"""Compute feature maps from the Presto backbone.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
context: the model context. Input dicts should have some subset of Presto.input_keys.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
a FeatureMaps with one feature map that is at the same resolution as the
|
|
206
|
+
input (since Presto operates per-pixel).
|
|
207
|
+
"""
|
|
208
|
+
time_modalities = ["s1", "s2", "era5"]
|
|
209
|
+
stacked_inputs = {}
|
|
210
|
+
latlons: torch.Tensor | None = None
|
|
211
|
+
months: torch.Tensor | None = None
|
|
212
|
+
for key in context.inputs[0].keys():
|
|
213
|
+
# assume all the keys in an input are consistent
|
|
214
|
+
if key in self.input_keys:
|
|
215
|
+
if key == "latlon":
|
|
216
|
+
latlons = torch.stack(
|
|
217
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
stacked_inputs[key] = torch.stack(
|
|
221
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
222
|
+
)
|
|
223
|
+
if key in time_modalities:
|
|
224
|
+
if months is None:
|
|
225
|
+
if context.inputs[0][key].timestamps is not None:
|
|
226
|
+
months = torch.stack(
|
|
227
|
+
[
|
|
228
|
+
self.time_ranges_to_timestamps(
|
|
229
|
+
inp[key].timestamps, # type: ignore
|
|
230
|
+
device=stacked_inputs[key].device,
|
|
231
|
+
)
|
|
232
|
+
for inp in context.inputs
|
|
233
|
+
],
|
|
234
|
+
dim=0,
|
|
235
|
+
)
|
|
236
|
+
if months is not None:
|
|
237
|
+
stacked_inputs["months"] = months
|
|
238
|
+
|
|
239
|
+
(
|
|
240
|
+
x,
|
|
241
|
+
mask,
|
|
242
|
+
dynamic_world,
|
|
243
|
+
months,
|
|
244
|
+
) = self.construct_presto_input(
|
|
245
|
+
**stacked_inputs,
|
|
246
|
+
s1_bands=PRESTO_S1_BANDS,
|
|
247
|
+
s2_bands=INPUT_PRESTO_S2_BANDS,
|
|
248
|
+
era5_bands=ERA5_BANDS,
|
|
249
|
+
srtm_bands=SRTM_BANDS,
|
|
250
|
+
normalize=True,
|
|
251
|
+
)
|
|
252
|
+
b, _, h, w, _ = x.shape
|
|
253
|
+
|
|
254
|
+
output_features = torch.zeros(
|
|
255
|
+
b * h * w, self.model.embedding_size, device=x.device
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
x = rearrange(x, "b t h w d -> (b h w) t d")
|
|
259
|
+
mask = rearrange(mask, "b t h w d -> (b h w) t d")
|
|
260
|
+
dynamic_world = rearrange(dynamic_world, "b t h w -> (b h w) t")
|
|
261
|
+
months = repeat(months, "b t -> (b h w) t", h=h, w=w)
|
|
262
|
+
if latlons is not None:
|
|
263
|
+
latlons = rearrange(latlons, "b c h w -> (b h w) c")
|
|
264
|
+
|
|
265
|
+
for batch_idx in range(0, b * h * w, self.pixel_batch_size):
|
|
266
|
+
x_b = x[batch_idx : batch_idx + self.pixel_batch_size]
|
|
267
|
+
mask_b = mask[batch_idx : batch_idx + self.pixel_batch_size]
|
|
268
|
+
dw = dynamic_world[batch_idx : batch_idx + self.pixel_batch_size]
|
|
269
|
+
months_b = months[batch_idx : batch_idx + self.pixel_batch_size]
|
|
270
|
+
if latlons is not None:
|
|
271
|
+
l_b = latlons[batch_idx : batch_idx + self.pixel_batch_size]
|
|
272
|
+
else:
|
|
273
|
+
l_b = None
|
|
274
|
+
output_b = self.model(
|
|
275
|
+
x=x_b,
|
|
276
|
+
dynamic_world=dw,
|
|
277
|
+
mask=mask_b,
|
|
278
|
+
month=months_b,
|
|
279
|
+
latlons=l_b,
|
|
280
|
+
eval_task=True,
|
|
281
|
+
)
|
|
282
|
+
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
283
|
+
|
|
284
|
+
return FeatureMaps(
|
|
285
|
+
[rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def get_backbone_channels(self) -> list:
|
|
289
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
290
|
+
|
|
291
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
292
|
+
to the feature maps that the backbone returns.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
296
|
+
"""
|
|
297
|
+
return [(1, 128)]
|