rslearn 0.0.9__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +5 -1
- rslearn/models/dinov3.py +6 -1
- rslearn/models/feature_center_crop.py +50 -0
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +263 -0
- rslearn/models/olmoearth_pretrain/norm.py +84 -0
- rslearn/models/pooling_decoder.py +43 -0
- rslearn/models/prithvi.py +9 -1
- rslearn/train/lightning_module.py +0 -3
- rslearn/train/tasks/classification.py +2 -2
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +5 -4
- rslearn/train/tasks/regression.py +5 -5
- rslearn/train/transforms/pad.py +3 -3
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/METADATA +3 -1
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/RECORD +21 -25
- rslearn-0.0.12.dist-info/licenses/NOTICE +115 -0
- rslearn/models/copernicusfm.py +0 -228
- rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/WHEEL +0 -0
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/top_level.txt +0 -0
rslearn/models/anysat.py
CHANGED
rslearn/models/dinov3.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""DinoV3 model.
|
|
1
|
+
"""DinoV3 model.
|
|
2
|
+
|
|
3
|
+
This code loads the DINOv3 model. You must obtain the model separately from Meta to use
|
|
4
|
+
it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
|
|
5
|
+
information.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
from enum import StrEnum
|
|
4
9
|
from pathlib import Path
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Apply center cropping on a feature map."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FeatureCenterCrop(torch.nn.Module):
|
|
9
|
+
"""Apply center cropping on the input feature maps."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
sizes: list[tuple[int, int]],
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Create a new FeatureCenterCrop.
|
|
16
|
+
|
|
17
|
+
Only the center of each feature map will be retained and passed to the next
|
|
18
|
+
module.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
sizes: a list of (height, width) tuples, with one tuple for each input
|
|
22
|
+
feature map.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.sizes = sizes
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
+
) -> list[torch.Tensor]:
|
|
30
|
+
"""Apply center cropping on the feature maps.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
features: list of feature maps at different resolutions.
|
|
34
|
+
inputs: original inputs (ignored).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
center cropped feature maps.
|
|
38
|
+
"""
|
|
39
|
+
new_features = []
|
|
40
|
+
for i, feat in enumerate(features):
|
|
41
|
+
height, width = self.sizes[i]
|
|
42
|
+
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"feature map is smaller than the desired height and width"
|
|
45
|
+
)
|
|
46
|
+
start_h = feat.shape[2] // 2 - height // 2
|
|
47
|
+
start_w = feat.shape[3] // 2 - width // 2
|
|
48
|
+
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
|
+
new_features.append(feat)
|
|
50
|
+
return new_features
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth model architecture."""
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from olmo_core.config import Config
|
|
10
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.model_loader import (
|
|
13
|
+
ModelID,
|
|
14
|
+
load_model_from_id,
|
|
15
|
+
load_model_from_path,
|
|
16
|
+
)
|
|
17
|
+
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
18
|
+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
19
|
+
from upath import UPath
|
|
20
|
+
|
|
21
|
+
from rslearn.log_utils import get_logger
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
MODALITY_NAMES = [
|
|
26
|
+
"sentinel2_l2a",
|
|
27
|
+
"sentinel1",
|
|
28
|
+
"worldcover",
|
|
29
|
+
"openstreetmap_raster",
|
|
30
|
+
"landsat",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
AUTOCAST_DTYPE_MAP = {
|
|
34
|
+
"bfloat16": torch.bfloat16,
|
|
35
|
+
"float16": torch.float16,
|
|
36
|
+
"float32": torch.float32,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
EMBEDDING_SIZES = {
|
|
40
|
+
ModelID.OLMOEARTH_V1_NANO: 128,
|
|
41
|
+
ModelID.OLMOEARTH_V1_TINY: 192,
|
|
42
|
+
ModelID.OLMOEARTH_V1_BASE: 768,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class OlmoEarth(torch.nn.Module):
|
|
47
|
+
"""A wrapper to support the OlmoEarth model."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
patch_size: int,
|
|
52
|
+
model_id: ModelID | None = None,
|
|
53
|
+
model_path: str | None = None,
|
|
54
|
+
checkpoint_path: str | None = None,
|
|
55
|
+
selector: list[str | int] = ["encoder"],
|
|
56
|
+
forward_kwargs: dict[str, Any] = {},
|
|
57
|
+
random_initialization: bool = False,
|
|
58
|
+
embedding_size: int | None = None,
|
|
59
|
+
autocast_dtype: str | None = "bfloat16",
|
|
60
|
+
):
|
|
61
|
+
"""Create a new OlmoEarth model.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
patch_size: token spatial patch size to use.
|
|
65
|
+
model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
|
|
66
|
+
set.
|
|
67
|
+
model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
|
|
68
|
+
set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
|
|
69
|
+
checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
|
|
70
|
+
set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
|
|
71
|
+
folder.
|
|
72
|
+
selector: an optional sequence of attribute names or list indices to select
|
|
73
|
+
the sub-module that should be applied on the input images. Defaults to
|
|
74
|
+
["encoder"] to select only the transformer encoder.
|
|
75
|
+
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
76
|
+
MaskedOlmoEarthSample.
|
|
77
|
+
random_initialization: whether to skip loading the checkpoint so the
|
|
78
|
+
weights are randomly initialized. In this case, the checkpoint is only
|
|
79
|
+
used to define the model architecture.
|
|
80
|
+
embedding_size: optional embedding size to report via
|
|
81
|
+
get_backbone_channels (if model_id is not set).
|
|
82
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
83
|
+
"""
|
|
84
|
+
if (
|
|
85
|
+
sum(
|
|
86
|
+
[
|
|
87
|
+
model_id is not None,
|
|
88
|
+
model_path is not None,
|
|
89
|
+
checkpoint_path is not None,
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
!= 1
|
|
93
|
+
):
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"exactly one of model_id, model_path, or checkpoint_path must be set"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.patch_size = patch_size
|
|
100
|
+
self.forward_kwargs = forward_kwargs
|
|
101
|
+
self.embedding_size = embedding_size
|
|
102
|
+
|
|
103
|
+
if autocast_dtype is not None:
|
|
104
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
105
|
+
else:
|
|
106
|
+
self.autocast_dtype = None
|
|
107
|
+
|
|
108
|
+
if model_id is not None:
|
|
109
|
+
# Load from Hugging Face.
|
|
110
|
+
model = load_model_from_id(model_id, load_weights=not random_initialization)
|
|
111
|
+
if self.embedding_size is None and model_id in EMBEDDING_SIZES:
|
|
112
|
+
self.embedding_size = EMBEDDING_SIZES[model_id]
|
|
113
|
+
|
|
114
|
+
elif model_path is not None:
|
|
115
|
+
# Load from path.
|
|
116
|
+
model = load_model_from_path(
|
|
117
|
+
UPath(model_path), load_weights=not random_initialization
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
else:
|
|
121
|
+
# Load the distributed model checkpoint by path through Olmo Core
|
|
122
|
+
model = self._load_model_from_checkpoint(
|
|
123
|
+
UPath(checkpoint_path), random_initialization
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Select just the portion of the model that we actually want to use.
|
|
127
|
+
for part in selector:
|
|
128
|
+
if isinstance(part, str):
|
|
129
|
+
model = getattr(model, part)
|
|
130
|
+
else:
|
|
131
|
+
model = model[part]
|
|
132
|
+
self.model = model
|
|
133
|
+
|
|
134
|
+
def _load_model_from_checkpoint(
|
|
135
|
+
self, checkpoint_upath: UPath, random_initialization: bool
|
|
136
|
+
) -> torch.nn.Module:
|
|
137
|
+
"""Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
|
|
138
|
+
|
|
139
|
+
The folder should contain config.json as well as the model_and_optim folder
|
|
140
|
+
that contains the distributed checkpoint. This is the format produced by
|
|
141
|
+
pre-training runs in olmoearth_pretrain.
|
|
142
|
+
"""
|
|
143
|
+
# Load the model config and initialize it.
|
|
144
|
+
# We avoid loading the train module here because it depends on running within
|
|
145
|
+
# olmo_core.
|
|
146
|
+
with (checkpoint_upath / "config.json").open() as f:
|
|
147
|
+
config_dict = json.load(f)
|
|
148
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
149
|
+
|
|
150
|
+
model = model_config.build()
|
|
151
|
+
|
|
152
|
+
# Load the checkpoint.
|
|
153
|
+
if not random_initialization:
|
|
154
|
+
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
155
|
+
if train_module_dir.exists():
|
|
156
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
157
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
158
|
+
else:
|
|
159
|
+
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
160
|
+
|
|
161
|
+
return model
|
|
162
|
+
|
|
163
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
164
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
165
|
+
|
|
166
|
+
Inputs:
|
|
167
|
+
inputs: input dicts. It should include keys corresponding to the modalities
|
|
168
|
+
that should be passed to the OlmoEarth model.
|
|
169
|
+
"""
|
|
170
|
+
kwargs = {}
|
|
171
|
+
present_modalities = []
|
|
172
|
+
device = None
|
|
173
|
+
# Handle the case where some modalities are multitemporal and some are not.
|
|
174
|
+
# We assume all multitemporal modalities have the same number of timesteps.
|
|
175
|
+
max_timesteps = 1
|
|
176
|
+
for modality in MODALITY_NAMES:
|
|
177
|
+
if modality not in inputs[0]:
|
|
178
|
+
continue
|
|
179
|
+
present_modalities.append(modality)
|
|
180
|
+
cur = torch.stack([inp[modality] for inp in inputs], dim=0)
|
|
181
|
+
device = cur.device
|
|
182
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
183
|
+
num_bands = Modality.get(modality).num_bands
|
|
184
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
185
|
+
max_timesteps = max(max_timesteps, num_timesteps)
|
|
186
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
187
|
+
kwargs[modality] = cur
|
|
188
|
+
# Create mask array which is BHWTS (without channels but with band sets).
|
|
189
|
+
num_band_sets = len(Modality.get(modality).band_sets)
|
|
190
|
+
mask_shape = cur.shape[0:4] + (num_band_sets,)
|
|
191
|
+
mask = (
|
|
192
|
+
torch.ones(mask_shape, dtype=torch.int32, device=device)
|
|
193
|
+
* MaskValue.ONLINE_ENCODER.value
|
|
194
|
+
)
|
|
195
|
+
kwargs[f"{modality}_mask"] = mask
|
|
196
|
+
|
|
197
|
+
# Timestamps is required.
|
|
198
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
199
|
+
# For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
|
|
200
|
+
timestamps = torch.zeros(
|
|
201
|
+
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
202
|
+
)
|
|
203
|
+
timestamps[:, :, 0] = 1 # day
|
|
204
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
205
|
+
None, :
|
|
206
|
+
] # month
|
|
207
|
+
timestamps[:, :, 2] = 2024 # year
|
|
208
|
+
kwargs["timestamps"] = timestamps
|
|
209
|
+
|
|
210
|
+
sample = MaskedOlmoEarthSample(**kwargs)
|
|
211
|
+
|
|
212
|
+
# Decide context based on self.autocast_dtype.
|
|
213
|
+
if self.autocast_dtype is None:
|
|
214
|
+
context = nullcontext()
|
|
215
|
+
else:
|
|
216
|
+
assert device is not None
|
|
217
|
+
context = torch.amp.autocast(
|
|
218
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
with context:
|
|
222
|
+
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
223
|
+
tokens_and_masks: TokensAndMasks
|
|
224
|
+
if isinstance(self.model, Encoder):
|
|
225
|
+
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
226
|
+
tokens_and_masks = self.model(
|
|
227
|
+
sample,
|
|
228
|
+
fast_pass=True,
|
|
229
|
+
patch_size=self.patch_size,
|
|
230
|
+
**self.forward_kwargs,
|
|
231
|
+
)["tokens_and_masks"]
|
|
232
|
+
else:
|
|
233
|
+
# Other models like STEncoder do not have this option supported.
|
|
234
|
+
tokens_and_masks = self.model(
|
|
235
|
+
sample, patch_size=self.patch_size, **self.forward_kwargs
|
|
236
|
+
)["tokens_and_masks"]
|
|
237
|
+
|
|
238
|
+
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
239
|
+
features = []
|
|
240
|
+
for modality in present_modalities:
|
|
241
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
242
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
243
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
244
|
+
# We want BHWC -> BCHW.
|
|
245
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
246
|
+
features.append(pooled)
|
|
247
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
248
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
249
|
+
return [pooled]
|
|
250
|
+
|
|
251
|
+
def get_backbone_channels(self) -> list:
|
|
252
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
253
|
+
|
|
254
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
255
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
256
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
257
|
+
has 32 channels.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
261
|
+
tuples.
|
|
262
|
+
"""
|
|
263
|
+
return [(self.patch_size, self.embedding_size)]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Normalization transforms."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from olmoearth_pretrain.data.normalize import load_computed_config
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.train.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__file__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OlmoEarthNormalize(Transform):
|
|
15
|
+
"""Normalize using OlmoEarth JSON config.
|
|
16
|
+
|
|
17
|
+
For Sentinel-1 data, the values should be converted to decibels before being passed
|
|
18
|
+
to this transform.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
band_names: dict[str, list[str]],
|
|
24
|
+
std_multiplier: float | None = 2,
|
|
25
|
+
config_fname: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize a new OlmoEarthNormalize.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band_names: map from modality name to the list of bands in that modality in
|
|
31
|
+
the order they are being loaded. Note that this order must match the
|
|
32
|
+
expected order for the OlmoEarth model.
|
|
33
|
+
std_multiplier: the std multiplier matching the one used for the model
|
|
34
|
+
training in OlmoEarth.
|
|
35
|
+
config_fname: load the normalization configuration from this file, instead
|
|
36
|
+
of getting it from OlmoEarth.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.band_names = band_names
|
|
40
|
+
self.std_multiplier = std_multiplier
|
|
41
|
+
|
|
42
|
+
if config_fname is None:
|
|
43
|
+
self.norm_config = load_computed_config()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
|
|
47
|
+
)
|
|
48
|
+
with open(config_fname) as f:
|
|
49
|
+
self.norm_config = json.load(f)
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Apply normalization over the inputs and targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
input_dict: the input
|
|
58
|
+
target_dict: the target
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
normalized (input_dicts, target_dicts) tuple
|
|
62
|
+
"""
|
|
63
|
+
for modality_name, cur_band_names in self.band_names.items():
|
|
64
|
+
band_norms = self.norm_config[modality_name]
|
|
65
|
+
image = input_dict[modality_name]
|
|
66
|
+
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
+
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
+
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
69
|
+
|
|
70
|
+
for band, norm_dict in band_norms.items():
|
|
71
|
+
# If multitemporal, normalize each timestep separately.
|
|
72
|
+
for t in range(num_timesteps):
|
|
73
|
+
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
|
+
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
|
+
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
+
image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
|
|
77
|
+
needed_band_indices.remove(band_idx)
|
|
78
|
+
|
|
79
|
+
if len(needed_band_indices) > 0:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return input_dict, target_dict
|
|
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
76
76
|
features = torch.amax(features, dim=(2, 3))
|
|
77
77
|
features = self.fc_layers(features)
|
|
78
78
|
return self.output_layer(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
82
|
+
"""Like PoolingDecoder, but copy output to all pixels.
|
|
83
|
+
|
|
84
|
+
This allows for the model to produce a global output while still being compatible
|
|
85
|
+
with SegmentationTask. This only makes sense for very small windows, since the
|
|
86
|
+
output probabilities will be the same at all pixels. The main use case is to train
|
|
87
|
+
for a classification-like task on small windows, but still produce a raster during
|
|
88
|
+
inference on large windows.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
in_channels: int,
|
|
94
|
+
out_channels: int,
|
|
95
|
+
image_key: str = "image",
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
):
|
|
98
|
+
"""Create a new SegmentationPoolingDecoder.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
in_channels: input channels (channels in the last feature map passed to
|
|
102
|
+
this module)
|
|
103
|
+
out_channels: channels for the output flat feature vector
|
|
104
|
+
image_key: the key in inputs for the image from which the expected width
|
|
105
|
+
and height is derived.
|
|
106
|
+
kwargs: other arguments to pass to PoolingDecoder.
|
|
107
|
+
"""
|
|
108
|
+
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
|
+
self.image_key = image_key
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
+
) -> torch.Tensor:
|
|
114
|
+
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
|
+
|
|
116
|
+
This only works when all of the pixels have the same segmentation target.
|
|
117
|
+
"""
|
|
118
|
+
output_probs = super().forward(features, inputs)
|
|
119
|
+
# BC -> BCHW
|
|
120
|
+
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
+
return output_probs[:, :, None, None].repeat([1, 1, h, w])
|
rslearn/models/prithvi.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
"""Prithvi V2.
|
|
1
|
+
"""Prithvi V2.
|
|
2
|
+
|
|
3
|
+
This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
4
|
+
|
|
5
|
+
The code is released under:
|
|
6
|
+
|
|
7
|
+
MIT License
|
|
8
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import json
|
|
4
12
|
import logging
|
|
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
94
94
|
restore_config: RestoreConfig | None = None,
|
|
95
95
|
print_parameters: bool = False,
|
|
96
96
|
print_model: bool = False,
|
|
97
|
-
strict_loading: bool = True,
|
|
98
97
|
# Deprecated options.
|
|
99
98
|
lr: float = 1e-3,
|
|
100
99
|
plateau: bool = False,
|
|
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
118
117
|
print_parameters: whether to print the list of model parameters after model
|
|
119
118
|
initialization
|
|
120
119
|
print_model: whether to print the model after model initialization
|
|
121
|
-
strict_loading: whether to strictly load the model parameters.
|
|
122
120
|
lr: deprecated.
|
|
123
121
|
plateau: deprecated.
|
|
124
122
|
plateau_factor: deprecated.
|
|
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
130
|
self.visualize_dir = visualize_dir
|
|
133
131
|
self.metrics_file = metrics_file
|
|
134
132
|
self.restore_config = restore_config
|
|
135
|
-
self.strict_loading = strict_loading
|
|
136
133
|
|
|
137
134
|
self.scheduler_factory: SchedulerFactory | None = None
|
|
138
135
|
if scheduler:
|
|
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
|
|
|
49
49
|
features with matching properties.
|
|
50
50
|
read_class_id: whether to read an integer class ID instead of the class
|
|
51
51
|
name.
|
|
52
|
-
allow_invalid: instead of throwing error when no
|
|
53
|
-
at a window, simply mark the example invalid for this task
|
|
52
|
+
allow_invalid: instead of throwing error when no classification label is
|
|
53
|
+
found at a window, simply mark the example invalid for this task
|
|
54
54
|
skip_unknown_categories: whether to skip examples with categories that are
|
|
55
55
|
not passed via classes, instead of throwing error
|
|
56
56
|
prob_property: when predicting, write probabilities in addition to class ID
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -72,11 +72,11 @@ class DetectionTask(BasicTask):
|
|
|
72
72
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
73
73
|
**kwargs: Any,
|
|
74
74
|
) -> None:
|
|
75
|
-
"""Initialize a new
|
|
75
|
+
"""Initialize a new DetectionTask.
|
|
76
76
|
|
|
77
77
|
Args:
|
|
78
|
-
property_name: the property from which to extract the class name.
|
|
79
|
-
|
|
78
|
+
property_name: the property from which to extract the class name. Features
|
|
79
|
+
without this property name are ignored.
|
|
80
80
|
classes: a list of class names.
|
|
81
81
|
filters: optional list of (property_name, property_value) to only consider
|
|
82
82
|
features with matching properties.
|
|
@@ -86,8 +86,8 @@ class DetectionTask(BasicTask):
|
|
|
86
86
|
not passed via classes, instead of throwing error
|
|
87
87
|
skip_empty_examples: whether to skip examples with zero labels.
|
|
88
88
|
colors: optional colors for each class
|
|
89
|
-
box_size: force all boxes to be this size, centered at the
|
|
90
|
-
geometry. Required for Point geometries.
|
|
89
|
+
box_size: force all boxes to be two times this size, centered at the
|
|
90
|
+
centroid of the geometry. Required for Point geometries.
|
|
91
91
|
clip_boxes: whether to clip boxes to the image bounds.
|
|
92
92
|
exclude_by_center: before optionally clipping boxes, exclude boxes if the
|
|
93
93
|
center is outside the image bounds.
|
|
@@ -26,10 +26,11 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
26
26
|
"""Initialize a new PerPixelRegressionTask.
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
|
-
scale_factor: multiply
|
|
29
|
+
scale_factor: multiply ground truth values by this factor before using it for
|
|
30
30
|
training.
|
|
31
|
-
metric_mode: what metric to use, either mse or l1
|
|
32
|
-
nodata_value: optional value to treat as invalid
|
|
31
|
+
metric_mode: what metric to use, either "mse" (default) or "l1"
|
|
32
|
+
nodata_value: optional value to treat as invalid. The loss will be masked
|
|
33
|
+
at pixels where the ground truth value is equal to nodata_value.
|
|
33
34
|
kwargs: other arguments to pass to BasicTask
|
|
34
35
|
"""
|
|
35
36
|
super().__init__(**kwargs)
|
|
@@ -141,7 +142,7 @@ class PerPixelRegressionHead(torch.nn.Module):
|
|
|
141
142
|
"""Initialize a new RegressionHead.
|
|
142
143
|
|
|
143
144
|
Args:
|
|
144
|
-
loss_mode: the loss function to use, either "mse" or "l1".
|
|
145
|
+
loss_mode: the loss function to use, either "mse" (default) or "l1".
|
|
145
146
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
146
147
|
requires targets to be between 0-1.
|
|
147
148
|
"""
|
|
@@ -33,14 +33,14 @@ class RegressionTask(BasicTask):
|
|
|
33
33
|
"""Initialize a new RegressionTask.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
|
-
property_name: the property from which to extract the
|
|
37
|
-
value is read from the first matching feature.
|
|
36
|
+
property_name: the property from which to extract the ground truth
|
|
37
|
+
regression value. The value is read from the first matching feature.
|
|
38
38
|
filters: optional list of (property_name, property_value) to only consider
|
|
39
39
|
features with matching properties.
|
|
40
40
|
allow_invalid: instead of throwing error when no regression label is found
|
|
41
41
|
at a window, simply mark the example invalid for this task
|
|
42
|
-
scale_factor: multiply the label value by this factor
|
|
43
|
-
metric_mode: what metric to use, either mse or l1
|
|
42
|
+
scale_factor: multiply the label value by this factor for training
|
|
43
|
+
metric_mode: what metric to use, either "mse" (default) or "l1"
|
|
44
44
|
use_accuracy_metric: include metric that reports percentage of
|
|
45
45
|
examples where output is within a factor of the ground truth.
|
|
46
46
|
within_factor: the factor for accuracy metric. If it's 0.2, and ground
|
|
@@ -189,7 +189,7 @@ class RegressionHead(torch.nn.Module):
|
|
|
189
189
|
"""Initialize a new RegressionHead.
|
|
190
190
|
|
|
191
191
|
Args:
|
|
192
|
-
loss_mode: the loss function to use, either "mse" or "l1".
|
|
192
|
+
loss_mode: the loss function to use, either "mse" (default) or "l1".
|
|
193
193
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
194
194
|
requires targets to be between 0-1.
|
|
195
195
|
"""
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -25,8 +25,8 @@ class Pad(Transform):
|
|
|
25
25
|
Args:
|
|
26
26
|
size: the size to pad to, or a min/max range of pad sizes. If the image is
|
|
27
27
|
larger than this size, then it is cropped instead.
|
|
28
|
-
mode: "
|
|
29
|
-
"
|
|
28
|
+
mode: "topleft" (default) to only apply padding on the bottom and right
|
|
29
|
+
sides, or "center" to apply padding equally on all sides.
|
|
30
30
|
image_selectors: image items to transform.
|
|
31
31
|
box_selectors: boxes items to transform.
|
|
32
32
|
"""
|
|
@@ -64,7 +64,7 @@ class Pad(Transform):
|
|
|
64
64
|
) -> torch.Tensor:
|
|
65
65
|
# Before/after must either be both non-negative or both negative.
|
|
66
66
|
# >=0 indicates padding while <0 indicates cropping.
|
|
67
|
-
assert (before < 0 and after
|
|
67
|
+
assert (before < 0 and after <= 0) or (before >= 0 and after >= 0)
|
|
68
68
|
if before > 0:
|
|
69
69
|
# Padding.
|
|
70
70
|
if horizontal:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.12
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -211,6 +211,7 @@ Project-URL: repository, https://github.com/allenai/rslearn
|
|
|
211
211
|
Requires-Python: >=3.11
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
|
+
License-File: NOTICE
|
|
214
215
|
Requires-Dist: boto3>=1.39
|
|
215
216
|
Requires-Dist: fiona>=1.10
|
|
216
217
|
Requires-Dist: fsspec>=2025.9.0
|
|
@@ -243,6 +244,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
243
244
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
244
245
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
245
246
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
247
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
246
248
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
247
249
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
248
250
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|