rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ """An attention pooling layer."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from rslearn.models.component import (
12
+ FeatureMaps,
13
+ IntermediateComponent,
14
+ TokenFeatureMaps,
15
+ )
16
+ from rslearn.train.model_context import ModelContext
17
+
18
+
19
+ class SimpleAttentionPool(IntermediateComponent):
20
+ """Simple Attention Pooling.
21
+
22
+ Given a token feature map of shape BCHWN,
23
+ learn an attention layer which aggregates over
24
+ the N dimension.
25
+
26
+ This is done simply by learning a mapping D->1 which is the weight
27
+ which should be assigned to each token during averaging:
28
+
29
+ output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
30
+ """
31
+
32
+ def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
33
+ """Initialize the simple attention pooling layer.
34
+
35
+ Args:
36
+ in_dim: the encoding dimension D
37
+ hidden_linear: whether to apply an additional linear transformation D -> D
38
+ to the feat tokens. If this is True, a ReLU activation is applied
39
+ after the first linear transformation.
40
+ """
41
+ super().__init__()
42
+ if hidden_linear:
43
+ self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
44
+ else:
45
+ self.hidden_linear = None
46
+ self.linear = nn.Linear(in_features=in_dim, out_features=1)
47
+
48
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
49
+ """Attention pooling for a single feature map (BCHWN tensor)."""
50
+ B, D, H, W, N = feat_tokens.shape
51
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
52
+ if self.hidden_linear is not None:
53
+ feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
54
+ attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
55
+ feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
56
+ return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
57
+
58
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
59
+ """Forward pass for attention pooling linear probe.
60
+
61
+ Args:
62
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
63
+ We pool over the final dimension in the TokenFeatureMaps. If multiple maps
64
+ are passed, we apply the same linear layers to all of them.
65
+ context: the model context.
66
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
67
+
68
+ Returns:
69
+ torch.Tensor:
70
+ - output, attentioned pool over the last dimension (B, C, H, W)
71
+ """
72
+ if not isinstance(intermediates, TokenFeatureMaps):
73
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
74
+
75
+ features = []
76
+ for feat in intermediates.feature_maps:
77
+ features.append(self.forward_for_map(feat))
78
+ return FeatureMaps(features)
79
+
80
+
81
+ class AttentionPool(IntermediateComponent):
82
+ """Attention Pooling.
83
+
84
+ Given a feature map of shape BCHWN,
85
+ learn an attention layer which aggregates over
86
+ the N dimension.
87
+
88
+ We do this by learning a query token, and applying a standard
89
+ attention mechanism against this learned query token.
90
+ """
91
+
92
+ def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
93
+ """Initialize the attention pooling layer.
94
+
95
+ Args:
96
+ in_dim: the encoding dimension D
97
+ num_heads: the number of heads to use
98
+ linear_on_kv: Whether to apply a linear layer on the input tokens
99
+ to create the key and value tokens.
100
+ """
101
+ super().__init__()
102
+ self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
103
+ if linear_on_kv:
104
+ self.k_linear = nn.Linear(in_dim, in_dim)
105
+ self.v_linear = nn.Linear(in_dim, in_dim)
106
+ else:
107
+ self.k_linear = None
108
+ self.v_linear = None
109
+ if in_dim % num_heads != 0:
110
+ raise ValueError(
111
+ f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
112
+ )
113
+ self.num_heads = num_heads
114
+ self.init_weights()
115
+
116
+ def init_weights(self) -> None:
117
+ """Initialize weights for the probe."""
118
+ nn.init.trunc_normal_(self.query_token, std=0.02)
119
+
120
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
121
+ """Attention pooling for a single feature map (BCHWN tensor)."""
122
+ B, D, H, W, N = feat_tokens.shape
123
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
124
+ collapsed_dim = B * H * W
125
+ q = self.query_token.expand(collapsed_dim, 1, -1)
126
+ q = q.reshape(
127
+ collapsed_dim, 1, self.num_heads, D // self.num_heads
128
+ ) # [B, 1, head, D_head]
129
+ q = rearrange(q, "b h n d -> b n h d")
130
+ if self.k_linear is not None:
131
+ assert self.v_linear is not None
132
+ k = self.k_linear(feat_tokens).reshape(
133
+ collapsed_dim, N, self.num_heads, D // self.num_heads
134
+ )
135
+ v = self.v_linear(feat_tokens).reshape(
136
+ collapsed_dim, N, self.num_heads, D // self.num_heads
137
+ )
138
+ else:
139
+ k = feat_tokens.reshape(
140
+ collapsed_dim, N, self.num_heads, D // self.num_heads
141
+ )
142
+ v = feat_tokens.reshape(
143
+ collapsed_dim, N, self.num_heads, D // self.num_heads
144
+ )
145
+ k = rearrange(k, "b n h d -> b h n d")
146
+ v = rearrange(v, "b n h d -> b h n d")
147
+
148
+ # Compute attention scores
149
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
150
+ D // self.num_heads
151
+ )
152
+ attn_weights = F.softmax(attn_scores, dim=-1)
153
+ x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
+ return x.reshape(B, D, H, W)
155
+
156
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
+ """Forward pass for attention pooling linear probe.
158
+
159
+ Args:
160
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
161
+ We pool over the final dimension in the TokenFeatureMaps. If multiple feature
162
+ maps are passed, we apply the same attention weights (query token and linear k, v layers)
163
+ to all the maps.
164
+ context: the model context.
165
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
166
+
167
+ Returns:
168
+ torch.Tensor:
169
+ - output, attentioned pool over the last dimension (B, C, H, W)
170
+ """
171
+ if not isinstance(intermediates, TokenFeatureMaps):
172
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
173
+
174
+ features = []
175
+ for feat in intermediates.feature_maps:
176
+ features.append(self.forward_for_map(feat))
177
+ return FeatureMaps(features)
@@ -16,6 +16,8 @@ from huggingface_hub import hf_hub_download
16
16
  # from claymodel.module import ClayMAEModule
17
17
  from terratorch.models.backbones.clay_v15.module import ClayMAEModule
18
18
 
19
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
20
+ from rslearn.train.model_context import ModelContext
19
21
  from rslearn.train.transforms.normalize import Normalize
20
22
  from rslearn.train.transforms.transform import Transform
21
23
 
@@ -42,7 +44,7 @@ def get_clay_checkpoint_path(
42
44
  return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
43
45
 
44
46
 
45
- class Clay(torch.nn.Module):
47
+ class Clay(FeatureExtractor):
46
48
  """Clay backbones."""
47
49
 
48
50
  def __init__(
@@ -108,23 +110,20 @@ class Clay(torch.nn.Module):
108
110
  image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
109
111
  )
110
112
 
111
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
113
+ def forward(self, context: ModelContext) -> FeatureMaps:
112
114
  """Forward pass for the Clay model.
113
115
 
114
116
  Args:
115
- inputs: input dicts that must include `self.modality` as a key
117
+ context: the model context. Input dicts must include `self.modality` as a key
116
118
 
117
119
  Returns:
118
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
120
+ a FeatureMaps consisting of one feature map, computed by Clay.
119
121
  """
120
- if self.modality not in inputs[0]:
121
- raise ValueError(f"Missing modality {self.modality} in inputs.")
122
-
123
122
  param = next(self.model.parameters())
124
123
  device = param.device
125
124
 
126
125
  chips = torch.stack(
127
- [inp[self.modality] for inp in inputs], dim=0
126
+ [inp[self.modality] for inp in context.inputs], dim=0
128
127
  ) # (B, C, H, W)
129
128
  if self.do_resizing:
130
129
  chips = self._resize_image(chips, chips.shape[2])
@@ -163,7 +162,7 @@ class Clay(torch.nn.Module):
163
162
  )
164
163
 
165
164
  features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
166
- return [features]
165
+ return FeatureMaps([features])
167
166
 
168
167
  def get_backbone_channels(self) -> list:
169
168
  """Return output channels of this model when used as a backbone."""
rslearn/models/clip.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """OpenAI CLIP models."""
2
2
 
3
- from typing import Any
4
-
5
- import torch
6
3
  from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
7
4
 
5
+ from rslearn.train.model_context import ModelContext
6
+
7
+ from .component import FeatureExtractor, FeatureMaps
8
+
8
9
 
9
- class CLIP(torch.nn.Module):
10
+ class CLIP(FeatureExtractor):
10
11
  """CLIP image encoder."""
11
12
 
12
13
  def __init__(
@@ -31,17 +32,17 @@ class CLIP(torch.nn.Module):
31
32
  self.height = crop_size["height"] // stride[0]
32
33
  self.width = crop_size["width"] // stride[1]
33
34
 
34
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
35
+ def forward(self, context: ModelContext) -> FeatureMaps:
35
36
  """Compute outputs from the backbone.
36
37
 
37
- Inputs:
38
- inputs: input dicts that must include "image" key containing the image to
39
- process. The images should have values 0-255.
38
+ Args:
39
+ context: the model context. Input dicts must include "image" key containing
40
+ the image to process. The images should have values 0-255.
40
41
 
41
42
  Returns:
42
- list of feature maps. The ViT produces features at one scale, so the list
43
- contains a single Bx24x24x1024 feature map.
43
+ a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
44
  """
45
+ inputs = context.inputs
45
46
  device = inputs[0]["image"].device
46
47
  clip_inputs = self.processor(
47
48
  images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
@@ -55,8 +56,10 @@ class CLIP(torch.nn.Module):
55
56
  batch_size = image_features.shape[0]
56
57
 
57
58
  # 576x1024 -> HxWxC
58
- return [
59
- image_features.reshape(
60
- batch_size, self.height, self.width, self.num_features
61
- ).permute(0, 3, 1, 2)
62
- ]
59
+ return FeatureMaps(
60
+ [
61
+ image_features.reshape(
62
+ batch_size, self.height, self.width, self.num_features
63
+ ).permute(0, 3, 1, 2)
64
+ ]
65
+ )
@@ -0,0 +1,111 @@
1
+ """Model component API."""
2
+
3
+ import abc
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+
12
+ class FeatureExtractor(torch.nn.Module, abc.ABC):
13
+ """A feature extractor that performs initial processing of the inputs.
14
+
15
+ The FeatureExtractor is the first component in the encoders list for
16
+ SingleTaskModel and MultiTaskModel.
17
+ """
18
+
19
+ @abc.abstractmethod
20
+ def forward(self, context: ModelContext) -> Any:
21
+ """Extract an initial intermediate from the model context.
22
+
23
+ Args:
24
+ context: the model context.
25
+
26
+ Returns:
27
+ any intermediate to pass to downstream components. Oftentimes this is a
28
+ FeatureMaps.
29
+ """
30
+ raise NotImplementedError
31
+
32
+
33
+ class IntermediateComponent(torch.nn.Module, abc.ABC):
34
+ """An intermediate component in the model.
35
+
36
+ In SingleTaskModel and MultiTaskModel, modules after the first module
37
+ in the encoders list are IntermediateComponents, as are modules before the last
38
+ module in the decoders list(s).
39
+ """
40
+
41
+ @abc.abstractmethod
42
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
43
+ """Process the given intermediate into another intermediate.
44
+
45
+ Args:
46
+ intermediates: the output from the previous component (either a
47
+ FeatureExtractor or another IntermediateComponent).
48
+ context: the model context.
49
+
50
+ Returns:
51
+ any intermediate to pass to downstream components.
52
+ """
53
+ raise NotImplementedError
54
+
55
+
56
+ class Predictor(torch.nn.Module, abc.ABC):
57
+ """A predictor that computes task-specific outputs and a loss dict.
58
+
59
+ In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
60
+ are Predictors.
61
+ """
62
+
63
+ @abc.abstractmethod
64
+ def forward(
65
+ self,
66
+ intermediates: Any,
67
+ context: ModelContext,
68
+ targets: list[dict[str, torch.Tensor]] | None = None,
69
+ ) -> ModelOutput:
70
+ """Compute task-specific outputs and loss dict.
71
+
72
+ Args:
73
+ intermediates: the output from the previous component.
74
+ context: the model context.
75
+ targets: the training targets, or None during prediction.
76
+
77
+ Returns:
78
+ a tuple of the task-specific outputs (which should be compatible with the
79
+ configured Task) and loss dict. The loss dict maps from a name for each
80
+ loss to a scalar tensor.
81
+ """
82
+ raise NotImplementedError
83
+
84
+
85
+ @dataclass
86
+ class FeatureMaps:
87
+ """An intermediate output type for multi-resolution feature maps."""
88
+
89
+ # List of BxCxHxW feature maps at different scales, ordered from highest resolution
90
+ # (most fine-grained) to lowest resolution (coarsest).
91
+ feature_maps: list[torch.Tensor]
92
+
93
+
94
+ @dataclass
95
+ class TokenFeatureMaps:
96
+ """An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
97
+
98
+ Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
99
+ """
100
+
101
+ # List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
102
+ # (most fine-grained) to lowest resolution (coarsest).
103
+ feature_maps: list[torch.Tensor]
104
+
105
+
106
+ @dataclass
107
+ class FeatureVector:
108
+ """An intermediate output type for a flat feature vector."""
109
+
110
+ # Flat BxC feature vector.
111
+ feature_vector: torch.Tensor
@@ -4,8 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class ConcatenateFeatures(torch.nn.Module):
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class ConcatenateFeatures(IntermediateComponent):
9
13
  """Concatenate feature map with additional raw data inputs."""
10
14
 
11
15
  def __init__(
@@ -55,26 +59,32 @@ class ConcatenateFeatures(torch.nn.Module):
55
59
 
56
60
  self.conv_layers = torch.nn.Sequential(*conv_layers)
57
61
 
58
- def forward(
59
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
60
- ) -> list[torch.Tensor]:
62
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
61
63
  """Concatenate the feature map with the raw data inputs.
62
64
 
63
65
  Args:
64
- features: list of feature maps at different resolutions.
65
- inputs: original inputs.
66
+ intermediates: the previous output, which must be a FeatureMaps.
67
+ context: the model context. The input dicts must have a key matching the
68
+ configured key.
66
69
 
67
70
  Returns:
68
71
  concatenated feature maps.
69
72
  """
70
- if not features:
71
- raise ValueError("Expected at least one feature map, got none.")
73
+ if (
74
+ not isinstance(intermediates, FeatureMaps)
75
+ or len(intermediates.feature_maps) == 0
76
+ ):
77
+ raise ValueError(
78
+ "Expected input to be FeatureMaps with at least one feature map"
79
+ )
72
80
 
73
- add_data = torch.stack([input_data[self.key] for input_data in inputs], dim=0)
81
+ add_data = torch.stack(
82
+ [input_data[self.key] for input_data in context.inputs], dim=0
83
+ )
74
84
  add_features = self.conv_layers(add_data)
75
85
 
76
86
  new_features: list[torch.Tensor] = []
77
- for feature_map in features:
87
+ for feature_map in intermediates.feature_maps:
78
88
  # Shape of feature map: BCHW
79
89
  feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
80
90
 
@@ -90,4 +100,4 @@ class ConcatenateFeatures(torch.nn.Module):
90
100
 
91
101
  new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
92
102
 
93
- return new_features
103
+ return FeatureMaps(new_features)
rslearn/models/conv.py CHANGED
@@ -4,8 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class Conv(torch.nn.Module):
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class Conv(IntermediateComponent):
9
13
  """A single convolutional layer.
10
14
 
11
15
  It inputs a set of feature maps; the conv layer is applied to each feature map
@@ -38,19 +42,22 @@ class Conv(torch.nn.Module):
38
42
  )
39
43
  self.activation = activation
40
44
 
41
- def forward(self, features: list[torch.Tensor], inputs: Any) -> list[torch.Tensor]:
42
- """Compute flat output vector from multi-scale feature map.
45
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
46
+ """Apply conv layer on each feature map.
43
47
 
44
48
  Args:
45
- features: list of feature maps at different resolutions.
46
- inputs: original inputs (ignored).
49
+ intermediates: the previous output, which must be a FeatureMaps.
50
+ context: the model context.
47
51
 
48
52
  Returns:
49
- flat feature vector
53
+ the resulting feature maps after applying the same Conv2d on each one.
50
54
  """
55
+ if not isinstance(intermediates, FeatureMaps):
56
+ raise ValueError("input to Conv must be FeatureMaps")
57
+
51
58
  new_features = []
52
- for feat_map in features:
59
+ for feat_map in intermediates.feature_maps:
53
60
  feat_map = self.layer(feat_map)
54
61
  feat_map = self.activation(feat_map)
55
62
  new_features.append(feat_map)
56
- return new_features
63
+ return FeatureMaps(new_features)
rslearn/models/croma.py CHANGED
@@ -12,9 +12,11 @@ from einops import rearrange
12
12
  from upath import UPath
13
13
 
14
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.train.model_context import ModelContext
15
16
  from rslearn.train.transforms.transform import Transform
16
17
  from rslearn.utils.fsspec import open_atomic
17
18
 
19
+ from .component import FeatureExtractor, FeatureMaps
18
20
  from .use_croma import PretrainedCROMA
19
21
 
20
22
  logger = get_logger(__name__)
@@ -76,7 +78,7 @@ MODALITY_BANDS = {
76
78
  }
77
79
 
78
80
 
79
- class Croma(torch.nn.Module):
81
+ class Croma(FeatureExtractor):
80
82
  """CROMA backbones.
81
83
 
82
84
  There are two model sizes, base and large.
@@ -160,20 +162,23 @@ class Croma(torch.nn.Module):
160
162
  align_corners=False,
161
163
  )
162
164
 
163
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
165
+ def forward(self, context: ModelContext) -> FeatureMaps:
164
166
  """Compute feature maps from the Croma backbone.
165
167
 
166
- Inputs:
167
- inputs: input dicts that must include either/both of "sentinel2" or
168
- "sentinel1" keys depending on the configured modality.
168
+ Args:
169
+ context: the model context. Input dicts must include either/both of
170
+ "sentinel2" or "sentinel1" keys depending on the configured modality.
171
+
172
+ Returns:
173
+ a FeatureMaps with one feature map at 1/8 the input resolution.
169
174
  """
170
175
  sentinel1: torch.Tensor | None = None
171
176
  sentinel2: torch.Tensor | None = None
172
177
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
173
- sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
178
+ sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
174
179
  sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
175
180
  if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
176
- sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
181
+ sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
177
182
  sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
178
183
 
179
184
  outputs = self.model(
@@ -200,7 +205,7 @@ class Croma(torch.nn.Module):
200
205
  w=num_patches_per_dim,
201
206
  )
202
207
 
203
- return [features]
208
+ return FeatureMaps([features])
204
209
 
205
210
  def get_backbone_channels(self) -> list:
206
211
  """Returns the output channels of this model when used as a backbone.
@@ -13,6 +13,8 @@ import torch.nn.functional as F
13
13
  from torch import nn
14
14
 
15
15
  import rslearn.models.detr.box_ops as box_ops
16
+ from rslearn.models.component import FeatureMaps, Predictor
17
+ from rslearn.train.model_context import ModelContext, ModelOutput
16
18
 
17
19
  from .matcher import HungarianMatcher
18
20
  from .position_encoding import PositionEmbeddingSine
@@ -405,7 +407,7 @@ class PostProcess(nn.Module):
405
407
  return results
406
408
 
407
409
 
408
- class Detr(nn.Module):
410
+ class Detr(Predictor):
409
411
  """DETR prediction module.
410
412
 
411
413
  This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
@@ -440,33 +442,39 @@ class Detr(nn.Module):
440
442
 
441
443
  def forward(
442
444
  self,
443
- features: list[torch.Tensor],
444
- inputs: list[dict[str, Any]],
445
+ intermediates: Any,
446
+ context: ModelContext,
445
447
  targets: list[dict[str, Any]] | None = None,
446
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
448
+ ) -> ModelOutput:
447
449
  """Compute the detection outputs and loss from features.
448
450
 
449
451
  DETR will use only the last feature map, which should correspond to the lowest
450
452
  resolution one.
451
453
 
452
454
  Args:
453
- features: multi-scale feature maps.
454
- inputs: original inputs, should contain image key for original image size.
455
- targets: should contain class key that stores the class label.
455
+ intermediates: the output from the previous component. It must be a FeatureMaps.
456
+ context: the model context. Input dicts must contain an "image" key which we will
457
+ be used to establish the original image size.
458
+ targets: must contain class key that stores the class label.
456
459
 
457
460
  Returns:
458
- tuple of outputs and loss dict.
461
+ the model output.
459
462
  """
463
+ if not isinstance(intermediates, FeatureMaps):
464
+ raise ValueError("input to Detr must be a FeatureMaps")
465
+
466
+ # We only use the last feature map (most fine-grained).
467
+ features = intermediates.feature_maps[-1]
468
+
460
469
  # Get image sizes.
461
470
  image_sizes = torch.tensor(
462
- [[inp["image"].shape[2], inp["image"].shape[1]] for inp in inputs],
471
+ [[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
463
472
  dtype=torch.int32,
464
- device=features[0].device,
473
+ device=features.device,
465
474
  )
466
475
 
467
- feat_map = features[-1]
468
- pos_embedding = self.pos_embedding(feat_map)
469
- outputs = self.predictor(feat_map, pos_embedding)
476
+ pos_embedding = self.pos_embedding(features)
477
+ outputs = self.predictor(features, pos_embedding)
470
478
 
471
479
  if targets is not None:
472
480
  # Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
@@ -490,4 +498,7 @@ class Detr(nn.Module):
490
498
 
491
499
  results = self.postprocess(outputs, image_sizes)
492
500
 
493
- return results, losses
501
+ return ModelOutput(
502
+ outputs=results,
503
+ loss_dict=losses,
504
+ )