rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,508 @@
1
+ """CROMA code, copied from https://github.com/antofuller/CROMA.
2
+
3
+ The code is released under:
4
+
5
+ MIT License
6
+ Copyright (c) 2023 Anthony Fuller
7
+ """
8
+
9
+ import itertools
10
+ import math
11
+ import warnings
12
+
13
+ import torch
14
+ from einops import rearrange
15
+ from torch import einsum, nn
16
+
17
+
18
+ class PretrainedCROMA(nn.Module):
19
+ """Pre-trained CROMA model."""
20
+
21
+ def __init__(
22
+ self,
23
+ pretrained_path: str = "CROMA_base.pt",
24
+ size: str = "base",
25
+ modality: str = "both",
26
+ image_resolution: int = 120,
27
+ ):
28
+ """Create a new PretrainedCROMA.
29
+
30
+ NOTE: image_resolution is not the spatial, spectral, or temporal resolution. It is the height and width of the image, in pixels.
31
+ E.g., CROMA was pretrained on 120x120px images, hence image_resolution is 120 by default
32
+ """
33
+ super().__init__()
34
+ # check types
35
+ assert isinstance(pretrained_path, str), (
36
+ f"pretrained_path must be a string, not {type(pretrained_path)}"
37
+ )
38
+ assert isinstance(size, str), f"size must be a string, not {type(size)}"
39
+ assert isinstance(modality, str), (
40
+ f"modality must be a string, not {type(modality)}"
41
+ )
42
+ assert isinstance(image_resolution, int), (
43
+ f"image_resolution must be an int, not {type(image_resolution)}"
44
+ )
45
+
46
+ # check values
47
+ assert size in [
48
+ "base",
49
+ "large",
50
+ ], f"size must be either base or large, not {size}"
51
+ assert image_resolution % 8 == 0, (
52
+ f"image_resolution must be a multiple of 8, not {image_resolution}"
53
+ )
54
+ assert modality in [
55
+ "both",
56
+ "SAR",
57
+ "optical",
58
+ ], f"modality must be either both, SAR, or optical, not {modality}"
59
+
60
+ # warn the user if the path contains a different size than the size parameter
61
+ if size == "base" and "large" in pretrained_path:
62
+ warnings.warn(
63
+ "The size is set to base, but the word large appears in the pretrained path!"
64
+ )
65
+ elif size == "large" and "base" in pretrained_path:
66
+ warnings.warn(
67
+ "The size is set to large, but the word base appears in the pretrained path!"
68
+ )
69
+
70
+ if size == "base":
71
+ self.encoder_dim = 768
72
+ self.encoder_depth = 12
73
+ self.num_heads = 16
74
+ self.patch_size = 8
75
+ else:
76
+ # large by default
77
+ self.encoder_dim = 1024
78
+ self.encoder_depth = 24
79
+ self.num_heads = 16
80
+ self.patch_size = 8
81
+
82
+ self.modality = modality
83
+ self.num_patches = int((image_resolution / 8) ** 2)
84
+ self.s1_channels = 2 # fixed at 2 SAR backscatter channels
85
+ self.s2_channels = 12 # fixed at 12 multispectral optical channels
86
+ self.attn_bias = get_2dalibi(
87
+ num_heads=self.num_heads, num_patches=self.num_patches
88
+ )
89
+
90
+ if modality in ["SAR", "both"]:
91
+ print("Initializing SAR encoder")
92
+ self.s1_encoder = ViT(
93
+ dim=self.encoder_dim,
94
+ depth=int(self.encoder_depth / 2),
95
+ in_channels=self.s1_channels,
96
+ )
97
+ self.GAP_FFN_s1 = nn.Sequential(
98
+ nn.LayerNorm(self.encoder_dim),
99
+ nn.Linear(
100
+ self.encoder_dim, int(4 * self.encoder_dim)
101
+ ), # (BSZ, num_patches, inner_dim)
102
+ nn.GELU(), # (BSZ, num_patches, inner_dim)
103
+ nn.Linear(
104
+ int(4 * self.encoder_dim), self.encoder_dim
105
+ ), # (BSZ, num_patches, dim)
106
+ )
107
+
108
+ # load weights
109
+ self.s1_encoder.load_state_dict(
110
+ torch.load(pretrained_path, weights_only=True)["s1_encoder"]
111
+ )
112
+ self.GAP_FFN_s1.load_state_dict(
113
+ torch.load(pretrained_path, weights_only=True)["s1_GAP_FFN"]
114
+ )
115
+
116
+ if modality in ["optical", "both"]:
117
+ print("Initializing optical encoder")
118
+ self.s2_encoder = ViT(
119
+ dim=self.encoder_dim,
120
+ depth=self.encoder_depth,
121
+ in_channels=self.s2_channels,
122
+ )
123
+ self.GAP_FFN_s2 = nn.Sequential(
124
+ nn.LayerNorm(self.encoder_dim),
125
+ nn.Linear(
126
+ self.encoder_dim, int(4 * self.encoder_dim)
127
+ ), # (BSZ, num_patches, inner_dim)
128
+ nn.GELU(), # (BSZ, num_patches, inner_dim)
129
+ nn.Linear(
130
+ int(4 * self.encoder_dim), self.encoder_dim
131
+ ), # (BSZ, num_patches, dim)
132
+ )
133
+
134
+ # load weights
135
+ self.s2_encoder.load_state_dict(
136
+ torch.load(pretrained_path, weights_only=True)["s2_encoder"]
137
+ )
138
+ self.GAP_FFN_s2.load_state_dict(
139
+ torch.load(pretrained_path, weights_only=True)["s2_GAP_FFN"]
140
+ )
141
+
142
+ if modality == "both":
143
+ print("Initializing joint SAR-optical encoder")
144
+ self.cross_encoder = BaseTransformerCrossAttn(
145
+ dim=self.encoder_dim,
146
+ depth=int(self.encoder_depth / 2),
147
+ num_heads=self.num_heads,
148
+ )
149
+
150
+ # load weights
151
+ self.cross_encoder.load_state_dict(
152
+ torch.load(pretrained_path, weights_only=True)["joint_encoder"]
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ SAR_images: torch.Tensor | None = None,
158
+ optical_images: torch.Tensor | None = None,
159
+ ) -> dict[str, torch.Tensor]:
160
+ """Forward pass through PretrainedCROMA."""
161
+ return_dict = {}
162
+ if self.modality in ["SAR", "both"]:
163
+ assert SAR_images is not None, (
164
+ f"Modality is set to {self.modality}, but SAR_images are None"
165
+ )
166
+ SAR_encodings = self.s1_encoder(
167
+ imgs=SAR_images, attn_bias=self.attn_bias.to(SAR_images.device)
168
+ ) # (bsz, num_patches, encoder_dim)
169
+ SAR_GAP = self.GAP_FFN_s1(SAR_encodings.mean(dim=1)) # (bsz, encoder_dim)
170
+ return_dict["SAR_encodings"] = SAR_encodings
171
+ return_dict["SAR_GAP"] = SAR_GAP
172
+
173
+ if self.modality in ["optical", "both"]:
174
+ assert optical_images is not None, (
175
+ f"Modality is set to {self.modality}, but optical_images are None"
176
+ )
177
+ optical_encodings = self.s2_encoder(
178
+ imgs=optical_images, attn_bias=self.attn_bias.to(optical_images.device)
179
+ ) # (bsz, num_patches, encoder_dim)
180
+ optical_GAP = self.GAP_FFN_s2(
181
+ optical_encodings.mean(dim=1)
182
+ ) # (bsz, encoder_dim)
183
+ return_dict["optical_encodings"] = optical_encodings
184
+ return_dict["optical_GAP"] = optical_GAP
185
+
186
+ if self.modality == "both":
187
+ assert SAR_images is not None
188
+ assert optical_images is not None
189
+ joint_encodings = self.cross_encoder(
190
+ x=SAR_encodings,
191
+ context=optical_encodings,
192
+ relative_position_bias=self.attn_bias.to(optical_images.device),
193
+ ) # (bsz, num_patches, encoder_dim)
194
+ joint_GAP = joint_encodings.mean(dim=1) # (bsz, encoder_dim)
195
+ return_dict["joint_encodings"] = joint_encodings
196
+ return_dict["joint_GAP"] = joint_GAP
197
+
198
+ return return_dict
199
+
200
+
201
+ def get_2dalibi(num_heads: int, num_patches: int) -> torch.Tensor:
202
+ """Get 2D bias initialization for attention layer.
203
+
204
+ Args:
205
+ num_heads: the number of heads in the attention layer.
206
+ num_patches: the total number of patches, which should be a square.
207
+ """
208
+ # inspired by: https://github.com/ofirpress/attention_with_linear_biases
209
+ points = list(
210
+ itertools.product(
211
+ range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches)))
212
+ )
213
+ )
214
+
215
+ def get_slopes(n: int) -> list[float]:
216
+ def get_slopes_power_of_2(n: int) -> list[float]:
217
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
218
+ ratio = start
219
+ return [start * ratio**i for i in range(n)]
220
+
221
+ if math.log2(n).is_integer():
222
+ return get_slopes_power_of_2(n)
223
+ else:
224
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
225
+ return (
226
+ get_slopes_power_of_2(closest_power_of_2)
227
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
228
+ )
229
+
230
+ slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1)
231
+ idxs = []
232
+ for p1 in points:
233
+ for p2 in points:
234
+ dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
235
+ idxs.append(dist * slopes * -1)
236
+ all_bias = torch.cat(idxs, dim=1)
237
+ return all_bias.view(1, num_heads, num_patches, num_patches)
238
+
239
+
240
+ class FFN(nn.Module):
241
+ """Feed-forward network block."""
242
+
243
+ def __init__(
244
+ self,
245
+ dim: int,
246
+ mult: int = 4,
247
+ dropout: float = 0.0,
248
+ ):
249
+ """Create a new FFN.
250
+
251
+ Args:
252
+ dim: the input dimension.
253
+ mult: the MLP factor (how much larger the hidden dimension should be).
254
+ dropout: the dropout rate.
255
+ """
256
+ super().__init__()
257
+ inner_dim = int(dim * mult)
258
+
259
+ self.net = nn.Sequential(
260
+ nn.Linear(dim, inner_dim), # (BSZ, num_patches, inner_dim)
261
+ nn.GELU(), # (BSZ, num_patches, inner_dim)
262
+ nn.Dropout(dropout), # (BSZ, num_patches, inner_dim)
263
+ nn.Linear(inner_dim, dim), # (BSZ, num_patches, dim)
264
+ )
265
+ self.input_norm = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
268
+ """Forward pass through the FFN."""
269
+ x = self.input_norm(x) # (BSZ, num_patches, dim)
270
+ return self.net(x) # (BSZ, num_patches, dim)
271
+
272
+
273
+ class Attention(nn.Module):
274
+ """Attention block."""
275
+
276
+ def __init__(
277
+ self,
278
+ dim: int,
279
+ num_heads: int = 8,
280
+ dropout: float = 0.0,
281
+ ):
282
+ """Create a new Attention."""
283
+ super().__init__()
284
+ self.num_heads = num_heads
285
+ assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
286
+ dim_head = int(dim / num_heads)
287
+ self.scale = dim_head**-0.5
288
+
289
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
290
+ self.to_out = nn.Linear(dim, dim)
291
+ self.input_norm = nn.LayerNorm(dim)
292
+ self.dropout = nn.Dropout(dropout)
293
+
294
+ def forward(
295
+ self, x: torch.Tensor, relative_position_bias: torch.Tensor
296
+ ) -> torch.Tensor:
297
+ """Forward pass through the Attention."""
298
+ x = self.input_norm(x) # (BSZ, num_patches, dim)
299
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (BSZ, num_patches, dim)
300
+ q, k, v = map(
301
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
302
+ ) # (BSZ, num_heads, num_patches, dim_head)
303
+
304
+ attention_scores = (
305
+ einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
306
+ ) # (BSZ, num_heads, num_patches, num_patches)
307
+ attention_scores = (
308
+ attention_scores + relative_position_bias
309
+ ) # (BSZ, num_heads, num_patches, num_patches)
310
+
311
+ attn = attention_scores.softmax(
312
+ dim=-1
313
+ ) # (BSZ, num_heads, num_patches, num_patches)
314
+ attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
315
+
316
+ out = einsum(
317
+ "b h i j, b h j d -> b h i d", attn, v
318
+ ) # (BSZ, num_heads, num_patches, dim_head)
319
+ out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
320
+ return self.to_out(out) # (BSZ, num_patches, dim)
321
+
322
+
323
+ class CrossAttention(nn.Module):
324
+ """Cross-attention block."""
325
+
326
+ def __init__(
327
+ self,
328
+ dim: int,
329
+ num_heads: int = 8,
330
+ dropout: float = 0.0,
331
+ ):
332
+ """Create a new CrossAttention."""
333
+ super().__init__()
334
+ self.num_heads = num_heads
335
+ assert dim % num_heads == 0, "dim must be evenly divisible by num_heads"
336
+ dim_head = int(dim / num_heads)
337
+ self.scale = dim_head**-0.5
338
+
339
+ self.to_q = nn.Linear(dim, dim, bias=False)
340
+ self.to_k = nn.Linear(dim, dim, bias=False)
341
+ self.to_v = nn.Linear(dim, dim, bias=False)
342
+
343
+ self.to_out = nn.Linear(dim, dim)
344
+ self.input_norm = nn.LayerNorm(dim)
345
+ self.dropout = nn.Dropout(dropout)
346
+
347
+ def forward(
348
+ self,
349
+ x: torch.Tensor,
350
+ context: torch.Tensor,
351
+ relative_position_bias: torch.Tensor,
352
+ ) -> torch.Tensor:
353
+ """Forward pass through the CrossAttention."""
354
+ x = self.input_norm(x) # (BSZ, num_patches, dim)
355
+ context = self.input_norm(context) # (BSZ, num_patches, dim)
356
+
357
+ q = self.to_q(x) # (BSZ, num_patches, dim)
358
+ k = self.to_k(context) # (BSZ, num_patches, dim)
359
+ v = self.to_v(context) # (BSZ, num_patches, dim)
360
+
361
+ q, k, v = map(
362
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)
363
+ ) # (BSZ, num_heads, num_patches, dim_head)
364
+
365
+ attention_scores = (
366
+ einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
367
+ ) # (BSZ, num_heads, num_patches, num_patches)
368
+ attention_scores = (
369
+ attention_scores + relative_position_bias
370
+ ) # (BSZ, num_heads, num_patches, num_patches)
371
+
372
+ attn = attention_scores.softmax(
373
+ dim=-1
374
+ ) # (BSZ, num_heads, num_patches, num_patches)
375
+ attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches)
376
+
377
+ out = einsum(
378
+ "b h i j, b h j d -> b h i d", attn, v
379
+ ) # (BSZ, num_heads, num_patches, dim_head)
380
+ out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim)
381
+ return self.to_out(out) # (BSZ, num_patches, dim)
382
+
383
+
384
+ class BaseTransformer(nn.Module):
385
+ """Base transformer class."""
386
+
387
+ def __init__(
388
+ self,
389
+ dim: int,
390
+ depth: int,
391
+ num_heads: int = 8,
392
+ attn_dropout: float = 0.0,
393
+ ff_dropout: float = 0.0,
394
+ ff_mult: int = 4,
395
+ final_norm: bool = True,
396
+ ):
397
+ """Create a new BaseTransformer."""
398
+ super().__init__()
399
+ self.final_norm = final_norm
400
+ self.layers = nn.ModuleList([])
401
+ for _ in range(depth):
402
+ self.layers.append(
403
+ nn.ModuleList(
404
+ [
405
+ Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
406
+ FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
407
+ ]
408
+ )
409
+ )
410
+
411
+ if self.final_norm:
412
+ self.norm_out = nn.LayerNorm(dim)
413
+
414
+ def forward(
415
+ self, x: torch.Tensor, relative_position_bias: torch.Tensor
416
+ ) -> torch.Tensor:
417
+ """Forward pass through the BaseTransformer."""
418
+ for self_attn, ffn in self.layers:
419
+ x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
420
+ x = ffn(x) + x # (BSZ, num_patches, dim)
421
+
422
+ if self.final_norm:
423
+ return self.norm_out(x)
424
+ else:
425
+ return x
426
+
427
+
428
+ class BaseTransformerCrossAttn(nn.Module):
429
+ """Base transformer class for cross attention."""
430
+
431
+ def __init__(
432
+ self,
433
+ dim: int,
434
+ depth: int,
435
+ num_heads: int = 8,
436
+ attn_dropout: float = 0.0,
437
+ ff_dropout: float = 0.0,
438
+ ff_mult: int = 4,
439
+ ):
440
+ """Create a new BaseTransformerCrossAttn."""
441
+ super().__init__()
442
+ self.layers = nn.ModuleList([])
443
+ for _ in range(depth):
444
+ self.layers.append(
445
+ nn.ModuleList(
446
+ [
447
+ Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout),
448
+ CrossAttention(
449
+ dim=dim, num_heads=num_heads, dropout=attn_dropout
450
+ ),
451
+ FFN(dim=dim, mult=ff_mult, dropout=ff_dropout),
452
+ ]
453
+ )
454
+ )
455
+
456
+ self.norm_out = nn.LayerNorm(dim)
457
+
458
+ def forward(
459
+ self,
460
+ x: torch.Tensor,
461
+ context: torch.Tensor,
462
+ relative_position_bias: torch.Tensor,
463
+ ) -> torch.Tensor:
464
+ """Forward pass through the BaseTransformerCrossAttn."""
465
+ for self_attn, cross_attn, ffn in self.layers:
466
+ x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim)
467
+ x = (
468
+ cross_attn(x, context, relative_position_bias) + x
469
+ ) # (BSZ, num_patches, dim)
470
+ x = ffn(x) + x # (BSZ, num_patches, dim)
471
+
472
+ x = self.norm_out(x)
473
+ return x # (BSZ, num_patches, dim)
474
+
475
+
476
+ class ViT(nn.Module):
477
+ """ViT model."""
478
+
479
+ def __init__(self, dim: int, depth: int, in_channels: int):
480
+ """Create a new ViT."""
481
+ super().__init__()
482
+ self.depth = depth
483
+ self.in_channels = in_channels
484
+ self.dim = dim
485
+ self.num_heads = 16 # always 16, for base and large models
486
+ self.patch_size = 8 # always 8, for base and large models
487
+
488
+ pixels_per_patch = int(self.patch_size * self.patch_size * in_channels)
489
+ self.linear_input = nn.Linear(pixels_per_patch, self.dim)
490
+ self.transformer = BaseTransformer(
491
+ dim=self.dim,
492
+ depth=self.depth,
493
+ num_heads=self.num_heads,
494
+ )
495
+
496
+ def forward(self, imgs: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor:
497
+ """Forward pass through the ViT."""
498
+ x = rearrange(
499
+ imgs,
500
+ "b c (h i) (w j) -> b (h w) (c i j)",
501
+ i=self.patch_size,
502
+ j=self.patch_size,
503
+ )
504
+ # x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size)
505
+
506
+ x = self.linear_input(x) # (bsz, num_patches, dim)
507
+ x = self.transformer(x, relative_position_bias=attn_bias)
508
+ return x
@@ -0,0 +1,26 @@
1
+ """Template parameter substitution utilities for rslearn configuration files."""
2
+
3
+ import os
4
+ import re
5
+
6
+
7
+ def substitute_env_vars_in_string(content: str) -> str:
8
+ """Substitute environment variables in a string.
9
+
10
+ Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
11
+ This works on raw string content before YAML/JSON parsing.
12
+
13
+ Args:
14
+ content: The string content containing template variables
15
+
16
+ Returns:
17
+ The string with environment variables substituted
18
+ """
19
+ pattern = r"\$\{([^}]+)\}"
20
+
21
+ def replace_variable(match_obj: re.Match[str]) -> str:
22
+ var_name = match_obj.group(1)
23
+ env_value = os.getenv(var_name, "")
24
+ return env_value if env_value is not None else ""
25
+
26
+ return re.sub(pattern, replace_variable, content)
@@ -1,37 +1,60 @@
1
1
  """Tile stores that store ingested raster and vector data before materialization."""
2
2
 
3
- from upath import UPath
3
+ from typing import Any
4
4
 
5
- from rslearn.config import TileStoreConfig
5
+ import jsonargparse
6
+ from upath import UPath
6
7
 
7
- from .file import FileTileStore
8
- from .tile_store import (
9
- LayerMetadata,
10
- PrefixedTileStore,
11
- TileStore,
12
- TileStoreLayer,
13
- get_tile_store_for_layer,
14
- )
8
+ from rslearn.config import LayerConfig
9
+ from rslearn.utils.jsonargparse import init_jsonargparse
15
10
 
16
- registry = {"file": FileTileStore}
11
+ from .default import DefaultTileStore
12
+ from .tile_store import TileStore, TileStoreWithLayer
17
13
 
18
14
 
19
- def load_tile_store(config: TileStoreConfig, ds_path: UPath) -> TileStore:
15
+ def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore:
20
16
  """Load a tile store from a configuration.
21
17
 
22
18
  Args:
23
19
  config: the tile store configuration.
24
20
  ds_path: the dataset root path.
21
+
22
+ Returns:
23
+ the TileStore
24
+ """
25
+ init_jsonargparse()
26
+ parser = jsonargparse.ArgumentParser()
27
+ parser.add_argument("--tile_store", type=TileStore)
28
+ cfg = parser.parse_object({"tile_store": config})
29
+ tile_store = parser.instantiate_classes(cfg).tile_store
30
+ tile_store.set_dataset_path(ds_path)
31
+ return tile_store
32
+
33
+
34
+ def get_tile_store_with_layer(
35
+ tile_store: TileStore, layer_name: str, layer_cfg: LayerConfig
36
+ ) -> TileStoreWithLayer:
37
+ """Get the TileStoreWithLayer for the specified layer.
38
+
39
+ Uses alias of the layer if it is set, otherwise just the layer name.
40
+
41
+ Args:
42
+ tile_store: the tile store.
43
+ layer_name: the layer name.
44
+ layer_cfg: the layer configuration which can specify an alias.
45
+
46
+ Returns:
47
+ corresponding TileStoreWithLayer
25
48
  """
26
- return registry[config.name].from_config(config, ds_path)
49
+ if layer_cfg.alias is not None:
50
+ return TileStoreWithLayer(tile_store, layer_cfg.alias)
51
+ return TileStoreWithLayer(tile_store, layer_name)
27
52
 
28
53
 
29
54
  __all__ = (
30
- "FileTileStore",
31
- "LayerMetadata",
32
- "PrefixedTileStore",
55
+ "DefaultTileStore",
33
56
  "TileStore",
34
- "TileStoreLayer",
57
+ "TileStoreWithLayer",
35
58
  "load_tile_store",
36
- "get_tile_store_for_layer",
59
+ "get_tile_store_with_layer",
37
60
  )