rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Module wrapper provided for backwards compatibility."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureExtractor,
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EncoderModuleWrapper(FeatureExtractor):
|
|
17
|
+
"""Wraps one or more IntermediateComponents to function as the feature extractor.
|
|
18
|
+
|
|
19
|
+
The first component should input a FeatureMaps, which will be computed from the
|
|
20
|
+
overall inputs by stacking the "image" key from each input dict.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
module: IntermediateComponent | None = None,
|
|
26
|
+
modules: list[IntermediateComponent] = [],
|
|
27
|
+
):
|
|
28
|
+
"""Initialize an EncoderModuleWrapper.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
module: the IntermediateComponent to wrap for use as a FeatureExtractor.
|
|
32
|
+
Exactly one of module or modules must be set.
|
|
33
|
+
modules: list of modules to wrap
|
|
34
|
+
"""
|
|
35
|
+
super().__init__()
|
|
36
|
+
if module is not None and len(modules) > 0:
|
|
37
|
+
raise ValueError("only one of module or modules should be set")
|
|
38
|
+
if module is not None:
|
|
39
|
+
self.encoder_modules = torch.nn.ModuleList([module])
|
|
40
|
+
elif len(modules) > 0:
|
|
41
|
+
self.encoder_modules = torch.nn.ModuleList(modules)
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError("one of module or modules must be set")
|
|
44
|
+
|
|
45
|
+
def forward(self, context: ModelContext) -> Any:
|
|
46
|
+
"""Compute outputs from the wrapped module.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
context: the model context. Input dicts must include "image" key containing
|
|
50
|
+
the image to convert to a FeatureMaps, which will be passed to the
|
|
51
|
+
first wrapped module.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
the output from the last wrapped module.
|
|
55
|
+
"""
|
|
56
|
+
# take the first and only timestep. Currently no intermediate
|
|
57
|
+
# components support multi temporal inputs, so if the input is
|
|
58
|
+
# multitemporal it should be wrapped in a simple time series wrapper.
|
|
59
|
+
images = torch.stack(
|
|
60
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
61
|
+
)
|
|
62
|
+
cur: Any = FeatureMaps([images])
|
|
63
|
+
for m in self.encoder_modules:
|
|
64
|
+
cur = m(cur, context)
|
|
65
|
+
return cur
|
rslearn/models/molmo.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Molmo model."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
5
|
+
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Molmo(FeatureExtractor):
|
|
12
|
+
"""Molmo image encoder."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model_name: str,
|
|
17
|
+
):
|
|
18
|
+
"""Instantiate a new Molmo instance.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
model_name: the model name like "allenai/Molmo-7B-D-0924".
|
|
22
|
+
"""
|
|
23
|
+
super().__init__()
|
|
24
|
+
|
|
25
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
26
|
+
model_name,
|
|
27
|
+
trust_remote_code=True,
|
|
28
|
+
torch_dtype="auto",
|
|
29
|
+
device_map="cpu",
|
|
30
|
+
) # nosec
|
|
31
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
32
|
+
model_name,
|
|
33
|
+
trust_remote_code=True,
|
|
34
|
+
torch_dtype="auto",
|
|
35
|
+
device_map="cpu",
|
|
36
|
+
) # nosec
|
|
37
|
+
self.encoder = model.model.vision_backbone
|
|
38
|
+
|
|
39
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
40
|
+
"""Compute outputs from the backbone.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
context: the model context. Input dicts must include "image" key containing
|
|
44
|
+
the image to process. The images should have values 0-255.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
a FeatureMaps. Molmo produces features at one scale, so it will contain one
|
|
48
|
+
feature map that is a Bx24x24x2048 tensor.
|
|
49
|
+
"""
|
|
50
|
+
device = context.inputs[0]["image"].image.device
|
|
51
|
+
molmo_inputs_list = []
|
|
52
|
+
# Process each one so we can isolate just the full image without any crops.
|
|
53
|
+
for inp in context.inputs:
|
|
54
|
+
image = (
|
|
55
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
56
|
+
)
|
|
57
|
+
processed = self.processor.process(
|
|
58
|
+
images=[image],
|
|
59
|
+
text="",
|
|
60
|
+
)
|
|
61
|
+
molmo_inputs_list.append(processed["images"][0])
|
|
62
|
+
molmo_inputs: torch.Tensor = torch.stack(molmo_inputs_list, dim=0).unsqueeze(1)
|
|
63
|
+
|
|
64
|
+
image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
|
|
65
|
+
|
|
66
|
+
# 576x2048 -> 24x24x2048
|
|
67
|
+
return FeatureMaps(
|
|
68
|
+
[image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)]
|
|
69
|
+
)
|
rslearn/models/multitask.py
CHANGED
|
@@ -1,9 +1,53 @@
|
|
|
1
1
|
"""MultiTaskModel for rslearn."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from copy import deepcopy
|
|
3
5
|
from typing import Any
|
|
4
6
|
|
|
5
7
|
import torch
|
|
6
8
|
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
from rslearn.models.trunk import DecoderTrunk
|
|
11
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
12
|
+
|
|
13
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def sort_keys(d: dict[str, Any]) -> dict[str, Any]:
|
|
19
|
+
"""Recursively (half in place) sort the keys of a dictionary.
|
|
20
|
+
|
|
21
|
+
Need this so that the order of task embeddings indexing is consistent.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
d (dict[str, Any]): The dictionary to sort.
|
|
25
|
+
"""
|
|
26
|
+
d = {k: d[k] for k in sorted(d)}
|
|
27
|
+
for k, v in d.items():
|
|
28
|
+
if isinstance(v, dict):
|
|
29
|
+
d[k] = sort_keys(v)
|
|
30
|
+
return d
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def deepcopy_tensordict(d: dict[Any, Any]) -> dict[Any, Any]:
|
|
34
|
+
"""Deepcopy a dict with torch.Tensor, dict, and other types.
|
|
35
|
+
|
|
36
|
+
Make sure tensor copying is handled properly.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
d: the dict to deepcopy
|
|
40
|
+
"""
|
|
41
|
+
new_d = {}
|
|
42
|
+
for k, v in d.items():
|
|
43
|
+
if isinstance(v, torch.Tensor):
|
|
44
|
+
new_d[k] = torch.clone(v)
|
|
45
|
+
elif isinstance(v, dict):
|
|
46
|
+
new_d[k] = deepcopy_tensordict(v)
|
|
47
|
+
else:
|
|
48
|
+
new_d[k] = deepcopy(v)
|
|
49
|
+
return new_d
|
|
50
|
+
|
|
7
51
|
|
|
8
52
|
class MultiTaskModel(torch.nn.Module):
|
|
9
53
|
"""MultiTask model wrapper.
|
|
@@ -12,54 +56,366 @@ class MultiTaskModel(torch.nn.Module):
|
|
|
12
56
|
|
|
13
57
|
Then, it applies one sequential decoder for each configured task. It computes
|
|
14
58
|
outputs and loss using the final module in the decoder.
|
|
59
|
+
|
|
60
|
+
Optionally include a shared trunk module to postprocess the encoder features.
|
|
15
61
|
"""
|
|
16
62
|
|
|
17
63
|
def __init__(
|
|
18
|
-
self,
|
|
64
|
+
self,
|
|
65
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
66
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
67
|
+
lazy_decode: bool = False,
|
|
68
|
+
loss_weights: dict[str, float] | None = None,
|
|
69
|
+
trunk: DecoderTrunk | None = None,
|
|
19
70
|
):
|
|
20
71
|
"""Initialize a new MultiTaskModel.
|
|
21
72
|
|
|
22
73
|
Args:
|
|
23
|
-
encoder: modules to compute intermediate feature representations.
|
|
74
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
75
|
+
module must be a FeatureExtractor, and following modules must be
|
|
76
|
+
IntermediateComponents.
|
|
24
77
|
decoders: modules to compute outputs and loss, should match number of tasks.
|
|
78
|
+
The last module must be a Predictor, while the previous modules must be
|
|
79
|
+
IntermediateComponents.
|
|
80
|
+
lazy_decode: if True, only decode the outputs specified in the batch.
|
|
81
|
+
loss_weights: weights for each task's loss (default: None = equal weights).
|
|
82
|
+
trunk: if provided, use this trunk module to postprocess the features
|
|
83
|
+
(recommend including a task-specific embedding module here).
|
|
25
84
|
"""
|
|
26
85
|
super().__init__()
|
|
27
|
-
self.
|
|
86
|
+
self.lazy_decode = lazy_decode
|
|
87
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
28
88
|
self.decoders = torch.nn.ModuleDict(
|
|
29
|
-
|
|
89
|
+
sort_keys(
|
|
90
|
+
{
|
|
91
|
+
name: torch.nn.ModuleList(decoder)
|
|
92
|
+
for name, decoder in decoders.items()
|
|
93
|
+
}
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
self._init_loss_weights(loss_weights, list(self.decoders.keys()))
|
|
97
|
+
self._init_trunk(trunk, list(self.decoders.keys()))
|
|
98
|
+
|
|
99
|
+
def _init_loss_weights(
|
|
100
|
+
self, loss_weights: dict[str, float] | None, task_names: list[str]
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Initialize the loss weights for the tasks.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
loss_weights: weights for each task's loss (default: None = equal weights).
|
|
106
|
+
task_names: list of task names.
|
|
107
|
+
"""
|
|
108
|
+
if loss_weights is None:
|
|
109
|
+
loss_weights = {name: 1.0 for name in task_names}
|
|
110
|
+
for name in task_names:
|
|
111
|
+
if name not in loss_weights:
|
|
112
|
+
logger.warning(f"task {name} not in loss_weights, setting to 1.0")
|
|
113
|
+
loss_weights[name] = 1.0
|
|
114
|
+
self.loss_weights = sort_keys(loss_weights)
|
|
115
|
+
logger.info(f"loss_weights: {self.loss_weights}")
|
|
116
|
+
|
|
117
|
+
def _init_trunk(self, trunk: DecoderTrunk | None, task_names: list[str]) -> None:
|
|
118
|
+
"""Initialize the trunk module.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
trunk: the trunk module.
|
|
122
|
+
task_names: list of task names.
|
|
123
|
+
"""
|
|
124
|
+
self.trunk = trunk
|
|
125
|
+
if trunk is not None:
|
|
126
|
+
trunk.register_tasks(task_names)
|
|
127
|
+
logger.info("registered decoders with trunk")
|
|
128
|
+
|
|
129
|
+
def apply_decoder(
|
|
130
|
+
self,
|
|
131
|
+
intermediates: Any,
|
|
132
|
+
context: ModelContext,
|
|
133
|
+
targets: list[dict[str, Any]] | None,
|
|
134
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
135
|
+
task_name: str,
|
|
136
|
+
) -> ModelOutput:
|
|
137
|
+
"""Apply a decoder to a list of inputs and targets.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
intermediates: the intermediate output from the encoder.
|
|
141
|
+
context: the model context.
|
|
142
|
+
targets: list of target dicts
|
|
143
|
+
decoder: list of decoder modules
|
|
144
|
+
task_name: the name of the task
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
a ModelOutput containing outputs across all the decoders.
|
|
148
|
+
"""
|
|
149
|
+
# First, apply all but the last module in the decoder to the features
|
|
150
|
+
cur = intermediates
|
|
151
|
+
for module in decoder[:-1]:
|
|
152
|
+
cur = module(cur, context)
|
|
153
|
+
|
|
154
|
+
if targets is None:
|
|
155
|
+
cur_targets = None
|
|
156
|
+
else:
|
|
157
|
+
cur_targets = [target[task_name] for target in targets]
|
|
158
|
+
|
|
159
|
+
# Then, apply the last module to the features and targets
|
|
160
|
+
return decoder[-1](cur, context, cur_targets)
|
|
161
|
+
|
|
162
|
+
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
163
|
+
"""Get the tasks corresponding to this decoder.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
decoder: the name of the decoder
|
|
167
|
+
"""
|
|
168
|
+
return [decoder]
|
|
169
|
+
|
|
170
|
+
def apply_decoders(
|
|
171
|
+
self,
|
|
172
|
+
intermediates: Any,
|
|
173
|
+
context: ModelContext,
|
|
174
|
+
targets: list[dict[str, Any]] | None,
|
|
175
|
+
) -> ModelOutput:
|
|
176
|
+
"""Apply all the decoders to the features and targets.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
intermediates: the intermediates from the encoder.
|
|
180
|
+
context: the model context
|
|
181
|
+
targets: list of target dicts
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
combined ModelOutput. The outputs is a list of output dicts, one per example,
|
|
185
|
+
where the dict maps from task name to the corresponding task output. The
|
|
186
|
+
losses is a flat dict but the task name is prepended to the loss names.
|
|
187
|
+
"""
|
|
188
|
+
outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
|
|
189
|
+
losses: dict[str, torch.Tensor] = {}
|
|
190
|
+
|
|
191
|
+
if self.lazy_decode:
|
|
192
|
+
# Assume that all inputs have the same dataset_source
|
|
193
|
+
task_name = context.metadatas[0].dataset_source
|
|
194
|
+
|
|
195
|
+
if task_name is None:
|
|
196
|
+
raise ValueError("dataset_source must be set for lazy decoding")
|
|
197
|
+
|
|
198
|
+
decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
|
|
199
|
+
model_output = self.apply_decoder(
|
|
200
|
+
intermediates, context, targets, decoder, task_name
|
|
201
|
+
)
|
|
202
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
203
|
+
outputs[idx][task_name] = entry
|
|
204
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
205
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
206
|
+
loss_value * self.loss_weights[task_name]
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
for decoder_name, decoder in self.decoders.items():
|
|
210
|
+
for task_name in self._get_tasks_from_decoder(decoder_name):
|
|
211
|
+
model_output = self.apply_decoder(
|
|
212
|
+
intermediates, context, targets, decoder, task_name
|
|
213
|
+
)
|
|
214
|
+
for idx, entry in enumerate(model_output.outputs):
|
|
215
|
+
outputs[idx][task_name] = entry
|
|
216
|
+
for loss_name, loss_value in model_output.loss_dict.items():
|
|
217
|
+
losses[f"{task_name}_{loss_name}"] = (
|
|
218
|
+
loss_value * self.loss_weights[task_name]
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return ModelOutput(
|
|
222
|
+
outputs=outputs,
|
|
223
|
+
loss_dict=losses,
|
|
30
224
|
)
|
|
31
225
|
|
|
32
226
|
def forward(
|
|
33
227
|
self,
|
|
34
|
-
|
|
228
|
+
context: ModelContext,
|
|
35
229
|
targets: list[dict[str, Any]] | None = None,
|
|
36
|
-
) ->
|
|
230
|
+
) -> ModelOutput:
|
|
231
|
+
"""Apply the sequence of modules on the inputs, including shared trunk.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
context: the model context.
|
|
235
|
+
targets: optional list of target dicts
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
the model output from apply_decoders.
|
|
239
|
+
"""
|
|
240
|
+
cur = self.encoder[0](context)
|
|
241
|
+
for module in self.encoder[1:]:
|
|
242
|
+
cur = module(cur, context)
|
|
243
|
+
if self.trunk is not None:
|
|
244
|
+
trunk_out = self.trunk(cur, context)
|
|
245
|
+
outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
|
|
246
|
+
self.trunk.apply_auxiliary_losses(trunk_out, outs)
|
|
247
|
+
return outs | trunk_out
|
|
248
|
+
else:
|
|
249
|
+
return self.apply_decoders(cur, context, targets)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class MultiTaskMergedModel(MultiTaskModel):
|
|
253
|
+
"""Similar to MultiTaskModel, but allow merging in label space.
|
|
254
|
+
|
|
255
|
+
For example, if you have two classification tasks with N and M labels each, this will
|
|
256
|
+
handle generating an output layer with N+M layers and the corresponding modification
|
|
257
|
+
of targets/predictions/metrics.
|
|
258
|
+
|
|
259
|
+
Applies one sequential decoder for each configured task. It computes
|
|
260
|
+
outputs and loss using the final module in the decoder.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
266
|
+
decoders: dict[str, list[IntermediateComponent | Predictor]],
|
|
267
|
+
decoder_to_target: dict[str, list[str]],
|
|
268
|
+
task_label_offsets: dict[str, dict[str, Any]],
|
|
269
|
+
lazy_decode: bool = False,
|
|
270
|
+
loss_weights: dict[str, float] | None = None,
|
|
271
|
+
trunk: DecoderTrunk | None = None,
|
|
272
|
+
):
|
|
273
|
+
"""Initialize a new MultiTaskModel.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
encoder: modules to compute intermediate feature representations.
|
|
277
|
+
decoders: modules to compute outputs and loss, should match number of tasks.
|
|
278
|
+
decoder_to_target: mapping from decoder id to list of task names
|
|
279
|
+
(specify if merging heads, otherwise leave as None).
|
|
280
|
+
task_label_offsets: mapping from task name to dict of info (output_key, offset)
|
|
281
|
+
(specify if merging label groups across a single task).
|
|
282
|
+
lazy_decode: if True, only decode the outputs specified in the batch.
|
|
283
|
+
loss_weights: weights for each task's loss (default: None = equal weights).
|
|
284
|
+
trunk: if provided, use this trunk module to postprocess the features
|
|
285
|
+
(recommend including a task-specific embedding module here).
|
|
286
|
+
"""
|
|
287
|
+
# Can't use super() because we need to skip calls to _init_loss_weights and _init_trunk
|
|
288
|
+
torch.nn.Module.__init__(self)
|
|
289
|
+
|
|
290
|
+
self.lazy_decode = lazy_decode
|
|
291
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
292
|
+
self.decoders = torch.nn.ModuleDict(
|
|
293
|
+
sort_keys(
|
|
294
|
+
{
|
|
295
|
+
name: torch.nn.ModuleList(decoder)
|
|
296
|
+
for name, decoder in decoders.items()
|
|
297
|
+
}
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
self.task_label_offsets = task_label_offsets
|
|
301
|
+
|
|
302
|
+
self.decoder_to_target = sort_keys(decoder_to_target)
|
|
303
|
+
logger.info(f"merged decoders: {self.decoder_to_target}")
|
|
304
|
+
|
|
305
|
+
self.target_to_decoder = {}
|
|
306
|
+
for decoder_id, task_names in self.decoder_to_target.items():
|
|
307
|
+
for task_name in task_names:
|
|
308
|
+
self.target_to_decoder[task_name] = decoder_id
|
|
309
|
+
self.target_to_decoder = sort_keys(self.target_to_decoder)
|
|
310
|
+
|
|
311
|
+
self._init_loss_weights(loss_weights, list(self.target_to_decoder.keys()))
|
|
312
|
+
self._init_trunk(trunk, list(self.target_to_decoder.keys()))
|
|
313
|
+
|
|
314
|
+
def merge_task_labels(
|
|
315
|
+
self,
|
|
316
|
+
targets: list[dict[str, Any]] | None,
|
|
317
|
+
task_name: str,
|
|
318
|
+
) -> list[dict[str, Any]] | None:
|
|
319
|
+
"""Merge the task labels by adding an offset to the label key.
|
|
320
|
+
|
|
321
|
+
Make a clone before doing this because we may use targets elsewhere.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
targets: the target dicts
|
|
325
|
+
task_name: the name of the task
|
|
326
|
+
"""
|
|
327
|
+
if targets is None:
|
|
328
|
+
return targets
|
|
329
|
+
offset = self.task_label_offsets[task_name]["offset"]
|
|
330
|
+
outputs_key = self.task_label_offsets[task_name]["outputs_key"]
|
|
331
|
+
offset_targets = []
|
|
332
|
+
for target in targets:
|
|
333
|
+
offset_target = deepcopy_tensordict(target)
|
|
334
|
+
spliced = offset_target[task_name]
|
|
335
|
+
if torch.is_floating_point(spliced[outputs_key]):
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"task {task_name} has targets of type "
|
|
338
|
+
f"{spliced[outputs_key].dtype}, "
|
|
339
|
+
f"expected int (shape {spliced[outputs_key].shape})"
|
|
340
|
+
)
|
|
341
|
+
with torch.no_grad():
|
|
342
|
+
spliced[outputs_key] += offset
|
|
343
|
+
offset_targets.append(offset_target)
|
|
344
|
+
return offset_targets
|
|
345
|
+
|
|
346
|
+
def unmerge_output_labels(
|
|
347
|
+
self, outputs: Iterable[Any], task_name: str
|
|
348
|
+
) -> list[dict[str, torch.Tensor | dict]]:
|
|
349
|
+
"""Unmerge the task outputs.
|
|
350
|
+
|
|
351
|
+
For most tasks, this means chopping off the corresponding label dimensions.
|
|
352
|
+
For some, we might just need to subtract an offset from the target (ex: segmentation).
|
|
353
|
+
Assume first dimension is the number of outputs.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
outputs: the predictions
|
|
357
|
+
task_name: the name of the task
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
the unmerged outputs.
|
|
361
|
+
"""
|
|
362
|
+
offset = self.task_label_offsets[task_name]["offset"]
|
|
363
|
+
num_outputs = self.task_label_offsets[task_name]["num_outputs"]
|
|
364
|
+
output_key = self.task_label_offsets[task_name]["outputs_key"]
|
|
365
|
+
|
|
366
|
+
unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
|
|
367
|
+
with torch.no_grad():
|
|
368
|
+
for i, output in enumerate(outputs):
|
|
369
|
+
if not output:
|
|
370
|
+
# Possible if there are no detections
|
|
371
|
+
continue
|
|
372
|
+
output = output[task_name]
|
|
373
|
+
if isinstance(output, dict):
|
|
374
|
+
# For some tasks (eg object detection), we have discrete label
|
|
375
|
+
# predictions instead of a distribution over labels
|
|
376
|
+
unmerged_output = output.copy()
|
|
377
|
+
unmerged_output[output_key] = unmerged_output[output_key] - offset
|
|
378
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
379
|
+
elif isinstance(output, torch.Tensor):
|
|
380
|
+
# For classification/segmentation tasks, we have a distribution
|
|
381
|
+
# over labels, so we need to scale the predictions so that they
|
|
382
|
+
# sum to 1 since we chop off some of the probability densities
|
|
383
|
+
unmerged_output = output[offset : offset + num_outputs, ...]
|
|
384
|
+
unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
|
|
385
|
+
torch.float32
|
|
386
|
+
)
|
|
387
|
+
unmerged_outputs[i][task_name] = unmerged_output
|
|
388
|
+
|
|
389
|
+
return unmerged_outputs
|
|
390
|
+
|
|
391
|
+
def forward(
|
|
392
|
+
self,
|
|
393
|
+
context: ModelContext,
|
|
394
|
+
targets: list[dict[str, Any]] | None = None,
|
|
395
|
+
) -> ModelOutput:
|
|
37
396
|
"""Apply the sequence of modules on the inputs.
|
|
38
397
|
|
|
39
398
|
Args:
|
|
40
|
-
|
|
399
|
+
context: the model context.
|
|
41
400
|
targets: optional list of target dicts
|
|
42
401
|
|
|
43
402
|
Returns:
|
|
44
|
-
|
|
403
|
+
the model output.
|
|
404
|
+
"""
|
|
405
|
+
dataset_source = context.metadatas[0].dataset_source
|
|
406
|
+
assert dataset_source is not None
|
|
407
|
+
merged_targets = self.merge_task_labels(targets, dataset_source)
|
|
408
|
+
outs = super().forward(context, merged_targets)
|
|
409
|
+
unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
|
|
410
|
+
return ModelOutput(
|
|
411
|
+
outputs=unmerged_outputs,
|
|
412
|
+
loss_dict=outs.loss_dict,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
|
|
416
|
+
"""Get the tasks corresponding to this decoder.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
decoder: the name of the decoder
|
|
45
420
|
"""
|
|
46
|
-
|
|
47
|
-
outputs = [{} for _ in inputs]
|
|
48
|
-
losses = {}
|
|
49
|
-
for name, decoder in self.decoders.items():
|
|
50
|
-
cur = features
|
|
51
|
-
for module in decoder[:-1]:
|
|
52
|
-
cur = module(cur, inputs)
|
|
53
|
-
|
|
54
|
-
if targets is None:
|
|
55
|
-
cur_targets = None
|
|
56
|
-
else:
|
|
57
|
-
cur_targets = [target[name] for target in targets]
|
|
58
|
-
|
|
59
|
-
cur_output, cur_loss_dict = decoder[-1](cur, inputs, cur_targets)
|
|
60
|
-
|
|
61
|
-
for idx, entry in enumerate(cur_output):
|
|
62
|
-
outputs[idx][name] = entry
|
|
63
|
-
for loss_name, loss_value in cur_loss_dict.items():
|
|
64
|
-
losses[f"{name}_{loss_name}"] = loss_value
|
|
65
|
-
return outputs, losses
|
|
421
|
+
return self.decoder_to_target[decoder]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth model architecture."""
|