rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
"""CROMA code, copied from https://github.com/antofuller/CROMA.
|
|
2
|
+
|
|
3
|
+
The code is released under:
|
|
4
|
+
|
|
5
|
+
MIT License
|
|
6
|
+
Copyright (c) 2023 Anthony Fuller
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import itertools
|
|
10
|
+
import math
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from einops import rearrange
|
|
15
|
+
from torch import einsum, nn
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PretrainedCROMA(nn.Module):
|
|
19
|
+
"""Pre-trained CROMA model."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
pretrained_path: str = "CROMA_base.pt",
|
|
24
|
+
size: str = "base",
|
|
25
|
+
modality: str = "both",
|
|
26
|
+
image_resolution: int = 120,
|
|
27
|
+
):
|
|
28
|
+
"""Create a new PretrainedCROMA.
|
|
29
|
+
|
|
30
|
+
NOTE: image_resolution is not the spatial, spectral, or temporal resolution. It is the height and width of the image, in pixels.
|
|
31
|
+
E.g., CROMA was pretrained on 120x120px images, hence image_resolution is 120 by default
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
# check types
|
|
35
|
+
assert isinstance(pretrained_path, str), (
|
|
36
|
+
f"pretrained_path must be a string, not {type(pretrained_path)}"
|
|
37
|
+
)
|
|
38
|
+
assert isinstance(size, str), f"size must be a string, not {type(size)}"
|
|
39
|
+
assert isinstance(modality, str), (
|
|
40
|
+
f"modality must be a string, not {type(modality)}"
|
|
41
|
+
)
|
|
42
|
+
assert isinstance(image_resolution, int), (
|
|
43
|
+
f"image_resolution must be an int, not {type(image_resolution)}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# check values
|
|
47
|
+
assert size in [
|
|
48
|
+
"base",
|
|
49
|
+
"large",
|
|
50
|
+
], f"size must be either base or large, not {size}"
|
|
51
|
+
assert image_resolution % 8 == 0, (
|
|
52
|
+
f"image_resolution must be a multiple of 8, not {image_resolution}"
|
|
53
|
+
)
|
|
54
|
+
assert modality in [
|
|
55
|
+
"both",
|
|
56
|
+
"SAR",
|
|
57
|
+
"optical",
|
|
58
|
+
], f"modality must be either both, SAR, or optical, not {modality}"
|
|
59
|
+
|
|
60
|
+
# warn the user if the path contains a different size than the size parameter
|
|
61
|
+
if size == "base" and "large" in pretrained_path:
|
|
62
|
+
warnings.warn(
|
|
63
|
+
"The size is set to base, but the word large appears in the pretrained path!"
|
|
64
|
+
)
|
|
65
|
+
elif size == "large" and "base" in pretrained_path:
|
|
66
|
+
warnings.warn(
|
|
67
|
+
"The size is set to large, but the word base appears in the pretrained path!"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if size == "base":
|
|
71
|
+
self.encoder_dim = 768
|
|
72
|
+
self.encoder_depth = 12
|
|
73
|
+
self.num_heads = 16
|
|
74
|
+
self.patch_size = 8
|
|
75
|
+
else:
|
|
76
|
+
# large by default
|
|
77
|
+
self.encoder_dim = 1024
|
|
78
|
+
self.encoder_depth = 24
|
|
79
|
+
self.num_heads = 16
|
|
80
|
+
self.patch_size = 8
|
|
81
|
+
|
|
82
|
+
self.modality = modality
|
|
83
|
+
self.num_patches = int((image_resolution / 8) ** 2)
|
|
84
|
+
self.s1_channels = 2 # fixed at 2 SAR backscatter channels
|
|
85
|
+
self.s2_channels = 12 # fixed at 12 multispectral optical channels
|
|
86
|
+
self.attn_bias = get_2dalibi(
|
|
87
|
+
num_heads=self.num_heads, num_patches=self.num_patches
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if modality in ["SAR", "both"]:
|
|
91
|
+
print("Initializing SAR encoder")
|
|
92
|
+
self.s1_encoder = ViT(
|
|
93
|
+
dim=self.encoder_dim,
|
|
94
|
+
depth=int(self.encoder_depth / 2),
|
|
95
|
+
in_channels=self.s1_channels,
|
|
96
|
+
)
|
|
97
|
+
self.GAP_FFN_s1 = nn.Sequential(
|
|
98
|
+
nn.LayerNorm(self.encoder_dim),
|
|
99
|
+
nn.Linear(
|
|
100
|
+
self.encoder_dim, int(4 * self.encoder_dim)
|
|
101
|
+
), # (BSZ, num_patches, inner_dim)
|
|
102
|
+
nn.GELU(), # (BSZ, num_patches, inner_dim)
|
|
103
|
+
nn.Linear(
|
|
104
|
+
int(4 * self.encoder_dim), self.encoder_dim
|
|
105
|
+
), # (BSZ, num_patches, dim)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# load weights
|
|
109
|
+
self.s1_encoder.load_state_dict(
|
|
110
|
+
torch.load(pretrained_path, weights_only=True)["s1_encoder"]
|
|
111
|
+
)
|
|
112
|
+
self.GAP_FFN_s1.load_state_dict(
|
|
113
|
+
torch.load(pretrained_path, weights_only=True)["s1_GAP_FFN"]
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if modality in ["optical", "both"]:
|
|
117
|
+
print("Initializing optical encoder")
|
|
118
|
+
self.s2_encoder = ViT(
|
|
119
|
+
dim=self.encoder_dim,
|
|
120
|
+
depth=self.encoder_depth,
|
|
121
|
+
in_channels=self.s2_channels,
|
|
122
|
+
)
|
|
123
|
+
self.GAP_FFN_s2 = nn.Sequential(
|
|
124
|
+
nn.LayerNorm(self.encoder_dim),
|
|
125
|
+
nn.Linear(
|
|
126
|
+
self.encoder_dim, int(4 * self.encoder_dim)
|
|
127
|
+
), # (BSZ, num_patches, inner_dim)
|
|
128
|
+
nn.GELU(), # (BSZ, num_patches, inner_dim)
|
|
129
|
+
nn.Linear(
|
|
130
|
+
int(4 * self.encoder_dim), self.encoder_dim
|
|
131
|
+
), # (BSZ, num_patches, dim)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# load weights
|
|
135
|
+
self.s2_encoder.load_state_dict(
|
|
136
|
+
torch.load(pretrained_path, weights_only=True)["s2_encoder"]
|
|
137
|
+
)
|
|
138
|
+
self.GAP_FFN_s2.load_state_dict(
|
|
139
|
+
torch.load(pretrained_path, weights_only=True)["s2_GAP_FFN"]
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if modality == "both":
|
|
143
|
+
print("Initializing joint SAR-optical encoder")
|
|
144
|
+
self.cross_encoder = BaseTransformerCrossAttn(
|
|
145
|
+
dim=self.encoder_dim,
|
|
146
|
+
depth=int(self.encoder_depth / 2),
|
|
147
|
+
num_heads=self.num_heads,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# load weights
|
|
151
|
+
self.cross_encoder.load_state_dict(
|
|
152
|
+
torch.load(pretrained_path, weights_only=True)["joint_encoder"]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def forward(
|
|
156
|
+
self,
|
|
157
|
+
SAR_images: torch.Tensor | None = None,
|
|
158
|
+
optical_images: torch.Tensor | None = None,
|
|
159
|
+
) -> dict[str, torch.Tensor]:
|
|
160
|
+
"""Forward pass through PretrainedCROMA."""
|
|
161
|
+
return_dict = {}
|
|
162
|
+
if self.modality in ["SAR", "both"]:
|
|
163
|
+
assert SAR_images is not None, (
|
|
164
|
+
f"Modality is set to {self.modality}, but SAR_images are None"
|
|
165
|
+
)
|
|
166
|
+
SAR_encodings = self.s1_encoder(
|
|
167
|
+
imgs=SAR_images, attn_bias=self.attn_bias.to(SAR_images.device)
|
|
168
|
+
) # (bsz, num_patches, encoder_dim)
|
|
169
|
+
SAR_GAP = self.GAP_FFN_s1(SAR_encodings.mean(dim=1)) # (bsz, encoder_dim)
|
|
170
|
+
return_dict["SAR_encodings"] = SAR_encodings
|
|
171
|
+
return_dict["SAR_GAP"] = SAR_GAP
|
|
172
|
+
|
|
173
|
+
if self.modality in ["optical", "both"]:
|
|
174
|
+
assert optical_images is not None, (
|
|
175
|
+
f"Modality is set to {self.modality}, but optical_images are None"
|
|
176
|
+
)
|
|
177
|
+
optical_encodings = self.s2_encoder(
|
|
178
|
+
imgs=optical_images, attn_bias=self.attn_bias.to(optical_images.device)
|
|
179
|
+
) # (bsz, num_patches, encoder_dim)
|
|
180
|
+
optical_GAP = self.GAP_FFN_s2(
|
|
181
|
+
optical_encodings.mean(dim=1)
|
|
182
|
+
) # (bsz, encoder_dim)
|
|
183
|
+
return_dict["optical_encodings"] = optical_encodings
|
|
184
|
+
return_dict["optical_GAP"] = optical_GAP
|
|
185
|
+
|
|
186
|
+
if self.modality == "both":
|
|
187
|
+
assert SAR_images is not None
|
|
188
|
+
assert optical_images is not None
|
|
189
|
+
joint_encodings = self.cross_encoder(
|
|
190
|
+
x=SAR_encodings,
|
|
191
|
+
context=optical_encodings,
|
|
192
|
+
relative_position_bias=self.attn_bias.to(optical_images.device),
|
|
193
|
+
) # (bsz, num_patches, encoder_dim)
|
|
194
|
+
joint_GAP = joint_encodings.mean(dim=1) # (bsz, encoder_dim)
|
|
195
|
+
return_dict["joint_encodings"] = joint_encodings
|
|
196
|
+
return_dict["joint_GAP"] = joint_GAP
|
|
197
|
+
|
|
198
|
+
return return_dict
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_2dalibi(num_heads: int, num_patches: int) -> torch.Tensor:
|
|
202
|
+
"""Get 2D bias initialization for attention layer.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
num_heads: the number of heads in the attention layer.
|
|
206
|
+
num_patches: the total number of patches, which should be a square.
|
|
207
|
+
"""
|
|
208
|
+
# inspired by: https://github.com/ofirpress/attention_with_linear_biases
|
|
209
|
+
points = list(
|
|
210
|
+
itertools.product(
|
|
211
|
+
range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches)))
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def get_slopes(n: int) -> list[float]:
|
|
216
|
+
def get_slopes_power_of_2(n: int) -> list[float]:
|
|
217
|
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
218
|
+
ratio = start
|
|
219
|
+
return [start * ratio**i for i in range(n)]
|
|
220
|
+
|
|
221
|
+
if math.log2(n).is_integer():
|
|
222
|
+
return get_slopes_power_of_2(n)
|
|
223
|
+
else:
|
|
224
|
+
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
|
225
|
+
return (
|
|
226
|
+
get_slopes_power_of_2(closest_power_of_2)
|
|
227
|
+
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1)
|
|
231
|
+
idxs = []
|
|
232
|
+
for p1 in points:
|
|
233
|
+
for p2 in points:
|
|
234
|
+
dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
|
235
|
+
idxs.append(dist * slopes * -1)
|
|
236
|
+
all_bias = torch.cat(idxs, dim=1)
|
|
237
|
+
return all_bias.view(1, num_heads, num_patches, num_patches)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class FFN(nn.Module):
|
|
241
|
+
"""Feed-forward network block."""
|
|
242
|
+
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
dim: int,
|
|
246
|
+
mult: int = 4,
|
|
247
|
+
dropout: float = 0.0,
|
|
248
|
+
):
|
|
249
|
+
"""Create a new FFN.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
dim: the input dimension.
|
|
253
|
+
mult: the MLP factor (how much larger the hidden dimension should be).
|
|
254
|
+
dropout: the dropout rate.
|
|
255
|
+
"""
|
|
256
|
+
super().__init__()
|
|
257
|
+
inner_dim = int(dim * mult)
|
|
258
|
+
|
|
259
|
+
self.net = nn.Sequential(
|
|
260
|
+
nn.Linear(dim, inner_dim), # (BSZ, num_patches, inner_dim)
|
|
261
|
+
nn.GELU(), # (BSZ, num_patches, inner_dim)
|
|
262
|
+
nn.Dropout(dropout), # (BSZ, num_patches, inner_dim)
|
|
263
|
+
nn.Linear(inner_dim, dim), # (BSZ, num_patches, dim)
|
|
264
|
+
)
|
|
265
|
+
self.input_norm = nn.LayerNorm(dim)
|
|
266
|
+
|
|
267
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
268
|
+
"""Forward pass through the FFN."""
|
|
269
|
+
x = self.input_norm(x) # (BSZ, num_patches, dim)
|
|
270
|
+
return self.net(x) # (BSZ, num_patches, dim)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class Attention(nn.Module):
|
|
274
|
+
"""Attention block."""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
dim: int,
|
|
279
|
+
num_heads: int = 8,
|
|
280
|
+
dropout: float = 0.0,
|
|
281
|
+
):
|
|
282
|
+
"""Create a new Attention."""
|
|
283
|
+
super().__init__()
|
|
284
|
+
self.num_heads = num_heads
|
|
285
|
+
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
|
|
286
|
+
dim_head = int(dim / num_heads)
|
|
287
|
+
self.scale = dim_head**-0.5
|
|
288
|
+
|
|
289
|
+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
290
|
+
self.to_out = nn.Linear(dim, dim)
|
|
291
|
+
self.input_norm = nn.LayerNorm(dim)
|
|
292
|
+
self.dropout = nn.Dropout(dropout)
|
|
293
|
+
|
|
294
|
+
def forward(
|
|
295
|
+
self, x: torch.Tensor, relative_position_bias: torch.Tensor
|
|
296
|
+
) -> torch.Tensor:
|
|
297
|
+
"""Forward pass through the Attention."""
|
|
298
|
+
x = self.input_norm(x) # (BSZ, num_patches, dim)
|
|
299
|
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (BSZ, num_patches, dim)
|
|
300
|
+
q, k, v = map(
|
|
301
|
+
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
|
|
302
|
+
) # (BSZ, num_heads, num_patches, dim_head)
|
|
303
|
+
|
|
304
|
+
attention_scores = (
|
|
305
|
+
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
|
|
306
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
307
|
+
attention_scores = (
|
|
308
|
+
attention_scores + relative_position_bias
|
|
309
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
310
|
+
|
|
311
|
+
attn = attention_scores.softmax(
|
|
312
|
+
dim=-1
|
|
313
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
314
|
+
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
|
|
315
|
+
|
|
316
|
+
out = einsum(
|
|
317
|
+
"b h i j, b h j d -> b h i d", attn, v
|
|
318
|
+
) # (BSZ, num_heads, num_patches, dim_head)
|
|
319
|
+
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
|
|
320
|
+
return self.to_out(out) # (BSZ, num_patches, dim)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class CrossAttention(nn.Module):
|
|
324
|
+
"""Cross-attention block."""
|
|
325
|
+
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
dim: int,
|
|
329
|
+
num_heads: int = 8,
|
|
330
|
+
dropout: float = 0.0,
|
|
331
|
+
):
|
|
332
|
+
"""Create a new CrossAttention."""
|
|
333
|
+
super().__init__()
|
|
334
|
+
self.num_heads = num_heads
|
|
335
|
+
assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
|
|
336
|
+
dim_head = int(dim / num_heads)
|
|
337
|
+
self.scale = dim_head**-0.5
|
|
338
|
+
|
|
339
|
+
self.to_q = nn.Linear(dim, dim, bias=False)
|
|
340
|
+
self.to_k = nn.Linear(dim, dim, bias=False)
|
|
341
|
+
self.to_v = nn.Linear(dim, dim, bias=False)
|
|
342
|
+
|
|
343
|
+
self.to_out = nn.Linear(dim, dim)
|
|
344
|
+
self.input_norm = nn.LayerNorm(dim)
|
|
345
|
+
self.dropout = nn.Dropout(dropout)
|
|
346
|
+
|
|
347
|
+
def forward(
|
|
348
|
+
self,
|
|
349
|
+
x: torch.Tensor,
|
|
350
|
+
context: torch.Tensor,
|
|
351
|
+
relative_position_bias: torch.Tensor,
|
|
352
|
+
) -> torch.Tensor:
|
|
353
|
+
"""Forward pass through the CrossAttention."""
|
|
354
|
+
x = self.input_norm(x) # (BSZ, num_patches, dim)
|
|
355
|
+
context = self.input_norm(context) # (BSZ, num_patches, dim)
|
|
356
|
+
|
|
357
|
+
q = self.to_q(x) # (BSZ, num_patches, dim)
|
|
358
|
+
k = self.to_k(context) # (BSZ, num_patches, dim)
|
|
359
|
+
v = self.to_v(context) # (BSZ, num_patches, dim)
|
|
360
|
+
|
|
361
|
+
q, k, v = map(
|
|
362
|
+
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
|
|
363
|
+
) # (BSZ, num_heads, num_patches, dim_head)
|
|
364
|
+
|
|
365
|
+
attention_scores = (
|
|
366
|
+
einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
|
|
367
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
368
|
+
attention_scores = (
|
|
369
|
+
attention_scores + relative_position_bias
|
|
370
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
371
|
+
|
|
372
|
+
attn = attention_scores.softmax(
|
|
373
|
+
dim=-1
|
|
374
|
+
) # (BSZ, num_heads, num_patches, num_patches)
|
|
375
|
+
attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
|
|
376
|
+
|
|
377
|
+
out = einsum(
|
|
378
|
+
"b h i j, b h j d -> b h i d", attn, v
|
|
379
|
+
) # (BSZ, num_heads, num_patches, dim_head)
|
|
380
|
+
out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
|
|
381
|
+
return self.to_out(out) # (BSZ, num_patches, dim)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class BaseTransformer(nn.Module):
|
|
385
|
+
"""Base transformer class."""
|
|
386
|
+
|
|
387
|
+
def __init__(
|
|
388
|
+
self,
|
|
389
|
+
dim: int,
|
|
390
|
+
depth: int,
|
|
391
|
+
num_heads: int = 8,
|
|
392
|
+
attn_dropout: float = 0.0,
|
|
393
|
+
ff_dropout: float = 0.0,
|
|
394
|
+
ff_mult: int = 4,
|
|
395
|
+
final_norm: bool = True,
|
|
396
|
+
):
|
|
397
|
+
"""Create a new BaseTransformer."""
|
|
398
|
+
super().__init__()
|
|
399
|
+
self.final_norm = final_norm
|
|
400
|
+
self.layers = nn.ModuleList([])
|
|
401
|
+
for _ in range(depth):
|
|
402
|
+
self.layers.append(
|
|
403
|
+
nn.ModuleList(
|
|
404
|
+
[
|
|
405
|
+
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
|
|
406
|
+
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
|
|
407
|
+
]
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
if self.final_norm:
|
|
412
|
+
self.norm_out = nn.LayerNorm(dim)
|
|
413
|
+
|
|
414
|
+
def forward(
|
|
415
|
+
self, x: torch.Tensor, relative_position_bias: torch.Tensor
|
|
416
|
+
) -> torch.Tensor:
|
|
417
|
+
"""Forward pass through the BaseTransformer."""
|
|
418
|
+
for self_attn, ffn in self.layers:
|
|
419
|
+
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
|
|
420
|
+
x = ffn(x) + x # (BSZ, num_patches, dim)
|
|
421
|
+
|
|
422
|
+
if self.final_norm:
|
|
423
|
+
return self.norm_out(x)
|
|
424
|
+
else:
|
|
425
|
+
return x
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class BaseTransformerCrossAttn(nn.Module):
|
|
429
|
+
"""Base transformer class for cross attention."""
|
|
430
|
+
|
|
431
|
+
def __init__(
|
|
432
|
+
self,
|
|
433
|
+
dim: int,
|
|
434
|
+
depth: int,
|
|
435
|
+
num_heads: int = 8,
|
|
436
|
+
attn_dropout: float = 0.0,
|
|
437
|
+
ff_dropout: float = 0.0,
|
|
438
|
+
ff_mult: int = 4,
|
|
439
|
+
):
|
|
440
|
+
"""Create a new BaseTransformerCrossAttn."""
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.layers = nn.ModuleList([])
|
|
443
|
+
for _ in range(depth):
|
|
444
|
+
self.layers.append(
|
|
445
|
+
nn.ModuleList(
|
|
446
|
+
[
|
|
447
|
+
Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
|
|
448
|
+
CrossAttention(
|
|
449
|
+
dim=dim, num_heads=num_heads, dropout=attn_dropout
|
|
450
|
+
),
|
|
451
|
+
FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
|
|
452
|
+
]
|
|
453
|
+
)
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
self.norm_out = nn.LayerNorm(dim)
|
|
457
|
+
|
|
458
|
+
def forward(
|
|
459
|
+
self,
|
|
460
|
+
x: torch.Tensor,
|
|
461
|
+
context: torch.Tensor,
|
|
462
|
+
relative_position_bias: torch.Tensor,
|
|
463
|
+
) -> torch.Tensor:
|
|
464
|
+
"""Forward pass through the BaseTransformerCrossAttn."""
|
|
465
|
+
for self_attn, cross_attn, ffn in self.layers:
|
|
466
|
+
x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
|
|
467
|
+
x = (
|
|
468
|
+
cross_attn(x, context, relative_position_bias) + x
|
|
469
|
+
) # (BSZ, num_patches, dim)
|
|
470
|
+
x = ffn(x) + x # (BSZ, num_patches, dim)
|
|
471
|
+
|
|
472
|
+
x = self.norm_out(x)
|
|
473
|
+
return x # (BSZ, num_patches, dim)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class ViT(nn.Module):
|
|
477
|
+
"""ViT model."""
|
|
478
|
+
|
|
479
|
+
def __init__(self, dim: int, depth: int, in_channels: int):
|
|
480
|
+
"""Create a new ViT."""
|
|
481
|
+
super().__init__()
|
|
482
|
+
self.depth = depth
|
|
483
|
+
self.in_channels = in_channels
|
|
484
|
+
self.dim = dim
|
|
485
|
+
self.num_heads = 16 # always 16, for base and large models
|
|
486
|
+
self.patch_size = 8 # always 8, for base and large models
|
|
487
|
+
|
|
488
|
+
pixels_per_patch = int(self.patch_size * self.patch_size * in_channels)
|
|
489
|
+
self.linear_input = nn.Linear(pixels_per_patch, self.dim)
|
|
490
|
+
self.transformer = BaseTransformer(
|
|
491
|
+
dim=self.dim,
|
|
492
|
+
depth=self.depth,
|
|
493
|
+
num_heads=self.num_heads,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def forward(self, imgs: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor:
|
|
497
|
+
"""Forward pass through the ViT."""
|
|
498
|
+
x = rearrange(
|
|
499
|
+
imgs,
|
|
500
|
+
"b c (h i) (w j) -> b (h w) (c i j)",
|
|
501
|
+
i=self.patch_size,
|
|
502
|
+
j=self.patch_size,
|
|
503
|
+
)
|
|
504
|
+
# x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size)
|
|
505
|
+
|
|
506
|
+
x = self.linear_input(x) # (bsz, num_patches, dim)
|
|
507
|
+
x = self.transformer(x, relative_position_bias=attn_bias)
|
|
508
|
+
return x
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Template parameter substitution utilities for rslearn configuration files."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def substitute_env_vars_in_string(content: str) -> str:
|
|
8
|
+
"""Substitute environment variables in a string.
|
|
9
|
+
|
|
10
|
+
Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
|
|
11
|
+
This works on raw string content before YAML/JSON parsing.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
content: The string content containing template variables
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The string with environment variables substituted
|
|
18
|
+
"""
|
|
19
|
+
pattern = r"\$\{([^}]+)\}"
|
|
20
|
+
|
|
21
|
+
def replace_variable(match_obj: re.Match[str]) -> str:
|
|
22
|
+
var_name = match_obj.group(1)
|
|
23
|
+
env_value = os.getenv(var_name, "")
|
|
24
|
+
return env_value if env_value is not None else ""
|
|
25
|
+
|
|
26
|
+
return re.sub(pattern, replace_variable, content)
|
rslearn/tile_stores/__init__.py
CHANGED
|
@@ -1,37 +1,60 @@
|
|
|
1
1
|
"""Tile stores that store ingested raster and vector data before materialization."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
import jsonargparse
|
|
6
|
+
from upath import UPath
|
|
6
7
|
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
LayerMetadata,
|
|
10
|
-
PrefixedTileStore,
|
|
11
|
-
TileStore,
|
|
12
|
-
TileStoreLayer,
|
|
13
|
-
get_tile_store_for_layer,
|
|
14
|
-
)
|
|
8
|
+
from rslearn.config import LayerConfig
|
|
9
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
15
10
|
|
|
16
|
-
|
|
11
|
+
from .default import DefaultTileStore
|
|
12
|
+
from .tile_store import TileStore, TileStoreWithLayer
|
|
17
13
|
|
|
18
14
|
|
|
19
|
-
def load_tile_store(config:
|
|
15
|
+
def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore:
|
|
20
16
|
"""Load a tile store from a configuration.
|
|
21
17
|
|
|
22
18
|
Args:
|
|
23
19
|
config: the tile store configuration.
|
|
24
20
|
ds_path: the dataset root path.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
the TileStore
|
|
24
|
+
"""
|
|
25
|
+
init_jsonargparse()
|
|
26
|
+
parser = jsonargparse.ArgumentParser()
|
|
27
|
+
parser.add_argument("--tile_store", type=TileStore)
|
|
28
|
+
cfg = parser.parse_object({"tile_store": config})
|
|
29
|
+
tile_store = parser.instantiate_classes(cfg).tile_store
|
|
30
|
+
tile_store.set_dataset_path(ds_path)
|
|
31
|
+
return tile_store
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_tile_store_with_layer(
|
|
35
|
+
tile_store: TileStore, layer_name: str, layer_cfg: LayerConfig
|
|
36
|
+
) -> TileStoreWithLayer:
|
|
37
|
+
"""Get the TileStoreWithLayer for the specified layer.
|
|
38
|
+
|
|
39
|
+
Uses alias of the layer if it is set, otherwise just the layer name.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
tile_store: the tile store.
|
|
43
|
+
layer_name: the layer name.
|
|
44
|
+
layer_cfg: the layer configuration which can specify an alias.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
corresponding TileStoreWithLayer
|
|
25
48
|
"""
|
|
26
|
-
|
|
49
|
+
if layer_cfg.alias is not None:
|
|
50
|
+
return TileStoreWithLayer(tile_store, layer_cfg.alias)
|
|
51
|
+
return TileStoreWithLayer(tile_store, layer_name)
|
|
27
52
|
|
|
28
53
|
|
|
29
54
|
__all__ = (
|
|
30
|
-
"
|
|
31
|
-
"LayerMetadata",
|
|
32
|
-
"PrefixedTileStore",
|
|
55
|
+
"DefaultTileStore",
|
|
33
56
|
"TileStore",
|
|
34
|
-
"
|
|
57
|
+
"TileStoreWithLayer",
|
|
35
58
|
"load_tile_store",
|
|
36
|
-
"
|
|
59
|
+
"get_tile_store_with_layer",
|
|
37
60
|
)
|