rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/local_files.py +20 -3
- rslearn/data_sources/planetary_computer.py +79 -14
- rslearn/dataset/handler_summaries.py +130 -0
- rslearn/dataset/manage.py +159 -24
- rslearn/dataset/materialize.py +21 -2
- rslearn/dataset/remap.py +29 -4
- rslearn/main.py +60 -8
- rslearn/models/clay/clay.py +29 -14
- rslearn/models/copernicusfm.py +37 -25
- rslearn/models/dinov3.py +166 -0
- rslearn/models/galileo/galileo.py +58 -12
- rslearn/models/galileo/single_file_galileo.py +7 -1
- rslearn/models/presto/presto.py +11 -0
- rslearn/models/prithvi.py +139 -52
- rslearn/models/registry.py +19 -2
- rslearn/models/resize_features.py +45 -0
- rslearn/models/simple_time_series.py +65 -10
- rslearn/models/upsample.py +2 -2
- rslearn/tile_stores/default.py +34 -7
- rslearn/train/transforms/normalize.py +34 -5
- rslearn/train/transforms/select_bands.py +67 -0
- rslearn/train/transforms/sentinel1.py +60 -0
- rslearn/train/transforms/transform.py +23 -6
- rslearn/utils/raster_format.py +44 -5
- rslearn/utils/vector_format.py +35 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/METADATA +3 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/RECORD +31 -26
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/WHEEL +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/top_level.txt +0 -0
rslearn/dataset/materialize.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Classes to implement dataset materialization."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from typing import Any, Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
from rasterio.enums import Resampling
|
|
9
9
|
|
|
10
10
|
from rslearn.config import (
|
|
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
|
|
|
25
25
|
from .remap import Remapper, load_remapper
|
|
26
26
|
from .window import Window
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
_MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _MaterializerRegistry(dict[str, type["Materializer"]]):
|
|
32
|
+
"""Registry for Materializer classes."""
|
|
33
|
+
|
|
34
|
+
def register(
|
|
35
|
+
self, name: str
|
|
36
|
+
) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
|
|
37
|
+
"""Decorator to register a materializer class."""
|
|
38
|
+
|
|
39
|
+
def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
|
|
40
|
+
self[name] = cls
|
|
41
|
+
return cls
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
Materializers = _MaterializerRegistry()
|
|
47
|
+
|
|
29
48
|
|
|
30
49
|
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
31
50
|
|
rslearn/dataset/remap.py
CHANGED
|
@@ -1,18 +1,42 @@
|
|
|
1
1
|
"""Classes to remap raster values."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
_RemapperT = TypeVar("_RemapperT", bound="Remapper")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _RemapperRegistry(dict[str, type["Remapper"]]):
|
|
13
|
+
"""Registry for Remapper classes."""
|
|
14
|
+
|
|
15
|
+
def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
|
|
16
|
+
"""Decorator to register a remapper class."""
|
|
17
|
+
|
|
18
|
+
def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
|
|
19
|
+
self[name] = cls
|
|
20
|
+
return cls
|
|
21
|
+
|
|
22
|
+
return decorator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Remappers = _RemapperRegistry()
|
|
10
26
|
"""Registry of Remapper implementations."""
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class Remapper:
|
|
14
30
|
"""An abstract class that remaps pixel values based on layer configuration."""
|
|
15
31
|
|
|
32
|
+
def __init__(self, config: dict[str, Any]) -> None:
|
|
33
|
+
"""Initialize a Remapper.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: the config dict for this remapper.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
16
40
|
def __call__(
|
|
17
41
|
self, array: npt.NDArray[Any], dtype: npt.DTypeLike
|
|
18
42
|
) -> npt.NDArray[Any]:
|
|
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
|
|
|
67
91
|
|
|
68
92
|
def load_remapper(config: dict[str, Any]) -> Remapper:
|
|
69
93
|
"""Load a remapper from a configuration dictionary."""
|
|
70
|
-
|
|
94
|
+
cls = Remappers[config["name"]]
|
|
95
|
+
return cls(config)
|
rslearn/main.py
CHANGED
|
@@ -4,6 +4,7 @@ import argparse
|
|
|
4
4
|
import multiprocessing
|
|
5
5
|
import random
|
|
6
6
|
import sys
|
|
7
|
+
import time
|
|
7
8
|
from collections.abc import Callable
|
|
8
9
|
from datetime import UTC, datetime, timedelta
|
|
9
10
|
from typing import Any, TypeVar
|
|
@@ -19,8 +20,18 @@ from rslearn.const import WGS84_EPSG
|
|
|
19
20
|
from rslearn.data_sources import Item, data_source_from_config
|
|
20
21
|
from rslearn.dataset import Dataset, Window, WindowLayerData
|
|
21
22
|
from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
|
|
23
|
+
from rslearn.dataset.handler_summaries import (
|
|
24
|
+
ErrorOutcome,
|
|
25
|
+
IngestCounts,
|
|
26
|
+
IngestDatasetJobsSummary,
|
|
27
|
+
LayerIngestSummary,
|
|
28
|
+
MaterializeDatasetWindowsSummary,
|
|
29
|
+
PrepareDatasetWindowsSummary,
|
|
30
|
+
UnknownIngestCounts,
|
|
31
|
+
)
|
|
22
32
|
from rslearn.dataset.index import DatasetIndex
|
|
23
33
|
from rslearn.dataset.manage import (
|
|
34
|
+
AttemptsCounter,
|
|
24
35
|
materialize_dataset_windows,
|
|
25
36
|
prepare_dataset_windows,
|
|
26
37
|
retry,
|
|
@@ -287,7 +298,7 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
|
|
|
287
298
|
|
|
288
299
|
|
|
289
300
|
def apply_on_windows(
|
|
290
|
-
f: Callable[[list[Window]],
|
|
301
|
+
f: Callable[[list[Window]], Any],
|
|
291
302
|
dataset: Dataset,
|
|
292
303
|
group: str | list[str] | None = None,
|
|
293
304
|
names: list[str] | None = None,
|
|
@@ -367,7 +378,7 @@ def apply_on_windows(
|
|
|
367
378
|
p.close()
|
|
368
379
|
|
|
369
380
|
|
|
370
|
-
def apply_on_windows_args(f: Callable[...,
|
|
381
|
+
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
371
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
372
383
|
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
373
384
|
apply_on_windows(
|
|
@@ -413,12 +424,12 @@ class PrepareHandler:
|
|
|
413
424
|
"""
|
|
414
425
|
self.dataset = dataset
|
|
415
426
|
|
|
416
|
-
def __call__(self, windows: list[Window]) ->
|
|
427
|
+
def __call__(self, windows: list[Window]) -> PrepareDatasetWindowsSummary:
|
|
417
428
|
"""Prepares the windows from apply_on_windows."""
|
|
418
429
|
logger.info(f"Running prepare on {len(windows)} windows")
|
|
419
430
|
if self.dataset is None:
|
|
420
431
|
raise ValueError("dataset not set")
|
|
421
|
-
prepare_dataset_windows(
|
|
432
|
+
return prepare_dataset_windows(
|
|
422
433
|
self.dataset,
|
|
423
434
|
windows,
|
|
424
435
|
self.force,
|
|
@@ -502,14 +513,20 @@ class IngestHandler:
|
|
|
502
513
|
|
|
503
514
|
def __call__(
|
|
504
515
|
self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
|
|
505
|
-
) ->
|
|
516
|
+
) -> IngestDatasetJobsSummary:
|
|
506
517
|
"""Ingest the specified items.
|
|
507
518
|
|
|
508
519
|
The items are computed from list of windows via IngestHandler.get_jobs.
|
|
509
520
|
|
|
510
521
|
Args:
|
|
511
|
-
jobs: list of (layer_name, item, geometries) tuples to ingest.
|
|
522
|
+
jobs: list of (layer_name, layer_cfg, item, geometries) tuples to ingest.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
summary of the ingest jobs operation fit for telemetry purposes.
|
|
512
526
|
"""
|
|
527
|
+
start_time = time.monotonic()
|
|
528
|
+
layer_summaries: list[LayerIngestSummary] = []
|
|
529
|
+
|
|
513
530
|
logger.info(f"Running ingest for {len(jobs)} jobs")
|
|
514
531
|
import gc
|
|
515
532
|
|
|
@@ -533,6 +550,8 @@ class IngestHandler:
|
|
|
533
550
|
layer_cfg = self.dataset.layers[layer_name]
|
|
534
551
|
data_source = data_source_from_config(layer_cfg, self.dataset.path)
|
|
535
552
|
|
|
553
|
+
attempts_counter = AttemptsCounter()
|
|
554
|
+
ingest_counts: IngestCounts | UnknownIngestCounts
|
|
536
555
|
try:
|
|
537
556
|
retry(
|
|
538
557
|
lambda: data_source.ingest(
|
|
@@ -544,18 +563,47 @@ class IngestHandler:
|
|
|
544
563
|
),
|
|
545
564
|
retry_max_attempts=self.retry_max_attempts,
|
|
546
565
|
retry_backoff=self.retry_backoff,
|
|
566
|
+
attempts_counter=attempts_counter,
|
|
567
|
+
)
|
|
568
|
+
ingest_counts = IngestCounts(
|
|
569
|
+
items_ingested=len(items_and_geometries),
|
|
570
|
+
geometries_ingested=sum(
|
|
571
|
+
len(geometries) for _, geometries in items_and_geometries
|
|
572
|
+
),
|
|
547
573
|
)
|
|
548
574
|
except Exception as e:
|
|
549
575
|
if not self.ignore_errors:
|
|
550
576
|
raise
|
|
551
577
|
|
|
578
|
+
ingest_counts = UnknownIngestCounts(
|
|
579
|
+
items_attempted=len(items_and_geometries),
|
|
580
|
+
geometries_attempted=sum(
|
|
581
|
+
len(geometries) for _, geometries in items_and_geometries
|
|
582
|
+
),
|
|
583
|
+
)
|
|
552
584
|
logger.error(
|
|
553
585
|
"warning: got error while ingesting "
|
|
554
586
|
+ f"{len(items_and_geometries)} items: {e}"
|
|
555
587
|
)
|
|
556
588
|
|
|
589
|
+
layer_summaries.append(
|
|
590
|
+
LayerIngestSummary(
|
|
591
|
+
layer_name=layer_name,
|
|
592
|
+
data_source_name=getattr(layer_cfg.data_source, "name", "N/A"),
|
|
593
|
+
duration_seconds=time.monotonic() - start_time,
|
|
594
|
+
ingest_counts=ingest_counts,
|
|
595
|
+
ingest_attempts=attempts_counter.value,
|
|
596
|
+
)
|
|
597
|
+
)
|
|
598
|
+
|
|
557
599
|
gc.collect()
|
|
558
600
|
|
|
601
|
+
return IngestDatasetJobsSummary(
|
|
602
|
+
duration_seconds=time.monotonic() - start_time,
|
|
603
|
+
num_jobs=len(jobs),
|
|
604
|
+
layer_summaries=layer_summaries,
|
|
605
|
+
)
|
|
606
|
+
|
|
559
607
|
def _load_layer_data_for_windows(
|
|
560
608
|
self, windows: list[Window], workers: int
|
|
561
609
|
) -> list[tuple[Window, dict[str, WindowLayerData]]]:
|
|
@@ -686,13 +734,16 @@ class MaterializeHandler:
|
|
|
686
734
|
"""
|
|
687
735
|
self.dataset = dataset
|
|
688
736
|
|
|
689
|
-
def __call__(
|
|
737
|
+
def __call__(
|
|
738
|
+
self, windows: list[Window]
|
|
739
|
+
) -> MaterializeDatasetWindowsSummary | ErrorOutcome:
|
|
690
740
|
"""Materializes the windows from apply_on_windows."""
|
|
691
741
|
logger.info(f"Running Materialize with {len(windows)} windows")
|
|
742
|
+
start_time = time.monotonic()
|
|
692
743
|
if self.dataset is None:
|
|
693
744
|
raise ValueError("dataset not set")
|
|
694
745
|
try:
|
|
695
|
-
materialize_dataset_windows(
|
|
746
|
+
return materialize_dataset_windows(
|
|
696
747
|
self.dataset,
|
|
697
748
|
windows,
|
|
698
749
|
retry_max_attempts=self.retry_max_attempts,
|
|
@@ -703,6 +754,7 @@ class MaterializeHandler:
|
|
|
703
754
|
logger.error(f"Error materializing windows: {e}")
|
|
704
755
|
raise
|
|
705
756
|
logger.warning(f"Ignoring error while materializing windows: {e}")
|
|
757
|
+
return ErrorOutcome(duration_seconds=time.monotonic() - start_time)
|
|
706
758
|
|
|
707
759
|
|
|
708
760
|
@register_handler("dataset", "materialize")
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -15,6 +15,7 @@ from huggingface_hub import hf_hub_download
|
|
|
15
15
|
# from claymodel.module import ClayMAEModule
|
|
16
16
|
from terratorch.models.backbones.clay_v15.module import ClayMAEModule
|
|
17
17
|
|
|
18
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
18
19
|
from rslearn.train.transforms.transform import Transform
|
|
19
20
|
|
|
20
21
|
|
|
@@ -163,13 +164,36 @@ class Clay(torch.nn.Module):
|
|
|
163
164
|
|
|
164
165
|
|
|
165
166
|
class ClayNormalize(Transform):
|
|
166
|
-
"""Normalize inputs using Clay metadata.
|
|
167
|
+
"""Normalize inputs using Clay metadata.
|
|
168
|
+
|
|
169
|
+
For Sentinel-1, the intensities should be converted to decibels.
|
|
170
|
+
"""
|
|
167
171
|
|
|
168
172
|
def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
|
|
169
173
|
"""Initialize ClayNormalize."""
|
|
170
174
|
super().__init__()
|
|
171
175
|
with open(metadata_path) as f:
|
|
172
|
-
|
|
176
|
+
metadata = yaml.safe_load(f)
|
|
177
|
+
normalizers = {}
|
|
178
|
+
for modality in CLAY_MODALITIES:
|
|
179
|
+
if modality not in metadata:
|
|
180
|
+
continue
|
|
181
|
+
modality_metadata = metadata[modality]
|
|
182
|
+
means = [
|
|
183
|
+
modality_metadata["bands"]["mean"][b]
|
|
184
|
+
for b in modality_metadata["band_order"]
|
|
185
|
+
]
|
|
186
|
+
stds = [
|
|
187
|
+
modality_metadata["bands"]["std"][b]
|
|
188
|
+
for b in modality_metadata["band_order"]
|
|
189
|
+
]
|
|
190
|
+
normalizers[modality] = Normalize(
|
|
191
|
+
mean=means,
|
|
192
|
+
std=stds,
|
|
193
|
+
selectors=[modality],
|
|
194
|
+
num_bands=len(means),
|
|
195
|
+
)
|
|
196
|
+
self.normalizers = torch.nn.ModuleDict(normalizers)
|
|
173
197
|
|
|
174
198
|
def apply_image(
|
|
175
199
|
self, image: torch.Tensor, means: list[float], stds: list[float]
|
|
@@ -188,17 +212,8 @@ class ClayNormalize(Transform):
|
|
|
188
212
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
189
213
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
190
214
|
"""Normalize the specified image with Clay normalization."""
|
|
191
|
-
for modality in
|
|
192
|
-
if modality not in input_dict
|
|
215
|
+
for modality, normalizer in self.normalizers.items():
|
|
216
|
+
if modality not in input_dict:
|
|
193
217
|
continue
|
|
194
|
-
|
|
195
|
-
means = [
|
|
196
|
-
modality_metadata["bands"]["mean"][b]
|
|
197
|
-
for b in modality_metadata["band_order"]
|
|
198
|
-
]
|
|
199
|
-
stds = [
|
|
200
|
-
modality_metadata["bands"]["std"][b]
|
|
201
|
-
for b in modality_metadata["band_order"]
|
|
202
|
-
]
|
|
203
|
-
input_dict[modality] = self.apply_image(input_dict[modality], means, stds)
|
|
218
|
+
input_dict, target_dict = normalizer(input_dict, target_dict)
|
|
204
219
|
return input_dict, target_dict
|
rslearn/models/copernicusfm.py
CHANGED
|
@@ -3,11 +3,12 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
5
|
from enum import Enum
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn.functional as F
|
|
9
10
|
from einops import rearrange
|
|
10
|
-
from
|
|
11
|
+
from huggingface_hub import hf_hub_download
|
|
11
12
|
|
|
12
13
|
from .copernicusfm_src.model_vit import vit_base_patch16
|
|
13
14
|
|
|
@@ -64,6 +65,10 @@ MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
|
|
|
64
65
|
},
|
|
65
66
|
}
|
|
66
67
|
|
|
68
|
+
HF_REPO_ID = "wangyi111/Copernicus-FM"
|
|
69
|
+
HF_REPO_REVISION = "e1db406d517a122c8373802e1c130c5fc4789f84"
|
|
70
|
+
HF_FILENAME = "CopernicusFM_ViT_base_varlang_e100.pth"
|
|
71
|
+
|
|
67
72
|
|
|
68
73
|
class CopernicusFM(torch.nn.Module):
|
|
69
74
|
"""Wrapper for Copernicus FM to ingest Masked Helios Sample."""
|
|
@@ -80,44 +85,51 @@ class CopernicusFM(torch.nn.Module):
|
|
|
80
85
|
def __init__(
|
|
81
86
|
self,
|
|
82
87
|
band_order: dict[str, list[str]],
|
|
83
|
-
|
|
88
|
+
cache_dir: str | Path | None = None,
|
|
84
89
|
) -> None:
|
|
85
90
|
"""Initialize the Copernicus FM wrapper.
|
|
86
91
|
|
|
87
92
|
Args:
|
|
88
|
-
band_order: The band order for each modality
|
|
89
|
-
|
|
93
|
+
band_order: The band order for each modality that will be used. The bands
|
|
94
|
+
can be provided in any order, and any subset can be used.
|
|
95
|
+
cache_dir: The directory to cache the weights. If None, a default directory
|
|
96
|
+
managed by huggingface_hub is used. The weights are downloaded from
|
|
97
|
+
Hugging Face (https://huggingface.co/wangyi111/Copernicus-FM).
|
|
90
98
|
"""
|
|
91
99
|
super().__init__()
|
|
92
100
|
|
|
101
|
+
# Make sure all keys in band_order are in supported_modalities.
|
|
102
|
+
for modality_name in band_order.keys():
|
|
103
|
+
if modality_name in self.supported_modalities:
|
|
104
|
+
continue
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"band_order contains unsupported modality {modality_name}"
|
|
107
|
+
)
|
|
108
|
+
|
|
93
109
|
# global_pool=True so that we initialize the fc_norm layer
|
|
94
|
-
self.band_order = band_order
|
|
95
110
|
self.model = vit_base_patch16(num_classes=10, global_pool=True)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
# take MODALITY_TO_WAVELENGTH_BANDWIDTHS and
|
|
108
|
-
# ordering as the
|
|
111
|
+
|
|
112
|
+
# Load weights, downloading if needed.
|
|
113
|
+
local_fname = hf_hub_download(
|
|
114
|
+
repo_id=HF_REPO_ID,
|
|
115
|
+
revision=HF_REPO_REVISION,
|
|
116
|
+
filename=HF_FILENAME,
|
|
117
|
+
local_dir=cache_dir,
|
|
118
|
+
) # nosec
|
|
119
|
+
state_dict = torch.load(local_fname, weights_only=True)
|
|
120
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
121
|
+
|
|
122
|
+
# take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrange it so that it has the same
|
|
123
|
+
# ordering as the user-provided band order.
|
|
109
124
|
self.modality_to_wavelength_bandwidths = {}
|
|
110
125
|
for modality in self.supported_modalities:
|
|
126
|
+
if modality not in band_order:
|
|
127
|
+
continue
|
|
128
|
+
|
|
111
129
|
wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
|
|
112
130
|
wavelengths = []
|
|
113
131
|
bandwidths = []
|
|
114
|
-
|
|
115
|
-
if modality_band_order is None:
|
|
116
|
-
logger.warning(
|
|
117
|
-
f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
|
|
118
|
-
)
|
|
119
|
-
continue
|
|
120
|
-
for b in modality_band_order:
|
|
132
|
+
for b in band_order[modality]:
|
|
121
133
|
cfm_idx = wavelength_bandwidths["band_names"].index(b)
|
|
122
134
|
wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
|
|
123
135
|
bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
|
rslearn/models/dinov3.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""DinoV3 model."""
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torchvision
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
12
|
+
from rslearn.train.transforms.transform import Transform
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DinoV3Models(StrEnum):
|
|
16
|
+
"""Names for different DinoV3 images on torch hub."""
|
|
17
|
+
|
|
18
|
+
SMALL_WEB = "dinov3_vits16"
|
|
19
|
+
SMALL_PLUS_WEB = "dinov3_vits16plus"
|
|
20
|
+
BASE_WEB = "dinov3_vitb16"
|
|
21
|
+
LARGE_WEB = "dinov3_vitl16"
|
|
22
|
+
HUGE_PLUS_WEB = "dinov3_vith16plus"
|
|
23
|
+
FULL_7B_WEB = "dinov3_vit7b16"
|
|
24
|
+
LARGE_SATELLITE = "dinov3_vitl16_sat"
|
|
25
|
+
FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
DINOV3_PTHS: dict[str, str] = {
|
|
29
|
+
DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
|
|
30
|
+
DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
|
|
31
|
+
DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
|
|
32
|
+
DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
|
|
33
|
+
DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
|
|
34
|
+
DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DinoV3(torch.nn.Module):
|
|
39
|
+
"""DinoV3 Backbones.
|
|
40
|
+
|
|
41
|
+
Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
|
|
42
|
+
See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
|
|
43
|
+
|
|
44
|
+
Only takes RGB as input. Expects normalized data (use the below normalizer).
|
|
45
|
+
|
|
46
|
+
Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
|
|
47
|
+
segmentation or detection tasks with inputs larger than 256x256, it may be best to
|
|
48
|
+
train and predict on 256x256 crops (using SplitConfig.patch_size argument).
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
image_size: int = 256
|
|
52
|
+
patch_size: int = 16
|
|
53
|
+
output_dim: int = 1024
|
|
54
|
+
|
|
55
|
+
def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
|
|
56
|
+
model_name = size.replace("_sat", "")
|
|
57
|
+
if checkpoint_dir is not None:
|
|
58
|
+
weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
|
|
59
|
+
return torch.hub.load(
|
|
60
|
+
"facebookresearch/dinov3",
|
|
61
|
+
model_name,
|
|
62
|
+
weights=weights,
|
|
63
|
+
) # nosec
|
|
64
|
+
return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
checkpoint_dir: str | None,
|
|
69
|
+
size: str = DinoV3Models.LARGE_SATELLITE,
|
|
70
|
+
use_cls_token: bool = False,
|
|
71
|
+
do_resizing: bool = True,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Instantiate a new DinoV3 instance.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
|
|
77
|
+
only (randomly initialized).
|
|
78
|
+
size: the model size, see class for various models.
|
|
79
|
+
use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
|
|
80
|
+
do_resizing: whether to resize inputs to 256x256. Default true.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.size = size
|
|
84
|
+
self.checkpoint_dir = checkpoint_dir
|
|
85
|
+
self.use_cls_token = use_cls_token
|
|
86
|
+
self.do_resizing = do_resizing
|
|
87
|
+
self.model = self._load_model(size, checkpoint_dir)
|
|
88
|
+
|
|
89
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
90
|
+
"""Forward pass for the dinov3 model.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
inputs: input dicts that must include "image" key.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List[torch.Tensor]: Single-scale feature tensors from the encoder.
|
|
97
|
+
"""
|
|
98
|
+
cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
|
|
99
|
+
|
|
100
|
+
if self.do_resizing and (
|
|
101
|
+
cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
|
|
102
|
+
):
|
|
103
|
+
cur = torchvision.transforms.functional.resize(
|
|
104
|
+
cur,
|
|
105
|
+
[self.image_size, self.image_size],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if self.use_cls_token:
|
|
109
|
+
features = self.model(cur)
|
|
110
|
+
else:
|
|
111
|
+
features = self.model.forward_features(cur)["x_norm_patchtokens"]
|
|
112
|
+
batch_size, num_patches, _ = features.shape
|
|
113
|
+
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
114
|
+
features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
|
|
115
|
+
|
|
116
|
+
return [features]
|
|
117
|
+
|
|
118
|
+
def get_backbone_channels(self) -> list:
|
|
119
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
120
|
+
|
|
121
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
122
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
123
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
124
|
+
has 32 channels.
|
|
125
|
+
"""
|
|
126
|
+
return [(self.patch_size, self.output_dim)]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DinoV3Normalize(Transform):
|
|
130
|
+
"""Normalize inputs using DinoV3 normalization.
|
|
131
|
+
|
|
132
|
+
Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
|
|
133
|
+
|
|
134
|
+
Input "image" should be RGB-like image between 0-255.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(self, satellite: bool = True):
|
|
138
|
+
"""Initialize a new DinoV3Normalize."""
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.satellite = satellite
|
|
141
|
+
if satellite:
|
|
142
|
+
mean = [0.430, 0.411, 0.296]
|
|
143
|
+
std = [0.213, 0.156, 0.143]
|
|
144
|
+
else:
|
|
145
|
+
mean = [0.485, 0.456, 0.406]
|
|
146
|
+
std = [0.229, 0.224, 0.225]
|
|
147
|
+
|
|
148
|
+
self.normalize = Normalize(
|
|
149
|
+
[value * 255 for value in mean],
|
|
150
|
+
[value * 255 for value in std],
|
|
151
|
+
num_bands=3,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def forward(
|
|
155
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
156
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
157
|
+
"""Normalize the specified image with DinoV3 normalization.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
input_dict: the input dictionary.
|
|
161
|
+
target_dict: the target dictionary.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
normalized (input_dicts, target_dicts) tuple
|
|
165
|
+
"""
|
|
166
|
+
return self.normalize(input_dict, target_dict)
|