rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +55 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +10 -3
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +20 -17
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/models/molmo.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Molmo model."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
7
5
|
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
|
|
11
|
+
class Molmo(FeatureExtractor):
|
|
10
12
|
"""Molmo image encoder."""
|
|
11
13
|
|
|
12
14
|
def __init__(
|
|
@@ -34,21 +36,21 @@ class Molmo(torch.nn.Module):
|
|
|
34
36
|
) # nosec
|
|
35
37
|
self.encoder = model.model.vision_backbone
|
|
36
38
|
|
|
37
|
-
def forward(self,
|
|
39
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
38
40
|
"""Compute outputs from the backbone.
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
process. The images should have values 0-255.
|
|
42
|
+
Args:
|
|
43
|
+
context: the model context. Input dicts must include "image" key containing
|
|
44
|
+
the image to process. The images should have values 0-255.
|
|
43
45
|
|
|
44
46
|
Returns:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
a FeatureMaps. Molmo produces features at one scale, so it will contain one
|
|
48
|
+
feature map that is a Bx24x24x2048 tensor.
|
|
47
49
|
"""
|
|
48
|
-
device = inputs[0]["image"].device
|
|
50
|
+
device = context.inputs[0]["image"].device
|
|
49
51
|
molmo_inputs_list = []
|
|
50
52
|
# Process each one so we can isolate just the full image without any crops.
|
|
51
|
-
for inp in inputs:
|
|
53
|
+
for inp in context.inputs:
|
|
52
54
|
image = inp["image"].cpu().numpy().transpose(1, 2, 0)
|
|
53
55
|
processed = self.processor.process(
|
|
54
56
|
images=[image],
|
|
@@ -60,6 +62,6 @@ class Molmo(torch.nn.Module):
|
|
|
60
62
|
image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
|
|
61
63
|
|
|
62
64
|
# 576x2048 -> 24x24x2048
|
|
63
|
-
return
|
|
64
|
-
image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
|
|
65
|
-
|
|
65
|
+
return FeatureMaps(
|
|
66
|
+
[image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
|
|
67
|
+
)
|
rslearn/models/multitask.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""MultiTaskModel for rslearn."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Iterable
|
|
3
4
|
from copy import deepcopy
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
@@ -7,6 +8,9 @@ import torch
|
|
|
7
8
|
|
|
8
9
|
from rslearn.log_utils import get_logger
|
|
9
10
|
from rslearn.models.trunk import DecoderTrunk
|
|
11
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
12
|
+
|
|
13
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
10
14
|
|
|
11
15
|
logger = get_logger(__name__)
|
|
12
16
|
|
|
@@ -58,8 +62,8 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
58
62
|
|
|
59
63
|
def __init__(
|
|
60
64
|
self,
|
|
61
|
-
encoder: list[
|
|
62
|
-
decoders: dict[str, list[
|
|
65
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
66
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
63
67
|
lazy_decode: bool = False,
|
|
64
68
|
loss_weights: dict[str, float] | None = None,
|
|
65
69
|
trunk: DecoderTrunk | None = None,
|
|
@@ -67,8 +71,12 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
67
71
|
"""Initialize a new MultiTaskModel.
|
|
68
72
|
|
|
69
73
|
Args:
|
|
70
|
-
encoder: modules to compute intermediate feature representations.
|
|
74
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
75
|
+
module must be a FeatureExtractor, and following modules must be
|
|
76
|
+
IntermediateComponents.
|
|
71
77
|
decoders: modules to compute outputs and loss, should match number of tasks.
|
|
78
|
+
The last module must be a Predictor, while the previous modules must be
|
|
79
|
+
IntermediateComponents.
|
|
72
80
|
lazy_decode: if True, only decode the outputs specified in the batch.
|
|
73
81
|
loss_weights: weights for each task's loss (default: None = equal weights).
|
|
74
82
|
trunk: if provided, use this trunk module to postprocess the features
|
|
@@ -76,7 +84,7 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
76
84
|
"""
|
|
77
85
|
super().__init__()
|
|
78
86
|
self.lazy_decode = lazy_decode
|
|
79
|
-
self.encoder = torch.nn.
|
|
87
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
80
88
|
self.decoders = torch.nn.ModuleDict(
|
|
81
89
|
sort_keys(
|
|
82
90
|
{
|
|
@@ -120,32 +128,28 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
120
128
|
|
|
121
129
|
def apply_decoder(
|
|
122
130
|
self,
|
|
123
|
-
|
|
124
|
-
|
|
131
|
+
intermediates: Any,
|
|
132
|
+
context: ModelContext,
|
|
125
133
|
targets: list[dict[str, Any]] | None,
|
|
126
|
-
decoder: list[
|
|
134
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
127
135
|
task_name: str,
|
|
128
|
-
|
|
129
|
-
losses: dict[str, torch.Tensor],
|
|
130
|
-
) -> tuple[list[dict[str, Any]], dict[str, torch.Tensor]]:
|
|
136
|
+
) -> ModelOutput:
|
|
131
137
|
"""Apply a decoder to a list of inputs and targets.
|
|
132
138
|
|
|
133
139
|
Args:
|
|
134
|
-
|
|
135
|
-
|
|
140
|
+
intermediates: the intermediate output from the encoder.
|
|
141
|
+
context: the model context.
|
|
136
142
|
targets: list of target dicts
|
|
137
143
|
decoder: list of decoder modules
|
|
138
144
|
task_name: the name of the task
|
|
139
|
-
outputs: list of output dicts
|
|
140
|
-
losses: dictionary of loss values
|
|
141
145
|
|
|
142
146
|
Returns:
|
|
143
|
-
|
|
147
|
+
a ModelOutput containing outputs across all the decoders.
|
|
144
148
|
"""
|
|
145
149
|
# First, apply all but the last module in the decoder to the features
|
|
146
|
-
cur =
|
|
150
|
+
cur = intermediates
|
|
147
151
|
for module in decoder[:-1]:
|
|
148
|
-
cur = module(cur,
|
|
152
|
+
cur = module(cur, context)
|
|
149
153
|
|
|
150
154
|
if targets is None:
|
|
151
155
|
cur_targets = None
|
|
@@ -153,14 +157,7 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
153
157
|
cur_targets = [target[task_name] for target in targets]
|
|
154
158
|
|
|
155
159
|
# Then, apply the last module to the features and targets
|
|
156
|
-
|
|
157
|
-
for idx, entry in enumerate(cur_output):
|
|
158
|
-
outputs[idx][task_name] = entry
|
|
159
|
-
for loss_name, loss_value in cur_loss_dict.items():
|
|
160
|
-
losses[f"{task_name}_{loss_name}"] = (
|
|
161
|
-
loss_value * self.loss_weights[task_name]
|
|
162
|
-
)
|
|
163
|
-
return outputs, losses
|
|
160
|
+
return decoder[-1](cur, context, cur_targets)
|
|
164
161
|
|
|
165
162
|
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
166
163
|
"""Get the tasks corresponding to this decoder.
|
|
@@ -172,66 +169,84 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
172
169
|
|
|
173
170
|
def apply_decoders(
|
|
174
171
|
self,
|
|
175
|
-
|
|
176
|
-
|
|
172
|
+
intermediates: Any,
|
|
173
|
+
context: ModelContext,
|
|
177
174
|
targets: list[dict[str, Any]] | None,
|
|
178
|
-
) ->
|
|
175
|
+
) -> ModelOutput:
|
|
179
176
|
"""Apply all the decoders to the features and targets.
|
|
180
177
|
|
|
181
178
|
Args:
|
|
182
|
-
|
|
183
|
-
|
|
179
|
+
intermediates: the intermediates from the encoder.
|
|
180
|
+
context: the model context
|
|
184
181
|
targets: list of target dicts
|
|
185
182
|
|
|
186
183
|
Returns:
|
|
187
|
-
|
|
184
|
+
combined ModelOutput. The outputs is a list of output dicts, one per example,
|
|
185
|
+
where the dict maps from task name to the corresponding task output. The
|
|
186
|
+
losses is a flat dict but the task name is prepended to the loss names.
|
|
188
187
|
"""
|
|
189
|
-
outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in inputs]
|
|
188
|
+
outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
|
|
190
189
|
losses: dict[str, torch.Tensor] = {}
|
|
191
190
|
|
|
192
191
|
if self.lazy_decode:
|
|
193
192
|
# Assume that all inputs have the same dataset_source
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
193
|
+
task_name = context.metadatas[0].dataset_source
|
|
194
|
+
|
|
195
|
+
if task_name is None:
|
|
196
|
+
raise ValueError("dataset_source must be set for lazy decoding")
|
|
197
|
+
|
|
198
|
+
decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
|
|
199
|
+
model_output = self.apply_decoder(
|
|
200
|
+
intermediates, context, targets, decoder, task_name
|
|
200
201
|
)
|
|
202
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
203
|
+
outputs[idx][task_name] = entry
|
|
204
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
205
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
206
|
+
loss_value * self.loss_weights[task_name]
|
|
207
|
+
)
|
|
201
208
|
else:
|
|
202
209
|
for decoder_name, decoder in self.decoders.items():
|
|
203
210
|
for task_name in self._get_tasks_from_decoder(decoder_name):
|
|
204
|
-
self.apply_decoder(
|
|
205
|
-
|
|
211
|
+
model_output = self.apply_decoder(
|
|
212
|
+
intermediates, context, targets, decoder, task_name
|
|
206
213
|
)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
214
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
215
|
+
outputs[idx][task_name] = entry
|
|
216
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
217
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
218
|
+
loss_value * self.loss_weights[task_name]
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return ModelOutput(
|
|
222
|
+
outputs=outputs,
|
|
223
|
+
loss_dict=losses,
|
|
224
|
+
)
|
|
212
225
|
|
|
213
226
|
def forward(
|
|
214
227
|
self,
|
|
215
|
-
|
|
228
|
+
context: ModelContext,
|
|
216
229
|
targets: list[dict[str, Any]] | None = None,
|
|
217
|
-
) ->
|
|
230
|
+
) -> ModelOutput:
|
|
218
231
|
"""Apply the sequence of modules on the inputs, including shared trunk.
|
|
219
232
|
|
|
220
233
|
Args:
|
|
221
|
-
|
|
234
|
+
context: the model context.
|
|
222
235
|
targets: optional list of target dicts
|
|
223
236
|
|
|
224
237
|
Returns:
|
|
225
|
-
|
|
238
|
+
the model output from apply_decoders.
|
|
226
239
|
"""
|
|
227
|
-
|
|
240
|
+
cur = self.encoder[0](context)
|
|
241
|
+
for module in self.encoder[1:]:
|
|
242
|
+
cur = module(cur, context)
|
|
228
243
|
if self.trunk is not None:
|
|
229
|
-
trunk_out = self.trunk(
|
|
230
|
-
outs = self.apply_decoders(trunk_out.pop("outputs"),
|
|
244
|
+
trunk_out = self.trunk(cur, context)
|
|
245
|
+
outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
|
|
231
246
|
self.trunk.apply_auxiliary_losses(trunk_out, outs)
|
|
232
247
|
return outs | trunk_out
|
|
233
248
|
else:
|
|
234
|
-
return self.apply_decoders(
|
|
249
|
+
return self.apply_decoders(cur, context, targets)
|
|
235
250
|
|
|
236
251
|
|
|
237
252
|
class MultiTaskMergedModel(MultiTaskModel):
|
|
@@ -247,8 +262,8 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
247
262
|
|
|
248
263
|
def __init__(
|
|
249
264
|
self,
|
|
250
|
-
encoder: list[
|
|
251
|
-
decoders: dict[str, list[
|
|
265
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
266
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
252
267
|
decoder_to_target: dict[str, list[str]],
|
|
253
268
|
task_label_offsets: dict[str, dict[str, Any]],
|
|
254
269
|
lazy_decode: bool = False,
|
|
@@ -273,7 +288,7 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
273
288
|
torch.nn.Module.__init__(self)
|
|
274
289
|
|
|
275
290
|
self.lazy_decode = lazy_decode
|
|
276
|
-
self.encoder = torch.nn.
|
|
291
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
277
292
|
self.decoders = torch.nn.ModuleDict(
|
|
278
293
|
sort_keys(
|
|
279
294
|
{
|
|
@@ -329,9 +344,9 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
329
344
|
return offset_targets
|
|
330
345
|
|
|
331
346
|
def unmerge_output_labels(
|
|
332
|
-
self, outputs:
|
|
333
|
-
) ->
|
|
334
|
-
"""Unmerge the task
|
|
347
|
+
self, outputs: Iterable[Any], task_name: str
|
|
348
|
+
) -> list[dict[str, torch.Tensor | dict]]:
|
|
349
|
+
"""Unmerge the task outputs.
|
|
335
350
|
|
|
336
351
|
For most tasks, this means chopping off the corresponding label dimensions.
|
|
337
352
|
For some, we might just need to subtract an offset from the target (ex: segmentation).
|
|
@@ -340,10 +355,15 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
340
355
|
Args:
|
|
341
356
|
outputs: the predictions
|
|
342
357
|
task_name: the name of the task
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
the unmerged outputs.
|
|
343
361
|
"""
|
|
344
362
|
offset = self.task_label_offsets[task_name]["offset"]
|
|
345
363
|
num_outputs = self.task_label_offsets[task_name]["num_outputs"]
|
|
346
364
|
output_key = self.task_label_offsets[task_name]["outputs_key"]
|
|
365
|
+
|
|
366
|
+
unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
|
|
347
367
|
with torch.no_grad():
|
|
348
368
|
for i, output in enumerate(outputs):
|
|
349
369
|
if not output:
|
|
@@ -353,35 +373,44 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
353
373
|
if isinstance(output, dict):
|
|
354
374
|
# For some tasks (eg object detection), we have discrete label
|
|
355
375
|
# predictions instead of a distribution over labels
|
|
356
|
-
|
|
376
|
+
unmerged_output = output.copy()
|
|
377
|
+
unmerged_output[output_key] = unmerged_output[output_key] - offset
|
|
378
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
357
379
|
elif isinstance(output, torch.Tensor):
|
|
358
380
|
# For classification/segmentation tasks, we have a distribution
|
|
359
381
|
# over labels, so we need to scale the predictions so that they
|
|
360
382
|
# sum to 1 since we chop off some of the probability densities
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
383
|
+
unmerged_output = output[offset : offset + num_outputs, ...]
|
|
384
|
+
unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
|
|
385
|
+
torch.float32
|
|
386
|
+
)
|
|
387
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
388
|
+
|
|
389
|
+
return unmerged_outputs
|
|
364
390
|
|
|
365
391
|
def forward(
|
|
366
392
|
self,
|
|
367
|
-
|
|
393
|
+
context: ModelContext,
|
|
368
394
|
targets: list[dict[str, Any]] | None = None,
|
|
369
|
-
) ->
|
|
395
|
+
) -> ModelOutput:
|
|
370
396
|
"""Apply the sequence of modules on the inputs.
|
|
371
397
|
|
|
372
398
|
Args:
|
|
373
|
-
|
|
399
|
+
context: the model context.
|
|
374
400
|
targets: optional list of target dicts
|
|
375
401
|
|
|
376
402
|
Returns:
|
|
377
|
-
|
|
403
|
+
the model output.
|
|
378
404
|
"""
|
|
379
|
-
dataset_source =
|
|
380
|
-
assert
|
|
381
|
-
|
|
382
|
-
outs = super().forward(
|
|
383
|
-
self.unmerge_output_labels(outs
|
|
384
|
-
return
|
|
405
|
+
dataset_source = context.metadatas[0].dataset_source
|
|
406
|
+
assert dataset_source is not None
|
|
407
|
+
merged_targets = self.merge_task_labels(targets, dataset_source)
|
|
408
|
+
outs = super().forward(context, merged_targets)
|
|
409
|
+
unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
|
|
410
|
+
return ModelOutput(
|
|
411
|
+
outputs=unmerged_outputs,
|
|
412
|
+
loss_dict=outs.loss_dict,
|
|
413
|
+
)
|
|
385
414
|
|
|
386
415
|
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
387
416
|
"""Get the tasks corresponding to this decoder.
|
|
@@ -19,6 +19,8 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
21
|
from rslearn.log_utils import get_logger
|
|
22
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
23
|
+
from rslearn.train.model_context import ModelContext
|
|
22
24
|
|
|
23
25
|
logger = get_logger(__name__)
|
|
24
26
|
|
|
@@ -44,7 +46,7 @@ EMBEDDING_SIZES = {
|
|
|
44
46
|
}
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
class OlmoEarth(
|
|
49
|
+
class OlmoEarth(FeatureExtractor):
|
|
48
50
|
"""A wrapper to support the OlmoEarth model."""
|
|
49
51
|
|
|
50
52
|
def __init__(
|
|
@@ -153,20 +155,21 @@ class OlmoEarth(torch.nn.Module):
|
|
|
153
155
|
# Load the checkpoint.
|
|
154
156
|
if not random_initialization:
|
|
155
157
|
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
159
|
-
else:
|
|
160
|
-
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
158
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
159
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
161
160
|
|
|
162
161
|
return model
|
|
163
162
|
|
|
164
|
-
def forward(self,
|
|
163
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
165
164
|
"""Compute feature maps from the OlmoEarth backbone.
|
|
166
165
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
that should be passed to the OlmoEarth model.
|
|
166
|
+
Args:
|
|
167
|
+
context: the model context. Input dicts should include keys corresponding
|
|
168
|
+
to the modalities that should be passed to the OlmoEarth model.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
a FeatureMaps consisting of one feature map, at 1/patch_size of the input
|
|
172
|
+
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
170
173
|
"""
|
|
171
174
|
kwargs = {}
|
|
172
175
|
present_modalities = []
|
|
@@ -175,10 +178,10 @@ class OlmoEarth(torch.nn.Module):
|
|
|
175
178
|
# We assume all multitemporal modalities have the same number of timesteps.
|
|
176
179
|
max_timesteps = 1
|
|
177
180
|
for modality in MODALITY_NAMES:
|
|
178
|
-
if modality not in inputs[0]:
|
|
181
|
+
if modality not in context.inputs[0]:
|
|
179
182
|
continue
|
|
180
183
|
present_modalities.append(modality)
|
|
181
|
-
cur = torch.stack([inp[modality] for inp in inputs], dim=0)
|
|
184
|
+
cur = torch.stack([inp[modality] for inp in context.inputs], dim=0)
|
|
182
185
|
device = cur.device
|
|
183
186
|
# Check if it's single or multitemporal, and reshape accordingly
|
|
184
187
|
num_bands = Modality.get(modality).num_bands
|
|
@@ -199,7 +202,7 @@ class OlmoEarth(torch.nn.Module):
|
|
|
199
202
|
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
200
203
|
# For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
|
|
201
204
|
timestamps = torch.zeros(
|
|
202
|
-
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
205
|
+
(len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
203
206
|
)
|
|
204
207
|
timestamps[:, :, 0] = 1 # day
|
|
205
208
|
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
@@ -212,14 +215,14 @@ class OlmoEarth(torch.nn.Module):
|
|
|
212
215
|
|
|
213
216
|
# Decide context based on self.autocast_dtype.
|
|
214
217
|
if self.autocast_dtype is None:
|
|
215
|
-
|
|
218
|
+
torch_context = nullcontext()
|
|
216
219
|
else:
|
|
217
220
|
assert device is not None
|
|
218
|
-
|
|
221
|
+
torch_context = torch.amp.autocast(
|
|
219
222
|
device_type=device.type, dtype=self.autocast_dtype
|
|
220
223
|
)
|
|
221
224
|
|
|
222
|
-
with
|
|
225
|
+
with torch_context:
|
|
223
226
|
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
224
227
|
tokens_and_masks: TokensAndMasks
|
|
225
228
|
if isinstance(self.model, Encoder):
|
|
@@ -247,7 +250,7 @@ class OlmoEarth(torch.nn.Module):
|
|
|
247
250
|
features.append(pooled)
|
|
248
251
|
# Pool over the modalities, so we get one BCHW feature map.
|
|
249
252
|
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
250
|
-
return [pooled]
|
|
253
|
+
return FeatureMaps([pooled])
|
|
251
254
|
|
|
252
255
|
def get_backbone_channels(self) -> list:
|
|
253
256
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/panopticon.py
CHANGED
|
@@ -3,15 +3,16 @@
|
|
|
3
3
|
import math
|
|
4
4
|
from enum import StrEnum
|
|
5
5
|
from importlib import resources
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
import torch.nn.functional as F
|
|
10
9
|
import yaml
|
|
11
10
|
from einops import rearrange, repeat
|
|
12
|
-
from torch import nn
|
|
13
11
|
|
|
14
12
|
from rslearn.log_utils import get_logger
|
|
13
|
+
from rslearn.train.model_context import ModelContext
|
|
14
|
+
|
|
15
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
16
|
|
|
16
17
|
logger = get_logger(__name__)
|
|
17
18
|
|
|
@@ -28,7 +29,7 @@ class PanopticonModalities(StrEnum):
|
|
|
28
29
|
# Add more modalities as needed
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
class Panopticon(
|
|
32
|
+
class Panopticon(FeatureExtractor):
|
|
32
33
|
"""Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
|
|
33
34
|
|
|
34
35
|
patch_size: int = 14
|
|
@@ -138,11 +139,11 @@ class Panopticon(nn.Module):
|
|
|
138
139
|
"chn_ids": chn_ids,
|
|
139
140
|
}
|
|
140
141
|
|
|
141
|
-
def forward(self,
|
|
142
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
142
143
|
"""Forward pass through the panopticon model."""
|
|
143
144
|
batch_inputs = {
|
|
144
|
-
key: torch.stack([inp[key] for inp in inputs], dim=0)
|
|
145
|
-
for key in inputs[0].keys()
|
|
145
|
+
key: torch.stack([inp[key] for inp in context.inputs], dim=0)
|
|
146
|
+
for key in context.inputs[0].keys()
|
|
146
147
|
}
|
|
147
148
|
panopticon_inputs = self.prepare_input(batch_inputs)
|
|
148
149
|
output_features = self.model.forward_features(panopticon_inputs)[
|
|
@@ -154,7 +155,7 @@ class Panopticon(nn.Module):
|
|
|
154
155
|
output_features = rearrange(
|
|
155
156
|
output_features, "b (h w) d -> b d h w", h=height, w=height
|
|
156
157
|
)
|
|
157
|
-
return [output_features]
|
|
158
|
+
return FeatureMaps([output_features])
|
|
158
159
|
|
|
159
160
|
def get_backbone_channels(self) -> list:
|
|
160
161
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/pick_features.py
CHANGED
|
@@ -2,45 +2,39 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
6
|
|
|
7
|
+
from .component import (
|
|
8
|
+
FeatureMaps,
|
|
9
|
+
IntermediateComponent,
|
|
10
|
+
)
|
|
7
11
|
|
|
8
|
-
|
|
12
|
+
|
|
13
|
+
class PickFeatures(IntermediateComponent):
|
|
9
14
|
"""Picks a subset of feature maps in a multi-scale feature map list."""
|
|
10
15
|
|
|
11
|
-
def __init__(self, indexes: list[int]
|
|
16
|
+
def __init__(self, indexes: list[int]):
|
|
12
17
|
"""Create a new PickFeatures.
|
|
13
18
|
|
|
14
19
|
Args:
|
|
15
20
|
indexes: the indexes of the input feature map list to select.
|
|
16
|
-
collapse: return one feature map instead of list. If enabled, indexes must
|
|
17
|
-
consist of one index. This is mainly useful for using PickFeatures as
|
|
18
|
-
the final module in the decoder, since the final prediction is expected
|
|
19
|
-
to be one feature map for most tasks like segmentation.
|
|
20
21
|
"""
|
|
21
22
|
super().__init__()
|
|
22
23
|
self.indexes = indexes
|
|
23
|
-
self.collapse = collapse
|
|
24
|
-
|
|
25
|
-
if self.collapse and len(self.indexes) != 1:
|
|
26
|
-
raise ValueError("if collapse is enabled, must get exactly one index")
|
|
27
24
|
|
|
28
25
|
def forward(
|
|
29
26
|
self,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
) -> list[torch.Tensor]:
|
|
27
|
+
intermediates: Any,
|
|
28
|
+
context: ModelContext,
|
|
29
|
+
) -> FeatureMaps:
|
|
34
30
|
"""Pick a subset of the features.
|
|
35
31
|
|
|
36
32
|
Args:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
targets: targets, not used
|
|
33
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
34
|
+
context: the model context.
|
|
40
35
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return new_features
|
|
36
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
37
|
+
raise ValueError("input to PickFeatures must be FeatureMaps")
|
|
38
|
+
|
|
39
|
+
new_features = [intermediates.feature_maps[idx] for idx in self.indexes]
|
|
40
|
+
return FeatureMaps(new_features)
|
|
@@ -4,8 +4,16 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
FeatureVector,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PoolingDecoder(IntermediateComponent):
|
|
9
17
|
"""Decoder that computes flat vector from a 2D feature map.
|
|
10
18
|
|
|
11
19
|
It inputs multi-scale features, but only uses the last feature map. Then applies a
|
|
@@ -57,25 +65,26 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
57
65
|
|
|
58
66
|
self.output_layer = torch.nn.Linear(prev_channels, out_channels)
|
|
59
67
|
|
|
60
|
-
def forward(
|
|
61
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
62
|
-
) -> torch.Tensor:
|
|
68
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
63
69
|
"""Compute flat output vector from multi-scale feature map.
|
|
64
70
|
|
|
65
71
|
Args:
|
|
66
|
-
|
|
67
|
-
|
|
72
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
73
|
+
context: the model context.
|
|
68
74
|
|
|
69
75
|
Returns:
|
|
70
76
|
flat feature vector
|
|
71
77
|
"""
|
|
78
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
79
|
+
raise ValueError("input to PoolingDecoder must be a FeatureMaps")
|
|
80
|
+
|
|
72
81
|
# Only use last feature map.
|
|
73
|
-
features =
|
|
82
|
+
features = intermediates.feature_maps[-1]
|
|
74
83
|
|
|
75
84
|
features = self.conv_layers(features)
|
|
76
85
|
features = torch.amax(features, dim=(2, 3))
|
|
77
86
|
features = self.fc_layers(features)
|
|
78
|
-
return self.output_layer(features)
|
|
87
|
+
return FeatureVector(self.output_layer(features))
|
|
79
88
|
|
|
80
89
|
|
|
81
90
|
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
@@ -108,14 +117,13 @@ class SegmentationPoolingDecoder(PoolingDecoder):
|
|
|
108
117
|
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
118
|
self.image_key = image_key
|
|
110
119
|
|
|
111
|
-
def forward(
|
|
112
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
-
) -> torch.Tensor:
|
|
120
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
114
121
|
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
122
|
|
|
116
123
|
This only works when all of the pixels have the same segmentation target.
|
|
117
124
|
"""
|
|
118
|
-
output_probs = super().forward(
|
|
125
|
+
output_probs = super().forward(intermediates, context)
|
|
119
126
|
# BC -> BCHW
|
|
120
|
-
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
-
|
|
127
|
+
h, w = context.inputs[0][self.image_key].shape[1:3]
|
|
128
|
+
feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
|
|
129
|
+
return FeatureMaps([feat_map])
|