rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """Abstract RasterFormat class."""
2
2
 
3
+ import hashlib
3
4
  import json
4
5
  from typing import Any, BinaryIO
5
6
 
@@ -7,16 +8,134 @@ import affine
7
8
  import numpy as np
8
9
  import numpy.typing as npt
9
10
  import rasterio
10
- from class_registry import ClassRegistry
11
11
  from PIL import Image
12
+ from rasterio.crs import CRS
13
+ from rasterio.enums import Resampling
12
14
  from upath import UPath
13
15
 
14
- from rslearn.config import RasterFormatConfig
15
16
  from rslearn.const import TILE_SIZE
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
16
19
 
17
20
  from .geometry import PixelBounds, Projection
18
21
 
19
- RasterFormats = ClassRegistry()
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ def get_bandset_dirname(bands: list[str]) -> str:
26
+ """Get the directory name that should be used to store the given group of bands."""
27
+ # We try to use a human-readable name with underscore as the delimiter, but if that
28
+ # isn't straightforward then we use hash instead.
29
+ if any(["_" in band for band in bands]):
30
+ # In this case we hash the JSON representation of the bands.
31
+ return hashlib.sha256(json.dumps(bands).encode()).hexdigest()
32
+ dirname = "_".join(bands)
33
+ if len(dirname) > 64:
34
+ # Previously we simply joined the bands, but this can result in directory name
35
+ # that is too long. In this case, now we use hash instead.
36
+ # We use a different code path here where we hash the initial directory name
37
+ # instead of the JSON, for historical reasons (to maintain backwards
38
+ # compatibility).
39
+ dirname = hashlib.sha256(dirname.encode()).hexdigest()
40
+ return dirname
41
+
42
+
43
+ def get_raster_projection_and_bounds_from_transform(
44
+ crs: CRS, transform: affine.Affine, width: int, height: int
45
+ ) -> tuple[Projection, PixelBounds]:
46
+ """Determine Projection and bounds from the specified CRS and transform.
47
+
48
+ Args:
49
+ crs: the coordinate reference system.
50
+ transform: corresponding affine transform matrix.
51
+ width: the array width
52
+ height: the array height
53
+
54
+ Returns:
55
+ a tuple (projection, bounds).
56
+ """
57
+ x_resolution = transform.a
58
+ y_resolution = transform.e
59
+ projection = Projection(crs, x_resolution, y_resolution)
60
+ offset = (
61
+ int(round(transform.c / x_resolution)),
62
+ int(round(transform.f / y_resolution)),
63
+ )
64
+ bounds = (offset[0], offset[1], offset[0] + width, offset[1] + height)
65
+ return (projection, bounds)
66
+
67
+
68
+ def get_raster_projection_and_bounds(
69
+ raster: rasterio.DatasetReader,
70
+ ) -> tuple[Projection, PixelBounds]:
71
+ """Determine the Projection and bounds of the specified raster.
72
+
73
+ Args:
74
+ raster: the raster dataset opened with rasterio.
75
+
76
+ Returns:
77
+ a tuple (projection, bounds).
78
+ """
79
+ return get_raster_projection_and_bounds_from_transform(
80
+ raster.crs, raster.transform, raster.width, raster.height
81
+ )
82
+
83
+
84
+ def get_transform_from_projection_and_bounds(
85
+ projection: Projection, bounds: PixelBounds
86
+ ) -> affine.Affine:
87
+ """Get the affine transform that corresponds to the given projection and bounds.
88
+
89
+ Args:
90
+ projection: the projection. Only the resolutions are used.
91
+ bounds: the bounding box. Only the top-left corner is used.
92
+ """
93
+ return affine.Affine(
94
+ projection.x_resolution,
95
+ 0,
96
+ bounds[0] * projection.x_resolution,
97
+ 0,
98
+ projection.y_resolution,
99
+ bounds[1] * projection.y_resolution,
100
+ )
101
+
102
+
103
+ def adjust_projection_and_bounds_for_array(
104
+ projection: Projection, bounds: PixelBounds, array: npt.NDArray
105
+ ) -> tuple[Projection, PixelBounds]:
106
+ """Adjust the projection and bounds to correspond to the resolution of the array.
107
+
108
+ The returned projection and bounds cover the same spatial extent as the inputs, but
109
+ are updated so that the width and height match that of the array.
110
+
111
+ Args:
112
+ projection: the original projection.
113
+ bounds: the original bounds.
114
+ array: the CHW array for which to compute an updated projection and bounds. The
115
+ returned bounds will have the same width and height as this array.
116
+
117
+ Returns:
118
+ a tuple of adjusted (projection, bounds)
119
+ """
120
+ if array.shape[2] == (bounds[2] - bounds[0]) and array.shape[1] == (
121
+ bounds[3] - bounds[1]
122
+ ):
123
+ return (projection, bounds)
124
+
125
+ x_factor = array.shape[2] / (bounds[2] - bounds[0])
126
+ y_factor = array.shape[1] / (bounds[3] - bounds[1])
127
+ adjusted_projection = Projection(
128
+ projection.crs,
129
+ projection.x_resolution / x_factor,
130
+ projection.y_resolution / y_factor,
131
+ )
132
+ adjusted_bounds = (
133
+ round(bounds[0] * x_factor),
134
+ round(bounds[1] * y_factor),
135
+ round(bounds[0] * x_factor) + array.shape[2],
136
+ round(bounds[1] * y_factor) + array.shape[1],
137
+ )
138
+ return (adjusted_projection, adjusted_bounds)
20
139
 
21
140
 
22
141
  class RasterFormat:
@@ -44,21 +163,39 @@ class RasterFormat:
44
163
  raise NotImplementedError
45
164
 
46
165
  def decode_raster(
47
- self, path: UPath, bounds: PixelBounds
48
- ) -> npt.NDArray[Any] | None:
166
+ self,
167
+ path: UPath,
168
+ projection: Projection,
169
+ bounds: PixelBounds,
170
+ resampling: Resampling = Resampling.bilinear,
171
+ ) -> npt.NDArray[Any]:
49
172
  """Decodes raster data.
50
173
 
51
174
  Args:
52
175
  path: the directory to read from
53
- bounds: the bounds of the raster to read
176
+ projection: the projection to read the raster in.
177
+ bounds: the bounds to read in the given projection.
178
+ resampling: resampling method to use in case resampling is needed.
179
+
180
+ Returns:
181
+ the raster data
182
+ """
183
+ raise NotImplementedError
184
+
185
+ @staticmethod
186
+ def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
187
+ """Create a RasterFormat from a config dict.
188
+
189
+ Args:
190
+ name: the name of this format
191
+ config: the config dict
54
192
 
55
193
  Returns:
56
- the raster data, or None if no image content is found
194
+ the RasterFormat instance
57
195
  """
58
196
  raise NotImplementedError
59
197
 
60
198
 
61
- @RasterFormats.register("image_tile")
62
199
  class ImageTileRasterFormat(RasterFormat):
63
200
  """A RasterFormat that stores data in image tiles corresponding to grid cells.
64
201
 
@@ -152,6 +289,19 @@ class ImageTileRasterFormat(RasterFormat):
152
289
  bounds: the bounds of the raster data in the projection
153
290
  array: the raster data (must be CHW)
154
291
  """
292
+ # Write metadata about the projection that we are writing under.
293
+ # We also save dtype and number of bands so we can return correct shape when
294
+ # there are no intersecting tiles.
295
+ with (path / "metadata.json").open("w") as f:
296
+ json.dump(
297
+ {
298
+ "projection": projection.serialize(),
299
+ "dtype": array.dtype.name,
300
+ "num_bands": array.shape[0],
301
+ },
302
+ f,
303
+ )
304
+
155
305
  start_tile = (bounds[0] // self.tile_size, bounds[1] // self.tile_size)
156
306
  end_tile = (bounds[2] // self.tile_size + 1, bounds[3] // self.tile_size + 1)
157
307
  extension = self.get_extension()
@@ -190,17 +340,34 @@ class ImageTileRasterFormat(RasterFormat):
190
340
  self.encode_tile(f, projection, cur_bounds, cur_array)
191
341
 
192
342
  def decode_raster(
193
- self, path: UPath, bounds: PixelBounds
194
- ) -> npt.NDArray[Any] | None:
343
+ self,
344
+ path: UPath,
345
+ projection: Projection,
346
+ bounds: PixelBounds,
347
+ resampling: Resampling = Resampling.bilinear,
348
+ ) -> npt.NDArray[Any]:
195
349
  """Decodes raster data.
196
350
 
197
351
  Args:
198
352
  path: the directory to read from
199
- bounds: the bounds of the raster to read
353
+ projection: the projection to read the raster in.
354
+ bounds: the bounds to read in the given projection.
355
+ resampling: resampling method to use in case resampling is needed.
200
356
 
201
357
  Returns:
202
- the raster data, or None if no image content is found
358
+ the raster data
203
359
  """
360
+ # Verify that the source data has the same projection as the requested one.
361
+ # ImageTileRasterFormat currently does not support re-projecting.
362
+ with (path / "metadata.json").open() as f:
363
+ image_metadata = json.load(f)
364
+ source_data_projection = Projection.deserialize(image_metadata["projection"])
365
+ if source_data_projection != projection:
366
+ raise NotImplementedError(
367
+ "not implemented to re-project source data "
368
+ + f"(source projection {source_data_projection} does not match requested projection {projection})"
369
+ )
370
+
204
371
  extension = self.get_extension()
205
372
 
206
373
  # Load tiles one at a time.
@@ -209,7 +376,12 @@ class ImageTileRasterFormat(RasterFormat):
209
376
  (bounds[2] - 1) // self.tile_size + 1,
210
377
  (bounds[3] - 1) // self.tile_size + 1,
211
378
  )
212
- dst = None
379
+ dst_shape = (
380
+ image_metadata["num_bands"],
381
+ bounds[3] - bounds[1],
382
+ bounds[2] - bounds[0],
383
+ )
384
+ dst = np.zeros(dst_shape, dtype=image_metadata["dtype"])
213
385
  for col in range(start_tile[0], end_tile[0]):
214
386
  for row in range(start_tile[1], end_tile[1]):
215
387
  fname = path / f"{col}_{row}.{extension}"
@@ -272,13 +444,17 @@ class ImageTileRasterFormat(RasterFormat):
272
444
  )
273
445
 
274
446
 
275
- @RasterFormats.register("geotiff")
276
447
  class GeotiffRasterFormat(RasterFormat):
277
448
  """A raster format that uses one big, tiled GeoTIFF with small block size."""
278
449
 
279
450
  fname = "geotiff.tif"
280
451
 
281
- def __init__(self, block_size: int = TILE_SIZE, always_enable_tiling: bool = False):
452
+ def __init__(
453
+ self,
454
+ block_size: int = TILE_SIZE,
455
+ always_enable_tiling: bool = False,
456
+ geotiff_options: dict[str, Any] = {},
457
+ ):
282
458
  """Initializes a GeotiffRasterFormat.
283
459
 
284
460
  Args:
@@ -287,9 +463,11 @@ class GeotiffRasterFormat(RasterFormat):
287
463
  GeoTIFFs. The default is False so that tiling is only used if the size
288
464
  of the GeoTIFF exceeds the block_size on either dimension. If True,
289
465
  then tiling is always enabled (cloud-optimized GeoTIFF).
466
+ geotiff_options: other options to pass to rasterio.open (for writes).
290
467
  """
291
468
  self.block_size = block_size
292
469
  self.always_enable_tiling = always_enable_tiling
470
+ self.geotiff_options = geotiff_options
293
471
 
294
472
  def encode_raster(
295
473
  self,
@@ -297,6 +475,7 @@ class GeotiffRasterFormat(RasterFormat):
297
475
  projection: Projection,
298
476
  bounds: PixelBounds,
299
477
  array: npt.NDArray[Any],
478
+ fname: str | None = None,
300
479
  ) -> None:
301
480
  """Encodes raster data.
302
481
 
@@ -305,7 +484,11 @@ class GeotiffRasterFormat(RasterFormat):
305
484
  projection: the projection of the raster data
306
485
  bounds: the bounds of the raster data in the projection
307
486
  array: the raster data
487
+ fname: override the filename to save as
308
488
  """
489
+ if fname is None:
490
+ fname = self.fname
491
+
309
492
  crs = projection.crs
310
493
  transform = affine.Affine(
311
494
  projection.x_resolution,
@@ -338,76 +521,48 @@ class GeotiffRasterFormat(RasterFormat):
338
521
  profile["blockxsize"] = self.block_size
339
522
  profile["blockysize"] = self.block_size
340
523
 
524
+ profile.update(self.geotiff_options)
525
+
341
526
  path.mkdir(parents=True, exist_ok=True)
342
- with (path / self.fname).open("wb") as f:
343
- with rasterio.open(f, "w", **profile) as dst:
344
- dst.write(array)
527
+ logger.debug(f"Writing geotiff to {path / fname}")
528
+ with open_rasterio_upath_writer(path / fname, **profile) as dst:
529
+ dst.write(array)
345
530
 
346
531
  def decode_raster(
347
- self, path: UPath, bounds: PixelBounds
348
- ) -> npt.NDArray[Any] | None:
532
+ self,
533
+ path: UPath,
534
+ projection: Projection,
535
+ bounds: PixelBounds,
536
+ resampling: Resampling = Resampling.bilinear,
537
+ fname: str | None = None,
538
+ ) -> npt.NDArray[Any]:
349
539
  """Decodes raster data.
350
540
 
351
541
  Args:
352
542
  path: the directory to read from
353
- bounds: the bounds of the raster to read
543
+ projection: the projection to read the raster in.
544
+ bounds: the bounds to read in the given projection.
545
+ resampling: resampling method to use in case resampling is needed.
546
+ fname: override the filename to read from
354
547
 
355
548
  Returns:
356
- the raster data, or None if no image content is found
549
+ the raster data
357
550
  """
358
- with (path / self.fname).open("rb") as f:
359
- with rasterio.open(f) as src:
360
- transform = src.transform
361
- x_resolution = transform.a
362
- y_resolution = transform.e
363
- offset = (
364
- int(transform.c / x_resolution),
365
- int(transform.f / y_resolution),
366
- )
367
- # bounds is in global pixel coordinates.
368
- # We first convert that to pixels relative to top-left of the raster.
369
- relative_bounds = [
370
- bounds[0] - offset[0],
371
- bounds[1] - offset[1],
372
- bounds[2] - offset[0],
373
- bounds[3] - offset[1],
374
- ]
375
- if (
376
- relative_bounds[2] < 0
377
- or relative_bounds[3] < 0
378
- or relative_bounds[0] >= src.width
379
- or relative_bounds[1] >= src.height
380
- ):
381
- return None
382
- # Now get the actual pixels we will read, which must be contained in
383
- # the GeoTIFF.
384
- # Padding is (before_x, before_y, after_x, after_y) and will be used to
385
- # pad the output back to the originally requested bounds.
386
- padding = [0, 0, 0, 0]
387
- if relative_bounds[0] < 0:
388
- padding[0] = -relative_bounds[0]
389
- relative_bounds[0] = 0
390
- if relative_bounds[1] < 0:
391
- padding[1] = -relative_bounds[1]
392
- relative_bounds[1] = 0
393
- if relative_bounds[2] > src.width:
394
- padding[2] = relative_bounds[2] - src.width
395
- relative_bounds[2] = src.width
396
- if relative_bounds[3] > src.height:
397
- padding[3] = relative_bounds[3] - src.height
398
- relative_bounds[3] = src.height
399
-
400
- window = rasterio.windows.Window(
401
- relative_bounds[0],
402
- relative_bounds[1],
403
- relative_bounds[2] - relative_bounds[0],
404
- relative_bounds[3] - relative_bounds[1],
405
- )
406
- array = src.read(window=window)
407
- array = np.pad(
408
- array, ((0, 0), (padding[1], padding[3]), (padding[0], padding[2]))
409
- )
410
- return array
551
+ if fname is None:
552
+ fname = self.fname
553
+
554
+ # Construct the transform to use for the warped dataset.
555
+ wanted_transform = get_transform_from_projection_and_bounds(projection, bounds)
556
+ with open_rasterio_upath_reader(path / fname) as src:
557
+ with rasterio.vrt.WarpedVRT(
558
+ src,
559
+ crs=projection.crs,
560
+ transform=wanted_transform,
561
+ width=bounds[2] - bounds[0],
562
+ height=bounds[3] - bounds[1],
563
+ resampling=resampling,
564
+ ) as vrt:
565
+ return vrt.read()
411
566
 
412
567
  def get_raster_bounds(self, path: UPath) -> PixelBounds:
413
568
  """Returns the bounds of the stored raster.
@@ -418,21 +573,9 @@ class GeotiffRasterFormat(RasterFormat):
418
573
  Returns:
419
574
  the PixelBounds of the raster
420
575
  """
421
- with (path / self.fname).open("rb") as f:
422
- with rasterio.open(f) as src:
423
- transform = src.transform
424
- x_resolution = transform.a
425
- y_resolution = transform.e
426
- offset = (
427
- int(transform.c / x_resolution),
428
- int(transform.f / y_resolution),
429
- )
430
- return (
431
- offset[0],
432
- offset[1],
433
- offset[0] + src.width,
434
- offset[1] + src.height,
435
- )
576
+ with open_rasterio_upath_reader(path / self.fname) as src:
577
+ _, bounds = get_raster_projection_and_bounds(src)
578
+ return bounds
436
579
 
437
580
  @staticmethod
438
581
  def from_config(name: str, config: dict[str, Any]) -> "GeotiffRasterFormat":
@@ -450,10 +593,11 @@ class GeotiffRasterFormat(RasterFormat):
450
593
  kwargs["block_size"] = config["block_size"]
451
594
  if "always_enable_tiling" in config:
452
595
  kwargs["always_enable_tiling"] = config["always_enable_tiling"]
596
+ if "geotiff_options" in config:
597
+ kwargs["geotiff_options"] = config["geotiff_options"]
453
598
  return GeotiffRasterFormat(**kwargs)
454
599
 
455
600
 
456
- @RasterFormats.register("single_image")
457
601
  class SingleImageRasterFormat(RasterFormat):
458
602
  """A raster format that produces a single image called image.png/jpg.
459
603
 
@@ -503,35 +647,57 @@ class SingleImageRasterFormat(RasterFormat):
503
647
  if array.shape[2] == 1:
504
648
  array = array[:, :, 0]
505
649
  Image.fromarray(array).save(f, format=self.format.upper())
650
+
651
+ # Since the image file doesn't include the georeferencing, we store it in an
652
+ # auxiliary metadata file.
506
653
  with (path / "metadata.json").open("w") as f:
507
654
  json.dump(
508
655
  {
656
+ "projection": projection.serialize(),
509
657
  "bounds": bounds,
510
658
  },
511
659
  f,
512
660
  )
513
661
 
514
662
  def decode_raster(
515
- self, path: UPath, bounds: PixelBounds
516
- ) -> npt.NDArray[Any] | None:
663
+ self,
664
+ path: UPath,
665
+ projection: Projection,
666
+ bounds: PixelBounds,
667
+ resampling: Resampling = Resampling.bilinear,
668
+ ) -> npt.NDArray[Any]:
517
669
  """Decodes raster data.
518
670
 
519
671
  Args:
520
672
  path: the directory to read from
521
- bounds: the bounds of the raster to read
673
+ projection: the projection to read the raster in.
674
+ bounds: the bounds to read in the given projection.
675
+ resampling: resampling method to use in case resampling is needed.
522
676
 
523
677
  Returns:
524
- the raster data, or None if no image content is found
678
+ the raster data
525
679
  """
526
- image_fname = path / ("image." + self.get_extension())
680
+ # Try to get the bounds of the saved image from the metadata file.
681
+ # In old versions, the file may be missing the projection key.
527
682
  metadata_fname = path / "metadata.json"
528
- if metadata_fname.exists():
529
- with metadata_fname.open() as f:
530
- image_bounds = json.load(f)["bounds"]
531
- else:
532
- # Backwards compatibility -- assume that requested bounds matches the window bounds.
533
- image_bounds = bounds
683
+ with metadata_fname.open() as f:
684
+ image_metadata = json.load(f)
685
+
686
+ image_bounds = image_metadata["bounds"]
534
687
 
688
+ # If the projection key is set, verify that it matches the requested projection
689
+ # since SingleImageRasterFormat currently does not support re-projecting.
690
+ if "projection" in image_metadata:
691
+ source_data_projection = Projection.deserialize(
692
+ image_metadata["projection"]
693
+ )
694
+ if projection != source_data_projection:
695
+ raise NotImplementedError(
696
+ "not implemented to re-project source data "
697
+ + f"(source projection {source_data_projection} does not match requested projection {projection})"
698
+ )
699
+
700
+ image_fname = path / ("image." + self.get_extension())
535
701
  with image_fname.open("rb") as f:
536
702
  array = np.array(Image.open(f, formats=[self.format.upper()]))
537
703
 
@@ -583,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
583
749
  if "format" in config:
584
750
  kwargs["format"] = config["format"]
585
751
  return SingleImageRasterFormat(**kwargs)
586
-
587
-
588
- def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
589
- """Loads a RasterFormat from a RasterFormatConfig.
590
-
591
- Args:
592
- config: the RasterFormatConfig configuration object specifying the
593
- RasterFormat.
594
-
595
- Returns:
596
- the loaded RasterFormat implementation
597
- """
598
- cls = RasterFormats.get_class(config.name)
599
- return cls.from_config(config.name, config.config_dict)
@@ -1,7 +1,9 @@
1
1
  """RtreeIndex spatial index implementation."""
2
2
 
3
+ import hashlib
3
4
  import os
4
5
  import shutil
6
+ import tempfile
5
7
  from collections.abc import Callable
6
8
  from typing import Any
7
9
 
@@ -9,8 +11,11 @@ import fsspec
9
11
  from rtree import index
10
12
  from upath import UPath
11
13
 
14
+ from rslearn.log_utils import get_logger
12
15
  from rslearn.utils.spatial_index import SpatialIndex
13
16
 
17
+ logger = get_logger(__name__)
18
+
14
19
 
15
20
  class RtreeIndex(SpatialIndex):
16
21
  """An index of spatiotemporal geometries backed by an rtree index.
@@ -18,7 +23,7 @@ class RtreeIndex(SpatialIndex):
18
23
  Both in-memory and on-disk options are supported.
19
24
  """
20
25
 
21
- def __init__(self, fname: str | None = None):
26
+ def __init__(self, fname: str | None = None) -> None:
22
27
  """Initialize a new RtreeIndex.
23
28
 
24
29
  If fname is set, the index is persisted on disk, otherwise it is in-memory.
@@ -50,6 +55,7 @@ class RtreeIndex(SpatialIndex):
50
55
  self.counter += 1
51
56
  self.index.insert(id=self.counter, coordinates=box, obj=data)
52
57
 
58
+ # TODO: Make a named tuple for all the bounding box stuff
53
59
  def query(self, box: tuple[float, float, float, float]) -> list[Any]:
54
60
  """Query the index for objects intersecting a box.
55
61
 
@@ -63,20 +69,51 @@ class RtreeIndex(SpatialIndex):
63
69
  return [r.object for r in results]
64
70
 
65
71
 
72
+ def delete_partially_created_local_files(fname: str) -> None:
73
+ """Delete partially created .dat and .idx files."""
74
+ extensions = [".dat", ".idx"]
75
+ for ext in extensions:
76
+ cur_fname = fname + ext
77
+ if os.path.exists(cur_fname):
78
+ os.unlink(cur_fname)
79
+
80
+
81
+ def _get_tmp_dir_for_cached_rtree(cache_dir: UPath) -> str:
82
+ """Get a local temporary directory to store the rtree from the specified cache_dir.
83
+
84
+ This function is deterministic so the same cache_dir will always yield the same
85
+ local temporary directory.
86
+
87
+ Note that the directory is not cleaned up after the program exits, so the rtree
88
+ will remain there. This is because this function may be called from multiple worker
89
+ processes but the index should be reused across workers.
90
+
91
+ Args:
92
+ cache_dir: the non-local directory where the rtree files are stored.
93
+
94
+ Returns:
95
+ the temporary local directory to copy the cached rtree to.
96
+ """
97
+ cache_id = hashlib.sha256(str(cache_dir).encode()).hexdigest()
98
+ tmp_dir = os.path.join(
99
+ tempfile.gettempdir(), "rslearn_cache", "rtree_index", cache_id
100
+ )
101
+ os.makedirs(tmp_dir, exist_ok=True)
102
+ return tmp_dir
103
+
104
+
66
105
  def get_cached_rtree(
67
- cache_dir: UPath, tmp_dir: str, build_fn: Callable[[RtreeIndex], None]
106
+ cache_dir: UPath, build_fn: Callable[[RtreeIndex], None]
68
107
  ) -> RtreeIndex:
69
108
  """Returns an RtreeIndex cached in cache_dir, creating it if needed.
70
109
 
71
110
  The .dat and .idx files are cached in cache_dir. Since RtreeIndex expects local
72
- filesystem, it is copied to a local temporary directory if needed. If the index
73
- doesn't exist yet, then it is created using build_fn.
111
+ filesystem, it is copied to a local temporary directory if needed (it is not needed
112
+ if the cache_dir is already on local filesystem). If the index doesn't exist yet,
113
+ then it is created using build_fn.
74
114
 
75
115
  Args:
76
116
  cache_dir: directory to cache the index files.
77
- tmp_dir: temporary local directory to use in case cache_dir is on a remote
78
- filesystem. The caller is responsible for cleaning this up when they don't
79
- need the index anymore.
80
117
  build_fn: function to build the index in case it doesn't exist yet.
81
118
 
82
119
  Returns:
@@ -95,14 +132,13 @@ def get_cached_rtree(
95
132
  if is_local_cache:
96
133
  local_fname = (cache_dir / "rtree_index").path
97
134
  else:
135
+ tmp_dir = _get_tmp_dir_for_cached_rtree(cache_dir)
98
136
  local_fname = os.path.join(tmp_dir, "rtree_index")
137
+ delete_partially_created_local_files(local_fname)
99
138
 
100
- # Delete any local files that might be partially created.
101
- for ext in extensions:
102
- cur_fname = local_fname + ext
103
- if os.path.exists(cur_fname):
104
- os.unlink(cur_fname)
105
-
139
+ logger.info(
140
+ "building rtree index at %s to be cached at %s", local_fname, cache_dir
141
+ )
106
142
  rtree_index = RtreeIndex(local_fname)
107
143
  build_fn(rtree_index)
108
144
  del rtree_index
@@ -115,6 +151,7 @@ def get_cached_rtree(
115
151
 
116
152
  # Create the completed file to indicate index is ready in cache.
117
153
  completed_fname.touch()
154
+ logger.info("rtree index is built and ready")
118
155
 
119
156
  else:
120
157
  # Initialize the index from the cached version.
@@ -122,10 +159,20 @@ def get_cached_rtree(
122
159
  if is_local_cache:
123
160
  local_fname = (cache_dir / "rtree_index").path
124
161
  else:
162
+ tmp_dir = _get_tmp_dir_for_cached_rtree(cache_dir)
125
163
  local_fname = os.path.join(tmp_dir, "rtree_index")
126
- for ext in extensions:
127
- with (cache_dir / f"rtree_index{ext}").open("rb") as src:
128
- with open(local_fname + ext, "wb") as dst:
129
- shutil.copyfileobj(src, dst)
164
+
165
+ if not os.path.exists(local_fname + extensions[-1]):
166
+ logger.info(
167
+ "copying rtree index from non-local cache at %s to local temporary directory %s",
168
+ cache_dir,
169
+ local_fname,
170
+ )
171
+ for ext in extensions:
172
+ with (cache_dir / f"rtree_index{ext}").open("rb") as src:
173
+ with open(local_fname + ext + ".tmp", "wb") as dst:
174
+ shutil.copyfileobj(src, dst)
175
+ os.rename(local_fname + ext + ".tmp", local_fname + ext)
176
+ logger.info("rtree index is ready")
130
177
 
131
178
  return RtreeIndex(local_fname)