rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/utils/raster_format.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Abstract RasterFormat class."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
3
4
|
import json
|
|
4
5
|
from typing import Any, BinaryIO
|
|
5
6
|
|
|
@@ -7,16 +8,134 @@ import affine
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import numpy.typing as npt
|
|
9
10
|
import rasterio
|
|
10
|
-
from class_registry import ClassRegistry
|
|
11
11
|
from PIL import Image
|
|
12
|
+
from rasterio.crs import CRS
|
|
13
|
+
from rasterio.enums import Resampling
|
|
12
14
|
from upath import UPath
|
|
13
15
|
|
|
14
|
-
from rslearn.config import RasterFormatConfig
|
|
15
16
|
from rslearn.const import TILE_SIZE
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
|
|
16
19
|
|
|
17
20
|
from .geometry import PixelBounds, Projection
|
|
18
21
|
|
|
19
|
-
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_bandset_dirname(bands: list[str]) -> str:
|
|
26
|
+
"""Get the directory name that should be used to store the given group of bands."""
|
|
27
|
+
# We try to use a human-readable name with underscore as the delimiter, but if that
|
|
28
|
+
# isn't straightforward then we use hash instead.
|
|
29
|
+
if any(["_" in band for band in bands]):
|
|
30
|
+
# In this case we hash the JSON representation of the bands.
|
|
31
|
+
return hashlib.sha256(json.dumps(bands).encode()).hexdigest()
|
|
32
|
+
dirname = "_".join(bands)
|
|
33
|
+
if len(dirname) > 64:
|
|
34
|
+
# Previously we simply joined the bands, but this can result in directory name
|
|
35
|
+
# that is too long. In this case, now we use hash instead.
|
|
36
|
+
# We use a different code path here where we hash the initial directory name
|
|
37
|
+
# instead of the JSON, for historical reasons (to maintain backwards
|
|
38
|
+
# compatibility).
|
|
39
|
+
dirname = hashlib.sha256(dirname.encode()).hexdigest()
|
|
40
|
+
return dirname
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_raster_projection_and_bounds_from_transform(
|
|
44
|
+
crs: CRS, transform: affine.Affine, width: int, height: int
|
|
45
|
+
) -> tuple[Projection, PixelBounds]:
|
|
46
|
+
"""Determine Projection and bounds from the specified CRS and transform.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
crs: the coordinate reference system.
|
|
50
|
+
transform: corresponding affine transform matrix.
|
|
51
|
+
width: the array width
|
|
52
|
+
height: the array height
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
a tuple (projection, bounds).
|
|
56
|
+
"""
|
|
57
|
+
x_resolution = transform.a
|
|
58
|
+
y_resolution = transform.e
|
|
59
|
+
projection = Projection(crs, x_resolution, y_resolution)
|
|
60
|
+
offset = (
|
|
61
|
+
int(round(transform.c / x_resolution)),
|
|
62
|
+
int(round(transform.f / y_resolution)),
|
|
63
|
+
)
|
|
64
|
+
bounds = (offset[0], offset[1], offset[0] + width, offset[1] + height)
|
|
65
|
+
return (projection, bounds)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_raster_projection_and_bounds(
|
|
69
|
+
raster: rasterio.DatasetReader,
|
|
70
|
+
) -> tuple[Projection, PixelBounds]:
|
|
71
|
+
"""Determine the Projection and bounds of the specified raster.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
raster: the raster dataset opened with rasterio.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
a tuple (projection, bounds).
|
|
78
|
+
"""
|
|
79
|
+
return get_raster_projection_and_bounds_from_transform(
|
|
80
|
+
raster.crs, raster.transform, raster.width, raster.height
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_transform_from_projection_and_bounds(
|
|
85
|
+
projection: Projection, bounds: PixelBounds
|
|
86
|
+
) -> affine.Affine:
|
|
87
|
+
"""Get the affine transform that corresponds to the given projection and bounds.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
projection: the projection. Only the resolutions are used.
|
|
91
|
+
bounds: the bounding box. Only the top-left corner is used.
|
|
92
|
+
"""
|
|
93
|
+
return affine.Affine(
|
|
94
|
+
projection.x_resolution,
|
|
95
|
+
0,
|
|
96
|
+
bounds[0] * projection.x_resolution,
|
|
97
|
+
0,
|
|
98
|
+
projection.y_resolution,
|
|
99
|
+
bounds[1] * projection.y_resolution,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def adjust_projection_and_bounds_for_array(
|
|
104
|
+
projection: Projection, bounds: PixelBounds, array: npt.NDArray
|
|
105
|
+
) -> tuple[Projection, PixelBounds]:
|
|
106
|
+
"""Adjust the projection and bounds to correspond to the resolution of the array.
|
|
107
|
+
|
|
108
|
+
The returned projection and bounds cover the same spatial extent as the inputs, but
|
|
109
|
+
are updated so that the width and height match that of the array.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
projection: the original projection.
|
|
113
|
+
bounds: the original bounds.
|
|
114
|
+
array: the CHW array for which to compute an updated projection and bounds. The
|
|
115
|
+
returned bounds will have the same width and height as this array.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
a tuple of adjusted (projection, bounds)
|
|
119
|
+
"""
|
|
120
|
+
if array.shape[2] == (bounds[2] - bounds[0]) and array.shape[1] == (
|
|
121
|
+
bounds[3] - bounds[1]
|
|
122
|
+
):
|
|
123
|
+
return (projection, bounds)
|
|
124
|
+
|
|
125
|
+
x_factor = array.shape[2] / (bounds[2] - bounds[0])
|
|
126
|
+
y_factor = array.shape[1] / (bounds[3] - bounds[1])
|
|
127
|
+
adjusted_projection = Projection(
|
|
128
|
+
projection.crs,
|
|
129
|
+
projection.x_resolution / x_factor,
|
|
130
|
+
projection.y_resolution / y_factor,
|
|
131
|
+
)
|
|
132
|
+
adjusted_bounds = (
|
|
133
|
+
round(bounds[0] * x_factor),
|
|
134
|
+
round(bounds[1] * y_factor),
|
|
135
|
+
round(bounds[0] * x_factor) + array.shape[2],
|
|
136
|
+
round(bounds[1] * y_factor) + array.shape[1],
|
|
137
|
+
)
|
|
138
|
+
return (adjusted_projection, adjusted_bounds)
|
|
20
139
|
|
|
21
140
|
|
|
22
141
|
class RasterFormat:
|
|
@@ -44,21 +163,39 @@ class RasterFormat:
|
|
|
44
163
|
raise NotImplementedError
|
|
45
164
|
|
|
46
165
|
def decode_raster(
|
|
47
|
-
self,
|
|
48
|
-
|
|
166
|
+
self,
|
|
167
|
+
path: UPath,
|
|
168
|
+
projection: Projection,
|
|
169
|
+
bounds: PixelBounds,
|
|
170
|
+
resampling: Resampling = Resampling.bilinear,
|
|
171
|
+
) -> npt.NDArray[Any]:
|
|
49
172
|
"""Decodes raster data.
|
|
50
173
|
|
|
51
174
|
Args:
|
|
52
175
|
path: the directory to read from
|
|
53
|
-
|
|
176
|
+
projection: the projection to read the raster in.
|
|
177
|
+
bounds: the bounds to read in the given projection.
|
|
178
|
+
resampling: resampling method to use in case resampling is needed.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
the raster data
|
|
182
|
+
"""
|
|
183
|
+
raise NotImplementedError
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
|
|
187
|
+
"""Create a RasterFormat from a config dict.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
name: the name of this format
|
|
191
|
+
config: the config dict
|
|
54
192
|
|
|
55
193
|
Returns:
|
|
56
|
-
the
|
|
194
|
+
the RasterFormat instance
|
|
57
195
|
"""
|
|
58
196
|
raise NotImplementedError
|
|
59
197
|
|
|
60
198
|
|
|
61
|
-
@RasterFormats.register("image_tile")
|
|
62
199
|
class ImageTileRasterFormat(RasterFormat):
|
|
63
200
|
"""A RasterFormat that stores data in image tiles corresponding to grid cells.
|
|
64
201
|
|
|
@@ -152,6 +289,19 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
152
289
|
bounds: the bounds of the raster data in the projection
|
|
153
290
|
array: the raster data (must be CHW)
|
|
154
291
|
"""
|
|
292
|
+
# Write metadata about the projection that we are writing under.
|
|
293
|
+
# We also save dtype and number of bands so we can return correct shape when
|
|
294
|
+
# there are no intersecting tiles.
|
|
295
|
+
with (path / "metadata.json").open("w") as f:
|
|
296
|
+
json.dump(
|
|
297
|
+
{
|
|
298
|
+
"projection": projection.serialize(),
|
|
299
|
+
"dtype": array.dtype.name,
|
|
300
|
+
"num_bands": array.shape[0],
|
|
301
|
+
},
|
|
302
|
+
f,
|
|
303
|
+
)
|
|
304
|
+
|
|
155
305
|
start_tile = (bounds[0] // self.tile_size, bounds[1] // self.tile_size)
|
|
156
306
|
end_tile = (bounds[2] // self.tile_size + 1, bounds[3] // self.tile_size + 1)
|
|
157
307
|
extension = self.get_extension()
|
|
@@ -190,17 +340,34 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
190
340
|
self.encode_tile(f, projection, cur_bounds, cur_array)
|
|
191
341
|
|
|
192
342
|
def decode_raster(
|
|
193
|
-
self,
|
|
194
|
-
|
|
343
|
+
self,
|
|
344
|
+
path: UPath,
|
|
345
|
+
projection: Projection,
|
|
346
|
+
bounds: PixelBounds,
|
|
347
|
+
resampling: Resampling = Resampling.bilinear,
|
|
348
|
+
) -> npt.NDArray[Any]:
|
|
195
349
|
"""Decodes raster data.
|
|
196
350
|
|
|
197
351
|
Args:
|
|
198
352
|
path: the directory to read from
|
|
199
|
-
|
|
353
|
+
projection: the projection to read the raster in.
|
|
354
|
+
bounds: the bounds to read in the given projection.
|
|
355
|
+
resampling: resampling method to use in case resampling is needed.
|
|
200
356
|
|
|
201
357
|
Returns:
|
|
202
|
-
the raster data
|
|
358
|
+
the raster data
|
|
203
359
|
"""
|
|
360
|
+
# Verify that the source data has the same projection as the requested one.
|
|
361
|
+
# ImageTileRasterFormat currently does not support re-projecting.
|
|
362
|
+
with (path / "metadata.json").open() as f:
|
|
363
|
+
image_metadata = json.load(f)
|
|
364
|
+
source_data_projection = Projection.deserialize(image_metadata["projection"])
|
|
365
|
+
if source_data_projection != projection:
|
|
366
|
+
raise NotImplementedError(
|
|
367
|
+
"not implemented to re-project source data "
|
|
368
|
+
+ f"(source projection {source_data_projection} does not match requested projection {projection})"
|
|
369
|
+
)
|
|
370
|
+
|
|
204
371
|
extension = self.get_extension()
|
|
205
372
|
|
|
206
373
|
# Load tiles one at a time.
|
|
@@ -209,7 +376,12 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
209
376
|
(bounds[2] - 1) // self.tile_size + 1,
|
|
210
377
|
(bounds[3] - 1) // self.tile_size + 1,
|
|
211
378
|
)
|
|
212
|
-
|
|
379
|
+
dst_shape = (
|
|
380
|
+
image_metadata["num_bands"],
|
|
381
|
+
bounds[3] - bounds[1],
|
|
382
|
+
bounds[2] - bounds[0],
|
|
383
|
+
)
|
|
384
|
+
dst = np.zeros(dst_shape, dtype=image_metadata["dtype"])
|
|
213
385
|
for col in range(start_tile[0], end_tile[0]):
|
|
214
386
|
for row in range(start_tile[1], end_tile[1]):
|
|
215
387
|
fname = path / f"{col}_{row}.{extension}"
|
|
@@ -272,13 +444,17 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
272
444
|
)
|
|
273
445
|
|
|
274
446
|
|
|
275
|
-
@RasterFormats.register("geotiff")
|
|
276
447
|
class GeotiffRasterFormat(RasterFormat):
|
|
277
448
|
"""A raster format that uses one big, tiled GeoTIFF with small block size."""
|
|
278
449
|
|
|
279
450
|
fname = "geotiff.tif"
|
|
280
451
|
|
|
281
|
-
def __init__(
|
|
452
|
+
def __init__(
|
|
453
|
+
self,
|
|
454
|
+
block_size: int = TILE_SIZE,
|
|
455
|
+
always_enable_tiling: bool = False,
|
|
456
|
+
geotiff_options: dict[str, Any] = {},
|
|
457
|
+
):
|
|
282
458
|
"""Initializes a GeotiffRasterFormat.
|
|
283
459
|
|
|
284
460
|
Args:
|
|
@@ -287,9 +463,11 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
287
463
|
GeoTIFFs. The default is False so that tiling is only used if the size
|
|
288
464
|
of the GeoTIFF exceeds the block_size on either dimension. If True,
|
|
289
465
|
then tiling is always enabled (cloud-optimized GeoTIFF).
|
|
466
|
+
geotiff_options: other options to pass to rasterio.open (for writes).
|
|
290
467
|
"""
|
|
291
468
|
self.block_size = block_size
|
|
292
469
|
self.always_enable_tiling = always_enable_tiling
|
|
470
|
+
self.geotiff_options = geotiff_options
|
|
293
471
|
|
|
294
472
|
def encode_raster(
|
|
295
473
|
self,
|
|
@@ -297,6 +475,7 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
297
475
|
projection: Projection,
|
|
298
476
|
bounds: PixelBounds,
|
|
299
477
|
array: npt.NDArray[Any],
|
|
478
|
+
fname: str | None = None,
|
|
300
479
|
) -> None:
|
|
301
480
|
"""Encodes raster data.
|
|
302
481
|
|
|
@@ -305,7 +484,11 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
305
484
|
projection: the projection of the raster data
|
|
306
485
|
bounds: the bounds of the raster data in the projection
|
|
307
486
|
array: the raster data
|
|
487
|
+
fname: override the filename to save as
|
|
308
488
|
"""
|
|
489
|
+
if fname is None:
|
|
490
|
+
fname = self.fname
|
|
491
|
+
|
|
309
492
|
crs = projection.crs
|
|
310
493
|
transform = affine.Affine(
|
|
311
494
|
projection.x_resolution,
|
|
@@ -338,76 +521,48 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
338
521
|
profile["blockxsize"] = self.block_size
|
|
339
522
|
profile["blockysize"] = self.block_size
|
|
340
523
|
|
|
524
|
+
profile.update(self.geotiff_options)
|
|
525
|
+
|
|
341
526
|
path.mkdir(parents=True, exist_ok=True)
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
527
|
+
logger.debug(f"Writing geotiff to {path / fname}")
|
|
528
|
+
with open_rasterio_upath_writer(path / fname, **profile) as dst:
|
|
529
|
+
dst.write(array)
|
|
345
530
|
|
|
346
531
|
def decode_raster(
|
|
347
|
-
self,
|
|
348
|
-
|
|
532
|
+
self,
|
|
533
|
+
path: UPath,
|
|
534
|
+
projection: Projection,
|
|
535
|
+
bounds: PixelBounds,
|
|
536
|
+
resampling: Resampling = Resampling.bilinear,
|
|
537
|
+
fname: str | None = None,
|
|
538
|
+
) -> npt.NDArray[Any]:
|
|
349
539
|
"""Decodes raster data.
|
|
350
540
|
|
|
351
541
|
Args:
|
|
352
542
|
path: the directory to read from
|
|
353
|
-
|
|
543
|
+
projection: the projection to read the raster in.
|
|
544
|
+
bounds: the bounds to read in the given projection.
|
|
545
|
+
resampling: resampling method to use in case resampling is needed.
|
|
546
|
+
fname: override the filename to read from
|
|
354
547
|
|
|
355
548
|
Returns:
|
|
356
|
-
the raster data
|
|
549
|
+
the raster data
|
|
357
550
|
"""
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
bounds[3] - offset[1],
|
|
374
|
-
]
|
|
375
|
-
if (
|
|
376
|
-
relative_bounds[2] < 0
|
|
377
|
-
or relative_bounds[3] < 0
|
|
378
|
-
or relative_bounds[0] >= src.width
|
|
379
|
-
or relative_bounds[1] >= src.height
|
|
380
|
-
):
|
|
381
|
-
return None
|
|
382
|
-
# Now get the actual pixels we will read, which must be contained in
|
|
383
|
-
# the GeoTIFF.
|
|
384
|
-
# Padding is (before_x, before_y, after_x, after_y) and will be used to
|
|
385
|
-
# pad the output back to the originally requested bounds.
|
|
386
|
-
padding = [0, 0, 0, 0]
|
|
387
|
-
if relative_bounds[0] < 0:
|
|
388
|
-
padding[0] = -relative_bounds[0]
|
|
389
|
-
relative_bounds[0] = 0
|
|
390
|
-
if relative_bounds[1] < 0:
|
|
391
|
-
padding[1] = -relative_bounds[1]
|
|
392
|
-
relative_bounds[1] = 0
|
|
393
|
-
if relative_bounds[2] > src.width:
|
|
394
|
-
padding[2] = relative_bounds[2] - src.width
|
|
395
|
-
relative_bounds[2] = src.width
|
|
396
|
-
if relative_bounds[3] > src.height:
|
|
397
|
-
padding[3] = relative_bounds[3] - src.height
|
|
398
|
-
relative_bounds[3] = src.height
|
|
399
|
-
|
|
400
|
-
window = rasterio.windows.Window(
|
|
401
|
-
relative_bounds[0],
|
|
402
|
-
relative_bounds[1],
|
|
403
|
-
relative_bounds[2] - relative_bounds[0],
|
|
404
|
-
relative_bounds[3] - relative_bounds[1],
|
|
405
|
-
)
|
|
406
|
-
array = src.read(window=window)
|
|
407
|
-
array = np.pad(
|
|
408
|
-
array, ((0, 0), (padding[1], padding[3]), (padding[0], padding[2]))
|
|
409
|
-
)
|
|
410
|
-
return array
|
|
551
|
+
if fname is None:
|
|
552
|
+
fname = self.fname
|
|
553
|
+
|
|
554
|
+
# Construct the transform to use for the warped dataset.
|
|
555
|
+
wanted_transform = get_transform_from_projection_and_bounds(projection, bounds)
|
|
556
|
+
with open_rasterio_upath_reader(path / fname) as src:
|
|
557
|
+
with rasterio.vrt.WarpedVRT(
|
|
558
|
+
src,
|
|
559
|
+
crs=projection.crs,
|
|
560
|
+
transform=wanted_transform,
|
|
561
|
+
width=bounds[2] - bounds[0],
|
|
562
|
+
height=bounds[3] - bounds[1],
|
|
563
|
+
resampling=resampling,
|
|
564
|
+
) as vrt:
|
|
565
|
+
return vrt.read()
|
|
411
566
|
|
|
412
567
|
def get_raster_bounds(self, path: UPath) -> PixelBounds:
|
|
413
568
|
"""Returns the bounds of the stored raster.
|
|
@@ -418,21 +573,9 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
418
573
|
Returns:
|
|
419
574
|
the PixelBounds of the raster
|
|
420
575
|
"""
|
|
421
|
-
with (path / self.fname)
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
x_resolution = transform.a
|
|
425
|
-
y_resolution = transform.e
|
|
426
|
-
offset = (
|
|
427
|
-
int(transform.c / x_resolution),
|
|
428
|
-
int(transform.f / y_resolution),
|
|
429
|
-
)
|
|
430
|
-
return (
|
|
431
|
-
offset[0],
|
|
432
|
-
offset[1],
|
|
433
|
-
offset[0] + src.width,
|
|
434
|
-
offset[1] + src.height,
|
|
435
|
-
)
|
|
576
|
+
with open_rasterio_upath_reader(path / self.fname) as src:
|
|
577
|
+
_, bounds = get_raster_projection_and_bounds(src)
|
|
578
|
+
return bounds
|
|
436
579
|
|
|
437
580
|
@staticmethod
|
|
438
581
|
def from_config(name: str, config: dict[str, Any]) -> "GeotiffRasterFormat":
|
|
@@ -450,10 +593,11 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
450
593
|
kwargs["block_size"] = config["block_size"]
|
|
451
594
|
if "always_enable_tiling" in config:
|
|
452
595
|
kwargs["always_enable_tiling"] = config["always_enable_tiling"]
|
|
596
|
+
if "geotiff_options" in config:
|
|
597
|
+
kwargs["geotiff_options"] = config["geotiff_options"]
|
|
453
598
|
return GeotiffRasterFormat(**kwargs)
|
|
454
599
|
|
|
455
600
|
|
|
456
|
-
@RasterFormats.register("single_image")
|
|
457
601
|
class SingleImageRasterFormat(RasterFormat):
|
|
458
602
|
"""A raster format that produces a single image called image.png/jpg.
|
|
459
603
|
|
|
@@ -503,35 +647,57 @@ class SingleImageRasterFormat(RasterFormat):
|
|
|
503
647
|
if array.shape[2] == 1:
|
|
504
648
|
array = array[:, :, 0]
|
|
505
649
|
Image.fromarray(array).save(f, format=self.format.upper())
|
|
650
|
+
|
|
651
|
+
# Since the image file doesn't include the georeferencing, we store it in an
|
|
652
|
+
# auxiliary metadata file.
|
|
506
653
|
with (path / "metadata.json").open("w") as f:
|
|
507
654
|
json.dump(
|
|
508
655
|
{
|
|
656
|
+
"projection": projection.serialize(),
|
|
509
657
|
"bounds": bounds,
|
|
510
658
|
},
|
|
511
659
|
f,
|
|
512
660
|
)
|
|
513
661
|
|
|
514
662
|
def decode_raster(
|
|
515
|
-
self,
|
|
516
|
-
|
|
663
|
+
self,
|
|
664
|
+
path: UPath,
|
|
665
|
+
projection: Projection,
|
|
666
|
+
bounds: PixelBounds,
|
|
667
|
+
resampling: Resampling = Resampling.bilinear,
|
|
668
|
+
) -> npt.NDArray[Any]:
|
|
517
669
|
"""Decodes raster data.
|
|
518
670
|
|
|
519
671
|
Args:
|
|
520
672
|
path: the directory to read from
|
|
521
|
-
|
|
673
|
+
projection: the projection to read the raster in.
|
|
674
|
+
bounds: the bounds to read in the given projection.
|
|
675
|
+
resampling: resampling method to use in case resampling is needed.
|
|
522
676
|
|
|
523
677
|
Returns:
|
|
524
|
-
the raster data
|
|
678
|
+
the raster data
|
|
525
679
|
"""
|
|
526
|
-
|
|
680
|
+
# Try to get the bounds of the saved image from the metadata file.
|
|
681
|
+
# In old versions, the file may be missing the projection key.
|
|
527
682
|
metadata_fname = path / "metadata.json"
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
# Backwards compatibility -- assume that requested bounds matches the window bounds.
|
|
533
|
-
image_bounds = bounds
|
|
683
|
+
with metadata_fname.open() as f:
|
|
684
|
+
image_metadata = json.load(f)
|
|
685
|
+
|
|
686
|
+
image_bounds = image_metadata["bounds"]
|
|
534
687
|
|
|
688
|
+
# If the projection key is set, verify that it matches the requested projection
|
|
689
|
+
# since SingleImageRasterFormat currently does not support re-projecting.
|
|
690
|
+
if "projection" in image_metadata:
|
|
691
|
+
source_data_projection = Projection.deserialize(
|
|
692
|
+
image_metadata["projection"]
|
|
693
|
+
)
|
|
694
|
+
if projection != source_data_projection:
|
|
695
|
+
raise NotImplementedError(
|
|
696
|
+
"not implemented to re-project source data "
|
|
697
|
+
+ f"(source projection {source_data_projection} does not match requested projection {projection})"
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
image_fname = path / ("image." + self.get_extension())
|
|
535
701
|
with image_fname.open("rb") as f:
|
|
536
702
|
array = np.array(Image.open(f, formats=[self.format.upper()]))
|
|
537
703
|
|
|
@@ -583,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
|
|
|
583
749
|
if "format" in config:
|
|
584
750
|
kwargs["format"] = config["format"]
|
|
585
751
|
return SingleImageRasterFormat(**kwargs)
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
|
|
589
|
-
"""Loads a RasterFormat from a RasterFormatConfig.
|
|
590
|
-
|
|
591
|
-
Args:
|
|
592
|
-
config: the RasterFormatConfig configuration object specifying the
|
|
593
|
-
RasterFormat.
|
|
594
|
-
|
|
595
|
-
Returns:
|
|
596
|
-
the loaded RasterFormat implementation
|
|
597
|
-
"""
|
|
598
|
-
cls = RasterFormats.get_class(config.name)
|
|
599
|
-
return cls.from_config(config.name, config.config_dict)
|
rslearn/utils/rtree_index.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""RtreeIndex spatial index implementation."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
3
4
|
import os
|
|
4
5
|
import shutil
|
|
6
|
+
import tempfile
|
|
5
7
|
from collections.abc import Callable
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
@@ -9,8 +11,11 @@ import fsspec
|
|
|
9
11
|
from rtree import index
|
|
10
12
|
from upath import UPath
|
|
11
13
|
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
12
15
|
from rslearn.utils.spatial_index import SpatialIndex
|
|
13
16
|
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
14
19
|
|
|
15
20
|
class RtreeIndex(SpatialIndex):
|
|
16
21
|
"""An index of spatiotemporal geometries backed by an rtree index.
|
|
@@ -18,7 +23,7 @@ class RtreeIndex(SpatialIndex):
|
|
|
18
23
|
Both in-memory and on-disk options are supported.
|
|
19
24
|
"""
|
|
20
25
|
|
|
21
|
-
def __init__(self, fname: str | None = None):
|
|
26
|
+
def __init__(self, fname: str | None = None) -> None:
|
|
22
27
|
"""Initialize a new RtreeIndex.
|
|
23
28
|
|
|
24
29
|
If fname is set, the index is persisted on disk, otherwise it is in-memory.
|
|
@@ -50,6 +55,7 @@ class RtreeIndex(SpatialIndex):
|
|
|
50
55
|
self.counter += 1
|
|
51
56
|
self.index.insert(id=self.counter, coordinates=box, obj=data)
|
|
52
57
|
|
|
58
|
+
# TODO: Make a named tuple for all the bounding box stuff
|
|
53
59
|
def query(self, box: tuple[float, float, float, float]) -> list[Any]:
|
|
54
60
|
"""Query the index for objects intersecting a box.
|
|
55
61
|
|
|
@@ -63,20 +69,51 @@ class RtreeIndex(SpatialIndex):
|
|
|
63
69
|
return [r.object for r in results]
|
|
64
70
|
|
|
65
71
|
|
|
72
|
+
def delete_partially_created_local_files(fname: str) -> None:
|
|
73
|
+
"""Delete partially created .dat and .idx files."""
|
|
74
|
+
extensions = [".dat", ".idx"]
|
|
75
|
+
for ext in extensions:
|
|
76
|
+
cur_fname = fname + ext
|
|
77
|
+
if os.path.exists(cur_fname):
|
|
78
|
+
os.unlink(cur_fname)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_tmp_dir_for_cached_rtree(cache_dir: UPath) -> str:
|
|
82
|
+
"""Get a local temporary directory to store the rtree from the specified cache_dir.
|
|
83
|
+
|
|
84
|
+
This function is deterministic so the same cache_dir will always yield the same
|
|
85
|
+
local temporary directory.
|
|
86
|
+
|
|
87
|
+
Note that the directory is not cleaned up after the program exits, so the rtree
|
|
88
|
+
will remain there. This is because this function may be called from multiple worker
|
|
89
|
+
processes but the index should be reused across workers.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
cache_dir: the non-local directory where the rtree files are stored.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
the temporary local directory to copy the cached rtree to.
|
|
96
|
+
"""
|
|
97
|
+
cache_id = hashlib.sha256(str(cache_dir).encode()).hexdigest()
|
|
98
|
+
tmp_dir = os.path.join(
|
|
99
|
+
tempfile.gettempdir(), "rslearn_cache", "rtree_index", cache_id
|
|
100
|
+
)
|
|
101
|
+
os.makedirs(tmp_dir, exist_ok=True)
|
|
102
|
+
return tmp_dir
|
|
103
|
+
|
|
104
|
+
|
|
66
105
|
def get_cached_rtree(
|
|
67
|
-
cache_dir: UPath,
|
|
106
|
+
cache_dir: UPath, build_fn: Callable[[RtreeIndex], None]
|
|
68
107
|
) -> RtreeIndex:
|
|
69
108
|
"""Returns an RtreeIndex cached in cache_dir, creating it if needed.
|
|
70
109
|
|
|
71
110
|
The .dat and .idx files are cached in cache_dir. Since RtreeIndex expects local
|
|
72
|
-
filesystem, it is copied to a local temporary directory if needed
|
|
73
|
-
doesn't exist yet,
|
|
111
|
+
filesystem, it is copied to a local temporary directory if needed (it is not needed
|
|
112
|
+
if the cache_dir is already on local filesystem). If the index doesn't exist yet,
|
|
113
|
+
then it is created using build_fn.
|
|
74
114
|
|
|
75
115
|
Args:
|
|
76
116
|
cache_dir: directory to cache the index files.
|
|
77
|
-
tmp_dir: temporary local directory to use in case cache_dir is on a remote
|
|
78
|
-
filesystem. The caller is responsible for cleaning this up when they don't
|
|
79
|
-
need the index anymore.
|
|
80
117
|
build_fn: function to build the index in case it doesn't exist yet.
|
|
81
118
|
|
|
82
119
|
Returns:
|
|
@@ -95,14 +132,13 @@ def get_cached_rtree(
|
|
|
95
132
|
if is_local_cache:
|
|
96
133
|
local_fname = (cache_dir / "rtree_index").path
|
|
97
134
|
else:
|
|
135
|
+
tmp_dir = _get_tmp_dir_for_cached_rtree(cache_dir)
|
|
98
136
|
local_fname = os.path.join(tmp_dir, "rtree_index")
|
|
137
|
+
delete_partially_created_local_files(local_fname)
|
|
99
138
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if os.path.exists(cur_fname):
|
|
104
|
-
os.unlink(cur_fname)
|
|
105
|
-
|
|
139
|
+
logger.info(
|
|
140
|
+
"building rtree index at %s to be cached at %s", local_fname, cache_dir
|
|
141
|
+
)
|
|
106
142
|
rtree_index = RtreeIndex(local_fname)
|
|
107
143
|
build_fn(rtree_index)
|
|
108
144
|
del rtree_index
|
|
@@ -115,6 +151,7 @@ def get_cached_rtree(
|
|
|
115
151
|
|
|
116
152
|
# Create the completed file to indicate index is ready in cache.
|
|
117
153
|
completed_fname.touch()
|
|
154
|
+
logger.info("rtree index is built and ready")
|
|
118
155
|
|
|
119
156
|
else:
|
|
120
157
|
# Initialize the index from the cached version.
|
|
@@ -122,10 +159,20 @@ def get_cached_rtree(
|
|
|
122
159
|
if is_local_cache:
|
|
123
160
|
local_fname = (cache_dir / "rtree_index").path
|
|
124
161
|
else:
|
|
162
|
+
tmp_dir = _get_tmp_dir_for_cached_rtree(cache_dir)
|
|
125
163
|
local_fname = os.path.join(tmp_dir, "rtree_index")
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
164
|
+
|
|
165
|
+
if not os.path.exists(local_fname + extensions[-1]):
|
|
166
|
+
logger.info(
|
|
167
|
+
"copying rtree index from non-local cache at %s to local temporary directory %s",
|
|
168
|
+
cache_dir,
|
|
169
|
+
local_fname,
|
|
170
|
+
)
|
|
171
|
+
for ext in extensions:
|
|
172
|
+
with (cache_dir / f"rtree_index{ext}").open("rb") as src:
|
|
173
|
+
with open(local_fname + ext + ".tmp", "wb") as dst:
|
|
174
|
+
shutil.copyfileobj(src, dst)
|
|
175
|
+
os.rename(local_fname + ext + ".tmp", local_fname + ext)
|
|
176
|
+
logger.info("rtree index is ready")
|
|
130
177
|
|
|
131
178
|
return RtreeIndex(local_fname)
|