rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +49 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +18 -12
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/dataset/window.py
CHANGED
|
@@ -1,20 +1,16 @@
|
|
|
1
1
|
"""rslearn windows."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
from datetime import datetime
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Any
|
|
6
5
|
|
|
7
6
|
import shapely
|
|
8
7
|
from upath import UPath
|
|
9
8
|
|
|
9
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
10
10
|
from rslearn.log_utils import get_logger
|
|
11
11
|
from rslearn.utils import Projection, STGeometry
|
|
12
|
-
from rslearn.utils.fsspec import open_atomic
|
|
13
12
|
from rslearn.utils.raster_format import get_bandset_dirname
|
|
14
13
|
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from .index import DatasetIndex
|
|
17
|
-
|
|
18
14
|
logger = get_logger(__name__)
|
|
19
15
|
|
|
20
16
|
LAYERS_DIRECTORY_NAME = "layers"
|
|
@@ -138,14 +134,13 @@ class Window:
|
|
|
138
134
|
|
|
139
135
|
def __init__(
|
|
140
136
|
self,
|
|
141
|
-
|
|
137
|
+
storage: WindowStorage,
|
|
142
138
|
group: str,
|
|
143
139
|
name: str,
|
|
144
140
|
projection: Projection,
|
|
145
141
|
bounds: tuple[int, int, int, int],
|
|
146
142
|
time_range: tuple[datetime, datetime] | None,
|
|
147
143
|
options: dict[str, Any] = {},
|
|
148
|
-
index: "DatasetIndex | None" = None,
|
|
149
144
|
) -> None:
|
|
150
145
|
"""Creates a new Window instance.
|
|
151
146
|
|
|
@@ -153,23 +148,21 @@ class Window:
|
|
|
153
148
|
stored in metadata.json.
|
|
154
149
|
|
|
155
150
|
Args:
|
|
156
|
-
|
|
151
|
+
storage: the dataset storage for the underlying rslearn dataset.
|
|
157
152
|
group: the group the window belongs to
|
|
158
153
|
name: the unique name for this window
|
|
159
154
|
projection: the projection of the window
|
|
160
155
|
bounds: the bounds of the window in pixel coordinates
|
|
161
156
|
time_range: optional time range of the window
|
|
162
157
|
options: additional options (?)
|
|
163
|
-
index: DatasetIndex if it is available
|
|
164
158
|
"""
|
|
165
|
-
self.
|
|
159
|
+
self.storage = storage
|
|
166
160
|
self.group = group
|
|
167
161
|
self.name = name
|
|
168
162
|
self.projection = projection
|
|
169
163
|
self.bounds = bounds
|
|
170
164
|
self.time_range = time_range
|
|
171
165
|
self.options = options
|
|
172
|
-
self.index = index
|
|
173
166
|
|
|
174
167
|
def get_geometry(self) -> STGeometry:
|
|
175
168
|
"""Computes the STGeometry corresponding to this window."""
|
|
@@ -181,29 +174,11 @@ class Window:
|
|
|
181
174
|
|
|
182
175
|
def load_layer_datas(self) -> dict[str, WindowLayerData]:
|
|
183
176
|
"""Load layer datas describing items in retrieved layers from items.json."""
|
|
184
|
-
|
|
185
|
-
if self.index is not None:
|
|
186
|
-
layer_datas = self.index.layer_datas.get(self.name, [])
|
|
187
|
-
|
|
188
|
-
else:
|
|
189
|
-
items_fname = self.path / "items.json"
|
|
190
|
-
if not items_fname.exists():
|
|
191
|
-
return {}
|
|
192
|
-
with items_fname.open("r") as f:
|
|
193
|
-
layer_datas = [
|
|
194
|
-
WindowLayerData.deserialize(layer_data)
|
|
195
|
-
for layer_data in json.load(f)
|
|
196
|
-
]
|
|
197
|
-
|
|
198
|
-
return {layer_data.layer_name: layer_data for layer_data in layer_datas}
|
|
177
|
+
return self.storage.get_layer_datas(self.group, self.name)
|
|
199
178
|
|
|
200
179
|
def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
|
|
201
180
|
"""Save layer datas to items.json."""
|
|
202
|
-
|
|
203
|
-
items_fname = self.path / "items.json"
|
|
204
|
-
logger.info(f"Saving window items to {items_fname}")
|
|
205
|
-
with open_atomic(items_fname, "w") as f:
|
|
206
|
-
json.dump(json_data, f)
|
|
181
|
+
self.storage.save_layer_datas(self.group, self.name, layer_datas)
|
|
207
182
|
|
|
208
183
|
def list_completed_layers(self) -> list[tuple[str, int]]:
|
|
209
184
|
"""List the layers available for this window that are completed.
|
|
@@ -211,18 +186,7 @@ class Window:
|
|
|
211
186
|
Returns:
|
|
212
187
|
a list of (layer_name, group_idx) completed layers.
|
|
213
188
|
"""
|
|
214
|
-
|
|
215
|
-
if not layers_directory.exists():
|
|
216
|
-
return []
|
|
217
|
-
|
|
218
|
-
completed_layers = []
|
|
219
|
-
for layer_dir in layers_directory.iterdir():
|
|
220
|
-
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
221
|
-
if not self.is_layer_completed(layer_name, group_idx):
|
|
222
|
-
continue
|
|
223
|
-
completed_layers.append((layer_name, group_idx))
|
|
224
|
-
|
|
225
|
-
return completed_layers
|
|
189
|
+
return self.storage.list_completed_layers(self.group, self.name)
|
|
226
190
|
|
|
227
191
|
def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
|
|
228
192
|
"""Get the directory containing materialized data for the specified layer.
|
|
@@ -235,7 +199,9 @@ class Window:
|
|
|
235
199
|
Returns:
|
|
236
200
|
the path where data is or should be materialized.
|
|
237
201
|
"""
|
|
238
|
-
return get_window_layer_dir(
|
|
202
|
+
return get_window_layer_dir(
|
|
203
|
+
self.storage.get_window_root(self.group, self.name), layer_name, group_idx
|
|
204
|
+
)
|
|
239
205
|
|
|
240
206
|
def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
|
|
241
207
|
"""Check whether the specified layer is completed.
|
|
@@ -250,14 +216,9 @@ class Window:
|
|
|
250
216
|
Returns:
|
|
251
217
|
whether the layer is completed
|
|
252
218
|
"""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
self.name, []
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
layer_dir = self.get_layer_dir(layer_name, group_idx)
|
|
260
|
-
return (layer_dir / "completed").exists()
|
|
219
|
+
return self.storage.is_layer_completed(
|
|
220
|
+
self.group, self.name, layer_name, group_idx
|
|
221
|
+
)
|
|
261
222
|
|
|
262
223
|
def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
|
|
263
224
|
"""Mark the specified layer completed.
|
|
@@ -272,8 +233,7 @@ class Window:
|
|
|
272
233
|
layer_name: the layer name.
|
|
273
234
|
group_idx: the index of the group within the layer.
|
|
274
235
|
"""
|
|
275
|
-
|
|
276
|
-
(layer_dir / "completed").touch()
|
|
236
|
+
self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
|
|
277
237
|
|
|
278
238
|
def get_raster_dir(
|
|
279
239
|
self, layer_name: str, bands: list[str], group_idx: int = 0
|
|
@@ -289,7 +249,12 @@ class Window:
|
|
|
289
249
|
Returns:
|
|
290
250
|
the directory containing the raster.
|
|
291
251
|
"""
|
|
292
|
-
return get_window_raster_dir(
|
|
252
|
+
return get_window_raster_dir(
|
|
253
|
+
self.storage.get_window_root(self.group, self.name),
|
|
254
|
+
layer_name,
|
|
255
|
+
bands,
|
|
256
|
+
group_idx,
|
|
257
|
+
)
|
|
293
258
|
|
|
294
259
|
def get_metadata(self) -> dict[str, Any]:
|
|
295
260
|
"""Returns the window metadata dictionary."""
|
|
@@ -308,18 +273,14 @@ class Window:
|
|
|
308
273
|
|
|
309
274
|
def save(self) -> None:
|
|
310
275
|
"""Save the window metadata to its root directory."""
|
|
311
|
-
self.
|
|
312
|
-
metadata_path = self.path / "metadata.json"
|
|
313
|
-
logger.debug(f"Saving window metadata to {metadata_path}")
|
|
314
|
-
with open_atomic(metadata_path, "w") as f:
|
|
315
|
-
json.dump(self.get_metadata(), f)
|
|
276
|
+
self.storage.create_or_update_window(self)
|
|
316
277
|
|
|
317
278
|
@staticmethod
|
|
318
|
-
def from_metadata(
|
|
319
|
-
"""Create a Window from
|
|
279
|
+
def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
|
|
280
|
+
"""Create a Window from the WindowStorage and the window's metadata dictionary.
|
|
320
281
|
|
|
321
282
|
Args:
|
|
322
|
-
|
|
283
|
+
storage: the WindowStorage for the underlying dataset.
|
|
323
284
|
metadata: the window metadata.
|
|
324
285
|
|
|
325
286
|
Returns:
|
|
@@ -334,7 +295,7 @@ class Window:
|
|
|
334
295
|
)
|
|
335
296
|
|
|
336
297
|
return Window(
|
|
337
|
-
|
|
298
|
+
storage=storage,
|
|
338
299
|
group=metadata["group"],
|
|
339
300
|
name=metadata["name"],
|
|
340
301
|
projection=Projection.deserialize(metadata["projection"]),
|
|
@@ -350,21 +311,6 @@ class Window:
|
|
|
350
311
|
options=metadata["options"],
|
|
351
312
|
)
|
|
352
313
|
|
|
353
|
-
@staticmethod
|
|
354
|
-
def load(path: UPath) -> "Window":
|
|
355
|
-
"""Load a Window from a UPath.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
path: the root directory of the window
|
|
359
|
-
|
|
360
|
-
Returns:
|
|
361
|
-
the Window
|
|
362
|
-
"""
|
|
363
|
-
metadata_fname = path / "metadata.json"
|
|
364
|
-
with metadata_fname.open("r") as f:
|
|
365
|
-
metadata = json.load(f)
|
|
366
|
-
return Window.from_metadata(path, metadata)
|
|
367
|
-
|
|
368
314
|
@staticmethod
|
|
369
315
|
def get_window_root(ds_path: UPath, group: str, name: str) -> UPath:
|
|
370
316
|
"""Gets the root directory of a window.
|
rslearn/main.py
CHANGED
|
@@ -27,13 +27,13 @@ from rslearn.dataset.handler_summaries import (
|
|
|
27
27
|
PrepareDatasetWindowsSummary,
|
|
28
28
|
UnknownIngestCounts,
|
|
29
29
|
)
|
|
30
|
-
from rslearn.dataset.index import DatasetIndex
|
|
31
30
|
from rslearn.dataset.manage import (
|
|
32
31
|
AttemptsCounter,
|
|
33
32
|
materialize_dataset_windows,
|
|
34
33
|
prepare_dataset_windows,
|
|
35
34
|
retry,
|
|
36
35
|
)
|
|
36
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
37
37
|
from rslearn.log_utils import get_logger
|
|
38
38
|
from rslearn.tile_stores import get_tile_store_with_layer
|
|
39
39
|
from rslearn.utils import Projection, STGeometry
|
|
@@ -315,7 +315,8 @@ def apply_on_windows(
|
|
|
315
315
|
load_workers: optional different number of workers to use for loading the
|
|
316
316
|
windows. If set, workers controls the number of workers to process the
|
|
317
317
|
jobs, while load_workers controls the number of workers to use for reading
|
|
318
|
-
windows from the rslearn dataset.
|
|
318
|
+
windows from the rslearn dataset. Workers is only passed if the window
|
|
319
|
+
storage is FileWindowStorage.
|
|
319
320
|
batch_size: if workers > 0, the maximum number of windows to pass to the
|
|
320
321
|
function.
|
|
321
322
|
jobs_per_process: optional, terminate processes after they have handled this
|
|
@@ -336,11 +337,14 @@ def apply_on_windows(
|
|
|
336
337
|
else:
|
|
337
338
|
groups = group
|
|
338
339
|
|
|
339
|
-
if
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
340
|
+
# Load the windows. We pass workers and show_progress if it is FileWindowStorage.
|
|
341
|
+
kwargs: dict[str, Any] = {}
|
|
342
|
+
if isinstance(dataset.storage, FileWindowStorage):
|
|
343
|
+
if load_workers is None:
|
|
344
|
+
load_workers = workers
|
|
345
|
+
kwargs["workers"] = load_workers
|
|
346
|
+
kwargs["show_progress"] = True
|
|
347
|
+
windows = dataset.load_windows(groups=groups, names=names, **kwargs)
|
|
344
348
|
logger.info(f"found {len(windows)} windows")
|
|
345
349
|
|
|
346
350
|
if hasattr(f, "get_jobs"):
|
|
@@ -798,35 +802,6 @@ def dataset_materialize() -> None:
|
|
|
798
802
|
apply_on_windows_args(fn, args)
|
|
799
803
|
|
|
800
804
|
|
|
801
|
-
@register_handler("dataset", "build_index")
|
|
802
|
-
def dataset_build_index() -> None:
|
|
803
|
-
"""Handler for the rslearn dataset build_index command."""
|
|
804
|
-
parser = argparse.ArgumentParser(
|
|
805
|
-
prog="rslearn dataset build_index",
|
|
806
|
-
description=("rslearn dataset build_index: " + "create a dataset index file"),
|
|
807
|
-
)
|
|
808
|
-
parser.add_argument(
|
|
809
|
-
"--root",
|
|
810
|
-
type=str,
|
|
811
|
-
required=True,
|
|
812
|
-
help="Dataset path",
|
|
813
|
-
)
|
|
814
|
-
parser.add_argument(
|
|
815
|
-
"--workers",
|
|
816
|
-
type=int,
|
|
817
|
-
default=16,
|
|
818
|
-
help="Number of workers",
|
|
819
|
-
)
|
|
820
|
-
args = parser.parse_args(args=sys.argv[3:])
|
|
821
|
-
ds_path = UPath(args.root)
|
|
822
|
-
dataset = Dataset(ds_path)
|
|
823
|
-
index = DatasetIndex.build_index(
|
|
824
|
-
dataset=dataset,
|
|
825
|
-
workers=args.workers,
|
|
826
|
-
)
|
|
827
|
-
index.save_index(ds_path)
|
|
828
|
-
|
|
829
|
-
|
|
830
805
|
@register_handler("model", "fit")
|
|
831
806
|
def model_fit() -> None:
|
|
832
807
|
"""Handler for rslearn model fit."""
|
rslearn/models/anysat.py
CHANGED
|
@@ -4,11 +4,13 @@ This code loads the AnySat model from torch hub. See
|
|
|
4
4
|
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
7
|
import torch
|
|
10
8
|
from einops import rearrange
|
|
11
9
|
|
|
10
|
+
from rslearn.train.model_context import ModelContext
|
|
11
|
+
|
|
12
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
13
|
+
|
|
12
14
|
# AnySat github: https://github.com/gastruc/AnySat
|
|
13
15
|
# Modalities and expected resolutions (meters)
|
|
14
16
|
MODALITY_RESOLUTIONS: dict[str, float] = {
|
|
@@ -44,7 +46,7 @@ MODALITY_BANDS: dict[str, list[str]] = {
|
|
|
44
46
|
TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
class AnySat(
|
|
49
|
+
class AnySat(FeatureExtractor):
|
|
48
50
|
"""AnySat backbone (outputs one feature map)."""
|
|
49
51
|
|
|
50
52
|
def __init__(
|
|
@@ -117,17 +119,17 @@ class AnySat(torch.nn.Module):
|
|
|
117
119
|
)
|
|
118
120
|
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
119
121
|
|
|
120
|
-
def forward(self,
|
|
122
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
121
123
|
"""Forward pass for the AnySat model.
|
|
122
124
|
|
|
123
125
|
Args:
|
|
124
|
-
|
|
126
|
+
context: the model context. Input dicts must include modalities as keys
|
|
127
|
+
which are defined in the self.modalities list
|
|
125
128
|
|
|
126
129
|
Returns:
|
|
127
|
-
|
|
130
|
+
a FeatureMaps with one feature map at the configured patch size.
|
|
128
131
|
"""
|
|
129
|
-
|
|
130
|
-
raise ValueError("empty inputs")
|
|
132
|
+
inputs = context.inputs
|
|
131
133
|
|
|
132
134
|
batch: dict[str, torch.Tensor] = {}
|
|
133
135
|
spatial_extent: tuple[float, float] | None = None
|
|
@@ -192,7 +194,7 @@ class AnySat(torch.nn.Module):
|
|
|
192
194
|
kwargs["output_modality"] = self.output_modality
|
|
193
195
|
|
|
194
196
|
features = self.model(batch, **kwargs)
|
|
195
|
-
return [rearrange(features, "b h w d -> b d h w")]
|
|
197
|
+
return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
|
|
196
198
|
|
|
197
199
|
def get_backbone_channels(self) -> list:
|
|
198
200
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -16,6 +16,8 @@ from huggingface_hub import hf_hub_download
|
|
|
16
16
|
# from claymodel.module import ClayMAEModule
|
|
17
17
|
from terratorch.models.backbones.clay_v15.module import ClayMAEModule
|
|
18
18
|
|
|
19
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
20
|
+
from rslearn.train.model_context import ModelContext
|
|
19
21
|
from rslearn.train.transforms.normalize import Normalize
|
|
20
22
|
from rslearn.train.transforms.transform import Transform
|
|
21
23
|
|
|
@@ -42,7 +44,7 @@ def get_clay_checkpoint_path(
|
|
|
42
44
|
return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
|
|
43
45
|
|
|
44
46
|
|
|
45
|
-
class Clay(
|
|
47
|
+
class Clay(FeatureExtractor):
|
|
46
48
|
"""Clay backbones."""
|
|
47
49
|
|
|
48
50
|
def __init__(
|
|
@@ -108,23 +110,20 @@ class Clay(torch.nn.Module):
|
|
|
108
110
|
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
109
111
|
)
|
|
110
112
|
|
|
111
|
-
def forward(self,
|
|
113
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
112
114
|
"""Forward pass for the Clay model.
|
|
113
115
|
|
|
114
116
|
Args:
|
|
115
|
-
|
|
117
|
+
context: the model context. Input dicts must include `self.modality` as a key
|
|
116
118
|
|
|
117
119
|
Returns:
|
|
118
|
-
|
|
120
|
+
a FeatureMaps consisting of one feature map, computed by Clay.
|
|
119
121
|
"""
|
|
120
|
-
if self.modality not in inputs[0]:
|
|
121
|
-
raise ValueError(f"Missing modality {self.modality} in inputs.")
|
|
122
|
-
|
|
123
122
|
param = next(self.model.parameters())
|
|
124
123
|
device = param.device
|
|
125
124
|
|
|
126
125
|
chips = torch.stack(
|
|
127
|
-
[inp[self.modality] for inp in inputs], dim=0
|
|
126
|
+
[inp[self.modality] for inp in context.inputs], dim=0
|
|
128
127
|
) # (B, C, H, W)
|
|
129
128
|
if self.do_resizing:
|
|
130
129
|
chips = self._resize_image(chips, chips.shape[2])
|
|
@@ -163,7 +162,7 @@ class Clay(torch.nn.Module):
|
|
|
163
162
|
)
|
|
164
163
|
|
|
165
164
|
features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
|
|
166
|
-
return [features]
|
|
165
|
+
return FeatureMaps([features])
|
|
167
166
|
|
|
168
167
|
def get_backbone_channels(self) -> list:
|
|
169
168
|
"""Return output channels of this model when used as a backbone."""
|
rslearn/models/clip.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
"""OpenAI CLIP models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
3
|
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
|
|
7
4
|
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
|
+
|
|
7
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
|
+
|
|
8
9
|
|
|
9
|
-
class CLIP(
|
|
10
|
+
class CLIP(FeatureExtractor):
|
|
10
11
|
"""CLIP image encoder."""
|
|
11
12
|
|
|
12
13
|
def __init__(
|
|
@@ -31,17 +32,17 @@ class CLIP(torch.nn.Module):
|
|
|
31
32
|
self.height = crop_size["height"] // stride[0]
|
|
32
33
|
self.width = crop_size["width"] // stride[1]
|
|
33
34
|
|
|
34
|
-
def forward(self,
|
|
35
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
35
36
|
"""Compute outputs from the backbone.
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
process. The images should have values 0-255.
|
|
38
|
+
Args:
|
|
39
|
+
context: the model context. Input dicts must include "image" key containing
|
|
40
|
+
the image to process. The images should have values 0-255.
|
|
40
41
|
|
|
41
42
|
Returns:
|
|
42
|
-
|
|
43
|
-
contains a single Bx24x24x1024 feature map.
|
|
43
|
+
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
44
|
"""
|
|
45
|
+
inputs = context.inputs
|
|
45
46
|
device = inputs[0]["image"].device
|
|
46
47
|
clip_inputs = self.processor(
|
|
47
48
|
images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
|
|
@@ -55,8 +56,10 @@ class CLIP(torch.nn.Module):
|
|
|
55
56
|
batch_size = image_features.shape[0]
|
|
56
57
|
|
|
57
58
|
# 576x1024 -> HxWxC
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
return FeatureMaps(
|
|
60
|
+
[
|
|
61
|
+
image_features.reshape(
|
|
62
|
+
batch_size, self.height, self.width, self.num_features
|
|
63
|
+
).permute(0, 3, 1, 2)
|
|
64
|
+
]
|
|
65
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Model component API."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeatureExtractor(torch.nn.Module, abc.ABC):
|
|
13
|
+
"""A feature extractor that performs initial processing of the inputs.
|
|
14
|
+
|
|
15
|
+
The FeatureExtractor is the first component in the encoders list for
|
|
16
|
+
SingleTaskModel and MultiTaskModel.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def forward(self, context: ModelContext) -> Any:
|
|
21
|
+
"""Extract an initial intermediate from the model context.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
context: the model context.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
any intermediate to pass to downstream components. Oftentimes this is a
|
|
28
|
+
FeatureMaps.
|
|
29
|
+
"""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class IntermediateComponent(torch.nn.Module, abc.ABC):
|
|
34
|
+
"""An intermediate component in the model.
|
|
35
|
+
|
|
36
|
+
In SingleTaskModel and MultiTaskModel, modules after the first module
|
|
37
|
+
in the encoders list are IntermediateComponents, as are modules before the last
|
|
38
|
+
module in the decoders list(s).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
43
|
+
"""Process the given intermediate into another intermediate.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
intermediates: the output from the previous component (either a
|
|
47
|
+
FeatureExtractor or another IntermediateComponent).
|
|
48
|
+
context: the model context.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
any intermediate to pass to downstream components.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Predictor(torch.nn.Module, abc.ABC):
|
|
57
|
+
"""A predictor that computes task-specific outputs and a loss dict.
|
|
58
|
+
|
|
59
|
+
In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
|
|
60
|
+
are Predictors.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abc.abstractmethod
|
|
64
|
+
def forward(
|
|
65
|
+
self,
|
|
66
|
+
intermediates: Any,
|
|
67
|
+
context: ModelContext,
|
|
68
|
+
targets: list[dict[str, torch.Tensor]] | None = None,
|
|
69
|
+
) -> ModelOutput:
|
|
70
|
+
"""Compute task-specific outputs and loss dict.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
intermediates: the output from the previous component.
|
|
74
|
+
context: the model context.
|
|
75
|
+
targets: the training targets, or None during prediction.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
a tuple of the task-specific outputs (which should be compatible with the
|
|
79
|
+
configured Task) and loss dict. The loss dict maps from a name for each
|
|
80
|
+
loss to a scalar tensor.
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class FeatureMaps:
|
|
87
|
+
"""An intermediate output type for multi-resolution feature maps."""
|
|
88
|
+
|
|
89
|
+
# List of BxCxHxW feature maps at different scales, ordered from highest resolution
|
|
90
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
91
|
+
feature_maps: list[torch.Tensor]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class FeatureVector:
|
|
96
|
+
"""An intermediate output type for a flat feature vector."""
|
|
97
|
+
|
|
98
|
+
# Flat BxC feature vector.
|
|
99
|
+
feature_vector: torch.Tensor
|
|
@@ -4,8 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatenateFeatures(IntermediateComponent):
|
|
9
13
|
"""Concatenate feature map with additional raw data inputs."""
|
|
10
14
|
|
|
11
15
|
def __init__(
|
|
@@ -55,26 +59,32 @@ class ConcatenateFeatures(torch.nn.Module):
|
|
|
55
59
|
|
|
56
60
|
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
57
61
|
|
|
58
|
-
def forward(
|
|
59
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
60
|
-
) -> list[torch.Tensor]:
|
|
62
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
61
63
|
"""Concatenate the feature map with the raw data inputs.
|
|
62
64
|
|
|
63
65
|
Args:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
67
|
+
context: the model context. The input dicts must have a key matching the
|
|
68
|
+
configured key.
|
|
66
69
|
|
|
67
70
|
Returns:
|
|
68
71
|
concatenated feature maps.
|
|
69
72
|
"""
|
|
70
|
-
if
|
|
71
|
-
|
|
73
|
+
if (
|
|
74
|
+
not isinstance(intermediates, FeatureMaps)
|
|
75
|
+
or len(intermediates.feature_maps) == 0
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Expected input to be FeatureMaps with at least one feature map"
|
|
79
|
+
)
|
|
72
80
|
|
|
73
|
-
add_data = torch.stack(
|
|
81
|
+
add_data = torch.stack(
|
|
82
|
+
[input_data[self.key] for input_data in context.inputs], dim=0
|
|
83
|
+
)
|
|
74
84
|
add_features = self.conv_layers(add_data)
|
|
75
85
|
|
|
76
86
|
new_features: list[torch.Tensor] = []
|
|
77
|
-
for feature_map in
|
|
87
|
+
for feature_map in intermediates.feature_maps:
|
|
78
88
|
# Shape of feature map: BCHW
|
|
79
89
|
feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
|
|
80
90
|
|
|
@@ -90,4 +100,4 @@ class ConcatenateFeatures(torch.nn.Module):
|
|
|
90
100
|
|
|
91
101
|
new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
|
|
92
102
|
|
|
93
|
-
return new_features
|
|
103
|
+
return FeatureMaps(new_features)
|
rslearn/models/conv.py
CHANGED
|
@@ -4,8 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Conv(IntermediateComponent):
|
|
9
13
|
"""A single convolutional layer.
|
|
10
14
|
|
|
11
15
|
It inputs a set of feature maps; the conv layer is applied to each feature map
|
|
@@ -38,19 +42,22 @@ class Conv(torch.nn.Module):
|
|
|
38
42
|
)
|
|
39
43
|
self.activation = activation
|
|
40
44
|
|
|
41
|
-
def forward(self,
|
|
42
|
-
"""
|
|
45
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
46
|
+
"""Apply conv layer on each feature map.
|
|
43
47
|
|
|
44
48
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
49
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
50
|
+
context: the model context.
|
|
47
51
|
|
|
48
52
|
Returns:
|
|
49
|
-
|
|
53
|
+
the resulting feature maps after applying the same Conv2d on each one.
|
|
50
54
|
"""
|
|
55
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
56
|
+
raise ValueError("input to Conv must be FeatureMaps")
|
|
57
|
+
|
|
51
58
|
new_features = []
|
|
52
|
-
for feat_map in
|
|
59
|
+
for feat_map in intermediates.feature_maps:
|
|
53
60
|
feat_map = self.layer(feat_map)
|
|
54
61
|
feat_map = self.activation(feat_map)
|
|
55
62
|
new_features.append(feat_map)
|
|
56
|
-
return new_features
|
|
63
|
+
return FeatureMaps(new_features)
|