rslearn 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """OlmoEarth model architecture."""
@@ -0,0 +1,203 @@
1
+ """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
+
3
+ import json
4
+ from contextlib import nullcontext
5
+ from typing import Any
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from olmo_core.config import Config
10
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.data.constants import Modality
12
+ from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
13
+ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
14
+ from upath import UPath
15
+
16
+ from rslearn.log_utils import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ MODALITY_NAMES = [
21
+ "sentinel2_l2a",
22
+ "sentinel1",
23
+ "worldcover",
24
+ "openstreetmap_raster",
25
+ "landsat",
26
+ ]
27
+
28
+ AUTOCAST_DTYPE_MAP = {
29
+ "bfloat16": torch.bfloat16,
30
+ "float16": torch.float16,
31
+ "float32": torch.float32,
32
+ }
33
+
34
+
35
+ class OlmoEarth(torch.nn.Module):
36
+ """A wrapper to support the OlmoEarth model."""
37
+
38
+ def __init__(
39
+ self,
40
+ # TODO: we should accept model ID instead of checkpoint_path once we are closer
41
+ # to being ready for release.
42
+ checkpoint_path: str,
43
+ selector: list[str | int] = [],
44
+ forward_kwargs: dict[str, Any] = {},
45
+ random_initialization: bool = False,
46
+ embedding_size: int | None = None,
47
+ patch_size: int | None = None,
48
+ autocast_dtype: str | None = "bfloat16",
49
+ ):
50
+ """Create a new OlmoEarth model.
51
+
52
+ Args:
53
+ checkpoint_path: the checkpoint directory to load. It should contain
54
+ config.json file as well as model_and_optim folder.
55
+ selector: an optional sequence of attribute names or list indices to select
56
+ the sub-module that should be applied on the input images.
57
+ forward_kwargs: additional arguments to pass to forward pass besides the
58
+ MaskedOlmoEarthSample.
59
+ random_initialization: whether to skip loading the checkpoint so the
60
+ weights are randomly initialized. In this case, the checkpoint is only
61
+ used to define the model architecture.
62
+ embedding_size: optional embedding size to report via
63
+ get_backbone_channels.
64
+ patch_size: optional patch size to report via get_backbone_channels.
65
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
66
+ """
67
+ super().__init__()
68
+ _checkpoint_path = UPath(checkpoint_path)
69
+ self.forward_kwargs = forward_kwargs
70
+ self.embedding_size = embedding_size
71
+ self.patch_size = patch_size
72
+
73
+ if autocast_dtype is not None:
74
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
75
+ else:
76
+ self.autocast_dtype = None
77
+
78
+ # Load the model config and initialize it.
79
+ # We avoid loading the train module here because it depends on running within
80
+ # olmo_core.
81
+ with (_checkpoint_path / "config.json").open() as f:
82
+ config_dict = json.load(f)
83
+ model_config = Config.from_dict(config_dict["model"])
84
+
85
+ model = model_config.build()
86
+
87
+ # Load the checkpoint.
88
+ if not random_initialization:
89
+ train_module_dir = _checkpoint_path / "model_and_optim"
90
+ if train_module_dir.exists():
91
+ load_model_and_optim_state(str(train_module_dir), model)
92
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
93
+ else:
94
+ logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
95
+ else:
96
+ logger.info("skipping loading OlmoEarth encoder")
97
+
98
+ # Select just the portion of the model that we actually want to use.
99
+ for part in selector:
100
+ if isinstance(part, str):
101
+ model = getattr(model, part)
102
+ else:
103
+ model = model[part]
104
+ self.model = model
105
+
106
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
107
+ """Compute feature maps from the OlmoEarth backbone.
108
+
109
+ Inputs:
110
+ inputs: input dicts. It should include keys corresponding to the modalities
111
+ that should be passed to the OlmoEarth model.
112
+ """
113
+ kwargs = {}
114
+ present_modalities = []
115
+ device = None
116
+ # Handle the case where some modalities are multitemporal and some are not.
117
+ # We assume all multitemporal modalities have the same number of timesteps.
118
+ max_timesteps = 1
119
+ for modality in MODALITY_NAMES:
120
+ if modality not in inputs[0]:
121
+ continue
122
+ present_modalities.append(modality)
123
+ cur = torch.stack([inp[modality] for inp in inputs], dim=0)
124
+ device = cur.device
125
+ # Check if it's single or multitemporal, and reshape accordingly
126
+ num_bands = Modality.get(modality).num_bands
127
+ num_timesteps = cur.shape[1] // num_bands
128
+ max_timesteps = max(max_timesteps, num_timesteps)
129
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
130
+ kwargs[modality] = cur
131
+ # Create mask array which is BHWTS (without channels but with band sets).
132
+ num_band_sets = len(Modality.get(modality).band_sets)
133
+ mask_shape = cur.shape[0:4] + (num_band_sets,)
134
+ mask = (
135
+ torch.ones(mask_shape, dtype=torch.int32, device=device)
136
+ * MaskValue.ONLINE_ENCODER.value
137
+ )
138
+ kwargs[f"{modality}_mask"] = mask
139
+
140
+ # Timestamps is required.
141
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
142
+ # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
143
+ timestamps = torch.zeros(
144
+ (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
145
+ )
146
+ timestamps[:, :, 0] = 1 # day
147
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
148
+ None, :
149
+ ] # month
150
+ timestamps[:, :, 2] = 2024 # year
151
+ kwargs["timestamps"] = timestamps
152
+
153
+ sample = MaskedOlmoEarthSample(**kwargs)
154
+
155
+ # Decide context based on self.autocast_dtype.
156
+ if self.autocast_dtype is None:
157
+ context = nullcontext()
158
+ else:
159
+ assert device is not None
160
+ context = torch.amp.autocast(
161
+ device_type=device.type, dtype=self.autocast_dtype
162
+ )
163
+
164
+ with context:
165
+ # Currently we assume the provided model always returns a TokensAndMasks object.
166
+ tokens_and_masks: TokensAndMasks
167
+ if isinstance(self.model, Encoder):
168
+ # Encoder has a fast_pass argument to indicate mask is not needed.
169
+ tokens_and_masks = self.model(
170
+ sample, fast_pass=True, **self.forward_kwargs
171
+ )["tokens_and_masks"]
172
+ else:
173
+ # Other models like STEncoder do not have this option supported.
174
+ tokens_and_masks = self.model(sample, **self.forward_kwargs)[
175
+ "tokens_and_masks"
176
+ ]
177
+
178
+ # Apply temporal/modality pooling so we just have one feature per patch.
179
+ features = []
180
+ for modality in present_modalities:
181
+ modality_features = getattr(tokens_and_masks, modality)
182
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
183
+ pooled = modality_features.mean(dim=[3, 4])
184
+ # We want BHWC -> BCHW.
185
+ pooled = rearrange(pooled, "b h w c -> b c h w")
186
+ features.append(pooled)
187
+ # Pool over the modalities, so we get one BCHW feature map.
188
+ pooled = torch.stack(features, dim=0).mean(dim=0)
189
+ return [pooled]
190
+
191
+ def get_backbone_channels(self) -> list:
192
+ """Returns the output channels of this model when used as a backbone.
193
+
194
+ The output channels is a list of (downsample_factor, depth) that corresponds
195
+ to the feature maps that the backbone returns. For example, an element [2, 32]
196
+ indicates that the corresponding feature map is 1/2 the input resolution and
197
+ has 32 channels.
198
+
199
+ Returns:
200
+ the output channels of the backbone as a list of (downsample_factor, depth)
201
+ tuples.
202
+ """
203
+ return [(self.patch_size, self.embedding_size)]
@@ -0,0 +1,84 @@
1
+ """Normalization transforms."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from olmoearth_pretrain.data.normalize import load_computed_config
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.train.transforms.transform import Transform
10
+
11
+ logger = get_logger(__file__)
12
+
13
+
14
+ class OlmoEarthNormalize(Transform):
15
+ """Normalize using OlmoEarth JSON config.
16
+
17
+ For Sentinel-1 data, the values should be converted to decibels before being passed
18
+ to this transform.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ band_names: dict[str, list[str]],
24
+ std_multiplier: float | None = 2,
25
+ config_fname: str | None = None,
26
+ ) -> None:
27
+ """Initialize a new OlmoEarthNormalize.
28
+
29
+ Args:
30
+ band_names: map from modality name to the list of bands in that modality in
31
+ the order they are being loaded. Note that this order must match the
32
+ expected order for the OlmoEarth model.
33
+ std_multiplier: the std multiplier matching the one used for the model
34
+ training in OlmoEarth.
35
+ config_fname: load the normalization configuration from this file, instead
36
+ of getting it from OlmoEarth.
37
+ """
38
+ super().__init__()
39
+ self.band_names = band_names
40
+ self.std_multiplier = std_multiplier
41
+
42
+ if config_fname is None:
43
+ self.norm_config = load_computed_config()
44
+ else:
45
+ logger.warning(
46
+ f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
47
+ )
48
+ with open(config_fname) as f:
49
+ self.norm_config = json.load(f)
50
+
51
+ def forward(
52
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Apply normalization over the inputs and targets.
55
+
56
+ Args:
57
+ input_dict: the input
58
+ target_dict: the target
59
+
60
+ Returns:
61
+ normalized (input_dicts, target_dicts) tuple
62
+ """
63
+ for modality_name, cur_band_names in self.band_names.items():
64
+ band_norms = self.norm_config[modality_name]
65
+ image = input_dict[modality_name]
66
+ # Keep a set of indices to make sure that we normalize all of them.
67
+ needed_band_indices = set(range(image.shape[0]))
68
+ num_timesteps = image.shape[0] // len(cur_band_names)
69
+
70
+ for band, norm_dict in band_norms.items():
71
+ # If multitemporal, normalize each timestep separately.
72
+ for t in range(num_timesteps):
73
+ band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
+ min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
+ max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
+ image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
77
+ needed_band_indices.remove(band_idx)
78
+
79
+ if len(needed_band_indices) > 0:
80
+ raise ValueError(
81
+ f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
82
+ )
83
+
84
+ return input_dict, target_dict
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
76
76
  features = torch.amax(features, dim=(2, 3))
77
77
  features = self.fc_layers(features)
78
78
  return self.output_layer(features)
79
+
80
+
81
+ class SegmentationPoolingDecoder(PoolingDecoder):
82
+ """Like PoolingDecoder, but copy output to all pixels.
83
+
84
+ This allows for the model to produce a global output while still being compatible
85
+ with SegmentationTask. This only makes sense for very small windows, since the
86
+ output probabilities will be the same at all pixels. The main use case is to train
87
+ for a classification-like task on small windows, but still produce a raster during
88
+ inference on large windows.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ image_key: str = "image",
96
+ **kwargs: Any,
97
+ ):
98
+ """Create a new SegmentationPoolingDecoder.
99
+
100
+ Args:
101
+ in_channels: input channels (channels in the last feature map passed to
102
+ this module)
103
+ out_channels: channels for the output flat feature vector
104
+ image_key: the key in inputs for the image from which the expected width
105
+ and height is derived.
106
+ kwargs: other arguments to pass to PoolingDecoder.
107
+ """
108
+ super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
+ self.image_key = image_key
110
+
111
+ def forward(
112
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
+ ) -> torch.Tensor:
114
+ """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
+
116
+ This only works when all of the pixels have the same segmentation target.
117
+ """
118
+ output_probs = super().forward(features, inputs)
119
+ # BC -> BCHW
120
+ h, w = inputs[0][self.image_key].shape[1:3]
121
+ return output_probs[:, :, None, None].repeat([1, 1, h, w])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -243,6 +243,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
243
243
  Requires-Dist: pycocotools>=2.0; extra == "extra"
244
244
  Requires-Dist: pystac_client>=0.9; extra == "extra"
245
245
  Requires-Dist: rtree>=1.4; extra == "extra"
246
+ Requires-Dist: termcolor>=3.0; extra == "extra"
246
247
  Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
247
248
  Requires-Dist: scipy>=1.16; extra == "extra"
248
249
  Requires-Dist: terratorch>=1.0.2; extra == "extra"
@@ -57,7 +57,7 @@ rslearn/models/molmo.py,sha256=mVrARBhZciMzOgOOjGB5AHlPIf2iO9IBSJmdyKSl1L8,2061
57
57
  rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,14672
58
58
  rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
59
59
  rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
60
- rslearn/models/pooling_decoder.py,sha256=jZfEQCfthfa21C9sEjgFHUcfhHMVlvG7_nDMw_1FLaE,2727
60
+ rslearn/models/pooling_decoder.py,sha256=unr2fSE_QmJHPi3dKtopqMtb1Kn-2h94LgwwAVP9vZg,4437
61
61
  rslearn/models/prithvi.py,sha256=SVM3ypJlVTkXQ69pPhB4UeJr87VnmADTCuyV365dbkU,39961
62
62
  rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
63
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
@@ -93,6 +93,9 @@ rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,6
93
93
  rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
94
94
  rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
95
95
  rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
96
+ rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
97
+ rslearn/models/olmoearth_pretrain/model.py,sha256=F-B1ym9UZuTPJ0OY15Jwb1TkNtr_EtAUlqI-tr_Z2uo,8352
98
+ rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
96
99
  rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
97
100
  rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
98
101
  rslearn/models/panopticon_data/sensors/goes.yaml,sha256=o00aoWCYqam0aB1rPmXq1MKe8hsKak_qyBG7BPL27Sc,152
@@ -156,9 +159,9 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
156
159
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
157
160
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
158
161
  rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
159
- rslearn-0.0.9.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
- rslearn-0.0.9.dist-info/METADATA,sha256=6BV8wt9tuo94FkoKjR3RcF3AbKNbU3IodkJtK4tASkE,36248
161
- rslearn-0.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- rslearn-0.0.9.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
163
- rslearn-0.0.9.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
164
- rslearn-0.0.9.dist-info/RECORD,,
162
+ rslearn-0.0.11.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
163
+ rslearn-0.0.11.dist-info/METADATA,sha256=jwB0ZZ-oLa1Y_1iuZRKCQoB4i3kOFDJ0xSeMTJP7zww,36297
164
+ rslearn-0.0.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
165
+ rslearn-0.0.11.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
166
+ rslearn-0.0.11.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
167
+ rslearn-0.0.11.dist-info/RECORD,,