rslearn 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. rslearn/dataset/handler_summaries.py +130 -0
  2. rslearn/dataset/manage.py +157 -22
  3. rslearn/main.py +60 -8
  4. rslearn/models/anysat.py +207 -0
  5. rslearn/models/clay/clay.py +219 -0
  6. rslearn/models/clay/configs/metadata.yaml +295 -0
  7. rslearn/models/copernicusfm.py +37 -25
  8. rslearn/models/dinov3.py +165 -0
  9. rslearn/models/galileo/__init__.py +5 -0
  10. rslearn/models/galileo/galileo.py +517 -0
  11. rslearn/models/galileo/single_file_galileo.py +1672 -0
  12. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  13. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  14. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  15. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  16. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  17. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  18. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  19. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  20. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  21. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  22. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  23. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  24. rslearn/models/presto/presto.py +10 -7
  25. rslearn/models/prithvi.py +1122 -0
  26. rslearn/models/resize_features.py +45 -0
  27. rslearn/models/simple_time_series.py +65 -10
  28. rslearn/models/unet.py +17 -11
  29. rslearn/models/upsample.py +2 -2
  30. rslearn/tile_stores/default.py +31 -6
  31. rslearn/train/transforms/normalize.py +34 -5
  32. rslearn/train/transforms/select_bands.py +67 -0
  33. rslearn/train/transforms/sentinel1.py +60 -0
  34. rslearn/utils/geometry.py +61 -1
  35. rslearn/utils/raster_format.py +7 -1
  36. rslearn/utils/vector_format.py +13 -10
  37. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/METADATA +144 -15
  38. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/RECORD +42 -18
  39. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/WHEEL +0 -0
  40. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/entry_points.txt +0 -0
  41. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/licenses/LICENSE +0 -0
  42. {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,517 @@
1
+ """Galileo models."""
2
+
3
+ import math
4
+ import tempfile
5
+ from enum import StrEnum
6
+ from typing import Any, cast
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+ from huggingface_hub import hf_hub_download
13
+ from upath import UPath
14
+
15
+ from rslearn.log_utils import get_logger
16
+ from rslearn.models.galileo.single_file_galileo import (
17
+ CONFIG_FILENAME,
18
+ DW_BANDS,
19
+ ENCODER_FILENAME,
20
+ ERA5_BANDS,
21
+ LANDSCAN_BANDS,
22
+ LOCATION_BANDS,
23
+ S1_BANDS,
24
+ S2_BANDS,
25
+ SPACE_BAND_GROUPS_IDX,
26
+ SPACE_BANDS,
27
+ SPACE_TIME_BANDS,
28
+ SPACE_TIME_BANDS_GROUPS_IDX,
29
+ SRTM_BANDS,
30
+ STATIC_BAND_GROUPS_IDX,
31
+ STATIC_BANDS,
32
+ TC_BANDS,
33
+ TIME_BAND_GROUPS_IDX,
34
+ TIME_BANDS,
35
+ VIIRS_BANDS,
36
+ WC_BANDS,
37
+ Encoder,
38
+ MaskedOutput,
39
+ Normalizer,
40
+ )
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ HF_HUB_ID = "nasaharvest/galileo"
46
+ DEFAULT_MONTH = 5
47
+
48
+
49
+ # Galileo provides three sizes: nano, tiny, base
50
+ class GalileoSize(StrEnum):
51
+ """Size of the Galileo model."""
52
+
53
+ NANO = "nano"
54
+ TINY = "tiny"
55
+ BASE = "base"
56
+
57
+
58
+ pretrained_weights: dict[GalileoSize, str] = {
59
+ GalileoSize.NANO: "models/nano",
60
+ GalileoSize.TINY: "models/tiny",
61
+ GalileoSize.BASE: "models/base",
62
+ }
63
+
64
+ DEFAULT_NORMALIZER = Normalizer()
65
+
66
+
67
+ class GalileoModel(nn.Module):
68
+ """Galileo backbones."""
69
+
70
+ input_keys = [
71
+ "s1",
72
+ "s2",
73
+ "era5",
74
+ "tc",
75
+ "viirs",
76
+ "srtm",
77
+ "dw",
78
+ "wc",
79
+ "landscan",
80
+ "latlon",
81
+ ]
82
+
83
+ def __init__(
84
+ self,
85
+ size: GalileoSize,
86
+ patch_size: int = 4,
87
+ pretrained_path: str | UPath | None = None,
88
+ ) -> None:
89
+ """Initialize the Galileo model.
90
+
91
+ Args:
92
+ size: The size of the Galileo model.
93
+ patch_size: The patch size to use.
94
+ pretrained_path: the local path to the pretrained weights. Otherwise it is
95
+ downloaded and cached in temp directory.
96
+ """
97
+ super().__init__()
98
+ if pretrained_path is None:
99
+ pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "galileo")
100
+
101
+ pretrained_path_for_size = UPath(pretrained_path) / pretrained_weights[size]
102
+ if not (pretrained_path_for_size / CONFIG_FILENAME).exists():
103
+ _ = hf_hub_download(
104
+ local_dir=pretrained_path,
105
+ repo_id=HF_HUB_ID,
106
+ filename=f"{pretrained_weights[size]}/{CONFIG_FILENAME}",
107
+ revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
108
+ )
109
+ if not (pretrained_path_for_size / ENCODER_FILENAME).exists():
110
+ _ = hf_hub_download(
111
+ local_dir=pretrained_path,
112
+ repo_id=HF_HUB_ID,
113
+ filename=f"{pretrained_weights[size]}/{ENCODER_FILENAME}",
114
+ revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
115
+ )
116
+
117
+ assert (pretrained_path_for_size / ENCODER_FILENAME).exists()
118
+ assert (pretrained_path_for_size / CONFIG_FILENAME).exists()
119
+
120
+ self.model = Encoder.load_from_folder(
121
+ pretrained_path_for_size, device=torch.device("cpu")
122
+ )
123
+
124
+ self.s_t_channels_s2 = [
125
+ idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key
126
+ ]
127
+ self.s_t_channels_s1 = [
128
+ idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
129
+ ]
130
+
131
+ self.patch_size = patch_size
132
+
133
+ @staticmethod
134
+ def to_cartesian(
135
+ lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
136
+ ) -> np.ndarray | torch.Tensor:
137
+ """Transform latitudes and longitudes to cartesian coordinates."""
138
+ if isinstance(lat, float):
139
+ assert -90 <= lat <= 90, (
140
+ f"lat out of range ({lat}). Make sure you are in EPSG:4326"
141
+ )
142
+ assert -180 <= lon <= 180, (
143
+ f"lon out of range ({lon}). Make sure you are in EPSG:4326"
144
+ )
145
+ assert isinstance(lon, float), f"Expected float got {type(lon)}"
146
+ # transform to radians
147
+ lat = lat * math.pi / 180
148
+ lon = lon * math.pi / 180
149
+ x = math.cos(lat) * math.cos(lon)
150
+ y = math.cos(lat) * math.sin(lon)
151
+ z = math.sin(lat)
152
+ return np.array([x, y, z])
153
+ elif isinstance(lon, np.ndarray):
154
+ assert -90 <= lat.min(), (
155
+ f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
156
+ )
157
+ assert 90 >= lat.max(), (
158
+ f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
159
+ )
160
+ assert -180 <= lon.min(), (
161
+ f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
162
+ )
163
+ assert 180 >= lon.max(), (
164
+ f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
165
+ )
166
+ assert isinstance(lat, np.ndarray), f"Expected np.ndarray got {type(lat)}"
167
+ # transform to radians
168
+ lat = lat * math.pi / 180
169
+ lon = lon * math.pi / 180
170
+ x_np = np.cos(lat) * np.cos(lon)
171
+ y_np = np.cos(lat) * np.sin(lon)
172
+ z_np = np.sin(lat)
173
+ return np.stack([x_np, y_np, z_np], axis=-1)
174
+ elif isinstance(lon, torch.Tensor):
175
+ assert -90 <= lat.min(), (
176
+ f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
177
+ )
178
+ assert 90 >= lat.max(), (
179
+ f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
180
+ )
181
+ assert -180 <= lon.min(), (
182
+ f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
183
+ )
184
+ assert 180 >= lon.max(), (
185
+ f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
186
+ )
187
+ assert isinstance(lat, torch.Tensor), (
188
+ f"Expected torch.Tensor got {type(lat)}"
189
+ )
190
+ # transform to radians
191
+ lat = lat * math.pi / 180
192
+ lon = lon * math.pi / 180
193
+ x_t = torch.cos(lat) * torch.cos(lon)
194
+ y_t = torch.cos(lat) * torch.sin(lon)
195
+ z_t = torch.sin(lat)
196
+ return torch.stack([x_t, y_t, z_t], dim=-1)
197
+ else:
198
+ raise AssertionError(f"Unexpected input type {type(lon)}")
199
+
200
+ @classmethod
201
+ def construct_galileo_input(
202
+ cls,
203
+ s1: torch.Tensor | None = None, # [H, W, T, D]
204
+ s2: torch.Tensor | None = None, # [H, W, T, D]
205
+ era5: torch.Tensor | None = None, # [T, D]
206
+ tc: torch.Tensor | None = None, # [T, D]
207
+ viirs: torch.Tensor | None = None, # [T, D]
208
+ srtm: torch.Tensor | None = None, # [H, W, D]
209
+ dw: torch.Tensor | None = None, # [H, W, D]
210
+ wc: torch.Tensor | None = None, # [H, W, D]
211
+ landscan: torch.Tensor | None = None, # [D]
212
+ latlon: torch.Tensor | None = None, # [D]
213
+ months: torch.Tensor | None = None, # [T]
214
+ normalize: bool = False,
215
+ ) -> MaskedOutput:
216
+ """Construct a Galileo input."""
217
+ space_time_inputs = [s1, s2]
218
+ time_inputs = [era5, tc, viirs]
219
+ space_inputs = [srtm, dw, wc]
220
+ static_inputs = [landscan, latlon]
221
+ devices = [
222
+ x.device
223
+ for x in space_time_inputs + time_inputs + space_inputs + static_inputs
224
+ if x is not None
225
+ ]
226
+
227
+ if len(devices) == 0:
228
+ raise ValueError("At least one input must be not None")
229
+ if not all(devices[0] == device for device in devices):
230
+ raise ValueError("Received tensors on multiple devices")
231
+ device = devices[0]
232
+
233
+ # first, check all the input shapes are consistent
234
+ batch_list = (
235
+ [x.shape[0] for x in space_time_inputs if x is not None]
236
+ + [x.shape[0] for x in time_inputs if x is not None]
237
+ + [x.shape[0] for x in space_inputs if x is not None]
238
+ + [x.shape[0] for x in static_inputs if x is not None]
239
+ )
240
+ timesteps_list = [x.shape[3] for x in space_time_inputs if x is not None] + [
241
+ x.shape[1] for x in time_inputs if x is not None
242
+ ]
243
+ height_list = [x.shape[1] for x in space_time_inputs if x is not None] + [
244
+ x.shape[1] for x in space_inputs if x is not None
245
+ ]
246
+ width_list = [x.shape[2] for x in space_time_inputs if x is not None] + [
247
+ x.shape[2] for x in space_inputs if x is not None
248
+ ]
249
+ if len(batch_list) > 0:
250
+ if len(set(batch_list)) > 1:
251
+ raise ValueError("Inconsistent number of batch sizes per input")
252
+ b = batch_list[0]
253
+
254
+ if len(timesteps_list) > 0:
255
+ if not all(timesteps_list[0] == timestep for timestep in timesteps_list):
256
+ raise ValueError("Inconsistent number of timesteps per input")
257
+ t = timesteps_list[0]
258
+ else:
259
+ t = 1
260
+ if len(height_list) > 0:
261
+ if not all(height_list[0] == height for height in height_list):
262
+ raise ValueError("Inconsistent heights per input")
263
+ if not all(width_list[0] == width for width in width_list):
264
+ raise ValueError("Inconsistent widths per input")
265
+ h = height_list[0]
266
+ w = width_list[0]
267
+ else:
268
+ h, w = 1, 1
269
+
270
+ # now, we can construct our empty input tensors. By default, everything is masked
271
+ s_t_x = torch.zeros(
272
+ (b, h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device
273
+ )
274
+ s_t_m = torch.ones(
275
+ (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)),
276
+ dtype=torch.float,
277
+ device=device,
278
+ )
279
+ sp_x = torch.zeros(
280
+ (b, h, w, len(SPACE_BANDS)), dtype=torch.float, device=device
281
+ )
282
+ sp_m = torch.ones(
283
+ (b, h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device
284
+ )
285
+ t_x = torch.zeros((b, t, len(TIME_BANDS)), dtype=torch.float, device=device)
286
+ t_m = torch.ones(
287
+ (b, t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device
288
+ )
289
+ st_x = torch.zeros((b, len(STATIC_BANDS)), dtype=torch.float, device=device)
290
+ st_m = torch.ones(
291
+ (b, len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device
292
+ )
293
+
294
+ for x, bands_list, group_key in zip(
295
+ [s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]
296
+ ):
297
+ if x is not None:
298
+ indices = [
299
+ idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list
300
+ ]
301
+ groups_idx = [
302
+ idx
303
+ for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX)
304
+ if group_key in key
305
+ ]
306
+ s_t_x[:, :, :, :, indices] = x
307
+ s_t_m[:, :, :, :, groups_idx] = 0
308
+
309
+ for x, bands_list, group_key in zip(
310
+ [srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"]
311
+ ):
312
+ if x is not None:
313
+ indices = [
314
+ idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list
315
+ ]
316
+ groups_idx = [
317
+ idx
318
+ for idx, key in enumerate(SPACE_BAND_GROUPS_IDX)
319
+ if group_key in key
320
+ ]
321
+ sp_x[:, :, :, indices] = x
322
+ sp_m[:, :, :, groups_idx] = 0
323
+
324
+ for x, bands_list, group_key in zip(
325
+ [era5, tc, viirs],
326
+ [ERA5_BANDS, TC_BANDS, VIIRS_BANDS],
327
+ ["ERA5", "TC", "VIIRS"],
328
+ ):
329
+ if x is not None:
330
+ indices = [
331
+ idx for idx, val in enumerate(TIME_BANDS) if val in bands_list
332
+ ]
333
+ groups_idx = [
334
+ idx
335
+ for idx, key in enumerate(TIME_BAND_GROUPS_IDX)
336
+ if group_key in key
337
+ ]
338
+ t_x[:, :, indices] = x
339
+ t_m[:, :, groups_idx] = 0
340
+
341
+ for x, bands_list, group_key in zip(
342
+ [landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"]
343
+ ):
344
+ if x is not None:
345
+ if group_key == "location":
346
+ # transform latlon to cartesian
347
+ x = cast(torch.Tensor, cls.to_cartesian(x[:, 0], x[:, 1]))
348
+ indices = [
349
+ idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list
350
+ ]
351
+ groups_idx = [
352
+ idx
353
+ for idx, key in enumerate(STATIC_BAND_GROUPS_IDX)
354
+ if group_key in key
355
+ ]
356
+ st_x[:, indices] = x
357
+ st_m[:, groups_idx] = 0
358
+
359
+ if months is None:
360
+ months = torch.ones((b, t), dtype=torch.long, device=device) * DEFAULT_MONTH
361
+ else:
362
+ if months.shape[1] != t:
363
+ raise ValueError("Incorrect number of input months")
364
+
365
+ if normalize:
366
+ s_t_x = (
367
+ torch.from_numpy(DEFAULT_NORMALIZER(s_t_x.cpu().numpy()))
368
+ .to(device)
369
+ .float()
370
+ )
371
+ sp_x = (
372
+ torch.from_numpy(DEFAULT_NORMALIZER(sp_x.cpu().numpy()))
373
+ .to(device)
374
+ .float()
375
+ )
376
+ t_x = (
377
+ torch.from_numpy(DEFAULT_NORMALIZER(t_x.cpu().numpy()))
378
+ .to(device)
379
+ .float()
380
+ )
381
+ st_x = (
382
+ torch.from_numpy(DEFAULT_NORMALIZER(st_x.cpu().numpy()))
383
+ .to(device)
384
+ .float()
385
+ )
386
+
387
+ return MaskedOutput(
388
+ s_t_x=s_t_x,
389
+ s_t_m=s_t_m,
390
+ sp_x=sp_x,
391
+ sp_m=sp_m,
392
+ t_x=t_x,
393
+ t_m=t_m,
394
+ st_x=st_x,
395
+ st_m=st_m,
396
+ months=months,
397
+ )
398
+
399
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
400
+ """Compute feature maps from the Galileo backbone.
401
+
402
+ Inputs:
403
+ inputs: a dictionary of tensors, where the keys are one of Galileo.input_keys
404
+ (also documented below) and values are tensors of the following shapes,
405
+ per input key:
406
+ "s1": B (T * C) H W
407
+ "s2": B (T * C) H W
408
+ "era5": B (T * C) H W (we will average over the H, W dimensions)
409
+ "tc": B (T * C) H W (we will average over the H, W dimensions)
410
+ "viirs": B (T * C) H W (we will average over the H, W dimensions)
411
+ "srtm": B C H W (SRTM has no temporal dimension)
412
+ "dw": : B C H W (Dynamic World should be averaged over time)
413
+ "wc": B C H W (WorldCereal has no temporal dimension)
414
+ "landscan": B C H W (we will average over the H, W dimensions)
415
+ "latlon": B C H W (we will average over the H, W dimensions)
416
+
417
+ The output will be an embedding representing the pooled tokens. If there is
418
+ only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
419
+ a pool of all the unmasked tokens.
420
+
421
+ If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
422
+ take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
423
+ """
424
+ stacked_inputs = {}
425
+ for key in inputs[0].keys():
426
+ # assume all the keys in an input are consistent
427
+ if key in self.input_keys:
428
+ stacked_inputs[key] = torch.stack([inp[key] for inp in inputs], dim=0)
429
+ s_t_channels = []
430
+ for space_time_modality in ["s1", "s2"]:
431
+ if space_time_modality not in stacked_inputs:
432
+ continue
433
+ if space_time_modality == "s1":
434
+ s_t_channels += self.s_t_channels_s1
435
+ else:
436
+ s_t_channels += self.s_t_channels_s2
437
+ cur = stacked_inputs[space_time_modality]
438
+ # Check if it's single or multitemporal, and reshape accordingly
439
+ num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
440
+ num_timesteps = cur.shape[1] // num_bands
441
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
442
+ stacked_inputs[space_time_modality] = cur
443
+
444
+ for space_modality in ["srtm", "dw", "wc"]:
445
+ if space_modality not in stacked_inputs:
446
+ continue
447
+ stacked_inputs[space_modality] = rearrange(
448
+ stacked_inputs[space_modality], "b c h w -> b h w c"
449
+ )
450
+
451
+ for time_modality in ["era5", "tc", "viirs"]:
452
+ if time_modality not in stacked_inputs:
453
+ continue
454
+ cur = stacked_inputs[time_modality]
455
+ # Check if it's single or multitemporal, and reshape accordingly
456
+ num_bands = {
457
+ "era5": len(ERA5_BANDS),
458
+ "tc": len(TC_BANDS),
459
+ "viirs": len(VIIRS_BANDS),
460
+ }[time_modality]
461
+ num_timesteps = cur.shape[1] // num_bands
462
+ # take the average over the h, w bands since Galileo
463
+ # treats it as a pixel-timeseries
464
+ cur = rearrange(
465
+ torch.nanmean(torch.nanmean(cur, dim=-1), dim=-1),
466
+ "b (t c) -> b t c",
467
+ t=num_timesteps,
468
+ )
469
+ stacked_inputs[time_modality] = cur
470
+
471
+ for static_modality in ["landscan", "latlon"]:
472
+ if static_modality not in stacked_inputs:
473
+ continue
474
+ cur = stacked_inputs[static_modality]
475
+ stacked_inputs[static_modality] = torch.nanmean(
476
+ torch.nanmean(cur, dim=-1), dim=-1
477
+ )
478
+ galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
479
+ h = galileo_input.s_t_x.shape[1]
480
+ if h < self.patch_size:
481
+ logger.warning(
482
+ f"Given patch size {self.patch_size} < h {h}. Reducing patch size to {h}"
483
+ )
484
+ patch_size = h
485
+ else:
486
+ patch_size = self.patch_size
487
+ outputs = self.model(
488
+ s_t_x=galileo_input.s_t_x,
489
+ s_t_m=galileo_input.s_t_m,
490
+ sp_x=galileo_input.sp_x,
491
+ sp_m=galileo_input.sp_m,
492
+ t_x=galileo_input.t_x,
493
+ t_m=galileo_input.t_m,
494
+ st_x=galileo_input.st_x,
495
+ st_m=galileo_input.st_m,
496
+ months=galileo_input.months,
497
+ patch_size=patch_size,
498
+ )
499
+ if h == patch_size:
500
+ # only one spatial patch, so we can just take an average
501
+ # of all the tokens to output b c_g 1 1
502
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = outputs
503
+ averaged = self.model.average_tokens(
504
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
505
+ )
506
+ return [repeat(averaged, "b d -> b d 1 1")]
507
+ else:
508
+ s_t_x = outputs[0]
509
+ # we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
510
+ # s_t_x has shape [b, h, w, t, c_g, d]
511
+ # and we want [b, d, h, w]
512
+ return [
513
+ rearrange(
514
+ s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
515
+ "b h w c_g d -> b c_g d h w",
516
+ ).mean(dim=1)
517
+ ]