rslearn 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/dataset/handler_summaries.py +130 -0
- rslearn/dataset/manage.py +157 -22
- rslearn/main.py +60 -8
- rslearn/models/anysat.py +207 -0
- rslearn/models/clay/clay.py +219 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/copernicusfm.py +37 -25
- rslearn/models/dinov3.py +165 -0
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +517 -0
- rslearn/models/galileo/single_file_galileo.py +1672 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/presto/presto.py +10 -7
- rslearn/models/prithvi.py +1122 -0
- rslearn/models/resize_features.py +45 -0
- rslearn/models/simple_time_series.py +65 -10
- rslearn/models/unet.py +17 -11
- rslearn/models/upsample.py +2 -2
- rslearn/tile_stores/default.py +31 -6
- rslearn/train/transforms/normalize.py +34 -5
- rslearn/train/transforms/select_bands.py +67 -0
- rslearn/train/transforms/sentinel1.py +60 -0
- rslearn/utils/geometry.py +61 -1
- rslearn/utils/raster_format.py +7 -1
- rslearn/utils/vector_format.py +13 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/METADATA +144 -15
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/RECORD +42 -18
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/WHEEL +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/top_level.txt +0 -0
rslearn/models/copernicusfm.py
CHANGED
|
@@ -3,11 +3,12 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
5
|
from enum import Enum
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn.functional as F
|
|
9
10
|
from einops import rearrange
|
|
10
|
-
from
|
|
11
|
+
from huggingface_hub import hf_hub_download
|
|
11
12
|
|
|
12
13
|
from .copernicusfm_src.model_vit import vit_base_patch16
|
|
13
14
|
|
|
@@ -64,6 +65,10 @@ MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
|
|
|
64
65
|
},
|
|
65
66
|
}
|
|
66
67
|
|
|
68
|
+
HF_REPO_ID = "wangyi111/Copernicus-FM"
|
|
69
|
+
HF_REPO_REVISION = "e1db406d517a122c8373802e1c130c5fc4789f84"
|
|
70
|
+
HF_FILENAME = "CopernicusFM_ViT_base_varlang_e100.pth"
|
|
71
|
+
|
|
67
72
|
|
|
68
73
|
class CopernicusFM(torch.nn.Module):
|
|
69
74
|
"""Wrapper for Copernicus FM to ingest Masked Helios Sample."""
|
|
@@ -80,44 +85,51 @@ class CopernicusFM(torch.nn.Module):
|
|
|
80
85
|
def __init__(
|
|
81
86
|
self,
|
|
82
87
|
band_order: dict[str, list[str]],
|
|
83
|
-
|
|
88
|
+
cache_dir: str | Path | None = None,
|
|
84
89
|
) -> None:
|
|
85
90
|
"""Initialize the Copernicus FM wrapper.
|
|
86
91
|
|
|
87
92
|
Args:
|
|
88
|
-
band_order: The band order for each modality
|
|
89
|
-
|
|
93
|
+
band_order: The band order for each modality that will be used. The bands
|
|
94
|
+
can be provided in any order, and any subset can be used.
|
|
95
|
+
cache_dir: The directory to cache the weights. If None, a default directory
|
|
96
|
+
managed by huggingface_hub is used. The weights are downloaded from
|
|
97
|
+
Hugging Face (https://huggingface.co/wangyi111/Copernicus-FM).
|
|
90
98
|
"""
|
|
91
99
|
super().__init__()
|
|
92
100
|
|
|
101
|
+
# Make sure all keys in band_order are in supported_modalities.
|
|
102
|
+
for modality_name in band_order.keys():
|
|
103
|
+
if modality_name in self.supported_modalities:
|
|
104
|
+
continue
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"band_order contains unsupported modality {modality_name}"
|
|
107
|
+
)
|
|
108
|
+
|
|
93
109
|
# global_pool=True so that we initialize the fc_norm layer
|
|
94
|
-
self.band_order = band_order
|
|
95
110
|
self.model = vit_base_patch16(num_classes=10, global_pool=True)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
# take MODALITY_TO_WAVELENGTH_BANDWIDTHS and
|
|
108
|
-
# ordering as the
|
|
111
|
+
|
|
112
|
+
# Load weights, downloading if needed.
|
|
113
|
+
local_fname = hf_hub_download(
|
|
114
|
+
repo_id=HF_REPO_ID,
|
|
115
|
+
revision=HF_REPO_REVISION,
|
|
116
|
+
filename=HF_FILENAME,
|
|
117
|
+
local_dir=cache_dir,
|
|
118
|
+
) # nosec
|
|
119
|
+
state_dict = torch.load(local_fname, weights_only=True)
|
|
120
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
121
|
+
|
|
122
|
+
# take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrange it so that it has the same
|
|
123
|
+
# ordering as the user-provided band order.
|
|
109
124
|
self.modality_to_wavelength_bandwidths = {}
|
|
110
125
|
for modality in self.supported_modalities:
|
|
126
|
+
if modality not in band_order:
|
|
127
|
+
continue
|
|
128
|
+
|
|
111
129
|
wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
|
|
112
130
|
wavelengths = []
|
|
113
131
|
bandwidths = []
|
|
114
|
-
|
|
115
|
-
if modality_band_order is None:
|
|
116
|
-
logger.warning(
|
|
117
|
-
f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
|
|
118
|
-
)
|
|
119
|
-
continue
|
|
120
|
-
for b in modality_band_order:
|
|
132
|
+
for b in band_order[modality]:
|
|
121
133
|
cfm_idx = wavelength_bandwidths["band_names"].index(b)
|
|
122
134
|
wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
|
|
123
135
|
bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
|
rslearn/models/dinov3.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""DinoV3 model."""
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torchvision
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from torchvision.transforms import v2
|
|
11
|
+
|
|
12
|
+
from rslearn.train.transforms.transform import Transform
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DinoV3Models(StrEnum):
|
|
16
|
+
"""Names for different DinoV3 images on torch hub."""
|
|
17
|
+
|
|
18
|
+
SMALL_WEB = "dinov3_vits16"
|
|
19
|
+
SMALL_PLUS_WEB = "dinov3_vits16plus"
|
|
20
|
+
BASE_WEB = "dinov3_vitb16"
|
|
21
|
+
LARGE_WEB = "dinov3_vitl16"
|
|
22
|
+
HUGE_PLUS_WEB = "dinov3_vith16plus"
|
|
23
|
+
FULL_7B_WEB = "dinov3_vit7b16"
|
|
24
|
+
LARGE_SATELLITE = "dinov3_vitl16_sat"
|
|
25
|
+
FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
DINOV3_PTHS: dict[str, str] = {
|
|
29
|
+
DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
|
|
30
|
+
DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
|
|
31
|
+
DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
|
|
32
|
+
DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
|
|
33
|
+
DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
|
|
34
|
+
DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DinoV3(torch.nn.Module):
|
|
39
|
+
"""DinoV3 Backbones.
|
|
40
|
+
|
|
41
|
+
Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
|
|
42
|
+
See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
|
|
43
|
+
|
|
44
|
+
Only takes RGB as input. Expects normalized data (use the below normalizer).
|
|
45
|
+
|
|
46
|
+
Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
|
|
47
|
+
segmentation or detection tasks with inputs larger than 256x256, it may be best to
|
|
48
|
+
train and predict on 256x256 crops (using SplitConfig.patch_size argument).
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
image_size: int = 256
|
|
52
|
+
patch_size: int = 16
|
|
53
|
+
output_dim: int = 1024
|
|
54
|
+
|
|
55
|
+
def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
|
|
56
|
+
model_name = size.replace("_sat", "")
|
|
57
|
+
if checkpoint_dir is not None:
|
|
58
|
+
weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
|
|
59
|
+
return torch.hub.load(
|
|
60
|
+
"facebookresearch/dinov3",
|
|
61
|
+
model_name,
|
|
62
|
+
weights=weights,
|
|
63
|
+
) # nosec
|
|
64
|
+
return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
checkpoint_dir: str | None,
|
|
69
|
+
size: str = DinoV3Models.LARGE_SATELLITE,
|
|
70
|
+
use_cls_token: bool = False,
|
|
71
|
+
do_resizing: bool = True,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Instantiate a new DinoV3 instance.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
|
|
77
|
+
only (randomly initialized).
|
|
78
|
+
size: the model size, see class for various models.
|
|
79
|
+
use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
|
|
80
|
+
do_resizing: whether to resize inputs to 256x256. Default true.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.size = size
|
|
84
|
+
self.checkpoint_dir = checkpoint_dir
|
|
85
|
+
self.use_cls_token = use_cls_token
|
|
86
|
+
self.do_resizing = do_resizing
|
|
87
|
+
self.model = self._load_model(size, checkpoint_dir)
|
|
88
|
+
|
|
89
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
90
|
+
"""Forward pass for the dinov3 model.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
inputs: input dicts that must include "image" key.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List[torch.Tensor]: Single-scale feature tensors from the encoder.
|
|
97
|
+
"""
|
|
98
|
+
cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
|
|
99
|
+
|
|
100
|
+
if self.do_resizing and (
|
|
101
|
+
cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
|
|
102
|
+
):
|
|
103
|
+
cur = torchvision.transforms.functional.resize(
|
|
104
|
+
cur,
|
|
105
|
+
[self.image_size, self.image_size],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if self.use_cls_token:
|
|
109
|
+
features = self.model(cur)
|
|
110
|
+
else:
|
|
111
|
+
features = self.model.forward_features(cur)["x_norm_patchtokens"]
|
|
112
|
+
batch_size, num_patches, _ = features.shape
|
|
113
|
+
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
114
|
+
features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
|
|
115
|
+
|
|
116
|
+
return [features]
|
|
117
|
+
|
|
118
|
+
def get_backbone_channels(self) -> list:
|
|
119
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
120
|
+
|
|
121
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
122
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
123
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
124
|
+
has 32 channels.
|
|
125
|
+
"""
|
|
126
|
+
return [(self.patch_size, self.output_dim)]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DinoV3Normalize(Transform):
|
|
130
|
+
"""Normalize inputs using DinoV3 normalization.
|
|
131
|
+
|
|
132
|
+
Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
|
|
133
|
+
|
|
134
|
+
Input "image" should be RGB-like image between 0-255.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(self, satellite: bool = True):
|
|
138
|
+
"""Initialize a new DinoV3Normalize."""
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.satellite = satellite
|
|
141
|
+
if satellite:
|
|
142
|
+
self.normalize = v2.Normalize(
|
|
143
|
+
mean=(0.430, 0.411, 0.296),
|
|
144
|
+
std=(0.213, 0.156, 0.143),
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
self.normalize = v2.Normalize(
|
|
148
|
+
mean=(0.485, 0.456, 0.406),
|
|
149
|
+
std=(0.229, 0.224, 0.225),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def forward(
|
|
153
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
154
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
155
|
+
"""Normalize the specified image with DinoV3 normalization.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
input_dict: the input dictionary.
|
|
159
|
+
target_dict: the target dictionary.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
normalized (input_dicts, target_dicts) tuple
|
|
163
|
+
"""
|
|
164
|
+
input_dict["image"] = self.normalize(input_dict["image"] / 255.0)
|
|
165
|
+
return input_dict, target_dict
|