rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
1
+ """AnySat model.
2
+
3
+ This code loads the AnySat model from torch hub. See
4
+ https://github.com/gastruc/AnySat for applicable license and copyright information.
5
+ """
6
+
7
+ from datetime import datetime
8
+
9
+ import torch
10
+ from einops import rearrange
11
+
12
+ from rslearn.train.model_context import ModelContext
13
+
14
+ from .component import FeatureExtractor, FeatureMaps
15
+
16
+ # AnySat github: https://github.com/gastruc/AnySat
17
+ # Modalities and expected resolutions (meters)
18
+ MODALITY_RESOLUTIONS: dict[str, float] = {
19
+ "aerial": 0.2,
20
+ "aerial-flair": 0.2,
21
+ "spot": 1,
22
+ "naip": 1.25,
23
+ "s2": 10,
24
+ "s1-asc": 10,
25
+ "s1": 10,
26
+ "alos": 30,
27
+ "l7": 30,
28
+ "l8": 10, # L8 must be upsampled to 10 m in AnySat
29
+ "modis": 250,
30
+ }
31
+
32
+ # Modalities and expected band names
33
+ MODALITY_BANDS: dict[str, list[str]] = {
34
+ "aerial": ["R", "G", "B", "NiR"],
35
+ "aerial-flair": ["R", "G", "B", "NiR", "Elevation"],
36
+ "spot": ["R", "G", "B"],
37
+ "naip": ["R", "G", "B", "NiR"],
38
+ "s2": ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8a", "B11", "B12"],
39
+ "s1-asc": ["VV", "VH"],
40
+ "s1": ["VV", "VH", "Ratio"],
41
+ "alos": ["HH", "HV", "Ratio"],
42
+ "l7": ["B1", "B2", "B3", "B4", "B5", "B7"],
43
+ "l8": ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"],
44
+ "modis": ["B1", "B2", "B3", "B4", "B5", "B6", "B7"],
45
+ }
46
+
47
+ # Modalities that require *_dates* input
48
+ TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
49
+
50
+
51
+ class AnySat(FeatureExtractor):
52
+ """AnySat backbone (outputs one feature map)."""
53
+
54
+ def __init__(
55
+ self,
56
+ modalities: list[str],
57
+ patch_size_meters: int,
58
+ output: str = "patch",
59
+ output_modality: str | None = None,
60
+ hub_repo: str = "gastruc/anysat",
61
+ pretrained: bool = True,
62
+ force_reload: bool = False,
63
+ flash_attn: bool = False,
64
+ ) -> None:
65
+ """Initialize an AnySat model.
66
+
67
+ Args:
68
+ modalities: list of modalities to use as input (1 or more).
69
+ patch_size_meters: patch size in meters (must be multiple of 10). Avoid having more than 1024 patches per tile
70
+ ie, the height/width in meters should be <= 32 * patch_size_meters.
71
+ dates: dict mapping time-series modalities to list of dates (day number in a year, 0-255).
72
+ output: 'patch' (default) or 'dense'. Use 'patch' for classification tasks,
73
+ 'dense' for segmentation tasks.
74
+ output_modality: required if output='dense', specifies which modality to use
75
+ for the dense output (one of the input modalities).
76
+ hub_repo: torch.hub repository to load AnySat from.
77
+ pretrained: whether to load pretrained weights.
78
+ force_reload: whether to force re-download of the model.
79
+ flash_attn: whether to use flash attention (if available).
80
+ """
81
+ super().__init__()
82
+
83
+ if not modalities:
84
+ raise ValueError("At least one modality must be specified.")
85
+ for m in modalities:
86
+ if m not in MODALITY_RESOLUTIONS:
87
+ raise ValueError(f"Invalid modality: {m}")
88
+
89
+ if patch_size_meters % 10 != 0:
90
+ raise ValueError(
91
+ "In AnySat, `patch_size` is in meters and must be a multiple of 10."
92
+ )
93
+
94
+ output = output.lower()
95
+ if output not in {"patch", "dense"}:
96
+ raise ValueError("`output` must be 'patch' or 'dense'.")
97
+ if output == "dense" and output_modality is None:
98
+ raise ValueError("`output_modality` is required when output='dense'.")
99
+
100
+ self.modalities = modalities
101
+ self.patch_size_meters = int(patch_size_meters)
102
+ self.output = output
103
+ self.output_modality = output_modality
104
+
105
+ self.model = torch.hub.load( # nosec B614
106
+ hub_repo,
107
+ "anysat",
108
+ pretrained=pretrained,
109
+ force_reload=force_reload,
110
+ flash_attn=flash_attn,
111
+ )
112
+ self._embed_dim = 768 # base width, 'dense' returns 2x
113
+
114
+ @staticmethod
115
+ def time_ranges_to_doy(
116
+ time_ranges: list[tuple[datetime, datetime]],
117
+ device: torch.device,
118
+ ) -> torch.Tensor:
119
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
120
+
121
+ AnySat uses the doy with each timestamp, so we take the midpoint
122
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
123
+ time so that start_time == end_time == mid_time.
124
+ """
125
+ doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
126
+ return torch.tensor(doys, dtype=torch.int32, device=device)
127
+
128
+ def forward(self, context: ModelContext) -> FeatureMaps:
129
+ """Forward pass for the AnySat model.
130
+
131
+ Args:
132
+ context: the model context. Input dicts must include modalities as keys
133
+ which are defined in the self.modalities list
134
+
135
+ Returns:
136
+ a FeatureMaps with one feature map at the configured patch size.
137
+ """
138
+ inputs = context.inputs
139
+
140
+ batch: dict[str, torch.Tensor] = {}
141
+ spatial_extent: tuple[float, float] | None = None
142
+
143
+ for modality in self.modalities:
144
+ if modality not in inputs[0]:
145
+ raise ValueError(f"Modality '{modality}' not present in inputs.")
146
+
147
+ cur = torch.stack(
148
+ [inp[modality].image for inp in inputs], dim=0
149
+ ) # (B, C, T, H, W)
150
+
151
+ if modality in TIME_SERIES_MODALITIES:
152
+ num_bands = cur.shape[1]
153
+ cur = rearrange(cur, "b c t h w -> b t c h w")
154
+ H, W = cur.shape[-2], cur.shape[-1]
155
+
156
+ if inputs[0][modality].timestamps is None:
157
+ raise ValueError(
158
+ f"Require timestamps for time series modality {modality}"
159
+ )
160
+ timestamps = torch.stack(
161
+ [
162
+ self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
163
+ for inp in inputs
164
+ ],
165
+ dim=0,
166
+ )
167
+ batch[f"{modality}_dates"] = timestamps
168
+ else:
169
+ # take the first (assumed only) timestep
170
+ cur = cur[:, :, 0]
171
+ num_bands = cur.shape[1]
172
+ H, W = cur.shape[-2], cur.shape[-1]
173
+
174
+ if num_bands != len(MODALITY_BANDS[modality]):
175
+ raise ValueError(
176
+ f"Modality '{modality}' expected {len(MODALITY_BANDS[modality])} bands, "
177
+ f"got {num_bands} (shape {tuple(cur.shape)})"
178
+ )
179
+
180
+ batch[modality] = cur
181
+
182
+ # Ensure same spatial extent across all modalities (H*res, W*res)
183
+ extent = (
184
+ H * MODALITY_RESOLUTIONS[modality],
185
+ W * MODALITY_RESOLUTIONS[modality],
186
+ )
187
+ if spatial_extent is None:
188
+ spatial_extent = extent
189
+ elif spatial_extent != extent:
190
+ raise ValueError(
191
+ "All modalities must share the same spatial extent (H*res, W*res)."
192
+ )
193
+
194
+ kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
195
+ if self.output == "dense":
196
+ kwargs["output_modality"] = self.output_modality
197
+
198
+ features = self.model(batch, **kwargs)
199
+ return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
200
+
201
+ def get_backbone_channels(self) -> list:
202
+ """Returns the output channels of this model when used as a backbone.
203
+
204
+ The output channels is a list of (patch_size, depth) that corresponds
205
+ to the feature maps that the backbone returns.
206
+
207
+ Returns:
208
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
209
+ """
210
+ if self.output == "patch":
211
+ return [(self.patch_size_meters // 10, 768)]
212
+ elif self.output == "dense":
213
+ return [(1, 1536)]
214
+ else:
215
+ raise ValueError(f"invalid output type: {self.output}")
@@ -0,0 +1,177 @@
1
+ """An attention pooling layer."""
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from rslearn.models.component import (
12
+ FeatureMaps,
13
+ IntermediateComponent,
14
+ TokenFeatureMaps,
15
+ )
16
+ from rslearn.train.model_context import ModelContext
17
+
18
+
19
+ class SimpleAttentionPool(IntermediateComponent):
20
+ """Simple Attention Pooling.
21
+
22
+ Given a token feature map of shape BCHWN,
23
+ learn an attention layer which aggregates over
24
+ the N dimension.
25
+
26
+ This is done simply by learning a mapping D->1 which is the weight
27
+ which should be assigned to each token during averaging:
28
+
29
+ output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
30
+ """
31
+
32
+ def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
33
+ """Initialize the simple attention pooling layer.
34
+
35
+ Args:
36
+ in_dim: the encoding dimension D
37
+ hidden_linear: whether to apply an additional linear transformation D -> D
38
+ to the feat tokens. If this is True, a ReLU activation is applied
39
+ after the first linear transformation.
40
+ """
41
+ super().__init__()
42
+ if hidden_linear:
43
+ self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
44
+ else:
45
+ self.hidden_linear = None
46
+ self.linear = nn.Linear(in_features=in_dim, out_features=1)
47
+
48
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
49
+ """Attention pooling for a single feature map (BCHWN tensor)."""
50
+ B, D, H, W, N = feat_tokens.shape
51
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
52
+ if self.hidden_linear is not None:
53
+ feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
54
+ attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
55
+ feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
56
+ return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
57
+
58
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
59
+ """Forward pass for attention pooling linear probe.
60
+
61
+ Args:
62
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
63
+ We pool over the final dimension in the TokenFeatureMaps. If multiple maps
64
+ are passed, we apply the same linear layers to all of them.
65
+ context: the model context.
66
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
67
+
68
+ Returns:
69
+ torch.Tensor:
70
+ - output, attentioned pool over the last dimension (B, C, H, W)
71
+ """
72
+ if not isinstance(intermediates, TokenFeatureMaps):
73
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
74
+
75
+ features = []
76
+ for feat in intermediates.feature_maps:
77
+ features.append(self.forward_for_map(feat))
78
+ return FeatureMaps(features)
79
+
80
+
81
+ class AttentionPool(IntermediateComponent):
82
+ """Attention Pooling.
83
+
84
+ Given a feature map of shape BCHWN,
85
+ learn an attention layer which aggregates over
86
+ the N dimension.
87
+
88
+ We do this by learning a query token, and applying a standard
89
+ attention mechanism against this learned query token.
90
+ """
91
+
92
+ def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
93
+ """Initialize the attention pooling layer.
94
+
95
+ Args:
96
+ in_dim: the encoding dimension D
97
+ num_heads: the number of heads to use
98
+ linear_on_kv: Whether to apply a linear layer on the input tokens
99
+ to create the key and value tokens.
100
+ """
101
+ super().__init__()
102
+ self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
103
+ if linear_on_kv:
104
+ self.k_linear = nn.Linear(in_dim, in_dim)
105
+ self.v_linear = nn.Linear(in_dim, in_dim)
106
+ else:
107
+ self.k_linear = None
108
+ self.v_linear = None
109
+ if in_dim % num_heads != 0:
110
+ raise ValueError(
111
+ f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
112
+ )
113
+ self.num_heads = num_heads
114
+ self.init_weights()
115
+
116
+ def init_weights(self) -> None:
117
+ """Initialize weights for the probe."""
118
+ nn.init.trunc_normal_(self.query_token, std=0.02)
119
+
120
+ def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
121
+ """Attention pooling for a single feature map (BCHWN tensor)."""
122
+ B, D, H, W, N = feat_tokens.shape
123
+ feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
124
+ collapsed_dim = B * H * W
125
+ q = self.query_token.expand(collapsed_dim, 1, -1)
126
+ q = q.reshape(
127
+ collapsed_dim, 1, self.num_heads, D // self.num_heads
128
+ ) # [B, 1, head, D_head]
129
+ q = rearrange(q, "b h n d -> b n h d")
130
+ if self.k_linear is not None:
131
+ assert self.v_linear is not None
132
+ k = self.k_linear(feat_tokens).reshape(
133
+ collapsed_dim, N, self.num_heads, D // self.num_heads
134
+ )
135
+ v = self.v_linear(feat_tokens).reshape(
136
+ collapsed_dim, N, self.num_heads, D // self.num_heads
137
+ )
138
+ else:
139
+ k = feat_tokens.reshape(
140
+ collapsed_dim, N, self.num_heads, D // self.num_heads
141
+ )
142
+ v = feat_tokens.reshape(
143
+ collapsed_dim, N, self.num_heads, D // self.num_heads
144
+ )
145
+ k = rearrange(k, "b n h d -> b h n d")
146
+ v = rearrange(v, "b n h d -> b h n d")
147
+
148
+ # Compute attention scores
149
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
150
+ D // self.num_heads
151
+ )
152
+ attn_weights = F.softmax(attn_scores, dim=-1)
153
+ x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
154
+ return x.reshape(B, D, H, W)
155
+
156
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
157
+ """Forward pass for attention pooling linear probe.
158
+
159
+ Args:
160
+ intermediates: the output from the previous component, which must be a TokenFeatureMaps.
161
+ We pool over the final dimension in the TokenFeatureMaps. If multiple feature
162
+ maps are passed, we apply the same attention weights (query token and linear k, v layers)
163
+ to all the maps.
164
+ context: the model context.
165
+ feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
166
+
167
+ Returns:
168
+ torch.Tensor:
169
+ - output, attentioned pool over the last dimension (B, C, H, W)
170
+ """
171
+ if not isinstance(intermediates, TokenFeatureMaps):
172
+ raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
173
+
174
+ features = []
175
+ for feat in intermediates.feature_maps:
176
+ features.append(self.forward_for_map(feat))
177
+ return FeatureMaps(features)
@@ -0,0 +1,231 @@
1
+ """Clay models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from enum import Enum
7
+ from importlib.resources import files
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import yaml
13
+ from einops import rearrange
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ # from claymodel.module import ClayMAEModule
17
+ from terratorch.models.backbones.clay_v15.module import ClayMAEModule
18
+
19
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
20
+ from rslearn.train.model_context import ModelContext
21
+ from rslearn.train.transforms.normalize import Normalize
22
+ from rslearn.train.transforms.transform import Transform
23
+
24
+
25
+ class ClaySize(str, Enum):
26
+ """Size of the Clay model."""
27
+
28
+ BASE = "base"
29
+ LARGE = "large"
30
+
31
+
32
+ PATCH_SIZE = 8
33
+ CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
34
+ CONFIG_DIR = files("rslearn.models.clay.configs")
35
+ CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
36
+ DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
37
+
38
+
39
+ def get_clay_checkpoint_path(
40
+ filename: str = "v1.5/clay-v1.5.ckpt",
41
+ repo_id: str = "made-with-clay/Clay",
42
+ ) -> str:
43
+ """Return a cached local path to the Clay ckpt from the Hugging Face Hub."""
44
+ return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
45
+
46
+
47
+ class Clay(FeatureExtractor):
48
+ """Clay backbones."""
49
+
50
+ def __init__(
51
+ self,
52
+ model_size: ClaySize,
53
+ modality: str = "sentinel-2-l2a",
54
+ checkpoint_path: str | None = None,
55
+ metadata_path: str = CLAY_METADATA_PATH,
56
+ do_resizing: bool = False,
57
+ ) -> None:
58
+ """Initialize the Clay model.
59
+
60
+ Args:
61
+ model_size: The size of the Clay model.
62
+ modality: The modality to use (subset of CLAY_MODALITIES).
63
+ checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
64
+ metadata_path: Path to metadata.yaml.
65
+ do_resizing: Whether to resize the image to the input resolution.
66
+ """
67
+ super().__init__()
68
+
69
+ # Clay only supports single modality input
70
+ if modality not in CLAY_MODALITIES:
71
+ raise ValueError(f"Invalid modality: {modality}")
72
+
73
+ ckpt = checkpoint_path or get_clay_checkpoint_path()
74
+ if model_size == ClaySize.LARGE:
75
+ self.model = ClayMAEModule.load_from_checkpoint(
76
+ checkpoint_path=ckpt,
77
+ model_size="large",
78
+ metadata_path=metadata_path,
79
+ dolls=[16, 32, 64, 128, 256, 768, 1024],
80
+ doll_weights=[1, 1, 1, 1, 1, 1, 1],
81
+ mask_ratio=0.0,
82
+ shuffle=False,
83
+ )
84
+ elif model_size == ClaySize.BASE:
85
+ # Failed to load Base model in Clay v1.5
86
+ raise ValueError("Clay BASE model currently not supported in v1.5.")
87
+ self.model = ClayMAEModule.load_from_checkpoint(
88
+ checkpoint_path=ckpt,
89
+ model_size="base",
90
+ metadata_path=metadata_path,
91
+ dolls=[16, 32, 64, 128, 256, 768],
92
+ doll_weights=[1, 1, 1, 1, 1, 1],
93
+ mask_ratio=0.0,
94
+ shuffle=False,
95
+ )
96
+ else:
97
+ raise ValueError(f"Invalid model size: {model_size}")
98
+
99
+ with open(metadata_path) as f:
100
+ self.metadata = yaml.safe_load(f)
101
+
102
+ self.model_size = model_size
103
+ self.modality = modality
104
+ self.do_resizing = do_resizing
105
+
106
+ def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
107
+ """Resize the image to the input resolution."""
108
+ new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
109
+ return F.interpolate(
110
+ image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
111
+ )
112
+
113
+ def forward(self, context: ModelContext) -> FeatureMaps:
114
+ """Forward pass for the Clay model.
115
+
116
+ Args:
117
+ context: the model context. Input dicts must include `self.modality` as a key
118
+
119
+ Returns:
120
+ a FeatureMaps consisting of one feature map, computed by Clay.
121
+ """
122
+ param = next(self.model.parameters())
123
+ device = param.device
124
+
125
+ chips = torch.stack(
126
+ [inp[self.modality] for inp in context.inputs], dim=0
127
+ ) # (B, C, H, W)
128
+ if self.do_resizing:
129
+ chips = self._resize_image(chips, chips.shape[2])
130
+ order = self.metadata[self.modality]["band_order"]
131
+ wavelengths = []
132
+ for band in self.metadata[self.modality]["band_order"]:
133
+ wavelengths.append(
134
+ self.metadata[self.modality]["bands"]["wavelength"][band] * 1000
135
+ ) # Convert to nm
136
+ # Check channel count matches Clay expectation
137
+ if chips.shape[1] != len(order):
138
+ raise ValueError(
139
+ f"Channel count {chips.shape[1]} does not match expected {len(order)} for {self.modality}"
140
+ )
141
+
142
+ # Time & latlon zeros are valid per Clay doc
143
+ # https://clay-foundation.github.io/model/getting-started/basic_use.html
144
+ datacube = {
145
+ "platform": self.modality,
146
+ "time": torch.zeros(chips.shape[0], 4).to(device),
147
+ "latlon": torch.zeros(chips.shape[0], 4).to(device),
148
+ "pixels": chips.to(device),
149
+ "gsd": torch.tensor(self.metadata[self.modality]["gsd"]).to(device),
150
+ "waves": torch.tensor(wavelengths).to(device),
151
+ }
152
+
153
+ tokens, *_ = self.model.model.encoder(datacube) # (B, 1 + N, D)
154
+
155
+ # Remove CLS token
156
+ spatial = tokens[:, 1:, :] # (B, N, D)
157
+ n_tokens = spatial.shape[1]
158
+ side = int(math.isqrt(n_tokens))
159
+ if chips.shape[2] != side * PATCH_SIZE or chips.shape[3] != side * PATCH_SIZE:
160
+ raise ValueError(
161
+ f"Input spatial size {(chips.shape[2], chips.shape[3])} is not compatible with patch size {PATCH_SIZE}"
162
+ )
163
+
164
+ features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
165
+ return FeatureMaps([features])
166
+
167
+ def get_backbone_channels(self) -> list:
168
+ """Return output channels of this model when used as a backbone."""
169
+ if self.model_size == ClaySize.LARGE:
170
+ depth = 1024
171
+ elif self.model_size == ClaySize.BASE:
172
+ depth = 768
173
+ else:
174
+ raise ValueError(f"Invalid model size: {self.model_size}")
175
+ return [(PATCH_SIZE, depth)]
176
+
177
+
178
+ class ClayNormalize(Transform):
179
+ """Normalize inputs using Clay metadata.
180
+
181
+ For Sentinel-1, the intensities should be converted to decibels.
182
+ """
183
+
184
+ def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
185
+ """Initialize ClayNormalize."""
186
+ super().__init__()
187
+ with open(metadata_path) as f:
188
+ metadata = yaml.safe_load(f)
189
+ normalizers = {}
190
+ for modality in CLAY_MODALITIES:
191
+ if modality not in metadata:
192
+ continue
193
+ modality_metadata = metadata[modality]
194
+ means = [
195
+ modality_metadata["bands"]["mean"][b]
196
+ for b in modality_metadata["band_order"]
197
+ ]
198
+ stds = [
199
+ modality_metadata["bands"]["std"][b]
200
+ for b in modality_metadata["band_order"]
201
+ ]
202
+ normalizers[modality] = Normalize(
203
+ mean=means,
204
+ std=stds,
205
+ selectors=[modality],
206
+ num_bands=len(means),
207
+ )
208
+ self.normalizers = torch.nn.ModuleDict(normalizers)
209
+
210
+ def apply_image(
211
+ self, image: torch.Tensor, means: list[float], stds: list[float]
212
+ ) -> torch.Tensor:
213
+ """Normalize the specified image with Clay normalization."""
214
+ x = image.float()
215
+ if x.shape[0] != len(means):
216
+ raise ValueError(
217
+ f"channel count {x.shape[0]} does not match provided band stats {len(means)}"
218
+ )
219
+ for c in range(x.shape[0]):
220
+ x[c] = (x[c] - means[c]) / stds[c]
221
+ return x
222
+
223
+ def forward(
224
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
225
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
226
+ """Normalize the specified image with Clay normalization."""
227
+ for modality, normalizer in self.normalizers.items():
228
+ if modality not in input_dict:
229
+ continue
230
+ input_dict, target_dict = normalizer(input_dict, target_dict)
231
+ return input_dict, target_dict