rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,530 @@
1
+ """Wrapper around ModelDataset to load all patches (crops) in a window."""
2
+
3
+ import itertools
4
+ from collections.abc import Iterable, Iterator
5
+ from dataclasses import replace
6
+ from typing import Any
7
+
8
+ import shapely
9
+ import torch
10
+
11
+ from rslearn.dataset import Window
12
+ from rslearn.train.dataset import DataInput, ModelDataset
13
+ from rslearn.train.model_context import RasterImage, SampleMetadata
14
+ from rslearn.utils.geometry import PixelBounds, STGeometry
15
+
16
+
17
+ def get_window_patch_options(
18
+ patch_size: tuple[int, int],
19
+ overlap_size: tuple[int, int],
20
+ bounds: PixelBounds,
21
+ ) -> list[PixelBounds]:
22
+ """Get the bounds of each input patch within the window bounds.
23
+
24
+ This is used when running inference on all patches (crops) of a large window, to
25
+ compute the position of each patch.
26
+
27
+ Args:
28
+ patch_size: the size of the patches to extract.
29
+ overlap_size: the size of the overlap between patches.
30
+ bounds: the window bounds to divide up into smaller patches.
31
+
32
+ Returns:
33
+ a list of patch bounds within the overall bounds. The rightmost and
34
+ bottommost patches may extend beyond the provided bounds.
35
+ """
36
+ # We stride the patches by patch_size - overlap_size until the last patch.
37
+ # We handle the first patch with a special case to ensure it is always used.
38
+ # We handle the last patch with a special case to ensure it does not exceed the
39
+ # window bounds. Instead, it may overlap the previous patch.
40
+ cols = [bounds[0]] + list(
41
+ range(
42
+ bounds[0] + patch_size[0],
43
+ bounds[2] - patch_size[0],
44
+ patch_size[0] - overlap_size[0],
45
+ )
46
+ )
47
+ rows = [bounds[1]] + list(
48
+ range(
49
+ bounds[1] + patch_size[1],
50
+ bounds[3] - patch_size[1],
51
+ patch_size[1] - overlap_size[1],
52
+ )
53
+ )
54
+ # Add last patches only if the input is larger than one patch.
55
+ if bounds[2] - patch_size[0] > bounds[0]:
56
+ cols.append(bounds[2] - patch_size[0])
57
+ if bounds[3] - patch_size[1] > bounds[1]:
58
+ rows.append(bounds[3] - patch_size[1])
59
+
60
+ patch_bounds: list[PixelBounds] = []
61
+ for col in cols:
62
+ for row in rows:
63
+ patch_bounds.append((col, row, col + patch_size[0], row + patch_size[1]))
64
+ return patch_bounds
65
+
66
+
67
+ def pad_slice_protect(
68
+ raw_inputs: dict[str, Any],
69
+ passthrough_inputs: dict[str, Any],
70
+ patch_size: tuple[int, int],
71
+ inputs: dict[str, DataInput],
72
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
73
+ """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
74
+
75
+ The padding is scaled based on each input's resolution_factor.
76
+
77
+ Args:
78
+ raw_inputs: the raw inputs to pad.
79
+ passthrough_inputs: the passthrough inputs to pad.
80
+ patch_size: the size of the patches to extract (at window resolution).
81
+ inputs: the DataInput definitions, used to get resolution_factor per input.
82
+
83
+ Returns:
84
+ a tuple of (raw_inputs, passthrough_inputs).
85
+ """
86
+ for d in [raw_inputs, passthrough_inputs]:
87
+ for input_name, value in list(d.items()):
88
+ if not isinstance(value, torch.Tensor):
89
+ continue
90
+ # Get resolution scale for this input
91
+ rf = inputs[input_name].resolution_factor
92
+ scale = rf.numerator / rf.denominator
93
+ # Scale the padding amount
94
+ scaled_pad_x = int(patch_size[0] * scale)
95
+ scaled_pad_y = int(patch_size[1] * scale)
96
+ d[input_name] = torch.nn.functional.pad(
97
+ value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
98
+ )
99
+ return raw_inputs, passthrough_inputs
100
+
101
+
102
+ def crop_tensor_or_rasterimage(
103
+ x: torch.Tensor | RasterImage, start: tuple[int, int], end: tuple[int, int]
104
+ ) -> torch.Tensor | RasterImage:
105
+ """Crop a tensor or a RasterImage."""
106
+ if isinstance(x, torch.Tensor):
107
+ # Crop the CHW tensor with scaled coordinates.
108
+ return x[
109
+ :,
110
+ start[1] : end[1],
111
+ start[0] : end[0],
112
+ ].clone()
113
+ else:
114
+ # Crop the CTHW tensor with scaled coordinates.
115
+ return RasterImage(
116
+ x.image[
117
+ :,
118
+ :,
119
+ start[1] : end[1],
120
+ start[0] : end[0],
121
+ ].clone(),
122
+ x.timestamps,
123
+ )
124
+
125
+
126
+ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
127
+ """This wraps a ModelDataset to iterate over all patches in that dataset.
128
+
129
+ This should be used when SplitConfig.load_all_patches is enabled. The ModelDataset
130
+ is configured with no patch size (load entire windows), and the dataset is wrapped
131
+ in an AllPatchesDataset.
132
+
133
+ Similar to DistributedSampler, we add extra samples at each rank to ensure
134
+ consistent number of batches across all ranks.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ dataset: ModelDataset,
140
+ patch_size: tuple[int, int],
141
+ overlap_ratio: float = 0.0,
142
+ rank: int = 0,
143
+ world_size: int = 1,
144
+ ):
145
+ """Create a new IterableAllPatchesDataset.
146
+
147
+ Args:
148
+ dataset: the ModelDataset to wrap.
149
+ patch_size: the size of the patches to extract.
150
+ overlap_ratio: whether to include overlap between the patches. Note that
151
+ the right/bottom-most patches may still overlap since we ensure that
152
+ all patches are contained in the window bounds.
153
+ rank: the global rank of this train worker process.
154
+ world_size: the total number of train worker processes.
155
+ """
156
+ super().__init__()
157
+ self.dataset = dataset
158
+ self.patch_size = patch_size
159
+ self.overlap_size = (
160
+ round(self.patch_size[0] * overlap_ratio),
161
+ round(self.patch_size[1] * overlap_ratio),
162
+ )
163
+ self.rank = rank
164
+ self.world_size = world_size
165
+ self.windows = self.dataset.get_dataset_examples()
166
+ self.inputs = dataset.inputs
167
+
168
+ def set_name(self, name: str) -> None:
169
+ """Sets dataset name.
170
+
171
+ Args:
172
+ name: dataset name
173
+ """
174
+ self.dataset.set_name(name)
175
+
176
+ def get_window_num_patches(self, bounds: PixelBounds) -> int:
177
+ """Get the number of patches for these bounds.
178
+
179
+ This corresponds to the length of the list returned by get_patch_options.
180
+ """
181
+ num_cols = (
182
+ len(
183
+ range(
184
+ bounds[0],
185
+ bounds[2] - self.patch_size[0],
186
+ self.patch_size[0] - self.overlap_size[0],
187
+ )
188
+ )
189
+ + 1
190
+ )
191
+ num_rows = (
192
+ len(
193
+ range(
194
+ bounds[1],
195
+ bounds[3] - self.patch_size[1],
196
+ self.patch_size[1] - self.overlap_size[1],
197
+ )
198
+ )
199
+ + 1
200
+ )
201
+ return num_cols * num_rows
202
+
203
+ def _get_worker_iteration_data(self) -> tuple[Iterable[int], int]:
204
+ """Get the windows we should iterate over.
205
+
206
+ This is split both by training worker (self.rank) and data loader worker (via
207
+ get_worker_info).
208
+
209
+ We also compute the total number of samples that each data loader worker should
210
+ yield. This is important for DDP to ensure that all ranks see the same number
211
+ of batches.
212
+
213
+ Returns:
214
+ a tuple (window_ids, num_samples_per_worker).
215
+ """
216
+ # Figure out the total number of data loader workers and our worker ID.
217
+ worker_info = torch.utils.data.get_worker_info()
218
+ if worker_info is None:
219
+ worker_id = 0
220
+ num_workers = 1
221
+ else:
222
+ worker_id = worker_info.id
223
+ num_workers = worker_info.num_workers
224
+ global_worker_id = self.rank * num_workers + worker_id
225
+ global_num_workers = self.world_size * num_workers
226
+
227
+ # Split up the windows evenly among the workers.
228
+ # We compute this for all workers since we will need to see the maximum number
229
+ # of samples under this assignment across workers.
230
+ window_indexes = range(len(self.windows))
231
+ windows_by_worker = [
232
+ window_indexes[cur_rank :: self.world_size][cur_worker_id::num_workers]
233
+ for cur_rank in range(self.world_size)
234
+ for cur_worker_id in range(num_workers)
235
+ ]
236
+
237
+ # Now compute the maximum number of samples across workers.
238
+ max_num_patches = 0
239
+ for worker_windows in windows_by_worker:
240
+ worker_num_patches = 0
241
+ for window_id in worker_windows:
242
+ worker_num_patches += self.get_window_num_patches(
243
+ self.windows[window_id].bounds
244
+ )
245
+ max_num_patches = max(max_num_patches, worker_num_patches)
246
+
247
+ # Each worker needs at least one window, otherwise it won't be able to pad.
248
+ # Unless there are zero windows total, which is fine.
249
+ # Previously we would address this by borrowing the windows from another
250
+ # worker, but this causes issues with RslearnWriter: if we yield the same
251
+ # window from parallel workers, it may end up writing an empty output for that
252
+ # window in the end.
253
+ # So now we raise an error instead, and require the number of workers to be
254
+ # less than the number of windows.
255
+ if len(windows_by_worker[global_worker_id]) == 0 and max_num_patches > 0:
256
+ raise ValueError(
257
+ f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
258
+ )
259
+
260
+ return (windows_by_worker[global_worker_id], max_num_patches)
261
+
262
+ def __iter__(
263
+ self,
264
+ ) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
265
+ """Iterate over all patches in each element of the underlying ModelDataset."""
266
+ # Iterate over the window IDs until we have returned enough samples.
267
+ window_ids, num_samples_needed = self._get_worker_iteration_data()
268
+ num_samples_returned = 0
269
+
270
+ for iteration_idx in itertools.count():
271
+ for window_id in window_ids:
272
+ raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
273
+ window_id
274
+ )
275
+ bounds = metadata.patch_bounds
276
+
277
+ # For simplicity, pad tensors by patch size to ensure that any patch bounds
278
+ # extending outside the window bounds will not have issues when we slice
279
+ # the tensors later. Padding is scaled per-input based on resolution_factor.
280
+ pad_slice_protect(
281
+ raw_inputs, passthrough_inputs, self.patch_size, self.inputs
282
+ )
283
+
284
+ # Now iterate over the patches and extract/yield the crops.
285
+ # Note that, in case user is leveraging RslearnWriter, it is important that
286
+ # the patch_idx be increasing (as we iterate) within one window.
287
+ patches = get_window_patch_options(
288
+ self.patch_size, self.overlap_size, bounds
289
+ )
290
+ for patch_idx, patch_bounds in enumerate(patches):
291
+ cur_geom = STGeometry(
292
+ metadata.projection, shapely.box(*patch_bounds), None
293
+ )
294
+ start_offset = (
295
+ patch_bounds[0] - bounds[0],
296
+ patch_bounds[1] - bounds[1],
297
+ )
298
+ end_offset = (
299
+ patch_bounds[2] - bounds[0],
300
+ patch_bounds[3] - bounds[1],
301
+ )
302
+
303
+ # Define a helper function to handle each input dict.
304
+ # Crop coordinates are scaled based on each input's resolution_factor.
305
+ def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
306
+ cropped = {}
307
+ for input_name, value in d.items():
308
+ if isinstance(value, torch.Tensor | RasterImage):
309
+ # Get resolution scale for this input
310
+ rf = self.inputs[input_name].resolution_factor
311
+ scale = rf.numerator / rf.denominator
312
+ # Scale the crop coordinates
313
+ scaled_start = (
314
+ int(start_offset[0] * scale),
315
+ int(start_offset[1] * scale),
316
+ )
317
+ scaled_end = (
318
+ int(end_offset[0] * scale),
319
+ int(end_offset[1] * scale),
320
+ )
321
+ cropped[input_name] = crop_tensor_or_rasterimage(
322
+ value, scaled_start, scaled_end
323
+ )
324
+ elif isinstance(value, list):
325
+ cropped[input_name] = [
326
+ feat
327
+ for feat in value
328
+ if cur_geom.intersects(feat.geometry)
329
+ ]
330
+ else:
331
+ raise ValueError(
332
+ "got input that is neither tensor nor feature list"
333
+ )
334
+ return cropped
335
+
336
+ cur_raw_inputs = crop_input_dict(raw_inputs)
337
+ cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
338
+
339
+ # Adjust the metadata as well.
340
+ cur_metadata = replace(
341
+ metadata,
342
+ patch_bounds=patch_bounds,
343
+ patch_idx=patch_idx,
344
+ num_patches_in_window=len(patches),
345
+ )
346
+
347
+ # Now we can compute input and target dicts via the task.
348
+ input_dict, target_dict = self.dataset.task.process_inputs(
349
+ cur_raw_inputs,
350
+ metadata=cur_metadata,
351
+ load_targets=not self.dataset.split_config.get_skip_targets(),
352
+ )
353
+ input_dict.update(cur_passthrough_inputs)
354
+ input_dict, target_dict = self.dataset.transforms(
355
+ input_dict, target_dict
356
+ )
357
+
358
+ if num_samples_returned < num_samples_needed:
359
+ yield input_dict, target_dict, cur_metadata
360
+ num_samples_returned += 1
361
+ else:
362
+ assert iteration_idx > 0
363
+
364
+ if num_samples_returned >= num_samples_needed:
365
+ break
366
+
367
+ def get_dataset_examples(self) -> list[Window]:
368
+ """Returns a list of windows in this dataset."""
369
+ return self.dataset.get_dataset_examples()
370
+
371
+
372
+ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
373
+ """This wraps a ModelDataset to iterate over all patches in that dataset.
374
+
375
+ This should be used when SplitConfig.load_all_patches is enabled.
376
+
377
+ This is a simpler version of IterableAllPatchesDataset that caches all windows in memory.
378
+ This is useful for small datasets that fit in memory.
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ dataset: ModelDataset,
384
+ patch_size: tuple[int, int],
385
+ overlap_ratio: float = 0.0,
386
+ ):
387
+ """Create a new InMemoryAllPatchesDataset.
388
+
389
+ Args:
390
+ dataset: the ModelDataset to wrap.
391
+ patch_size: the size of the patches to extract.
392
+ overlap_ratio: whether to include overlap between the patches. Note that
393
+ the right/bottom-most patches may still overlap since we ensure that
394
+ all patches are contained in the window bounds.
395
+ """
396
+ super().__init__()
397
+ self.dataset = dataset
398
+ self.patch_size = patch_size
399
+ self.overlap_size = (
400
+ round(self.patch_size[0] * overlap_ratio),
401
+ round(self.patch_size[1] * overlap_ratio),
402
+ )
403
+ self.windows = self.dataset.get_dataset_examples()
404
+ self.inputs = dataset.inputs
405
+ self.window_cache: dict[
406
+ int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
407
+ ] = {}
408
+
409
+ # Precompute the batch boundaries for each window
410
+ self.patches = []
411
+ for window_id, window in enumerate(self.windows):
412
+ patch_bounds = get_window_patch_options(
413
+ self.patch_size, self.overlap_size, window.bounds
414
+ )
415
+ for i, patch_bound in enumerate(patch_bounds):
416
+ self.patches.append((window_id, patch_bound, (i, len(patch_bounds))))
417
+
418
+ def get_raw_inputs(
419
+ self, index: int
420
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
421
+ """Get the raw inputs for a single patch. Retrieve from cache if possible.
422
+
423
+ Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
424
+
425
+ Args:
426
+ index: the index of the patch.
427
+
428
+ Returns:
429
+ a tuple of (raw_inputs, passthrough_inputs, metadata).
430
+ """
431
+ if index in self.window_cache:
432
+ return self.window_cache[index]
433
+
434
+ raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
435
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
436
+
437
+ self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
438
+ return self.window_cache[index]
439
+
440
+ def _crop_input_dict(
441
+ self,
442
+ d: dict[str, Any],
443
+ start_offset: tuple[int, int],
444
+ end_offset: tuple[int, int],
445
+ cur_geom: STGeometry,
446
+ ) -> dict[str, Any]:
447
+ """Crop a dictionary of inputs to the given bounds.
448
+
449
+ Crop coordinates are scaled based on each input's resolution_factor.
450
+ """
451
+ cropped = {}
452
+ for input_name, value in d.items():
453
+ if isinstance(value, torch.Tensor | RasterImage):
454
+ # Get resolution scale for this input
455
+ rf = self.inputs[input_name].resolution_factor
456
+ scale = rf.numerator / rf.denominator
457
+ # Scale the crop coordinates
458
+ scaled_start = (
459
+ int(start_offset[0] * scale),
460
+ int(start_offset[1] * scale),
461
+ )
462
+ scaled_end = (
463
+ int(end_offset[0] * scale),
464
+ int(end_offset[1] * scale),
465
+ )
466
+ cropped[input_name] = crop_tensor_or_rasterimage(
467
+ value, scaled_start, scaled_end
468
+ )
469
+
470
+ elif isinstance(value, list):
471
+ cropped[input_name] = [
472
+ feat for feat in value if cur_geom.intersects(feat.geometry)
473
+ ]
474
+ else:
475
+ raise ValueError("got input that is neither tensor nor feature list")
476
+ return cropped
477
+
478
+ def __len__(self) -> int:
479
+ """Return the total number of patches in the dataset."""
480
+ return len(self.patches)
481
+
482
+ def __getitem__(
483
+ self, index: int
484
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
485
+ """Return (input_dict, target_dict, metadata) for a single flattened patch."""
486
+ (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
487
+ raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
488
+ bounds = metadata.patch_bounds
489
+
490
+ cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
491
+ start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
492
+ end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
493
+
494
+ cur_raw_inputs = self._crop_input_dict(
495
+ raw_inputs, start_offset, end_offset, cur_geom
496
+ )
497
+ cur_passthrough_inputs = self._crop_input_dict(
498
+ passthrough_inputs, start_offset, end_offset, cur_geom
499
+ )
500
+
501
+ # Adjust the metadata as well.
502
+ cur_metadata = replace(
503
+ metadata,
504
+ patch_bounds=patch_bounds,
505
+ patch_idx=patch_idx,
506
+ num_patches_in_window=num_patches,
507
+ )
508
+
509
+ # Now we can compute input and target dicts via the task.
510
+ input_dict, target_dict = self.dataset.task.process_inputs(
511
+ cur_raw_inputs,
512
+ metadata=cur_metadata,
513
+ load_targets=not self.dataset.split_config.get_skip_targets(),
514
+ )
515
+ input_dict.update(cur_passthrough_inputs)
516
+ input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
517
+
518
+ return input_dict, target_dict, cur_metadata
519
+
520
+ def get_dataset_examples(self) -> list[Window]:
521
+ """Returns a list of windows in this dataset."""
522
+ return self.dataset.get_dataset_examples()
523
+
524
+ def set_name(self, name: str) -> None:
525
+ """Sets dataset name.
526
+
527
+ Args:
528
+ name: dataset name
529
+ """
530
+ self.dataset.set_name(name)
@@ -0,0 +1,53 @@
1
+ """Callback to activate/deactivate adapter layers."""
2
+
3
+ from typing import Any
4
+
5
+ from lightning.pytorch import LightningModule
6
+ from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.trainer import Trainer
8
+
9
+ from rslearn.log_utils import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class ActivateLayers(Callback):
15
+ """Activates adapter layers on a given epoch.
16
+
17
+ By default, at every epoch, every adapter layer is deactivated.
18
+ To activate an adapter layer, add a selector with the name of the adapter layer
19
+ and the epoch at which to activate it. Once an adapter layer is activated, it
20
+ remains active until the end of training.
21
+ """
22
+
23
+ def __init__(self, selectors: list[dict[str, Any]]) -> None:
24
+ """Initialize the callback.
25
+
26
+ Args:
27
+ selectors: List of selectors to activate.
28
+ Each selector is a dictionary with the following keys:
29
+ - "name": Substring selector of modules to activate (str).
30
+ - "at_epoch": The epoch at which to activate (int).
31
+ """
32
+ self.selectors = selectors
33
+
34
+ def on_train_epoch_start(
35
+ self,
36
+ trainer: Trainer,
37
+ pl_module: LightningModule,
38
+ ) -> None:
39
+ """Activate adapter layers on a given epoch.
40
+
41
+ Adapter layers are activated/deactivated by setting the `active` attribute.
42
+
43
+ Args:
44
+ trainer: The trainer object.
45
+ pl_module: The LightningModule object.
46
+ """
47
+ status = {}
48
+ for name, module in pl_module.named_modules():
49
+ for selector in self.selectors:
50
+ if selector["name"] in name:
51
+ module.active = trainer.current_epoch >= selector["at_epoch"]
52
+ status[selector["name"]] = "active" if module.active else "inactive"
53
+ logger.info(f"Updated adapter status: {status}")