rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,65 @@
1
+ """Module wrapper provided for backwards compatibility."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureExtractor,
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class EncoderModuleWrapper(FeatureExtractor):
17
+ """Wraps one or more IntermediateComponents to function as the feature extractor.
18
+
19
+ The first component should input a FeatureMaps, which will be computed from the
20
+ overall inputs by stacking the "image" key from each input dict.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ module: IntermediateComponent | None = None,
26
+ modules: list[IntermediateComponent] = [],
27
+ ):
28
+ """Initialize an EncoderModuleWrapper.
29
+
30
+ Args:
31
+ module: the IntermediateComponent to wrap for use as a FeatureExtractor.
32
+ Exactly one of module or modules must be set.
33
+ modules: list of modules to wrap
34
+ """
35
+ super().__init__()
36
+ if module is not None and len(modules) > 0:
37
+ raise ValueError("only one of module or modules should be set")
38
+ if module is not None:
39
+ self.encoder_modules = torch.nn.ModuleList([module])
40
+ elif len(modules) > 0:
41
+ self.encoder_modules = torch.nn.ModuleList(modules)
42
+ else:
43
+ raise ValueError("one of module or modules must be set")
44
+
45
+ def forward(self, context: ModelContext) -> Any:
46
+ """Compute outputs from the wrapped module.
47
+
48
+ Args:
49
+ context: the model context. Input dicts must include "image" key containing
50
+ the image to convert to a FeatureMaps, which will be passed to the
51
+ first wrapped module.
52
+
53
+ Returns:
54
+ the output from the last wrapped module.
55
+ """
56
+ # take the first and only timestep. Currently no intermediate
57
+ # components support multi temporal inputs, so if the input is
58
+ # multitemporal it should be wrapped in a simple time series wrapper.
59
+ images = torch.stack(
60
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
61
+ )
62
+ cur: Any = FeatureMaps([images])
63
+ for m in self.encoder_modules:
64
+ cur = m(cur, context)
65
+ return cur
@@ -0,0 +1,69 @@
1
+ """Molmo model."""
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+
6
+ from rslearn.train.model_context import ModelContext
7
+
8
+ from .component import FeatureExtractor, FeatureMaps
9
+
10
+
11
+ class Molmo(FeatureExtractor):
12
+ """Molmo image encoder."""
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str,
17
+ ):
18
+ """Instantiate a new Molmo instance.
19
+
20
+ Args:
21
+ model_name: the model name like "allenai/Molmo-7B-D-0924".
22
+ """
23
+ super().__init__()
24
+
25
+ self.processor = AutoProcessor.from_pretrained(
26
+ model_name,
27
+ trust_remote_code=True,
28
+ torch_dtype="auto",
29
+ device_map="cpu",
30
+ ) # nosec
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ model_name,
33
+ trust_remote_code=True,
34
+ torch_dtype="auto",
35
+ device_map="cpu",
36
+ ) # nosec
37
+ self.encoder = model.model.vision_backbone
38
+
39
+ def forward(self, context: ModelContext) -> FeatureMaps:
40
+ """Compute outputs from the backbone.
41
+
42
+ Args:
43
+ context: the model context. Input dicts must include "image" key containing
44
+ the image to process. The images should have values 0-255.
45
+
46
+ Returns:
47
+ a FeatureMaps. Molmo produces features at one scale, so it will contain one
48
+ feature map that is a Bx24x24x2048 tensor.
49
+ """
50
+ device = context.inputs[0]["image"].image.device
51
+ molmo_inputs_list = []
52
+ # Process each one so we can isolate just the full image without any crops.
53
+ for inp in context.inputs:
54
+ image = (
55
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
56
+ )
57
+ processed = self.processor.process(
58
+ images=[image],
59
+ text="",
60
+ )
61
+ molmo_inputs_list.append(processed["images"][0])
62
+ molmo_inputs: torch.Tensor = torch.stack(molmo_inputs_list, dim=0).unsqueeze(1)
63
+
64
+ image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
65
+
66
+ # 576x2048 -> 24x24x2048
67
+ return FeatureMaps(
68
+ [image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
69
+ )
@@ -1,9 +1,53 @@
1
1
  """MultiTaskModel for rslearn."""
2
2
 
3
+ from collections.abc import Iterable
4
+ from copy import deepcopy
3
5
  from typing import Any
4
6
 
5
7
  import torch
6
8
 
9
+ from rslearn.log_utils import get_logger
10
+ from rslearn.models.trunk import DecoderTrunk
11
+ from rslearn.train.model_context import ModelContext, ModelOutput
12
+
13
+ from .component import FeatureExtractor, IntermediateComponent, Predictor
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ def sort_keys(d: dict[str, Any]) -> dict[str, Any]:
19
+ """Recursively (half in place) sort the keys of a dictionary.
20
+
21
+ Need this so that the order of task embeddings indexing is consistent.
22
+
23
+ Args:
24
+ d (dict[str, Any]): The dictionary to sort.
25
+ """
26
+ d = {k: d[k] for k in sorted(d)}
27
+ for k, v in d.items():
28
+ if isinstance(v, dict):
29
+ d[k] = sort_keys(v)
30
+ return d
31
+
32
+
33
+ def deepcopy_tensordict(d: dict[Any, Any]) -> dict[Any, Any]:
34
+ """Deepcopy a dict with torch.Tensor, dict, and other types.
35
+
36
+ Make sure tensor copying is handled properly.
37
+
38
+ Args:
39
+ d: the dict to deepcopy
40
+ """
41
+ new_d = {}
42
+ for k, v in d.items():
43
+ if isinstance(v, torch.Tensor):
44
+ new_d[k] = torch.clone(v)
45
+ elif isinstance(v, dict):
46
+ new_d[k] = deepcopy_tensordict(v)
47
+ else:
48
+ new_d[k] = deepcopy(v)
49
+ return new_d
50
+
7
51
 
8
52
  class MultiTaskModel(torch.nn.Module):
9
53
  """MultiTask model wrapper.
@@ -12,54 +56,366 @@ class MultiTaskModel(torch.nn.Module):
12
56
 
13
57
  Then, it applies one sequential decoder for each configured task. It computes
14
58
  outputs and loss using the final module in the decoder.
59
+
60
+ Optionally include a shared trunk module to postprocess the encoder features.
15
61
  """
16
62
 
17
63
  def __init__(
18
- self, encoder: list[torch.nn.Module], decoders: dict[str, list[torch.nn.Module]]
64
+ self,
65
+ encoder: list[FeatureExtractor | IntermediateComponent],
66
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
67
+ lazy_decode: bool = False,
68
+ loss_weights: dict[str, float] | None = None,
69
+ trunk: DecoderTrunk | None = None,
19
70
  ):
20
71
  """Initialize a new MultiTaskModel.
21
72
 
22
73
  Args:
23
- encoder: modules to compute intermediate feature representations.
74
+ encoder: modules to compute intermediate feature representations. The first
75
+ module must be a FeatureExtractor, and following modules must be
76
+ IntermediateComponents.
24
77
  decoders: modules to compute outputs and loss, should match number of tasks.
78
+ The last module must be a Predictor, while the previous modules must be
79
+ IntermediateComponents.
80
+ lazy_decode: if True, only decode the outputs specified in the batch.
81
+ loss_weights: weights for each task's loss (default: None = equal weights).
82
+ trunk: if provided, use this trunk module to postprocess the features
83
+ (recommend including a task-specific embedding module here).
25
84
  """
26
85
  super().__init__()
27
- self.encoder = torch.nn.Sequential(*encoder)
86
+ self.lazy_decode = lazy_decode
87
+ self.encoder = torch.nn.ModuleList(encoder)
28
88
  self.decoders = torch.nn.ModuleDict(
29
- {name: torch.nn.ModuleList(decoder) for name, decoder in decoders.items()}
89
+ sort_keys(
90
+ {
91
+ name: torch.nn.ModuleList(decoder)
92
+ for name, decoder in decoders.items()
93
+ }
94
+ )
95
+ )
96
+ self._init_loss_weights(loss_weights, list(self.decoders.keys()))
97
+ self._init_trunk(trunk, list(self.decoders.keys()))
98
+
99
+ def _init_loss_weights(
100
+ self, loss_weights: dict[str, float] | None, task_names: list[str]
101
+ ) -> None:
102
+ """Initialize the loss weights for the tasks.
103
+
104
+ Args:
105
+ loss_weights: weights for each task's loss (default: None = equal weights).
106
+ task_names: list of task names.
107
+ """
108
+ if loss_weights is None:
109
+ loss_weights = {name: 1.0 for name in task_names}
110
+ for name in task_names:
111
+ if name not in loss_weights:
112
+ logger.warning(f"task {name} not in loss_weights, setting to 1.0")
113
+ loss_weights[name] = 1.0
114
+ self.loss_weights = sort_keys(loss_weights)
115
+ logger.info(f"loss_weights: {self.loss_weights}")
116
+
117
+ def _init_trunk(self, trunk: DecoderTrunk | None, task_names: list[str]) -> None:
118
+ """Initialize the trunk module.
119
+
120
+ Args:
121
+ trunk: the trunk module.
122
+ task_names: list of task names.
123
+ """
124
+ self.trunk = trunk
125
+ if trunk is not None:
126
+ trunk.register_tasks(task_names)
127
+ logger.info("registered decoders with trunk")
128
+
129
+ def apply_decoder(
130
+ self,
131
+ intermediates: Any,
132
+ context: ModelContext,
133
+ targets: list[dict[str, Any]] | None,
134
+ decoder: list[IntermediateComponent | Predictor],
135
+ task_name: str,
136
+ ) -> ModelOutput:
137
+ """Apply a decoder to a list of inputs and targets.
138
+
139
+ Args:
140
+ intermediates: the intermediate output from the encoder.
141
+ context: the model context.
142
+ targets: list of target dicts
143
+ decoder: list of decoder modules
144
+ task_name: the name of the task
145
+
146
+ Returns:
147
+ a ModelOutput containing outputs across all the decoders.
148
+ """
149
+ # First, apply all but the last module in the decoder to the features
150
+ cur = intermediates
151
+ for module in decoder[:-1]:
152
+ cur = module(cur, context)
153
+
154
+ if targets is None:
155
+ cur_targets = None
156
+ else:
157
+ cur_targets = [target[task_name] for target in targets]
158
+
159
+ # Then, apply the last module to the features and targets
160
+ return decoder[-1](cur, context, cur_targets)
161
+
162
+ def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
163
+ """Get the tasks corresponding to this decoder.
164
+
165
+ Args:
166
+ decoder: the name of the decoder
167
+ """
168
+ return [decoder]
169
+
170
+ def apply_decoders(
171
+ self,
172
+ intermediates: Any,
173
+ context: ModelContext,
174
+ targets: list[dict[str, Any]] | None,
175
+ ) -> ModelOutput:
176
+ """Apply all the decoders to the features and targets.
177
+
178
+ Args:
179
+ intermediates: the intermediates from the encoder.
180
+ context: the model context
181
+ targets: list of target dicts
182
+
183
+ Returns:
184
+ combined ModelOutput. The outputs is a list of output dicts, one per example,
185
+ where the dict maps from task name to the corresponding task output. The
186
+ losses is a flat dict but the task name is prepended to the loss names.
187
+ """
188
+ outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
189
+ losses: dict[str, torch.Tensor] = {}
190
+
191
+ if self.lazy_decode:
192
+ # Assume that all inputs have the same dataset_source
193
+ task_name = context.metadatas[0].dataset_source
194
+
195
+ if task_name is None:
196
+ raise ValueError("dataset_source must be set for lazy decoding")
197
+
198
+ decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
199
+ model_output = self.apply_decoder(
200
+ intermediates, context, targets, decoder, task_name
201
+ )
202
+ for idx, entry in enumerate(model_output.outputs):
203
+ outputs[idx][task_name] = entry
204
+ for loss_name, loss_value in model_output.loss_dict.items():
205
+ losses[f"{task_name}_{loss_name}"] = (
206
+ loss_value * self.loss_weights[task_name]
207
+ )
208
+ else:
209
+ for decoder_name, decoder in self.decoders.items():
210
+ for task_name in self._get_tasks_from_decoder(decoder_name):
211
+ model_output = self.apply_decoder(
212
+ intermediates, context, targets, decoder, task_name
213
+ )
214
+ for idx, entry in enumerate(model_output.outputs):
215
+ outputs[idx][task_name] = entry
216
+ for loss_name, loss_value in model_output.loss_dict.items():
217
+ losses[f"{task_name}_{loss_name}"] = (
218
+ loss_value * self.loss_weights[task_name]
219
+ )
220
+
221
+ return ModelOutput(
222
+ outputs=outputs,
223
+ loss_dict=losses,
30
224
  )
31
225
 
32
226
  def forward(
33
227
  self,
34
- inputs: list[dict[str, Any]],
228
+ context: ModelContext,
35
229
  targets: list[dict[str, Any]] | None = None,
36
- ) -> tuple[list[dict[str, Any]], dict[str, torch.Tensor]]:
230
+ ) -> ModelOutput:
231
+ """Apply the sequence of modules on the inputs, including shared trunk.
232
+
233
+ Args:
234
+ context: the model context.
235
+ targets: optional list of target dicts
236
+
237
+ Returns:
238
+ the model output from apply_decoders.
239
+ """
240
+ cur = self.encoder[0](context)
241
+ for module in self.encoder[1:]:
242
+ cur = module(cur, context)
243
+ if self.trunk is not None:
244
+ trunk_out = self.trunk(cur, context)
245
+ outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
246
+ self.trunk.apply_auxiliary_losses(trunk_out, outs)
247
+ return outs | trunk_out
248
+ else:
249
+ return self.apply_decoders(cur, context, targets)
250
+
251
+
252
+ class MultiTaskMergedModel(MultiTaskModel):
253
+ """Similar to MultiTaskModel, but allow merging in label space.
254
+
255
+ For example, if you have two classification tasks with N and M labels each, this will
256
+ handle generating an output layer with N+M layers and the corresponding modification
257
+ of targets/predictions/metrics.
258
+
259
+ Applies one sequential decoder for each configured task. It computes
260
+ outputs and loss using the final module in the decoder.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ encoder: list[FeatureExtractor | IntermediateComponent],
266
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
267
+ decoder_to_target: dict[str, list[str]],
268
+ task_label_offsets: dict[str, dict[str, Any]],
269
+ lazy_decode: bool = False,
270
+ loss_weights: dict[str, float] | None = None,
271
+ trunk: DecoderTrunk | None = None,
272
+ ):
273
+ """Initialize a new MultiTaskModel.
274
+
275
+ Args:
276
+ encoder: modules to compute intermediate feature representations.
277
+ decoders: modules to compute outputs and loss, should match number of tasks.
278
+ decoder_to_target: mapping from decoder id to list of task names
279
+ (specify if merging heads, otherwise leave as None).
280
+ task_label_offsets: mapping from task name to dict of info (output_key, offset)
281
+ (specify if merging label groups across a single task).
282
+ lazy_decode: if True, only decode the outputs specified in the batch.
283
+ loss_weights: weights for each task's loss (default: None = equal weights).
284
+ trunk: if provided, use this trunk module to postprocess the features
285
+ (recommend including a task-specific embedding module here).
286
+ """
287
+ # Can't use super() because we need to skip calls to _init_loss_weights and _init_trunk
288
+ torch.nn.Module.__init__(self)
289
+
290
+ self.lazy_decode = lazy_decode
291
+ self.encoder = torch.nn.ModuleList(encoder)
292
+ self.decoders = torch.nn.ModuleDict(
293
+ sort_keys(
294
+ {
295
+ name: torch.nn.ModuleList(decoder)
296
+ for name, decoder in decoders.items()
297
+ }
298
+ )
299
+ )
300
+ self.task_label_offsets = task_label_offsets
301
+
302
+ self.decoder_to_target = sort_keys(decoder_to_target)
303
+ logger.info(f"merged decoders: {self.decoder_to_target}")
304
+
305
+ self.target_to_decoder = {}
306
+ for decoder_id, task_names in self.decoder_to_target.items():
307
+ for task_name in task_names:
308
+ self.target_to_decoder[task_name] = decoder_id
309
+ self.target_to_decoder = sort_keys(self.target_to_decoder)
310
+
311
+ self._init_loss_weights(loss_weights, list(self.target_to_decoder.keys()))
312
+ self._init_trunk(trunk, list(self.target_to_decoder.keys()))
313
+
314
+ def merge_task_labels(
315
+ self,
316
+ targets: list[dict[str, Any]] | None,
317
+ task_name: str,
318
+ ) -> list[dict[str, Any]] | None:
319
+ """Merge the task labels by adding an offset to the label key.
320
+
321
+ Make a clone before doing this because we may use targets elsewhere.
322
+
323
+ Args:
324
+ targets: the target dicts
325
+ task_name: the name of the task
326
+ """
327
+ if targets is None:
328
+ return targets
329
+ offset = self.task_label_offsets[task_name]["offset"]
330
+ outputs_key = self.task_label_offsets[task_name]["outputs_key"]
331
+ offset_targets = []
332
+ for target in targets:
333
+ offset_target = deepcopy_tensordict(target)
334
+ spliced = offset_target[task_name]
335
+ if torch.is_floating_point(spliced[outputs_key]):
336
+ logger.warning(
337
+ f"task {task_name} has targets of type "
338
+ f"{spliced[outputs_key].dtype}, "
339
+ f"expected int (shape {spliced[outputs_key].shape})"
340
+ )
341
+ with torch.no_grad():
342
+ spliced[outputs_key] += offset
343
+ offset_targets.append(offset_target)
344
+ return offset_targets
345
+
346
+ def unmerge_output_labels(
347
+ self, outputs: Iterable[Any], task_name: str
348
+ ) -> list[dict[str, torch.Tensor | dict]]:
349
+ """Unmerge the task outputs.
350
+
351
+ For most tasks, this means chopping off the corresponding label dimensions.
352
+ For some, we might just need to subtract an offset from the target (ex: segmentation).
353
+ Assume first dimension is the number of outputs.
354
+
355
+ Args:
356
+ outputs: the predictions
357
+ task_name: the name of the task
358
+
359
+ Returns:
360
+ the unmerged outputs.
361
+ """
362
+ offset = self.task_label_offsets[task_name]["offset"]
363
+ num_outputs = self.task_label_offsets[task_name]["num_outputs"]
364
+ output_key = self.task_label_offsets[task_name]["outputs_key"]
365
+
366
+ unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
367
+ with torch.no_grad():
368
+ for i, output in enumerate(outputs):
369
+ if not output:
370
+ # Possible if there are no detections
371
+ continue
372
+ output = output[task_name]
373
+ if isinstance(output, dict):
374
+ # For some tasks (eg object detection), we have discrete label
375
+ # predictions instead of a distribution over labels
376
+ unmerged_output = output.copy()
377
+ unmerged_output[output_key] = unmerged_output[output_key] - offset
378
+ unmerged_outputs[i][task_name] = unmerged_output
379
+ elif isinstance(output, torch.Tensor):
380
+ # For classification/segmentation tasks, we have a distribution
381
+ # over labels, so we need to scale the predictions so that they
382
+ # sum to 1 since we chop off some of the probability densities
383
+ unmerged_output = output[offset : offset + num_outputs, ...]
384
+ unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
385
+ torch.float32
386
+ )
387
+ unmerged_outputs[i][task_name] = unmerged_output
388
+
389
+ return unmerged_outputs
390
+
391
+ def forward(
392
+ self,
393
+ context: ModelContext,
394
+ targets: list[dict[str, Any]] | None = None,
395
+ ) -> ModelOutput:
37
396
  """Apply the sequence of modules on the inputs.
38
397
 
39
398
  Args:
40
- inputs: list of input dicts
399
+ context: the model context.
41
400
  targets: optional list of target dicts
42
401
 
43
402
  Returns:
44
- tuple (outputs, loss_dict) from the last module.
403
+ the model output.
404
+ """
405
+ dataset_source = context.metadatas[0].dataset_source
406
+ assert dataset_source is not None
407
+ merged_targets = self.merge_task_labels(targets, dataset_source)
408
+ outs = super().forward(context, merged_targets)
409
+ unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
410
+ return ModelOutput(
411
+ outputs=unmerged_outputs,
412
+ loss_dict=outs.loss_dict,
413
+ )
414
+
415
+ def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
416
+ """Get the tasks corresponding to this decoder.
417
+
418
+ Args:
419
+ decoder: the name of the decoder
45
420
  """
46
- features = self.encoder(inputs)
47
- outputs = [{} for _ in inputs]
48
- losses = {}
49
- for name, decoder in self.decoders.items():
50
- cur = features
51
- for module in decoder[:-1]:
52
- cur = module(cur, inputs)
53
-
54
- if targets is None:
55
- cur_targets = None
56
- else:
57
- cur_targets = [target[name] for target in targets]
58
-
59
- cur_output, cur_loss_dict = decoder[-1](cur, inputs, cur_targets)
60
-
61
- for idx, entry in enumerate(cur_output):
62
- outputs[idx][name] = entry
63
- for loss_name, loss_value in cur_loss_dict.items():
64
- losses[f"{name}_{loss_name}"] = loss_value
65
- return outputs, losses
421
+ return self.decoder_to_target[decoder]
@@ -0,0 +1 @@
1
+ """OlmoEarth model architecture."""