rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/models/dinov3.py
CHANGED
|
@@ -13,9 +13,12 @@ import torch
|
|
|
13
13
|
import torchvision
|
|
14
14
|
from einops import rearrange
|
|
15
15
|
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
16
17
|
from rslearn.train.transforms.normalize import Normalize
|
|
17
18
|
from rslearn.train.transforms.transform import Transform
|
|
18
19
|
|
|
20
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
21
|
+
|
|
19
22
|
|
|
20
23
|
class DinoV3Models(StrEnum):
|
|
21
24
|
"""Names for different DinoV3 images on torch hub."""
|
|
@@ -40,7 +43,7 @@ DINOV3_PTHS: dict[str, str] = {
|
|
|
40
43
|
}
|
|
41
44
|
|
|
42
45
|
|
|
43
|
-
class DinoV3(
|
|
46
|
+
class DinoV3(FeatureExtractor):
|
|
44
47
|
"""DinoV3 Backbones.
|
|
45
48
|
|
|
46
49
|
Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
|
|
@@ -91,16 +94,18 @@ class DinoV3(torch.nn.Module):
|
|
|
91
94
|
self.do_resizing = do_resizing
|
|
92
95
|
self.model = self._load_model(size, checkpoint_dir)
|
|
93
96
|
|
|
94
|
-
def forward(self,
|
|
97
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
95
98
|
"""Forward pass for the dinov3 model.
|
|
96
99
|
|
|
97
100
|
Args:
|
|
98
|
-
|
|
101
|
+
context: the model context. Input dicts must include "image" key.
|
|
99
102
|
|
|
100
103
|
Returns:
|
|
101
|
-
|
|
104
|
+
a FeatureMaps with one feature map.
|
|
102
105
|
"""
|
|
103
|
-
cur = torch.stack(
|
|
106
|
+
cur = torch.stack(
|
|
107
|
+
[inp["image"] for inp in context.inputs], dim=0
|
|
108
|
+
) # (B, C, H, W)
|
|
104
109
|
|
|
105
110
|
if self.do_resizing and (
|
|
106
111
|
cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
|
|
@@ -118,7 +123,7 @@ class DinoV3(torch.nn.Module):
|
|
|
118
123
|
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
119
124
|
features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
|
|
120
125
|
|
|
121
|
-
return [features]
|
|
126
|
+
return FeatureMaps([features])
|
|
122
127
|
|
|
123
128
|
def get_backbone_channels(self) -> list:
|
|
124
129
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/faster_rcnn.py
CHANGED
|
@@ -6,6 +6,10 @@ from typing import Any
|
|
|
6
6
|
import torch
|
|
7
7
|
import torchvision
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
from .component import FeatureMaps, Predictor
|
|
12
|
+
|
|
9
13
|
|
|
10
14
|
class NoopTransform(torch.nn.Module):
|
|
11
15
|
"""A placeholder transform used with torchvision detection model."""
|
|
@@ -55,7 +59,7 @@ class NoopTransform(torch.nn.Module):
|
|
|
55
59
|
return image_list, targets
|
|
56
60
|
|
|
57
61
|
|
|
58
|
-
class FasterRCNN(
|
|
62
|
+
class FasterRCNN(Predictor):
|
|
59
63
|
"""Faster R-CNN head for predicting bounding boxes.
|
|
60
64
|
|
|
61
65
|
It inputs multi-scale features, using each feature map to predict ROIs and then
|
|
@@ -176,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
|
|
|
176
180
|
|
|
177
181
|
def forward(
|
|
178
182
|
self,
|
|
179
|
-
|
|
180
|
-
|
|
183
|
+
intermediates: Any,
|
|
184
|
+
context: ModelContext,
|
|
181
185
|
targets: list[dict[str, Any]] | None = None,
|
|
182
|
-
) ->
|
|
186
|
+
) -> ModelOutput:
|
|
183
187
|
"""Compute the detection outputs and loss from features.
|
|
184
188
|
|
|
185
189
|
Args:
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
191
|
+
context: the model context. Input dicts must contain image key for original image size.
|
|
188
192
|
targets: should contain class key that stores the class label.
|
|
189
193
|
|
|
190
194
|
Returns:
|
|
191
195
|
tuple of outputs and loss dict
|
|
192
196
|
"""
|
|
197
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
198
|
+
raise ValueError("input to FasterRCNN must be FeatureMaps")
|
|
199
|
+
|
|
193
200
|
# Fix target labels to be 1 size in case it's empty.
|
|
194
201
|
# For some reason this is needed.
|
|
195
202
|
if targets:
|
|
@@ -203,11 +210,11 @@ class FasterRCNN(torch.nn.Module):
|
|
|
203
210
|
),
|
|
204
211
|
)
|
|
205
212
|
|
|
206
|
-
image_list = [inp["image"] for inp in inputs]
|
|
213
|
+
image_list = [inp["image"] for inp in context.inputs]
|
|
207
214
|
images, targets = self.noop_transform(image_list, targets)
|
|
208
215
|
|
|
209
216
|
feature_dict = collections.OrderedDict()
|
|
210
|
-
for i, feat_map in enumerate(
|
|
217
|
+
for i, feat_map in enumerate(intermediates.feature_maps):
|
|
211
218
|
feature_dict[f"feat{i}"] = feat_map
|
|
212
219
|
|
|
213
220
|
proposals, proposal_losses = self.rpn(images, feature_dict, targets)
|
|
@@ -219,4 +226,7 @@ class FasterRCNN(torch.nn.Module):
|
|
|
219
226
|
losses.update(proposal_losses)
|
|
220
227
|
losses.update(detector_losses)
|
|
221
228
|
|
|
222
|
-
return
|
|
229
|
+
return ModelOutput(
|
|
230
|
+
outputs=detections,
|
|
231
|
+
loss_dict=losses,
|
|
232
|
+
)
|
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
6
|
|
|
7
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
|
|
10
|
+
class FeatureCenterCrop(IntermediateComponent):
|
|
9
11
|
"""Apply center cropping on the input feature maps."""
|
|
10
12
|
|
|
11
13
|
def __init__(
|
|
@@ -24,20 +26,21 @@ class FeatureCenterCrop(torch.nn.Module):
|
|
|
24
26
|
super().__init__()
|
|
25
27
|
self.sizes = sizes
|
|
26
28
|
|
|
27
|
-
def forward(
|
|
28
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
-
) -> list[torch.Tensor]:
|
|
29
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
30
30
|
"""Apply center cropping on the feature maps.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
intermediates: output from the previous model component, which must be a FeatureMaps.
|
|
34
|
+
context: the model context.
|
|
35
35
|
|
|
36
36
|
Returns:
|
|
37
37
|
center cropped feature maps.
|
|
38
38
|
"""
|
|
39
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
40
|
+
raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
|
|
41
|
+
|
|
39
42
|
new_features = []
|
|
40
|
-
for i, feat in enumerate(
|
|
43
|
+
for i, feat in enumerate(intermediates.feature_maps):
|
|
41
44
|
height, width = self.sizes[i]
|
|
42
45
|
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
46
|
raise ValueError(
|
|
@@ -47,4 +50,4 @@ class FeatureCenterCrop(torch.nn.Module):
|
|
|
47
50
|
start_w = feat.shape[3] // 2 - width // 2
|
|
48
51
|
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
52
|
new_features.append(feat)
|
|
50
|
-
return new_features
|
|
53
|
+
return FeatureMaps(new_features)
|
rslearn/models/fpn.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Feature pyramid network."""
|
|
2
2
|
|
|
3
3
|
import collections
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
|
-
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Fpn(IntermediateComponent):
|
|
10
14
|
"""A feature pyramid network (FPN).
|
|
11
15
|
|
|
12
16
|
The FPN inputs a multi-scale feature map. At each scale, it computes new features
|
|
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
|
|
|
32
36
|
in_channels_list=in_channels, out_channels=out_channels
|
|
33
37
|
)
|
|
34
38
|
|
|
35
|
-
def forward(self,
|
|
39
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
36
40
|
"""Compute outputs of the FPN.
|
|
37
41
|
|
|
38
42
|
Args:
|
|
39
|
-
|
|
43
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
44
|
+
context: the model context.
|
|
40
45
|
|
|
41
46
|
Returns:
|
|
42
|
-
new multi-scale feature maps from the FPN
|
|
47
|
+
new multi-scale feature maps from the FPN.
|
|
43
48
|
"""
|
|
44
|
-
|
|
49
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
50
|
+
raise ValueError("input to Fpn must be FeatureMaps")
|
|
51
|
+
|
|
52
|
+
feature_maps = intermediates.feature_maps
|
|
53
|
+
inp = collections.OrderedDict(
|
|
54
|
+
[(f"feat{i}", el) for i, el in enumerate(feature_maps)]
|
|
55
|
+
)
|
|
45
56
|
output = self.fpn(inp)
|
|
46
57
|
output = list(output.values())
|
|
47
58
|
|
|
48
59
|
if self.prepend:
|
|
49
|
-
return output +
|
|
60
|
+
return FeatureMaps(output + feature_maps)
|
|
50
61
|
else:
|
|
51
|
-
return output
|
|
62
|
+
return FeatureMaps(output)
|
|
@@ -4,16 +4,16 @@ import math
|
|
|
4
4
|
import tempfile
|
|
5
5
|
from contextlib import nullcontext
|
|
6
6
|
from enum import StrEnum
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import cast
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
-
import torch.nn as nn
|
|
12
11
|
from einops import rearrange, repeat
|
|
13
12
|
from huggingface_hub import hf_hub_download
|
|
14
13
|
from upath import UPath
|
|
15
14
|
|
|
16
15
|
from rslearn.log_utils import get_logger
|
|
16
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
17
17
|
from rslearn.models.galileo.single_file_galileo import (
|
|
18
18
|
CONFIG_FILENAME,
|
|
19
19
|
DW_BANDS,
|
|
@@ -39,6 +39,7 @@ from rslearn.models.galileo.single_file_galileo import (
|
|
|
39
39
|
MaskedOutput,
|
|
40
40
|
Normalizer,
|
|
41
41
|
)
|
|
42
|
+
from rslearn.train.model_context import ModelContext
|
|
42
43
|
|
|
43
44
|
logger = get_logger(__name__)
|
|
44
45
|
|
|
@@ -70,7 +71,7 @@ AUTOCAST_DTYPE_MAP = {
|
|
|
70
71
|
}
|
|
71
72
|
|
|
72
73
|
|
|
73
|
-
class GalileoModel(
|
|
74
|
+
class GalileoModel(FeatureExtractor):
|
|
74
75
|
"""Galileo backbones."""
|
|
75
76
|
|
|
76
77
|
input_keys = [
|
|
@@ -410,11 +411,11 @@ class GalileoModel(nn.Module):
|
|
|
410
411
|
months=months,
|
|
411
412
|
)
|
|
412
413
|
|
|
413
|
-
def forward(self,
|
|
414
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
414
415
|
"""Compute feature maps from the Galileo backbone.
|
|
415
416
|
|
|
416
|
-
|
|
417
|
-
|
|
417
|
+
Args:
|
|
418
|
+
context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
|
|
418
419
|
(also documented below) and values are tensors of the following shapes,
|
|
419
420
|
per input key:
|
|
420
421
|
"s1": B (T * C) H W
|
|
@@ -436,10 +437,12 @@ class GalileoModel(nn.Module):
|
|
|
436
437
|
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
437
438
|
"""
|
|
438
439
|
stacked_inputs = {}
|
|
439
|
-
for key in inputs[0].keys():
|
|
440
|
+
for key in context.inputs[0].keys():
|
|
440
441
|
# assume all the keys in an input are consistent
|
|
441
442
|
if key in self.input_keys:
|
|
442
|
-
stacked_inputs[key] = torch.stack(
|
|
443
|
+
stacked_inputs[key] = torch.stack(
|
|
444
|
+
[inp[key] for inp in context.inputs], dim=0
|
|
445
|
+
)
|
|
443
446
|
s_t_channels = []
|
|
444
447
|
for space_time_modality in ["s1", "s2"]:
|
|
445
448
|
if space_time_modality not in stacked_inputs:
|
|
@@ -502,14 +505,14 @@ class GalileoModel(nn.Module):
|
|
|
502
505
|
# Decide context based on self.autocast_dtype.
|
|
503
506
|
device = galileo_input.s_t_x.device
|
|
504
507
|
if self.autocast_dtype is None:
|
|
505
|
-
|
|
508
|
+
torch_context = nullcontext()
|
|
506
509
|
else:
|
|
507
510
|
assert device is not None
|
|
508
|
-
|
|
511
|
+
torch_context = torch.amp.autocast(
|
|
509
512
|
device_type=device.type, dtype=self.autocast_dtype
|
|
510
513
|
)
|
|
511
514
|
|
|
512
|
-
with
|
|
515
|
+
with torch_context:
|
|
513
516
|
outputs = self.model(
|
|
514
517
|
s_t_x=galileo_input.s_t_x,
|
|
515
518
|
s_t_m=galileo_input.s_t_m,
|
|
@@ -530,18 +533,20 @@ class GalileoModel(nn.Module):
|
|
|
530
533
|
averaged = self.model.average_tokens(
|
|
531
534
|
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
532
535
|
)
|
|
533
|
-
return [repeat(averaged, "b d -> b d 1 1")]
|
|
536
|
+
return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
|
|
534
537
|
else:
|
|
535
538
|
s_t_x = outputs[0]
|
|
536
539
|
# we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
|
|
537
540
|
# s_t_x has shape [b, h, w, t, c_g, d]
|
|
538
541
|
# and we want [b, d, h, w]
|
|
539
|
-
return
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
542
|
+
return FeatureMaps(
|
|
543
|
+
[
|
|
544
|
+
rearrange(
|
|
545
|
+
s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
|
|
546
|
+
"b h w c_g d -> b c_g d h w",
|
|
547
|
+
).mean(dim=1)
|
|
548
|
+
]
|
|
549
|
+
)
|
|
545
550
|
|
|
546
551
|
def get_backbone_channels(self) -> list:
|
|
547
552
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/module_wrapper.py
CHANGED
|
@@ -1,67 +1,35 @@
|
|
|
1
|
-
"""Module
|
|
1
|
+
"""Module wrapper provided for backwards compatibility."""
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureExtractor,
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
10
14
|
|
|
11
|
-
The module should input feature map and produce a new feature map.
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
module: torch.nn.Module,
|
|
20
|
-
):
|
|
21
|
-
"""Initialize a DecoderModuleWrapper.
|
|
16
|
+
class EncoderModuleWrapper(FeatureExtractor):
|
|
17
|
+
"""Wraps one or more IntermediateComponents to function as the feature extractor.
|
|
22
18
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
super().__init__()
|
|
27
|
-
self.module = module
|
|
28
|
-
|
|
29
|
-
def forward(
|
|
30
|
-
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
31
|
-
) -> list[torch.Tensor]:
|
|
32
|
-
"""Apply the wrapped module on each feature map.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
features: list of feature maps at different resolutions.
|
|
36
|
-
inputs: original inputs (ignored).
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
new features
|
|
40
|
-
"""
|
|
41
|
-
new_features = []
|
|
42
|
-
for feat_map in features:
|
|
43
|
-
feat_map = self.module(feat_map)
|
|
44
|
-
new_features.append(feat_map)
|
|
45
|
-
return new_features
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class EncoderModuleWrapper(torch.nn.Module):
|
|
49
|
-
"""Wraps a module that is intended to be used as the decoder to work in encoder.
|
|
50
|
-
|
|
51
|
-
The module should input a feature map that corresponds to the original image, i.e.
|
|
52
|
-
the depth of the feature map would be the number of bands in the input image.
|
|
19
|
+
The first component should input a FeatureMaps, which will be computed from the
|
|
20
|
+
overall inputs by stacking the "image" key from each input dict.
|
|
53
21
|
"""
|
|
54
22
|
|
|
55
23
|
def __init__(
|
|
56
24
|
self,
|
|
57
|
-
module:
|
|
58
|
-
modules: list[
|
|
25
|
+
module: IntermediateComponent | None = None,
|
|
26
|
+
modules: list[IntermediateComponent] = [],
|
|
59
27
|
):
|
|
60
28
|
"""Initialize an EncoderModuleWrapper.
|
|
61
29
|
|
|
62
30
|
Args:
|
|
63
|
-
module: the
|
|
64
|
-
must be set.
|
|
31
|
+
module: the IntermediateComponent to wrap for use as a FeatureExtractor.
|
|
32
|
+
Exactly one of module or modules must be set.
|
|
65
33
|
modules: list of modules to wrap
|
|
66
34
|
"""
|
|
67
35
|
super().__init__()
|
|
@@ -74,18 +42,19 @@ class EncoderModuleWrapper(torch.nn.Module):
|
|
|
74
42
|
else:
|
|
75
43
|
raise ValueError("one of module or modules must be set")
|
|
76
44
|
|
|
77
|
-
def forward(
|
|
78
|
-
self,
|
|
79
|
-
inputs: list[dict[str, Any]],
|
|
80
|
-
) -> list[torch.Tensor]:
|
|
45
|
+
def forward(self, context: ModelContext) -> Any:
|
|
81
46
|
"""Compute outputs from the wrapped module.
|
|
82
47
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
48
|
+
Args:
|
|
49
|
+
context: the model context. Input dicts must include "image" key containing
|
|
50
|
+
the image to convert to a FeatureMaps, which will be passed to the
|
|
51
|
+
first wrapped module.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
the output from the last wrapped module.
|
|
86
55
|
"""
|
|
87
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
88
|
-
cur = [images]
|
|
56
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
57
|
+
cur: Any = FeatureMaps([images])
|
|
89
58
|
for m in self.encoder_modules:
|
|
90
|
-
cur = m(cur,
|
|
59
|
+
cur = m(cur, context)
|
|
91
60
|
return cur
|
rslearn/models/molmo.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Molmo model."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
7
5
|
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
|
|
11
|
+
class Molmo(FeatureExtractor):
|
|
10
12
|
"""Molmo image encoder."""
|
|
11
13
|
|
|
12
14
|
def __init__(
|
|
@@ -34,21 +36,21 @@ class Molmo(torch.nn.Module):
|
|
|
34
36
|
) # nosec
|
|
35
37
|
self.encoder = model.model.vision_backbone
|
|
36
38
|
|
|
37
|
-
def forward(self,
|
|
39
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
38
40
|
"""Compute outputs from the backbone.
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
process. The images should have values 0-255.
|
|
42
|
+
Args:
|
|
43
|
+
context: the model context. Input dicts must include "image" key containing
|
|
44
|
+
the image to process. The images should have values 0-255.
|
|
43
45
|
|
|
44
46
|
Returns:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
a FeatureMaps. Molmo produces features at one scale, so it will contain one
|
|
48
|
+
feature map that is a Bx24x24x2048 tensor.
|
|
47
49
|
"""
|
|
48
|
-
device = inputs[0]["image"].device
|
|
50
|
+
device = context.inputs[0]["image"].device
|
|
49
51
|
molmo_inputs_list = []
|
|
50
52
|
# Process each one so we can isolate just the full image without any crops.
|
|
51
|
-
for inp in inputs:
|
|
53
|
+
for inp in context.inputs:
|
|
52
54
|
image = inp["image"].cpu().numpy().transpose(1, 2, 0)
|
|
53
55
|
processed = self.processor.process(
|
|
54
56
|
images=[image],
|
|
@@ -60,6 +62,6 @@ class Molmo(torch.nn.Module):
|
|
|
60
62
|
image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
|
|
61
63
|
|
|
62
64
|
# 576x2048 -> 24x24x2048
|
|
63
|
-
return
|
|
64
|
-
image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
|
|
65
|
-
|
|
65
|
+
return FeatureMaps(
|
|
66
|
+
[image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
|
|
67
|
+
)
|