rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/anysat.py +35 -33
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clip.py +5 -2
- rslearn/models/component.py +12 -0
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +206 -51
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +43 -17
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +96 -28
- rslearn/train/dataset.py +102 -53
- rslearn/train/model_context.py +35 -1
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
from jsonargparse import Namespace
|
|
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
21
20
|
def parse_string(
|
|
22
21
|
self,
|
|
23
22
|
cfg_str: str,
|
|
24
|
-
|
|
25
|
-
ext_vars: dict | None = None,
|
|
26
|
-
env: bool | None = None,
|
|
27
|
-
defaults: bool = True,
|
|
28
|
-
with_meta: bool | None = None,
|
|
23
|
+
*args: Any,
|
|
29
24
|
**kwargs: Any,
|
|
30
25
|
) -> Namespace:
|
|
31
26
|
"""Pre-processes string for environment variable substitution before parsing."""
|
|
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
33
28
|
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
34
29
|
|
|
35
30
|
# Call the parent method with the substituted config
|
|
36
|
-
return super().parse_string(
|
|
37
|
-
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
38
|
-
)
|
|
31
|
+
return super().parse_string(substituted_cfg_str, *args, **kwargs)
|
rslearn/config/dataset.py
CHANGED
|
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
|
|
|
25
25
|
from upath import UPath
|
|
26
26
|
|
|
27
27
|
from rslearn.log_utils import get_logger
|
|
28
|
-
from rslearn.utils import PixelBounds, Projection
|
|
28
|
+
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
|
|
29
29
|
from rslearn.utils.raster_format import RasterFormat
|
|
30
30
|
from rslearn.utils.vector_format import VectorFormat
|
|
31
31
|
|
|
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
|
|
|
215
215
|
Returns:
|
|
216
216
|
tuple of updated projection and bounds with zoom offset applied
|
|
217
217
|
"""
|
|
218
|
-
if self.zoom_offset
|
|
219
|
-
|
|
220
|
-
projection = Projection(
|
|
221
|
-
projection.crs,
|
|
222
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
223
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
224
|
-
)
|
|
225
|
-
if self.zoom_offset > 0:
|
|
226
|
-
zoom_factor = 2**self.zoom_offset
|
|
227
|
-
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
218
|
+
if self.zoom_offset >= 0:
|
|
219
|
+
factor = ResolutionFactor(numerator=2**self.zoom_offset)
|
|
228
220
|
else:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
)
|
|
233
|
-
return projection, bounds
|
|
221
|
+
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
|
|
222
|
+
|
|
223
|
+
return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
|
|
234
224
|
|
|
235
225
|
@field_validator("format", mode="before")
|
|
236
226
|
@classmethod
|
|
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
|
|
|
645
635
|
default_factory=lambda: StorageConfig(),
|
|
646
636
|
description="jsonargparse configuration for the WindowStorageFactory.",
|
|
647
637
|
)
|
|
638
|
+
|
|
639
|
+
@field_validator("layers", mode="after")
|
|
640
|
+
@classmethod
|
|
641
|
+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
|
|
642
|
+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
|
|
643
|
+
for layer_name in v.keys():
|
|
644
|
+
if "." in layer_name:
|
|
645
|
+
raise ValueError(f"layer names must not contain periods: {layer_name}")
|
|
646
|
+
return v
|
rslearn/dataset/dataset.py
CHANGED
|
@@ -23,7 +23,7 @@ class Dataset:
|
|
|
23
23
|
.. code-block:: none
|
|
24
24
|
|
|
25
25
|
dataset/
|
|
26
|
-
config.json
|
|
26
|
+
config.json # optional, if config provided as runtime object
|
|
27
27
|
windows/
|
|
28
28
|
group1/
|
|
29
29
|
epsg:3857_10_623565_1528020/
|
|
@@ -40,37 +40,43 @@ class Dataset:
|
|
|
40
40
|
materialize.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
path: UPath,
|
|
46
|
+
disabled_layers: list[str] = [],
|
|
47
|
+
dataset_config: DatasetConfig | None = None,
|
|
48
|
+
) -> None:
|
|
44
49
|
"""Initializes a new Dataset.
|
|
45
50
|
|
|
46
51
|
Args:
|
|
47
52
|
path: the root directory of the dataset
|
|
48
53
|
disabled_layers: list of layers to disable
|
|
54
|
+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
|
|
49
55
|
"""
|
|
50
56
|
self.path = path
|
|
51
57
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for layer_name, layer_config in config.layers.items():
|
|
60
|
-
# Layer names must not contain period, since we use period to
|
|
61
|
-
# distinguish different materialized groups within a layer.
|
|
62
|
-
assert "." not in layer_name, "layer names must not contain periods"
|
|
63
|
-
if layer_name in disabled_layers:
|
|
64
|
-
logger.warning(f"Layer {layer_name} is disabled")
|
|
65
|
-
continue
|
|
66
|
-
self.layers[layer_name] = layer_config
|
|
67
|
-
|
|
68
|
-
self.tile_store_config = config.tile_store
|
|
69
|
-
self.storage = (
|
|
70
|
-
config.storage.instantiate_window_storage_factory().get_storage(
|
|
71
|
-
self.path
|
|
58
|
+
if dataset_config is None:
|
|
59
|
+
# Load dataset configuration from the dataset directory.
|
|
60
|
+
with (self.path / "config.json").open("r") as f:
|
|
61
|
+
config_content = f.read()
|
|
62
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
63
|
+
dataset_config = DatasetConfig.model_validate(
|
|
64
|
+
json.loads(config_content)
|
|
72
65
|
)
|
|
66
|
+
|
|
67
|
+
self.layers = {}
|
|
68
|
+
for layer_name, layer_config in dataset_config.layers.items():
|
|
69
|
+
if layer_name in disabled_layers:
|
|
70
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
71
|
+
continue
|
|
72
|
+
self.layers[layer_name] = layer_config
|
|
73
|
+
|
|
74
|
+
self.tile_store_config = dataset_config.tile_store
|
|
75
|
+
self.storage = (
|
|
76
|
+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
77
|
+
self.path
|
|
73
78
|
)
|
|
79
|
+
)
|
|
74
80
|
|
|
75
81
|
def load_windows(
|
|
76
82
|
self,
|
rslearn/lightning_cli.py
CHANGED
|
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
|
|
|
21
21
|
from rslearn.train.data_module import RslearnDataModule
|
|
22
22
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
23
23
|
from rslearn.utils.fsspec import open_atomic
|
|
24
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
24
25
|
|
|
25
26
|
WANDB_ID_FNAME = "wandb_id"
|
|
26
27
|
|
|
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
390
391
|
|
|
391
392
|
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
392
393
|
"""
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
if not hasattr(self.config, "subcommand"):
|
|
395
|
+
logger.warning(
|
|
396
|
+
"Config does not have subcommand attribute, assuming we are in run=False mode"
|
|
397
|
+
)
|
|
398
|
+
subcommand = None
|
|
399
|
+
c = self.config
|
|
400
|
+
else:
|
|
401
|
+
subcommand = self.config.subcommand
|
|
402
|
+
c = self.config[subcommand]
|
|
395
403
|
|
|
396
404
|
# If there is a RslearnPredictionWriter, set its path.
|
|
397
405
|
prediction_writer_callback = None
|
|
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
415
423
|
if subcommand == "predict":
|
|
416
424
|
c.return_predictions = False
|
|
417
425
|
|
|
418
|
-
#
|
|
426
|
+
# Default to DDP with find_unused_parameters. Likely won't get called with unified config
|
|
419
427
|
if subcommand == "fit":
|
|
420
|
-
c.trainer.strategy
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
+
if not c.trainer.strategy:
|
|
429
|
+
c.trainer.strategy = jsonargparse.Namespace(
|
|
430
|
+
{
|
|
431
|
+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
|
|
432
|
+
"init_args": jsonargparse.Namespace(
|
|
433
|
+
{"find_unused_parameters": True}
|
|
434
|
+
),
|
|
435
|
+
}
|
|
436
|
+
)
|
|
428
437
|
|
|
429
438
|
if c.management_dir:
|
|
430
439
|
self.enable_project_management(c.management_dir)
|
|
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
432
441
|
|
|
433
442
|
def model_handler() -> None:
|
|
434
443
|
"""Handler for any rslearn model X commands."""
|
|
444
|
+
init_jsonargparse()
|
|
445
|
+
|
|
435
446
|
RslearnLightningCLI(
|
|
436
447
|
model_class=RslearnLightningModule,
|
|
437
448
|
datamodule_class=RslearnDataModule,
|
rslearn/main.py
CHANGED
|
@@ -380,7 +380,7 @@ def apply_on_windows(
|
|
|
380
380
|
|
|
381
381
|
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
382
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
383
|
-
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
383
|
+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
|
|
384
384
|
apply_on_windows(
|
|
385
385
|
f=f,
|
|
386
386
|
dataset=dataset,
|
rslearn/models/anysat.py
CHANGED
|
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
|
|
|
4
4
|
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
7
9
|
import torch
|
|
8
10
|
from einops import rearrange
|
|
9
11
|
|
|
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
|
|
|
53
55
|
self,
|
|
54
56
|
modalities: list[str],
|
|
55
57
|
patch_size_meters: int,
|
|
56
|
-
dates: dict[str, list[int]],
|
|
57
58
|
output: str = "patch",
|
|
58
59
|
output_modality: str | None = None,
|
|
59
60
|
hub_repo: str = "gastruc/anysat",
|
|
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
|
|
|
85
86
|
if m not in MODALITY_RESOLUTIONS:
|
|
86
87
|
raise ValueError(f"Invalid modality: {m}")
|
|
87
88
|
|
|
88
|
-
if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
|
|
89
|
-
raise ValueError("`dates` keys must be time-series modalities only.")
|
|
90
|
-
for m in modalities:
|
|
91
|
-
if m in TIME_SERIES_MODALITIES and m not in dates:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
f"Missing required dates for time-series modality '{m}'."
|
|
94
|
-
)
|
|
95
|
-
|
|
96
89
|
if patch_size_meters % 10 != 0:
|
|
97
90
|
raise ValueError(
|
|
98
91
|
"In AnySat, `patch_size` is in meters and must be a multiple of 10."
|
|
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
|
|
|
106
99
|
|
|
107
100
|
self.modalities = modalities
|
|
108
101
|
self.patch_size_meters = int(patch_size_meters)
|
|
109
|
-
self.dates = dates
|
|
110
102
|
self.output = output
|
|
111
103
|
self.output_modality = output_modality
|
|
112
104
|
|
|
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
|
|
|
119
111
|
)
|
|
120
112
|
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
121
113
|
|
|
114
|
+
@staticmethod
|
|
115
|
+
def time_ranges_to_doy(
|
|
116
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
117
|
+
device: torch.device,
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
|
|
120
|
+
|
|
121
|
+
AnySat uses the doy with each timestamp, so we take the midpoint
|
|
122
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
123
|
+
time so that start_time == end_time == mid_time.
|
|
124
|
+
"""
|
|
125
|
+
doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
|
|
126
|
+
return torch.tensor(doys, dtype=torch.int32, device=device)
|
|
127
|
+
|
|
122
128
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
123
129
|
"""Forward pass for the AnySat model.
|
|
124
130
|
|
|
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
|
|
|
139
145
|
raise ValueError(f"Modality '{modality}' not present in inputs.")
|
|
140
146
|
|
|
141
147
|
cur = torch.stack(
|
|
142
|
-
[inp[modality] for inp in inputs], dim=0
|
|
143
|
-
) # (B, C,
|
|
148
|
+
[inp[modality].image for inp in inputs], dim=0
|
|
149
|
+
) # (B, C, T, H, W)
|
|
144
150
|
|
|
145
151
|
if modality in TIME_SERIES_MODALITIES:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
cur = rearrange(
|
|
149
|
-
cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
|
|
150
|
-
)
|
|
152
|
+
num_bands = cur.shape[1]
|
|
153
|
+
cur = rearrange(cur, "b c t h w -> b t c h w")
|
|
151
154
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
|
+
|
|
156
|
+
if inputs[0][modality].timestamps is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Require timestamps for time series modality {modality}"
|
|
159
|
+
)
|
|
160
|
+
timestamps = torch.stack(
|
|
161
|
+
[
|
|
162
|
+
self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
|
|
163
|
+
for inp in inputs
|
|
164
|
+
],
|
|
165
|
+
dim=0,
|
|
166
|
+
)
|
|
167
|
+
batch[f"{modality}_dates"] = timestamps
|
|
152
168
|
else:
|
|
169
|
+
# take the first (assumed only) timestep
|
|
170
|
+
cur = cur[:, :, 0]
|
|
153
171
|
num_bands = cur.shape[1]
|
|
154
172
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
173
|
|
|
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
|
|
|
173
191
|
"All modalities must share the same spatial extent (H*res, W*res)."
|
|
174
192
|
)
|
|
175
193
|
|
|
176
|
-
# Add *_dates
|
|
177
|
-
to_add = {}
|
|
178
|
-
for modality, x in list(batch.items()):
|
|
179
|
-
if modality in TIME_SERIES_MODALITIES:
|
|
180
|
-
B, T = x.shape[0], x.shape[1]
|
|
181
|
-
d = torch.as_tensor(
|
|
182
|
-
self.dates[modality], dtype=torch.long, device=x.device
|
|
183
|
-
)
|
|
184
|
-
if d.ndim != 1 or d.numel() != T:
|
|
185
|
-
raise ValueError(
|
|
186
|
-
f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
|
|
187
|
-
)
|
|
188
|
-
to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
|
|
189
|
-
|
|
190
|
-
batch.update(to_add)
|
|
191
|
-
|
|
192
194
|
kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
|
|
193
195
|
if self.output == "dense":
|
|
194
196
|
kwargs["output_modality"] = self.output_modality
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""An attention pooling layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import (
|
|
12
|
+
FeatureMaps,
|
|
13
|
+
IntermediateComponent,
|
|
14
|
+
TokenFeatureMaps,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleAttentionPool(IntermediateComponent):
|
|
20
|
+
"""Simple Attention Pooling.
|
|
21
|
+
|
|
22
|
+
Given a token feature map of shape BCHWN,
|
|
23
|
+
learn an attention layer which aggregates over
|
|
24
|
+
the N dimension.
|
|
25
|
+
|
|
26
|
+
This is done simply by learning a mapping D->1 which is the weight
|
|
27
|
+
which should be assigned to each token during averaging:
|
|
28
|
+
|
|
29
|
+
output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
|
|
33
|
+
"""Initialize the simple attention pooling layer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
in_dim: the encoding dimension D
|
|
37
|
+
hidden_linear: whether to apply an additional linear transformation D -> D
|
|
38
|
+
to the feat tokens. If this is True, a ReLU activation is applied
|
|
39
|
+
after the first linear transformation.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
if hidden_linear:
|
|
43
|
+
self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
|
|
44
|
+
else:
|
|
45
|
+
self.hidden_linear = None
|
|
46
|
+
self.linear = nn.Linear(in_features=in_dim, out_features=1)
|
|
47
|
+
|
|
48
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
50
|
+
B, D, H, W, N = feat_tokens.shape
|
|
51
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
52
|
+
if self.hidden_linear is not None:
|
|
53
|
+
feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
|
|
54
|
+
attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
|
|
55
|
+
feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
|
|
56
|
+
return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
|
|
57
|
+
|
|
58
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
59
|
+
"""Forward pass for attention pooling linear probe.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
63
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple maps
|
|
64
|
+
are passed, we apply the same linear layers to all of them.
|
|
65
|
+
context: the model context.
|
|
66
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor:
|
|
70
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
71
|
+
"""
|
|
72
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
73
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
74
|
+
|
|
75
|
+
features = []
|
|
76
|
+
for feat in intermediates.feature_maps:
|
|
77
|
+
features.append(self.forward_for_map(feat))
|
|
78
|
+
return FeatureMaps(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AttentionPool(IntermediateComponent):
|
|
82
|
+
"""Attention Pooling.
|
|
83
|
+
|
|
84
|
+
Given a feature map of shape BCHWN,
|
|
85
|
+
learn an attention layer which aggregates over
|
|
86
|
+
the N dimension.
|
|
87
|
+
|
|
88
|
+
We do this by learning a query token, and applying a standard
|
|
89
|
+
attention mechanism against this learned query token.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
|
|
93
|
+
"""Initialize the attention pooling layer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
in_dim: the encoding dimension D
|
|
97
|
+
num_heads: the number of heads to use
|
|
98
|
+
linear_on_kv: Whether to apply a linear layer on the input tokens
|
|
99
|
+
to create the key and value tokens.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
|
|
103
|
+
if linear_on_kv:
|
|
104
|
+
self.k_linear = nn.Linear(in_dim, in_dim)
|
|
105
|
+
self.v_linear = nn.Linear(in_dim, in_dim)
|
|
106
|
+
else:
|
|
107
|
+
self.k_linear = None
|
|
108
|
+
self.v_linear = None
|
|
109
|
+
if in_dim % num_heads != 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
|
|
112
|
+
)
|
|
113
|
+
self.num_heads = num_heads
|
|
114
|
+
self.init_weights()
|
|
115
|
+
|
|
116
|
+
def init_weights(self) -> None:
|
|
117
|
+
"""Initialize weights for the probe."""
|
|
118
|
+
nn.init.trunc_normal_(self.query_token, std=0.02)
|
|
119
|
+
|
|
120
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
122
|
+
B, D, H, W, N = feat_tokens.shape
|
|
123
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
124
|
+
collapsed_dim = B * H * W
|
|
125
|
+
q = self.query_token.expand(collapsed_dim, 1, -1)
|
|
126
|
+
q = q.reshape(
|
|
127
|
+
collapsed_dim, 1, self.num_heads, D // self.num_heads
|
|
128
|
+
) # [B, 1, head, D_head]
|
|
129
|
+
q = rearrange(q, "b h n d -> b n h d")
|
|
130
|
+
if self.k_linear is not None:
|
|
131
|
+
assert self.v_linear is not None
|
|
132
|
+
k = self.k_linear(feat_tokens).reshape(
|
|
133
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
134
|
+
)
|
|
135
|
+
v = self.v_linear(feat_tokens).reshape(
|
|
136
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
k = feat_tokens.reshape(
|
|
140
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
141
|
+
)
|
|
142
|
+
v = feat_tokens.reshape(
|
|
143
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
144
|
+
)
|
|
145
|
+
k = rearrange(k, "b n h d -> b h n d")
|
|
146
|
+
v = rearrange(v, "b n h d -> b h n d")
|
|
147
|
+
|
|
148
|
+
# Compute attention scores
|
|
149
|
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
|
150
|
+
D // self.num_heads
|
|
151
|
+
)
|
|
152
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
|
|
154
|
+
return x.reshape(B, D, H, W)
|
|
155
|
+
|
|
156
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
|
+
"""Forward pass for attention pooling linear probe.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
161
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple feature
|
|
162
|
+
maps are passed, we apply the same attention weights (query token and linear k, v layers)
|
|
163
|
+
to all the maps.
|
|
164
|
+
context: the model context.
|
|
165
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor:
|
|
169
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
170
|
+
"""
|
|
171
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
172
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
173
|
+
|
|
174
|
+
features = []
|
|
175
|
+
for feat in intermediates.feature_maps:
|
|
176
|
+
features.append(self.forward_for_map(feat))
|
|
177
|
+
return FeatureMaps(features)
|
rslearn/models/clip.py
CHANGED
|
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
|
|
|
43
43
|
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
44
|
"""
|
|
45
45
|
inputs = context.inputs
|
|
46
|
-
device = inputs[0]["image"].device
|
|
46
|
+
device = inputs[0]["image"].image.device
|
|
47
47
|
clip_inputs = self.processor(
|
|
48
|
-
images=[
|
|
48
|
+
images=[
|
|
49
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
50
|
+
for inp in inputs
|
|
51
|
+
],
|
|
49
52
|
return_tensors="pt",
|
|
50
53
|
padding=True,
|
|
51
54
|
)
|
rslearn/models/component.py
CHANGED
|
@@ -91,6 +91,18 @@ class FeatureMaps:
|
|
|
91
91
|
feature_maps: list[torch.Tensor]
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
@dataclass
|
|
95
|
+
class TokenFeatureMaps:
|
|
96
|
+
"""An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
|
|
97
|
+
|
|
98
|
+
Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
|
|
102
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
103
|
+
feature_maps: list[torch.Tensor]
|
|
104
|
+
|
|
105
|
+
|
|
94
106
|
@dataclass
|
|
95
107
|
class FeatureVector:
|
|
96
108
|
"""An intermediate output type for a flat feature vector."""
|
rslearn/models/croma.py
CHANGED
|
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
|
|
|
175
175
|
sentinel1: torch.Tensor | None = None
|
|
176
176
|
sentinel2: torch.Tensor | None = None
|
|
177
177
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
178
|
-
sentinel1 = torch.stack(
|
|
178
|
+
sentinel1 = torch.stack(
|
|
179
|
+
[inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
180
|
+
dim=0,
|
|
181
|
+
)
|
|
179
182
|
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
180
183
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
181
|
-
sentinel2 = torch.stack(
|
|
184
|
+
sentinel2 = torch.stack(
|
|
185
|
+
[inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
186
|
+
dim=0,
|
|
187
|
+
)
|
|
182
188
|
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
183
189
|
|
|
184
190
|
outputs = self.model(
|
|
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
|
|
|
294
300
|
for modality in MODALITY_BANDS.keys():
|
|
295
301
|
if modality not in input_dict:
|
|
296
302
|
continue
|
|
297
|
-
input_dict[modality] = self.apply_image(
|
|
303
|
+
input_dict[modality].image = self.apply_image(
|
|
304
|
+
input_dict[modality].image, modality
|
|
305
|
+
)
|
|
298
306
|
return input_dict, target_dict
|
rslearn/models/dinov3.py
CHANGED
|
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
|
|
|
104
104
|
a FeatureMaps with one feature map.
|
|
105
105
|
"""
|
|
106
106
|
cur = torch.stack(
|
|
107
|
-
[inp["image"] for inp in context.inputs],
|
|
107
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
108
|
+
dim=0,
|
|
108
109
|
) # (B, C, H, W)
|
|
109
110
|
|
|
110
111
|
if self.do_resizing and (
|
rslearn/models/faster_rcnn.py
CHANGED
|
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
|
|
|
210
210
|
),
|
|
211
211
|
)
|
|
212
212
|
|
|
213
|
-
|
|
213
|
+
# take the first (and assumed to be only) timestep
|
|
214
|
+
image_list = [inp["image"].image[:, 0] for inp in context.inputs]
|
|
214
215
|
images, targets = self.noop_transform(image_list, targets)
|
|
215
216
|
|
|
216
217
|
feature_dict = collections.OrderedDict()
|