rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -1,30 +1,55 @@
|
|
|
1
1
|
"""Default Dataset for rslearn."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
3
5
|
import multiprocessing
|
|
4
6
|
import os
|
|
5
7
|
import random
|
|
8
|
+
import tempfile
|
|
6
9
|
import time
|
|
10
|
+
import uuid
|
|
11
|
+
from datetime import datetime
|
|
7
12
|
from typing import Any
|
|
8
13
|
|
|
9
14
|
import torch
|
|
10
15
|
import tqdm
|
|
16
|
+
from rasterio.warp import Resampling
|
|
11
17
|
|
|
12
18
|
import rslearn.train.transforms.transform
|
|
13
19
|
from rslearn.config import (
|
|
14
20
|
DType,
|
|
15
|
-
|
|
16
|
-
RasterLayerConfig,
|
|
17
|
-
VectorLayerConfig,
|
|
21
|
+
LayerConfig,
|
|
18
22
|
)
|
|
19
|
-
from rslearn.
|
|
20
|
-
from rslearn.
|
|
21
|
-
from rslearn.
|
|
23
|
+
from rslearn.data_sources.data_source import Item
|
|
24
|
+
from rslearn.dataset.dataset import Dataset
|
|
25
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
26
|
+
from rslearn.dataset.window import (
|
|
27
|
+
Window,
|
|
28
|
+
WindowLayerData,
|
|
29
|
+
get_layer_and_group_from_dir_name,
|
|
30
|
+
)
|
|
31
|
+
from rslearn.log_utils import get_logger
|
|
32
|
+
from rslearn.train.model_context import RasterImage
|
|
33
|
+
from rslearn.utils.feature import Feature
|
|
34
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
22
35
|
from rslearn.utils.mp import star_imap_unordered
|
|
23
|
-
from rslearn.utils.raster_format import load_raster_format
|
|
24
|
-
from rslearn.utils.vector_format import load_vector_format
|
|
25
36
|
|
|
37
|
+
from .model_context import SampleMetadata
|
|
38
|
+
from .tasks import Task
|
|
26
39
|
from .transforms import Sequential
|
|
27
40
|
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_torch_dtype(dtype: DType) -> torch.dtype:
|
|
45
|
+
"""Convert rslearn DType to torch dtype."""
|
|
46
|
+
if dtype == DType.INT32:
|
|
47
|
+
return torch.int32
|
|
48
|
+
elif dtype == DType.FLOAT32:
|
|
49
|
+
return torch.float32
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"unable to handle {dtype} as a torch dtype")
|
|
52
|
+
|
|
28
53
|
|
|
29
54
|
class SamplerFactory:
|
|
30
55
|
"""Factory to produce a Sampler.
|
|
@@ -47,7 +72,9 @@ class SamplerFactory:
|
|
|
47
72
|
class RandomSamplerFactory(SamplerFactory):
|
|
48
73
|
"""A sampler factory for RandomSampler."""
|
|
49
74
|
|
|
50
|
-
def __init__(
|
|
75
|
+
def __init__(
|
|
76
|
+
self, replacement: bool = False, num_samples: int | None = None
|
|
77
|
+
) -> None:
|
|
51
78
|
"""Initialize a RandomSamplerFactory.
|
|
52
79
|
|
|
53
80
|
Args:
|
|
@@ -75,7 +102,9 @@ class RandomSamplerFactory(SamplerFactory):
|
|
|
75
102
|
class WeightedRandomSamplerFactory(SamplerFactory):
|
|
76
103
|
"""A sampler factory for WeightedRandomSampler."""
|
|
77
104
|
|
|
78
|
-
def __init__(
|
|
105
|
+
def __init__(
|
|
106
|
+
self, option_key: str, num_samples: int, replacement: bool = True
|
|
107
|
+
) -> None:
|
|
79
108
|
"""Initialize a WeightedRandomSamplerFactory.
|
|
80
109
|
|
|
81
110
|
Args:
|
|
@@ -97,7 +126,7 @@ class WeightedRandomSamplerFactory(SamplerFactory):
|
|
|
97
126
|
a RandomSampler
|
|
98
127
|
"""
|
|
99
128
|
weights = []
|
|
100
|
-
for window in dataset.
|
|
129
|
+
for window in dataset.get_dataset_examples():
|
|
101
130
|
weights.append(window.options[self.option_key])
|
|
102
131
|
return torch.utils.data.WeightedRandomSampler(
|
|
103
132
|
weights, self.num_samples, replacement=self.replacement
|
|
@@ -108,6 +137,10 @@ class DataInput:
|
|
|
108
137
|
"""Specification of a piece of data from a window that is needed for training.
|
|
109
138
|
|
|
110
139
|
The DataInput includes which layer(s) the data can be obtained from for each window.
|
|
140
|
+
|
|
141
|
+
Note that this class is not a dataclass because jsonargparse does not play well
|
|
142
|
+
with dataclasses without enabling specialized options which we have not validated
|
|
143
|
+
will work with the rest of our code.
|
|
111
144
|
"""
|
|
112
145
|
|
|
113
146
|
def __init__(
|
|
@@ -119,6 +152,10 @@ class DataInput:
|
|
|
119
152
|
passthrough: bool = False,
|
|
120
153
|
is_target: bool = False,
|
|
121
154
|
dtype: DType = DType.FLOAT32,
|
|
155
|
+
load_all_layers: bool = False,
|
|
156
|
+
load_all_item_groups: bool = False,
|
|
157
|
+
resolution_factor: ResolutionFactor = ResolutionFactor(),
|
|
158
|
+
resampling: Resampling = Resampling.nearest,
|
|
122
159
|
):
|
|
123
160
|
"""Initialize a new DataInput.
|
|
124
161
|
|
|
@@ -132,6 +169,21 @@ class DataInput:
|
|
|
132
169
|
is_target: whether this DataInput represents a target for the task. Targets
|
|
133
170
|
are not read during prediction phase.
|
|
134
171
|
dtype: data type to load the raster as
|
|
172
|
+
load_all_layers: whether to load all of the layers specified in the list of
|
|
173
|
+
layer names. By default, we randomly pick one layer to read. When
|
|
174
|
+
reading multiple layers, the images are stacked on the channel
|
|
175
|
+
dimension. This option will also cause the dataset to only include
|
|
176
|
+
windows where all of the layers are materialized (by default, only
|
|
177
|
+
windows with none of the layers materialized would be excluded).
|
|
178
|
+
load_all_item_groups: whether to load all item groups in the layer(s) we
|
|
179
|
+
are reading from. By default, we assume the specified layer name is of
|
|
180
|
+
the form "{layer_name}.{group_idx}" and read that item group only. With
|
|
181
|
+
this option enabled, we ignore the group_idx and read all item groups.
|
|
182
|
+
resolution_factor: controls the resolution at which raster data is loaded for training.
|
|
183
|
+
By default (factor=1), data is loaded at the window resolution.
|
|
184
|
+
E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
|
|
185
|
+
the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
|
|
186
|
+
resampling: resampling method (default nearest neighbor).
|
|
135
187
|
"""
|
|
136
188
|
self.data_type = data_type
|
|
137
189
|
self.layers = layers
|
|
@@ -140,6 +192,241 @@ class DataInput:
|
|
|
140
192
|
self.passthrough = passthrough
|
|
141
193
|
self.is_target = is_target
|
|
142
194
|
self.dtype = dtype
|
|
195
|
+
self.load_all_layers = load_all_layers
|
|
196
|
+
self.load_all_item_groups = load_all_item_groups
|
|
197
|
+
self.resolution_factor = resolution_factor
|
|
198
|
+
self.resampling = resampling
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def read_raster_layer_for_data_input(
|
|
202
|
+
window: Window,
|
|
203
|
+
bounds: PixelBounds,
|
|
204
|
+
layer_name: str,
|
|
205
|
+
group_idx: int,
|
|
206
|
+
layer_config: LayerConfig,
|
|
207
|
+
data_input: DataInput,
|
|
208
|
+
) -> torch.Tensor:
|
|
209
|
+
"""Read a raster layer for a DataInput.
|
|
210
|
+
|
|
211
|
+
This scans the available rasters for the layer at the window to determine which
|
|
212
|
+
ones are needed to get all of the configured bands.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
window: the window to read from.
|
|
216
|
+
bounds: the bounds to read.
|
|
217
|
+
layer_name: the layer.
|
|
218
|
+
group_idx: the item group.
|
|
219
|
+
layer_config: the layer configuration.
|
|
220
|
+
data_input: the DataInput that specifies the bands and dtype.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Raster data as a tensor.
|
|
224
|
+
"""
|
|
225
|
+
# See what different sets of bands we need to read to get all the
|
|
226
|
+
# configured bands.
|
|
227
|
+
needed_bands = data_input.bands
|
|
228
|
+
if needed_bands is None:
|
|
229
|
+
raise ValueError(f"No bands specified for {layer_name}")
|
|
230
|
+
needed_band_indexes = {}
|
|
231
|
+
for i, band in enumerate(needed_bands):
|
|
232
|
+
needed_band_indexes[band] = i
|
|
233
|
+
needed_sets_and_indexes = []
|
|
234
|
+
for band_set in layer_config.band_sets:
|
|
235
|
+
needed_src_indexes = []
|
|
236
|
+
needed_dst_indexes = []
|
|
237
|
+
if band_set.bands is None:
|
|
238
|
+
continue
|
|
239
|
+
for i, band in enumerate(band_set.bands):
|
|
240
|
+
if band not in needed_band_indexes:
|
|
241
|
+
continue
|
|
242
|
+
needed_src_indexes.append(i)
|
|
243
|
+
needed_dst_indexes.append(needed_band_indexes[band])
|
|
244
|
+
del needed_band_indexes[band]
|
|
245
|
+
if len(needed_src_indexes) == 0:
|
|
246
|
+
continue
|
|
247
|
+
needed_sets_and_indexes.append(
|
|
248
|
+
(band_set, needed_src_indexes, needed_dst_indexes)
|
|
249
|
+
)
|
|
250
|
+
if len(needed_band_indexes) > 0:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"could not get all the needed bands from "
|
|
253
|
+
+ f"window {window.name} layer {layer_name} group {group_idx}"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Get the projection and bounds to read under (multiply window resolution # by
|
|
257
|
+
# the specified resolution factor).
|
|
258
|
+
final_projection = data_input.resolution_factor.multiply_projection(
|
|
259
|
+
window.projection
|
|
260
|
+
)
|
|
261
|
+
final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
|
|
262
|
+
|
|
263
|
+
image = torch.zeros(
|
|
264
|
+
(
|
|
265
|
+
len(needed_bands),
|
|
266
|
+
final_bounds[3] - final_bounds[1],
|
|
267
|
+
final_bounds[2] - final_bounds[0],
|
|
268
|
+
),
|
|
269
|
+
dtype=get_torch_dtype(data_input.dtype),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
|
|
273
|
+
if band_set.format is None:
|
|
274
|
+
raise ValueError(f"No format specified for {layer_name}")
|
|
275
|
+
raster_format = band_set.instantiate_raster_format()
|
|
276
|
+
raster_dir = window.get_raster_dir(
|
|
277
|
+
layer_name, band_set.bands, group_idx=group_idx
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# TODO: previously we try to read based on band_set.zoom_offset when possible,
|
|
281
|
+
# and handle zooming in with torch.repeat (if resampling method is nearest
|
|
282
|
+
# neighbor). However, we have not benchmarked whether this actually improves
|
|
283
|
+
# data loading speed, so for simplicity, for now we let rasterio handle the
|
|
284
|
+
# resampling. If it really is much faster to handle it via torch, then it may
|
|
285
|
+
# make sense to bring back that functionality.
|
|
286
|
+
|
|
287
|
+
src = raster_format.decode_raster(
|
|
288
|
+
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
|
|
289
|
+
)
|
|
290
|
+
image[dst_indexes, :, :] = torch.as_tensor(
|
|
291
|
+
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return image
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def read_layer_time_range(
|
|
298
|
+
layer_data: WindowLayerData | None, group_idx: int
|
|
299
|
+
) -> tuple[datetime, datetime] | None:
|
|
300
|
+
"""Extract the combined time range from all items in a layer data group.
|
|
301
|
+
|
|
302
|
+
Returns the min start time and max end time across all items, or None if
|
|
303
|
+
no items have time ranges.
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
ValueError: If some items have time_range and others don't.
|
|
307
|
+
"""
|
|
308
|
+
if layer_data is None:
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
serialized_items = layer_data.serialized_item_groups[group_idx]
|
|
312
|
+
if not serialized_items:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
first_item = Item.deserialize(serialized_items[0])
|
|
316
|
+
if first_item.geometry.time_range is None:
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
# If the first item has a time_range, all items must have one
|
|
320
|
+
time_ranges: list[tuple[datetime, datetime]] = []
|
|
321
|
+
for serialized_item in serialized_items:
|
|
322
|
+
item = Item.deserialize(serialized_item)
|
|
323
|
+
if item.geometry.time_range is None:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Item '{item.name}' has no time_range, but first item does. "
|
|
326
|
+
"All items in a group must consistently have or lack time_range."
|
|
327
|
+
)
|
|
328
|
+
time_ranges.append(item.geometry.time_range)
|
|
329
|
+
|
|
330
|
+
return (
|
|
331
|
+
min(tr[0] for tr in time_ranges),
|
|
332
|
+
max(tr[1] for tr in time_ranges),
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def read_data_input(
|
|
337
|
+
dataset: Dataset,
|
|
338
|
+
window: Window,
|
|
339
|
+
bounds: PixelBounds,
|
|
340
|
+
data_input: DataInput,
|
|
341
|
+
rng: random.Random,
|
|
342
|
+
) -> RasterImage | list[Feature]:
|
|
343
|
+
"""Read the data specified by the DataInput from the window.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
dataset: the dataset, to get layer configs.
|
|
347
|
+
window: the window to read from.
|
|
348
|
+
bounds: the bounds of the patch we are reading.
|
|
349
|
+
data_input: the DataInput that specifies what layers to read.
|
|
350
|
+
rng: random number generator
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
the raster or vector data.
|
|
354
|
+
"""
|
|
355
|
+
# We first enumerate which layers are available.
|
|
356
|
+
# If load_all_item_groups is set, we need to check each item group within the
|
|
357
|
+
# layer.
|
|
358
|
+
layer_options: list[tuple[str, int]] = []
|
|
359
|
+
if data_input.load_all_item_groups:
|
|
360
|
+
wanted_layers = set(data_input.layers)
|
|
361
|
+
for layer_name, group_idx in window.list_completed_layers():
|
|
362
|
+
if layer_name not in wanted_layers:
|
|
363
|
+
continue
|
|
364
|
+
layer_options.append((layer_name, group_idx))
|
|
365
|
+
else:
|
|
366
|
+
for option in data_input.layers:
|
|
367
|
+
layer_name, group_idx = get_layer_and_group_from_dir_name(option)
|
|
368
|
+
if not window.is_layer_completed(layer_name, group_idx):
|
|
369
|
+
continue
|
|
370
|
+
layer_options.append((layer_name, group_idx))
|
|
371
|
+
|
|
372
|
+
# Now determine the layers that we should actually read.
|
|
373
|
+
# We randomly pick one, unless load_all_layers is set, in which case we read all of
|
|
374
|
+
# them.
|
|
375
|
+
layers_to_read: list[tuple[str, int]]
|
|
376
|
+
if data_input.load_all_layers:
|
|
377
|
+
# We assume that the user has ensured the layers are compatible, e.g. raster
|
|
378
|
+
# layers will need to have the same number of bands.
|
|
379
|
+
layers_to_read = layer_options
|
|
380
|
+
else:
|
|
381
|
+
layers_to_read = [rng.choice(layer_options)]
|
|
382
|
+
|
|
383
|
+
if data_input.data_type == "raster":
|
|
384
|
+
# load it once here
|
|
385
|
+
layer_datas = window.load_layer_datas()
|
|
386
|
+
images: list[torch.Tensor] = []
|
|
387
|
+
time_ranges: list[tuple[datetime, datetime] | None] = []
|
|
388
|
+
for layer_name, group_idx in layers_to_read:
|
|
389
|
+
layer_config = dataset.layers[layer_name]
|
|
390
|
+
image = read_raster_layer_for_data_input(
|
|
391
|
+
window,
|
|
392
|
+
bounds,
|
|
393
|
+
layer_name,
|
|
394
|
+
group_idx,
|
|
395
|
+
layer_config,
|
|
396
|
+
data_input,
|
|
397
|
+
)
|
|
398
|
+
# some layers (e.g. "label_raster") won't have associated layer datas
|
|
399
|
+
layer_data = layer_datas.get(layer_name)
|
|
400
|
+
time_range = read_layer_time_range(layer_data, group_idx)
|
|
401
|
+
if len(time_ranges) > 0:
|
|
402
|
+
if type(time_ranges[-1]) is not type(time_range):
|
|
403
|
+
raise ValueError(
|
|
404
|
+
f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
|
|
405
|
+
)
|
|
406
|
+
images.append(image)
|
|
407
|
+
time_ranges.append(time_range)
|
|
408
|
+
return RasterImage(
|
|
409
|
+
torch.stack(images, dim=1),
|
|
410
|
+
time_ranges if time_ranges[0] is not None else None, # type: ignore
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
elif data_input.data_type == "vector":
|
|
414
|
+
# We don't really support time series for vector data currently, we just
|
|
415
|
+
# concatenate the features together.
|
|
416
|
+
features: list[Feature] = []
|
|
417
|
+
for layer_name, group_idx in layers_to_read:
|
|
418
|
+
layer_config = dataset.layers[layer_name]
|
|
419
|
+
vector_format = layer_config.instantiate_vector_format()
|
|
420
|
+
layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
|
|
421
|
+
cur_features = vector_format.decode_vector(
|
|
422
|
+
layer_dir, window.projection, window.bounds
|
|
423
|
+
)
|
|
424
|
+
features.extend(cur_features)
|
|
425
|
+
|
|
426
|
+
return features
|
|
427
|
+
|
|
428
|
+
else:
|
|
429
|
+
raise ValueError(f"unknown data type {data_input.data_type}")
|
|
143
430
|
|
|
144
431
|
|
|
145
432
|
class SplitConfig:
|
|
@@ -149,15 +436,16 @@ class SplitConfig:
|
|
|
149
436
|
self,
|
|
150
437
|
groups: list[str] | None = None,
|
|
151
438
|
names: list[str] | None = None,
|
|
152
|
-
tags: dict[str,
|
|
439
|
+
tags: dict[str, Any] | None = None,
|
|
153
440
|
num_samples: int | None = None,
|
|
441
|
+
num_patches: int | None = None,
|
|
154
442
|
transforms: list[torch.nn.Module] | None = None,
|
|
155
443
|
sampler: SamplerFactory | None = None,
|
|
156
444
|
patch_size: int | tuple[int, int] | None = None,
|
|
157
445
|
overlap_ratio: float | None = None,
|
|
158
446
|
load_all_patches: bool | None = None,
|
|
159
447
|
skip_targets: bool | None = None,
|
|
160
|
-
):
|
|
448
|
+
) -> None:
|
|
161
449
|
"""Initialize a new SplitConfig.
|
|
162
450
|
|
|
163
451
|
Args:
|
|
@@ -168,6 +456,7 @@ class SplitConfig:
|
|
|
168
456
|
value. If value is empty, then only the existince of the key in the
|
|
169
457
|
window options is checked.
|
|
170
458
|
num_samples: limit this split to this many examples
|
|
459
|
+
num_patches: limit this split to this many patches
|
|
171
460
|
transforms: transforms to apply
|
|
172
461
|
sampler: SamplerFactory for this split
|
|
173
462
|
patch_size: an optional square size or (width, height) tuple. If set, read
|
|
@@ -183,15 +472,19 @@ class SplitConfig:
|
|
|
183
472
|
self.names = names
|
|
184
473
|
self.tags = tags
|
|
185
474
|
self.num_samples = num_samples
|
|
475
|
+
self.num_patches = num_patches
|
|
186
476
|
self.transforms = transforms
|
|
187
477
|
self.sampler = sampler
|
|
188
478
|
self.patch_size = patch_size
|
|
189
|
-
self.load_all_patches = load_all_patches
|
|
190
479
|
self.skip_targets = skip_targets
|
|
480
|
+
|
|
481
|
+
# Note that load_all_patches are handled by the RslearnDataModule rather than
|
|
482
|
+
# the ModelDataset.
|
|
483
|
+
self.load_all_patches = load_all_patches
|
|
191
484
|
self.overlap_ratio = overlap_ratio
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
485
|
+
|
|
486
|
+
if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
|
|
487
|
+
raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
|
|
195
488
|
|
|
196
489
|
def update(self, other: "SplitConfig") -> "SplitConfig":
|
|
197
490
|
"""Override settings in this SplitConfig with those in another.
|
|
@@ -204,6 +497,7 @@ class SplitConfig:
|
|
|
204
497
|
names=self.names,
|
|
205
498
|
tags=self.tags,
|
|
206
499
|
num_samples=self.num_samples,
|
|
500
|
+
num_patches=self.num_patches,
|
|
207
501
|
transforms=self.transforms,
|
|
208
502
|
sampler=self.sampler,
|
|
209
503
|
patch_size=self.patch_size,
|
|
@@ -219,6 +513,8 @@ class SplitConfig:
|
|
|
219
513
|
result.tags = other.tags
|
|
220
514
|
if other.num_samples:
|
|
221
515
|
result.num_samples = other.num_samples
|
|
516
|
+
if other.num_patches:
|
|
517
|
+
result.num_patches = other.num_patches
|
|
222
518
|
if other.transforms:
|
|
223
519
|
result.transforms = other.transforms
|
|
224
520
|
if other.sampler:
|
|
@@ -233,6 +529,18 @@ class SplitConfig:
|
|
|
233
529
|
result.skip_targets = other.skip_targets
|
|
234
530
|
return result
|
|
235
531
|
|
|
532
|
+
def get_patch_size(self) -> tuple[int, int] | None:
|
|
533
|
+
"""Get patch size normalized to int tuple."""
|
|
534
|
+
if self.patch_size is None:
|
|
535
|
+
return None
|
|
536
|
+
if isinstance(self.patch_size, int):
|
|
537
|
+
return (self.patch_size, self.patch_size)
|
|
538
|
+
return self.patch_size
|
|
539
|
+
|
|
540
|
+
def get_overlap_ratio(self) -> float:
|
|
541
|
+
"""Get the overlap ratio (default 0)."""
|
|
542
|
+
return self.overlap_ratio if self.overlap_ratio is not None else 0.0
|
|
543
|
+
|
|
236
544
|
def get_load_all_patches(self) -> bool:
|
|
237
545
|
"""Returns whether loading all patches is enabled (default False)."""
|
|
238
546
|
return True if self.load_all_patches is True else False
|
|
@@ -242,7 +550,7 @@ class SplitConfig:
|
|
|
242
550
|
return True if self.skip_targets is True else False
|
|
243
551
|
|
|
244
552
|
|
|
245
|
-
def check_window(inputs: dict[str, DataInput], window: Window) ->
|
|
553
|
+
def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
|
|
246
554
|
"""Verify that the window has the required layers based on the specified inputs.
|
|
247
555
|
|
|
248
556
|
Args:
|
|
@@ -254,17 +562,25 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
|
|
|
254
562
|
"""
|
|
255
563
|
|
|
256
564
|
# Make sure window has all the needed layers.
|
|
257
|
-
def
|
|
565
|
+
def is_available(data_input: DataInput) -> bool:
|
|
566
|
+
# If load_all_layers is enabled, we should check that all the layers are
|
|
567
|
+
# present. Otherwise, we just need one layer.
|
|
568
|
+
is_any_layer_available = False
|
|
569
|
+
are_all_layers_available = True
|
|
258
570
|
for layer_name in data_input.layers:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
571
|
+
if window.is_layer_completed(layer_name):
|
|
572
|
+
is_any_layer_available = True
|
|
573
|
+
else:
|
|
574
|
+
are_all_layers_available = False
|
|
575
|
+
if data_input.load_all_layers:
|
|
576
|
+
return are_all_layers_available
|
|
577
|
+
else:
|
|
578
|
+
return is_any_layer_available
|
|
263
579
|
|
|
264
580
|
for data_input in inputs.values():
|
|
265
581
|
if not data_input.required:
|
|
266
582
|
continue
|
|
267
|
-
if not
|
|
583
|
+
if not is_available(data_input):
|
|
268
584
|
logger.debug(
|
|
269
585
|
"Skipping window %s since check for layers %s failed",
|
|
270
586
|
window.name,
|
|
@@ -285,7 +601,9 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
285
601
|
inputs: dict[str, DataInput],
|
|
286
602
|
task: Task,
|
|
287
603
|
workers: int,
|
|
288
|
-
|
|
604
|
+
name: str | None = None,
|
|
605
|
+
fix_patch_pick: bool = False,
|
|
606
|
+
) -> None:
|
|
289
607
|
"""Instantiate a new ModelDataset.
|
|
290
608
|
|
|
291
609
|
Args:
|
|
@@ -294,50 +612,30 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
294
612
|
inputs: data to read from the dataset for training
|
|
295
613
|
task: the task to train on
|
|
296
614
|
workers: number of workers to use for initializing the dataset
|
|
615
|
+
name: name of the dataset (default: None)
|
|
616
|
+
fix_patch_pick: if True, fix the patch pick to be the same every time
|
|
617
|
+
for a given window. Useful for testing (default: False)
|
|
297
618
|
"""
|
|
298
619
|
self.dataset = dataset
|
|
299
620
|
self.split_config = split_config
|
|
300
621
|
self.inputs = inputs
|
|
301
622
|
self.task = task
|
|
302
|
-
|
|
623
|
+
self.name = name
|
|
624
|
+
self.fix_patch_pick = fix_patch_pick
|
|
303
625
|
if split_config.transforms:
|
|
304
626
|
self.transforms = Sequential(*split_config.transforms)
|
|
305
627
|
else:
|
|
306
628
|
self.transforms = rslearn.train.transforms.transform.Identity()
|
|
307
629
|
|
|
308
|
-
#
|
|
309
|
-
if
|
|
630
|
+
# Get normalized patch size from the SplitConfig.
|
|
631
|
+
# But if load all patches is enabled, this is handled by AllPatchesDataset, so
|
|
632
|
+
# here we instead load the entire windows.
|
|
633
|
+
if split_config.get_load_all_patches():
|
|
310
634
|
self.patch_size = None
|
|
311
|
-
elif isinstance(split_config.patch_size, int):
|
|
312
|
-
self.patch_size = (split_config.patch_size, split_config.patch_size)
|
|
313
635
|
else:
|
|
314
|
-
self.patch_size = split_config.
|
|
636
|
+
self.patch_size = split_config.get_patch_size()
|
|
315
637
|
|
|
316
|
-
|
|
317
|
-
windows = self.dataset.load_windows(
|
|
318
|
-
groups=split_config.groups,
|
|
319
|
-
names=split_config.names,
|
|
320
|
-
show_progress=True,
|
|
321
|
-
workers=workers,
|
|
322
|
-
)
|
|
323
|
-
elif split_config.groups:
|
|
324
|
-
windows = self.dataset.load_windows(
|
|
325
|
-
groups=split_config.groups, show_progress=True, workers=workers
|
|
326
|
-
)
|
|
327
|
-
else:
|
|
328
|
-
windows = self.dataset.load_windows(show_progress=True, workers=workers)
|
|
329
|
-
|
|
330
|
-
if split_config.tags:
|
|
331
|
-
# Filter the window.options.
|
|
332
|
-
new_windows = []
|
|
333
|
-
for window in windows:
|
|
334
|
-
for k, v in split_config.tags.items():
|
|
335
|
-
if k not in window.options:
|
|
336
|
-
continue
|
|
337
|
-
if v and window.options[k] != v:
|
|
338
|
-
continue
|
|
339
|
-
new_windows.append(window)
|
|
340
|
-
windows = new_windows
|
|
638
|
+
windows = self._get_initial_windows(split_config, workers)
|
|
341
639
|
|
|
342
640
|
# If targets are not needed, remove them from the inputs.
|
|
343
641
|
if split_config.get_skip_targets():
|
|
@@ -347,98 +645,178 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
347
645
|
|
|
348
646
|
# Eliminate windows that are missing either a requisite input layer, or missing
|
|
349
647
|
# all target layers.
|
|
350
|
-
p = multiprocessing.Pool(workers)
|
|
351
|
-
outputs = star_imap_unordered(
|
|
352
|
-
p,
|
|
353
|
-
check_window,
|
|
354
|
-
[
|
|
355
|
-
dict(
|
|
356
|
-
inputs=self.inputs,
|
|
357
|
-
window=window,
|
|
358
|
-
)
|
|
359
|
-
for window in windows
|
|
360
|
-
],
|
|
361
|
-
)
|
|
362
648
|
new_windows = []
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
649
|
+
if workers == 0:
|
|
650
|
+
for window in windows:
|
|
651
|
+
if check_window(self.inputs, window) is None:
|
|
652
|
+
continue
|
|
653
|
+
new_windows.append(window)
|
|
654
|
+
else:
|
|
655
|
+
p = multiprocessing.Pool(workers)
|
|
656
|
+
outputs = star_imap_unordered(
|
|
657
|
+
p,
|
|
658
|
+
check_window,
|
|
659
|
+
[
|
|
660
|
+
dict(
|
|
661
|
+
inputs=self.inputs,
|
|
662
|
+
window=window,
|
|
663
|
+
)
|
|
664
|
+
for window in windows
|
|
665
|
+
],
|
|
666
|
+
)
|
|
667
|
+
for window in tqdm.tqdm(
|
|
668
|
+
outputs, total=len(windows), desc="Checking available layers in windows"
|
|
669
|
+
):
|
|
670
|
+
if window is None:
|
|
671
|
+
continue
|
|
672
|
+
new_windows.append(window)
|
|
673
|
+
p.close()
|
|
370
674
|
windows = new_windows
|
|
371
675
|
|
|
676
|
+
# Sort the windows to ensure that the dataset is consistent across GPUs.
|
|
677
|
+
# Inconsistent ordering can lead to a subset of windows being processed during
|
|
678
|
+
# "model test" / "model predict" when using multiple GPUs.
|
|
679
|
+
# We use a hash so that functionality like num_samples limit gets a random
|
|
680
|
+
# subset of windows (with respect to the hash function choice).
|
|
681
|
+
windows.sort(
|
|
682
|
+
key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
|
|
683
|
+
)
|
|
684
|
+
|
|
372
685
|
# Limit windows to num_samples if requested.
|
|
373
686
|
if split_config.num_samples:
|
|
374
|
-
#
|
|
687
|
+
# The windows are sorted by hash of window name so this distribution should
|
|
688
|
+
# be representative of the population.
|
|
375
689
|
windows = windows[0 : split_config.num_samples]
|
|
376
690
|
|
|
377
|
-
|
|
691
|
+
# Write dataset_examples to a file so that we can load it lazily in the worker
|
|
692
|
+
# processes. Otherwise it takes a long time to transmit it when spawning each
|
|
693
|
+
# process.
|
|
694
|
+
self.dataset_examples_fname = os.path.join(
|
|
695
|
+
tempfile.gettempdir(),
|
|
696
|
+
"rslearn_dataset_examples",
|
|
697
|
+
f"{os.getpid()}_{uuid.uuid4()}.json",
|
|
698
|
+
)
|
|
699
|
+
self.num_dataset_examples = len(windows)
|
|
700
|
+
self.dataset_examples: list[Window] | None = None
|
|
701
|
+
logger.info(
|
|
702
|
+
f"Writing {len(windows)} dataset examples to {self.dataset_examples_fname}"
|
|
703
|
+
)
|
|
704
|
+
os.makedirs(os.path.dirname(self.dataset_examples_fname), exist_ok=True)
|
|
705
|
+
with open(self.dataset_examples_fname, "w") as f:
|
|
706
|
+
json.dump([self._serialize_item(example) for example in windows], f)
|
|
378
707
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
708
|
+
def _get_initial_windows(
|
|
709
|
+
self, split_config: SplitConfig, workers: int
|
|
710
|
+
) -> list[Window]:
|
|
711
|
+
"""Get the initial windows before input layer filtering.
|
|
712
|
+
|
|
713
|
+
The windows are filtered based on configured window names, groups, and tags.
|
|
714
|
+
|
|
715
|
+
This is a helper for the init function.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
split_config: the split configuration.
|
|
719
|
+
workers: number of worker processes.
|
|
720
|
+
|
|
721
|
+
Returns:
|
|
722
|
+
list of windows from the dataset after applying the aforementioned filters.
|
|
723
|
+
"""
|
|
724
|
+
# Load windows from dataset.
|
|
725
|
+
# If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
|
|
726
|
+
kwargs: dict[str, Any] = {}
|
|
727
|
+
if isinstance(self.dataset.storage, FileWindowStorage):
|
|
728
|
+
kwargs["workers"] = workers
|
|
729
|
+
kwargs["show_progress"] = True
|
|
730
|
+
# We also add the name/group filters to the kwargs.
|
|
731
|
+
if split_config.names:
|
|
732
|
+
kwargs["names"] = split_config.names
|
|
733
|
+
if split_config.groups:
|
|
734
|
+
kwargs["groups"] = split_config.groups
|
|
735
|
+
|
|
736
|
+
windows = self.dataset.load_windows(**kwargs)
|
|
737
|
+
|
|
738
|
+
# Filter by tags (if provided) using the window.options.
|
|
739
|
+
if split_config.tags:
|
|
740
|
+
new_windows = []
|
|
741
|
+
num_removed: dict[str, int] = {}
|
|
742
|
+
for window in windows:
|
|
743
|
+
for k, v in split_config.tags.items():
|
|
744
|
+
if k not in window.options or (v and window.options[k] != v):
|
|
745
|
+
num_removed[k] = num_removed.get(k, 0) + 1
|
|
746
|
+
break
|
|
747
|
+
else:
|
|
748
|
+
new_windows.append(window)
|
|
749
|
+
logger.info(
|
|
750
|
+
f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
|
|
751
|
+
)
|
|
752
|
+
for k, v in num_removed.items():
|
|
753
|
+
logger.info(f"Removed {v} windows due to tag {k}")
|
|
754
|
+
windows = new_windows
|
|
755
|
+
|
|
756
|
+
return windows
|
|
757
|
+
|
|
758
|
+
def _serialize_item(self, example: Window) -> dict[str, Any]:
|
|
759
|
+
return example.get_metadata()
|
|
760
|
+
|
|
761
|
+
def _deserialize_item(self, d: dict[str, Any]) -> Window:
|
|
762
|
+
return Window.from_metadata(
|
|
763
|
+
self.dataset.storage,
|
|
764
|
+
d,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
def get_dataset_examples(self) -> list[Window]:
|
|
768
|
+
"""Get a list of examples in the dataset.
|
|
769
|
+
|
|
770
|
+
If load_all_patches is False, this is a list of Windows. Otherwise, this is a
|
|
771
|
+
list of (window, patch_bounds, (patch_idx, # patches)) tuples.
|
|
772
|
+
"""
|
|
773
|
+
if self.dataset_examples is None:
|
|
774
|
+
logger.debug(
|
|
775
|
+
f"Loading dataset examples from {self.dataset_examples_fname} in process {os.getpid()}"
|
|
386
776
|
)
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
):
|
|
394
|
-
for row in range(
|
|
395
|
-
window.bounds[1],
|
|
396
|
-
window.bounds[3],
|
|
397
|
-
self.patch_size[1] - overlap_size,
|
|
398
|
-
):
|
|
399
|
-
cur_patches.append(
|
|
400
|
-
(
|
|
401
|
-
col,
|
|
402
|
-
row,
|
|
403
|
-
col + self.patch_size[0],
|
|
404
|
-
row + self.patch_size[1],
|
|
405
|
-
)
|
|
406
|
-
)
|
|
407
|
-
for i, patch_bounds in enumerate(cur_patches):
|
|
408
|
-
patches.append((window, patch_bounds, (i, len(cur_patches))))
|
|
409
|
-
self.windows = patches
|
|
777
|
+
with open(self.dataset_examples_fname) as f:
|
|
778
|
+
self.dataset_examples = [
|
|
779
|
+
self._deserialize_item(d) for d in json.load(f)
|
|
780
|
+
]
|
|
781
|
+
logger.debug(f"Finished loading dataset examples in process {os.getpid()}")
|
|
782
|
+
return self.dataset_examples
|
|
410
783
|
|
|
411
784
|
def __len__(self) -> int:
|
|
412
785
|
"""Returns the dataset length."""
|
|
413
|
-
return
|
|
786
|
+
return self.num_dataset_examples
|
|
414
787
|
|
|
415
|
-
def
|
|
416
|
-
|
|
788
|
+
def get_raw_inputs(
|
|
789
|
+
self, idx: int
|
|
790
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
791
|
+
"""Get the raw inputs and base metadata for this example.
|
|
792
|
+
|
|
793
|
+
This is the raster or vector data before being processed by the Task. So it
|
|
794
|
+
should be a Tensor for raster and list[Feature] for vector.
|
|
417
795
|
|
|
418
796
|
Args:
|
|
419
797
|
idx: the index in the dataset.
|
|
420
798
|
|
|
421
799
|
Returns:
|
|
422
|
-
a tuple (
|
|
800
|
+
a tuple (raw_inputs, passthrough_inputs, metadata).
|
|
423
801
|
"""
|
|
424
|
-
|
|
425
|
-
|
|
802
|
+
dataset_examples = self.get_dataset_examples()
|
|
803
|
+
example = dataset_examples[idx]
|
|
804
|
+
rng = random.Random(idx if self.fix_patch_pick else None)
|
|
426
805
|
|
|
427
806
|
# Select bounds to read.
|
|
428
|
-
if self.
|
|
429
|
-
window
|
|
430
|
-
elif self.patch_size:
|
|
807
|
+
if self.patch_size:
|
|
808
|
+
window = example
|
|
431
809
|
|
|
432
|
-
def get_patch_range(n_patch, n_window):
|
|
810
|
+
def get_patch_range(n_patch: int, n_window: int) -> list[int]:
|
|
433
811
|
if n_patch > n_window:
|
|
434
812
|
# Select arbitrary range containing the entire window.
|
|
435
813
|
# Basically arbitrarily padding the window to get to patch size.
|
|
436
|
-
start =
|
|
814
|
+
start = rng.randint(n_window - n_patch, 0)
|
|
437
815
|
return [start, start + n_patch]
|
|
438
816
|
|
|
439
817
|
else:
|
|
440
818
|
# Select arbitrary patch within the window.
|
|
441
|
-
start =
|
|
819
|
+
start = rng.randint(0, n_window - n_patch)
|
|
442
820
|
return [start, start + n_patch]
|
|
443
821
|
|
|
444
822
|
window_size = (
|
|
@@ -449,128 +827,56 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
449
827
|
get_patch_range(self.patch_size[0], window_size[0]),
|
|
450
828
|
get_patch_range(self.patch_size[1], window_size[1]),
|
|
451
829
|
]
|
|
452
|
-
bounds =
|
|
830
|
+
bounds = (
|
|
453
831
|
window.bounds[0] + patch_ranges[0][0],
|
|
454
832
|
window.bounds[1] + patch_ranges[1][0],
|
|
455
833
|
window.bounds[0] + patch_ranges[0][1],
|
|
456
834
|
window.bounds[1] + patch_ranges[1][1],
|
|
457
|
-
|
|
835
|
+
)
|
|
836
|
+
|
|
458
837
|
else:
|
|
838
|
+
window = example
|
|
459
839
|
bounds = window.bounds
|
|
460
840
|
|
|
461
|
-
|
|
462
|
-
def read_input(data_input: DataInput):
|
|
463
|
-
# First enumerate all options of individual layers to read.
|
|
464
|
-
layer_options = []
|
|
465
|
-
for layer_name in data_input.layers:
|
|
466
|
-
completed_fname = window.path / "layers" / layer_name / "completed"
|
|
467
|
-
if not completed_fname.exists():
|
|
468
|
-
continue
|
|
469
|
-
layer_options.append(layer_name)
|
|
470
|
-
|
|
471
|
-
# For now we just randomly pick one option.
|
|
472
|
-
# In the future we need to support different configuration for how to pick
|
|
473
|
-
# the options, as well as picking multiple for series inputs.
|
|
474
|
-
layer = random.choice(layer_options)
|
|
475
|
-
layer_dir = window.path / "layers" / layer
|
|
476
|
-
layer_config = self.dataset.layers[layer]
|
|
477
|
-
|
|
478
|
-
if data_input.data_type == "raster":
|
|
479
|
-
assert isinstance(layer_config, RasterLayerConfig)
|
|
480
|
-
|
|
481
|
-
# See what different sets of bands we need to read to get all the
|
|
482
|
-
# configured bands.
|
|
483
|
-
needed_bands = data_input.bands
|
|
484
|
-
needed_band_indexes = {}
|
|
485
|
-
for i, band in enumerate(needed_bands):
|
|
486
|
-
needed_band_indexes[band] = i
|
|
487
|
-
needed_sets_and_indexes = []
|
|
488
|
-
for band_set in layer_config.band_sets:
|
|
489
|
-
needed_src_indexes = []
|
|
490
|
-
needed_dst_indexes = []
|
|
491
|
-
for i, band in enumerate(band_set.bands):
|
|
492
|
-
if band not in needed_band_indexes:
|
|
493
|
-
continue
|
|
494
|
-
needed_src_indexes.append(i)
|
|
495
|
-
needed_dst_indexes.append(needed_band_indexes[band])
|
|
496
|
-
del needed_band_indexes[band]
|
|
497
|
-
if len(needed_src_indexes) == 0:
|
|
498
|
-
continue
|
|
499
|
-
needed_sets_and_indexes.append(
|
|
500
|
-
(band_set, needed_src_indexes, needed_dst_indexes)
|
|
501
|
-
)
|
|
502
|
-
if len(needed_band_indexes) > 0:
|
|
503
|
-
raise Exception(
|
|
504
|
-
"could not get all the needed bands from "
|
|
505
|
-
+ f"window {window.name} layer {layer}"
|
|
506
|
-
)
|
|
507
|
-
|
|
508
|
-
image = torch.zeros(
|
|
509
|
-
(len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
510
|
-
dtype=data_input.dtype.get_torch_dtype(),
|
|
511
|
-
)
|
|
512
|
-
|
|
513
|
-
for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
|
|
514
|
-
_, final_bounds = band_set.get_final_projection_and_bounds(
|
|
515
|
-
window.projection, bounds
|
|
516
|
-
)
|
|
517
|
-
raster_format = load_raster_format(
|
|
518
|
-
RasterFormatConfig(band_set.format["name"], band_set.format)
|
|
519
|
-
)
|
|
520
|
-
cur_path = layer_dir / "_".join(band_set.bands)
|
|
521
|
-
src = raster_format.decode_raster(cur_path, final_bounds)
|
|
522
|
-
|
|
523
|
-
# Resize to patch size if needed.
|
|
524
|
-
# This is for band sets that are stored at a lower resolution.
|
|
525
|
-
# Here we assume that it is a multiple.
|
|
526
|
-
if src.shape[1:3] != image.shape[1:3]:
|
|
527
|
-
if src.shape[1] < image.shape[1]:
|
|
528
|
-
factor = image.shape[1] // src.shape[1]
|
|
529
|
-
src = src.repeat(repeats=factor, axis=1).repeat(
|
|
530
|
-
repeats=factor, axis=2
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
factor = src.shape[1] // image.shape[1]
|
|
534
|
-
src = src[:, ::factor, ::factor]
|
|
535
|
-
|
|
536
|
-
image[dst_indexes, :, :] = torch.as_tensor(
|
|
537
|
-
src[src_indexes, :, :].astype(
|
|
538
|
-
data_input.dtype.get_numpy_dtype()
|
|
539
|
-
)
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
return image
|
|
543
|
-
|
|
544
|
-
elif data_input.data_type == "vector":
|
|
545
|
-
assert isinstance(layer_config, VectorLayerConfig)
|
|
546
|
-
vector_format = load_vector_format(layer_config.format)
|
|
547
|
-
features = vector_format.decode_vector(layer_dir, bounds)
|
|
548
|
-
return features
|
|
549
|
-
|
|
550
|
-
else:
|
|
551
|
-
raise Exception(f"unknown data type {data_input.data_type}")
|
|
841
|
+
assert isinstance(window, Window)
|
|
552
842
|
|
|
553
843
|
raw_inputs = {}
|
|
554
844
|
passthrough_inputs = {}
|
|
555
845
|
for name, data_input in self.inputs.items():
|
|
556
|
-
raw_inputs[name] =
|
|
846
|
+
raw_inputs[name] = read_data_input(
|
|
847
|
+
self.dataset, window, bounds, data_input, rng
|
|
848
|
+
)
|
|
557
849
|
if data_input.passthrough:
|
|
558
850
|
passthrough_inputs[name] = raw_inputs[name]
|
|
559
851
|
|
|
560
|
-
metadata =
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
852
|
+
metadata = SampleMetadata(
|
|
853
|
+
window_group=window.group,
|
|
854
|
+
window_name=window.name,
|
|
855
|
+
window_bounds=window.bounds,
|
|
856
|
+
patch_bounds=bounds,
|
|
857
|
+
patch_idx=0,
|
|
858
|
+
num_patches_in_window=1,
|
|
859
|
+
time_range=window.time_range,
|
|
860
|
+
projection=window.projection,
|
|
861
|
+
dataset_source=self.name,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
return raw_inputs, passthrough_inputs, metadata
|
|
865
|
+
|
|
866
|
+
def __getitem__(
|
|
867
|
+
self, idx: int
|
|
868
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
869
|
+
"""Read one training example.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
idx: the index in the dataset.
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
a tuple (input_dict, target_dict, metadata)
|
|
876
|
+
"""
|
|
877
|
+
logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
|
|
878
|
+
|
|
879
|
+
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
|
|
574
880
|
|
|
575
881
|
input_dict, target_dict = self.task.process_inputs(
|
|
576
882
|
raw_inputs,
|
|
@@ -584,17 +890,21 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
584
890
|
|
|
585
891
|
return input_dict, target_dict, metadata
|
|
586
892
|
|
|
587
|
-
def
|
|
588
|
-
"""
|
|
589
|
-
|
|
893
|
+
def set_name(self, name: str) -> None:
|
|
894
|
+
"""Set the name of the dataset.
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
name: the name to set.
|
|
898
|
+
"""
|
|
899
|
+
self.name = name
|
|
590
900
|
|
|
591
901
|
|
|
592
902
|
class RetryDataset(torch.utils.data.Dataset):
|
|
593
903
|
"""A dataset wrapper that retries getitem upon encountering error."""
|
|
594
904
|
|
|
595
905
|
def __init__(
|
|
596
|
-
self, dataset:
|
|
597
|
-
):
|
|
906
|
+
self, dataset: ModelDataset, retries: int = 3, delay: float = 5
|
|
907
|
+
) -> None:
|
|
598
908
|
"""Create a new RetryDataset.
|
|
599
909
|
|
|
600
910
|
Args:
|
|
@@ -606,7 +916,15 @@ class RetryDataset(torch.utils.data.Dataset):
|
|
|
606
916
|
self.retries = retries
|
|
607
917
|
self.delay = delay
|
|
608
918
|
|
|
609
|
-
def
|
|
919
|
+
def set_name(self, name: str) -> None:
|
|
920
|
+
"""Set the name of the dataset.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
name: the name to set.
|
|
924
|
+
"""
|
|
925
|
+
self.dataset.set_name(name)
|
|
926
|
+
|
|
927
|
+
def __len__(self) -> int:
|
|
610
928
|
"""Return length of the dataset."""
|
|
611
929
|
return len(self.dataset)
|
|
612
930
|
|
|
@@ -632,6 +950,41 @@ class RetryDataset(torch.utils.data.Dataset):
|
|
|
632
950
|
# One last try -- but don't catch any more errors.
|
|
633
951
|
return self.dataset[idx]
|
|
634
952
|
|
|
635
|
-
def
|
|
953
|
+
def get_dataset_examples(self) -> list[Window]:
|
|
636
954
|
"""Returns a list of windows in this dataset."""
|
|
637
|
-
return self.dataset.
|
|
955
|
+
return self.dataset.get_dataset_examples()
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
class MultiDataset(torch.utils.data.Dataset):
|
|
959
|
+
"""A dataset that combines multiple datasets."""
|
|
960
|
+
|
|
961
|
+
def __init__(self, datasets: dict[str, RetryDataset]) -> None:
|
|
962
|
+
"""Create a new MultiDataset.
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
datasets: map of dataset name to dataset.
|
|
966
|
+
"""
|
|
967
|
+
self.datasets = datasets
|
|
968
|
+
self.buckets = {}
|
|
969
|
+
curr_offset = 0
|
|
970
|
+
for name, ds in datasets.items():
|
|
971
|
+
self.buckets[name] = range(curr_offset, curr_offset + len(ds))
|
|
972
|
+
curr_offset += len(ds)
|
|
973
|
+
|
|
974
|
+
def __len__(self) -> int:
|
|
975
|
+
"""Return length of the dataset."""
|
|
976
|
+
return sum(len(ds) for ds in self.datasets.values())
|
|
977
|
+
|
|
978
|
+
def __getitem__(self, idx: int) -> Any:
|
|
979
|
+
"""Get item from the dataset.
|
|
980
|
+
|
|
981
|
+
Args:
|
|
982
|
+
idx: the item index.
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
the item data.
|
|
986
|
+
"""
|
|
987
|
+
for name, bucket in self.buckets.items():
|
|
988
|
+
if idx in bucket:
|
|
989
|
+
return self.datasets[name][idx - bucket.start]
|
|
990
|
+
raise IndexError(f"Index {idx} out of range (len={len(self)})")
|