rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,595 @@
1
+ """Galileo models."""
2
+
3
+ import math
4
+ import tempfile
5
+ from contextlib import nullcontext
6
+ from datetime import datetime
7
+ from enum import StrEnum
8
+ from typing import cast
9
+
10
+ import numpy as np
11
+ import torch
12
+ from einops import rearrange, repeat
13
+ from huggingface_hub import hf_hub_download
14
+ from upath import UPath
15
+
16
+ from rslearn.log_utils import get_logger
17
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
18
+ from rslearn.models.galileo.single_file_galileo import (
19
+ CONFIG_FILENAME,
20
+ DW_BANDS,
21
+ ENCODER_FILENAME,
22
+ ERA5_BANDS,
23
+ LANDSCAN_BANDS,
24
+ LOCATION_BANDS,
25
+ S1_BANDS,
26
+ S2_BANDS,
27
+ SPACE_BAND_GROUPS_IDX,
28
+ SPACE_BANDS,
29
+ SPACE_TIME_BANDS,
30
+ SPACE_TIME_BANDS_GROUPS_IDX,
31
+ SRTM_BANDS,
32
+ STATIC_BAND_GROUPS_IDX,
33
+ STATIC_BANDS,
34
+ TC_BANDS,
35
+ TIME_BAND_GROUPS_IDX,
36
+ TIME_BANDS,
37
+ VIIRS_BANDS,
38
+ WC_BANDS,
39
+ Encoder,
40
+ MaskedOutput,
41
+ Normalizer,
42
+ )
43
+ from rslearn.train.model_context import ModelContext
44
+
45
+ logger = get_logger(__name__)
46
+
47
+
48
+ HF_HUB_ID = "nasaharvest/galileo"
49
+ DEFAULT_MONTH = 5
50
+
51
+
52
+ # Galileo provides three sizes: nano, tiny, base
53
+ class GalileoSize(StrEnum):
54
+ """Size of the Galileo model."""
55
+
56
+ NANO = "nano"
57
+ TINY = "tiny"
58
+ BASE = "base"
59
+
60
+
61
+ pretrained_weights: dict[GalileoSize, str] = {
62
+ GalileoSize.NANO: "models/nano",
63
+ GalileoSize.TINY: "models/tiny",
64
+ GalileoSize.BASE: "models/base",
65
+ }
66
+
67
+ DEFAULT_NORMALIZER = Normalizer()
68
+
69
+ AUTOCAST_DTYPE_MAP = {
70
+ "bfloat16": torch.bfloat16,
71
+ "float32": torch.float32,
72
+ }
73
+
74
+
75
+ class GalileoModel(FeatureExtractor):
76
+ """Galileo backbones."""
77
+
78
+ input_keys = [
79
+ "s1",
80
+ "s2",
81
+ "era5",
82
+ "tc",
83
+ "viirs",
84
+ "srtm",
85
+ "dw",
86
+ "wc",
87
+ "landscan",
88
+ "latlon",
89
+ ]
90
+
91
+ def __init__(
92
+ self,
93
+ size: GalileoSize,
94
+ patch_size: int = 4,
95
+ pretrained_path: str | UPath | None = None,
96
+ autocast_dtype: str | None = "bfloat16",
97
+ ) -> None:
98
+ """Initialize the Galileo model.
99
+
100
+ Args:
101
+ size: The size of the Galileo model.
102
+ patch_size: The patch size to use.
103
+ pretrained_path: the local path to the pretrained weights. Otherwise it is
104
+ downloaded and cached in temp directory.
105
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
106
+ """
107
+ super().__init__()
108
+ if pretrained_path is None:
109
+ pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "galileo")
110
+
111
+ pretrained_path_for_size = UPath(pretrained_path) / pretrained_weights[size]
112
+ if not (pretrained_path_for_size / CONFIG_FILENAME).exists():
113
+ _ = hf_hub_download(
114
+ local_dir=pretrained_path,
115
+ repo_id=HF_HUB_ID,
116
+ filename=f"{pretrained_weights[size]}/{CONFIG_FILENAME}",
117
+ revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
118
+ )
119
+ if not (pretrained_path_for_size / ENCODER_FILENAME).exists():
120
+ _ = hf_hub_download(
121
+ local_dir=pretrained_path,
122
+ repo_id=HF_HUB_ID,
123
+ filename=f"{pretrained_weights[size]}/{ENCODER_FILENAME}",
124
+ revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
125
+ )
126
+
127
+ assert (pretrained_path_for_size / ENCODER_FILENAME).exists()
128
+ assert (pretrained_path_for_size / CONFIG_FILENAME).exists()
129
+
130
+ self.model = Encoder.load_from_folder(
131
+ pretrained_path_for_size, device=torch.device("cpu")
132
+ )
133
+
134
+ self.s_t_channels_s2 = [
135
+ idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key
136
+ ]
137
+ self.s_t_channels_s1 = [
138
+ idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
139
+ ]
140
+
141
+ self.size = size
142
+ self.patch_size = patch_size
143
+
144
+ if autocast_dtype is not None:
145
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
146
+ else:
147
+ self.autocast_dtype = None
148
+
149
+ @staticmethod
150
+ def to_cartesian(
151
+ lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
152
+ ) -> np.ndarray | torch.Tensor:
153
+ """Transform latitudes and longitudes to cartesian coordinates."""
154
+ if isinstance(lat, float):
155
+ assert -90 <= lat <= 90, (
156
+ f"lat out of range ({lat}). Make sure you are in EPSG:4326"
157
+ )
158
+ assert -180 <= lon <= 180, (
159
+ f"lon out of range ({lon}). Make sure you are in EPSG:4326"
160
+ )
161
+ assert isinstance(lon, float), f"Expected float got {type(lon)}"
162
+ # transform to radians
163
+ lat = lat * math.pi / 180
164
+ lon = lon * math.pi / 180
165
+ x = math.cos(lat) * math.cos(lon)
166
+ y = math.cos(lat) * math.sin(lon)
167
+ z = math.sin(lat)
168
+ return np.array([x, y, z])
169
+ elif isinstance(lon, np.ndarray):
170
+ assert -90 <= lat.min(), (
171
+ f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
172
+ )
173
+ assert 90 >= lat.max(), (
174
+ f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
175
+ )
176
+ assert -180 <= lon.min(), (
177
+ f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
178
+ )
179
+ assert 180 >= lon.max(), (
180
+ f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
181
+ )
182
+ assert isinstance(lat, np.ndarray), f"Expected np.ndarray got {type(lat)}"
183
+ # transform to radians
184
+ lat = lat * math.pi / 180
185
+ lon = lon * math.pi / 180
186
+ x_np = np.cos(lat) * np.cos(lon)
187
+ y_np = np.cos(lat) * np.sin(lon)
188
+ z_np = np.sin(lat)
189
+ return np.stack([x_np, y_np, z_np], axis=-1)
190
+ elif isinstance(lon, torch.Tensor):
191
+ assert -90 <= lat.min(), (
192
+ f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
193
+ )
194
+ assert 90 >= lat.max(), (
195
+ f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
196
+ )
197
+ assert -180 <= lon.min(), (
198
+ f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
199
+ )
200
+ assert 180 >= lon.max(), (
201
+ f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
202
+ )
203
+ assert isinstance(lat, torch.Tensor), (
204
+ f"Expected torch.Tensor got {type(lat)}"
205
+ )
206
+ # transform to radians
207
+ lat = lat * math.pi / 180
208
+ lon = lon * math.pi / 180
209
+ x_t = torch.cos(lat) * torch.cos(lon)
210
+ y_t = torch.cos(lat) * torch.sin(lon)
211
+ z_t = torch.sin(lat)
212
+ return torch.stack([x_t, y_t, z_t], dim=-1)
213
+ else:
214
+ raise AssertionError(f"Unexpected input type {type(lon)}")
215
+
216
+ @classmethod
217
+ def construct_galileo_input(
218
+ cls,
219
+ s1: torch.Tensor | None = None, # [H, W, T, D]
220
+ s2: torch.Tensor | None = None, # [H, W, T, D]
221
+ era5: torch.Tensor | None = None, # [T, D]
222
+ tc: torch.Tensor | None = None, # [T, D]
223
+ viirs: torch.Tensor | None = None, # [T, D]
224
+ srtm: torch.Tensor | None = None, # [H, W, D]
225
+ dw: torch.Tensor | None = None, # [H, W, D]
226
+ wc: torch.Tensor | None = None, # [H, W, D]
227
+ landscan: torch.Tensor | None = None, # [D]
228
+ latlon: torch.Tensor | None = None, # [D]
229
+ months: torch.Tensor | None = None, # [T]
230
+ normalize: bool = False,
231
+ ) -> MaskedOutput:
232
+ """Construct a Galileo input."""
233
+ space_time_inputs = [s1, s2]
234
+ time_inputs = [era5, tc, viirs]
235
+ space_inputs = [srtm, dw, wc]
236
+ static_inputs = [landscan, latlon]
237
+ devices = [
238
+ x.device
239
+ for x in space_time_inputs + time_inputs + space_inputs + static_inputs
240
+ if x is not None
241
+ ]
242
+
243
+ if len(devices) == 0:
244
+ raise ValueError("At least one input must be not None")
245
+ if not all(devices[0] == device for device in devices):
246
+ raise ValueError("Received tensors on multiple devices")
247
+ device = devices[0]
248
+
249
+ # first, check all the input shapes are consistent
250
+ batch_list = (
251
+ [x.shape[0] for x in space_time_inputs if x is not None]
252
+ + [x.shape[0] for x in time_inputs if x is not None]
253
+ + [x.shape[0] for x in space_inputs if x is not None]
254
+ + [x.shape[0] for x in static_inputs if x is not None]
255
+ )
256
+ timesteps_list = [x.shape[3] for x in space_time_inputs if x is not None] + [
257
+ x.shape[1] for x in time_inputs if x is not None
258
+ ]
259
+ height_list = [x.shape[1] for x in space_time_inputs if x is not None] + [
260
+ x.shape[1] for x in space_inputs if x is not None
261
+ ]
262
+ width_list = [x.shape[2] for x in space_time_inputs if x is not None] + [
263
+ x.shape[2] for x in space_inputs if x is not None
264
+ ]
265
+ if len(batch_list) > 0:
266
+ if len(set(batch_list)) > 1:
267
+ raise ValueError("Inconsistent number of batch sizes per input")
268
+ b = batch_list[0]
269
+
270
+ if len(timesteps_list) > 0:
271
+ if not all(timesteps_list[0] == timestep for timestep in timesteps_list):
272
+ raise ValueError("Inconsistent number of timesteps per input")
273
+ t = timesteps_list[0]
274
+ else:
275
+ t = 1
276
+ if len(height_list) > 0:
277
+ if not all(height_list[0] == height for height in height_list):
278
+ raise ValueError("Inconsistent heights per input")
279
+ if not all(width_list[0] == width for width in width_list):
280
+ raise ValueError("Inconsistent widths per input")
281
+ h = height_list[0]
282
+ w = width_list[0]
283
+ else:
284
+ h, w = 1, 1
285
+
286
+ # now, we can construct our empty input tensors. By default, everything is masked
287
+ s_t_x = torch.zeros(
288
+ (b, h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device
289
+ )
290
+ s_t_m = torch.ones(
291
+ (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)),
292
+ dtype=torch.float,
293
+ device=device,
294
+ )
295
+ sp_x = torch.zeros(
296
+ (b, h, w, len(SPACE_BANDS)), dtype=torch.float, device=device
297
+ )
298
+ sp_m = torch.ones(
299
+ (b, h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device
300
+ )
301
+ t_x = torch.zeros((b, t, len(TIME_BANDS)), dtype=torch.float, device=device)
302
+ t_m = torch.ones(
303
+ (b, t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device
304
+ )
305
+ st_x = torch.zeros((b, len(STATIC_BANDS)), dtype=torch.float, device=device)
306
+ st_m = torch.ones(
307
+ (b, len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device
308
+ )
309
+
310
+ for x, bands_list, group_key in zip(
311
+ [s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]
312
+ ):
313
+ if x is not None:
314
+ indices = [
315
+ idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list
316
+ ]
317
+ groups_idx = [
318
+ idx
319
+ for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX)
320
+ if group_key in key
321
+ ]
322
+ s_t_x[:, :, :, :, indices] = x
323
+ s_t_m[:, :, :, :, groups_idx] = 0
324
+
325
+ for x, bands_list, group_key in zip(
326
+ [srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"]
327
+ ):
328
+ if x is not None:
329
+ indices = [
330
+ idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list
331
+ ]
332
+ groups_idx = [
333
+ idx
334
+ for idx, key in enumerate(SPACE_BAND_GROUPS_IDX)
335
+ if group_key in key
336
+ ]
337
+ sp_x[:, :, :, indices] = x
338
+ sp_m[:, :, :, groups_idx] = 0
339
+
340
+ for x, bands_list, group_key in zip(
341
+ [era5, tc, viirs],
342
+ [ERA5_BANDS, TC_BANDS, VIIRS_BANDS],
343
+ ["ERA5", "TC", "VIIRS"],
344
+ ):
345
+ if x is not None:
346
+ indices = [
347
+ idx for idx, val in enumerate(TIME_BANDS) if val in bands_list
348
+ ]
349
+ groups_idx = [
350
+ idx
351
+ for idx, key in enumerate(TIME_BAND_GROUPS_IDX)
352
+ if group_key in key
353
+ ]
354
+ t_x[:, :, indices] = x
355
+ t_m[:, :, groups_idx] = 0
356
+
357
+ for x, bands_list, group_key in zip(
358
+ [landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"]
359
+ ):
360
+ if x is not None:
361
+ if group_key == "location":
362
+ # transform latlon to cartesian
363
+ x = cast(torch.Tensor, cls.to_cartesian(x[:, 0], x[:, 1]))
364
+ indices = [
365
+ idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list
366
+ ]
367
+ groups_idx = [
368
+ idx
369
+ for idx, key in enumerate(STATIC_BAND_GROUPS_IDX)
370
+ if group_key in key
371
+ ]
372
+ st_x[:, indices] = x
373
+ st_m[:, groups_idx] = 0
374
+
375
+ if months is None:
376
+ months = torch.ones((b, t), dtype=torch.long, device=device) * DEFAULT_MONTH
377
+ else:
378
+ if months.shape[1] != t:
379
+ raise ValueError("Incorrect number of input months")
380
+
381
+ if normalize:
382
+ s_t_x = (
383
+ torch.from_numpy(DEFAULT_NORMALIZER(s_t_x.cpu().numpy()))
384
+ .to(device)
385
+ .float()
386
+ )
387
+ sp_x = (
388
+ torch.from_numpy(DEFAULT_NORMALIZER(sp_x.cpu().numpy()))
389
+ .to(device)
390
+ .float()
391
+ )
392
+ t_x = (
393
+ torch.from_numpy(DEFAULT_NORMALIZER(t_x.cpu().numpy()))
394
+ .to(device)
395
+ .float()
396
+ )
397
+ st_x = (
398
+ torch.from_numpy(DEFAULT_NORMALIZER(st_x.cpu().numpy()))
399
+ .to(device)
400
+ .float()
401
+ )
402
+
403
+ return MaskedOutput(
404
+ s_t_x=s_t_x,
405
+ s_t_m=s_t_m,
406
+ sp_x=sp_x,
407
+ sp_m=sp_m,
408
+ t_x=t_x,
409
+ t_m=t_m,
410
+ st_x=st_x,
411
+ st_m=st_m,
412
+ months=months,
413
+ )
414
+
415
+ @staticmethod
416
+ def time_ranges_to_timestamps(
417
+ time_ranges: list[tuple[datetime, datetime]],
418
+ device: torch.device,
419
+ ) -> torch.Tensor:
420
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
421
+
422
+ Galileo only uses the month associated with each timestamp, so we take the midpoint
423
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
424
+ time so that start_time == end_time == mid_time.
425
+ """
426
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
427
+ # months are indexed 0-11
428
+ return torch.tensor(
429
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
430
+ )
431
+
432
+ def forward(self, context: ModelContext) -> FeatureMaps:
433
+ """Compute feature maps from the Galileo backbone.
434
+
435
+ Args:
436
+ context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
437
+ (also documented below) and values are tensors of the following shapes,
438
+ per input key:
439
+ "s1": B C T H W
440
+ "s2": B C T H W
441
+ "era5": B C T H W (we will average over the H, W dimensions)
442
+ "tc": B C T H W (we will average over the H, W dimensions)
443
+ "viirs": B C T H W (we will average over the H, W dimensions)
444
+ "srtm": B C 1 H W (SRTM has no temporal dimension)
445
+ "dw": : B C 1 H W (Dynamic World should be averaged over time)
446
+ "wc": B C 1 H W (WorldCereal has no temporal dimension)
447
+ "landscan": B C 1 H W (we will average over the H, W dimensions)
448
+ "latlon": B C 1 H W (we will average over the H, W dimensions)
449
+
450
+ The output will be an embedding representing the pooled tokens. If there is
451
+ only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
452
+ a pool of all the unmasked tokens.
453
+
454
+ If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
455
+ take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
456
+ """
457
+ space_time_modalities = ["s1", "s2"]
458
+ time_modalities = ["era5", "tc", "viirs"]
459
+ stacked_inputs = {}
460
+ months: torch.Tensor | None = None
461
+ for key in context.inputs[0].keys():
462
+ # assume all the keys in an input are consistent
463
+ if key in self.input_keys:
464
+ stacked_inputs[key] = torch.stack(
465
+ [inp[key].image for inp in context.inputs], dim=0
466
+ )
467
+ if key in space_time_modalities + time_modalities:
468
+ if months is None:
469
+ if context.inputs[0][key].timestamps is not None:
470
+ months = torch.stack(
471
+ [
472
+ self.time_ranges_to_timestamps(
473
+ inp[key].timestamps, # type: ignore
474
+ device=stacked_inputs[key].device,
475
+ )
476
+ for inp in context.inputs
477
+ ],
478
+ dim=0,
479
+ )
480
+
481
+ if months is not None:
482
+ stacked_inputs["months"] = months
483
+
484
+ s_t_channels = []
485
+ for space_time_modality in space_time_modalities:
486
+ if space_time_modality not in stacked_inputs:
487
+ continue
488
+ if space_time_modality == "s1":
489
+ s_t_channels += self.s_t_channels_s1
490
+ else:
491
+ s_t_channels += self.s_t_channels_s2
492
+ cur = stacked_inputs[space_time_modality]
493
+ cur = rearrange(cur, "b c t h w -> b h w t c")
494
+ stacked_inputs[space_time_modality] = cur
495
+
496
+ for space_modality in ["srtm", "dw", "wc"]:
497
+ if space_modality not in stacked_inputs:
498
+ continue
499
+ # take the first (and assumed only) timestep
500
+ stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
501
+ stacked_inputs[space_modality] = rearrange(
502
+ stacked_inputs[space_modality], "b c h w -> b h w c"
503
+ )
504
+
505
+ for time_modality in time_modalities:
506
+ if time_modality not in stacked_inputs:
507
+ continue
508
+ cur = stacked_inputs[time_modality]
509
+ # take the average over the h, w bands since Galileo
510
+ # treats it as a pixel-timeseries
511
+ cur = rearrange(
512
+ torch.nanmean(cur, dim=(-1, -2)),
513
+ "b c t -> b t c",
514
+ )
515
+ stacked_inputs[time_modality] = cur
516
+
517
+ for static_modality in ["landscan", "latlon"]:
518
+ if static_modality not in stacked_inputs:
519
+ continue
520
+ cur = stacked_inputs[static_modality]
521
+ stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
522
+
523
+ galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
524
+ h = galileo_input.s_t_x.shape[1]
525
+ if h < self.patch_size:
526
+ logger.warning(
527
+ f"Given patch size {self.patch_size} < h {h}. Reducing patch size to {h}"
528
+ )
529
+ patch_size = h
530
+ else:
531
+ patch_size = self.patch_size
532
+
533
+ # Decide context based on self.autocast_dtype.
534
+ device = galileo_input.s_t_x.device
535
+ if self.autocast_dtype is None:
536
+ torch_context = nullcontext()
537
+ else:
538
+ assert device is not None
539
+ torch_context = torch.amp.autocast(
540
+ device_type=device.type, dtype=self.autocast_dtype
541
+ )
542
+ with torch_context:
543
+ outputs = self.model(
544
+ s_t_x=galileo_input.s_t_x,
545
+ s_t_m=galileo_input.s_t_m,
546
+ sp_x=galileo_input.sp_x,
547
+ sp_m=galileo_input.sp_m,
548
+ t_x=galileo_input.t_x,
549
+ t_m=galileo_input.t_m,
550
+ st_x=galileo_input.st_x,
551
+ st_m=galileo_input.st_m,
552
+ months=galileo_input.months,
553
+ patch_size=patch_size,
554
+ )
555
+
556
+ if h == patch_size:
557
+ # only one spatial patch, so we can just take an average
558
+ # of all the tokens to output b c_g 1 1
559
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = outputs
560
+ averaged = self.model.average_tokens(
561
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
562
+ )
563
+ return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
564
+ else:
565
+ s_t_x = outputs[0]
566
+ # we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
567
+ # s_t_x has shape [b, h, w, t, c_g, d]
568
+ # and we want [b, d, h, w]
569
+ return FeatureMaps(
570
+ [
571
+ rearrange(
572
+ s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
573
+ "b h w c_g d -> b c_g d h w",
574
+ ).mean(dim=1)
575
+ ]
576
+ )
577
+
578
+ def get_backbone_channels(self) -> list:
579
+ """Returns the output channels of this model when used as a backbone.
580
+
581
+ The output channels is a list of (patch_size, depth) that corresponds
582
+ to the feature maps that the backbone returns.
583
+
584
+ Returns:
585
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
586
+ """
587
+ if self.size == GalileoSize.BASE:
588
+ depth = 768
589
+ elif self.model_size == GalileoSize.TINY:
590
+ depth = 192
591
+ elif self.model_size == GalileoSize.NANO:
592
+ depth = 128
593
+ else:
594
+ raise ValueError(f"Invalid model size: {self.size}")
595
+ return [(self.patch_size, depth)]