rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
1
+ """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
+
3
+ import json
4
+ import warnings
5
+ from contextlib import nullcontext
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from olmoearth_pretrain.config import Config, require_olmo_core
12
+ from olmoearth_pretrain.data.constants import Modality
13
+ from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
14
+ from olmoearth_pretrain.model_loader import (
15
+ ModelID,
16
+ load_model_from_id,
17
+ load_model_from_path,
18
+ )
19
+ from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
20
+ from upath import UPath
21
+
22
+ from rslearn.log_utils import get_logger
23
+ from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
24
+ from rslearn.train.model_context import ModelContext, RasterImage
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ MODALITY_NAMES = [
29
+ "sentinel2_l2a",
30
+ "sentinel1",
31
+ "worldcover",
32
+ "openstreetmap_raster",
33
+ "landsat",
34
+ ]
35
+
36
+ AUTOCAST_DTYPE_MAP = {
37
+ "bfloat16": torch.bfloat16,
38
+ "float16": torch.float16,
39
+ "float32": torch.float32,
40
+ }
41
+
42
+ EMBEDDING_SIZES = {
43
+ ModelID.OLMOEARTH_V1_NANO: 128,
44
+ ModelID.OLMOEARTH_V1_TINY: 192,
45
+ ModelID.OLMOEARTH_V1_BASE: 768,
46
+ ModelID.OLMOEARTH_V1_LARGE: 1024,
47
+ }
48
+
49
+
50
+ class OlmoEarth(FeatureExtractor):
51
+ """A wrapper to support the OlmoEarth model."""
52
+
53
+ def __init__(
54
+ self,
55
+ patch_size: int,
56
+ model_id: ModelID | None = None,
57
+ model_path: str | None = None,
58
+ checkpoint_path: str | None = None,
59
+ selector: list[str | int] = ["encoder"],
60
+ forward_kwargs: dict[str, Any] = {},
61
+ random_initialization: bool = False,
62
+ embedding_size: int | None = None,
63
+ autocast_dtype: str | None = "bfloat16",
64
+ token_pooling: bool = True,
65
+ use_legacy_timestamps: bool = True,
66
+ ):
67
+ """Create a new OlmoEarth model.
68
+
69
+ Args:
70
+ patch_size: token spatial patch size to use.
71
+ model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
72
+ set.
73
+ model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
74
+ set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
75
+ checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
76
+ set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
77
+ folder.
78
+ selector: an optional sequence of attribute names or list indices to select
79
+ the sub-module that should be applied on the input images. Defaults to
80
+ ["encoder"] to select only the transformer encoder.
81
+ forward_kwargs: additional arguments to pass to forward pass besides the
82
+ MaskedOlmoEarthSample.
83
+ random_initialization: whether to skip loading the checkpoint so the
84
+ weights are randomly initialized. In this case, the checkpoint is only
85
+ used to define the model architecture.
86
+ embedding_size: optional embedding size to report via
87
+ get_backbone_channels (if model_id is not set).
88
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
89
+ token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
90
+ there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
91
+ dimensions.
92
+ use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
93
+ from 0 (instead of the actual timestamps of the input). The option to do this is preserved
94
+ for backwards compatability with finetuned models which were trained against this implementation.
95
+ """
96
+ if use_legacy_timestamps:
97
+ warnings.warn(
98
+ "For new projects, don't use legacy timesteps.", DeprecationWarning
99
+ )
100
+
101
+ if (
102
+ sum(
103
+ [
104
+ model_id is not None,
105
+ model_path is not None,
106
+ checkpoint_path is not None,
107
+ ]
108
+ )
109
+ != 1
110
+ ):
111
+ raise ValueError(
112
+ "exactly one of model_id, model_path, or checkpoint_path must be set"
113
+ )
114
+
115
+ super().__init__()
116
+ self.patch_size = patch_size
117
+ self.forward_kwargs = forward_kwargs
118
+ self.embedding_size = embedding_size
119
+
120
+ if autocast_dtype is not None:
121
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
122
+ else:
123
+ self.autocast_dtype = None
124
+
125
+ if model_id is not None:
126
+ # Load from Hugging Face.
127
+ model = load_model_from_id(model_id, load_weights=not random_initialization)
128
+ if self.embedding_size is None and model_id in EMBEDDING_SIZES:
129
+ self.embedding_size = EMBEDDING_SIZES[model_id]
130
+
131
+ elif model_path is not None:
132
+ # Load from path.
133
+ model = load_model_from_path(
134
+ UPath(model_path), load_weights=not random_initialization
135
+ )
136
+
137
+ else:
138
+ # Load the distributed model checkpoint by path through Olmo Core
139
+ model = self._load_model_from_checkpoint(
140
+ UPath(checkpoint_path), random_initialization
141
+ )
142
+
143
+ # Select just the portion of the model that we actually want to use.
144
+ for part in selector:
145
+ if isinstance(part, str):
146
+ model = getattr(model, part)
147
+ else:
148
+ model = model[part]
149
+ self.model = model
150
+ self.token_pooling = token_pooling
151
+ self.use_legacy_timestamps = use_legacy_timestamps
152
+
153
+ def _load_model_from_checkpoint(
154
+ self, checkpoint_upath: UPath, random_initialization: bool
155
+ ) -> torch.nn.Module:
156
+ """Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
157
+
158
+ The folder should contain config.json as well as the model_and_optim folder
159
+ that contains the distributed checkpoint. This is the format produced by
160
+ pre-training runs in olmoearth_pretrain.
161
+ """
162
+ with (checkpoint_upath / "config.json").open() as f:
163
+ config_dict = json.load(f)
164
+ model_config = Config.from_dict(config_dict["model"])
165
+
166
+ model = model_config.build()
167
+
168
+ # Load the checkpoint (requires olmo_core for distributed checkpoint loading).
169
+ if not random_initialization:
170
+ require_olmo_core(
171
+ "_load_model_from_checkpoint with random_initialization=False"
172
+ )
173
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
174
+
175
+ train_module_dir = checkpoint_upath / "model_and_optim"
176
+ load_model_and_optim_state(str(train_module_dir), model)
177
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
178
+
179
+ return model
180
+
181
+ @staticmethod
182
+ def time_ranges_to_timestamps(
183
+ time_ranges: list[tuple[datetime, datetime]],
184
+ max_timestamps: int,
185
+ device: torch.device,
186
+ ) -> torch.Tensor:
187
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
188
+
189
+ OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
190
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
191
+ time so that start_time == end_time == mid_time.
192
+ """
193
+ timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
194
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
195
+ timestamps[: len(time_ranges), 0] = torch.tensor(
196
+ [d.day for d in mid_ranges], dtype=torch.int32
197
+ )
198
+ # months are indexed 0-11
199
+ timestamps[: len(time_ranges), 1] = torch.tensor(
200
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32
201
+ )
202
+ timestamps[: len(time_ranges), 2] = torch.tensor(
203
+ [d.year for d in mid_ranges], dtype=torch.int32
204
+ )
205
+ return timestamps
206
+
207
+ def _prepare_modality_inputs(
208
+ self, context: ModelContext
209
+ ) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
210
+ """Prepare modality tensors and masks for the OlmoEarth model.
211
+
212
+ Uses a two-pass approach to ensure all modalities have consistent timestep
213
+ dimensions for position encoding.
214
+
215
+ Args:
216
+ context: the model context with input tensors.
217
+
218
+ Returns:
219
+ tuple of (sample, present_modalities, device)
220
+ """
221
+ kwargs = {}
222
+ present_modalities = []
223
+ device = None
224
+
225
+ # First pass: find global max_timesteps across all modalities and samples
226
+ # TODO: currently we assume all modalities have the same number of timesteps,
227
+ # which is not true for all cases, and time series time steps are assumed to
228
+ # be 1-month apart. It also assumes continuity between available timesteps.
229
+ # We'll have to fix all that.
230
+ max_timesteps = 1
231
+ modality_data = {}
232
+ # we will just store the longest time range
233
+ # per instance in the batch. This means it may not be
234
+ # aligned per modality
235
+ timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
236
+ context.inputs
237
+ )
238
+ for modality in MODALITY_NAMES:
239
+ if modality not in context.inputs[0]:
240
+ continue
241
+ present_modalities.append(modality)
242
+ tensors = []
243
+ for idx, inp in enumerate(context.inputs):
244
+ assert isinstance(inp[modality], RasterImage)
245
+ tensors.append(inp[modality].image)
246
+ cur_timestamps = inp[modality].timestamps
247
+ if cur_timestamps is not None and len(cur_timestamps) > len(
248
+ timestamps_per_instance[idx]
249
+ ):
250
+ timestamps_per_instance[idx] = cur_timestamps
251
+ tensors = [inp[modality].image for inp in context.inputs]
252
+ device = tensors[0].device
253
+ max_t = max(t.shape[1] for t in tensors)
254
+ max_timesteps = max(max_timesteps, max_t)
255
+ modality_data[modality] = (
256
+ tensors,
257
+ len(Modality.get(modality).band_sets),
258
+ )
259
+
260
+ # Second pass: pad and process each modality with global max_timesteps
261
+ for modality in present_modalities:
262
+ tensors, num_band_sets = modality_data[modality]
263
+
264
+ # Pad tensors to target_ch and track original timesteps for masking
265
+ padded = []
266
+ original_timesteps = []
267
+ for t in tensors:
268
+ orig_t = t.shape[1]
269
+ original_timesteps.append(orig_t)
270
+ if orig_t < max_timesteps:
271
+ pad = torch.zeros(
272
+ t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
273
+ dtype=t.dtype,
274
+ device=device,
275
+ )
276
+ t = torch.cat([t, pad], dim=1)
277
+ padded.append(t)
278
+
279
+ cur = torch.stack(padded, dim=0)
280
+ cur = rearrange(cur, "b c t h w -> b h w t c")
281
+ kwargs[modality] = cur
282
+
283
+ # Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
284
+ b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
285
+ mask = torch.full(
286
+ (b, h, w, max_timesteps, num_band_sets),
287
+ fill_value=MaskValue.ONLINE_ENCODER.value,
288
+ dtype=torch.int32,
289
+ device=device,
290
+ )
291
+ for sample_idx, orig_t in enumerate(original_timesteps):
292
+ if orig_t < max_timesteps:
293
+ mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
294
+ kwargs[f"{modality}_mask"] = mask
295
+
296
+ if self.use_legacy_timestamps:
297
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
298
+ timestamps = torch.zeros(
299
+ (len(context.inputs), max_timesteps, 3),
300
+ dtype=torch.int32,
301
+ device=device,
302
+ )
303
+ timestamps[:, :, 0] = 1 # day
304
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
305
+ None, :
306
+ ] # month
307
+ timestamps[:, :, 2] = 2024 # year
308
+ kwargs["timestamps"] = timestamps
309
+ else:
310
+ if max([len(t) for t in timestamps_per_instance]) == 0:
311
+ # Timestamps is required.
312
+ raise ValueError("No inputs had timestamps.")
313
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
314
+ kwargs["timestamps"] = torch.stack(
315
+ [
316
+ self.time_ranges_to_timestamps(time_range, max_timesteps, device)
317
+ for time_range in timestamps_per_instance
318
+ ],
319
+ dim=0,
320
+ )
321
+
322
+ return MaskedOlmoEarthSample(**kwargs), present_modalities, device
323
+
324
+ def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
325
+ """Compute feature maps from the OlmoEarth backbone.
326
+
327
+ Args:
328
+ context: the model context. Input dicts should include keys corresponding
329
+ to the modalities that should be passed to the OlmoEarth model.
330
+
331
+ Returns:
332
+ a FeatureMaps consisting of one feature map, at 1/patch_size of the input
333
+ resolution. Embeddings will be pooled across modalities and timesteps.
334
+ """
335
+ sample, present_modalities, device = self._prepare_modality_inputs(context)
336
+
337
+ # Decide context based on self.autocast_dtype.
338
+ if self.autocast_dtype is None:
339
+ torch_context = nullcontext()
340
+ else:
341
+ assert device is not None
342
+ torch_context = torch.amp.autocast(
343
+ device_type=device.type, dtype=self.autocast_dtype
344
+ )
345
+
346
+ # Check if we can bypass masks (fast_pass=True)
347
+ missing_tokens = False
348
+ for modality in present_modalities:
349
+ modality_mask = getattr(sample, f"{modality}_mask")
350
+ if torch.any(modality_mask == MaskValue.MISSING.value):
351
+ missing_tokens = True
352
+ break
353
+
354
+ with torch_context:
355
+ # Currently we assume the provided model always returns a TokensAndMasks object.
356
+ tokens_and_masks: TokensAndMasks
357
+ if isinstance(self.model, Encoder):
358
+ # Encoder has a fast_pass argument to indicate mask is not needed.
359
+ tokens_and_masks = self.model(
360
+ sample,
361
+ fast_pass=not missing_tokens,
362
+ patch_size=self.patch_size,
363
+ **self.forward_kwargs,
364
+ )["tokens_and_masks"]
365
+ else:
366
+ # Other models like STEncoder do not have this option supported.
367
+ tokens_and_masks = self.model(
368
+ sample, patch_size=self.patch_size, **self.forward_kwargs
369
+ )["tokens_and_masks"]
370
+
371
+ # Apply temporal/modality pooling so we just have one feature per patch.
372
+ features = []
373
+ if self.token_pooling:
374
+ for modality in present_modalities:
375
+ modality_features = getattr(tokens_and_masks, modality) # BHWTSC
376
+ # If fast_pass is False, we need to mask the missing tokens before pooling.
377
+ if missing_tokens:
378
+ modality_masks = getattr(
379
+ tokens_and_masks, f"{modality}_mask"
380
+ ) # BHWTS
381
+ modality_masks_bool = (
382
+ modality_masks != MaskValue.MISSING.value
383
+ ).unsqueeze(-1)
384
+ count = modality_masks_bool.sum(dim=[3, 4])
385
+ # Masked average over band sets and timesteps (BHWTSC -> BHWC).
386
+ pooled = (modality_features * modality_masks_bool).sum(
387
+ dim=[3, 4]
388
+ ) / count.clamp(min=1)
389
+ else:
390
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
391
+ pooled = modality_features.mean(dim=[3, 4])
392
+ # We want BHWC -> BCHW.
393
+ pooled = rearrange(pooled, "b h w c -> b c h w")
394
+ features.append(pooled)
395
+ # Pool over the modalities, so we get one BCHW feature map.
396
+ pooled = torch.stack(features, dim=0).mean(dim=0)
397
+ return FeatureMaps([pooled])
398
+ else:
399
+ for modality in present_modalities:
400
+ modality_features = getattr(tokens_and_masks, modality)
401
+ # Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
402
+ modality_features = rearrange(
403
+ modality_features, "b h w t s c -> b c h w (t s)"
404
+ )
405
+ features.append(modality_features)
406
+ pooled = torch.cat(features, dim=-1)
407
+ return TokenFeatureMaps([pooled])
408
+
409
+ def get_backbone_channels(self) -> list:
410
+ """Returns the output channels of this model when used as a backbone.
411
+
412
+ The output channels is a list of (downsample_factor, depth) that corresponds
413
+ to the feature maps that the backbone returns. For example, an element [2, 32]
414
+ indicates that the corresponding feature map is 1/2 the input resolution and
415
+ has 32 channels.
416
+
417
+ Returns:
418
+ the output channels of the backbone as a list of (downsample_factor, depth)
419
+ tuples.
420
+ """
421
+ return [(self.patch_size, self.embedding_size)]
@@ -0,0 +1,86 @@
1
+ """Normalization transforms."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from olmoearth_pretrain.data.normalize import load_computed_config
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.train.transforms.transform import Transform
10
+
11
+ logger = get_logger(__file__)
12
+
13
+
14
+ class OlmoEarthNormalize(Transform):
15
+ """Normalize using OlmoEarth JSON config.
16
+
17
+ For Sentinel-1 data, the values should be converted to decibels before being passed
18
+ to this transform.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ band_names: dict[str, list[str]],
24
+ std_multiplier: float | None = 2,
25
+ config_fname: str | None = None,
26
+ ) -> None:
27
+ """Initialize a new OlmoEarthNormalize.
28
+
29
+ Args:
30
+ band_names: map from modality name to the list of bands in that modality in
31
+ the order they are being loaded. Note that this order must match the
32
+ expected order for the OlmoEarth model.
33
+ std_multiplier: the std multiplier matching the one used for the model
34
+ training in OlmoEarth.
35
+ config_fname: load the normalization configuration from this file, instead
36
+ of getting it from OlmoEarth.
37
+ """
38
+ super().__init__()
39
+ self.band_names = band_names
40
+ self.std_multiplier = std_multiplier
41
+
42
+ if config_fname is None:
43
+ self.norm_config = load_computed_config()
44
+ else:
45
+ logger.warning(
46
+ f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
47
+ )
48
+ with open(config_fname) as f:
49
+ self.norm_config = json.load(f)
50
+
51
+ def forward(
52
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Apply normalization over the inputs and targets.
55
+
56
+ Args:
57
+ input_dict: the input
58
+ target_dict: the target
59
+
60
+ Returns:
61
+ normalized (input_dicts, target_dicts) tuple
62
+ """
63
+ for modality_name, cur_band_names in self.band_names.items():
64
+ band_norms = self.norm_config[modality_name]
65
+ image = input_dict[modality_name]
66
+ # Keep a set of indices to make sure that we normalize all of them.
67
+ needed_band_indices = set(range(image.image.shape[0]))
68
+ num_timesteps = image.image.shape[0] // len(cur_band_names)
69
+
70
+ for band, norm_dict in band_norms.items():
71
+ # If multitemporal, normalize each timestep separately.
72
+ for t in range(num_timesteps):
73
+ band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
+ min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
+ max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
+ image.image[band_idx] = (image.image[band_idx] - min_val) / (
77
+ max_val - min_val
78
+ )
79
+ needed_band_indices.remove(band_idx)
80
+
81
+ if len(needed_band_indices) > 0:
82
+ raise ValueError(
83
+ f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
84
+ )
85
+
86
+ return input_dict, target_dict
@@ -0,0 +1,170 @@
1
+ """Wrapper for the Panopticon model."""
2
+
3
+ import math
4
+ from enum import StrEnum
5
+ from importlib import resources
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import yaml
10
+ from einops import rearrange, repeat
11
+
12
+ from rslearn.log_utils import get_logger
13
+ from rslearn.train.model_context import ModelContext
14
+
15
+ from .component import FeatureExtractor, FeatureMaps
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class PanopticonModalities(StrEnum):
21
+ """Modalities supported by Panopticon.
22
+
23
+ These are the keys needed to load the yaml file from panopticon_data/sensors
24
+ """
25
+
26
+ SENTINEL2 = "sentinel2"
27
+ LANDSAT8 = "landsat8"
28
+ SENTINEL1 = "sentinel1"
29
+ # Add more modalities as needed
30
+
31
+
32
+ class Panopticon(FeatureExtractor):
33
+ """Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
34
+
35
+ patch_size: int = 14
36
+ base_image_size: int = 224
37
+
38
+ def __init__(
39
+ self,
40
+ band_order: dict[str, list[str]],
41
+ torchhub_id: str = "panopticon_vitb14",
42
+ ):
43
+ """Initialize the Panopticon wrapper.
44
+
45
+ Args:
46
+ band_order: The band order for the panopticon model, must match the specified order in the data config
47
+ torchhub_id: The torch hub model ID for panopticon
48
+ """
49
+ super().__init__()
50
+ # Load the panopticon model
51
+ self._load_model(torchhub_id)
52
+ self.output_dim = self.model.embed_dim
53
+ self.band_order = band_order
54
+ self.supported_modalities = list(band_order.keys())
55
+
56
+ def _load_model(self, torchhub_id: str) -> None:
57
+ """Load the panopticon model from torch hub."""
58
+ import time
59
+
60
+ # Hack to get around https://discuss.pytorch.org/t/torch-hub-load-gives-httperror-rate-limit-exceeded/124769
61
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
62
+ for attempt in range(2):
63
+ try:
64
+ self.model = torch.hub.load( # nosec B614
65
+ "panopticon-FM/panopticon",
66
+ torchhub_id,
67
+ )
68
+ break
69
+ except Exception as e:
70
+ logger.warning(
71
+ f"Error loading panopticon model: {e}. Retrying in 5 seconds..."
72
+ )
73
+ time.sleep(5)
74
+ else:
75
+ raise RuntimeError(
76
+ f"Failed to load panopticon model {torchhub_id} after retrying."
77
+ )
78
+
79
+ def _process_modality_data(self, data: torch.Tensor) -> torch.Tensor:
80
+ """Process individual modality data.
81
+
82
+ Args:
83
+ data: Input tensor of shape [B, C, H, W]
84
+
85
+ Returns:
86
+ Processed tensor of shape [B, C, H, W]
87
+ """
88
+ original_height = data.shape[2]
89
+ new_height = self.patch_size if original_height == 1 else self.base_image_size
90
+
91
+ data = F.interpolate(
92
+ data,
93
+ size=(new_height, new_height),
94
+ mode="bilinear",
95
+ align_corners=False,
96
+ )
97
+ return data
98
+
99
+ def _create_channel_ids(
100
+ self, modality: str, batch_size: int, device: torch.device
101
+ ) -> torch.Tensor:
102
+ """Create channel IDs for the panopticon model."""
103
+ with resources.open_text(
104
+ "rslearn.models.panopticon_data.sensors", f"{modality}.yaml"
105
+ ) as f:
106
+ sensor_config = yaml.safe_load(f)
107
+
108
+ band_order = self.band_order[modality]
109
+ chn_ids = [
110
+ sensor_config["bands"][band.upper()]["gaussian"]["mu"]
111
+ for band in band_order
112
+ ]
113
+ chn_ids = torch.tensor(chn_ids, dtype=torch.float32, device=device)
114
+ chn_ids = repeat(chn_ids, "c -> b c", b=batch_size)
115
+ return chn_ids
116
+
117
+ def prepare_input(
118
+ self, input_data: dict[str, torch.Tensor]
119
+ ) -> dict[str, torch.Tensor]:
120
+ """Prepare input for the panopticon model from MaskedHeliosSample."""
121
+ channel_ids_list: list[torch.Tensor] = []
122
+ processed_data_list: list[torch.Tensor] = []
123
+ for modality in self.supported_modalities:
124
+ if modality not in input_data.keys():
125
+ logger.debug(f"Modality {modality} not found in input data")
126
+ continue
127
+ data = input_data[modality]
128
+ device = data.device
129
+ processed_data = self._process_modality_data(data)
130
+ processed_data_list.append(processed_data)
131
+ batch_size = processed_data.shape[0]
132
+ chn_ids = self._create_channel_ids(modality, batch_size, device)
133
+ channel_ids_list.append(chn_ids)
134
+
135
+ processed_data = torch.cat(processed_data_list, dim=1)
136
+ chn_ids = torch.cat(channel_ids_list, dim=1)
137
+ return {
138
+ "imgs": processed_data,
139
+ "chn_ids": chn_ids,
140
+ }
141
+
142
+ def forward(self, context: ModelContext) -> FeatureMaps:
143
+ """Forward pass through the panopticon model."""
144
+ batch_inputs = {
145
+ key: torch.stack(
146
+ [inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
147
+ )
148
+ for key in context.inputs[0].keys()
149
+ }
150
+ panopticon_inputs = self.prepare_input(batch_inputs)
151
+ output_features = self.model.forward_features(panopticon_inputs)[
152
+ "x_norm_patchtokens"
153
+ ]
154
+
155
+ num_tokens = output_features.shape[1]
156
+ height = int(math.sqrt(num_tokens))
157
+ output_features = rearrange(
158
+ output_features, "b (h w) d -> b d h w", h=height, w=height
159
+ )
160
+ return FeatureMaps([output_features])
161
+
162
+ def get_backbone_channels(self) -> list:
163
+ """Returns the output channels of this model when used as a backbone.
164
+
165
+ The output channels is a list of (downsample_factor, depth) that corresponds
166
+ to the feature maps that the backbone returns. For example, an element [2, 32]
167
+ indicates that the corresponding feature map is 1/2 the input resolution and
168
+ has 32 channels.
169
+ """
170
+ return [(self.patch_size, self.output_dim)]