rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/utils/geometry.py
CHANGED
|
@@ -1,18 +1,31 @@
|
|
|
1
1
|
"""Spatiotemporal geometry utilities."""
|
|
2
2
|
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import Sequence
|
|
3
5
|
from datetime import datetime, timedelta
|
|
4
6
|
from typing import Any
|
|
5
7
|
|
|
6
8
|
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
7
10
|
import rasterio.warp
|
|
8
11
|
import shapely
|
|
9
12
|
import shapely.wkt
|
|
10
13
|
from rasterio.crs import CRS
|
|
11
14
|
|
|
12
|
-
|
|
13
|
-
|
|
15
|
+
from rslearn.log_utils import get_logger
|
|
14
16
|
|
|
17
|
+
logger = get_logger(__name__)
|
|
15
18
|
PixelBounds = tuple[int, int, int, int]
|
|
19
|
+
FloatBounds = tuple[float, float, float, float]
|
|
20
|
+
|
|
21
|
+
RESOLUTION_EPSILON = 1e-6
|
|
22
|
+
WGS84_EPSG = 4326
|
|
23
|
+
WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
|
|
24
|
+
|
|
25
|
+
# Threshold in degrees above which a geometry is probably not going to re-project
|
|
26
|
+
# correctly due to projections with limited validity and other issues.
|
|
27
|
+
# 6 degrees corresponds to the UTM zone interval.
|
|
28
|
+
MAX_GEOMETRY_DEGREES = 6
|
|
16
29
|
|
|
17
30
|
|
|
18
31
|
def is_same_resolution(res1: float, res2: float) -> bool:
|
|
@@ -20,7 +33,7 @@ def is_same_resolution(res1: float, res2: float) -> bool:
|
|
|
20
33
|
return (max(res1, res2) / min(res1, res2) - 1) < RESOLUTION_EPSILON
|
|
21
34
|
|
|
22
35
|
|
|
23
|
-
def shp_intersects(shp1: shapely.Geometry, shp2: shapely.Geometry):
|
|
36
|
+
def shp_intersects(shp1: shapely.Geometry, shp2: shapely.Geometry) -> bool:
|
|
24
37
|
"""Returns whether the two shapes intersect.
|
|
25
38
|
|
|
26
39
|
Tries shp.intersects but falls back to shp.intersection which can be more
|
|
@@ -98,6 +111,84 @@ class Projection:
|
|
|
98
111
|
)
|
|
99
112
|
|
|
100
113
|
|
|
114
|
+
# The Projection for WGS-84 assuming 1 degree per pixel.
|
|
115
|
+
# This can be used to create STGeometry with shapes in longitude/latitude coordinates.
|
|
116
|
+
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ResolutionFactor:
|
|
120
|
+
"""Multiplier for the resolution in a Projection.
|
|
121
|
+
|
|
122
|
+
The multiplier is either an integer x, or the inverse of an integer (1/x).
|
|
123
|
+
|
|
124
|
+
Factors greater than 1 increase the projection_units/pixel resolution, increasing
|
|
125
|
+
the resolution (more pixels per projection unit). Factors less than 1 make it coarser
|
|
126
|
+
(less pixels).
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, numerator: int = 1, denominator: int = 1):
|
|
130
|
+
"""Create a new ResolutionFactor.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
numerator: the numerator of the fraction.
|
|
134
|
+
denominator: the denominator of the fraction. If set, numerator must be 1.
|
|
135
|
+
"""
|
|
136
|
+
if numerator != 1 and denominator != 1:
|
|
137
|
+
raise ValueError("one of numerator or denominator must be 1")
|
|
138
|
+
if not isinstance(numerator, int) or not isinstance(denominator, int):
|
|
139
|
+
raise ValueError("numerator and denominator must be integers")
|
|
140
|
+
if numerator < 1 or denominator < 1:
|
|
141
|
+
raise ValueError("numerator and denominator must be >= 1")
|
|
142
|
+
self.numerator = numerator
|
|
143
|
+
self.denominator = denominator
|
|
144
|
+
|
|
145
|
+
def multiply_projection(self, projection: Projection) -> Projection:
|
|
146
|
+
"""Multiply the projection by this factor."""
|
|
147
|
+
if self.denominator > 1:
|
|
148
|
+
return Projection(
|
|
149
|
+
projection.crs,
|
|
150
|
+
projection.x_resolution * self.denominator,
|
|
151
|
+
projection.y_resolution * self.denominator,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
return Projection(
|
|
155
|
+
projection.crs,
|
|
156
|
+
projection.x_resolution // self.numerator,
|
|
157
|
+
projection.y_resolution // self.numerator,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
|
161
|
+
"""Multiply the bounds by this factor.
|
|
162
|
+
|
|
163
|
+
When coarsening, the width and height of the given bounds must be a multiple of
|
|
164
|
+
the denominator.
|
|
165
|
+
"""
|
|
166
|
+
if self.denominator > 1:
|
|
167
|
+
# Verify the width and height are multiples of the denominator.
|
|
168
|
+
# Otherwise the new width and height is not an integer.
|
|
169
|
+
width = bounds[2] - bounds[0]
|
|
170
|
+
height = bounds[3] - bounds[1]
|
|
171
|
+
if width % self.denominator != 0 or height % self.denominator != 0:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
|
|
174
|
+
)
|
|
175
|
+
# TODO: an offset could be introduced by bounds not being a multiple
|
|
176
|
+
# of the denominator -> will need to decide how to handle that.
|
|
177
|
+
return (
|
|
178
|
+
bounds[0] // self.denominator,
|
|
179
|
+
bounds[1] // self.denominator,
|
|
180
|
+
bounds[2] // self.denominator,
|
|
181
|
+
bounds[3] // self.denominator,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return (
|
|
185
|
+
bounds[0] * self.numerator,
|
|
186
|
+
bounds[1] * self.numerator,
|
|
187
|
+
bounds[2] * self.numerator,
|
|
188
|
+
bounds[3] * self.numerator,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
101
192
|
class STGeometry:
|
|
102
193
|
"""A spatiotemporal geometry.
|
|
103
194
|
|
|
@@ -166,7 +257,7 @@ class STGeometry:
|
|
|
166
257
|
|
|
167
258
|
def intersects_time_range(
|
|
168
259
|
self, time_range: tuple[datetime, datetime] | None
|
|
169
|
-
) ->
|
|
260
|
+
) -> bool:
|
|
170
261
|
"""Returns whether this geometry intersects the other time range."""
|
|
171
262
|
if self.time_range is None or time_range is None:
|
|
172
263
|
return True
|
|
@@ -176,6 +267,30 @@ class STGeometry:
|
|
|
176
267
|
return False
|
|
177
268
|
return True
|
|
178
269
|
|
|
270
|
+
def is_global(self) -> bool:
|
|
271
|
+
"""Returns whether this geometry has global spatial coverage.
|
|
272
|
+
|
|
273
|
+
Global coverage is indicated by a special geometry with WGS84 projection and
|
|
274
|
+
corners at (-180, -90, 180, 90) (see get_global_geometry).
|
|
275
|
+
"""
|
|
276
|
+
if self.projection != WGS84_PROJECTION:
|
|
277
|
+
return False
|
|
278
|
+
if self.shp != shapely.box(*WGS84_BOUNDS):
|
|
279
|
+
return False
|
|
280
|
+
return True
|
|
281
|
+
|
|
282
|
+
def is_too_large(self) -> bool:
|
|
283
|
+
"""Returns whether this geometry's spatial coverage is too large.
|
|
284
|
+
|
|
285
|
+
This means that it will likely have issues during re-projections and such.
|
|
286
|
+
"""
|
|
287
|
+
wgs84_bounds = self.to_projection(WGS84_PROJECTION).shp.bounds
|
|
288
|
+
if wgs84_bounds[2] - wgs84_bounds[0] > MAX_GEOMETRY_DEGREES:
|
|
289
|
+
return True
|
|
290
|
+
if wgs84_bounds[3] - wgs84_bounds[1] > MAX_GEOMETRY_DEGREES:
|
|
291
|
+
return True
|
|
292
|
+
return False
|
|
293
|
+
|
|
179
294
|
def intersects(self, other: "STGeometry") -> bool:
|
|
180
295
|
"""Returns whether this box intersects the other box."""
|
|
181
296
|
# Check temporal.
|
|
@@ -183,6 +298,9 @@ class STGeometry:
|
|
|
183
298
|
return False
|
|
184
299
|
|
|
185
300
|
# Check spatial.
|
|
301
|
+
if self.is_global() or other.is_global():
|
|
302
|
+
# One of the geometries indicates global coverage.
|
|
303
|
+
return True
|
|
186
304
|
# Need to reproject if projections don't match.
|
|
187
305
|
if other.projection != self.projection:
|
|
188
306
|
other = other.to_projection(self.projection)
|
|
@@ -194,7 +312,12 @@ class STGeometry:
|
|
|
194
312
|
def to_projection(self, projection: Projection) -> "STGeometry":
|
|
195
313
|
"""Transforms this geometry to the specified projection."""
|
|
196
314
|
|
|
197
|
-
def apply_resolution(
|
|
315
|
+
def apply_resolution(
|
|
316
|
+
array: np.ndarray,
|
|
317
|
+
x_resolution: float,
|
|
318
|
+
y_resolution: float,
|
|
319
|
+
forward: bool = True,
|
|
320
|
+
) -> np.ndarray:
|
|
198
321
|
if forward:
|
|
199
322
|
return np.stack(
|
|
200
323
|
[array[:, 0] / x_resolution, array[:, 1] / y_resolution], axis=1
|
|
@@ -215,8 +338,12 @@ class STGeometry:
|
|
|
215
338
|
),
|
|
216
339
|
)
|
|
217
340
|
# Change crs.
|
|
218
|
-
|
|
219
|
-
|
|
341
|
+
# We only apply transform_geom if the CRS doesn't match, because even if we
|
|
342
|
+
# call transform_geom with the same source and destination CRS, it takes
|
|
343
|
+
# several milliseconds.
|
|
344
|
+
if self.projection.crs != projection.crs:
|
|
345
|
+
shp = rasterio.warp.transform_geom(self.projection.crs, projection.crs, shp)
|
|
346
|
+
shp = shapely.geometry.shape(shp)
|
|
220
347
|
# Apply new resolution.
|
|
221
348
|
shp = shapely.transform(
|
|
222
349
|
shp,
|
|
@@ -224,6 +351,7 @@ class STGeometry:
|
|
|
224
351
|
array, projection.x_resolution, projection.y_resolution, forward=True
|
|
225
352
|
),
|
|
226
353
|
)
|
|
354
|
+
|
|
227
355
|
return STGeometry(projection, shp, self.time_range)
|
|
228
356
|
|
|
229
357
|
def __repr__(self) -> str:
|
|
@@ -260,3 +388,215 @@ class STGeometry:
|
|
|
260
388
|
else None
|
|
261
389
|
),
|
|
262
390
|
)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def get_global_geometry(time_range: tuple[datetime, datetime] | None) -> STGeometry:
|
|
394
|
+
"""Gets a geometry that indicates global spatial coverage for the given time range.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
time_range: the time range for the STGeometry.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
STGeometry with global spatial coverage and specified time range.
|
|
401
|
+
"""
|
|
402
|
+
return STGeometry(WGS84_PROJECTION, shapely.box(*WGS84_BOUNDS), time_range)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def flatten_shape(shp: shapely.Geometry) -> list[shapely.Geometry]:
|
|
406
|
+
"""Flatten the shape into a list of primitive shapes (Point, LineString, and Polygon).
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
shp: the shape, which could be a primitive shape like polygon or a collection.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
list of primitive shapes.
|
|
413
|
+
"""
|
|
414
|
+
if isinstance(
|
|
415
|
+
shp,
|
|
416
|
+
shapely.MultiPoint
|
|
417
|
+
| shapely.MultiLineString
|
|
418
|
+
| shapely.MultiPolygon
|
|
419
|
+
| shapely.GeometryCollection,
|
|
420
|
+
):
|
|
421
|
+
flat_list: list[shapely.Geometry] = []
|
|
422
|
+
for component in shp.geoms:
|
|
423
|
+
flat_list.extend(flatten_shape(component))
|
|
424
|
+
return flat_list
|
|
425
|
+
|
|
426
|
+
else:
|
|
427
|
+
return [shp]
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _collect_shapes(shapes: list[shapely.Geometry]) -> shapely.Geometry:
|
|
431
|
+
# Collect the shapes into an appropriate container.
|
|
432
|
+
flat_list: list[shapely.Geometry] = []
|
|
433
|
+
for shp in shapes:
|
|
434
|
+
flat_list.extend(flatten_shape(shp))
|
|
435
|
+
|
|
436
|
+
if len(flat_list) == 1:
|
|
437
|
+
return flat_list[0]
|
|
438
|
+
|
|
439
|
+
if all(isinstance(shp, shapely.Point) for shp in flat_list):
|
|
440
|
+
return shapely.MultiPoint(flat_list)
|
|
441
|
+
|
|
442
|
+
if all(isinstance(shp, shapely.LineString) for shp in flat_list):
|
|
443
|
+
return shapely.MultiLineString(flat_list)
|
|
444
|
+
|
|
445
|
+
if all(isinstance(shp, shapely.Polygon) for shp in flat_list):
|
|
446
|
+
return shapely.MultiPolygon(flat_list)
|
|
447
|
+
|
|
448
|
+
return shapely.GeometryCollection(flat_list)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def split_shape_at_antimeridian(
|
|
452
|
+
shp: shapely.Geometry, epsilon: float = 1e-6
|
|
453
|
+
) -> shapely.Geometry:
|
|
454
|
+
"""Split the given shape at the antimeridian.
|
|
455
|
+
|
|
456
|
+
The shape must be in WGS84 coordinates.
|
|
457
|
+
|
|
458
|
+
See split_at_antimeridian for details.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
shp: the shape to split.
|
|
462
|
+
epsilon: the padding in degrees.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
the split shape, in WGS84 projection.
|
|
466
|
+
"""
|
|
467
|
+
# We assume the shape is fine if:
|
|
468
|
+
# 1. It doesn't need padding (no coordinates close to +/- 180).
|
|
469
|
+
# 2. And all coordinates are either less than 90 or more than -90 (meaning the
|
|
470
|
+
# shape approaches the antimeridian on at most one side).
|
|
471
|
+
bounds = shp.bounds
|
|
472
|
+
if bounds[0] > -180 + epsilon and bounds[2] < 90:
|
|
473
|
+
return shp
|
|
474
|
+
if bounds[0] > -90 and bounds[2] < 180 - epsilon:
|
|
475
|
+
return shp
|
|
476
|
+
|
|
477
|
+
if isinstance(
|
|
478
|
+
shp,
|
|
479
|
+
shapely.MultiPoint
|
|
480
|
+
| shapely.MultiLineString
|
|
481
|
+
| shapely.MultiPolygon
|
|
482
|
+
| shapely.GeometryCollection,
|
|
483
|
+
):
|
|
484
|
+
return _collect_shapes(
|
|
485
|
+
[split_shape_at_antimeridian(component) for component in shp.geoms]
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if isinstance(shp, shapely.Point):
|
|
489
|
+
# Points only need padding.
|
|
490
|
+
lon = shp.x
|
|
491
|
+
if lon < -180 + epsilon:
|
|
492
|
+
lon = -180 + epsilon
|
|
493
|
+
if lon > 180 - epsilon:
|
|
494
|
+
lon = 180 - epsilon
|
|
495
|
+
return shapely.Point(lon, shp.y)
|
|
496
|
+
|
|
497
|
+
if isinstance(shp, shapely.LineString | shapely.Polygon):
|
|
498
|
+
# We add 360 to the negative coordinates and then separate the parts above and
|
|
499
|
+
# below 180.
|
|
500
|
+
def add360(array: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
501
|
+
new_array = array.copy()
|
|
502
|
+
new_array[new_array[:, 0] < 0, 0] += 360
|
|
503
|
+
return new_array
|
|
504
|
+
|
|
505
|
+
shp = shapely.transform(shp, add360)
|
|
506
|
+
|
|
507
|
+
positive_part = shapely.box(0, -90, 180 - epsilon, 90)
|
|
508
|
+
negative_part = shapely.box(180 + epsilon, -90, 360, 90)
|
|
509
|
+
positive_shp = shp.intersection(positive_part)
|
|
510
|
+
negative_shp = shp.intersection(negative_part)
|
|
511
|
+
negative_shp = shapely.transform(negative_shp, lambda coords: coords - [360, 0])
|
|
512
|
+
return _collect_shapes([positive_shp, negative_shp])
|
|
513
|
+
|
|
514
|
+
raise TypeError("Unsupported shape type")
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def split_at_antimeridian(geometry: STGeometry, epsilon: float = 1e-6) -> STGeometry:
|
|
518
|
+
"""Split lines and polygons in the given geometry at the antimeridian.
|
|
519
|
+
|
|
520
|
+
The returned geometry will always be in WGS84 projection.
|
|
521
|
+
|
|
522
|
+
Small padding is also introduced to ensure coordinates are a bit more than -180 or
|
|
523
|
+
a bit less than 180.
|
|
524
|
+
|
|
525
|
+
For example, if the input is a polygon:
|
|
526
|
+
|
|
527
|
+
Polygon([[-180, 10], [180, 11], [-179, 11], [-179, 10]])
|
|
528
|
+
|
|
529
|
+
Then it would be converted to:
|
|
530
|
+
|
|
531
|
+
Polygon([[-179.999999, 10], [-179,999999, 11], [-179, 11], [-179, 10]])
|
|
532
|
+
|
|
533
|
+
This function may produce unexpected results if the geometries span more than 90
|
|
534
|
+
degrees on either dimension.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
geometry: the geometry to split.
|
|
538
|
+
epsilon: the padding in degrees. It is equivalent to about 1 m at the equator.
|
|
539
|
+
We ensure no longitude coordinates are within this padding of +/- 180.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
the padded geometry, in WGS84 projection.
|
|
543
|
+
"""
|
|
544
|
+
# Convert to WGS84.
|
|
545
|
+
geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
546
|
+
new_shp = split_shape_at_antimeridian(geometry.shp, epsilon=epsilon)
|
|
547
|
+
return STGeometry(geometry.projection, new_shp, geometry.time_range)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def safely_reproject_and_clip(
|
|
551
|
+
src_geoms: Sequence[STGeometry], dst_geom: STGeometry
|
|
552
|
+
) -> Sequence[STGeometry | None]:
|
|
553
|
+
"""Re-project src_geoms into the projection of dst_geom.
|
|
554
|
+
|
|
555
|
+
The resulting geometries will be clipped to dst_geom. If there is no intersection
|
|
556
|
+
for an src_geom, then the result will be None. The list of results is returned.
|
|
557
|
+
|
|
558
|
+
This function addresses issues with direct re-projection (e.g. using
|
|
559
|
+
src_geom.to_projection(dst_geom.projection)), which may fail if the source geometry
|
|
560
|
+
is outside the area of use of the destination projection.
|
|
561
|
+
|
|
562
|
+
It will first check for compatibility in WGS84, and only proceed with re-projection
|
|
563
|
+
if the geometries intersect.
|
|
564
|
+
|
|
565
|
+
This function may produce unexpected results if the geometries span more than 90
|
|
566
|
+
degrees on either dimension.
|
|
567
|
+
"""
|
|
568
|
+
|
|
569
|
+
# We cache re-projecting the destination geometry to WGS84 since the re-projection
|
|
570
|
+
# can be costly. This also avoids re-projecting in case all the src_geoms are
|
|
571
|
+
# already in the same projection as dst_geom.
|
|
572
|
+
@functools.cache
|
|
573
|
+
def get_dst_geom_wgs84() -> STGeometry:
|
|
574
|
+
"""Lazily compute and cache dst_geom in WGS84 projection."""
|
|
575
|
+
return split_at_antimeridian(dst_geom.to_projection(WGS84_PROJECTION))
|
|
576
|
+
|
|
577
|
+
def intersects_in_wgs84(src_geom: STGeometry) -> bool:
|
|
578
|
+
"""Return False if there is no intersection."""
|
|
579
|
+
src_geom_wgs84 = split_at_antimeridian(src_geom.to_projection(WGS84_PROJECTION))
|
|
580
|
+
return src_geom_wgs84.intersects(get_dst_geom_wgs84())
|
|
581
|
+
|
|
582
|
+
results: list[STGeometry | None] = []
|
|
583
|
+
for src_geom in src_geoms:
|
|
584
|
+
# Only do the extra check in WGS84 if the projections don't already match.
|
|
585
|
+
if (
|
|
586
|
+
src_geom.projection.crs != dst_geom.projection.crs
|
|
587
|
+
and not intersects_in_wgs84(src_geom)
|
|
588
|
+
):
|
|
589
|
+
results.append(None)
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
src_geom_in_dst_projection = src_geom.to_projection(dst_geom.projection)
|
|
593
|
+
if not shp_intersects(src_geom_in_dst_projection.shp, dst_geom.shp):
|
|
594
|
+
results.append(None)
|
|
595
|
+
continue
|
|
596
|
+
intersect_shp = src_geom_in_dst_projection.shp.intersection(dst_geom.shp)
|
|
597
|
+
intersect_geom = STGeometry(
|
|
598
|
+
dst_geom.projection, intersect_shp, src_geom.time_range
|
|
599
|
+
)
|
|
600
|
+
results.append(intersect_geom)
|
|
601
|
+
|
|
602
|
+
return results
|
rslearn/utils/get_utm_ups_crs.py
CHANGED
|
@@ -5,7 +5,7 @@ import pyproj.database
|
|
|
5
5
|
import shapely
|
|
6
6
|
from rasterio.crs import CRS
|
|
7
7
|
|
|
8
|
-
from rslearn.utils import Projection, STGeometry
|
|
8
|
+
from rslearn.utils.geometry import WGS84_PROJECTION, Projection, STGeometry
|
|
9
9
|
|
|
10
10
|
UPS_NORTH_EPSG = 5041
|
|
11
11
|
"""EPSG code for the UPS North CRS."""
|
|
@@ -121,8 +121,7 @@ def get_proj_bounds(utm_crs: CRS) -> tuple[float, float, float, float]:
|
|
|
121
121
|
"""
|
|
122
122
|
bounds = get_wgs84_bounds(utm_crs)
|
|
123
123
|
# Convert from WGS84 to the UTM zone.
|
|
124
|
-
src_proj = Projection(CRS.from_epsg(4326), 1, 1)
|
|
125
124
|
dst_proj = Projection(utm_crs, 1, 1)
|
|
126
125
|
shp = shapely.box(*bounds)
|
|
127
|
-
result = STGeometry(
|
|
126
|
+
result = STGeometry(WGS84_PROJECTION, shp, None).to_projection(dst_proj).shp
|
|
128
127
|
return result.bounds
|
rslearn/utils/grid_index.py
CHANGED
|
@@ -13,15 +13,15 @@ class GridIndex(SpatialIndex):
|
|
|
13
13
|
Each cell in the grid contains a list of geometries that intersect it.
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
|
-
def __init__(self, size):
|
|
16
|
+
def __init__(self, size: float) -> None:
|
|
17
17
|
"""Initialize a new GridIndex.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
20
20
|
size: the size of the grid cells
|
|
21
21
|
"""
|
|
22
22
|
self.size = size
|
|
23
|
-
self.grid = {}
|
|
24
|
-
self.items = []
|
|
23
|
+
self.grid: dict = {}
|
|
24
|
+
self.items: list = []
|
|
25
25
|
|
|
26
26
|
def insert(self, box: tuple[float, float, float, float], data: Any) -> None:
|
|
27
27
|
"""Insert a box into the index.
|
|
@@ -33,7 +33,7 @@ class GridIndex(SpatialIndex):
|
|
|
33
33
|
item_idx = len(self.items)
|
|
34
34
|
self.items.append(data)
|
|
35
35
|
|
|
36
|
-
def f(cell):
|
|
36
|
+
def f(cell: tuple[int, int]) -> None:
|
|
37
37
|
if cell not in self.grid:
|
|
38
38
|
self.grid[cell] = []
|
|
39
39
|
self.grid[cell].append(item_idx)
|
|
@@ -71,7 +71,7 @@ class GridIndex(SpatialIndex):
|
|
|
71
71
|
"""
|
|
72
72
|
matches = set()
|
|
73
73
|
|
|
74
|
-
def f(cell):
|
|
74
|
+
def f(cell: tuple[int, int]) -> None:
|
|
75
75
|
if cell not in self.grid:
|
|
76
76
|
return
|
|
77
77
|
for item_idx in self.grid[cell]:
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Custom serialization for jsonargparse."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
import jsonargparse
|
|
7
|
+
from rasterio.crs import CRS
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
from rslearn.utils.geometry import ResolutionFactor
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
15
|
+
|
|
16
|
+
INITIALIZED = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def crs_serializer(v: CRS) -> str:
|
|
20
|
+
"""Serialize CRS for jsonargparse.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
v: the CRS object.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
the CRS encoded to string
|
|
27
|
+
"""
|
|
28
|
+
return v.to_string()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def crs_deserializer(v: str) -> CRS:
|
|
32
|
+
"""Deserialize CRS for jsonargparse.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
v: the encoded CRS.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
the decoded CRS object
|
|
39
|
+
"""
|
|
40
|
+
return CRS.from_string(v)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def datetime_serializer(v: datetime) -> str:
|
|
44
|
+
"""Serialize datetime for jsonargparse.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
v: the datetime object.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
the datetime encoded to string
|
|
51
|
+
"""
|
|
52
|
+
return v.isoformat()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def datetime_deserializer(v: str) -> datetime:
|
|
56
|
+
"""Deserialize datetime for jsonargparse.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
v: the encoded datetime.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
the decoded datetime object
|
|
63
|
+
"""
|
|
64
|
+
return datetime.fromisoformat(v)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
|
|
68
|
+
"""Serialize DataSourceContext for jsonargparse."""
|
|
69
|
+
x = {
|
|
70
|
+
"ds_path": (str(v.ds_path) if v.ds_path is not None else None),
|
|
71
|
+
"layer_config": (
|
|
72
|
+
v.layer_config.model_dump(mode="json")
|
|
73
|
+
if v.layer_config is not None
|
|
74
|
+
else None
|
|
75
|
+
),
|
|
76
|
+
}
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
81
|
+
"""Deserialize DataSourceContext for jsonargparse."""
|
|
82
|
+
# We lazily import these to avoid cyclic dependency.
|
|
83
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
84
|
+
|
|
85
|
+
return DataSourceContext(
|
|
86
|
+
ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
|
|
87
|
+
layer_config=(
|
|
88
|
+
LayerConfig.model_validate(v["layer_config"])
|
|
89
|
+
if v["layer_config"] is not None
|
|
90
|
+
else None
|
|
91
|
+
),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def resolution_factor_serializer(v: ResolutionFactor) -> str:
|
|
96
|
+
"""Serialize ResolutionFactor for jsonargparse.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
v: the ResolutionFactor object.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
the ResolutionFactor encoded to string
|
|
103
|
+
"""
|
|
104
|
+
if hasattr(v, "init_args"):
|
|
105
|
+
init_args = v.init_args
|
|
106
|
+
return f"{init_args.numerator}/{init_args.denominator}"
|
|
107
|
+
|
|
108
|
+
return f"{v.numerator}/{v.denominator}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
|
|
112
|
+
"""Deserialize ResolutionFactor for jsonargparse.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
v: the encoded ResolutionFactor.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
the decoded ResolutionFactor object
|
|
119
|
+
"""
|
|
120
|
+
# Handle already-instantiated ResolutionFactor
|
|
121
|
+
if isinstance(v, ResolutionFactor):
|
|
122
|
+
return v
|
|
123
|
+
|
|
124
|
+
# Handle Namespace from class_path syntax (used during config save/validation)
|
|
125
|
+
if hasattr(v, "init_args"):
|
|
126
|
+
init_args = v.init_args
|
|
127
|
+
return ResolutionFactor(
|
|
128
|
+
numerator=init_args.numerator,
|
|
129
|
+
denominator=init_args.denominator,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Handle dict from class_path syntax in YAML config
|
|
133
|
+
if isinstance(v, dict) and "init_args" in v:
|
|
134
|
+
init_args = v["init_args"]
|
|
135
|
+
return ResolutionFactor(
|
|
136
|
+
numerator=init_args.get("numerator", 1),
|
|
137
|
+
denominator=init_args.get("denominator", 1),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if isinstance(v, int):
|
|
141
|
+
return ResolutionFactor(numerator=v)
|
|
142
|
+
elif isinstance(v, str):
|
|
143
|
+
parts = v.split("/")
|
|
144
|
+
if len(parts) == 1:
|
|
145
|
+
return ResolutionFactor(numerator=int(parts[0]))
|
|
146
|
+
elif len(parts) == 2:
|
|
147
|
+
return ResolutionFactor(
|
|
148
|
+
numerator=int(parts[0]),
|
|
149
|
+
denominator=int(parts[1]),
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("expected resolution factor to be of the form x or 1/x")
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError("expected resolution factor to be str or int")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def init_jsonargparse() -> None:
|
|
158
|
+
"""Initialize custom jsonargparse serializers."""
|
|
159
|
+
global INITIALIZED
|
|
160
|
+
if INITIALIZED:
|
|
161
|
+
return
|
|
162
|
+
jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
|
|
163
|
+
jsonargparse.typing.register_type(
|
|
164
|
+
datetime, datetime_serializer, datetime_deserializer
|
|
165
|
+
)
|
|
166
|
+
jsonargparse.typing.register_type(
|
|
167
|
+
ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
171
|
+
|
|
172
|
+
jsonargparse.typing.register_type(
|
|
173
|
+
DataSourceContext,
|
|
174
|
+
data_source_context_serializer,
|
|
175
|
+
data_source_context_deserializer,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
INITIALIZED = True
|
rslearn/utils/mp.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Multi-processing utilities."""
|
|
2
2
|
|
|
3
3
|
import multiprocessing.pool
|
|
4
|
-
from collections.abc import Callable
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from multiprocessing.pool import IMapIterator
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
|
|
@@ -20,7 +21,7 @@ class StarImapUnorderedWrapper:
|
|
|
20
21
|
"""
|
|
21
22
|
self.fn = fn
|
|
22
23
|
|
|
23
|
-
def __call__(self, kwargs: dict[str, Any]):
|
|
24
|
+
def __call__(self, kwargs: dict[str, Any]) -> Any:
|
|
24
25
|
"""Wrapped call to the underlying function.
|
|
25
26
|
|
|
26
27
|
Args:
|
|
@@ -33,7 +34,7 @@ def star_imap_unordered(
|
|
|
33
34
|
p: multiprocessing.pool.Pool,
|
|
34
35
|
fn: Callable[..., Any],
|
|
35
36
|
kwargs_list: list[dict[str, Any]],
|
|
36
|
-
) ->
|
|
37
|
+
) -> IMapIterator:
|
|
37
38
|
"""Wrapper for Pool.imap_unordered that exposes kwargs to the function.
|
|
38
39
|
|
|
39
40
|
Args:
|