rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,22 @@
1
1
  """SSL4EO-S12 models."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  import torchvision
7
5
 
6
+ from rslearn.train.model_context import ModelContext
7
+
8
+ from .component import FeatureExtractor, FeatureMaps
9
+
8
10
 
9
- class Ssl4eoS12(torch.nn.Module):
11
+ class Ssl4eoS12(FeatureExtractor):
10
12
  """The SSL4EO-S12 family of pretrained models."""
11
13
 
12
14
  def __init__(
13
15
  self,
14
- backbone_ckpt_path: str,
16
+ backbone_ckpt_path: str | None,
15
17
  arch: str = "resnet50",
16
18
  output_layers: list[int] = [0, 1, 2, 3],
17
- ):
19
+ ) -> None:
18
20
  """Instantiate a new Swin instance.
19
21
 
20
22
  Args:
@@ -37,21 +39,24 @@ class Ssl4eoS12(torch.nn.Module):
37
39
  else:
38
40
  raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
39
41
 
40
- state_dict = torch.load(backbone_ckpt_path)
41
- state_dict = state_dict["teacher"]
42
- prefix = "module.backbone."
43
- state_dict = {
44
- k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
45
- }
46
- missing_keys, unexpected_keys = self.model.load_state_dict(
47
- state_dict, strict=False
48
- )
49
- if missing_keys or unexpected_keys:
50
- print(
51
- f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
42
+ if backbone_ckpt_path is not None:
43
+ state_dict = torch.load(backbone_ckpt_path, weights_only=True)
44
+ state_dict = state_dict["teacher"]
45
+ prefix = "module.backbone."
46
+ state_dict = {
47
+ k[len(prefix) :]: v
48
+ for k, v in state_dict.items()
49
+ if k.startswith(prefix)
50
+ }
51
+ missing_keys, unexpected_keys = self.model.load_state_dict(
52
+ state_dict, strict=False
52
53
  )
54
+ if missing_keys or unexpected_keys:
55
+ print(
56
+ f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
57
+ )
53
58
 
54
- def get_backbone_channels(self):
59
+ def get_backbone_channels(self) -> list[tuple[int, int]]:
55
60
  """Returns the output channels of this model when used as a backbone.
56
61
 
57
62
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -65,28 +70,33 @@ class Ssl4eoS12(torch.nn.Module):
65
70
  """
66
71
  if self.arch == "resnet50":
67
72
  all_out_channels = [
68
- [4, 256],
69
- [8, 512],
70
- [16, 1024],
71
- [32, 2048],
73
+ (4, 256),
74
+ (8, 512),
75
+ (16, 1024),
76
+ (32, 2048),
72
77
  ]
73
78
  return [all_out_channels[idx] for idx in self.output_layers]
74
79
 
75
80
  def forward(
76
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
77
- ):
81
+ self,
82
+ context: ModelContext,
83
+ ) -> FeatureMaps:
78
84
  """Compute outputs from the backbone.
79
85
 
80
86
  If output_layers is set, then the outputs are multi-scale feature maps;
81
87
  otherwise, the model is being used for classification so the outputs are class
82
88
  probabilities and the loss.
83
89
 
84
- Inputs:
85
- inputs: input dicts that must include "image" key containing the image to
86
- process.
87
- targets: target dicts that are ignored unless
90
+ Args:
91
+ context: the model context. Input dicts must include "image" key containing
92
+ the images to process.
93
+
94
+ Returns:
95
+ feature maps computed by the pre-trained model.
88
96
  """
89
- x = torch.stack([inp["image"] for inp in inputs], dim=0)
97
+ x = torch.stack(
98
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
99
+ )
90
100
  x = self.model.conv1(x)
91
101
  x = self.model.bn1(x)
92
102
  x = self.model.relu(x)
@@ -97,4 +107,4 @@ class Ssl4eoS12(torch.nn.Module):
97
107
  layer3 = self.model.layer3(layer2)
98
108
  layer4 = self.model.layer4(layer3)
99
109
  all_features = [layer1, layer2, layer3, layer4]
100
- return [all_features[idx] for idx in self.output_layers]
110
+ return FeatureMaps([all_features[idx] for idx in self.output_layers])
rslearn/models/swin.py CHANGED
@@ -1,7 +1,5 @@
1
1
  """Swin Transformer."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
4
  import torchvision
7
5
  from torchvision.models.swin_transformer import (
@@ -13,8 +11,12 @@ from torchvision.models.swin_transformer import (
13
11
  Swin_V2_T_Weights,
14
12
  )
15
13
 
14
+ from rslearn.train.model_context import ModelContext
15
+
16
+ from .component import FeatureExtractor, FeatureMaps, FeatureVector
16
17
 
17
- class Swin(torch.nn.Module):
18
+
19
+ class Swin(FeatureExtractor):
18
20
  """A Swin Transformer model.
19
21
 
20
22
  It can either be used stand-alone for classification, or as a feature extractor in
@@ -28,15 +30,18 @@ class Swin(torch.nn.Module):
28
30
  input_channels: int = 3,
29
31
  output_layers: list[int] | None = None,
30
32
  num_outputs: int = 1000,
31
- ):
33
+ ) -> None:
32
34
  """Instantiate a new Swin instance.
33
35
 
34
36
  Args:
35
37
  arch: the architecture, e.g. "swin_v2_b" (default) or "swin_t"
36
38
  pretrained: set True to use ImageNet pre-trained weights
37
- input_channels: number of input channels (default 3)
39
+ input_channels: number of input channels (default 3). If not 3, the first
40
+ layer is updated and will be randomly initialized even if pretrained is
41
+ set.
38
42
  output_layers: list of layers to output, default use as classification
39
- model. For feature extraction, [1, 3, 5, 7] is recommended.
43
+ model (output FeatureVector). For feature extraction, [1, 3, 5, 7] is
44
+ recommended.
40
45
  num_outputs: number of output logits, defaults to 1000 which matches the
41
46
  pretrained models.
42
47
  """
@@ -89,7 +94,7 @@ class Swin(torch.nn.Module):
89
94
  if num_outputs != self.model.head.out_features:
90
95
  self.model.head = torch.nn.Linear(self.model.head.in_features, num_outputs)
91
96
 
92
- def get_backbone_channels(self):
97
+ def get_backbone_channels(self) -> list[tuple[int, int]]:
93
98
  """Returns the output channels of this model when used as a backbone.
94
99
 
95
100
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -105,43 +110,50 @@ class Swin(torch.nn.Module):
105
110
 
106
111
  if self.arch in ["swin_b", "swin_v2_b"]:
107
112
  all_out_channels = [
108
- [4, 128],
109
- [4, 128],
110
- [8, 256],
111
- [8, 256],
112
- [16, 512],
113
- [16, 512],
114
- [32, 1024],
115
- [32, 1024],
113
+ (4, 128),
114
+ (4, 128),
115
+ (4, 128),
116
+ (8, 256),
117
+ (8, 256),
118
+ (16, 512),
119
+ (16, 512),
120
+ (32, 1024),
121
+ (32, 1024),
116
122
  ]
117
123
  elif self.arch in ["swin_s", "swin_v2_s", "swin_t", "swin_v2_t"]:
118
124
  all_out_channels = [
119
- [4, 96],
120
- [4, 96],
121
- [8, 192],
122
- [8, 192],
123
- [16, 384],
124
- [16, 384],
125
- [32, 768],
126
- [32, 768],
125
+ (4, 96),
126
+ (4, 96),
127
+ (8, 192),
128
+ (8, 192),
129
+ (16, 384),
130
+ (16, 384),
131
+ (32, 768),
132
+ (32, 768),
127
133
  ]
128
134
  return [all_out_channels[idx] for idx in self.output_layers]
129
135
 
130
136
  def forward(
131
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
132
- ):
137
+ self,
138
+ context: ModelContext,
139
+ ) -> FeatureVector | FeatureMaps:
133
140
  """Compute outputs from the backbone.
134
141
 
135
142
  If output_layers is set, then the outputs are multi-scale feature maps;
136
143
  otherwise, the model is being used for classification so the outputs are class
137
144
  probabilities and the loss.
138
145
 
139
- Inputs:
140
- inputs: input dicts that must include "image" key containing the image to
141
- process.
142
- targets: target dicts that are ignored unless
146
+ Args:
147
+ context: the model context. Input dicts must include "image" key containing
148
+ the image to process.
149
+
150
+ Returns:
151
+ a FeatureVector if the configured output_layers is None, or a FeatureMaps
152
+ otherwise containing one feature map per configured output layer.
143
153
  """
144
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
154
+ images = torch.stack(
155
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
156
+ )
145
157
 
146
158
  if self.output_layers:
147
159
  layer_features = []
@@ -149,7 +161,7 @@ class Swin(torch.nn.Module):
149
161
  for layer in self.model.features:
150
162
  x = layer(x)
151
163
  layer_features.append(x.permute(0, 3, 1, 2))
152
- return [layer_features[idx] for idx in self.output_layers]
164
+ return FeatureMaps([layer_features[idx] for idx in self.output_layers])
153
165
 
154
166
  else:
155
- return self.model(images)
167
+ return FeatureVector(self.model(images))
@@ -0,0 +1,250 @@
1
+ """Task embedding modules."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class PositionalEncoding(nn.Module):
11
+ """Simple sinusoidal positional encoding for the task embedding. From torch docs."""
12
+
13
+ def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 1024):
14
+ """Initialize the positional encoding module.
15
+
16
+ Args:
17
+ d_model: The dimension of the model.
18
+ dropout: The dropout rate.
19
+ max_len: The maximum length of the sequence.
20
+ """
21
+ super().__init__()
22
+ self.dropout = nn.Dropout(p=dropout)
23
+
24
+ position = torch.arange(max_len).unsqueeze(1)
25
+ div_term = torch.exp(
26
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
27
+ )
28
+ pe = torch.zeros(max_len, 1, d_model)
29
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
30
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
31
+ self.register_buffer("pe", pe)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ """Apply positional encoding to the input tensor.
35
+
36
+ Args:
37
+ x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
38
+ """
39
+ x = x + self.pe[: x.size(0)]
40
+ return self.dropout(x)
41
+
42
+
43
+ class BaseTaskEmbedding(torch.nn.Module):
44
+ """Base class for task embedding modules."""
45
+
46
+ def __init__(self, encoder_embedding_size: int) -> None:
47
+ """Initialize the base task embedding module.
48
+
49
+ Args:
50
+ encoder_embedding_size: The size of the encoder embedding.
51
+ """
52
+ super().__init__()
53
+ self.encoder_embedding_size = encoder_embedding_size
54
+
55
+ def register_tasks(self, task_names: list[str]) -> None:
56
+ """Register the tasks.
57
+
58
+ This must happen post-init so that we can dynamically determine
59
+ the tasks to use, so it doesn't have to be specified in the config.
60
+
61
+ Args:
62
+ task_names: The names of the tasks.
63
+ """
64
+ raise NotImplementedError
65
+
66
+ def compute_embeds(
67
+ self,
68
+ features: list[torch.tensor],
69
+ inputs: list[dict[str, Any]],
70
+ ) -> torch.Tensor:
71
+ """Compute the task-specific embeddings.
72
+
73
+ Args:
74
+ features: The encoder features.
75
+ inputs: The inputs to the model.
76
+
77
+ Returns:
78
+ The task-specific embeddings.
79
+ """
80
+ raise NotImplementedError
81
+
82
+
83
+ class TaskChannelEmbedding(BaseTaskEmbedding):
84
+ """Registers task-specific 'tokens', i.e. embeddings.
85
+
86
+ Each embedding is learned per-channel and copied over the full spatial dimensions.
87
+ Optionally, add a spatial sinusoidal positional embedding to the task embedding.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ encoder_embedding_size: int,
93
+ default_idx: int = 0,
94
+ add_spatial_embed: bool = False,
95
+ ) -> None:
96
+ """Initialize the task channel embedding module.
97
+
98
+ Args:
99
+ encoder_embedding_size: The size of the encoder embedding.
100
+ default_idx: The index of the default task, useful if loading a merged model.
101
+ add_spatial_embed: if true, add a spatial sinusoidal positional embedding to the task embedding
102
+ """
103
+ super().__init__(encoder_embedding_size)
104
+ self.default_idx = default_idx
105
+ self.add_spatial_embed = add_spatial_embed
106
+ if add_spatial_embed:
107
+ self.pos_embed = PositionalEncoding(encoder_embedding_size)
108
+
109
+ def register_tasks(self, task_names: list[str]) -> None:
110
+ """Register the tasks.
111
+
112
+ This must happen post-init so that we can dynamically determine
113
+ the tasks to use, so it doesn't have to be specified in the config.
114
+
115
+ Args:
116
+ task_names: The names of the tasks.
117
+ """
118
+ self.embed = torch.nn.Embedding(len(task_names), self.encoder_embedding_size)
119
+ self.target_to_embed_idx = {name: i for i, name in enumerate(task_names)}
120
+
121
+ def compute_embeds(
122
+ self,
123
+ features: list[torch.tensor],
124
+ inputs: list[dict[str, Any]],
125
+ ) -> torch.Tensor:
126
+ """Compute the task-specific embeddings.
127
+
128
+ Args:
129
+ inputs: The inputs to the model.
130
+ features: computed encoder features
131
+ device: The device to compute the embeddings on.
132
+
133
+ Returns:
134
+ The task-specific embeddings, shape (B, T, C), T = HW
135
+ The embeddings are repeated over the spatial dimensions, and optionally
136
+ a sinusoidal positional embedding is added.
137
+ """
138
+ try:
139
+ idx = [self.target_to_embed_idx[inp["dataset_source"]] for inp in inputs]
140
+ except KeyError:
141
+ idx = [self.default_idx] * len(inputs)
142
+ embeds = self.embed(torch.tensor(idx).to(features[0].device))
143
+ seq_len = features[0].shape[-1] * features[0].shape[-2] # T = HW
144
+ embeds = embeds.unsqueeze(0).repeat(seq_len, 1, 1) # T x B x C
145
+ if self.add_spatial_embed:
146
+ embeds = self.pos_embed(embeds)
147
+ embeds = torch.einsum("tbc->btc", embeds) # B x T x C
148
+ return embeds
149
+
150
+ def forward(
151
+ self,
152
+ features: list[torch.tensor],
153
+ inputs: list[dict[str, Any]],
154
+ embeds: torch.Tensor | None = None,
155
+ ) -> list[torch.tensor]:
156
+ """Compute and apply task-specific embeddings to encoder features.
157
+
158
+ Optionally, add a spatial sinusoidal positional embedding to the task embedding.
159
+ Otherwise, the task embedding is repeated over the spatial dimensions.
160
+
161
+ Args:
162
+ features: The encoder features, a 1-list of B x C x H x W features.
163
+ inputs: The inputs to the model.
164
+ embeds: Already-computed task embeddings, if provided, skip the computation.
165
+
166
+ Returns:
167
+ The encoder features with the task-specific embeddings added.
168
+ """
169
+ height, width = features[0].shape[-2:]
170
+ assert all(f.shape[-2:] == (height, width) for f in features), (
171
+ "features must have the same spatial dimensions"
172
+ )
173
+ if embeds is None:
174
+ embeds = self.compute_embeds(features, inputs) # B x HW x C
175
+ embeds = embeds.unflatten(dim=1, sizes=(height, width)) # B x H x W x C
176
+ for i in range(len(features)):
177
+ features[i] += torch.einsum("bhwc->bchw", embeds) # B x C x H x W
178
+ return features
179
+
180
+
181
+ class TaskMHAEmbedding(TaskChannelEmbedding):
182
+ """Multi-headed cross-attention over the spatial dimensions.
183
+
184
+ The task embedding is the query and the features are the key and value.
185
+ We copy the task embedding over the spatial dimensions, and optionally
186
+ add a sinusoidal positional embedding before the MHA layer.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ encoder_embedding_size: int,
192
+ num_heads: int,
193
+ default_idx: int = 0,
194
+ add_spatial_embed: bool = True,
195
+ ) -> None:
196
+ """Initialize the task MHA embedding module.
197
+
198
+ Args:
199
+ encoder_embedding_size: The size of the encoder embedding.
200
+ num_heads: The number of attention heads.
201
+ default_idx: The index of the default task, useful if loading a merged model.
202
+ add_spatial_embed: if true, add a spatial sinusoidal positional embedding to the task embedding
203
+ """
204
+ super().__init__(encoder_embedding_size, default_idx, add_spatial_embed)
205
+ self.mha = torch.nn.MultiheadAttention(
206
+ encoder_embedding_size, num_heads, batch_first=True
207
+ )
208
+
209
+ def register_tasks(self, task_names: list[str]) -> None:
210
+ """Register the tasks.
211
+
212
+ This must happen post-init so that we can dynamically determine
213
+ the tasks to use, so it doesn't have to be specified in the config.
214
+
215
+ Args:
216
+ task_names: The names of the tasks.
217
+ """
218
+ super().register_tasks(task_names)
219
+
220
+ def forward(
221
+ self,
222
+ features: list[torch.tensor],
223
+ inputs: list[dict[str, Any]],
224
+ embeds: torch.Tensor | None = None,
225
+ ) -> list[torch.tensor]:
226
+ """Compute and apply task-specific embeddings to encoder features.
227
+
228
+ Also apply the MHA layer across the spatial dimension, with the task embedding
229
+ as the query and the features as the key and value.
230
+
231
+ Args:
232
+ features: The encoder features, a 1-list of B x C x H x W features.
233
+ inputs: The inputs to the model.
234
+ embeds: Already-computed task embeddings, if provided, skip the computation.
235
+
236
+ Returns:
237
+ The encoder features with the task-specific embeddings added.
238
+ """
239
+ assert len(features) == 1, "TaskMHAEmbedding only supports one feature"
240
+ x = torch.flatten(features[0], start_dim=2) # B x C x T, T = HW
241
+ if embeds is None:
242
+ embeds = self.compute_embeds(features, inputs) # B x T x C
243
+ out = self.mha(
244
+ embeds, # B x T x C
245
+ torch.einsum("bct->btc", x),
246
+ torch.einsum("bct->btc", x),
247
+ )[0] # B x T x C
248
+ out = torch.einsum("btc->bct", out)
249
+ out = out.view(*features[0].shape) # B x C x H x W
250
+ return [out]