rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/utils/sqlite_index.py
CHANGED
|
@@ -1,4 +1,10 @@
|
|
|
1
|
-
"""Contains a SpatialIndex implementation that uses an sqlite database.
|
|
1
|
+
"""Contains a SpatialIndex implementation that uses an sqlite database.
|
|
2
|
+
|
|
3
|
+
# TODO: This is not yet complete decide to either complete it or remove this file.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# Ignoring Mypy until we determine if we want to keep this file.
|
|
7
|
+
# mypy: ignore-errors
|
|
2
8
|
|
|
3
9
|
import json
|
|
4
10
|
import sqlite3
|
rslearn/utils/vector_format.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
"""Classes for writing vector data to a UPath."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
6
|
-
import numpy as np
|
|
7
7
|
import shapely
|
|
8
|
-
from
|
|
8
|
+
from rasterio.crs import CRS
|
|
9
9
|
from upath import UPath
|
|
10
10
|
|
|
11
|
-
from rslearn.config import VectorFormatConfig
|
|
12
11
|
from rslearn.const import WGS84_PROJECTION
|
|
12
|
+
from rslearn.log_utils import get_logger
|
|
13
|
+
from rslearn.utils.fsspec import open_atomic
|
|
13
14
|
|
|
14
15
|
from .feature import Feature
|
|
15
|
-
from .geometry import PixelBounds, Projection, STGeometry
|
|
16
|
+
from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
logger = get_logger(__name__)
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class VectorFormat:
|
|
@@ -24,32 +25,45 @@ class VectorFormat:
|
|
|
24
25
|
a UPath. Vector data is a list of GeoJSON-like features.
|
|
25
26
|
"""
|
|
26
27
|
|
|
27
|
-
def encode_vector(
|
|
28
|
-
self, path: UPath, projection: Projection, features: list[Feature]
|
|
29
|
-
) -> None:
|
|
28
|
+
def encode_vector(self, path: UPath, features: list[Feature]) -> None:
|
|
30
29
|
"""Encodes vector data.
|
|
31
30
|
|
|
32
31
|
Args:
|
|
33
32
|
path: the directory to write to
|
|
34
|
-
projection: the projection of the raster data
|
|
35
33
|
features: the vector data
|
|
36
34
|
"""
|
|
37
35
|
raise NotImplementedError
|
|
38
36
|
|
|
39
|
-
def decode_vector(
|
|
37
|
+
def decode_vector(
|
|
38
|
+
self, path: UPath, projection: Projection, bounds: PixelBounds
|
|
39
|
+
) -> list[Feature]:
|
|
40
40
|
"""Decodes vector data.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
43
|
path: the directory to read from
|
|
44
|
-
|
|
44
|
+
projection: the projection to read the data in
|
|
45
|
+
bounds: the bounds to read under the given projection. Only features that
|
|
46
|
+
intersect the bounds should be returned.
|
|
45
47
|
|
|
46
48
|
Returns:
|
|
47
49
|
the vector data
|
|
48
50
|
"""
|
|
49
51
|
raise NotImplementedError
|
|
50
52
|
|
|
53
|
+
@staticmethod
|
|
54
|
+
def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
|
|
55
|
+
"""Create a VectorFormat from a config dict.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
name: the name of this format
|
|
59
|
+
config: the config dict
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
the VectorFormat instance
|
|
63
|
+
"""
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
51
66
|
|
|
52
|
-
@VectorFormats.register("tile")
|
|
53
67
|
class TileVectorFormat(VectorFormat):
|
|
54
68
|
"""TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
|
|
55
69
|
|
|
@@ -58,29 +72,62 @@ class TileVectorFormat(VectorFormat):
|
|
|
58
72
|
intersect.
|
|
59
73
|
"""
|
|
60
74
|
|
|
61
|
-
def __init__(
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
tile_size: int = 512,
|
|
78
|
+
projection: Projection | None = None,
|
|
79
|
+
index_property_name: str = "tvf_index",
|
|
80
|
+
):
|
|
62
81
|
"""Initialize a new TileVectorFormat instance.
|
|
63
82
|
|
|
64
83
|
Args:
|
|
65
84
|
tile_size: the tile size (grid size in pixels), default 512
|
|
85
|
+
projection: if set, store features under this projection. Otherwise, the
|
|
86
|
+
output projection is taken from the first feature in an encode_vector
|
|
87
|
+
call.
|
|
88
|
+
index_property_name: property name used to store an index integer that
|
|
89
|
+
identifies the same feature across different tiles.
|
|
66
90
|
"""
|
|
67
91
|
self.tile_size = tile_size
|
|
92
|
+
self.projection = projection
|
|
93
|
+
self.index_property_name = index_property_name
|
|
68
94
|
|
|
69
|
-
def encode_vector(
|
|
70
|
-
self, path: UPath, projection: Projection, features: list[Feature]
|
|
71
|
-
) -> None:
|
|
95
|
+
def encode_vector(self, path: UPath, features: list[Feature]) -> None:
|
|
72
96
|
"""Encodes vector data.
|
|
73
97
|
|
|
74
98
|
Args:
|
|
75
99
|
path: the directory to write to
|
|
76
|
-
projection: the projection of the raster data
|
|
77
100
|
features: the vector data
|
|
78
101
|
"""
|
|
79
|
-
|
|
80
|
-
|
|
102
|
+
# Determine the output projection to write in.
|
|
103
|
+
if len(features) == 0:
|
|
104
|
+
# We won't actually write any features but still setting output_projection
|
|
105
|
+
# to write to projection.json.
|
|
106
|
+
# We just fallback to WGS84 here.
|
|
107
|
+
output_projection = WGS84_PROJECTION
|
|
108
|
+
elif self.projection is not None:
|
|
109
|
+
output_projection = self.projection
|
|
110
|
+
else:
|
|
111
|
+
output_projection = features[0].geometry.projection
|
|
112
|
+
|
|
113
|
+
# Save metadata file containing the serialized projection so we can load it
|
|
114
|
+
# when decoding.
|
|
115
|
+
with open_atomic(path / "projection.json", "w") as f:
|
|
116
|
+
json.dump(output_projection.serialize(), f)
|
|
117
|
+
|
|
118
|
+
# Dictionary from tile (col, row) to the list of features intersecting that
|
|
119
|
+
# tile. We iterate over the features to populate tile_data, then write each
|
|
120
|
+
# tile as a separate file.
|
|
121
|
+
tile_data: dict[tuple[int, int], list[dict]] = {}
|
|
122
|
+
|
|
123
|
+
for feat_idx, feat in enumerate(features):
|
|
124
|
+
# Skip invalid features since they can cause errors.
|
|
81
125
|
if not feat.geometry.shp.is_valid:
|
|
82
126
|
continue
|
|
83
|
-
|
|
127
|
+
|
|
128
|
+
# Identify each grid cell that this feature intersects.
|
|
129
|
+
geometry = feat.geometry.to_projection(output_projection)
|
|
130
|
+
bounds = geometry.shp.bounds
|
|
84
131
|
start_tile = (
|
|
85
132
|
int(bounds[0]) // self.tile_size,
|
|
86
133
|
int(bounds[1]) // self.tile_size,
|
|
@@ -89,58 +136,73 @@ class TileVectorFormat(VectorFormat):
|
|
|
89
136
|
int(bounds[2]) // self.tile_size + 1,
|
|
90
137
|
int(bounds[3]) // self.tile_size + 1,
|
|
91
138
|
)
|
|
139
|
+
|
|
140
|
+
# We add an index property to the features so when reading we can
|
|
141
|
+
# de-duplicate (in case we read multiple tiles that contain the same
|
|
142
|
+
# feature).
|
|
143
|
+
properties = {self.index_property_name: feat_idx}
|
|
144
|
+
properties.update(feat.properties)
|
|
145
|
+
# Use the re-projected geometry here.
|
|
146
|
+
output_feat = Feature(geometry, properties)
|
|
147
|
+
output_geojson = output_feat.to_geojson()
|
|
148
|
+
|
|
149
|
+
# Now we add the feature to each tile that it intersects.
|
|
92
150
|
for col in range(start_tile[0], end_tile[0]):
|
|
93
151
|
for row in range(start_tile[1], end_tile[1]):
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
(row + 1) * self.tile_size,
|
|
100
|
-
)
|
|
152
|
+
tile_box = shapely.box(
|
|
153
|
+
col * self.tile_size,
|
|
154
|
+
row * self.tile_size,
|
|
155
|
+
(col + 1) * self.tile_size,
|
|
156
|
+
(row + 1) * self.tile_size,
|
|
101
157
|
)
|
|
102
|
-
|
|
103
|
-
cur_shp,
|
|
104
|
-
lambda array: array
|
|
105
|
-
- np.array([[col * self.tile_size, row * self.tile_size]]),
|
|
106
|
-
)
|
|
107
|
-
cur_feat = Feature(
|
|
108
|
-
STGeometry(projection, cur_shp, None), feat.properties
|
|
109
|
-
)
|
|
110
|
-
try:
|
|
111
|
-
cur_geojson = cur_feat.to_geojson()
|
|
112
|
-
except Exception as e:
|
|
113
|
-
print(e)
|
|
158
|
+
if not geometry.shp.intersects(tile_box):
|
|
114
159
|
continue
|
|
115
160
|
tile = (col, row)
|
|
116
161
|
if tile not in tile_data:
|
|
117
162
|
tile_data[tile] = []
|
|
118
|
-
tile_data[tile].append(
|
|
163
|
+
tile_data[tile].append(output_geojson)
|
|
119
164
|
|
|
120
165
|
path.mkdir(parents=True, exist_ok=True)
|
|
166
|
+
|
|
167
|
+
# Now save each tile.
|
|
121
168
|
for (col, row), geojson_features in tile_data.items():
|
|
122
169
|
fc = {
|
|
123
170
|
"type": "FeatureCollection",
|
|
124
171
|
"features": [geojson_feat for geojson_feat in geojson_features],
|
|
125
|
-
"properties":
|
|
172
|
+
"properties": output_projection.serialize(),
|
|
126
173
|
}
|
|
127
|
-
|
|
174
|
+
cur_fname = path / f"{col}_{row}.geojson"
|
|
175
|
+
logger.debug("writing tile (%d, %d) to %s", col, row, cur_fname)
|
|
176
|
+
with open_atomic(cur_fname, "w") as f:
|
|
128
177
|
json.dump(fc, f)
|
|
129
178
|
|
|
130
|
-
def decode_vector(
|
|
179
|
+
def decode_vector(
|
|
180
|
+
self, path: UPath, projection: Projection, bounds: PixelBounds
|
|
181
|
+
) -> list[Feature]:
|
|
131
182
|
"""Decodes vector data.
|
|
132
183
|
|
|
133
184
|
Args:
|
|
134
185
|
path: the directory to read from
|
|
135
|
-
|
|
186
|
+
projection: the projection to read the data in
|
|
187
|
+
bounds: the bounds to read under the given projection. Only features that
|
|
188
|
+
intersect the bounds should be returned.
|
|
136
189
|
|
|
137
190
|
Returns:
|
|
138
191
|
the vector data
|
|
139
192
|
"""
|
|
140
|
-
|
|
193
|
+
# Convert the bounds to the projection of the stored data.
|
|
194
|
+
with (path / "projection.json").open() as f:
|
|
195
|
+
storage_projection = Projection.deserialize(json.load(f))
|
|
196
|
+
bounds_geom = STGeometry(projection, shapely.box(*bounds), None)
|
|
197
|
+
storage_bounds = bounds_geom.to_projection(storage_projection).shp.bounds
|
|
198
|
+
|
|
199
|
+
start_tile = (
|
|
200
|
+
int(storage_bounds[0]) // self.tile_size,
|
|
201
|
+
int(storage_bounds[1]) // self.tile_size,
|
|
202
|
+
)
|
|
141
203
|
end_tile = (
|
|
142
|
-
(
|
|
143
|
-
(
|
|
204
|
+
(int(storage_bounds[2]) - 1) // self.tile_size + 1,
|
|
205
|
+
(int(storage_bounds[3]) - 1) // self.tile_size + 1,
|
|
144
206
|
)
|
|
145
207
|
features = []
|
|
146
208
|
for col in range(start_tile[0], end_tile[0]):
|
|
@@ -148,22 +210,12 @@ class TileVectorFormat(VectorFormat):
|
|
|
148
210
|
cur_fname = path / f"{col}_{row}.geojson"
|
|
149
211
|
if not cur_fname.exists():
|
|
150
212
|
continue
|
|
151
|
-
with cur_fname.open(
|
|
213
|
+
with cur_fname.open() as f:
|
|
152
214
|
fc = json.load(f)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
projection
|
|
157
|
-
|
|
158
|
-
for feat in fc["features"]:
|
|
159
|
-
shp = shapely.geometry.shape(feat["geometry"])
|
|
160
|
-
shp = shapely.transform(
|
|
161
|
-
shp,
|
|
162
|
-
lambda array: array
|
|
163
|
-
+ np.array([[col * self.tile_size, row * self.tile_size]]),
|
|
164
|
-
)
|
|
165
|
-
feat["geometry"] = json.loads(shapely.to_geojson(shp))
|
|
166
|
-
features.append(Feature.from_geojson(projection, feat))
|
|
215
|
+
|
|
216
|
+
for geojson_feat in fc["features"]:
|
|
217
|
+
feat = Feature.from_geojson(storage_projection, geojson_feat)
|
|
218
|
+
features.append(feat.to_projection(projection))
|
|
167
219
|
return features
|
|
168
220
|
|
|
169
221
|
@staticmethod
|
|
@@ -177,54 +229,168 @@ class TileVectorFormat(VectorFormat):
|
|
|
177
229
|
Returns:
|
|
178
230
|
the TileVectorFormat
|
|
179
231
|
"""
|
|
180
|
-
|
|
232
|
+
kwargs = {}
|
|
233
|
+
if "tile_size" in config:
|
|
234
|
+
kwargs["tile_size"] = config["tile_size"]
|
|
235
|
+
if "projection" in config:
|
|
236
|
+
kwargs["projection"] = Projection.deserialize(config["projection"])
|
|
237
|
+
if "index_property_name" in config:
|
|
238
|
+
kwargs["index_property_name"] = config["index_property_name"]
|
|
239
|
+
return TileVectorFormat(**kwargs)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class GeojsonCoordinateMode(Enum):
|
|
243
|
+
"""The projection to use when writing GeoJSON file."""
|
|
244
|
+
|
|
245
|
+
# Write the features as is.
|
|
246
|
+
PIXEL = "pixel"
|
|
247
|
+
|
|
248
|
+
# Write the features in CRS coordinates (i.e., a projection with x_resolution=1 and
|
|
249
|
+
# y_resolution=1).
|
|
250
|
+
CRS = "crs"
|
|
251
|
+
|
|
252
|
+
# Write in WGS84 (longitude, latitude) coordinates.
|
|
253
|
+
WGS84 = "wgs84"
|
|
181
254
|
|
|
182
255
|
|
|
183
|
-
@VectorFormats.register("geojson")
|
|
184
256
|
class GeojsonVectorFormat(VectorFormat):
|
|
185
257
|
"""A vector format that uses one big GeoJSON."""
|
|
186
258
|
|
|
187
259
|
fname = "data.geojson"
|
|
188
260
|
|
|
189
|
-
def
|
|
190
|
-
self,
|
|
191
|
-
)
|
|
261
|
+
def __init__(
|
|
262
|
+
self, coordinate_mode: GeojsonCoordinateMode = GeojsonCoordinateMode.PIXEL
|
|
263
|
+
):
|
|
264
|
+
"""Create a new GeojsonVectorFormat.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
coordinate_mode: the projection to use for coordinates written to the
|
|
268
|
+
GeoJSON files. PIXEL means we write them as is, CRS means we just undo
|
|
269
|
+
the resolution in the Projection so they are in CRS coordinates, and
|
|
270
|
+
WGS84 means we always write longitude/latitude. When using PIXEL, the
|
|
271
|
+
GeoJSON will not be readable by GIS tools since it relies on a custom
|
|
272
|
+
encoding.
|
|
273
|
+
"""
|
|
274
|
+
self.coordinate_mode = coordinate_mode
|
|
275
|
+
|
|
276
|
+
def encode_to_file(self, fname: UPath, features: list[Feature]) -> None:
|
|
277
|
+
"""Encode vector data to a specific file.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
fname: the file to write to
|
|
281
|
+
features: the vector data
|
|
282
|
+
"""
|
|
283
|
+
fc: dict[str, Any] = {"type": "FeatureCollection"}
|
|
284
|
+
|
|
285
|
+
# Identify target projection and convert features.
|
|
286
|
+
# Also set the target projection in the FeatureCollection.
|
|
287
|
+
# For PIXEL mode, we need to use a custom encoding so the resolution is stored.
|
|
288
|
+
output_projection: Projection
|
|
289
|
+
if len(features) > 0 and self.coordinate_mode != GeojsonCoordinateMode.WGS84:
|
|
290
|
+
if self.coordinate_mode == GeojsonCoordinateMode.PIXEL:
|
|
291
|
+
output_projection = features[0].geometry.projection
|
|
292
|
+
fc["properties"] = output_projection.serialize()
|
|
293
|
+
elif self.coordinate_mode == GeojsonCoordinateMode.CRS:
|
|
294
|
+
output_projection = Projection(
|
|
295
|
+
features[0].geometry.projection.crs, 1, 1
|
|
296
|
+
)
|
|
297
|
+
fc["crs"] = {
|
|
298
|
+
"type": "name",
|
|
299
|
+
"properties": {
|
|
300
|
+
"name": output_projection.crs.to_wkt(),
|
|
301
|
+
},
|
|
302
|
+
}
|
|
303
|
+
else:
|
|
304
|
+
# Either there are no features so we need to fallback to WGS84, or the
|
|
305
|
+
# coordinate mode is WGS84.
|
|
306
|
+
output_projection = WGS84_PROJECTION
|
|
307
|
+
fc["crs"] = {
|
|
308
|
+
"type": "name",
|
|
309
|
+
"properties": {
|
|
310
|
+
"name": output_projection.crs.to_wkt(),
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
fc["features"] = []
|
|
315
|
+
for feat in features:
|
|
316
|
+
feat = feat.to_projection(output_projection)
|
|
317
|
+
fc["features"].append(feat.to_geojson())
|
|
318
|
+
|
|
319
|
+
logger.debug(
|
|
320
|
+
"writing features to %s with coordinate mode %s",
|
|
321
|
+
fname,
|
|
322
|
+
self.coordinate_mode,
|
|
323
|
+
)
|
|
324
|
+
with open_atomic(fname, "w") as f:
|
|
325
|
+
json.dump(fc, f)
|
|
326
|
+
|
|
327
|
+
def encode_vector(self, path: UPath, features: list[Feature]) -> None:
|
|
192
328
|
"""Encodes vector data.
|
|
193
329
|
|
|
194
330
|
Args:
|
|
195
331
|
path: the directory to write to
|
|
196
|
-
projection: the projection of the raster data
|
|
197
332
|
features: the vector data
|
|
198
333
|
"""
|
|
199
334
|
path.mkdir(parents=True, exist_ok=True)
|
|
200
|
-
|
|
201
|
-
json.dump(
|
|
202
|
-
{
|
|
203
|
-
"type": "FeatureCollection",
|
|
204
|
-
"features": [feat.to_geojson() for feat in features],
|
|
205
|
-
"properties": projection.serialize(),
|
|
206
|
-
},
|
|
207
|
-
f,
|
|
208
|
-
)
|
|
335
|
+
self.encode_to_file(path / self.fname, features)
|
|
209
336
|
|
|
210
|
-
def
|
|
211
|
-
"""Decodes vector data.
|
|
337
|
+
def decode_from_file(self, fname: UPath) -> list[Feature]:
|
|
338
|
+
"""Decodes vector data from a filename.
|
|
212
339
|
|
|
213
340
|
Args:
|
|
214
|
-
|
|
215
|
-
bounds: the bounds of the vector data to read
|
|
341
|
+
fname: the filename to read.
|
|
216
342
|
|
|
217
343
|
Returns:
|
|
218
344
|
the vector data
|
|
219
345
|
"""
|
|
220
|
-
with
|
|
346
|
+
with fname.open() as f:
|
|
221
347
|
fc = json.load(f)
|
|
348
|
+
|
|
349
|
+
# Detect the projection that the features are stored under.
|
|
222
350
|
if "properties" in fc and "crs" in fc["properties"]:
|
|
351
|
+
# Means it uses our custom Projection encoding.
|
|
223
352
|
projection = Projection.deserialize(fc["properties"])
|
|
353
|
+
elif "crs" in fc:
|
|
354
|
+
# Means it uses standard GeoJSON CRS encoding.
|
|
355
|
+
crs = CRS.from_string(fc["crs"]["properties"]["name"])
|
|
356
|
+
projection = Projection(crs, 1, 1)
|
|
224
357
|
else:
|
|
358
|
+
# Otherwise it should be WGS84 (GeoJSONs created in rslearn should include
|
|
359
|
+
# the "crs" attribute, but maybe it was created externally).
|
|
225
360
|
projection = WGS84_PROJECTION
|
|
361
|
+
|
|
226
362
|
return [Feature.from_geojson(projection, feat) for feat in fc["features"]]
|
|
227
363
|
|
|
364
|
+
def decode_vector(
|
|
365
|
+
self, path: UPath, projection: Projection, bounds: PixelBounds
|
|
366
|
+
) -> list[Feature]:
|
|
367
|
+
"""Decodes vector data.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
path: the directory to read from
|
|
371
|
+
projection: the projection to read the data in
|
|
372
|
+
bounds: the bounds to read under the given projection. Only features that
|
|
373
|
+
intersect the bounds should be returned.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
the vector data
|
|
377
|
+
"""
|
|
378
|
+
features = self.decode_from_file(path / self.fname)
|
|
379
|
+
|
|
380
|
+
# Re-project to the desired projection and clip to bounds.
|
|
381
|
+
dst_geom = STGeometry(projection, shapely.box(*bounds), None)
|
|
382
|
+
reprojected_geoms = safely_reproject_and_clip(
|
|
383
|
+
[feat.geometry for feat in features], dst_geom
|
|
384
|
+
)
|
|
385
|
+
reprojected_features = []
|
|
386
|
+
for feat, geom in zip(features, reprojected_geoms):
|
|
387
|
+
if geom is None:
|
|
388
|
+
# None value means that it did not intersect the provided bounds.
|
|
389
|
+
continue
|
|
390
|
+
reprojected_features.append(Feature(geom, feat.properties))
|
|
391
|
+
|
|
392
|
+
return reprojected_features
|
|
393
|
+
|
|
228
394
|
@staticmethod
|
|
229
395
|
def from_config(name: str, config: dict[str, Any]) -> "GeojsonVectorFormat":
|
|
230
396
|
"""Create a GeojsonVectorFormat from a config dict.
|
|
@@ -236,18 +402,7 @@ class GeojsonVectorFormat(VectorFormat):
|
|
|
236
402
|
Returns:
|
|
237
403
|
the GeojsonVectorFormat
|
|
238
404
|
"""
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
"""Loads a VectorFormat from a VectorFormatConfig.
|
|
244
|
-
|
|
245
|
-
Args:
|
|
246
|
-
config: the VectorFormatConfig configuration object specifying the
|
|
247
|
-
VectorFormat.
|
|
248
|
-
|
|
249
|
-
Returns:
|
|
250
|
-
the loaded VectorFormat implementation
|
|
251
|
-
"""
|
|
252
|
-
cls = VectorFormats.get_class(config.name)
|
|
253
|
-
return cls.from_config(config.name, config.config_dict)
|
|
405
|
+
kwargs = {}
|
|
406
|
+
if "coordinate_mode" in config:
|
|
407
|
+
kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
|
|
408
|
+
return GeojsonVectorFormat(**kwargs)
|