rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,40 +1,154 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
4
7
|
from typing import Any
|
|
5
8
|
|
|
6
9
|
import numpy as np
|
|
10
|
+
import numpy.typing as npt
|
|
7
11
|
from lightning.pytorch import LightningModule, Trainer
|
|
8
12
|
from lightning.pytorch.callbacks import BasePredictionWriter
|
|
9
13
|
from upath import UPath
|
|
10
14
|
|
|
11
|
-
from rslearn.config import
|
|
12
|
-
|
|
15
|
+
from rslearn.config import (
|
|
16
|
+
DatasetConfig,
|
|
17
|
+
LayerConfig,
|
|
18
|
+
LayerType,
|
|
19
|
+
StorageConfig,
|
|
20
|
+
)
|
|
21
|
+
from rslearn.dataset import Window
|
|
22
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
23
|
+
from rslearn.log_utils import get_logger
|
|
24
|
+
from rslearn.train.model_context import SampleMetadata
|
|
13
25
|
from rslearn.utils.array import copy_spatial_array
|
|
14
|
-
from rslearn.utils.
|
|
15
|
-
from rslearn.utils.
|
|
26
|
+
from rslearn.utils.feature import Feature
|
|
27
|
+
from rslearn.utils.geometry import PixelBounds
|
|
28
|
+
from rslearn.utils.raster_format import (
|
|
29
|
+
RasterFormat,
|
|
30
|
+
adjust_projection_and_bounds_for_array,
|
|
31
|
+
)
|
|
32
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
16
33
|
|
|
17
34
|
from .lightning_module import RslearnLightningModule
|
|
35
|
+
from .model_context import ModelOutput
|
|
36
|
+
from .tasks.task import Task
|
|
37
|
+
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class PendingPatchOutput:
|
|
43
|
+
"""A patch output that hasn't been merged yet."""
|
|
44
|
+
|
|
45
|
+
bounds: PixelBounds
|
|
46
|
+
output: Any
|
|
18
47
|
|
|
19
48
|
|
|
20
49
|
class PatchPredictionMerger:
|
|
21
50
|
"""Base class for merging predictions from multiple patches."""
|
|
22
51
|
|
|
23
52
|
def merge(
|
|
24
|
-
self,
|
|
25
|
-
|
|
26
|
-
|
|
53
|
+
self,
|
|
54
|
+
window: Window,
|
|
55
|
+
outputs: Sequence[PendingPatchOutput],
|
|
56
|
+
layer_config: LayerConfig,
|
|
57
|
+
) -> Any:
|
|
58
|
+
"""Merge the outputs.
|
|
27
59
|
|
|
28
60
|
Args:
|
|
61
|
+
window: the window we are merging the outputs for.
|
|
29
62
|
outputs: the outputs to process.
|
|
30
|
-
|
|
63
|
+
layer_config: the output layer configuration.
|
|
31
64
|
|
|
32
65
|
Returns:
|
|
33
|
-
the merged outputs
|
|
66
|
+
the merged outputs.
|
|
34
67
|
"""
|
|
35
68
|
raise NotImplementedError
|
|
36
69
|
|
|
37
70
|
|
|
71
|
+
class VectorMerger(PatchPredictionMerger):
|
|
72
|
+
"""Merger for vector data that simply concatenates the features."""
|
|
73
|
+
|
|
74
|
+
def merge(
|
|
75
|
+
self,
|
|
76
|
+
window: Window,
|
|
77
|
+
outputs: Sequence[PendingPatchOutput],
|
|
78
|
+
layer_config: LayerConfig,
|
|
79
|
+
) -> list[Feature]:
|
|
80
|
+
"""Concatenate the vector features."""
|
|
81
|
+
return [feat for output in outputs for feat in output.output]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RasterMerger(PatchPredictionMerger):
|
|
85
|
+
"""Merger for raster data that copies the rasters to the output."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, padding: int | None = None, downsample_factor: int = 1):
|
|
88
|
+
"""Create a new RasterMerger.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
padding: the padding around the individual patch outputs to remove. This is
|
|
92
|
+
typically used when leveraging overlapping patches. Portions of outputs
|
|
93
|
+
at the border of the window will still be retained.
|
|
94
|
+
downsample_factor: the factor by which the rasters output by the task are
|
|
95
|
+
lower in resolution relative to the window resolution.
|
|
96
|
+
"""
|
|
97
|
+
self.padding = padding
|
|
98
|
+
self.downsample_factor = downsample_factor
|
|
99
|
+
|
|
100
|
+
def merge(
|
|
101
|
+
self,
|
|
102
|
+
window: Window,
|
|
103
|
+
outputs: Sequence[PendingPatchOutput],
|
|
104
|
+
layer_config: LayerConfig,
|
|
105
|
+
) -> npt.NDArray:
|
|
106
|
+
"""Merge the raster outputs."""
|
|
107
|
+
num_channels = outputs[0].output.shape[0]
|
|
108
|
+
merged_image = np.zeros(
|
|
109
|
+
(
|
|
110
|
+
num_channels,
|
|
111
|
+
(window.bounds[3] - window.bounds[1]) // self.downsample_factor,
|
|
112
|
+
(window.bounds[2] - window.bounds[0]) // self.downsample_factor,
|
|
113
|
+
),
|
|
114
|
+
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Ensure the outputs are sorted by height then width.
|
|
118
|
+
# This way when we merge we can be sure that outputs that are lower or further
|
|
119
|
+
# to the right will overwrite earlier outputs.
|
|
120
|
+
sorted_outputs = sorted(
|
|
121
|
+
outputs, key=lambda output: (output.bounds[0], output.bounds[1])
|
|
122
|
+
)
|
|
123
|
+
for output in sorted_outputs:
|
|
124
|
+
# So now we just need to compute the src_offset to copy.
|
|
125
|
+
# If the output is not on the left or top boundary, then we should apply
|
|
126
|
+
# the padding (if set).
|
|
127
|
+
src = output.output
|
|
128
|
+
src_offset = (
|
|
129
|
+
output.bounds[0] // self.downsample_factor,
|
|
130
|
+
output.bounds[1] // self.downsample_factor,
|
|
131
|
+
)
|
|
132
|
+
if self.padding is not None and output.bounds[0] != window.bounds[0]:
|
|
133
|
+
src = src[:, :, self.padding :]
|
|
134
|
+
src_offset = (src_offset[0] + self.padding, src_offset[1])
|
|
135
|
+
if self.padding is not None and output.bounds[1] != window.bounds[1]:
|
|
136
|
+
src = src[:, self.padding :, :]
|
|
137
|
+
src_offset = (src_offset[0], src_offset[1] + self.padding)
|
|
138
|
+
|
|
139
|
+
copy_spatial_array(
|
|
140
|
+
src=src,
|
|
141
|
+
dst=merged_image,
|
|
142
|
+
src_offset=src_offset,
|
|
143
|
+
dst_offset=(
|
|
144
|
+
window.bounds[0] // self.downsample_factor,
|
|
145
|
+
window.bounds[1] // self.downsample_factor,
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return merged_image
|
|
150
|
+
|
|
151
|
+
|
|
38
152
|
class RslearnWriter(BasePredictionWriter):
|
|
39
153
|
"""A writer that writes predictions back into the rslearn dataset.
|
|
40
154
|
|
|
@@ -46,9 +160,12 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
46
160
|
self,
|
|
47
161
|
path: str,
|
|
48
162
|
output_layer: str,
|
|
49
|
-
path_options: dict[str, Any] =
|
|
50
|
-
selector: list[str] =
|
|
163
|
+
path_options: dict[str, Any] | None = None,
|
|
164
|
+
selector: list[str] | None = None,
|
|
51
165
|
merger: PatchPredictionMerger | None = None,
|
|
166
|
+
output_path: str | Path | None = None,
|
|
167
|
+
layer_config: LayerConfig | None = None,
|
|
168
|
+
storage_config: StorageConfig | None = None,
|
|
52
169
|
):
|
|
53
170
|
"""Create a new RslearnWriter.
|
|
54
171
|
|
|
@@ -57,42 +174,125 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
57
174
|
output_layer: which layer to write the outputs under.
|
|
58
175
|
path_options: additional options for path to pass to fsspec
|
|
59
176
|
selector: keys to access the desired output in the output dict if needed.
|
|
177
|
+
e.g ["key1", "key2"] gets output["key1"]["key2"]
|
|
60
178
|
merger: merger to use to merge outputs from overlapped patches.
|
|
179
|
+
output_path: optional custom path for writing predictions. If provided,
|
|
180
|
+
predictions will be written to this path instead of deriving from dataset path.
|
|
181
|
+
layer_config: optional layer configuration. If provided, this config will be
|
|
182
|
+
used instead of reading from the dataset config, allowing usage without
|
|
183
|
+
requiring dataset config at the output path.
|
|
184
|
+
storage_config: optional storage configuration, needed similar to layer_config
|
|
185
|
+
if there is no dataset config.
|
|
61
186
|
"""
|
|
62
187
|
super().__init__(write_interval="batch")
|
|
63
188
|
self.output_layer = output_layer
|
|
64
|
-
self.selector = selector
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
189
|
+
self.selector = selector or []
|
|
190
|
+
ds_upath = UPath(path, **path_options or {})
|
|
191
|
+
output_upath = (
|
|
192
|
+
UPath(output_path, **path_options or {})
|
|
193
|
+
if output_path is not None
|
|
194
|
+
else ds_upath
|
|
195
|
+
)
|
|
68
196
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
|
|
197
|
+
self.layer_config, self.dataset_storage = (
|
|
198
|
+
self._get_layer_config_and_dataset_storage(
|
|
199
|
+
ds_upath, output_upath, layer_config, storage_config
|
|
73
200
|
)
|
|
74
|
-
|
|
75
|
-
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.format: RasterFormat | VectorFormat
|
|
204
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
205
|
+
band_cfg = self.layer_config.band_sets[0]
|
|
206
|
+
self.format = band_cfg.instantiate_raster_format()
|
|
207
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
208
|
+
self.format = self.layer_config.instantiate_vector_format()
|
|
76
209
|
else:
|
|
77
|
-
raise ValueError(f"invalid layer type {self.layer_config.
|
|
210
|
+
raise ValueError(f"invalid layer type {self.layer_config.type}")
|
|
78
211
|
|
|
79
|
-
|
|
212
|
+
if merger is not None:
|
|
213
|
+
self.merger = merger
|
|
214
|
+
elif self.layer_config.type == LayerType.RASTER:
|
|
215
|
+
self.merger = RasterMerger()
|
|
216
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
217
|
+
self.merger = VectorMerger()
|
|
80
218
|
|
|
81
219
|
# Map from window name to pending data to write.
|
|
82
220
|
# This is used when windows are split up into patches, so the data from all the
|
|
83
221
|
# patches of each window need to be reconstituted.
|
|
84
|
-
self.pending_outputs = {}
|
|
222
|
+
self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
|
|
223
|
+
|
|
224
|
+
def _get_layer_config_and_dataset_storage(
|
|
225
|
+
self,
|
|
226
|
+
ds_upath: UPath,
|
|
227
|
+
output_upath: UPath,
|
|
228
|
+
layer_config: LayerConfig | None,
|
|
229
|
+
storage_config: StorageConfig | None,
|
|
230
|
+
) -> tuple[LayerConfig, WindowStorage]:
|
|
231
|
+
"""Get the layer config and dataset storage to use.
|
|
232
|
+
|
|
233
|
+
This is a helper function for the init method.
|
|
234
|
+
|
|
235
|
+
If layer_config is set, we use that. If storage_config is set, we use it to
|
|
236
|
+
instantiate a WindowStorage using the output_upath.
|
|
237
|
+
|
|
238
|
+
If one of them is not set, we load the config from the ds_upath. Otherwise, we
|
|
239
|
+
avoid reading the dataset config; this way, RslearnWriter can be used with
|
|
240
|
+
output directories that do not contain the dataset config, as long as
|
|
241
|
+
layer_config and storage_config are both provided.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
ds_upath: the dataset path, where a dataset config can be loaded from if
|
|
245
|
+
layer_config or storage_config is not provided.
|
|
246
|
+
output_upath: the output directory, which could be different from the
|
|
247
|
+
dataset path.
|
|
248
|
+
layer_config: optional LayerConfig to provide.
|
|
249
|
+
storage_config: optional StorageConfig to provide.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
a tuple (layer_config, dataset_storage)
|
|
253
|
+
"""
|
|
254
|
+
dataset_storage: WindowStorage | None = None
|
|
255
|
+
|
|
256
|
+
# Instantiate the WindowStorage from the storage_config if provided.
|
|
257
|
+
if storage_config:
|
|
258
|
+
dataset_storage = (
|
|
259
|
+
storage_config.instantiate_window_storage_factory().get_storage(
|
|
260
|
+
output_upath
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if not layer_config or not dataset_storage:
|
|
265
|
+
# Need to load dataset config since one of LayerConfig/StorageConfig is missing.
|
|
266
|
+
# We use DatasetConfig.model_validate instead of initializing the Dataset
|
|
267
|
+
# because we want to get a WindowStorage that has the dataset path set to
|
|
268
|
+
# output_upath instead of ds_upath.
|
|
269
|
+
with (ds_upath / "config.json").open() as f:
|
|
270
|
+
dataset_config = DatasetConfig.model_validate(json.load(f))
|
|
271
|
+
|
|
272
|
+
if not layer_config:
|
|
273
|
+
if self.output_layer not in dataset_config.layers:
|
|
274
|
+
raise KeyError(
|
|
275
|
+
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
276
|
+
)
|
|
277
|
+
layer_config = dataset_config.layers[self.output_layer]
|
|
278
|
+
|
|
279
|
+
if not dataset_storage:
|
|
280
|
+
dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
281
|
+
output_upath
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return (layer_config, dataset_storage)
|
|
85
285
|
|
|
86
286
|
def write_on_batch_end(
|
|
87
287
|
self,
|
|
88
288
|
trainer: Trainer,
|
|
89
289
|
pl_module: LightningModule,
|
|
90
|
-
prediction:
|
|
91
|
-
batch_indices: Sequence[
|
|
92
|
-
batch:
|
|
290
|
+
prediction: ModelOutput,
|
|
291
|
+
batch_indices: Sequence[int] | None,
|
|
292
|
+
batch: tuple[list, list, list],
|
|
93
293
|
batch_idx: int,
|
|
94
294
|
dataloader_idx: int,
|
|
95
|
-
):
|
|
295
|
+
) -> None:
|
|
96
296
|
"""Write a batch of predictions into the rslearn dataset.
|
|
97
297
|
|
|
98
298
|
Args:
|
|
@@ -100,14 +300,38 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
100
300
|
pl_module: the LightningModule.
|
|
101
301
|
prediction: the prediction to write.
|
|
102
302
|
batch_indices: batch indices.
|
|
103
|
-
batch: the batch that was input to the model.
|
|
303
|
+
batch: the batch that was input to the model. It should be a list of
|
|
304
|
+
(inputs, targets, metadatas).
|
|
104
305
|
batch_idx: the batch index.
|
|
105
306
|
dataloader_idx: the index in the dataloader.
|
|
106
307
|
"""
|
|
107
308
|
assert isinstance(pl_module, RslearnLightningModule)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
309
|
+
task = pl_module.task
|
|
310
|
+
_, _, metadatas = batch
|
|
311
|
+
self.process_output_batch(task, prediction.outputs, metadatas)
|
|
312
|
+
|
|
313
|
+
def process_output_batch(
|
|
314
|
+
self,
|
|
315
|
+
task: Task,
|
|
316
|
+
prediction: Iterable[Any],
|
|
317
|
+
metadatas: Iterable[SampleMetadata],
|
|
318
|
+
) -> None:
|
|
319
|
+
"""Write a prediction batch with simplified API.
|
|
320
|
+
|
|
321
|
+
write_on_batch_end wraps this function to work with lightning API, but only a
|
|
322
|
+
subset of arguments are used.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
task: the Task that we are writing outputs for.
|
|
326
|
+
prediction: the list of predictions in this batch to write. These outputs
|
|
327
|
+
will be processed by the task to obtain a vector (list[Feature]) or
|
|
328
|
+
raster (npt.NDArray) output.
|
|
329
|
+
metadatas: corresponding list of metadatas from the batch describing the
|
|
330
|
+
patches that were processed.
|
|
331
|
+
"""
|
|
332
|
+
# Process the predictions into outputs that can be written.
|
|
333
|
+
outputs: list = [
|
|
334
|
+
task.process_output(output, metadata)
|
|
111
335
|
for output, metadata in zip(prediction, metadatas)
|
|
112
336
|
]
|
|
113
337
|
|
|
@@ -115,64 +339,75 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
115
339
|
for k in self.selector:
|
|
116
340
|
output = output[k]
|
|
117
341
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
342
|
+
window = Window(
|
|
343
|
+
storage=self.dataset_storage,
|
|
344
|
+
group=metadata.window_group,
|
|
345
|
+
name=metadata.window_name,
|
|
346
|
+
projection=metadata.projection,
|
|
347
|
+
bounds=metadata.window_bounds,
|
|
348
|
+
time_range=metadata.time_range,
|
|
349
|
+
)
|
|
350
|
+
self.process_output(
|
|
351
|
+
window,
|
|
352
|
+
metadata.patch_idx,
|
|
353
|
+
metadata.num_patches_in_window,
|
|
354
|
+
metadata.patch_bounds,
|
|
355
|
+
output,
|
|
356
|
+
)
|
|
132
357
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
358
|
+
def process_output(
|
|
359
|
+
self,
|
|
360
|
+
window: Window,
|
|
361
|
+
patch_idx: int,
|
|
362
|
+
num_patches: int,
|
|
363
|
+
cur_bounds: PixelBounds,
|
|
364
|
+
output: npt.NDArray | list[Feature],
|
|
365
|
+
) -> None:
|
|
366
|
+
"""Process one output from the model.
|
|
142
367
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
368
|
+
Args:
|
|
369
|
+
window: the window that the output pertains to.
|
|
370
|
+
patch_idx: the index of this patch for the window.
|
|
371
|
+
num_patches: the total number of patches to be processed for the window.
|
|
372
|
+
cur_bounds: the bounds of the current patch.
|
|
373
|
+
output: the output data.
|
|
374
|
+
"""
|
|
375
|
+
# Incorporate the output into our list of pending patch outputs.
|
|
376
|
+
if window.name not in self.pending_outputs:
|
|
377
|
+
self.pending_outputs[window.name] = []
|
|
378
|
+
self.pending_outputs[window.name].append(PendingPatchOutput(cur_bounds, output))
|
|
379
|
+
logger.debug(
|
|
380
|
+
f"Stored PendingPatchOutput for patch #{patch_idx}/{num_patches} at window {window.name}"
|
|
381
|
+
)
|
|
146
382
|
|
|
147
|
-
|
|
383
|
+
if patch_idx < num_patches - 1:
|
|
384
|
+
return
|
|
148
385
|
|
|
149
|
-
|
|
150
|
-
|
|
386
|
+
# This is the last patch so it's time to write it.
|
|
387
|
+
# First get the pending output and clear it.
|
|
388
|
+
pending_output = self.pending_outputs[window.name]
|
|
389
|
+
del self.pending_outputs[window.name]
|
|
151
390
|
|
|
152
|
-
|
|
153
|
-
|
|
391
|
+
# Merge outputs from overlapped patches if merger is set.
|
|
392
|
+
logger.debug(f"Merging and writing for window {window.name}")
|
|
393
|
+
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
154
394
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
395
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
396
|
+
raster_dir = window.get_raster_dir(
|
|
397
|
+
self.output_layer, self.layer_config.band_sets[0].bands
|
|
398
|
+
)
|
|
399
|
+
assert isinstance(self.format, RasterFormat)
|
|
158
400
|
|
|
159
|
-
#
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
/ metadata["group"]
|
|
164
|
-
/ window_name
|
|
165
|
-
/ "layers"
|
|
166
|
-
/ self.output_layer
|
|
401
|
+
# In case the merged_output is at a different resolution than the window,
|
|
402
|
+
# get adjusted projection and bounds for writing it.
|
|
403
|
+
projection, bounds = adjust_projection_and_bounds_for_array(
|
|
404
|
+
window.projection, window.bounds, merged_output
|
|
167
405
|
)
|
|
406
|
+
self.format.encode_raster(raster_dir, projection, bounds, merged_output)
|
|
168
407
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
)
|
|
408
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
409
|
+
layer_dir = window.get_layer_dir(self.output_layer)
|
|
410
|
+
assert isinstance(self.format, VectorFormat)
|
|
411
|
+
self.format.encode_vector(layer_dir, merged_output)
|
|
174
412
|
|
|
175
|
-
|
|
176
|
-
self.format.encode_vector(
|
|
177
|
-
layer_dir, metadata["projection"], pending_output
|
|
178
|
-
)
|
|
413
|
+
window.mark_layer_completed(self.output_layer)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Learning rate schedulers for rslearn."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
5
|
+
|
|
6
|
+
from torch.optim import Optimizer
|
|
7
|
+
from torch.optim.lr_scheduler import (
|
|
8
|
+
CosineAnnealingLR,
|
|
9
|
+
CosineAnnealingWarmRestarts,
|
|
10
|
+
LRScheduler,
|
|
11
|
+
MultiStepLR,
|
|
12
|
+
ReduceLROnPlateau,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from rslearn.log_utils import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SchedulerFactory(ABC):
|
|
21
|
+
"""A factory class that initializes an LR scheduler given the optimizer."""
|
|
22
|
+
|
|
23
|
+
def get_kwargs(self) -> dict:
|
|
24
|
+
"""Get the keyword arguments for the scheduler."""
|
|
25
|
+
return {k: v for k, v in asdict(self).items() if v is not None} # type: ignore
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
29
|
+
"""Build the learning rate scheduler configured by this factory class."""
|
|
30
|
+
logger.info(
|
|
31
|
+
f"Using scheduler {self.__class__.__name__} with kwargs {self.get_kwargs()}"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class PlateauScheduler(SchedulerFactory):
|
|
37
|
+
"""Plateau learning rate scheduler."""
|
|
38
|
+
|
|
39
|
+
mode: str | None = None
|
|
40
|
+
factor: float | None = None
|
|
41
|
+
patience: int | None = None
|
|
42
|
+
threshold: float | None = None
|
|
43
|
+
threshold_mode: str | None = None
|
|
44
|
+
cooldown: int | None = None
|
|
45
|
+
min_lr: float | None = None
|
|
46
|
+
eps: float | None = None
|
|
47
|
+
|
|
48
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
49
|
+
"""Build the ReduceLROnPlateau scheduler."""
|
|
50
|
+
super().build(optimizer)
|
|
51
|
+
return ReduceLROnPlateau(optimizer, **self.get_kwargs())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class MultiStepScheduler(SchedulerFactory):
|
|
56
|
+
"""Step learning rate scheduler."""
|
|
57
|
+
|
|
58
|
+
milestones: list[int]
|
|
59
|
+
gamma: float | None = None
|
|
60
|
+
last_epoch: int | None = None
|
|
61
|
+
|
|
62
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
63
|
+
"""Build the ReduceLROnPlateau scheduler."""
|
|
64
|
+
super().build(optimizer)
|
|
65
|
+
return MultiStepLR(optimizer, **self.get_kwargs())
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class CosineAnnealingScheduler(SchedulerFactory):
|
|
70
|
+
"""Cosine annealing learning rate scheduler."""
|
|
71
|
+
|
|
72
|
+
T_max: int
|
|
73
|
+
eta_min: float | None = None
|
|
74
|
+
|
|
75
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
76
|
+
"""Build the CosineAnnealingLR scheduler."""
|
|
77
|
+
super().build(optimizer)
|
|
78
|
+
return CosineAnnealingLR(optimizer, **self.get_kwargs())
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class CosineAnnealingWarmRestartsScheduler(SchedulerFactory):
|
|
83
|
+
"""Cosine annealing with warm restarts learning rate scheduler."""
|
|
84
|
+
|
|
85
|
+
T_0: int
|
|
86
|
+
T_mult: int = 1
|
|
87
|
+
eta_min: float = 0.0
|
|
88
|
+
|
|
89
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
90
|
+
"""Build the CosineAnnealingWarmRestarts scheduler."""
|
|
91
|
+
super().build(optimizer)
|
|
92
|
+
return CosineAnnealingWarmRestarts(optimizer, **self.get_kwargs())
|