rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,40 +1,154 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
- from collections.abc import Sequence
3
+ import json
4
+ from collections.abc import Iterable, Sequence
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
4
7
  from typing import Any
5
8
 
6
9
  import numpy as np
10
+ import numpy.typing as npt
7
11
  from lightning.pytorch import LightningModule, Trainer
8
12
  from lightning.pytorch.callbacks import BasePredictionWriter
9
13
  from upath import UPath
10
14
 
11
- from rslearn.config import LayerType, RasterFormatConfig
12
- from rslearn.dataset import Dataset
15
+ from rslearn.config import (
16
+ DatasetConfig,
17
+ LayerConfig,
18
+ LayerType,
19
+ StorageConfig,
20
+ )
21
+ from rslearn.dataset import Window
22
+ from rslearn.dataset.storage.storage import WindowStorage
23
+ from rslearn.log_utils import get_logger
24
+ from rslearn.train.model_context import SampleMetadata
13
25
  from rslearn.utils.array import copy_spatial_array
14
- from rslearn.utils.raster_format import load_raster_format
15
- from rslearn.utils.vector_format import load_vector_format
26
+ from rslearn.utils.feature import Feature
27
+ from rslearn.utils.geometry import PixelBounds
28
+ from rslearn.utils.raster_format import (
29
+ RasterFormat,
30
+ adjust_projection_and_bounds_for_array,
31
+ )
32
+ from rslearn.utils.vector_format import VectorFormat
16
33
 
17
34
  from .lightning_module import RslearnLightningModule
35
+ from .model_context import ModelOutput
36
+ from .tasks.task import Task
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class PendingPatchOutput:
43
+ """A patch output that hasn't been merged yet."""
44
+
45
+ bounds: PixelBounds
46
+ output: Any
18
47
 
19
48
 
20
49
  class PatchPredictionMerger:
21
50
  """Base class for merging predictions from multiple patches."""
22
51
 
23
52
  def merge(
24
- self, outputs: Sequence[Any], metadatas: Sequence[Any]
25
- ) -> tuple[Sequence[Any], Sequence[Any]]:
26
- """Merge the outputs and metadatas.
53
+ self,
54
+ window: Window,
55
+ outputs: Sequence[PendingPatchOutput],
56
+ layer_config: LayerConfig,
57
+ ) -> Any:
58
+ """Merge the outputs.
27
59
 
28
60
  Args:
61
+ window: the window we are merging the outputs for.
29
62
  outputs: the outputs to process.
30
- metadatas: the metadatas to process.
63
+ layer_config: the output layer configuration.
31
64
 
32
65
  Returns:
33
- the merged outputs and metadatas.
66
+ the merged outputs.
34
67
  """
35
68
  raise NotImplementedError
36
69
 
37
70
 
71
+ class VectorMerger(PatchPredictionMerger):
72
+ """Merger for vector data that simply concatenates the features."""
73
+
74
+ def merge(
75
+ self,
76
+ window: Window,
77
+ outputs: Sequence[PendingPatchOutput],
78
+ layer_config: LayerConfig,
79
+ ) -> list[Feature]:
80
+ """Concatenate the vector features."""
81
+ return [feat for output in outputs for feat in output.output]
82
+
83
+
84
+ class RasterMerger(PatchPredictionMerger):
85
+ """Merger for raster data that copies the rasters to the output."""
86
+
87
+ def __init__(self, padding: int | None = None, downsample_factor: int = 1):
88
+ """Create a new RasterMerger.
89
+
90
+ Args:
91
+ padding: the padding around the individual patch outputs to remove. This is
92
+ typically used when leveraging overlapping patches. Portions of outputs
93
+ at the border of the window will still be retained.
94
+ downsample_factor: the factor by which the rasters output by the task are
95
+ lower in resolution relative to the window resolution.
96
+ """
97
+ self.padding = padding
98
+ self.downsample_factor = downsample_factor
99
+
100
+ def merge(
101
+ self,
102
+ window: Window,
103
+ outputs: Sequence[PendingPatchOutput],
104
+ layer_config: LayerConfig,
105
+ ) -> npt.NDArray:
106
+ """Merge the raster outputs."""
107
+ num_channels = outputs[0].output.shape[0]
108
+ merged_image = np.zeros(
109
+ (
110
+ num_channels,
111
+ (window.bounds[3] - window.bounds[1]) // self.downsample_factor,
112
+ (window.bounds[2] - window.bounds[0]) // self.downsample_factor,
113
+ ),
114
+ dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
115
+ )
116
+
117
+ # Ensure the outputs are sorted by height then width.
118
+ # This way when we merge we can be sure that outputs that are lower or further
119
+ # to the right will overwrite earlier outputs.
120
+ sorted_outputs = sorted(
121
+ outputs, key=lambda output: (output.bounds[0], output.bounds[1])
122
+ )
123
+ for output in sorted_outputs:
124
+ # So now we just need to compute the src_offset to copy.
125
+ # If the output is not on the left or top boundary, then we should apply
126
+ # the padding (if set).
127
+ src = output.output
128
+ src_offset = (
129
+ output.bounds[0] // self.downsample_factor,
130
+ output.bounds[1] // self.downsample_factor,
131
+ )
132
+ if self.padding is not None and output.bounds[0] != window.bounds[0]:
133
+ src = src[:, :, self.padding :]
134
+ src_offset = (src_offset[0] + self.padding, src_offset[1])
135
+ if self.padding is not None and output.bounds[1] != window.bounds[1]:
136
+ src = src[:, self.padding :, :]
137
+ src_offset = (src_offset[0], src_offset[1] + self.padding)
138
+
139
+ copy_spatial_array(
140
+ src=src,
141
+ dst=merged_image,
142
+ src_offset=src_offset,
143
+ dst_offset=(
144
+ window.bounds[0] // self.downsample_factor,
145
+ window.bounds[1] // self.downsample_factor,
146
+ ),
147
+ )
148
+
149
+ return merged_image
150
+
151
+
38
152
  class RslearnWriter(BasePredictionWriter):
39
153
  """A writer that writes predictions back into the rslearn dataset.
40
154
 
@@ -46,9 +160,12 @@ class RslearnWriter(BasePredictionWriter):
46
160
  self,
47
161
  path: str,
48
162
  output_layer: str,
49
- path_options: dict[str, Any] = {},
50
- selector: list[str] = [],
163
+ path_options: dict[str, Any] | None = None,
164
+ selector: list[str] | None = None,
51
165
  merger: PatchPredictionMerger | None = None,
166
+ output_path: str | Path | None = None,
167
+ layer_config: LayerConfig | None = None,
168
+ storage_config: StorageConfig | None = None,
52
169
  ):
53
170
  """Create a new RslearnWriter.
54
171
 
@@ -57,42 +174,125 @@ class RslearnWriter(BasePredictionWriter):
57
174
  output_layer: which layer to write the outputs under.
58
175
  path_options: additional options for path to pass to fsspec
59
176
  selector: keys to access the desired output in the output dict if needed.
177
+ e.g ["key1", "key2"] gets output["key1"]["key2"]
60
178
  merger: merger to use to merge outputs from overlapped patches.
179
+ output_path: optional custom path for writing predictions. If provided,
180
+ predictions will be written to this path instead of deriving from dataset path.
181
+ layer_config: optional layer configuration. If provided, this config will be
182
+ used instead of reading from the dataset config, allowing usage without
183
+ requiring dataset config at the output path.
184
+ storage_config: optional storage configuration, needed similar to layer_config
185
+ if there is no dataset config.
61
186
  """
62
187
  super().__init__(write_interval="batch")
63
188
  self.output_layer = output_layer
64
- self.selector = selector
65
- self.path = UPath(path, **path_options)
66
- self.dataset = Dataset(self.path)
67
- self.layer_config = self.dataset.layers[self.output_layer]
189
+ self.selector = selector or []
190
+ ds_upath = UPath(path, **path_options or {})
191
+ output_upath = (
192
+ UPath(output_path, **path_options or {})
193
+ if output_path is not None
194
+ else ds_upath
195
+ )
68
196
 
69
- if self.layer_config.layer_type == LayerType.RASTER:
70
- band_cfg = self.layer_config.band_sets[0]
71
- self.format = load_raster_format(
72
- RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
197
+ self.layer_config, self.dataset_storage = (
198
+ self._get_layer_config_and_dataset_storage(
199
+ ds_upath, output_upath, layer_config, storage_config
73
200
  )
74
- elif self.layer_config.layer_type == LayerType.VECTOR:
75
- self.format = load_vector_format(self.layer_config.format)
201
+ )
202
+
203
+ self.format: RasterFormat | VectorFormat
204
+ if self.layer_config.type == LayerType.RASTER:
205
+ band_cfg = self.layer_config.band_sets[0]
206
+ self.format = band_cfg.instantiate_raster_format()
207
+ elif self.layer_config.type == LayerType.VECTOR:
208
+ self.format = self.layer_config.instantiate_vector_format()
76
209
  else:
77
- raise ValueError(f"invalid layer type {self.layer_config.layer_type}")
210
+ raise ValueError(f"invalid layer type {self.layer_config.type}")
78
211
 
79
- self.merger = merger
212
+ if merger is not None:
213
+ self.merger = merger
214
+ elif self.layer_config.type == LayerType.RASTER:
215
+ self.merger = RasterMerger()
216
+ elif self.layer_config.type == LayerType.VECTOR:
217
+ self.merger = VectorMerger()
80
218
 
81
219
  # Map from window name to pending data to write.
82
220
  # This is used when windows are split up into patches, so the data from all the
83
221
  # patches of each window need to be reconstituted.
84
- self.pending_outputs = {}
222
+ self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
223
+
224
+ def _get_layer_config_and_dataset_storage(
225
+ self,
226
+ ds_upath: UPath,
227
+ output_upath: UPath,
228
+ layer_config: LayerConfig | None,
229
+ storage_config: StorageConfig | None,
230
+ ) -> tuple[LayerConfig, WindowStorage]:
231
+ """Get the layer config and dataset storage to use.
232
+
233
+ This is a helper function for the init method.
234
+
235
+ If layer_config is set, we use that. If storage_config is set, we use it to
236
+ instantiate a WindowStorage using the output_upath.
237
+
238
+ If one of them is not set, we load the config from the ds_upath. Otherwise, we
239
+ avoid reading the dataset config; this way, RslearnWriter can be used with
240
+ output directories that do not contain the dataset config, as long as
241
+ layer_config and storage_config are both provided.
242
+
243
+ Args:
244
+ ds_upath: the dataset path, where a dataset config can be loaded from if
245
+ layer_config or storage_config is not provided.
246
+ output_upath: the output directory, which could be different from the
247
+ dataset path.
248
+ layer_config: optional LayerConfig to provide.
249
+ storage_config: optional StorageConfig to provide.
250
+
251
+ Returns:
252
+ a tuple (layer_config, dataset_storage)
253
+ """
254
+ dataset_storage: WindowStorage | None = None
255
+
256
+ # Instantiate the WindowStorage from the storage_config if provided.
257
+ if storage_config:
258
+ dataset_storage = (
259
+ storage_config.instantiate_window_storage_factory().get_storage(
260
+ output_upath
261
+ )
262
+ )
263
+
264
+ if not layer_config or not dataset_storage:
265
+ # Need to load dataset config since one of LayerConfig/StorageConfig is missing.
266
+ # We use DatasetConfig.model_validate instead of initializing the Dataset
267
+ # because we want to get a WindowStorage that has the dataset path set to
268
+ # output_upath instead of ds_upath.
269
+ with (ds_upath / "config.json").open() as f:
270
+ dataset_config = DatasetConfig.model_validate(json.load(f))
271
+
272
+ if not layer_config:
273
+ if self.output_layer not in dataset_config.layers:
274
+ raise KeyError(
275
+ f"Output layer '{self.output_layer}' not found in dataset layers."
276
+ )
277
+ layer_config = dataset_config.layers[self.output_layer]
278
+
279
+ if not dataset_storage:
280
+ dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
281
+ output_upath
282
+ )
283
+
284
+ return (layer_config, dataset_storage)
85
285
 
86
286
  def write_on_batch_end(
87
287
  self,
88
288
  trainer: Trainer,
89
289
  pl_module: LightningModule,
90
- prediction: Sequence[Any],
91
- batch_indices: Sequence[Any],
92
- batch: Any,
290
+ prediction: ModelOutput,
291
+ batch_indices: Sequence[int] | None,
292
+ batch: tuple[list, list, list],
93
293
  batch_idx: int,
94
294
  dataloader_idx: int,
95
- ):
295
+ ) -> None:
96
296
  """Write a batch of predictions into the rslearn dataset.
97
297
 
98
298
  Args:
@@ -100,14 +300,38 @@ class RslearnWriter(BasePredictionWriter):
100
300
  pl_module: the LightningModule.
101
301
  prediction: the prediction to write.
102
302
  batch_indices: batch indices.
103
- batch: the batch that was input to the model.
303
+ batch: the batch that was input to the model. It should be a list of
304
+ (inputs, targets, metadatas).
104
305
  batch_idx: the batch index.
105
306
  dataloader_idx: the index in the dataloader.
106
307
  """
107
308
  assert isinstance(pl_module, RslearnLightningModule)
108
- metadatas = batch[2]
109
- outputs = [
110
- pl_module.task.process_output(output, metadata)
309
+ task = pl_module.task
310
+ _, _, metadatas = batch
311
+ self.process_output_batch(task, prediction.outputs, metadatas)
312
+
313
+ def process_output_batch(
314
+ self,
315
+ task: Task,
316
+ prediction: Iterable[Any],
317
+ metadatas: Iterable[SampleMetadata],
318
+ ) -> None:
319
+ """Write a prediction batch with simplified API.
320
+
321
+ write_on_batch_end wraps this function to work with lightning API, but only a
322
+ subset of arguments are used.
323
+
324
+ Args:
325
+ task: the Task that we are writing outputs for.
326
+ prediction: the list of predictions in this batch to write. These outputs
327
+ will be processed by the task to obtain a vector (list[Feature]) or
328
+ raster (npt.NDArray) output.
329
+ metadatas: corresponding list of metadatas from the batch describing the
330
+ patches that were processed.
331
+ """
332
+ # Process the predictions into outputs that can be written.
333
+ outputs: list = [
334
+ task.process_output(output, metadata)
111
335
  for output, metadata in zip(prediction, metadatas)
112
336
  ]
113
337
 
@@ -115,64 +339,75 @@ class RslearnWriter(BasePredictionWriter):
115
339
  for k in self.selector:
116
340
  output = output[k]
117
341
 
118
- window_name = metadata["window_name"]
119
- cur_bounds = metadata["bounds"]
120
- window_bounds = metadata["window_bounds"]
121
-
122
- if self.layer_config.layer_type == LayerType.RASTER:
123
- if window_name not in self.pending_outputs:
124
- self.pending_outputs[window_name] = np.zeros(
125
- (
126
- output.shape[0],
127
- window_bounds[3] - window_bounds[1],
128
- window_bounds[2] - window_bounds[0],
129
- ),
130
- dtype=output.dtype,
131
- )
342
+ window = Window(
343
+ storage=self.dataset_storage,
344
+ group=metadata.window_group,
345
+ name=metadata.window_name,
346
+ projection=metadata.projection,
347
+ bounds=metadata.window_bounds,
348
+ time_range=metadata.time_range,
349
+ )
350
+ self.process_output(
351
+ window,
352
+ metadata.patch_idx,
353
+ metadata.num_patches_in_window,
354
+ metadata.patch_bounds,
355
+ output,
356
+ )
132
357
 
133
- # Use copy_spatial_array to handle the copy since, when using patches,
134
- # the last column/row of outputs might extend beyond the bounds of the
135
- # window.
136
- copy_spatial_array(
137
- src=output,
138
- dst=self.pending_outputs[window_name],
139
- src_offset=(cur_bounds[0], cur_bounds[1]),
140
- dst_offset=(window_bounds[0], window_bounds[1]),
141
- )
358
+ def process_output(
359
+ self,
360
+ window: Window,
361
+ patch_idx: int,
362
+ num_patches: int,
363
+ cur_bounds: PixelBounds,
364
+ output: npt.NDArray | list[Feature],
365
+ ) -> None:
366
+ """Process one output from the model.
142
367
 
143
- elif self.layer_config.layer_type == LayerType.VECTOR:
144
- if window_name not in self.pending_outputs:
145
- self.pending_outputs[window_name] = []
368
+ Args:
369
+ window: the window that the output pertains to.
370
+ patch_idx: the index of this patch for the window.
371
+ num_patches: the total number of patches to be processed for the window.
372
+ cur_bounds: the bounds of the current patch.
373
+ output: the output data.
374
+ """
375
+ # Incorporate the output into our list of pending patch outputs.
376
+ if window.name not in self.pending_outputs:
377
+ self.pending_outputs[window.name] = []
378
+ self.pending_outputs[window.name].append(PendingPatchOutput(cur_bounds, output))
379
+ logger.debug(
380
+ f"Stored PendingPatchOutput for patch #{patch_idx}/{num_patches} at window {window.name}"
381
+ )
146
382
 
147
- self.pending_outputs[window_name].extend(output)
383
+ if patch_idx < num_patches - 1:
384
+ return
148
385
 
149
- if metadata["patch_idx"] < metadata["num_patches"] - 1:
150
- continue
386
+ # This is the last patch so it's time to write it.
387
+ # First get the pending output and clear it.
388
+ pending_output = self.pending_outputs[window.name]
389
+ del self.pending_outputs[window.name]
151
390
 
152
- pending_output = self.pending_outputs[window_name]
153
- del self.pending_outputs[window_name]
391
+ # Merge outputs from overlapped patches if merger is set.
392
+ logger.debug(f"Merging and writing for window {window.name}")
393
+ merged_output = self.merger.merge(window, pending_output, self.layer_config)
154
394
 
155
- # This is the last patch so it's time to merge outputs from overlapped patches
156
- if self.merger is not None:
157
- pending_output = self.merger.merge(pending_output)
395
+ if self.layer_config.type == LayerType.RASTER:
396
+ raster_dir = window.get_raster_dir(
397
+ self.output_layer, self.layer_config.band_sets[0].bands
398
+ )
399
+ assert isinstance(self.format, RasterFormat)
158
400
 
159
- # This is the last patch so it's time to write it.
160
- layer_dir = (
161
- self.dataset.path
162
- / "windows"
163
- / metadata["group"]
164
- / window_name
165
- / "layers"
166
- / self.output_layer
401
+ # In case the merged_output is at a different resolution than the window,
402
+ # get adjusted projection and bounds for writing it.
403
+ projection, bounds = adjust_projection_and_bounds_for_array(
404
+ window.projection, window.bounds, merged_output
167
405
  )
406
+ self.format.encode_raster(raster_dir, projection, bounds, merged_output)
168
407
 
169
- if self.layer_config.layer_type == LayerType.RASTER:
170
- band_dir = layer_dir / "_".join(self.layer_config.band_sets[0].bands)
171
- self.format.encode_raster(
172
- band_dir, metadata["projection"], window_bounds, pending_output
173
- )
408
+ elif self.layer_config.type == LayerType.VECTOR:
409
+ layer_dir = window.get_layer_dir(self.output_layer)
410
+ assert isinstance(self.format, VectorFormat)
411
+ self.format.encode_vector(layer_dir, merged_output)
174
412
 
175
- elif self.layer_config.layer_type == LayerType.VECTOR:
176
- self.format.encode_vector(
177
- layer_dir, metadata["projection"], pending_output
178
- )
413
+ window.mark_layer_completed(self.output_layer)
@@ -0,0 +1,92 @@
1
+ """Learning rate schedulers for rslearn."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import asdict, dataclass
5
+
6
+ from torch.optim import Optimizer
7
+ from torch.optim.lr_scheduler import (
8
+ CosineAnnealingLR,
9
+ CosineAnnealingWarmRestarts,
10
+ LRScheduler,
11
+ MultiStepLR,
12
+ ReduceLROnPlateau,
13
+ )
14
+
15
+ from rslearn.log_utils import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class SchedulerFactory(ABC):
21
+ """A factory class that initializes an LR scheduler given the optimizer."""
22
+
23
+ def get_kwargs(self) -> dict:
24
+ """Get the keyword arguments for the scheduler."""
25
+ return {k: v for k, v in asdict(self).items() if v is not None} # type: ignore
26
+
27
+ @abstractmethod
28
+ def build(self, optimizer: Optimizer) -> LRScheduler:
29
+ """Build the learning rate scheduler configured by this factory class."""
30
+ logger.info(
31
+ f"Using scheduler {self.__class__.__name__} with kwargs {self.get_kwargs()}"
32
+ )
33
+
34
+
35
+ @dataclass
36
+ class PlateauScheduler(SchedulerFactory):
37
+ """Plateau learning rate scheduler."""
38
+
39
+ mode: str | None = None
40
+ factor: float | None = None
41
+ patience: int | None = None
42
+ threshold: float | None = None
43
+ threshold_mode: str | None = None
44
+ cooldown: int | None = None
45
+ min_lr: float | None = None
46
+ eps: float | None = None
47
+
48
+ def build(self, optimizer: Optimizer) -> LRScheduler:
49
+ """Build the ReduceLROnPlateau scheduler."""
50
+ super().build(optimizer)
51
+ return ReduceLROnPlateau(optimizer, **self.get_kwargs())
52
+
53
+
54
+ @dataclass
55
+ class MultiStepScheduler(SchedulerFactory):
56
+ """Step learning rate scheduler."""
57
+
58
+ milestones: list[int]
59
+ gamma: float | None = None
60
+ last_epoch: int | None = None
61
+
62
+ def build(self, optimizer: Optimizer) -> LRScheduler:
63
+ """Build the ReduceLROnPlateau scheduler."""
64
+ super().build(optimizer)
65
+ return MultiStepLR(optimizer, **self.get_kwargs())
66
+
67
+
68
+ @dataclass
69
+ class CosineAnnealingScheduler(SchedulerFactory):
70
+ """Cosine annealing learning rate scheduler."""
71
+
72
+ T_max: int
73
+ eta_min: float | None = None
74
+
75
+ def build(self, optimizer: Optimizer) -> LRScheduler:
76
+ """Build the CosineAnnealingLR scheduler."""
77
+ super().build(optimizer)
78
+ return CosineAnnealingLR(optimizer, **self.get_kwargs())
79
+
80
+
81
+ @dataclass
82
+ class CosineAnnealingWarmRestartsScheduler(SchedulerFactory):
83
+ """Cosine annealing with warm restarts learning rate scheduler."""
84
+
85
+ T_0: int
86
+ T_mult: int = 1
87
+ eta_min: float = 0.0
88
+
89
+ def build(self, optimizer: Optimizer) -> LRScheduler:
90
+ """Build the CosineAnnealingWarmRestarts scheduler."""
91
+ super().build(optimizer)
92
+ return CosineAnnealingWarmRestarts(optimizer, **self.get_kwargs())