rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/models/dinov3.py CHANGED
@@ -13,9 +13,12 @@ import torch
13
13
  import torchvision
14
14
  from einops import rearrange
15
15
 
16
+ from rslearn.train.model_context import ModelContext
16
17
  from rslearn.train.transforms.normalize import Normalize
17
18
  from rslearn.train.transforms.transform import Transform
18
19
 
20
+ from .component import FeatureExtractor, FeatureMaps
21
+
19
22
 
20
23
  class DinoV3Models(StrEnum):
21
24
  """Names for different DinoV3 images on torch hub."""
@@ -40,7 +43,7 @@ DINOV3_PTHS: dict[str, str] = {
40
43
  }
41
44
 
42
45
 
43
- class DinoV3(torch.nn.Module):
46
+ class DinoV3(FeatureExtractor):
44
47
  """DinoV3 Backbones.
45
48
 
46
49
  Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
@@ -91,16 +94,18 @@ class DinoV3(torch.nn.Module):
91
94
  self.do_resizing = do_resizing
92
95
  self.model = self._load_model(size, checkpoint_dir)
93
96
 
94
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
97
+ def forward(self, context: ModelContext) -> FeatureMaps:
95
98
  """Forward pass for the dinov3 model.
96
99
 
97
100
  Args:
98
- inputs: input dicts that must include "image" key.
101
+ context: the model context. Input dicts must include "image" key.
99
102
 
100
103
  Returns:
101
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
104
+ a FeatureMaps with one feature map.
102
105
  """
103
- cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
106
+ cur = torch.stack(
107
+ [inp["image"] for inp in context.inputs], dim=0
108
+ ) # (B, C, H, W)
104
109
 
105
110
  if self.do_resizing and (
106
111
  cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
@@ -118,7 +123,7 @@ class DinoV3(torch.nn.Module):
118
123
  height, width = int(num_patches**0.5), int(num_patches**0.5)
119
124
  features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
120
125
 
121
- return [features]
126
+ return FeatureMaps([features])
122
127
 
123
128
  def get_backbone_channels(self) -> list:
124
129
  """Returns the output channels of this model when used as a backbone.
@@ -6,6 +6,10 @@ from typing import Any
6
6
  import torch
7
7
  import torchvision
8
8
 
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+ from .component import FeatureMaps, Predictor
12
+
9
13
 
10
14
  class NoopTransform(torch.nn.Module):
11
15
  """A placeholder transform used with torchvision detection model."""
@@ -55,7 +59,7 @@ class NoopTransform(torch.nn.Module):
55
59
  return image_list, targets
56
60
 
57
61
 
58
- class FasterRCNN(torch.nn.Module):
62
+ class FasterRCNN(Predictor):
59
63
  """Faster R-CNN head for predicting bounding boxes.
60
64
 
61
65
  It inputs multi-scale features, using each feature map to predict ROIs and then
@@ -176,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
176
180
 
177
181
  def forward(
178
182
  self,
179
- features: list[torch.Tensor],
180
- inputs: list[dict[str, Any]],
183
+ intermediates: Any,
184
+ context: ModelContext,
181
185
  targets: list[dict[str, Any]] | None = None,
182
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
186
+ ) -> ModelOutput:
183
187
  """Compute the detection outputs and loss from features.
184
188
 
185
189
  Args:
186
- features: multi-scale feature maps.
187
- inputs: original inputs, should cotnain image key for original image size.
190
+ intermediates: the output from the previous component, which must be a FeatureMaps.
191
+ context: the model context. Input dicts must contain image key for original image size.
188
192
  targets: should contain class key that stores the class label.
189
193
 
190
194
  Returns:
191
195
  tuple of outputs and loss dict
192
196
  """
197
+ if not isinstance(intermediates, FeatureMaps):
198
+ raise ValueError("input to FasterRCNN must be FeatureMaps")
199
+
193
200
  # Fix target labels to be 1 size in case it's empty.
194
201
  # For some reason this is needed.
195
202
  if targets:
@@ -203,11 +210,11 @@ class FasterRCNN(torch.nn.Module):
203
210
  ),
204
211
  )
205
212
 
206
- image_list = [inp["image"] for inp in inputs]
213
+ image_list = [inp["image"] for inp in context.inputs]
207
214
  images, targets = self.noop_transform(image_list, targets)
208
215
 
209
216
  feature_dict = collections.OrderedDict()
210
- for i, feat_map in enumerate(features):
217
+ for i, feat_map in enumerate(intermediates.feature_maps):
211
218
  feature_dict[f"feat{i}"] = feat_map
212
219
 
213
220
  proposals, proposal_losses = self.rpn(images, feature_dict, targets)
@@ -219,4 +226,7 @@ class FasterRCNN(torch.nn.Module):
219
226
  losses.update(proposal_losses)
220
227
  losses.update(detector_losses)
221
228
 
222
- return detections, losses
229
+ return ModelOutput(
230
+ outputs=detections,
231
+ loss_dict=losses,
232
+ )
@@ -2,10 +2,12 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
5
+ from rslearn.train.model_context import ModelContext
6
6
 
7
+ from .component import FeatureMaps, IntermediateComponent
7
8
 
8
- class FeatureCenterCrop(torch.nn.Module):
9
+
10
+ class FeatureCenterCrop(IntermediateComponent):
9
11
  """Apply center cropping on the input feature maps."""
10
12
 
11
13
  def __init__(
@@ -24,20 +26,21 @@ class FeatureCenterCrop(torch.nn.Module):
24
26
  super().__init__()
25
27
  self.sizes = sizes
26
28
 
27
- def forward(
28
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
29
- ) -> list[torch.Tensor]:
29
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
30
30
  """Apply center cropping on the feature maps.
31
31
 
32
32
  Args:
33
- features: list of feature maps at different resolutions.
34
- inputs: original inputs (ignored).
33
+ intermediates: output from the previous model component, which must be a FeatureMaps.
34
+ context: the model context.
35
35
 
36
36
  Returns:
37
37
  center cropped feature maps.
38
38
  """
39
+ if not isinstance(intermediates, FeatureMaps):
40
+ raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
41
+
39
42
  new_features = []
40
- for i, feat in enumerate(features):
43
+ for i, feat in enumerate(intermediates.feature_maps):
41
44
  height, width = self.sizes[i]
42
45
  if feat.shape[2] < height or feat.shape[3] < width:
43
46
  raise ValueError(
@@ -47,4 +50,4 @@ class FeatureCenterCrop(torch.nn.Module):
47
50
  start_w = feat.shape[3] // 2 - width // 2
48
51
  feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
49
52
  new_features.append(feat)
50
- return new_features
53
+ return FeatureMaps(new_features)
rslearn/models/fpn.py CHANGED
@@ -1,12 +1,16 @@
1
1
  """Feature pyramid network."""
2
2
 
3
3
  import collections
4
+ from typing import Any
4
5
 
5
- import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import ModelContext
8
9
 
9
- class Fpn(torch.nn.Module):
10
+ from .component import FeatureMaps, IntermediateComponent
11
+
12
+
13
+ class Fpn(IntermediateComponent):
10
14
  """A feature pyramid network (FPN).
11
15
 
12
16
  The FPN inputs a multi-scale feature map. At each scale, it computes new features
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
32
36
  in_channels_list=in_channels, out_channels=out_channels
33
37
  )
34
38
 
35
- def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
39
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
36
40
  """Compute outputs of the FPN.
37
41
 
38
42
  Args:
39
- x: the multi-scale feature maps
43
+ intermediates: the output from the previous component, which must be a FeatureMaps.
44
+ context: the model context.
40
45
 
41
46
  Returns:
42
- new multi-scale feature maps from the FPN
47
+ new multi-scale feature maps from the FPN.
43
48
  """
44
- inp = collections.OrderedDict([(f"feat{i}", el) for i, el in enumerate(x)])
49
+ if not isinstance(intermediates, FeatureMaps):
50
+ raise ValueError("input to Fpn must be FeatureMaps")
51
+
52
+ feature_maps = intermediates.feature_maps
53
+ inp = collections.OrderedDict(
54
+ [(f"feat{i}", el) for i, el in enumerate(feature_maps)]
55
+ )
45
56
  output = self.fpn(inp)
46
57
  output = list(output.values())
47
58
 
48
59
  if self.prepend:
49
- return output + x
60
+ return FeatureMaps(output + feature_maps)
50
61
  else:
51
- return output
62
+ return FeatureMaps(output)
@@ -4,16 +4,16 @@ import math
4
4
  import tempfile
5
5
  from contextlib import nullcontext
6
6
  from enum import StrEnum
7
- from typing import Any, cast
7
+ from typing import cast
8
8
 
9
9
  import numpy as np
10
10
  import torch
11
- import torch.nn as nn
12
11
  from einops import rearrange, repeat
13
12
  from huggingface_hub import hf_hub_download
14
13
  from upath import UPath
15
14
 
16
15
  from rslearn.log_utils import get_logger
16
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
17
17
  from rslearn.models.galileo.single_file_galileo import (
18
18
  CONFIG_FILENAME,
19
19
  DW_BANDS,
@@ -39,6 +39,7 @@ from rslearn.models.galileo.single_file_galileo import (
39
39
  MaskedOutput,
40
40
  Normalizer,
41
41
  )
42
+ from rslearn.train.model_context import ModelContext
42
43
 
43
44
  logger = get_logger(__name__)
44
45
 
@@ -70,7 +71,7 @@ AUTOCAST_DTYPE_MAP = {
70
71
  }
71
72
 
72
73
 
73
- class GalileoModel(nn.Module):
74
+ class GalileoModel(FeatureExtractor):
74
75
  """Galileo backbones."""
75
76
 
76
77
  input_keys = [
@@ -410,11 +411,11 @@ class GalileoModel(nn.Module):
410
411
  months=months,
411
412
  )
412
413
 
413
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
414
+ def forward(self, context: ModelContext) -> FeatureMaps:
414
415
  """Compute feature maps from the Galileo backbone.
415
416
 
416
- Inputs:
417
- inputs: a dictionary of tensors, where the keys are one of Galileo.input_keys
417
+ Args:
418
+ context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
418
419
  (also documented below) and values are tensors of the following shapes,
419
420
  per input key:
420
421
  "s1": B (T * C) H W
@@ -436,10 +437,12 @@ class GalileoModel(nn.Module):
436
437
  take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
437
438
  """
438
439
  stacked_inputs = {}
439
- for key in inputs[0].keys():
440
+ for key in context.inputs[0].keys():
440
441
  # assume all the keys in an input are consistent
441
442
  if key in self.input_keys:
442
- stacked_inputs[key] = torch.stack([inp[key] for inp in inputs], dim=0)
443
+ stacked_inputs[key] = torch.stack(
444
+ [inp[key] for inp in context.inputs], dim=0
445
+ )
443
446
  s_t_channels = []
444
447
  for space_time_modality in ["s1", "s2"]:
445
448
  if space_time_modality not in stacked_inputs:
@@ -502,14 +505,14 @@ class GalileoModel(nn.Module):
502
505
  # Decide context based on self.autocast_dtype.
503
506
  device = galileo_input.s_t_x.device
504
507
  if self.autocast_dtype is None:
505
- context = nullcontext()
508
+ torch_context = nullcontext()
506
509
  else:
507
510
  assert device is not None
508
- context = torch.amp.autocast(
511
+ torch_context = torch.amp.autocast(
509
512
  device_type=device.type, dtype=self.autocast_dtype
510
513
  )
511
514
 
512
- with context:
515
+ with torch_context:
513
516
  outputs = self.model(
514
517
  s_t_x=galileo_input.s_t_x,
515
518
  s_t_m=galileo_input.s_t_m,
@@ -530,18 +533,20 @@ class GalileoModel(nn.Module):
530
533
  averaged = self.model.average_tokens(
531
534
  s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
532
535
  )
533
- return [repeat(averaged, "b d -> b d 1 1")]
536
+ return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
534
537
  else:
535
538
  s_t_x = outputs[0]
536
539
  # we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
537
540
  # s_t_x has shape [b, h, w, t, c_g, d]
538
541
  # and we want [b, d, h, w]
539
- return [
540
- rearrange(
541
- s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
542
- "b h w c_g d -> b c_g d h w",
543
- ).mean(dim=1)
544
- ]
542
+ return FeatureMaps(
543
+ [
544
+ rearrange(
545
+ s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
546
+ "b h w c_g d -> b c_g d h w",
547
+ ).mean(dim=1)
548
+ ]
549
+ )
545
550
 
546
551
  def get_backbone_channels(self) -> list:
547
552
  """Returns the output channels of this model when used as a backbone.
@@ -1,67 +1,35 @@
1
- """Module wrappers."""
1
+ """Module wrapper provided for backwards compatibility."""
2
2
 
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class DecoderModuleWrapper(torch.nn.Module):
9
- """Wrapper for a module that processes features to work in decoder.
9
+ from .component import (
10
+ FeatureExtractor,
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
10
14
 
11
- The module should input feature map and produce a new feature map.
12
15
 
13
- We wrap it to process each feature map in multi-scale features which is what's used
14
- for most decoders.
15
- """
16
-
17
- def __init__(
18
- self,
19
- module: torch.nn.Module,
20
- ):
21
- """Initialize a DecoderModuleWrapper.
16
+ class EncoderModuleWrapper(FeatureExtractor):
17
+ """Wraps one or more IntermediateComponents to function as the feature extractor.
22
18
 
23
- Args:
24
- module: the module to wrap
25
- """
26
- super().__init__()
27
- self.module = module
28
-
29
- def forward(
30
- self, features: list[torch.Tensor], inputs: list[torch.Tensor]
31
- ) -> list[torch.Tensor]:
32
- """Apply the wrapped module on each feature map.
33
-
34
- Args:
35
- features: list of feature maps at different resolutions.
36
- inputs: original inputs (ignored).
37
-
38
- Returns:
39
- new features
40
- """
41
- new_features = []
42
- for feat_map in features:
43
- feat_map = self.module(feat_map)
44
- new_features.append(feat_map)
45
- return new_features
46
-
47
-
48
- class EncoderModuleWrapper(torch.nn.Module):
49
- """Wraps a module that is intended to be used as the decoder to work in encoder.
50
-
51
- The module should input a feature map that corresponds to the original image, i.e.
52
- the depth of the feature map would be the number of bands in the input image.
19
+ The first component should input a FeatureMaps, which will be computed from the
20
+ overall inputs by stacking the "image" key from each input dict.
53
21
  """
54
22
 
55
23
  def __init__(
56
24
  self,
57
- module: torch.nn.Module | None = None,
58
- modules: list[torch.nn.Module] = [],
25
+ module: IntermediateComponent | None = None,
26
+ modules: list[IntermediateComponent] = [],
59
27
  ):
60
28
  """Initialize an EncoderModuleWrapper.
61
29
 
62
30
  Args:
63
- module: the encoder module to wrap. Exactly one one of module or modules
64
- must be set.
31
+ module: the IntermediateComponent to wrap for use as a FeatureExtractor.
32
+ Exactly one of module or modules must be set.
65
33
  modules: list of modules to wrap
66
34
  """
67
35
  super().__init__()
@@ -74,18 +42,19 @@ class EncoderModuleWrapper(torch.nn.Module):
74
42
  else:
75
43
  raise ValueError("one of module or modules must be set")
76
44
 
77
- def forward(
78
- self,
79
- inputs: list[dict[str, Any]],
80
- ) -> list[torch.Tensor]:
45
+ def forward(self, context: ModelContext) -> Any:
81
46
  """Compute outputs from the wrapped module.
82
47
 
83
- Inputs:
84
- inputs: input dicts that must include "image" key containing the image to
85
- process.
48
+ Args:
49
+ context: the model context. Input dicts must include "image" key containing
50
+ the image to convert to a FeatureMaps, which will be passed to the
51
+ first wrapped module.
52
+
53
+ Returns:
54
+ the output from the last wrapped module.
86
55
  """
87
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
88
- cur = [images]
56
+ images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
57
+ cur: Any = FeatureMaps([images])
89
58
  for m in self.encoder_modules:
90
- cur = m(cur, inputs)
59
+ cur = m(cur, context)
91
60
  return cur
rslearn/models/molmo.py CHANGED
@@ -1,12 +1,14 @@
1
1
  """Molmo model."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  from transformers import AutoModelForCausalLM, AutoProcessor
7
5
 
6
+ from rslearn.train.model_context import ModelContext
7
+
8
+ from .component import FeatureExtractor, FeatureMaps
8
9
 
9
- class Molmo(torch.nn.Module):
10
+
11
+ class Molmo(FeatureExtractor):
10
12
  """Molmo image encoder."""
11
13
 
12
14
  def __init__(
@@ -34,21 +36,21 @@ class Molmo(torch.nn.Module):
34
36
  ) # nosec
35
37
  self.encoder = model.model.vision_backbone
36
38
 
37
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
39
+ def forward(self, context: ModelContext) -> FeatureMaps:
38
40
  """Compute outputs from the backbone.
39
41
 
40
- Inputs:
41
- inputs: input dicts that must include "image" key containing the image to
42
- process. The images should have values 0-255.
42
+ Args:
43
+ context: the model context. Input dicts must include "image" key containing
44
+ the image to process. The images should have values 0-255.
43
45
 
44
46
  Returns:
45
- list of feature maps. Molmo produces features at one scale, so the list
46
- contains a single Bx24x24x2048 tensor.
47
+ a FeatureMaps. Molmo produces features at one scale, so it will contain one
48
+ feature map that is a Bx24x24x2048 tensor.
47
49
  """
48
- device = inputs[0]["image"].device
50
+ device = context.inputs[0]["image"].device
49
51
  molmo_inputs_list = []
50
52
  # Process each one so we can isolate just the full image without any crops.
51
- for inp in inputs:
53
+ for inp in context.inputs:
52
54
  image = inp["image"].cpu().numpy().transpose(1, 2, 0)
53
55
  processed = self.processor.process(
54
56
  images=[image],
@@ -60,6 +62,6 @@ class Molmo(torch.nn.Module):
60
62
  image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
61
63
 
62
64
  # 576x2048 -> 24x24x2048
63
- return [
64
- image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
65
- ]
65
+ return FeatureMaps(
66
+ [image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
67
+ )