rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/models/croma.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""CROMA models."""
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
import tempfile
|
|
5
|
+
import urllib.request
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from einops import rearrange
|
|
12
|
+
from upath import UPath
|
|
13
|
+
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.train.model_context import ModelContext
|
|
16
|
+
from rslearn.train.transforms.transform import Transform
|
|
17
|
+
from rslearn.utils.fsspec import open_atomic
|
|
18
|
+
|
|
19
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
20
|
+
from .use_croma import PretrainedCROMA
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CromaSize(str, Enum):
|
|
26
|
+
"""CROMA model size."""
|
|
27
|
+
|
|
28
|
+
BASE = "base"
|
|
29
|
+
LARGE = "large"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CromaModality(str, Enum):
|
|
33
|
+
"""CROMA model configured input modalities."""
|
|
34
|
+
|
|
35
|
+
BOTH = "both"
|
|
36
|
+
SENTINEL1 = "SAR"
|
|
37
|
+
SENTINEL2 = "optical"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
PATCH_SIZE = 8
|
|
41
|
+
DEFAULT_IMAGE_RESOLUTION = 120
|
|
42
|
+
PRETRAINED_URLS: dict[CromaSize, str] = {
|
|
43
|
+
CromaSize.BASE: "https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_base.pt",
|
|
44
|
+
CromaSize.LARGE: "https://huggingface.co/antofuller/CROMA/resolve/main/CROMA_large.pt",
|
|
45
|
+
}
|
|
46
|
+
MEAN_AND_STD_BY_BAND: dict[tuple[str, str], tuple[float, float]] = {
|
|
47
|
+
("sentinel1", "vv"): (0.15, 0.82),
|
|
48
|
+
("sentinel1", "vh"): (0.03, 0.15),
|
|
49
|
+
("sentinel2", "B01"): (1116, 1956),
|
|
50
|
+
("sentinel2", "B02"): (1189, 1859),
|
|
51
|
+
("sentinel2", "B03"): (1408, 1728),
|
|
52
|
+
("sentinel2", "B04"): (1513, 1741),
|
|
53
|
+
("sentinel2", "B05"): (1891, 1755),
|
|
54
|
+
("sentinel2", "B06"): (2484, 1622),
|
|
55
|
+
("sentinel2", "B07"): (2723, 1622),
|
|
56
|
+
("sentinel2", "B08"): (2755, 1612),
|
|
57
|
+
("sentinel2", "B8A"): (2886, 1611),
|
|
58
|
+
("sentinel2", "B09"): (3270, 2651),
|
|
59
|
+
("sentinel2", "B11"): (2563, 1442),
|
|
60
|
+
("sentinel2", "B12"): (1914, 1329),
|
|
61
|
+
}
|
|
62
|
+
MODALITY_BANDS = {
|
|
63
|
+
"sentinel1": ["vv", "vh"],
|
|
64
|
+
"sentinel2": [
|
|
65
|
+
"B01",
|
|
66
|
+
"B02",
|
|
67
|
+
"B03",
|
|
68
|
+
"B04",
|
|
69
|
+
"B05",
|
|
70
|
+
"B06",
|
|
71
|
+
"B07",
|
|
72
|
+
"B08",
|
|
73
|
+
"B8A",
|
|
74
|
+
"B09",
|
|
75
|
+
"B11",
|
|
76
|
+
"B12",
|
|
77
|
+
],
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Croma(FeatureExtractor):
|
|
82
|
+
"""CROMA backbones.
|
|
83
|
+
|
|
84
|
+
There are two model sizes, base and large.
|
|
85
|
+
|
|
86
|
+
The model can be applied with just Sentinel-1, just Sentinel-2, or both. The input
|
|
87
|
+
must be defined a priori by passing the corresponding CromaModality. Sentinel-1
|
|
88
|
+
images should be passed under the "sentinel1" key while Sentinel-2 images should be
|
|
89
|
+
passed under the "sentinel2" key. Only a single timestep can be provided.
|
|
90
|
+
|
|
91
|
+
The band order for Sentinel-1 is: vv, vh.
|
|
92
|
+
|
|
93
|
+
The band order for Sentinel-2 is: B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09,
|
|
94
|
+
B11, B12. It is trained on L1C images with B10 removed.
|
|
95
|
+
|
|
96
|
+
See https://github.com/antofuller/CROMA for more details.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
size: CromaSize,
|
|
102
|
+
modality: CromaModality,
|
|
103
|
+
pretrained_path: str | None = None,
|
|
104
|
+
image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
|
|
105
|
+
do_resizing: bool = False,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Instantiate a new Croma instance.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
size: the model size, either base or large.
|
|
111
|
+
modality: the modalities to configure the model to accept.
|
|
112
|
+
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
113
|
+
downloaded and cached in temp directory.
|
|
114
|
+
image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
|
|
115
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
116
|
+
"""
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.size = size
|
|
119
|
+
self.modality = modality
|
|
120
|
+
self.do_resizing = do_resizing
|
|
121
|
+
if not do_resizing:
|
|
122
|
+
self.image_resolution = image_resolution
|
|
123
|
+
else:
|
|
124
|
+
# With single pixel input, we always resample to the patch size.
|
|
125
|
+
if image_resolution == 1:
|
|
126
|
+
self.image_resolution = PATCH_SIZE
|
|
127
|
+
else:
|
|
128
|
+
self.image_resolution = DEFAULT_IMAGE_RESOLUTION
|
|
129
|
+
|
|
130
|
+
# Cache the CROMA weights to a deterministic path in temporary directory if the
|
|
131
|
+
# path is not provided by the user.
|
|
132
|
+
if pretrained_path is None:
|
|
133
|
+
pretrained_url = PRETRAINED_URLS[self.size]
|
|
134
|
+
local_fname = UPath(
|
|
135
|
+
tempfile.gettempdir(), "rslearn_cache", "croma", f"{self.size.value}.pt"
|
|
136
|
+
)
|
|
137
|
+
if not local_fname.exists():
|
|
138
|
+
logger.info(
|
|
139
|
+
"caching CROMA weights from %s to %s", pretrained_url, local_fname
|
|
140
|
+
)
|
|
141
|
+
local_fname.parent.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
with urllib.request.urlopen(pretrained_url) as response:
|
|
143
|
+
with open_atomic(local_fname, "wb") as f:
|
|
144
|
+
shutil.copyfileobj(response, f)
|
|
145
|
+
else:
|
|
146
|
+
logger.info("using cached CROMA weights at %s", local_fname)
|
|
147
|
+
pretrained_path = local_fname.path
|
|
148
|
+
|
|
149
|
+
self.model = PretrainedCROMA(
|
|
150
|
+
pretrained_path=pretrained_path,
|
|
151
|
+
size=size.value,
|
|
152
|
+
modality=modality.value,
|
|
153
|
+
image_resolution=self.image_resolution,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
157
|
+
"""Resize the image to the input resolution."""
|
|
158
|
+
return F.interpolate(
|
|
159
|
+
image,
|
|
160
|
+
size=(self.image_resolution, self.image_resolution),
|
|
161
|
+
mode="bilinear",
|
|
162
|
+
align_corners=False,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
166
|
+
"""Compute feature maps from the Croma backbone.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
context: the model context. Input dicts must include either/both of
|
|
170
|
+
"sentinel2" or "sentinel1" keys depending on the configured modality.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
a FeatureMaps with one feature map at 1/8 the input resolution.
|
|
174
|
+
"""
|
|
175
|
+
sentinel1: torch.Tensor | None = None
|
|
176
|
+
sentinel2: torch.Tensor | None = None
|
|
177
|
+
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
178
|
+
sentinel1 = torch.stack(
|
|
179
|
+
[inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
180
|
+
dim=0,
|
|
181
|
+
)
|
|
182
|
+
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
183
|
+
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
184
|
+
sentinel2 = torch.stack(
|
|
185
|
+
[inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
186
|
+
dim=0,
|
|
187
|
+
)
|
|
188
|
+
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
189
|
+
|
|
190
|
+
outputs = self.model(
|
|
191
|
+
SAR_images=sentinel1,
|
|
192
|
+
optical_images=sentinel2,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Pick which encoding to use.
|
|
196
|
+
# If modality is both, then there are three options, we could concatenate the
|
|
197
|
+
# SAR and optical encodings but for now we just use the joint encodings.
|
|
198
|
+
if self.modality == CromaModality.BOTH:
|
|
199
|
+
features = outputs["joint_encodings"]
|
|
200
|
+
elif self.modality == CromaModality.SENTINEL1:
|
|
201
|
+
features = outputs["SAR_encodings"]
|
|
202
|
+
elif self.modality == CromaModality.SENTINEL2:
|
|
203
|
+
features = outputs["optical_encodings"]
|
|
204
|
+
|
|
205
|
+
# Rearrange from patch embeddings to 2D feature map.
|
|
206
|
+
num_patches_per_dim = self.image_resolution // PATCH_SIZE
|
|
207
|
+
features = rearrange(
|
|
208
|
+
features,
|
|
209
|
+
"b (h w) d -> b d h w",
|
|
210
|
+
h=num_patches_per_dim,
|
|
211
|
+
w=num_patches_per_dim,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return FeatureMaps([features])
|
|
215
|
+
|
|
216
|
+
def get_backbone_channels(self) -> list:
|
|
217
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
218
|
+
|
|
219
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
220
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
221
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
222
|
+
has 32 channels.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
226
|
+
tuples.
|
|
227
|
+
"""
|
|
228
|
+
if self.size == CromaSize.BASE:
|
|
229
|
+
depth = 768
|
|
230
|
+
elif self.size == CromaSize.LARGE:
|
|
231
|
+
depth = 1024
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError(f"unknown CromaSize {self.size}")
|
|
234
|
+
return [(PATCH_SIZE, depth)]
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class CromaNormalize(Transform):
|
|
238
|
+
"""Normalize inputs using CROMA normalization.
|
|
239
|
+
|
|
240
|
+
It will apply normalization to the "sentinel1" and "sentinel2" input keys (if set).
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(self) -> None:
|
|
244
|
+
"""Initialize a new CromaNormalize."""
|
|
245
|
+
super().__init__()
|
|
246
|
+
|
|
247
|
+
def apply_image(self, image: torch.Tensor, modality: str) -> torch.Tensor:
|
|
248
|
+
"""Normalize the specified image with CROMA normalization.
|
|
249
|
+
|
|
250
|
+
CROMA normalized based on batch statistics, but we may apply the model with
|
|
251
|
+
small batches, so we instead use preset statistics corresponding to the dataset
|
|
252
|
+
distribution.
|
|
253
|
+
|
|
254
|
+
The normalized value is based on clipping to [mean-2*std, mean+2*std] and then
|
|
255
|
+
linear rescaling to [0, 1].
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
image: the image to transform.
|
|
259
|
+
modality: the modality of the image.
|
|
260
|
+
mean: the mean to use for the normalization.
|
|
261
|
+
std: the standard deviation to use for the normalization.
|
|
262
|
+
"""
|
|
263
|
+
image = image.float()
|
|
264
|
+
|
|
265
|
+
# Number of channels must be a multiple of the expected number of bands for
|
|
266
|
+
# this modality. It can be a multiple since we accept stacked time series.
|
|
267
|
+
band_names = MODALITY_BANDS[modality]
|
|
268
|
+
if image.shape[0] % len(band_names) != 0:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
f"image has {image.shape[0]} channels for modality {modality} which is not a multiple of expected number of bands {len(band_names)}"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
normalized_bands = []
|
|
274
|
+
for band_idx in range(image.shape[0]):
|
|
275
|
+
band_name = band_names[band_idx % len(band_names)]
|
|
276
|
+
mean, std = MEAN_AND_STD_BY_BAND[(modality, band_name)]
|
|
277
|
+
|
|
278
|
+
orig = image[band_idx, :, :]
|
|
279
|
+
min_value = mean - 2 * std
|
|
280
|
+
max_value = mean + 2 * std
|
|
281
|
+
|
|
282
|
+
normalized = (orig - min_value) / (max_value - min_value)
|
|
283
|
+
normalized = torch.clip(normalized, 0, 1)
|
|
284
|
+
normalized_bands.append(normalized)
|
|
285
|
+
|
|
286
|
+
return torch.stack(normalized_bands, dim=0)
|
|
287
|
+
|
|
288
|
+
def forward(
|
|
289
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
290
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
291
|
+
"""Apply normalization over the inputs and targets.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
input_dict: the input
|
|
295
|
+
target_dict: the target
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
normalized (input_dicts, target_dicts) tuple
|
|
299
|
+
"""
|
|
300
|
+
for modality in MODALITY_BANDS.keys():
|
|
301
|
+
if modality not in input_dict:
|
|
302
|
+
continue
|
|
303
|
+
input_dict[modality].image = self.apply_image(
|
|
304
|
+
input_dict[modality].image, modality
|
|
305
|
+
)
|
|
306
|
+
return input_dict, target_dict
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Utilities for bounding box manipulation and GIoU.
|
|
2
|
+
|
|
3
|
+
This is copied from https://github.com/facebookresearch/detr/.
|
|
4
|
+
The original code is:
|
|
5
|
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torchvision.ops.boxes import box_area
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def box_cxcywh_to_xyxy(x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
"""Convert boxes from cxcywh format to xyxy format."""
|
|
14
|
+
x_c, y_c, w, h = x.unbind(-1)
|
|
15
|
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
16
|
+
return torch.stack(b, dim=-1)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def box_xyxy_to_cxcywh(x: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
"""Convert boxes from xyxy format to cxcywh format."""
|
|
21
|
+
x0, y0, x1, y1 = x.unbind(-1)
|
|
22
|
+
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
|
23
|
+
return torch.stack(b, dim=-1)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# modified from torchvision to also return the union
|
|
27
|
+
def box_iou(
|
|
28
|
+
boxes1: torch.Tensor, boxes2: torch.Tensor
|
|
29
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
30
|
+
"""Compute the intersection-over-union score between the two lists of boxes.
|
|
31
|
+
|
|
32
|
+
The boxes should be in xyxy format.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
boxes1: the first list of boxes (Nx4).
|
|
36
|
+
boxes2: the second list of boxes (Mx4).
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
the intersection-over-union score.
|
|
40
|
+
"""
|
|
41
|
+
area1 = box_area(boxes1)
|
|
42
|
+
area2 = box_area(boxes2)
|
|
43
|
+
|
|
44
|
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
45
|
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
46
|
+
|
|
47
|
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
48
|
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
49
|
+
|
|
50
|
+
union = area1[:, None] + area2 - inter
|
|
51
|
+
|
|
52
|
+
iou = inter / union
|
|
53
|
+
return iou, union
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def generalized_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
"""Generalized IoU from https://giou.stanford.edu/.
|
|
58
|
+
|
|
59
|
+
The boxes should be in [x0, y0, x1, y1] format
|
|
60
|
+
|
|
61
|
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
|
62
|
+
and M = len(boxes2)
|
|
63
|
+
"""
|
|
64
|
+
# degenerate boxes gives inf / nan results
|
|
65
|
+
# so do an early check
|
|
66
|
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
|
67
|
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
|
68
|
+
iou, union = box_iou(boxes1, boxes2)
|
|
69
|
+
|
|
70
|
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
71
|
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
72
|
+
|
|
73
|
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
74
|
+
area = wh[:, :, 0] * wh[:, :, 1]
|
|
75
|
+
|
|
76
|
+
return iou - (area - union) / area
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
"""Compute the bounding boxes around the provided masks.
|
|
81
|
+
|
|
82
|
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
|
83
|
+
|
|
84
|
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
|
85
|
+
"""
|
|
86
|
+
if masks.numel() == 0:
|
|
87
|
+
return torch.zeros((0, 4), device=masks.device)
|
|
88
|
+
|
|
89
|
+
h, w = masks.shape[-2:]
|
|
90
|
+
|
|
91
|
+
y = torch.arange(0, h, dtype=torch.float)
|
|
92
|
+
x = torch.arange(0, w, dtype=torch.float)
|
|
93
|
+
y, x = torch.meshgrid(y, x)
|
|
94
|
+
|
|
95
|
+
x_mask = masks * x.unsqueeze(0)
|
|
96
|
+
x_max = x_mask.flatten(1).max(-1)[0]
|
|
97
|
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
98
|
+
|
|
99
|
+
y_mask = masks * y.unsqueeze(0)
|
|
100
|
+
y_max = y_mask.flatten(1).max(-1)[0]
|
|
101
|
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
102
|
+
|
|
103
|
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|