rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/dataset/materialize.py
CHANGED
|
@@ -4,37 +4,32 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
|
-
from
|
|
8
|
-
from upath import UPath
|
|
7
|
+
from rasterio.enums import Resampling
|
|
9
8
|
|
|
10
9
|
from rslearn.config import (
|
|
10
|
+
BandSetConfig,
|
|
11
|
+
CompositingMethod,
|
|
11
12
|
LayerConfig,
|
|
12
|
-
RasterFormatConfig,
|
|
13
|
-
RasterLayerConfig,
|
|
14
|
-
VectorLayerConfig,
|
|
15
13
|
)
|
|
16
|
-
from rslearn.data_sources import
|
|
17
|
-
from rslearn.tile_stores import
|
|
18
|
-
from rslearn.utils import Feature
|
|
19
|
-
from rslearn.utils.
|
|
20
|
-
from rslearn.utils.vector_format import load_vector_format
|
|
14
|
+
from rslearn.data_sources.data_source import ItemType
|
|
15
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
16
|
+
from rslearn.utils.feature import Feature
|
|
17
|
+
from rslearn.utils.geometry import PixelBounds, Projection
|
|
21
18
|
|
|
22
19
|
from .remap import Remapper, load_remapper
|
|
23
20
|
from .window import Window
|
|
24
21
|
|
|
25
|
-
Materializers = ClassRegistry()
|
|
26
|
-
|
|
27
22
|
|
|
28
23
|
class Materializer:
|
|
29
24
|
"""An abstract class that materializes data from a tile store."""
|
|
30
25
|
|
|
31
26
|
def materialize(
|
|
32
27
|
self,
|
|
33
|
-
tile_store:
|
|
28
|
+
tile_store: TileStoreWithLayer,
|
|
34
29
|
window: Window,
|
|
35
30
|
layer_name: str,
|
|
36
31
|
layer_cfg: LayerConfig,
|
|
37
|
-
item_groups: list[list[
|
|
32
|
+
item_groups: list[list[ItemType]],
|
|
38
33
|
) -> None:
|
|
39
34
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
40
35
|
|
|
@@ -50,11 +45,16 @@ class Materializer:
|
|
|
50
45
|
|
|
51
46
|
def read_raster_window_from_tiles(
|
|
52
47
|
dst: npt.NDArray[Any],
|
|
53
|
-
|
|
48
|
+
tile_store: TileStoreWithLayer,
|
|
49
|
+
item_name: str,
|
|
50
|
+
bands: list[str],
|
|
51
|
+
projection: Projection,
|
|
54
52
|
bounds: PixelBounds,
|
|
55
53
|
src_indexes: list[int],
|
|
56
54
|
dst_indexes: list[int],
|
|
55
|
+
nodata_vals: list[float],
|
|
57
56
|
remapper: Remapper | None = None,
|
|
57
|
+
resampling: Resampling = Resampling.bilinear,
|
|
58
58
|
) -> None:
|
|
59
59
|
"""Read a window of raster data from tiles in a tile store.
|
|
60
60
|
|
|
@@ -62,13 +62,22 @@ def read_raster_window_from_tiles(
|
|
|
62
62
|
|
|
63
63
|
Args:
|
|
64
64
|
dst: the destination numpy array
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
tile_store: the TileStore to read from.
|
|
66
|
+
item_name: the item name.
|
|
67
|
+
bands: the bands that identify the raster we want to read.
|
|
68
|
+
projection: the projection of the dst array.
|
|
69
|
+
bounds: the bounds of the dst array.
|
|
67
70
|
src_indexes: the source band indexes to use
|
|
68
71
|
dst_indexes: corresponding destination band indexes for each source band index
|
|
72
|
+
nodata_vals: the nodata values for each band, to determine which parts of dst
|
|
73
|
+
should be overwritten.
|
|
69
74
|
remapper: optional remapper to apply on the source pixel values
|
|
75
|
+
resampling: how to resample the pixels in case re-projection is needed.
|
|
70
76
|
"""
|
|
71
|
-
|
|
77
|
+
# Only read the portion of the raster that overlaps with dst.
|
|
78
|
+
# This way we can avoid creating big arrays that are all empty which speeds things
|
|
79
|
+
# up for large windows.
|
|
80
|
+
src_bounds = tile_store.get_raster_bounds(item_name, bands, projection)
|
|
72
81
|
intersection = (
|
|
73
82
|
max(bounds[0], src_bounds[0]),
|
|
74
83
|
max(bounds[1], src_bounds[1]),
|
|
@@ -81,7 +90,9 @@ def read_raster_window_from_tiles(
|
|
|
81
90
|
dst_col_offset = intersection[0] - bounds[0]
|
|
82
91
|
dst_row_offset = intersection[1] - bounds[1]
|
|
83
92
|
|
|
84
|
-
src =
|
|
93
|
+
src = tile_store.read_raster(
|
|
94
|
+
item_name, bands, projection, intersection, resampling=resampling
|
|
95
|
+
)
|
|
85
96
|
src = src[src_indexes, :, :]
|
|
86
97
|
if remapper:
|
|
87
98
|
src = remapper(src, dst.dtype)
|
|
@@ -91,45 +102,403 @@ def read_raster_window_from_tiles(
|
|
|
91
102
|
dst_row_offset : dst_row_offset + src.shape[1],
|
|
92
103
|
dst_col_offset : dst_col_offset + src.shape[2],
|
|
93
104
|
]
|
|
94
|
-
|
|
105
|
+
|
|
106
|
+
# Create mask indicating where dst has no data (based on nodata_vals).
|
|
107
|
+
# We overwrite dst at pixels where all the bands are nodata.
|
|
108
|
+
nodata_vals_arr = np.array(nodata_vals)[:, None, None]
|
|
109
|
+
mask = (dst_crop[dst_indexes, :, :] == nodata_vals_arr).min(axis=0)
|
|
110
|
+
|
|
95
111
|
for src_index, dst_index in enumerate(dst_indexes):
|
|
96
112
|
dst_crop[dst_index, mask] = src[src_index, mask]
|
|
97
113
|
|
|
98
114
|
|
|
99
|
-
|
|
115
|
+
def get_needed_band_sets_and_indexes(
|
|
116
|
+
item: ItemType,
|
|
117
|
+
bands: list[str],
|
|
118
|
+
tile_store: TileStoreWithLayer,
|
|
119
|
+
) -> list[tuple[list[str], list[int], list[int]]]:
|
|
120
|
+
"""Identify indexes of required bands in tile store.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A list for each tile-store layer that contains at least
|
|
124
|
+
one requested band, a tuple: (src_bands, src_idx, dst_idx) where
|
|
125
|
+
- src_bands: the full band list for that layer,
|
|
126
|
+
- src_idx: indexes into src_bands of the bands that were requested,
|
|
127
|
+
- dst_idx: corresponding indexes in the requested `bands` list.
|
|
128
|
+
"""
|
|
129
|
+
# Identify which tile store layer(s) to read to get the configured bands.
|
|
130
|
+
wanted_band_indexes = {}
|
|
131
|
+
for i, band in enumerate(bands):
|
|
132
|
+
wanted_band_indexes[band] = i
|
|
133
|
+
|
|
134
|
+
available_bands = tile_store.get_raster_bands(item.name)
|
|
135
|
+
needed_band_sets_and_indexes = []
|
|
136
|
+
|
|
137
|
+
for src_bands in available_bands:
|
|
138
|
+
needed_src_indexes = []
|
|
139
|
+
needed_dst_indexes = []
|
|
140
|
+
for i, band in enumerate(src_bands):
|
|
141
|
+
if band not in wanted_band_indexes:
|
|
142
|
+
continue
|
|
143
|
+
needed_src_indexes.append(i)
|
|
144
|
+
needed_dst_indexes.append(wanted_band_indexes[band])
|
|
145
|
+
del wanted_band_indexes[band]
|
|
146
|
+
if len(needed_src_indexes) == 0:
|
|
147
|
+
continue
|
|
148
|
+
needed_band_sets_and_indexes.append(
|
|
149
|
+
(src_bands, needed_src_indexes, needed_dst_indexes)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if len(wanted_band_indexes) > 0:
|
|
153
|
+
# This item doesn't have all the needed bands, so skip it.
|
|
154
|
+
return []
|
|
155
|
+
|
|
156
|
+
return needed_band_sets_and_indexes
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def build_first_valid_composite(
|
|
160
|
+
group: list[ItemType],
|
|
161
|
+
nodata_vals: list[Any],
|
|
162
|
+
bands: list[str],
|
|
163
|
+
bounds: PixelBounds,
|
|
164
|
+
band_dtype: npt.DTypeLike,
|
|
165
|
+
tile_store: TileStoreWithLayer,
|
|
166
|
+
projection: Projection,
|
|
167
|
+
remapper: Remapper | None,
|
|
168
|
+
resampling_method: Resampling = Resampling.bilinear,
|
|
169
|
+
) -> npt.NDArray[np.generic]:
|
|
170
|
+
"""Build a composite by selecting the first valid pixel of items in the group.
|
|
171
|
+
|
|
172
|
+
A composite of shape of (bands,bounds) is created by iterating over items in
|
|
173
|
+
group in order and selecting the first pixel that is not nodata per index.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
group: list of items to composite together
|
|
177
|
+
nodata_vals: list of nodata values for each band
|
|
178
|
+
bands: list of band names to include in the composite
|
|
179
|
+
bounds: pixel bounds defining the spatial extent of the composite
|
|
180
|
+
band_dtype: data type for the output bands
|
|
181
|
+
tile_store: tile store containing the actual raster data
|
|
182
|
+
projection: spatial projection for the composite
|
|
183
|
+
remapper: remapper to apply to pixel values, or None
|
|
184
|
+
resampling_method: resampling method to use when reprojecting
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Composite of shape (bands, bounds) built from all items in the group
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
# Initialize the destination array to the nodata values.
|
|
191
|
+
# We default the nodata value to 0.
|
|
192
|
+
dst = np.zeros(
|
|
193
|
+
(len(bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
194
|
+
dtype=band_dtype,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
for idx, nodata_val in enumerate(nodata_vals):
|
|
198
|
+
dst[idx] = nodata_val
|
|
199
|
+
|
|
200
|
+
for item in group:
|
|
201
|
+
needed_band_sets_and_indexes = get_needed_band_sets_and_indexes(
|
|
202
|
+
item, bands, tile_store
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
for (
|
|
206
|
+
src_bands,
|
|
207
|
+
src_indexes,
|
|
208
|
+
dst_indexes,
|
|
209
|
+
) in needed_band_sets_and_indexes:
|
|
210
|
+
cur_nodata_vals = [nodata_vals[idx] for idx in dst_indexes]
|
|
211
|
+
read_raster_window_from_tiles(
|
|
212
|
+
dst=dst,
|
|
213
|
+
tile_store=tile_store,
|
|
214
|
+
item_name=item.name,
|
|
215
|
+
bands=src_bands,
|
|
216
|
+
projection=projection,
|
|
217
|
+
bounds=bounds,
|
|
218
|
+
src_indexes=src_indexes,
|
|
219
|
+
dst_indexes=dst_indexes,
|
|
220
|
+
nodata_vals=cur_nodata_vals,
|
|
221
|
+
remapper=remapper,
|
|
222
|
+
resampling=resampling_method,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return dst
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def read_and_stack_raster_windows(
|
|
229
|
+
group: list[ItemType],
|
|
230
|
+
bounds: PixelBounds,
|
|
231
|
+
bands: list[str],
|
|
232
|
+
tile_store: TileStoreWithLayer,
|
|
233
|
+
projection: Projection,
|
|
234
|
+
nodata_vals: list[Any],
|
|
235
|
+
remapper: Remapper | None,
|
|
236
|
+
band_dtype: npt.DTypeLike,
|
|
237
|
+
resampling_method: Resampling = Resampling.bilinear,
|
|
238
|
+
) -> npt.NDArray[np.generic]:
|
|
239
|
+
"""Create a stack of extent aligned raster windows.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
group: Iterable of items (e.g., scene metadata objects) to read data from.
|
|
243
|
+
bounds: Pixel bounds as (xmin, ymin, xmax, ymax) defining the spatial extent.
|
|
244
|
+
bands: List of band names to include in the output.
|
|
245
|
+
tile_store: Tile store containing the raster tiles for the items.
|
|
246
|
+
projection: Projection object specifying the spatial reference system.
|
|
247
|
+
nodata_vals: List of nodata values corresponding to each band.
|
|
248
|
+
band_dtype: Data type for the output raster (e.g., np.uint16, np.float32).
|
|
249
|
+
remapper: Optional remapper object to transform pixel values after reading.
|
|
250
|
+
resampling_method: Resampling method to use when reading/reprojecting tiles.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
NumPy array of shape (num_items, num_bands, height, width) containing
|
|
254
|
+
the stacked rasters for all items, with nodata values filled where data
|
|
255
|
+
is missing.
|
|
256
|
+
"""
|
|
257
|
+
height = bounds[3] - bounds[1]
|
|
258
|
+
width = bounds[2] - bounds[0]
|
|
259
|
+
window_shape = (len(bands), height, width)
|
|
260
|
+
|
|
261
|
+
extent_aligned_raster_windows: list[np.ndarray] = []
|
|
262
|
+
|
|
263
|
+
for item in group:
|
|
264
|
+
# Initialize destination array to nodata
|
|
265
|
+
dst = np.empty(window_shape, dtype=band_dtype)
|
|
266
|
+
for idx, nodata_val in enumerate(nodata_vals):
|
|
267
|
+
dst[idx, :, :] = nodata_val
|
|
268
|
+
|
|
269
|
+
# Determine which source band sets/indexes are needed for this item
|
|
270
|
+
needed_band_sets_and_indexes = get_needed_band_sets_and_indexes(
|
|
271
|
+
item, bands, tile_store
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Fill the destination window from the tile store
|
|
275
|
+
for src_bands, src_indexes, dst_indexes in needed_band_sets_and_indexes:
|
|
276
|
+
cur_nodata_vals = [nodata_vals[idx] for idx in dst_indexes]
|
|
277
|
+
read_raster_window_from_tiles(
|
|
278
|
+
dst=dst,
|
|
279
|
+
tile_store=tile_store,
|
|
280
|
+
item_name=item.name,
|
|
281
|
+
bands=src_bands,
|
|
282
|
+
projection=projection,
|
|
283
|
+
bounds=bounds,
|
|
284
|
+
src_indexes=src_indexes,
|
|
285
|
+
dst_indexes=dst_indexes,
|
|
286
|
+
nodata_vals=cur_nodata_vals,
|
|
287
|
+
remapper=remapper,
|
|
288
|
+
resampling=resampling_method,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
extent_aligned_raster_windows.append(dst)
|
|
292
|
+
|
|
293
|
+
# Stack along a new axis (items axis): (N_items, N_bands, H, W)
|
|
294
|
+
stacked_arrays = np.stack(extent_aligned_raster_windows, axis=0)
|
|
295
|
+
return stacked_arrays
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def mask_stacked_rasters(
|
|
299
|
+
stacked_rasters: npt.NDArray[np.generic],
|
|
300
|
+
nodata_vals: list[Any],
|
|
301
|
+
) -> np.ma.MaskedArray:
|
|
302
|
+
"""Masks the stacked rasters - each items band with the corresponding nodata val.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
stacked_rasters: NumPy array of shape (num_items, num_bands, height, width)
|
|
306
|
+
containing raster values for each item in the group.
|
|
307
|
+
nodata_vals: Sequence of nodata values, one per band, used to identify invalid
|
|
308
|
+
pixels in the stacked rasters.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
np.ma.MaskedArray with the same shape as `stacked_rasters`, where all
|
|
312
|
+
pixels equal to the per-band nodata value are masked.
|
|
313
|
+
"""
|
|
314
|
+
# Create mask based on nodata values
|
|
315
|
+
nodata_vals_array = np.array(nodata_vals).reshape(1, -1, 1, 1)
|
|
316
|
+
valid_mask = stacked_rasters != nodata_vals_array
|
|
317
|
+
|
|
318
|
+
# Create masked array for all bands
|
|
319
|
+
masked_data = np.ma.masked_where(~valid_mask, stacked_rasters)
|
|
320
|
+
|
|
321
|
+
return masked_data
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def build_mean_composite(
|
|
325
|
+
group: list[ItemType],
|
|
326
|
+
nodata_vals: list[Any],
|
|
327
|
+
bands: list[str],
|
|
328
|
+
bounds: PixelBounds,
|
|
329
|
+
band_dtype: npt.DTypeLike,
|
|
330
|
+
tile_store: TileStoreWithLayer,
|
|
331
|
+
projection: Projection,
|
|
332
|
+
remapper: Remapper | None,
|
|
333
|
+
resampling_method: Resampling = Resampling.bilinear,
|
|
334
|
+
) -> npt.NDArray[np.generic]:
|
|
335
|
+
"""Build a composite by computing the mean of valid pixels across items in the group.
|
|
336
|
+
|
|
337
|
+
A composite of shape (bands, bounds) is created by computing the per-pixel mean of
|
|
338
|
+
valid (non-nodata) pixels across all items in the group.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
group: list of items to composite together
|
|
342
|
+
nodata_vals: list of nodata values for each band
|
|
343
|
+
bands: list of band names to include in the composite
|
|
344
|
+
bounds: pixel bounds defining the spatial extent of the composite
|
|
345
|
+
band_dtype: data type for the output bands
|
|
346
|
+
tile_store: tile store containing the raster data
|
|
347
|
+
projection: spatial projection for the composite
|
|
348
|
+
remapper: remapper to apply to pixel values, or None
|
|
349
|
+
resampling_method: resampling method to use when reprojecting
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Composite of shape (bands, bounds) having per-pixel mean of all items in the group
|
|
353
|
+
"""
|
|
354
|
+
# TODO: Might want to add a running sum/count based method to reduce memory utilization
|
|
355
|
+
|
|
356
|
+
stacked_arrays = read_and_stack_raster_windows(
|
|
357
|
+
group=group,
|
|
358
|
+
bounds=bounds,
|
|
359
|
+
bands=bands,
|
|
360
|
+
tile_store=tile_store,
|
|
361
|
+
projection=projection,
|
|
362
|
+
nodata_vals=nodata_vals,
|
|
363
|
+
band_dtype=band_dtype,
|
|
364
|
+
remapper=remapper,
|
|
365
|
+
resampling_method=resampling_method,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Mask stacked arrays with nodata values of each band
|
|
369
|
+
masked_data = mask_stacked_rasters(stacked_arrays, nodata_vals)
|
|
370
|
+
|
|
371
|
+
# Compute mean along the items axis for all
|
|
372
|
+
mean_result = np.ma.mean(masked_data, axis=0)
|
|
373
|
+
|
|
374
|
+
# Fill masked values and convert to target dtype
|
|
375
|
+
fill_vals = np.array(nodata_vals).reshape(-1, 1, 1)
|
|
376
|
+
result = np.ma.filled(mean_result, fill_value=fill_vals).astype(band_dtype)
|
|
377
|
+
|
|
378
|
+
return result
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def build_median_composite(
|
|
382
|
+
group: list[ItemType],
|
|
383
|
+
nodata_vals: list[Any],
|
|
384
|
+
bands: list[str],
|
|
385
|
+
bounds: PixelBounds,
|
|
386
|
+
band_dtype: npt.DTypeLike,
|
|
387
|
+
tile_store: TileStoreWithLayer,
|
|
388
|
+
projection: Projection,
|
|
389
|
+
remapper: Remapper | None,
|
|
390
|
+
resampling_method: Resampling = Resampling.bilinear,
|
|
391
|
+
) -> npt.NDArray[np.generic]:
|
|
392
|
+
"""Build a composite by computing the median of valid pixels across items in the group.
|
|
393
|
+
|
|
394
|
+
A composite of shape (bands, bounds) is created by computing the per-pixel median of
|
|
395
|
+
valid (non-nodata) pixels across all items in the group.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
group: list of items to composite together
|
|
399
|
+
nodata_vals: list of nodata values for each band
|
|
400
|
+
bands: list of band names to include in the composite
|
|
401
|
+
bounds: pixel bounds defining the spatial extent of the composite
|
|
402
|
+
band_dtype: data type for the output bands
|
|
403
|
+
tile_store: tile store containing the raster data
|
|
404
|
+
projection: spatial projection for the composite
|
|
405
|
+
remapper: remapper to apply to pixel values, or None
|
|
406
|
+
resampling_method: resampling method to use when reprojecting
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Composite of shape (bands, bounds) having per-pixel median of all items in the group
|
|
410
|
+
"""
|
|
411
|
+
stacked_arrays = read_and_stack_raster_windows(
|
|
412
|
+
group=group,
|
|
413
|
+
bounds=bounds,
|
|
414
|
+
bands=bands,
|
|
415
|
+
tile_store=tile_store,
|
|
416
|
+
projection=projection,
|
|
417
|
+
nodata_vals=nodata_vals,
|
|
418
|
+
band_dtype=band_dtype,
|
|
419
|
+
remapper=remapper,
|
|
420
|
+
resampling_method=resampling_method,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Mask stacked arrays with nodata values of each band
|
|
424
|
+
masked_data = mask_stacked_rasters(stacked_arrays, nodata_vals)
|
|
425
|
+
|
|
426
|
+
# Compute median along the items axis for all
|
|
427
|
+
mean_result = np.ma.median(masked_data, axis=0)
|
|
428
|
+
|
|
429
|
+
# Fill masked values and convert to target dtype
|
|
430
|
+
fill_vals = np.array(nodata_vals).reshape(-1, 1, 1)
|
|
431
|
+
result = np.ma.filled(mean_result, fill_value=fill_vals).astype(band_dtype)
|
|
432
|
+
|
|
433
|
+
return result
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
compositing_methods = {
|
|
437
|
+
CompositingMethod.FIRST_VALID: build_first_valid_composite,
|
|
438
|
+
CompositingMethod.MEAN: build_mean_composite,
|
|
439
|
+
CompositingMethod.MEDIAN: build_median_composite,
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def build_composite(
|
|
444
|
+
group: list[ItemType],
|
|
445
|
+
compositing_method: CompositingMethod,
|
|
446
|
+
tile_store: TileStoreWithLayer,
|
|
447
|
+
layer_cfg: LayerConfig,
|
|
448
|
+
band_cfg: BandSetConfig,
|
|
449
|
+
projection: Projection,
|
|
450
|
+
bounds: PixelBounds,
|
|
451
|
+
remapper: Remapper | None,
|
|
452
|
+
) -> npt.NDArray[np.generic]:
|
|
453
|
+
"""Build a temporal composite for specified bands from items in the group.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
group: list of items to composite together
|
|
457
|
+
compositing_method: Which method to use for compositing. First valid chooses the first valid value per pixel, mean takes the mean value per pixel
|
|
458
|
+
tile_store: tile store containing the raster data
|
|
459
|
+
layer_cfg: the configuration of the layer to materialize
|
|
460
|
+
band_cfg: the configuration of the layer to materialize. Contains the bands to process.
|
|
461
|
+
projection: spatial projection for the composite
|
|
462
|
+
bounds: pixel bounds defining the spatial extent of the composite
|
|
463
|
+
remapper: remapper to apply to pixel values, or None
|
|
464
|
+
"""
|
|
465
|
+
nodata_vals = band_cfg.nodata_vals
|
|
466
|
+
if nodata_vals is None:
|
|
467
|
+
nodata_vals = [0 for _ in band_cfg.bands]
|
|
468
|
+
|
|
469
|
+
return compositing_methods[compositing_method](
|
|
470
|
+
group=group,
|
|
471
|
+
nodata_vals=nodata_vals,
|
|
472
|
+
bands=band_cfg.bands,
|
|
473
|
+
bounds=bounds,
|
|
474
|
+
band_dtype=band_cfg.dtype.get_numpy_dtype(),
|
|
475
|
+
tile_store=tile_store,
|
|
476
|
+
projection=projection,
|
|
477
|
+
resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
|
|
478
|
+
remapper=remapper,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
|
|
100
482
|
class RasterMaterializer(Materializer):
|
|
101
483
|
"""A Materializer for raster data."""
|
|
102
484
|
|
|
103
485
|
def materialize(
|
|
104
486
|
self,
|
|
105
|
-
tile_store:
|
|
487
|
+
tile_store: TileStoreWithLayer,
|
|
106
488
|
window: Window,
|
|
107
489
|
layer_name: str,
|
|
108
490
|
layer_cfg: LayerConfig,
|
|
109
|
-
item_groups: list[list[
|
|
491
|
+
item_groups: list[list[ItemType]],
|
|
110
492
|
) -> None:
|
|
111
493
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
112
494
|
|
|
113
495
|
Args:
|
|
114
|
-
tile_store: the tile store where the items have been ingested
|
|
496
|
+
tile_store: the tile store where the items have been ingested
|
|
115
497
|
window: the window to materialize
|
|
116
498
|
layer_name: name of the layer to materialize
|
|
117
499
|
layer_cfg: the configuration of the layer to materialize
|
|
118
500
|
item_groups: the items associated with this window and layer
|
|
119
501
|
"""
|
|
120
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
121
|
-
|
|
122
|
-
layer_tile_store = get_tile_store_for_layer(tile_store, layer_name, layer_cfg)
|
|
123
|
-
|
|
124
|
-
out_layer_dirs: list[UPath] = []
|
|
125
|
-
for group_id in range(len(item_groups)):
|
|
126
|
-
if group_id == 0:
|
|
127
|
-
out_layer_name = layer_name
|
|
128
|
-
else:
|
|
129
|
-
out_layer_name = f"{layer_name}.{group_id}"
|
|
130
|
-
out_layer_dir = window.path / "layers" / out_layer_name
|
|
131
|
-
out_layer_dirs.append(out_layer_dir)
|
|
132
|
-
|
|
133
502
|
for band_cfg in layer_cfg.band_sets:
|
|
134
503
|
# band_cfg could specify zoom_offset and maybe other parameters that affect
|
|
135
504
|
# projection/bounds, so use the corrected projection/bounds.
|
|
@@ -142,72 +511,40 @@ class RasterMaterializer(Materializer):
|
|
|
142
511
|
if band_cfg.remap:
|
|
143
512
|
remapper = load_remapper(band_cfg.remap)
|
|
144
513
|
|
|
145
|
-
raster_format =
|
|
146
|
-
RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
|
|
147
|
-
)
|
|
514
|
+
raster_format = band_cfg.instantiate_raster_format()
|
|
148
515
|
|
|
149
516
|
for group_id, group in enumerate(item_groups):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
517
|
+
composite = build_composite(
|
|
518
|
+
group=group,
|
|
519
|
+
compositing_method=layer_cfg.compositing_method,
|
|
520
|
+
tile_store=tile_store,
|
|
521
|
+
layer_cfg=layer_cfg,
|
|
522
|
+
band_cfg=band_cfg,
|
|
523
|
+
projection=projection,
|
|
524
|
+
bounds=bounds,
|
|
525
|
+
remapper=remapper,
|
|
153
526
|
)
|
|
154
|
-
for item in group:
|
|
155
|
-
# Identify which tile store layer(s) to read to get the configured
|
|
156
|
-
# bands.
|
|
157
|
-
needed_band_indexes = {}
|
|
158
|
-
for i, band in enumerate(band_cfg.bands):
|
|
159
|
-
needed_band_indexes[band] = i
|
|
160
|
-
suffixes = layer_tile_store.list_layers((item.name,))
|
|
161
|
-
needed_suffixes_and_indexes = []
|
|
162
|
-
for suffix in suffixes:
|
|
163
|
-
bands = suffix.split("_")
|
|
164
|
-
needed_src_indexes = []
|
|
165
|
-
needed_dst_indexes = []
|
|
166
|
-
for i, band in enumerate(bands):
|
|
167
|
-
if band not in needed_band_indexes:
|
|
168
|
-
continue
|
|
169
|
-
needed_src_indexes.append(i)
|
|
170
|
-
needed_dst_indexes.append(needed_band_indexes[band])
|
|
171
|
-
del needed_band_indexes[band]
|
|
172
|
-
if len(needed_src_indexes) == 0:
|
|
173
|
-
continue
|
|
174
|
-
needed_suffixes_and_indexes.append(
|
|
175
|
-
(suffix, needed_src_indexes, needed_dst_indexes)
|
|
176
|
-
)
|
|
177
|
-
if len(needed_band_indexes) > 0:
|
|
178
|
-
# This item doesn't have all the needed bands, so skip it.
|
|
179
|
-
continue
|
|
180
|
-
|
|
181
|
-
for suffix, src_indexes, dst_indexes in needed_suffixes_and_indexes:
|
|
182
|
-
ts_layer = layer_tile_store.get_layer(
|
|
183
|
-
(item.name, suffix, str(projection))
|
|
184
|
-
)
|
|
185
|
-
read_raster_window_from_tiles(
|
|
186
|
-
dst, ts_layer, bounds, src_indexes, dst_indexes, remapper
|
|
187
|
-
)
|
|
188
|
-
|
|
189
527
|
raster_format.encode_raster(
|
|
190
|
-
|
|
528
|
+
window.get_raster_dir(layer_name, band_cfg.bands, group_id),
|
|
191
529
|
projection,
|
|
192
530
|
bounds,
|
|
193
|
-
|
|
531
|
+
composite,
|
|
194
532
|
)
|
|
195
533
|
|
|
196
|
-
for
|
|
197
|
-
(
|
|
534
|
+
for group_id in range(len(item_groups)):
|
|
535
|
+
window.mark_layer_completed(layer_name, group_id)
|
|
198
536
|
|
|
199
537
|
|
|
200
|
-
@Materializers.register("vector")
|
|
201
538
|
class VectorMaterializer(Materializer):
|
|
202
539
|
"""A Materializer for vector data."""
|
|
203
540
|
|
|
204
541
|
def materialize(
|
|
205
542
|
self,
|
|
206
|
-
tile_store:
|
|
543
|
+
tile_store: TileStoreWithLayer,
|
|
207
544
|
window: Window,
|
|
208
545
|
layer_name: str,
|
|
209
546
|
layer_cfg: LayerConfig,
|
|
210
|
-
item_groups: list[list[
|
|
547
|
+
item_groups: list[list[ItemType]],
|
|
211
548
|
) -> None:
|
|
212
549
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
213
550
|
|
|
@@ -218,33 +555,20 @@ class VectorMaterializer(Materializer):
|
|
|
218
555
|
layer_cfg: the configuration of the layer to materialize
|
|
219
556
|
item_groups: the items associated with this window and layer
|
|
220
557
|
"""
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
projection, bounds = layer_cfg.get_final_projection_and_bounds(
|
|
224
|
-
window.projection, window.bounds
|
|
225
|
-
)
|
|
226
|
-
vector_format = load_vector_format(layer_cfg.format)
|
|
227
|
-
|
|
228
|
-
out_layer_dirs: list[UPath] = []
|
|
229
|
-
for group_id in range(len(item_groups)):
|
|
230
|
-
if group_id == 0:
|
|
231
|
-
out_layer_name = layer_name
|
|
232
|
-
else:
|
|
233
|
-
out_layer_name = f"{layer_name}.{group_id}"
|
|
234
|
-
out_layer_dir = window.path / "layers" / out_layer_name
|
|
235
|
-
out_layer_dirs.append(out_layer_dir)
|
|
558
|
+
vector_format = layer_cfg.instantiate_vector_format()
|
|
236
559
|
|
|
237
560
|
for group_id, group in enumerate(item_groups):
|
|
238
561
|
features: list[Feature] = []
|
|
239
562
|
|
|
240
563
|
for item in group:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
)
|
|
244
|
-
cur_features = ts_layer.read_vector(bounds)
|
|
564
|
+
cur_features = tile_store.read_vector(
|
|
565
|
+
item.name, window.projection, window.bounds
|
|
566
|
+
)
|
|
245
567
|
features.extend(cur_features)
|
|
246
568
|
|
|
247
|
-
vector_format.encode_vector(
|
|
569
|
+
vector_format.encode_vector(
|
|
570
|
+
window.get_layer_dir(layer_name, group_id), features
|
|
571
|
+
)
|
|
248
572
|
|
|
249
|
-
for
|
|
250
|
-
(
|
|
573
|
+
for group_id in range(len(item_groups)):
|
|
574
|
+
window.mark_layer_completed(layer_name, group_id)
|