rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/utils/geometry.py CHANGED
@@ -1,18 +1,31 @@
1
1
  """Spatiotemporal geometry utilities."""
2
2
 
3
+ import functools
4
+ from collections.abc import Sequence
3
5
  from datetime import datetime, timedelta
4
6
  from typing import Any
5
7
 
6
8
  import numpy as np
9
+ import numpy.typing as npt
7
10
  import rasterio.warp
8
11
  import shapely
9
12
  import shapely.wkt
10
13
  from rasterio.crs import CRS
11
14
 
12
- RESOLUTION_EPSILON = 1e-6
13
-
15
+ from rslearn.log_utils import get_logger
14
16
 
17
+ logger = get_logger(__name__)
15
18
  PixelBounds = tuple[int, int, int, int]
19
+ FloatBounds = tuple[float, float, float, float]
20
+
21
+ RESOLUTION_EPSILON = 1e-6
22
+ WGS84_EPSG = 4326
23
+ WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
24
+
25
+ # Threshold in degrees above which a geometry is probably not going to re-project
26
+ # correctly due to projections with limited validity and other issues.
27
+ # 6 degrees corresponds to the UTM zone interval.
28
+ MAX_GEOMETRY_DEGREES = 6
16
29
 
17
30
 
18
31
  def is_same_resolution(res1: float, res2: float) -> bool:
@@ -20,7 +33,7 @@ def is_same_resolution(res1: float, res2: float) -> bool:
20
33
  return (max(res1, res2) / min(res1, res2) - 1) < RESOLUTION_EPSILON
21
34
 
22
35
 
23
- def shp_intersects(shp1: shapely.Geometry, shp2: shapely.Geometry):
36
+ def shp_intersects(shp1: shapely.Geometry, shp2: shapely.Geometry) -> bool:
24
37
  """Returns whether the two shapes intersect.
25
38
 
26
39
  Tries shp.intersects but falls back to shp.intersection which can be more
@@ -98,6 +111,84 @@ class Projection:
98
111
  )
99
112
 
100
113
 
114
+ # The Projection for WGS-84 assuming 1 degree per pixel.
115
+ # This can be used to create STGeometry with shapes in longitude/latitude coordinates.
116
+ WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
+
118
+
119
+ class ResolutionFactor:
120
+ """Multiplier for the resolution in a Projection.
121
+
122
+ The multiplier is either an integer x, or the inverse of an integer (1/x).
123
+
124
+ Factors greater than 1 increase the projection_units/pixel resolution, increasing
125
+ the resolution (more pixels per projection unit). Factors less than 1 make it coarser
126
+ (less pixels).
127
+ """
128
+
129
+ def __init__(self, numerator: int = 1, denominator: int = 1):
130
+ """Create a new ResolutionFactor.
131
+
132
+ Args:
133
+ numerator: the numerator of the fraction.
134
+ denominator: the denominator of the fraction. If set, numerator must be 1.
135
+ """
136
+ if numerator != 1 and denominator != 1:
137
+ raise ValueError("one of numerator or denominator must be 1")
138
+ if not isinstance(numerator, int) or not isinstance(denominator, int):
139
+ raise ValueError("numerator and denominator must be integers")
140
+ if numerator < 1 or denominator < 1:
141
+ raise ValueError("numerator and denominator must be >= 1")
142
+ self.numerator = numerator
143
+ self.denominator = denominator
144
+
145
+ def multiply_projection(self, projection: Projection) -> Projection:
146
+ """Multiply the projection by this factor."""
147
+ if self.denominator > 1:
148
+ return Projection(
149
+ projection.crs,
150
+ projection.x_resolution * self.denominator,
151
+ projection.y_resolution * self.denominator,
152
+ )
153
+ else:
154
+ return Projection(
155
+ projection.crs,
156
+ projection.x_resolution // self.numerator,
157
+ projection.y_resolution // self.numerator,
158
+ )
159
+
160
+ def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
161
+ """Multiply the bounds by this factor.
162
+
163
+ When coarsening, the width and height of the given bounds must be a multiple of
164
+ the denominator.
165
+ """
166
+ if self.denominator > 1:
167
+ # Verify the width and height are multiples of the denominator.
168
+ # Otherwise the new width and height is not an integer.
169
+ width = bounds[2] - bounds[0]
170
+ height = bounds[3] - bounds[1]
171
+ if width % self.denominator != 0 or height % self.denominator != 0:
172
+ raise ValueError(
173
+ f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
174
+ )
175
+ # TODO: an offset could be introduced by bounds not being a multiple
176
+ # of the denominator -> will need to decide how to handle that.
177
+ return (
178
+ bounds[0] // self.denominator,
179
+ bounds[1] // self.denominator,
180
+ bounds[2] // self.denominator,
181
+ bounds[3] // self.denominator,
182
+ )
183
+ else:
184
+ return (
185
+ bounds[0] * self.numerator,
186
+ bounds[1] * self.numerator,
187
+ bounds[2] * self.numerator,
188
+ bounds[3] * self.numerator,
189
+ )
190
+
191
+
101
192
  class STGeometry:
102
193
  """A spatiotemporal geometry.
103
194
 
@@ -166,7 +257,7 @@ class STGeometry:
166
257
 
167
258
  def intersects_time_range(
168
259
  self, time_range: tuple[datetime, datetime] | None
169
- ) -> timedelta:
260
+ ) -> bool:
170
261
  """Returns whether this geometry intersects the other time range."""
171
262
  if self.time_range is None or time_range is None:
172
263
  return True
@@ -176,6 +267,30 @@ class STGeometry:
176
267
  return False
177
268
  return True
178
269
 
270
+ def is_global(self) -> bool:
271
+ """Returns whether this geometry has global spatial coverage.
272
+
273
+ Global coverage is indicated by a special geometry with WGS84 projection and
274
+ corners at (-180, -90, 180, 90) (see get_global_geometry).
275
+ """
276
+ if self.projection != WGS84_PROJECTION:
277
+ return False
278
+ if self.shp != shapely.box(*WGS84_BOUNDS):
279
+ return False
280
+ return True
281
+
282
+ def is_too_large(self) -> bool:
283
+ """Returns whether this geometry's spatial coverage is too large.
284
+
285
+ This means that it will likely have issues during re-projections and such.
286
+ """
287
+ wgs84_bounds = self.to_projection(WGS84_PROJECTION).shp.bounds
288
+ if wgs84_bounds[2] - wgs84_bounds[0] > MAX_GEOMETRY_DEGREES:
289
+ return True
290
+ if wgs84_bounds[3] - wgs84_bounds[1] > MAX_GEOMETRY_DEGREES:
291
+ return True
292
+ return False
293
+
179
294
  def intersects(self, other: "STGeometry") -> bool:
180
295
  """Returns whether this box intersects the other box."""
181
296
  # Check temporal.
@@ -183,6 +298,9 @@ class STGeometry:
183
298
  return False
184
299
 
185
300
  # Check spatial.
301
+ if self.is_global() or other.is_global():
302
+ # One of the geometries indicates global coverage.
303
+ return True
186
304
  # Need to reproject if projections don't match.
187
305
  if other.projection != self.projection:
188
306
  other = other.to_projection(self.projection)
@@ -194,7 +312,12 @@ class STGeometry:
194
312
  def to_projection(self, projection: Projection) -> "STGeometry":
195
313
  """Transforms this geometry to the specified projection."""
196
314
 
197
- def apply_resolution(array, x_resolution, y_resolution, forward=True):
315
+ def apply_resolution(
316
+ array: np.ndarray,
317
+ x_resolution: float,
318
+ y_resolution: float,
319
+ forward: bool = True,
320
+ ) -> np.ndarray:
198
321
  if forward:
199
322
  return np.stack(
200
323
  [array[:, 0] / x_resolution, array[:, 1] / y_resolution], axis=1
@@ -215,8 +338,12 @@ class STGeometry:
215
338
  ),
216
339
  )
217
340
  # Change crs.
218
- shp = rasterio.warp.transform_geom(self.projection.crs, projection.crs, shp)
219
- shp = shapely.geometry.shape(shp)
341
+ # We only apply transform_geom if the CRS doesn't match, because even if we
342
+ # call transform_geom with the same source and destination CRS, it takes
343
+ # several milliseconds.
344
+ if self.projection.crs != projection.crs:
345
+ shp = rasterio.warp.transform_geom(self.projection.crs, projection.crs, shp)
346
+ shp = shapely.geometry.shape(shp)
220
347
  # Apply new resolution.
221
348
  shp = shapely.transform(
222
349
  shp,
@@ -224,6 +351,7 @@ class STGeometry:
224
351
  array, projection.x_resolution, projection.y_resolution, forward=True
225
352
  ),
226
353
  )
354
+
227
355
  return STGeometry(projection, shp, self.time_range)
228
356
 
229
357
  def __repr__(self) -> str:
@@ -260,3 +388,215 @@ class STGeometry:
260
388
  else None
261
389
  ),
262
390
  )
391
+
392
+
393
+ def get_global_geometry(time_range: tuple[datetime, datetime] | None) -> STGeometry:
394
+ """Gets a geometry that indicates global spatial coverage for the given time range.
395
+
396
+ Args:
397
+ time_range: the time range for the STGeometry.
398
+
399
+ Returns:
400
+ STGeometry with global spatial coverage and specified time range.
401
+ """
402
+ return STGeometry(WGS84_PROJECTION, shapely.box(*WGS84_BOUNDS), time_range)
403
+
404
+
405
+ def flatten_shape(shp: shapely.Geometry) -> list[shapely.Geometry]:
406
+ """Flatten the shape into a list of primitive shapes (Point, LineString, and Polygon).
407
+
408
+ Args:
409
+ shp: the shape, which could be a primitive shape like polygon or a collection.
410
+
411
+ Returns:
412
+ list of primitive shapes.
413
+ """
414
+ if isinstance(
415
+ shp,
416
+ shapely.MultiPoint
417
+ | shapely.MultiLineString
418
+ | shapely.MultiPolygon
419
+ | shapely.GeometryCollection,
420
+ ):
421
+ flat_list: list[shapely.Geometry] = []
422
+ for component in shp.geoms:
423
+ flat_list.extend(flatten_shape(component))
424
+ return flat_list
425
+
426
+ else:
427
+ return [shp]
428
+
429
+
430
+ def _collect_shapes(shapes: list[shapely.Geometry]) -> shapely.Geometry:
431
+ # Collect the shapes into an appropriate container.
432
+ flat_list: list[shapely.Geometry] = []
433
+ for shp in shapes:
434
+ flat_list.extend(flatten_shape(shp))
435
+
436
+ if len(flat_list) == 1:
437
+ return flat_list[0]
438
+
439
+ if all(isinstance(shp, shapely.Point) for shp in flat_list):
440
+ return shapely.MultiPoint(flat_list)
441
+
442
+ if all(isinstance(shp, shapely.LineString) for shp in flat_list):
443
+ return shapely.MultiLineString(flat_list)
444
+
445
+ if all(isinstance(shp, shapely.Polygon) for shp in flat_list):
446
+ return shapely.MultiPolygon(flat_list)
447
+
448
+ return shapely.GeometryCollection(flat_list)
449
+
450
+
451
+ def split_shape_at_antimeridian(
452
+ shp: shapely.Geometry, epsilon: float = 1e-6
453
+ ) -> shapely.Geometry:
454
+ """Split the given shape at the antimeridian.
455
+
456
+ The shape must be in WGS84 coordinates.
457
+
458
+ See split_at_antimeridian for details.
459
+
460
+ Args:
461
+ shp: the shape to split.
462
+ epsilon: the padding in degrees.
463
+
464
+ Returns:
465
+ the split shape, in WGS84 projection.
466
+ """
467
+ # We assume the shape is fine if:
468
+ # 1. It doesn't need padding (no coordinates close to +/- 180).
469
+ # 2. And all coordinates are either less than 90 or more than -90 (meaning the
470
+ # shape approaches the antimeridian on at most one side).
471
+ bounds = shp.bounds
472
+ if bounds[0] > -180 + epsilon and bounds[2] < 90:
473
+ return shp
474
+ if bounds[0] > -90 and bounds[2] < 180 - epsilon:
475
+ return shp
476
+
477
+ if isinstance(
478
+ shp,
479
+ shapely.MultiPoint
480
+ | shapely.MultiLineString
481
+ | shapely.MultiPolygon
482
+ | shapely.GeometryCollection,
483
+ ):
484
+ return _collect_shapes(
485
+ [split_shape_at_antimeridian(component) for component in shp.geoms]
486
+ )
487
+
488
+ if isinstance(shp, shapely.Point):
489
+ # Points only need padding.
490
+ lon = shp.x
491
+ if lon < -180 + epsilon:
492
+ lon = -180 + epsilon
493
+ if lon > 180 - epsilon:
494
+ lon = 180 - epsilon
495
+ return shapely.Point(lon, shp.y)
496
+
497
+ if isinstance(shp, shapely.LineString | shapely.Polygon):
498
+ # We add 360 to the negative coordinates and then separate the parts above and
499
+ # below 180.
500
+ def add360(array: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
501
+ new_array = array.copy()
502
+ new_array[new_array[:, 0] < 0, 0] += 360
503
+ return new_array
504
+
505
+ shp = shapely.transform(shp, add360)
506
+
507
+ positive_part = shapely.box(0, -90, 180 - epsilon, 90)
508
+ negative_part = shapely.box(180 + epsilon, -90, 360, 90)
509
+ positive_shp = shp.intersection(positive_part)
510
+ negative_shp = shp.intersection(negative_part)
511
+ negative_shp = shapely.transform(negative_shp, lambda coords: coords - [360, 0])
512
+ return _collect_shapes([positive_shp, negative_shp])
513
+
514
+ raise TypeError("Unsupported shape type")
515
+
516
+
517
+ def split_at_antimeridian(geometry: STGeometry, epsilon: float = 1e-6) -> STGeometry:
518
+ """Split lines and polygons in the given geometry at the antimeridian.
519
+
520
+ The returned geometry will always be in WGS84 projection.
521
+
522
+ Small padding is also introduced to ensure coordinates are a bit more than -180 or
523
+ a bit less than 180.
524
+
525
+ For example, if the input is a polygon:
526
+
527
+ Polygon([[-180, 10], [180, 11], [-179, 11], [-179, 10]])
528
+
529
+ Then it would be converted to:
530
+
531
+ Polygon([[-179.999999, 10], [-179,999999, 11], [-179, 11], [-179, 10]])
532
+
533
+ This function may produce unexpected results if the geometries span more than 90
534
+ degrees on either dimension.
535
+
536
+ Args:
537
+ geometry: the geometry to split.
538
+ epsilon: the padding in degrees. It is equivalent to about 1 m at the equator.
539
+ We ensure no longitude coordinates are within this padding of +/- 180.
540
+
541
+ Returns:
542
+ the padded geometry, in WGS84 projection.
543
+ """
544
+ # Convert to WGS84.
545
+ geometry = geometry.to_projection(WGS84_PROJECTION)
546
+ new_shp = split_shape_at_antimeridian(geometry.shp, epsilon=epsilon)
547
+ return STGeometry(geometry.projection, new_shp, geometry.time_range)
548
+
549
+
550
+ def safely_reproject_and_clip(
551
+ src_geoms: Sequence[STGeometry], dst_geom: STGeometry
552
+ ) -> Sequence[STGeometry | None]:
553
+ """Re-project src_geoms into the projection of dst_geom.
554
+
555
+ The resulting geometries will be clipped to dst_geom. If there is no intersection
556
+ for an src_geom, then the result will be None. The list of results is returned.
557
+
558
+ This function addresses issues with direct re-projection (e.g. using
559
+ src_geom.to_projection(dst_geom.projection)), which may fail if the source geometry
560
+ is outside the area of use of the destination projection.
561
+
562
+ It will first check for compatibility in WGS84, and only proceed with re-projection
563
+ if the geometries intersect.
564
+
565
+ This function may produce unexpected results if the geometries span more than 90
566
+ degrees on either dimension.
567
+ """
568
+
569
+ # We cache re-projecting the destination geometry to WGS84 since the re-projection
570
+ # can be costly. This also avoids re-projecting in case all the src_geoms are
571
+ # already in the same projection as dst_geom.
572
+ @functools.cache
573
+ def get_dst_geom_wgs84() -> STGeometry:
574
+ """Lazily compute and cache dst_geom in WGS84 projection."""
575
+ return split_at_antimeridian(dst_geom.to_projection(WGS84_PROJECTION))
576
+
577
+ def intersects_in_wgs84(src_geom: STGeometry) -> bool:
578
+ """Return False if there is no intersection."""
579
+ src_geom_wgs84 = split_at_antimeridian(src_geom.to_projection(WGS84_PROJECTION))
580
+ return src_geom_wgs84.intersects(get_dst_geom_wgs84())
581
+
582
+ results: list[STGeometry | None] = []
583
+ for src_geom in src_geoms:
584
+ # Only do the extra check in WGS84 if the projections don't already match.
585
+ if (
586
+ src_geom.projection.crs != dst_geom.projection.crs
587
+ and not intersects_in_wgs84(src_geom)
588
+ ):
589
+ results.append(None)
590
+ continue
591
+
592
+ src_geom_in_dst_projection = src_geom.to_projection(dst_geom.projection)
593
+ if not shp_intersects(src_geom_in_dst_projection.shp, dst_geom.shp):
594
+ results.append(None)
595
+ continue
596
+ intersect_shp = src_geom_in_dst_projection.shp.intersection(dst_geom.shp)
597
+ intersect_geom = STGeometry(
598
+ dst_geom.projection, intersect_shp, src_geom.time_range
599
+ )
600
+ results.append(intersect_geom)
601
+
602
+ return results
@@ -5,7 +5,7 @@ import pyproj.database
5
5
  import shapely
6
6
  from rasterio.crs import CRS
7
7
 
8
- from rslearn.utils import Projection, STGeometry
8
+ from rslearn.utils.geometry import WGS84_PROJECTION, Projection, STGeometry
9
9
 
10
10
  UPS_NORTH_EPSG = 5041
11
11
  """EPSG code for the UPS North CRS."""
@@ -121,8 +121,7 @@ def get_proj_bounds(utm_crs: CRS) -> tuple[float, float, float, float]:
121
121
  """
122
122
  bounds = get_wgs84_bounds(utm_crs)
123
123
  # Convert from WGS84 to the UTM zone.
124
- src_proj = Projection(CRS.from_epsg(4326), 1, 1)
125
124
  dst_proj = Projection(utm_crs, 1, 1)
126
125
  shp = shapely.box(*bounds)
127
- result = STGeometry(src_proj, shp, None).to_projection(dst_proj).shp
126
+ result = STGeometry(WGS84_PROJECTION, shp, None).to_projection(dst_proj).shp
128
127
  return result.bounds
@@ -13,15 +13,15 @@ class GridIndex(SpatialIndex):
13
13
  Each cell in the grid contains a list of geometries that intersect it.
14
14
  """
15
15
 
16
- def __init__(self, size):
16
+ def __init__(self, size: float) -> None:
17
17
  """Initialize a new GridIndex.
18
18
 
19
19
  Args:
20
20
  size: the size of the grid cells
21
21
  """
22
22
  self.size = size
23
- self.grid = {}
24
- self.items = []
23
+ self.grid: dict = {}
24
+ self.items: list = []
25
25
 
26
26
  def insert(self, box: tuple[float, float, float, float], data: Any) -> None:
27
27
  """Insert a box into the index.
@@ -33,7 +33,7 @@ class GridIndex(SpatialIndex):
33
33
  item_idx = len(self.items)
34
34
  self.items.append(data)
35
35
 
36
- def f(cell):
36
+ def f(cell: tuple[int, int]) -> None:
37
37
  if cell not in self.grid:
38
38
  self.grid[cell] = []
39
39
  self.grid[cell].append(item_idx)
@@ -71,7 +71,7 @@ class GridIndex(SpatialIndex):
71
71
  """
72
72
  matches = set()
73
73
 
74
- def f(cell):
74
+ def f(cell: tuple[int, int]) -> None:
75
75
  if cell not in self.grid:
76
76
  return
77
77
  for item_idx in self.grid[cell]:
@@ -0,0 +1,178 @@
1
+ """Custom serialization for jsonargparse."""
2
+
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import jsonargparse
7
+ from rasterio.crs import CRS
8
+ from upath import UPath
9
+
10
+ from rslearn.config.dataset import LayerConfig
11
+ from rslearn.utils.geometry import ResolutionFactor
12
+
13
+ if TYPE_CHECKING:
14
+ from rslearn.data_sources.data_source import DataSourceContext
15
+
16
+ INITIALIZED = False
17
+
18
+
19
+ def crs_serializer(v: CRS) -> str:
20
+ """Serialize CRS for jsonargparse.
21
+
22
+ Args:
23
+ v: the CRS object.
24
+
25
+ Returns:
26
+ the CRS encoded to string
27
+ """
28
+ return v.to_string()
29
+
30
+
31
+ def crs_deserializer(v: str) -> CRS:
32
+ """Deserialize CRS for jsonargparse.
33
+
34
+ Args:
35
+ v: the encoded CRS.
36
+
37
+ Returns:
38
+ the decoded CRS object
39
+ """
40
+ return CRS.from_string(v)
41
+
42
+
43
+ def datetime_serializer(v: datetime) -> str:
44
+ """Serialize datetime for jsonargparse.
45
+
46
+ Args:
47
+ v: the datetime object.
48
+
49
+ Returns:
50
+ the datetime encoded to string
51
+ """
52
+ return v.isoformat()
53
+
54
+
55
+ def datetime_deserializer(v: str) -> datetime:
56
+ """Deserialize datetime for jsonargparse.
57
+
58
+ Args:
59
+ v: the encoded datetime.
60
+
61
+ Returns:
62
+ the decoded datetime object
63
+ """
64
+ return datetime.fromisoformat(v)
65
+
66
+
67
+ def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
68
+ """Serialize DataSourceContext for jsonargparse."""
69
+ x = {
70
+ "ds_path": (str(v.ds_path) if v.ds_path is not None else None),
71
+ "layer_config": (
72
+ v.layer_config.model_dump(mode="json")
73
+ if v.layer_config is not None
74
+ else None
75
+ ),
76
+ }
77
+ return x
78
+
79
+
80
+ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
81
+ """Deserialize DataSourceContext for jsonargparse."""
82
+ # We lazily import these to avoid cyclic dependency.
83
+ from rslearn.data_sources.data_source import DataSourceContext
84
+
85
+ return DataSourceContext(
86
+ ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
87
+ layer_config=(
88
+ LayerConfig.model_validate(v["layer_config"])
89
+ if v["layer_config"] is not None
90
+ else None
91
+ ),
92
+ )
93
+
94
+
95
+ def resolution_factor_serializer(v: ResolutionFactor) -> str:
96
+ """Serialize ResolutionFactor for jsonargparse.
97
+
98
+ Args:
99
+ v: the ResolutionFactor object.
100
+
101
+ Returns:
102
+ the ResolutionFactor encoded to string
103
+ """
104
+ if hasattr(v, "init_args"):
105
+ init_args = v.init_args
106
+ return f"{init_args.numerator}/{init_args.denominator}"
107
+
108
+ return f"{v.numerator}/{v.denominator}"
109
+
110
+
111
+ def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
112
+ """Deserialize ResolutionFactor for jsonargparse.
113
+
114
+ Args:
115
+ v: the encoded ResolutionFactor.
116
+
117
+ Returns:
118
+ the decoded ResolutionFactor object
119
+ """
120
+ # Handle already-instantiated ResolutionFactor
121
+ if isinstance(v, ResolutionFactor):
122
+ return v
123
+
124
+ # Handle Namespace from class_path syntax (used during config save/validation)
125
+ if hasattr(v, "init_args"):
126
+ init_args = v.init_args
127
+ return ResolutionFactor(
128
+ numerator=init_args.numerator,
129
+ denominator=init_args.denominator,
130
+ )
131
+
132
+ # Handle dict from class_path syntax in YAML config
133
+ if isinstance(v, dict) and "init_args" in v:
134
+ init_args = v["init_args"]
135
+ return ResolutionFactor(
136
+ numerator=init_args.get("numerator", 1),
137
+ denominator=init_args.get("denominator", 1),
138
+ )
139
+
140
+ if isinstance(v, int):
141
+ return ResolutionFactor(numerator=v)
142
+ elif isinstance(v, str):
143
+ parts = v.split("/")
144
+ if len(parts) == 1:
145
+ return ResolutionFactor(numerator=int(parts[0]))
146
+ elif len(parts) == 2:
147
+ return ResolutionFactor(
148
+ numerator=int(parts[0]),
149
+ denominator=int(parts[1]),
150
+ )
151
+ else:
152
+ raise ValueError("expected resolution factor to be of the form x or 1/x")
153
+ else:
154
+ raise ValueError("expected resolution factor to be str or int")
155
+
156
+
157
+ def init_jsonargparse() -> None:
158
+ """Initialize custom jsonargparse serializers."""
159
+ global INITIALIZED
160
+ if INITIALIZED:
161
+ return
162
+ jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
163
+ jsonargparse.typing.register_type(
164
+ datetime, datetime_serializer, datetime_deserializer
165
+ )
166
+ jsonargparse.typing.register_type(
167
+ ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
168
+ )
169
+
170
+ from rslearn.data_sources.data_source import DataSourceContext
171
+
172
+ jsonargparse.typing.register_type(
173
+ DataSourceContext,
174
+ data_source_context_serializer,
175
+ data_source_context_deserializer,
176
+ )
177
+
178
+ INITIALIZED = True
rslearn/utils/mp.py CHANGED
@@ -1,7 +1,8 @@
1
1
  """Multi-processing utilities."""
2
2
 
3
3
  import multiprocessing.pool
4
- from collections.abc import Callable, Generator
4
+ from collections.abc import Callable
5
+ from multiprocessing.pool import IMapIterator
5
6
  from typing import Any
6
7
 
7
8
 
@@ -20,7 +21,7 @@ class StarImapUnorderedWrapper:
20
21
  """
21
22
  self.fn = fn
22
23
 
23
- def __call__(self, kwargs: dict[str, Any]):
24
+ def __call__(self, kwargs: dict[str, Any]) -> Any:
24
25
  """Wrapped call to the underlying function.
25
26
 
26
27
  Args:
@@ -33,7 +34,7 @@ def star_imap_unordered(
33
34
  p: multiprocessing.pool.Pool,
34
35
  fn: Callable[..., Any],
35
36
  kwargs_list: list[dict[str, Any]],
36
- ) -> Generator[Any, None, None]:
37
+ ) -> IMapIterator:
37
38
  """Wrapper for Pool.imap_unordered that exposes kwargs to the function.
38
39
 
39
40
  Args: