rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,297 @@
1
+ """Presto wrapper to ingest Masked Helios Samples."""
2
+
3
+ import logging
4
+ import tempfile
5
+ from datetime import datetime
6
+
7
+ import torch
8
+ from einops import rearrange, repeat
9
+ from huggingface_hub import hf_hub_download
10
+ from upath import UPath
11
+
12
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
13
+ from rslearn.models.presto.single_file_presto import (
14
+ ERA5_BANDS,
15
+ NUM_DYNAMIC_WORLD_CLASSES,
16
+ PRESTO_ADD_BY,
17
+ PRESTO_BANDS,
18
+ PRESTO_DIV_BY,
19
+ PRESTO_S1_BANDS,
20
+ PRESTO_S2_BANDS,
21
+ SRTM_BANDS,
22
+ )
23
+ from rslearn.models.presto.single_file_presto import Presto as SFPresto
24
+ from rslearn.train.model_context import ModelContext
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B09"]
29
+ INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B09"]
30
+
31
+ PRESTO_S1_SUBTRACT_VALUE = -25.0
32
+ PRESTO_S1_DIV_VALUE = 25.0
33
+ PRESTO_S2_SUBTRACT_VALUE = 0.0
34
+ PRESTO_S2_DIV_VALUE = 1e4
35
+
36
+ HF_HUB_ID = "nasaharvest/presto"
37
+ MODEL_FILENAME = "default_model.pt"
38
+
39
+
40
+ class Presto(FeatureExtractor):
41
+ """Presto."""
42
+
43
+ input_keys = [
44
+ "s1",
45
+ "s2",
46
+ "era5",
47
+ "srtm",
48
+ "dynamic_world",
49
+ "latlon",
50
+ ]
51
+
52
+ def __init__(
53
+ self,
54
+ pretrained_path: str | UPath | None = None,
55
+ pixel_batch_size: int = 128,
56
+ ):
57
+ """Initialize the Presto wrapper.
58
+
59
+ Args:
60
+ pretrained_path: The directory to load from
61
+ pixel_batch_size: If the input has a h,w dimension >1, this is
62
+ flattened into a batch dimension (b h w) before being passed
63
+ to the model (since Presto is designed for pixel timeseries).
64
+ """
65
+ super().__init__()
66
+
67
+ if pretrained_path is None:
68
+ pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "presto")
69
+ if not (UPath(pretrained_path) / MODEL_FILENAME).exists():
70
+ _ = hf_hub_download(
71
+ local_dir=UPath(pretrained_path),
72
+ repo_id=HF_HUB_ID,
73
+ filename=MODEL_FILENAME,
74
+ # pin the model to a specific hugging face commit
75
+ revision="1b97f885969da4e2d5834ca8c92707c737911464",
76
+ )
77
+
78
+ model = SFPresto.construct()
79
+ model.load_state_dict(
80
+ torch.load(
81
+ UPath(pretrained_path) / MODEL_FILENAME,
82
+ map_location="cpu",
83
+ weights_only=True,
84
+ )
85
+ )
86
+ self.pixel_batch_size = pixel_batch_size
87
+ self.model = model.encoder
88
+ self.month = 6 # default month
89
+
90
+ def construct_presto_input(
91
+ self,
92
+ s1: torch.Tensor | None = None,
93
+ s1_bands: torch.Tensor | None = None,
94
+ s2: torch.Tensor | None = None,
95
+ s2_bands: torch.Tensor | None = None,
96
+ era5: torch.Tensor | None = None,
97
+ era5_bands: torch.Tensor | None = None,
98
+ srtm: torch.Tensor | None = None,
99
+ srtm_bands: torch.Tensor | None = None,
100
+ dynamic_world: torch.Tensor | None = None,
101
+ months: torch.Tensor | None = None,
102
+ normalize: bool = True,
103
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
+ """Inputs are paired into a tensor input <X> and a list <X>_bands, which describes <X>.
105
+
106
+ <X> should have shape (b, num_timesteps, h, w len(<X>_bands)), with the following bands for
107
+ each input:
108
+
109
+ s1: ["VV", "VH"]
110
+ s2: ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
111
+ era5: ["temperature_2m", "total_precipitation"]
112
+ "temperature_2m": Temperature of air at 2m above the surface of land,
113
+ sea or in-land waters in Kelvin (K)
114
+ "total_precipitation": Accumulated liquid and frozen water, including rain and snow,
115
+ that falls to the Earth's surface. Measured in metres (m)
116
+ srtm: ["elevation", "slope"]
117
+
118
+ dynamic_world is a 1d input of shape (num_timesteps,) representing the dynamic world classes
119
+ of each timestep for that pixel
120
+ """
121
+ bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
122
+ ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
123
+ hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
124
+ ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
125
+ devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
126
+
127
+ assert len(set(bs)) == 1
128
+ assert len(set(hs)) == 1
129
+ assert len(set(ws)) == 1
130
+ assert len(set(devices)) == 1
131
+ assert len(set(ts)) == 1
132
+ b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
133
+ # these values will be initialized as
134
+ # we iterate through the data
135
+ x: torch.Tensor | None = None
136
+ mask: torch.Tensor | None = None
137
+
138
+ for band_group in [
139
+ (s1, s1_bands),
140
+ (s2, s2_bands),
141
+ (era5, era5_bands),
142
+ (srtm, srtm_bands),
143
+ ]:
144
+ data, input_bands = band_group
145
+ if data is not None:
146
+ assert input_bands is not None
147
+ else:
148
+ continue
149
+
150
+ data = rearrange(data, "b c t h w -> b t h w c")
151
+ if x is None:
152
+ x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
153
+ if mask is None:
154
+ mask = torch.ones(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
155
+
156
+ # construct a mapping from the input bands to the presto input bands
157
+ input_to_output_mapping = [
158
+ INPUT_PRESTO_BANDS.index(val) for val in input_bands
159
+ ]
160
+ x[:, :, :, :, input_to_output_mapping] = data
161
+ mask[:, :, :, :, input_to_output_mapping] = 0
162
+
163
+ assert x is not None
164
+ assert mask is not None
165
+ assert t is not None
166
+
167
+ if dynamic_world is None:
168
+ dynamic_world = (
169
+ torch.ones(b, t, h, w, device=device) * NUM_DYNAMIC_WORLD_CLASSES
170
+ )
171
+
172
+ if months is None:
173
+ months = torch.ones((b, t), device=device) * self.month
174
+ else:
175
+ assert months.shape[-1] == t
176
+
177
+ if normalize:
178
+ x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
179
+ return x, mask, dynamic_world.long(), months.long()
180
+
181
+ @staticmethod
182
+ def time_ranges_to_timestamps(
183
+ time_ranges: list[tuple[datetime, datetime]],
184
+ device: torch.device,
185
+ ) -> torch.Tensor:
186
+ """Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
187
+
188
+ Presto only uses the month associated with each timestamp, so we take the midpoint
189
+ the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
190
+ time so that start_time == end_time == mid_time.
191
+ """
192
+ mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
193
+ # months are indexed 0-11
194
+ return torch.tensor(
195
+ [d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
196
+ )
197
+
198
+ def forward(self, context: ModelContext) -> FeatureMaps:
199
+ """Compute feature maps from the Presto backbone.
200
+
201
+ Args:
202
+ context: the model context. Input dicts should have some subset of Presto.input_keys.
203
+
204
+ Returns:
205
+ a FeatureMaps with one feature map that is at the same resolution as the
206
+ input (since Presto operates per-pixel).
207
+ """
208
+ time_modalities = ["s1", "s2", "era5"]
209
+ stacked_inputs = {}
210
+ latlons: torch.Tensor | None = None
211
+ months: torch.Tensor | None = None
212
+ for key in context.inputs[0].keys():
213
+ # assume all the keys in an input are consistent
214
+ if key in self.input_keys:
215
+ if key == "latlon":
216
+ latlons = torch.stack(
217
+ [inp[key].image for inp in context.inputs], dim=0
218
+ )
219
+ else:
220
+ stacked_inputs[key] = torch.stack(
221
+ [inp[key].image for inp in context.inputs], dim=0
222
+ )
223
+ if key in time_modalities:
224
+ if months is None:
225
+ if context.inputs[0][key].timestamps is not None:
226
+ months = torch.stack(
227
+ [
228
+ self.time_ranges_to_timestamps(
229
+ inp[key].timestamps, # type: ignore
230
+ device=stacked_inputs[key].device,
231
+ )
232
+ for inp in context.inputs
233
+ ],
234
+ dim=0,
235
+ )
236
+ if months is not None:
237
+ stacked_inputs["months"] = months
238
+
239
+ (
240
+ x,
241
+ mask,
242
+ dynamic_world,
243
+ months,
244
+ ) = self.construct_presto_input(
245
+ **stacked_inputs,
246
+ s1_bands=PRESTO_S1_BANDS,
247
+ s2_bands=INPUT_PRESTO_S2_BANDS,
248
+ era5_bands=ERA5_BANDS,
249
+ srtm_bands=SRTM_BANDS,
250
+ normalize=True,
251
+ )
252
+ b, _, h, w, _ = x.shape
253
+
254
+ output_features = torch.zeros(
255
+ b * h * w, self.model.embedding_size, device=x.device
256
+ )
257
+
258
+ x = rearrange(x, "b t h w d -> (b h w) t d")
259
+ mask = rearrange(mask, "b t h w d -> (b h w) t d")
260
+ dynamic_world = rearrange(dynamic_world, "b t h w -> (b h w) t")
261
+ months = repeat(months, "b t -> (b h w) t", h=h, w=w)
262
+ if latlons is not None:
263
+ latlons = rearrange(latlons, "b c h w -> (b h w) c")
264
+
265
+ for batch_idx in range(0, b * h * w, self.pixel_batch_size):
266
+ x_b = x[batch_idx : batch_idx + self.pixel_batch_size]
267
+ mask_b = mask[batch_idx : batch_idx + self.pixel_batch_size]
268
+ dw = dynamic_world[batch_idx : batch_idx + self.pixel_batch_size]
269
+ months_b = months[batch_idx : batch_idx + self.pixel_batch_size]
270
+ if latlons is not None:
271
+ l_b = latlons[batch_idx : batch_idx + self.pixel_batch_size]
272
+ else:
273
+ l_b = None
274
+ output_b = self.model(
275
+ x=x_b,
276
+ dynamic_world=dw,
277
+ mask=mask_b,
278
+ month=months_b,
279
+ latlons=l_b,
280
+ eval_task=True,
281
+ )
282
+ output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
283
+
284
+ return FeatureMaps(
285
+ [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
286
+ )
287
+
288
+ def get_backbone_channels(self) -> list:
289
+ """Returns the output channels of this model when used as a backbone.
290
+
291
+ The output channels is a list of (patch_size, depth) that corresponds
292
+ to the feature maps that the backbone returns.
293
+
294
+ Returns:
295
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
296
+ """
297
+ return [(1, 128)]