rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +49 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/main.py +11 -36
  11. rslearn/models/anysat.py +11 -9
  12. rslearn/models/clay/clay.py +8 -9
  13. rslearn/models/clip.py +18 -15
  14. rslearn/models/component.py +99 -0
  15. rslearn/models/concatenate_features.py +21 -11
  16. rslearn/models/conv.py +15 -8
  17. rslearn/models/croma.py +13 -8
  18. rslearn/models/detr/detr.py +25 -14
  19. rslearn/models/dinov3.py +11 -6
  20. rslearn/models/faster_rcnn.py +19 -9
  21. rslearn/models/feature_center_crop.py +12 -9
  22. rslearn/models/fpn.py +19 -8
  23. rslearn/models/galileo/galileo.py +23 -18
  24. rslearn/models/module_wrapper.py +26 -57
  25. rslearn/models/molmo.py +16 -14
  26. rslearn/models/multitask.py +102 -73
  27. rslearn/models/olmoearth_pretrain/model.py +18 -12
  28. rslearn/models/panopticon.py +8 -7
  29. rslearn/models/pick_features.py +18 -24
  30. rslearn/models/pooling_decoder.py +22 -14
  31. rslearn/models/presto/presto.py +16 -10
  32. rslearn/models/presto/single_file_presto.py +4 -10
  33. rslearn/models/prithvi.py +12 -8
  34. rslearn/models/resize_features.py +21 -7
  35. rslearn/models/sam2_enc.py +11 -9
  36. rslearn/models/satlaspretrain.py +15 -9
  37. rslearn/models/simple_time_series.py +31 -17
  38. rslearn/models/singletask.py +24 -17
  39. rslearn/models/ssl4eo_s12.py +15 -10
  40. rslearn/models/swin.py +22 -13
  41. rslearn/models/terramind.py +24 -7
  42. rslearn/models/trunk.py +6 -3
  43. rslearn/models/unet.py +18 -9
  44. rslearn/models/upsample.py +22 -9
  45. rslearn/train/all_patches_dataset.py +22 -18
  46. rslearn/train/dataset.py +69 -54
  47. rslearn/train/lightning_module.py +51 -32
  48. rslearn/train/model_context.py +54 -0
  49. rslearn/train/prediction_writer.py +111 -41
  50. rslearn/train/tasks/classification.py +34 -15
  51. rslearn/train/tasks/detection.py +24 -31
  52. rslearn/train/tasks/embedding.py +33 -29
  53. rslearn/train/tasks/multi_task.py +7 -7
  54. rslearn/train/tasks/per_pixel_regression.py +41 -19
  55. rslearn/train/tasks/regression.py +38 -21
  56. rslearn/train/tasks/segmentation.py +33 -15
  57. rslearn/train/tasks/task.py +3 -2
  58. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
  59. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
  60. rslearn/dataset/index.py +0 -173
  61. rslearn/models/registry.py +0 -22
  62. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  63. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  64. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/models/croma.py CHANGED
@@ -12,9 +12,11 @@ from einops import rearrange
12
12
  from upath import UPath
13
13
 
14
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.train.model_context import ModelContext
15
16
  from rslearn.train.transforms.transform import Transform
16
17
  from rslearn.utils.fsspec import open_atomic
17
18
 
19
+ from .component import FeatureExtractor, FeatureMaps
18
20
  from .use_croma import PretrainedCROMA
19
21
 
20
22
  logger = get_logger(__name__)
@@ -76,7 +78,7 @@ MODALITY_BANDS = {
76
78
  }
77
79
 
78
80
 
79
- class Croma(torch.nn.Module):
81
+ class Croma(FeatureExtractor):
80
82
  """CROMA backbones.
81
83
 
82
84
  There are two model sizes, base and large.
@@ -160,20 +162,23 @@ class Croma(torch.nn.Module):
160
162
  align_corners=False,
161
163
  )
162
164
 
163
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
165
+ def forward(self, context: ModelContext) -> FeatureMaps:
164
166
  """Compute feature maps from the Croma backbone.
165
167
 
166
- Inputs:
167
- inputs: input dicts that must include either/both of "sentinel2" or
168
- "sentinel1" keys depending on the configured modality.
168
+ Args:
169
+ context: the model context. Input dicts must include either/both of
170
+ "sentinel2" or "sentinel1" keys depending on the configured modality.
171
+
172
+ Returns:
173
+ a FeatureMaps with one feature map at 1/8 the input resolution.
169
174
  """
170
175
  sentinel1: torch.Tensor | None = None
171
176
  sentinel2: torch.Tensor | None = None
172
177
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
173
- sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
178
+ sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
174
179
  sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
175
180
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
176
- sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
181
+ sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
177
182
  sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
178
183
 
179
184
  outputs = self.model(
@@ -200,7 +205,7 @@ class Croma(torch.nn.Module):
200
205
  w=num_patches_per_dim,
201
206
  )
202
207
 
203
- return [features]
208
+ return FeatureMaps([features])
204
209
 
205
210
  def get_backbone_channels(self) -> list:
206
211
  """Returns the output channels of this model when used as a backbone.
@@ -13,6 +13,8 @@ import torch.nn.functional as F
13
13
  from torch import nn
14
14
 
15
15
  import rslearn.models.detr.box_ops as box_ops
16
+ from rslearn.models.component import FeatureMaps, Predictor
17
+ from rslearn.train.model_context import ModelContext, ModelOutput
16
18
 
17
19
  from .matcher import HungarianMatcher
18
20
  from .position_encoding import PositionEmbeddingSine
@@ -405,7 +407,7 @@ class PostProcess(nn.Module):
405
407
  return results
406
408
 
407
409
 
408
- class Detr(nn.Module):
410
+ class Detr(Predictor):
409
411
  """DETR prediction module.
410
412
 
411
413
  This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
@@ -440,33 +442,39 @@ class Detr(nn.Module):
440
442
 
441
443
  def forward(
442
444
  self,
443
- features: list[torch.Tensor],
444
- inputs: list[dict[str, Any]],
445
+ intermediates: Any,
446
+ context: ModelContext,
445
447
  targets: list[dict[str, Any]] | None = None,
446
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
448
+ ) -> ModelOutput:
447
449
  """Compute the detection outputs and loss from features.
448
450
 
449
451
  DETR will use only the last feature map, which should correspond to the lowest
450
452
  resolution one.
451
453
 
452
454
  Args:
453
- features: multi-scale feature maps.
454
- inputs: original inputs, should contain image key for original image size.
455
- targets: should contain class key that stores the class label.
455
+ intermediates: the output from the previous component. It must be a FeatureMaps.
456
+ context: the model context. Input dicts must contain an "image" key which we will
457
+ be used to establish the original image size.
458
+ targets: must contain class key that stores the class label.
456
459
 
457
460
  Returns:
458
- tuple of outputs and loss dict.
461
+ the model output.
459
462
  """
463
+ if not isinstance(intermediates, FeatureMaps):
464
+ raise ValueError("input to Detr must be a FeatureMaps")
465
+
466
+ # We only use the last feature map (most fine-grained).
467
+ features = intermediates.feature_maps[-1]
468
+
460
469
  # Get image sizes.
461
470
  image_sizes = torch.tensor(
462
- [[inp["image"].shape[2], inp["image"].shape[1]] for inp in inputs],
471
+ [[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
463
472
  dtype=torch.int32,
464
- device=features[0].device,
473
+ device=features.device,
465
474
  )
466
475
 
467
- feat_map = features[-1]
468
- pos_embedding = self.pos_embedding(feat_map)
469
- outputs = self.predictor(feat_map, pos_embedding)
476
+ pos_embedding = self.pos_embedding(features)
477
+ outputs = self.predictor(features, pos_embedding)
470
478
 
471
479
  if targets is not None:
472
480
  # Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
@@ -490,4 +498,7 @@ class Detr(nn.Module):
490
498
 
491
499
  results = self.postprocess(outputs, image_sizes)
492
500
 
493
- return results, losses
501
+ return ModelOutput(
502
+ outputs=results,
503
+ loss_dict=losses,
504
+ )
rslearn/models/dinov3.py CHANGED
@@ -13,9 +13,12 @@ import torch
13
13
  import torchvision
14
14
  from einops import rearrange
15
15
 
16
+ from rslearn.train.model_context import ModelContext
16
17
  from rslearn.train.transforms.normalize import Normalize
17
18
  from rslearn.train.transforms.transform import Transform
18
19
 
20
+ from .component import FeatureExtractor, FeatureMaps
21
+
19
22
 
20
23
  class DinoV3Models(StrEnum):
21
24
  """Names for different DinoV3 images on torch hub."""
@@ -40,7 +43,7 @@ DINOV3_PTHS: dict[str, str] = {
40
43
  }
41
44
 
42
45
 
43
- class DinoV3(torch.nn.Module):
46
+ class DinoV3(FeatureExtractor):
44
47
  """DinoV3 Backbones.
45
48
 
46
49
  Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
@@ -91,16 +94,18 @@ class DinoV3(torch.nn.Module):
91
94
  self.do_resizing = do_resizing
92
95
  self.model = self._load_model(size, checkpoint_dir)
93
96
 
94
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
97
+ def forward(self, context: ModelContext) -> FeatureMaps:
95
98
  """Forward pass for the dinov3 model.
96
99
 
97
100
  Args:
98
- inputs: input dicts that must include "image" key.
101
+ context: the model context. Input dicts must include "image" key.
99
102
 
100
103
  Returns:
101
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
104
+ a FeatureMaps with one feature map.
102
105
  """
103
- cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
106
+ cur = torch.stack(
107
+ [inp["image"] for inp in context.inputs], dim=0
108
+ ) # (B, C, H, W)
104
109
 
105
110
  if self.do_resizing and (
106
111
  cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
@@ -118,7 +123,7 @@ class DinoV3(torch.nn.Module):
118
123
  height, width = int(num_patches**0.5), int(num_patches**0.5)
119
124
  features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
120
125
 
121
- return [features]
126
+ return FeatureMaps([features])
122
127
 
123
128
  def get_backbone_channels(self) -> list:
124
129
  """Returns the output channels of this model when used as a backbone.
@@ -6,6 +6,10 @@ from typing import Any
6
6
  import torch
7
7
  import torchvision
8
8
 
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+ from .component import FeatureMaps, Predictor
12
+
9
13
 
10
14
  class NoopTransform(torch.nn.Module):
11
15
  """A placeholder transform used with torchvision detection model."""
@@ -55,7 +59,7 @@ class NoopTransform(torch.nn.Module):
55
59
  return image_list, targets
56
60
 
57
61
 
58
- class FasterRCNN(torch.nn.Module):
62
+ class FasterRCNN(Predictor):
59
63
  """Faster R-CNN head for predicting bounding boxes.
60
64
 
61
65
  It inputs multi-scale features, using each feature map to predict ROIs and then
@@ -176,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
176
180
 
177
181
  def forward(
178
182
  self,
179
- features: list[torch.Tensor],
180
- inputs: list[dict[str, Any]],
183
+ intermediates: Any,
184
+ context: ModelContext,
181
185
  targets: list[dict[str, Any]] | None = None,
182
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
186
+ ) -> ModelOutput:
183
187
  """Compute the detection outputs and loss from features.
184
188
 
185
189
  Args:
186
- features: multi-scale feature maps.
187
- inputs: original inputs, should cotnain image key for original image size.
190
+ intermediates: the output from the previous component, which must be a FeatureMaps.
191
+ context: the model context. Input dicts must contain image key for original image size.
188
192
  targets: should contain class key that stores the class label.
189
193
 
190
194
  Returns:
191
195
  tuple of outputs and loss dict
192
196
  """
197
+ if not isinstance(intermediates, FeatureMaps):
198
+ raise ValueError("input to FasterRCNN must be FeatureMaps")
199
+
193
200
  # Fix target labels to be 1 size in case it's empty.
194
201
  # For some reason this is needed.
195
202
  if targets:
@@ -203,11 +210,11 @@ class FasterRCNN(torch.nn.Module):
203
210
  ),
204
211
  )
205
212
 
206
- image_list = [inp["image"] for inp in inputs]
213
+ image_list = [inp["image"] for inp in context.inputs]
207
214
  images, targets = self.noop_transform(image_list, targets)
208
215
 
209
216
  feature_dict = collections.OrderedDict()
210
- for i, feat_map in enumerate(features):
217
+ for i, feat_map in enumerate(intermediates.feature_maps):
211
218
  feature_dict[f"feat{i}"] = feat_map
212
219
 
213
220
  proposals, proposal_losses = self.rpn(images, feature_dict, targets)
@@ -219,4 +226,7 @@ class FasterRCNN(torch.nn.Module):
219
226
  losses.update(proposal_losses)
220
227
  losses.update(detector_losses)
221
228
 
222
- return detections, losses
229
+ return ModelOutput(
230
+ outputs=detections,
231
+ loss_dict=losses,
232
+ )
@@ -2,10 +2,12 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
5
+ from rslearn.train.model_context import ModelContext
6
6
 
7
+ from .component import FeatureMaps, IntermediateComponent
7
8
 
8
- class FeatureCenterCrop(torch.nn.Module):
9
+
10
+ class FeatureCenterCrop(IntermediateComponent):
9
11
  """Apply center cropping on the input feature maps."""
10
12
 
11
13
  def __init__(
@@ -24,20 +26,21 @@ class FeatureCenterCrop(torch.nn.Module):
24
26
  super().__init__()
25
27
  self.sizes = sizes
26
28
 
27
- def forward(
28
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
29
- ) -> list[torch.Tensor]:
29
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
30
30
  """Apply center cropping on the feature maps.
31
31
 
32
32
  Args:
33
- features: list of feature maps at different resolutions.
34
- inputs: original inputs (ignored).
33
+ intermediates: output from the previous model component, which must be a FeatureMaps.
34
+ context: the model context.
35
35
 
36
36
  Returns:
37
37
  center cropped feature maps.
38
38
  """
39
+ if not isinstance(intermediates, FeatureMaps):
40
+ raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
41
+
39
42
  new_features = []
40
- for i, feat in enumerate(features):
43
+ for i, feat in enumerate(intermediates.feature_maps):
41
44
  height, width = self.sizes[i]
42
45
  if feat.shape[2] < height or feat.shape[3] < width:
43
46
  raise ValueError(
@@ -47,4 +50,4 @@ class FeatureCenterCrop(torch.nn.Module):
47
50
  start_w = feat.shape[3] // 2 - width // 2
48
51
  feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
49
52
  new_features.append(feat)
50
- return new_features
53
+ return FeatureMaps(new_features)
rslearn/models/fpn.py CHANGED
@@ -1,12 +1,16 @@
1
1
  """Feature pyramid network."""
2
2
 
3
3
  import collections
4
+ from typing import Any
4
5
 
5
- import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import ModelContext
8
9
 
9
- class Fpn(torch.nn.Module):
10
+ from .component import FeatureMaps, IntermediateComponent
11
+
12
+
13
+ class Fpn(IntermediateComponent):
10
14
  """A feature pyramid network (FPN).
11
15
 
12
16
  The FPN inputs a multi-scale feature map. At each scale, it computes new features
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
32
36
  in_channels_list=in_channels, out_channels=out_channels
33
37
  )
34
38
 
35
- def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
39
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
36
40
  """Compute outputs of the FPN.
37
41
 
38
42
  Args:
39
- x: the multi-scale feature maps
43
+ intermediates: the output from the previous component, which must be a FeatureMaps.
44
+ context: the model context.
40
45
 
41
46
  Returns:
42
- new multi-scale feature maps from the FPN
47
+ new multi-scale feature maps from the FPN.
43
48
  """
44
- inp = collections.OrderedDict([(f"feat{i}", el) for i, el in enumerate(x)])
49
+ if not isinstance(intermediates, FeatureMaps):
50
+ raise ValueError("input to Fpn must be FeatureMaps")
51
+
52
+ feature_maps = intermediates.feature_maps
53
+ inp = collections.OrderedDict(
54
+ [(f"feat{i}", el) for i, el in enumerate(feature_maps)]
55
+ )
45
56
  output = self.fpn(inp)
46
57
  output = list(output.values())
47
58
 
48
59
  if self.prepend:
49
- return output + x
60
+ return FeatureMaps(output + feature_maps)
50
61
  else:
51
- return output
62
+ return FeatureMaps(output)
@@ -4,16 +4,16 @@ import math
4
4
  import tempfile
5
5
  from contextlib import nullcontext
6
6
  from enum import StrEnum
7
- from typing import Any, cast
7
+ from typing import cast
8
8
 
9
9
  import numpy as np
10
10
  import torch
11
- import torch.nn as nn
12
11
  from einops import rearrange, repeat
13
12
  from huggingface_hub import hf_hub_download
14
13
  from upath import UPath
15
14
 
16
15
  from rslearn.log_utils import get_logger
16
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
17
17
  from rslearn.models.galileo.single_file_galileo import (
18
18
  CONFIG_FILENAME,
19
19
  DW_BANDS,
@@ -39,6 +39,7 @@ from rslearn.models.galileo.single_file_galileo import (
39
39
  MaskedOutput,
40
40
  Normalizer,
41
41
  )
42
+ from rslearn.train.model_context import ModelContext
42
43
 
43
44
  logger = get_logger(__name__)
44
45
 
@@ -70,7 +71,7 @@ AUTOCAST_DTYPE_MAP = {
70
71
  }
71
72
 
72
73
 
73
- class GalileoModel(nn.Module):
74
+ class GalileoModel(FeatureExtractor):
74
75
  """Galileo backbones."""
75
76
 
76
77
  input_keys = [
@@ -410,11 +411,11 @@ class GalileoModel(nn.Module):
410
411
  months=months,
411
412
  )
412
413
 
413
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
414
+ def forward(self, context: ModelContext) -> FeatureMaps:
414
415
  """Compute feature maps from the Galileo backbone.
415
416
 
416
- Inputs:
417
- inputs: a dictionary of tensors, where the keys are one of Galileo.input_keys
417
+ Args:
418
+ context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
418
419
  (also documented below) and values are tensors of the following shapes,
419
420
  per input key:
420
421
  "s1": B (T * C) H W
@@ -436,10 +437,12 @@ class GalileoModel(nn.Module):
436
437
  take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
437
438
  """
438
439
  stacked_inputs = {}
439
- for key in inputs[0].keys():
440
+ for key in context.inputs[0].keys():
440
441
  # assume all the keys in an input are consistent
441
442
  if key in self.input_keys:
442
- stacked_inputs[key] = torch.stack([inp[key] for inp in inputs], dim=0)
443
+ stacked_inputs[key] = torch.stack(
444
+ [inp[key] for inp in context.inputs], dim=0
445
+ )
443
446
  s_t_channels = []
444
447
  for space_time_modality in ["s1", "s2"]:
445
448
  if space_time_modality not in stacked_inputs:
@@ -502,14 +505,14 @@ class GalileoModel(nn.Module):
502
505
  # Decide context based on self.autocast_dtype.
503
506
  device = galileo_input.s_t_x.device
504
507
  if self.autocast_dtype is None:
505
- context = nullcontext()
508
+ torch_context = nullcontext()
506
509
  else:
507
510
  assert device is not None
508
- context = torch.amp.autocast(
511
+ torch_context = torch.amp.autocast(
509
512
  device_type=device.type, dtype=self.autocast_dtype
510
513
  )
511
514
 
512
- with context:
515
+ with torch_context:
513
516
  outputs = self.model(
514
517
  s_t_x=galileo_input.s_t_x,
515
518
  s_t_m=galileo_input.s_t_m,
@@ -530,18 +533,20 @@ class GalileoModel(nn.Module):
530
533
  averaged = self.model.average_tokens(
531
534
  s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
532
535
  )
533
- return [repeat(averaged, "b d -> b d 1 1")]
536
+ return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
534
537
  else:
535
538
  s_t_x = outputs[0]
536
539
  # we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
537
540
  # s_t_x has shape [b, h, w, t, c_g, d]
538
541
  # and we want [b, d, h, w]
539
- return [
540
- rearrange(
541
- s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
542
- "b h w c_g d -> b c_g d h w",
543
- ).mean(dim=1)
544
- ]
542
+ return FeatureMaps(
543
+ [
544
+ rearrange(
545
+ s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
546
+ "b h w c_g d -> b c_g d h w",
547
+ ).mean(dim=1)
548
+ ]
549
+ )
545
550
 
546
551
  def get_backbone_channels(self) -> list:
547
552
  """Returns the output channels of this model when used as a backbone.
@@ -1,67 +1,35 @@
1
- """Module wrappers."""
1
+ """Module wrapper provided for backwards compatibility."""
2
2
 
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class DecoderModuleWrapper(torch.nn.Module):
9
- """Wrapper for a module that processes features to work in decoder.
9
+ from .component import (
10
+ FeatureExtractor,
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
10
14
 
11
- The module should input feature map and produce a new feature map.
12
15
 
13
- We wrap it to process each feature map in multi-scale features which is what's used
14
- for most decoders.
15
- """
16
-
17
- def __init__(
18
- self,
19
- module: torch.nn.Module,
20
- ):
21
- """Initialize a DecoderModuleWrapper.
16
+ class EncoderModuleWrapper(FeatureExtractor):
17
+ """Wraps one or more IntermediateComponents to function as the feature extractor.
22
18
 
23
- Args:
24
- module: the module to wrap
25
- """
26
- super().__init__()
27
- self.module = module
28
-
29
- def forward(
30
- self, features: list[torch.Tensor], inputs: list[torch.Tensor]
31
- ) -> list[torch.Tensor]:
32
- """Apply the wrapped module on each feature map.
33
-
34
- Args:
35
- features: list of feature maps at different resolutions.
36
- inputs: original inputs (ignored).
37
-
38
- Returns:
39
- new features
40
- """
41
- new_features = []
42
- for feat_map in features:
43
- feat_map = self.module(feat_map)
44
- new_features.append(feat_map)
45
- return new_features
46
-
47
-
48
- class EncoderModuleWrapper(torch.nn.Module):
49
- """Wraps a module that is intended to be used as the decoder to work in encoder.
50
-
51
- The module should input a feature map that corresponds to the original image, i.e.
52
- the depth of the feature map would be the number of bands in the input image.
19
+ The first component should input a FeatureMaps, which will be computed from the
20
+ overall inputs by stacking the "image" key from each input dict.
53
21
  """
54
22
 
55
23
  def __init__(
56
24
  self,
57
- module: torch.nn.Module | None = None,
58
- modules: list[torch.nn.Module] = [],
25
+ module: IntermediateComponent | None = None,
26
+ modules: list[IntermediateComponent] = [],
59
27
  ):
60
28
  """Initialize an EncoderModuleWrapper.
61
29
 
62
30
  Args:
63
- module: the encoder module to wrap. Exactly one one of module or modules
64
- must be set.
31
+ module: the IntermediateComponent to wrap for use as a FeatureExtractor.
32
+ Exactly one of module or modules must be set.
65
33
  modules: list of modules to wrap
66
34
  """
67
35
  super().__init__()
@@ -74,18 +42,19 @@ class EncoderModuleWrapper(torch.nn.Module):
74
42
  else:
75
43
  raise ValueError("one of module or modules must be set")
76
44
 
77
- def forward(
78
- self,
79
- inputs: list[dict[str, Any]],
80
- ) -> list[torch.Tensor]:
45
+ def forward(self, context: ModelContext) -> Any:
81
46
  """Compute outputs from the wrapped module.
82
47
 
83
- Inputs:
84
- inputs: input dicts that must include "image" key containing the image to
85
- process.
48
+ Args:
49
+ context: the model context. Input dicts must include "image" key containing
50
+ the image to convert to a FeatureMaps, which will be passed to the
51
+ first wrapped module.
52
+
53
+ Returns:
54
+ the output from the last wrapped module.
86
55
  """
87
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
88
- cur = [images]
56
+ images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
57
+ cur: Any = FeatureMaps([images])
89
58
  for m in self.encoder_modules:
90
- cur = m(cur, inputs)
59
+ cur = m(cur, context)
91
60
  return cur
rslearn/models/molmo.py CHANGED
@@ -1,12 +1,14 @@
1
1
  """Molmo model."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  from transformers import AutoModelForCausalLM, AutoProcessor
7
5
 
6
+ from rslearn.train.model_context import ModelContext
7
+
8
+ from .component import FeatureExtractor, FeatureMaps
8
9
 
9
- class Molmo(torch.nn.Module):
10
+
11
+ class Molmo(FeatureExtractor):
10
12
  """Molmo image encoder."""
11
13
 
12
14
  def __init__(
@@ -34,21 +36,21 @@ class Molmo(torch.nn.Module):
34
36
  ) # nosec
35
37
  self.encoder = model.model.vision_backbone
36
38
 
37
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
39
+ def forward(self, context: ModelContext) -> FeatureMaps:
38
40
  """Compute outputs from the backbone.
39
41
 
40
- Inputs:
41
- inputs: input dicts that must include "image" key containing the image to
42
- process. The images should have values 0-255.
42
+ Args:
43
+ context: the model context. Input dicts must include "image" key containing
44
+ the image to process. The images should have values 0-255.
43
45
 
44
46
  Returns:
45
- list of feature maps. Molmo produces features at one scale, so the list
46
- contains a single Bx24x24x2048 tensor.
47
+ a FeatureMaps. Molmo produces features at one scale, so it will contain one
48
+ feature map that is a Bx24x24x2048 tensor.
47
49
  """
48
- device = inputs[0]["image"].device
50
+ device = context.inputs[0]["image"].device
49
51
  molmo_inputs_list = []
50
52
  # Process each one so we can isolate just the full image without any crops.
51
- for inp in inputs:
53
+ for inp in context.inputs:
52
54
  image = inp["image"].cpu().numpy().transpose(1, 2, 0)
53
55
  processed = self.processor.process(
54
56
  images=[image],
@@ -60,6 +62,6 @@ class Molmo(torch.nn.Module):
60
62
  image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
61
63
 
62
64
  # 576x2048 -> 24x24x2048
63
- return [
64
- image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
65
- ]
65
+ return FeatureMaps(
66
+ [image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
67
+ )