rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/models/multitask.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""MultiTaskModel for rslearn."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Iterable
|
|
3
4
|
from copy import deepcopy
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
@@ -7,6 +8,9 @@ import torch
|
|
|
7
8
|
|
|
8
9
|
from rslearn.log_utils import get_logger
|
|
9
10
|
from rslearn.models.trunk import DecoderTrunk
|
|
11
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
12
|
+
|
|
13
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
10
14
|
|
|
11
15
|
logger = get_logger(__name__)
|
|
12
16
|
|
|
@@ -58,8 +62,8 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
58
62
|
|
|
59
63
|
def __init__(
|
|
60
64
|
self,
|
|
61
|
-
encoder: list[
|
|
62
|
-
decoders: dict[str, list[
|
|
65
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
66
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
63
67
|
lazy_decode: bool = False,
|
|
64
68
|
loss_weights: dict[str, float] | None = None,
|
|
65
69
|
trunk: DecoderTrunk | None = None,
|
|
@@ -67,8 +71,12 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
67
71
|
"""Initialize a new MultiTaskModel.
|
|
68
72
|
|
|
69
73
|
Args:
|
|
70
|
-
encoder: modules to compute intermediate feature representations.
|
|
74
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
75
|
+
module must be a FeatureExtractor, and following modules must be
|
|
76
|
+
IntermediateComponents.
|
|
71
77
|
decoders: modules to compute outputs and loss, should match number of tasks.
|
|
78
|
+
The last module must be a Predictor, while the previous modules must be
|
|
79
|
+
IntermediateComponents.
|
|
72
80
|
lazy_decode: if True, only decode the outputs specified in the batch.
|
|
73
81
|
loss_weights: weights for each task's loss (default: None = equal weights).
|
|
74
82
|
trunk: if provided, use this trunk module to postprocess the features
|
|
@@ -76,7 +84,7 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
76
84
|
"""
|
|
77
85
|
super().__init__()
|
|
78
86
|
self.lazy_decode = lazy_decode
|
|
79
|
-
self.encoder = torch.nn.
|
|
87
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
80
88
|
self.decoders = torch.nn.ModuleDict(
|
|
81
89
|
sort_keys(
|
|
82
90
|
{
|
|
@@ -120,32 +128,28 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
120
128
|
|
|
121
129
|
def apply_decoder(
|
|
122
130
|
self,
|
|
123
|
-
|
|
124
|
-
|
|
131
|
+
intermediates: Any,
|
|
132
|
+
context: ModelContext,
|
|
125
133
|
targets: list[dict[str, Any]] | None,
|
|
126
|
-
decoder: list[
|
|
134
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
127
135
|
task_name: str,
|
|
128
|
-
|
|
129
|
-
losses: dict[str, torch.Tensor],
|
|
130
|
-
) -> tuple[list[dict[str, Any]], dict[str, torch.Tensor]]:
|
|
136
|
+
) -> ModelOutput:
|
|
131
137
|
"""Apply a decoder to a list of inputs and targets.
|
|
132
138
|
|
|
133
139
|
Args:
|
|
134
|
-
|
|
135
|
-
|
|
140
|
+
intermediates: the intermediate output from the encoder.
|
|
141
|
+
context: the model context.
|
|
136
142
|
targets: list of target dicts
|
|
137
143
|
decoder: list of decoder modules
|
|
138
144
|
task_name: the name of the task
|
|
139
|
-
outputs: list of output dicts
|
|
140
|
-
losses: dictionary of loss values
|
|
141
145
|
|
|
142
146
|
Returns:
|
|
143
|
-
|
|
147
|
+
a ModelOutput containing outputs across all the decoders.
|
|
144
148
|
"""
|
|
145
149
|
# First, apply all but the last module in the decoder to the features
|
|
146
|
-
cur =
|
|
150
|
+
cur = intermediates
|
|
147
151
|
for module in decoder[:-1]:
|
|
148
|
-
cur = module(cur,
|
|
152
|
+
cur = module(cur, context)
|
|
149
153
|
|
|
150
154
|
if targets is None:
|
|
151
155
|
cur_targets = None
|
|
@@ -153,14 +157,7 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
153
157
|
cur_targets = [target[task_name] for target in targets]
|
|
154
158
|
|
|
155
159
|
# Then, apply the last module to the features and targets
|
|
156
|
-
|
|
157
|
-
for idx, entry in enumerate(cur_output):
|
|
158
|
-
outputs[idx][task_name] = entry
|
|
159
|
-
for loss_name, loss_value in cur_loss_dict.items():
|
|
160
|
-
losses[f"{task_name}_{loss_name}"] = (
|
|
161
|
-
loss_value * self.loss_weights[task_name]
|
|
162
|
-
)
|
|
163
|
-
return outputs, losses
|
|
160
|
+
return decoder[-1](cur, context, cur_targets)
|
|
164
161
|
|
|
165
162
|
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
166
163
|
"""Get the tasks corresponding to this decoder.
|
|
@@ -172,66 +169,84 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
172
169
|
|
|
173
170
|
def apply_decoders(
|
|
174
171
|
self,
|
|
175
|
-
|
|
176
|
-
|
|
172
|
+
intermediates: Any,
|
|
173
|
+
context: ModelContext,
|
|
177
174
|
targets: list[dict[str, Any]] | None,
|
|
178
|
-
) ->
|
|
175
|
+
) -> ModelOutput:
|
|
179
176
|
"""Apply all the decoders to the features and targets.
|
|
180
177
|
|
|
181
178
|
Args:
|
|
182
|
-
|
|
183
|
-
|
|
179
|
+
intermediates: the intermediates from the encoder.
|
|
180
|
+
context: the model context
|
|
184
181
|
targets: list of target dicts
|
|
185
182
|
|
|
186
183
|
Returns:
|
|
187
|
-
|
|
184
|
+
combined ModelOutput. The outputs is a list of output dicts, one per example,
|
|
185
|
+
where the dict maps from task name to the corresponding task output. The
|
|
186
|
+
losses is a flat dict but the task name is prepended to the loss names.
|
|
188
187
|
"""
|
|
189
|
-
outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in inputs]
|
|
188
|
+
outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
|
|
190
189
|
losses: dict[str, torch.Tensor] = {}
|
|
191
190
|
|
|
192
191
|
if self.lazy_decode:
|
|
193
192
|
# Assume that all inputs have the same dataset_source
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
193
|
+
task_name = context.metadatas[0].dataset_source
|
|
194
|
+
|
|
195
|
+
if task_name is None:
|
|
196
|
+
raise ValueError("dataset_source must be set for lazy decoding")
|
|
197
|
+
|
|
198
|
+
decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
|
|
199
|
+
model_output = self.apply_decoder(
|
|
200
|
+
intermediates, context, targets, decoder, task_name
|
|
200
201
|
)
|
|
202
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
203
|
+
outputs[idx][task_name] = entry
|
|
204
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
205
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
206
|
+
loss_value * self.loss_weights[task_name]
|
|
207
|
+
)
|
|
201
208
|
else:
|
|
202
209
|
for decoder_name, decoder in self.decoders.items():
|
|
203
210
|
for task_name in self._get_tasks_from_decoder(decoder_name):
|
|
204
|
-
self.apply_decoder(
|
|
205
|
-
|
|
211
|
+
model_output = self.apply_decoder(
|
|
212
|
+
intermediates, context, targets, decoder, task_name
|
|
206
213
|
)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
214
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
215
|
+
outputs[idx][task_name] = entry
|
|
216
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
217
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
218
|
+
loss_value * self.loss_weights[task_name]
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return ModelOutput(
|
|
222
|
+
outputs=outputs,
|
|
223
|
+
loss_dict=losses,
|
|
224
|
+
)
|
|
212
225
|
|
|
213
226
|
def forward(
|
|
214
227
|
self,
|
|
215
|
-
|
|
228
|
+
context: ModelContext,
|
|
216
229
|
targets: list[dict[str, Any]] | None = None,
|
|
217
|
-
) ->
|
|
230
|
+
) -> ModelOutput:
|
|
218
231
|
"""Apply the sequence of modules on the inputs, including shared trunk.
|
|
219
232
|
|
|
220
233
|
Args:
|
|
221
|
-
|
|
234
|
+
context: the model context.
|
|
222
235
|
targets: optional list of target dicts
|
|
223
236
|
|
|
224
237
|
Returns:
|
|
225
|
-
|
|
238
|
+
the model output from apply_decoders.
|
|
226
239
|
"""
|
|
227
|
-
|
|
240
|
+
cur = self.encoder[0](context)
|
|
241
|
+
for module in self.encoder[1:]:
|
|
242
|
+
cur = module(cur, context)
|
|
228
243
|
if self.trunk is not None:
|
|
229
|
-
trunk_out = self.trunk(
|
|
230
|
-
outs = self.apply_decoders(trunk_out.pop("outputs"),
|
|
244
|
+
trunk_out = self.trunk(cur, context)
|
|
245
|
+
outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
|
|
231
246
|
self.trunk.apply_auxiliary_losses(trunk_out, outs)
|
|
232
247
|
return outs | trunk_out
|
|
233
248
|
else:
|
|
234
|
-
return self.apply_decoders(
|
|
249
|
+
return self.apply_decoders(cur, context, targets)
|
|
235
250
|
|
|
236
251
|
|
|
237
252
|
class MultiTaskMergedModel(MultiTaskModel):
|
|
@@ -247,8 +262,8 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
247
262
|
|
|
248
263
|
def __init__(
|
|
249
264
|
self,
|
|
250
|
-
encoder: list[
|
|
251
|
-
decoders: dict[str, list[
|
|
265
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
266
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
252
267
|
decoder_to_target: dict[str, list[str]],
|
|
253
268
|
task_label_offsets: dict[str, dict[str, Any]],
|
|
254
269
|
lazy_decode: bool = False,
|
|
@@ -273,7 +288,7 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
273
288
|
torch.nn.Module.__init__(self)
|
|
274
289
|
|
|
275
290
|
self.lazy_decode = lazy_decode
|
|
276
|
-
self.encoder = torch.nn.
|
|
291
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
277
292
|
self.decoders = torch.nn.ModuleDict(
|
|
278
293
|
sort_keys(
|
|
279
294
|
{
|
|
@@ -329,9 +344,9 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
329
344
|
return offset_targets
|
|
330
345
|
|
|
331
346
|
def unmerge_output_labels(
|
|
332
|
-
self, outputs:
|
|
333
|
-
) ->
|
|
334
|
-
"""Unmerge the task
|
|
347
|
+
self, outputs: Iterable[Any], task_name: str
|
|
348
|
+
) -> list[dict[str, torch.Tensor | dict]]:
|
|
349
|
+
"""Unmerge the task outputs.
|
|
335
350
|
|
|
336
351
|
For most tasks, this means chopping off the corresponding label dimensions.
|
|
337
352
|
For some, we might just need to subtract an offset from the target (ex: segmentation).
|
|
@@ -340,10 +355,15 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
340
355
|
Args:
|
|
341
356
|
outputs: the predictions
|
|
342
357
|
task_name: the name of the task
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
the unmerged outputs.
|
|
343
361
|
"""
|
|
344
362
|
offset = self.task_label_offsets[task_name]["offset"]
|
|
345
363
|
num_outputs = self.task_label_offsets[task_name]["num_outputs"]
|
|
346
364
|
output_key = self.task_label_offsets[task_name]["outputs_key"]
|
|
365
|
+
|
|
366
|
+
unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
|
|
347
367
|
with torch.no_grad():
|
|
348
368
|
for i, output in enumerate(outputs):
|
|
349
369
|
if not output:
|
|
@@ -353,35 +373,44 @@ class MultiTaskMergedModel(MultiTaskModel):
|
|
|
353
373
|
if isinstance(output, dict):
|
|
354
374
|
# For some tasks (eg object detection), we have discrete label
|
|
355
375
|
# predictions instead of a distribution over labels
|
|
356
|
-
|
|
376
|
+
unmerged_output = output.copy()
|
|
377
|
+
unmerged_output[output_key] = unmerged_output[output_key] - offset
|
|
378
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
357
379
|
elif isinstance(output, torch.Tensor):
|
|
358
380
|
# For classification/segmentation tasks, we have a distribution
|
|
359
381
|
# over labels, so we need to scale the predictions so that they
|
|
360
382
|
# sum to 1 since we chop off some of the probability densities
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
383
|
+
unmerged_output = output[offset : offset + num_outputs, ...]
|
|
384
|
+
unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
|
|
385
|
+
torch.float32
|
|
386
|
+
)
|
|
387
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
388
|
+
|
|
389
|
+
return unmerged_outputs
|
|
364
390
|
|
|
365
391
|
def forward(
|
|
366
392
|
self,
|
|
367
|
-
|
|
393
|
+
context: ModelContext,
|
|
368
394
|
targets: list[dict[str, Any]] | None = None,
|
|
369
|
-
) ->
|
|
395
|
+
) -> ModelOutput:
|
|
370
396
|
"""Apply the sequence of modules on the inputs.
|
|
371
397
|
|
|
372
398
|
Args:
|
|
373
|
-
|
|
399
|
+
context: the model context.
|
|
374
400
|
targets: optional list of target dicts
|
|
375
401
|
|
|
376
402
|
Returns:
|
|
377
|
-
|
|
403
|
+
the model output.
|
|
378
404
|
"""
|
|
379
|
-
dataset_source =
|
|
380
|
-
assert
|
|
381
|
-
|
|
382
|
-
outs = super().forward(
|
|
383
|
-
self.unmerge_output_labels(outs
|
|
384
|
-
return
|
|
405
|
+
dataset_source = context.metadatas[0].dataset_source
|
|
406
|
+
assert dataset_source is not None
|
|
407
|
+
merged_targets = self.merge_task_labels(targets, dataset_source)
|
|
408
|
+
outs = super().forward(context, merged_targets)
|
|
409
|
+
unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
|
|
410
|
+
return ModelOutput(
|
|
411
|
+
outputs=unmerged_outputs,
|
|
412
|
+
loss_dict=outs.loss_dict,
|
|
413
|
+
)
|
|
385
414
|
|
|
386
415
|
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
387
416
|
"""Get the tasks corresponding to this decoder.
|
|
@@ -19,6 +19,8 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
21
|
from rslearn.log_utils import get_logger
|
|
22
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
|
|
23
|
+
from rslearn.train.model_context import ModelContext
|
|
22
24
|
|
|
23
25
|
logger = get_logger(__name__)
|
|
24
26
|
|
|
@@ -44,7 +46,7 @@ EMBEDDING_SIZES = {
|
|
|
44
46
|
}
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
class OlmoEarth(
|
|
49
|
+
class OlmoEarth(FeatureExtractor):
|
|
48
50
|
"""A wrapper to support the OlmoEarth model."""
|
|
49
51
|
|
|
50
52
|
def __init__(
|
|
@@ -58,6 +60,7 @@ class OlmoEarth(torch.nn.Module):
|
|
|
58
60
|
random_initialization: bool = False,
|
|
59
61
|
embedding_size: int | None = None,
|
|
60
62
|
autocast_dtype: str | None = "bfloat16",
|
|
63
|
+
token_pooling: bool = True,
|
|
61
64
|
):
|
|
62
65
|
"""Create a new OlmoEarth model.
|
|
63
66
|
|
|
@@ -81,6 +84,9 @@ class OlmoEarth(torch.nn.Module):
|
|
|
81
84
|
embedding_size: optional embedding size to report via
|
|
82
85
|
get_backbone_channels (if model_id is not set).
|
|
83
86
|
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
87
|
+
token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
|
|
88
|
+
there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
|
|
89
|
+
dimensions.
|
|
84
90
|
"""
|
|
85
91
|
if (
|
|
86
92
|
sum(
|
|
@@ -131,6 +137,7 @@ class OlmoEarth(torch.nn.Module):
|
|
|
131
137
|
else:
|
|
132
138
|
model = model[part]
|
|
133
139
|
self.model = model
|
|
140
|
+
self.token_pooling = token_pooling
|
|
134
141
|
|
|
135
142
|
def _load_model_from_checkpoint(
|
|
136
143
|
self, checkpoint_upath: UPath, random_initialization: bool
|
|
@@ -158,45 +165,89 @@ class OlmoEarth(torch.nn.Module):
|
|
|
158
165
|
|
|
159
166
|
return model
|
|
160
167
|
|
|
161
|
-
def
|
|
162
|
-
|
|
168
|
+
def _prepare_modality_inputs(
|
|
169
|
+
self, context: ModelContext
|
|
170
|
+
) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
|
|
171
|
+
"""Prepare modality tensors and masks for the OlmoEarth model.
|
|
172
|
+
|
|
173
|
+
Uses a two-pass approach to ensure all modalities have consistent timestep
|
|
174
|
+
dimensions for position encoding.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
context: the model context with input tensors.
|
|
163
178
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
that should be passed to the OlmoEarth model.
|
|
179
|
+
Returns:
|
|
180
|
+
tuple of (sample, present_modalities, device)
|
|
167
181
|
"""
|
|
168
182
|
kwargs = {}
|
|
169
183
|
present_modalities = []
|
|
170
184
|
device = None
|
|
171
|
-
|
|
172
|
-
#
|
|
185
|
+
|
|
186
|
+
# First pass: find global max_timesteps across all modalities and samples
|
|
187
|
+
# TODO: currently we assume all modalities have the same number of timesteps,
|
|
188
|
+
# which is not true for all cases, and time series time steps are assumed to
|
|
189
|
+
# be 1-month apart. It also assumes continuity between available timesteps.
|
|
190
|
+
# We'll have to fix all that.
|
|
173
191
|
max_timesteps = 1
|
|
192
|
+
modality_data = {}
|
|
174
193
|
for modality in MODALITY_NAMES:
|
|
175
|
-
if modality not in inputs[0]:
|
|
194
|
+
if modality not in context.inputs[0]:
|
|
176
195
|
continue
|
|
177
196
|
present_modalities.append(modality)
|
|
178
|
-
|
|
179
|
-
device =
|
|
180
|
-
# Check if it's single or multitemporal, and reshape accordingly
|
|
197
|
+
tensors = [inp[modality] for inp in context.inputs]
|
|
198
|
+
device = tensors[0].device
|
|
181
199
|
num_bands = Modality.get(modality).num_bands
|
|
182
|
-
|
|
183
|
-
max_timesteps = max(max_timesteps,
|
|
184
|
-
|
|
200
|
+
max_t = max(t.shape[0] for t in tensors) // num_bands
|
|
201
|
+
max_timesteps = max(max_timesteps, max_t)
|
|
202
|
+
modality_data[modality] = (
|
|
203
|
+
tensors,
|
|
204
|
+
num_bands,
|
|
205
|
+
len(Modality.get(modality).band_sets),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Second pass: pad and process each modality with global max_timesteps
|
|
209
|
+
for modality in present_modalities:
|
|
210
|
+
tensors, num_bands, num_band_sets = modality_data[modality]
|
|
211
|
+
target_ch = max_timesteps * num_bands
|
|
212
|
+
|
|
213
|
+
# Pad tensors to target_ch and track original timesteps for masking
|
|
214
|
+
padded = []
|
|
215
|
+
original_timesteps = []
|
|
216
|
+
for t in tensors:
|
|
217
|
+
orig_t = t.shape[0] // num_bands
|
|
218
|
+
original_timesteps.append(orig_t)
|
|
219
|
+
if t.shape[0] < target_ch:
|
|
220
|
+
pad = torch.zeros(
|
|
221
|
+
(target_ch - t.shape[0],) + t.shape[1:],
|
|
222
|
+
dtype=t.dtype,
|
|
223
|
+
device=device,
|
|
224
|
+
)
|
|
225
|
+
t = torch.cat([t, pad], dim=0)
|
|
226
|
+
padded.append(t)
|
|
227
|
+
|
|
228
|
+
cur = torch.stack(padded, dim=0)
|
|
229
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=max_timesteps)
|
|
185
230
|
kwargs[modality] = cur
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
mask = (
|
|
190
|
-
|
|
191
|
-
|
|
231
|
+
|
|
232
|
+
# Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
|
|
233
|
+
b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
|
|
234
|
+
mask = torch.full(
|
|
235
|
+
(b, h, w, max_timesteps, num_band_sets),
|
|
236
|
+
fill_value=MaskValue.ONLINE_ENCODER.value,
|
|
237
|
+
dtype=torch.int32,
|
|
238
|
+
device=device,
|
|
192
239
|
)
|
|
240
|
+
for sample_idx, orig_t in enumerate(original_timesteps):
|
|
241
|
+
if orig_t < max_timesteps:
|
|
242
|
+
mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
|
|
193
243
|
kwargs[f"{modality}_mask"] = mask
|
|
194
244
|
|
|
195
245
|
# Timestamps is required.
|
|
196
246
|
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
197
|
-
# For now, we assign same timestamps to all inputs, but later we should
|
|
247
|
+
# For now, we assign same timestamps to all inputs, but later we should
|
|
248
|
+
# handle varying timestamps per input.
|
|
198
249
|
timestamps = torch.zeros(
|
|
199
|
-
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
250
|
+
(len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
200
251
|
)
|
|
201
252
|
timestamps[:, :, 0] = 1 # day
|
|
202
253
|
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
@@ -205,25 +256,46 @@ class OlmoEarth(torch.nn.Module):
|
|
|
205
256
|
timestamps[:, :, 2] = 2024 # year
|
|
206
257
|
kwargs["timestamps"] = timestamps
|
|
207
258
|
|
|
208
|
-
|
|
259
|
+
return MaskedOlmoEarthSample(**kwargs), present_modalities, device
|
|
260
|
+
|
|
261
|
+
def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
|
|
262
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
context: the model context. Input dicts should include keys corresponding
|
|
266
|
+
to the modalities that should be passed to the OlmoEarth model.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
a FeatureMaps consisting of one feature map, at 1/patch_size of the input
|
|
270
|
+
resolution. Embeddings will be pooled across modalities and timesteps.
|
|
271
|
+
"""
|
|
272
|
+
sample, present_modalities, device = self._prepare_modality_inputs(context)
|
|
209
273
|
|
|
210
274
|
# Decide context based on self.autocast_dtype.
|
|
211
275
|
if self.autocast_dtype is None:
|
|
212
|
-
|
|
276
|
+
torch_context = nullcontext()
|
|
213
277
|
else:
|
|
214
278
|
assert device is not None
|
|
215
|
-
|
|
279
|
+
torch_context = torch.amp.autocast(
|
|
216
280
|
device_type=device.type, dtype=self.autocast_dtype
|
|
217
281
|
)
|
|
218
282
|
|
|
219
|
-
|
|
283
|
+
# Check if we can bypass masks (fast_pass=True)
|
|
284
|
+
missing_tokens = False
|
|
285
|
+
for modality in present_modalities:
|
|
286
|
+
modality_mask = getattr(sample, f"{modality}_mask")
|
|
287
|
+
if torch.any(modality_mask == MaskValue.MISSING.value):
|
|
288
|
+
missing_tokens = True
|
|
289
|
+
break
|
|
290
|
+
|
|
291
|
+
with torch_context:
|
|
220
292
|
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
221
293
|
tokens_and_masks: TokensAndMasks
|
|
222
294
|
if isinstance(self.model, Encoder):
|
|
223
295
|
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
224
296
|
tokens_and_masks = self.model(
|
|
225
297
|
sample,
|
|
226
|
-
fast_pass=
|
|
298
|
+
fast_pass=not missing_tokens,
|
|
227
299
|
patch_size=self.patch_size,
|
|
228
300
|
**self.forward_kwargs,
|
|
229
301
|
)["tokens_and_masks"]
|
|
@@ -235,16 +307,41 @@ class OlmoEarth(torch.nn.Module):
|
|
|
235
307
|
|
|
236
308
|
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
237
309
|
features = []
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
310
|
+
if self.token_pooling:
|
|
311
|
+
for modality in present_modalities:
|
|
312
|
+
modality_features = getattr(tokens_and_masks, modality) # BHWTSC
|
|
313
|
+
# If fast_pass is False, we need to mask the missing tokens before pooling.
|
|
314
|
+
if missing_tokens:
|
|
315
|
+
modality_masks = getattr(
|
|
316
|
+
tokens_and_masks, f"{modality}_mask"
|
|
317
|
+
) # BHWTS
|
|
318
|
+
modality_masks_bool = (
|
|
319
|
+
modality_masks != MaskValue.MISSING.value
|
|
320
|
+
).unsqueeze(-1)
|
|
321
|
+
count = modality_masks_bool.sum(dim=[3, 4])
|
|
322
|
+
# Masked average over band sets and timesteps (BHWTSC -> BHWC).
|
|
323
|
+
pooled = (modality_features * modality_masks_bool).sum(
|
|
324
|
+
dim=[3, 4]
|
|
325
|
+
) / count.clamp(min=1)
|
|
326
|
+
else:
|
|
327
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
328
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
329
|
+
# We want BHWC -> BCHW.
|
|
330
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
331
|
+
features.append(pooled)
|
|
332
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
333
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
334
|
+
return FeatureMaps([pooled])
|
|
335
|
+
else:
|
|
336
|
+
for modality in present_modalities:
|
|
337
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
338
|
+
# Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
|
|
339
|
+
modality_features = rearrange(
|
|
340
|
+
modality_features, "b h w t s c -> b c h w (t s)"
|
|
341
|
+
)
|
|
342
|
+
features.append(modality_features)
|
|
343
|
+
pooled = torch.cat(features, dim=-1)
|
|
344
|
+
return TokenFeatureMaps([pooled])
|
|
248
345
|
|
|
249
346
|
def get_backbone_channels(self) -> list:
|
|
250
347
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/panopticon.py
CHANGED
|
@@ -3,15 +3,16 @@
|
|
|
3
3
|
import math
|
|
4
4
|
from enum import StrEnum
|
|
5
5
|
from importlib import resources
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
import torch.nn.functional as F
|
|
10
9
|
import yaml
|
|
11
10
|
from einops import rearrange, repeat
|
|
12
|
-
from torch import nn
|
|
13
11
|
|
|
14
12
|
from rslearn.log_utils import get_logger
|
|
13
|
+
from rslearn.train.model_context import ModelContext
|
|
14
|
+
|
|
15
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
16
|
|
|
16
17
|
logger = get_logger(__name__)
|
|
17
18
|
|
|
@@ -28,7 +29,7 @@ class PanopticonModalities(StrEnum):
|
|
|
28
29
|
# Add more modalities as needed
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
class Panopticon(
|
|
32
|
+
class Panopticon(FeatureExtractor):
|
|
32
33
|
"""Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
|
|
33
34
|
|
|
34
35
|
patch_size: int = 14
|
|
@@ -138,11 +139,11 @@ class Panopticon(nn.Module):
|
|
|
138
139
|
"chn_ids": chn_ids,
|
|
139
140
|
}
|
|
140
141
|
|
|
141
|
-
def forward(self,
|
|
142
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
142
143
|
"""Forward pass through the panopticon model."""
|
|
143
144
|
batch_inputs = {
|
|
144
|
-
key: torch.stack([inp[key] for inp in inputs], dim=0)
|
|
145
|
-
for key in inputs[0].keys()
|
|
145
|
+
key: torch.stack([inp[key] for inp in context.inputs], dim=0)
|
|
146
|
+
for key in context.inputs[0].keys()
|
|
146
147
|
}
|
|
147
148
|
panopticon_inputs = self.prepare_input(batch_inputs)
|
|
148
149
|
output_features = self.model.forward_features(panopticon_inputs)[
|
|
@@ -154,7 +155,7 @@ class Panopticon(nn.Module):
|
|
|
154
155
|
output_features = rearrange(
|
|
155
156
|
output_features, "b (h w) d -> b d h w", h=height, w=height
|
|
156
157
|
)
|
|
157
|
-
return [output_features]
|
|
158
|
+
return FeatureMaps([output_features])
|
|
158
159
|
|
|
159
160
|
def get_backbone_channels(self) -> list:
|
|
160
161
|
"""Returns the output channels of this model when used as a backbone.
|