rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -4,37 +4,32 @@ from typing import Any
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
- from upath import UPath
7
+ from rasterio.enums import Resampling
9
8
 
10
9
  from rslearn.config import (
10
+ BandSetConfig,
11
+ CompositingMethod,
11
12
  LayerConfig,
12
- RasterFormatConfig,
13
- RasterLayerConfig,
14
- VectorLayerConfig,
15
13
  )
16
- from rslearn.data_sources import Item
17
- from rslearn.tile_stores import TileStore, TileStoreLayer, get_tile_store_for_layer
18
- from rslearn.utils import Feature, PixelBounds
19
- from rslearn.utils.raster_format import load_raster_format
20
- from rslearn.utils.vector_format import load_vector_format
14
+ from rslearn.data_sources.data_source import ItemType
15
+ from rslearn.tile_stores import TileStoreWithLayer
16
+ from rslearn.utils.feature import Feature
17
+ from rslearn.utils.geometry import PixelBounds, Projection
21
18
 
22
19
  from .remap import Remapper, load_remapper
23
20
  from .window import Window
24
21
 
25
- Materializers = ClassRegistry()
26
-
27
22
 
28
23
  class Materializer:
29
24
  """An abstract class that materializes data from a tile store."""
30
25
 
31
26
  def materialize(
32
27
  self,
33
- tile_store: TileStore,
28
+ tile_store: TileStoreWithLayer,
34
29
  window: Window,
35
30
  layer_name: str,
36
31
  layer_cfg: LayerConfig,
37
- item_groups: list[list[Item]],
32
+ item_groups: list[list[ItemType]],
38
33
  ) -> None:
39
34
  """Materialize portions of items corresponding to this window into the dataset.
40
35
 
@@ -50,11 +45,16 @@ class Materializer:
50
45
 
51
46
  def read_raster_window_from_tiles(
52
47
  dst: npt.NDArray[Any],
53
- ts_layer: TileStoreLayer,
48
+ tile_store: TileStoreWithLayer,
49
+ item_name: str,
50
+ bands: list[str],
51
+ projection: Projection,
54
52
  bounds: PixelBounds,
55
53
  src_indexes: list[int],
56
54
  dst_indexes: list[int],
55
+ nodata_vals: list[float],
57
56
  remapper: Remapper | None = None,
57
+ resampling: Resampling = Resampling.bilinear,
58
58
  ) -> None:
59
59
  """Read a window of raster data from tiles in a tile store.
60
60
 
@@ -62,13 +62,22 @@ def read_raster_window_from_tiles(
62
62
 
63
63
  Args:
64
64
  dst: the destination numpy array
65
- ts_layer: the tile store layer to read
66
- bounds: the bounds in pixel coordinates matching projection of ts_layer
65
+ tile_store: the TileStore to read from.
66
+ item_name: the item name.
67
+ bands: the bands that identify the raster we want to read.
68
+ projection: the projection of the dst array.
69
+ bounds: the bounds of the dst array.
67
70
  src_indexes: the source band indexes to use
68
71
  dst_indexes: corresponding destination band indexes for each source band index
72
+ nodata_vals: the nodata values for each band, to determine which parts of dst
73
+ should be overwritten.
69
74
  remapper: optional remapper to apply on the source pixel values
75
+ resampling: how to resample the pixels in case re-projection is needed.
70
76
  """
71
- src_bounds = ts_layer.get_raster_bounds()
77
+ # Only read the portion of the raster that overlaps with dst.
78
+ # This way we can avoid creating big arrays that are all empty which speeds things
79
+ # up for large windows.
80
+ src_bounds = tile_store.get_raster_bounds(item_name, bands, projection)
72
81
  intersection = (
73
82
  max(bounds[0], src_bounds[0]),
74
83
  max(bounds[1], src_bounds[1]),
@@ -81,7 +90,9 @@ def read_raster_window_from_tiles(
81
90
  dst_col_offset = intersection[0] - bounds[0]
82
91
  dst_row_offset = intersection[1] - bounds[1]
83
92
 
84
- src = ts_layer.read_raster(intersection)
93
+ src = tile_store.read_raster(
94
+ item_name, bands, projection, intersection, resampling=resampling
95
+ )
85
96
  src = src[src_indexes, :, :]
86
97
  if remapper:
87
98
  src = remapper(src, dst.dtype)
@@ -91,45 +102,403 @@ def read_raster_window_from_tiles(
91
102
  dst_row_offset : dst_row_offset + src.shape[1],
92
103
  dst_col_offset : dst_col_offset + src.shape[2],
93
104
  ]
94
- mask = dst_crop[dst_indexes, :, :].max(axis=0) == 0
105
+
106
+ # Create mask indicating where dst has no data (based on nodata_vals).
107
+ # We overwrite dst at pixels where all the bands are nodata.
108
+ nodata_vals_arr = np.array(nodata_vals)[:, None, None]
109
+ mask = (dst_crop[dst_indexes, :, :] == nodata_vals_arr).min(axis=0)
110
+
95
111
  for src_index, dst_index in enumerate(dst_indexes):
96
112
  dst_crop[dst_index, mask] = src[src_index, mask]
97
113
 
98
114
 
99
- @Materializers.register("raster")
115
+ def get_needed_band_sets_and_indexes(
116
+ item: ItemType,
117
+ bands: list[str],
118
+ tile_store: TileStoreWithLayer,
119
+ ) -> list[tuple[list[str], list[int], list[int]]]:
120
+ """Identify indexes of required bands in tile store.
121
+
122
+ Returns:
123
+ A list for each tile-store layer that contains at least
124
+ one requested band, a tuple: (src_bands, src_idx, dst_idx) where
125
+ - src_bands: the full band list for that layer,
126
+ - src_idx: indexes into src_bands of the bands that were requested,
127
+ - dst_idx: corresponding indexes in the requested `bands` list.
128
+ """
129
+ # Identify which tile store layer(s) to read to get the configured bands.
130
+ wanted_band_indexes = {}
131
+ for i, band in enumerate(bands):
132
+ wanted_band_indexes[band] = i
133
+
134
+ available_bands = tile_store.get_raster_bands(item.name)
135
+ needed_band_sets_and_indexes = []
136
+
137
+ for src_bands in available_bands:
138
+ needed_src_indexes = []
139
+ needed_dst_indexes = []
140
+ for i, band in enumerate(src_bands):
141
+ if band not in wanted_band_indexes:
142
+ continue
143
+ needed_src_indexes.append(i)
144
+ needed_dst_indexes.append(wanted_band_indexes[band])
145
+ del wanted_band_indexes[band]
146
+ if len(needed_src_indexes) == 0:
147
+ continue
148
+ needed_band_sets_and_indexes.append(
149
+ (src_bands, needed_src_indexes, needed_dst_indexes)
150
+ )
151
+
152
+ if len(wanted_band_indexes) > 0:
153
+ # This item doesn't have all the needed bands, so skip it.
154
+ return []
155
+
156
+ return needed_band_sets_and_indexes
157
+
158
+
159
+ def build_first_valid_composite(
160
+ group: list[ItemType],
161
+ nodata_vals: list[Any],
162
+ bands: list[str],
163
+ bounds: PixelBounds,
164
+ band_dtype: npt.DTypeLike,
165
+ tile_store: TileStoreWithLayer,
166
+ projection: Projection,
167
+ remapper: Remapper | None,
168
+ resampling_method: Resampling = Resampling.bilinear,
169
+ ) -> npt.NDArray[np.generic]:
170
+ """Build a composite by selecting the first valid pixel of items in the group.
171
+
172
+ A composite of shape of (bands,bounds) is created by iterating over items in
173
+ group in order and selecting the first pixel that is not nodata per index.
174
+
175
+ Args:
176
+ group: list of items to composite together
177
+ nodata_vals: list of nodata values for each band
178
+ bands: list of band names to include in the composite
179
+ bounds: pixel bounds defining the spatial extent of the composite
180
+ band_dtype: data type for the output bands
181
+ tile_store: tile store containing the actual raster data
182
+ projection: spatial projection for the composite
183
+ remapper: remapper to apply to pixel values, or None
184
+ resampling_method: resampling method to use when reprojecting
185
+
186
+ Returns:
187
+ Composite of shape (bands, bounds) built from all items in the group
188
+
189
+ """
190
+ # Initialize the destination array to the nodata values.
191
+ # We default the nodata value to 0.
192
+ dst = np.zeros(
193
+ (len(bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
194
+ dtype=band_dtype,
195
+ )
196
+
197
+ for idx, nodata_val in enumerate(nodata_vals):
198
+ dst[idx] = nodata_val
199
+
200
+ for item in group:
201
+ needed_band_sets_and_indexes = get_needed_band_sets_and_indexes(
202
+ item, bands, tile_store
203
+ )
204
+
205
+ for (
206
+ src_bands,
207
+ src_indexes,
208
+ dst_indexes,
209
+ ) in needed_band_sets_and_indexes:
210
+ cur_nodata_vals = [nodata_vals[idx] for idx in dst_indexes]
211
+ read_raster_window_from_tiles(
212
+ dst=dst,
213
+ tile_store=tile_store,
214
+ item_name=item.name,
215
+ bands=src_bands,
216
+ projection=projection,
217
+ bounds=bounds,
218
+ src_indexes=src_indexes,
219
+ dst_indexes=dst_indexes,
220
+ nodata_vals=cur_nodata_vals,
221
+ remapper=remapper,
222
+ resampling=resampling_method,
223
+ )
224
+
225
+ return dst
226
+
227
+
228
+ def read_and_stack_raster_windows(
229
+ group: list[ItemType],
230
+ bounds: PixelBounds,
231
+ bands: list[str],
232
+ tile_store: TileStoreWithLayer,
233
+ projection: Projection,
234
+ nodata_vals: list[Any],
235
+ remapper: Remapper | None,
236
+ band_dtype: npt.DTypeLike,
237
+ resampling_method: Resampling = Resampling.bilinear,
238
+ ) -> npt.NDArray[np.generic]:
239
+ """Create a stack of extent aligned raster windows.
240
+
241
+ Args:
242
+ group: Iterable of items (e.g., scene metadata objects) to read data from.
243
+ bounds: Pixel bounds as (xmin, ymin, xmax, ymax) defining the spatial extent.
244
+ bands: List of band names to include in the output.
245
+ tile_store: Tile store containing the raster tiles for the items.
246
+ projection: Projection object specifying the spatial reference system.
247
+ nodata_vals: List of nodata values corresponding to each band.
248
+ band_dtype: Data type for the output raster (e.g., np.uint16, np.float32).
249
+ remapper: Optional remapper object to transform pixel values after reading.
250
+ resampling_method: Resampling method to use when reading/reprojecting tiles.
251
+
252
+ Returns:
253
+ NumPy array of shape (num_items, num_bands, height, width) containing
254
+ the stacked rasters for all items, with nodata values filled where data
255
+ is missing.
256
+ """
257
+ height = bounds[3] - bounds[1]
258
+ width = bounds[2] - bounds[0]
259
+ window_shape = (len(bands), height, width)
260
+
261
+ extent_aligned_raster_windows: list[np.ndarray] = []
262
+
263
+ for item in group:
264
+ # Initialize destination array to nodata
265
+ dst = np.empty(window_shape, dtype=band_dtype)
266
+ for idx, nodata_val in enumerate(nodata_vals):
267
+ dst[idx, :, :] = nodata_val
268
+
269
+ # Determine which source band sets/indexes are needed for this item
270
+ needed_band_sets_and_indexes = get_needed_band_sets_and_indexes(
271
+ item, bands, tile_store
272
+ )
273
+
274
+ # Fill the destination window from the tile store
275
+ for src_bands, src_indexes, dst_indexes in needed_band_sets_and_indexes:
276
+ cur_nodata_vals = [nodata_vals[idx] for idx in dst_indexes]
277
+ read_raster_window_from_tiles(
278
+ dst=dst,
279
+ tile_store=tile_store,
280
+ item_name=item.name,
281
+ bands=src_bands,
282
+ projection=projection,
283
+ bounds=bounds,
284
+ src_indexes=src_indexes,
285
+ dst_indexes=dst_indexes,
286
+ nodata_vals=cur_nodata_vals,
287
+ remapper=remapper,
288
+ resampling=resampling_method,
289
+ )
290
+
291
+ extent_aligned_raster_windows.append(dst)
292
+
293
+ # Stack along a new axis (items axis): (N_items, N_bands, H, W)
294
+ stacked_arrays = np.stack(extent_aligned_raster_windows, axis=0)
295
+ return stacked_arrays
296
+
297
+
298
+ def mask_stacked_rasters(
299
+ stacked_rasters: npt.NDArray[np.generic],
300
+ nodata_vals: list[Any],
301
+ ) -> np.ma.MaskedArray:
302
+ """Masks the stacked rasters - each items band with the corresponding nodata val.
303
+
304
+ Args:
305
+ stacked_rasters: NumPy array of shape (num_items, num_bands, height, width)
306
+ containing raster values for each item in the group.
307
+ nodata_vals: Sequence of nodata values, one per band, used to identify invalid
308
+ pixels in the stacked rasters.
309
+
310
+ Returns:
311
+ np.ma.MaskedArray with the same shape as `stacked_rasters`, where all
312
+ pixels equal to the per-band nodata value are masked.
313
+ """
314
+ # Create mask based on nodata values
315
+ nodata_vals_array = np.array(nodata_vals).reshape(1, -1, 1, 1)
316
+ valid_mask = stacked_rasters != nodata_vals_array
317
+
318
+ # Create masked array for all bands
319
+ masked_data = np.ma.masked_where(~valid_mask, stacked_rasters)
320
+
321
+ return masked_data
322
+
323
+
324
+ def build_mean_composite(
325
+ group: list[ItemType],
326
+ nodata_vals: list[Any],
327
+ bands: list[str],
328
+ bounds: PixelBounds,
329
+ band_dtype: npt.DTypeLike,
330
+ tile_store: TileStoreWithLayer,
331
+ projection: Projection,
332
+ remapper: Remapper | None,
333
+ resampling_method: Resampling = Resampling.bilinear,
334
+ ) -> npt.NDArray[np.generic]:
335
+ """Build a composite by computing the mean of valid pixels across items in the group.
336
+
337
+ A composite of shape (bands, bounds) is created by computing the per-pixel mean of
338
+ valid (non-nodata) pixels across all items in the group.
339
+
340
+ Args:
341
+ group: list of items to composite together
342
+ nodata_vals: list of nodata values for each band
343
+ bands: list of band names to include in the composite
344
+ bounds: pixel bounds defining the spatial extent of the composite
345
+ band_dtype: data type for the output bands
346
+ tile_store: tile store containing the raster data
347
+ projection: spatial projection for the composite
348
+ remapper: remapper to apply to pixel values, or None
349
+ resampling_method: resampling method to use when reprojecting
350
+
351
+ Returns:
352
+ Composite of shape (bands, bounds) having per-pixel mean of all items in the group
353
+ """
354
+ # TODO: Might want to add a running sum/count based method to reduce memory utilization
355
+
356
+ stacked_arrays = read_and_stack_raster_windows(
357
+ group=group,
358
+ bounds=bounds,
359
+ bands=bands,
360
+ tile_store=tile_store,
361
+ projection=projection,
362
+ nodata_vals=nodata_vals,
363
+ band_dtype=band_dtype,
364
+ remapper=remapper,
365
+ resampling_method=resampling_method,
366
+ )
367
+
368
+ # Mask stacked arrays with nodata values of each band
369
+ masked_data = mask_stacked_rasters(stacked_arrays, nodata_vals)
370
+
371
+ # Compute mean along the items axis for all
372
+ mean_result = np.ma.mean(masked_data, axis=0)
373
+
374
+ # Fill masked values and convert to target dtype
375
+ fill_vals = np.array(nodata_vals).reshape(-1, 1, 1)
376
+ result = np.ma.filled(mean_result, fill_value=fill_vals).astype(band_dtype)
377
+
378
+ return result
379
+
380
+
381
+ def build_median_composite(
382
+ group: list[ItemType],
383
+ nodata_vals: list[Any],
384
+ bands: list[str],
385
+ bounds: PixelBounds,
386
+ band_dtype: npt.DTypeLike,
387
+ tile_store: TileStoreWithLayer,
388
+ projection: Projection,
389
+ remapper: Remapper | None,
390
+ resampling_method: Resampling = Resampling.bilinear,
391
+ ) -> npt.NDArray[np.generic]:
392
+ """Build a composite by computing the median of valid pixels across items in the group.
393
+
394
+ A composite of shape (bands, bounds) is created by computing the per-pixel median of
395
+ valid (non-nodata) pixels across all items in the group.
396
+
397
+ Args:
398
+ group: list of items to composite together
399
+ nodata_vals: list of nodata values for each band
400
+ bands: list of band names to include in the composite
401
+ bounds: pixel bounds defining the spatial extent of the composite
402
+ band_dtype: data type for the output bands
403
+ tile_store: tile store containing the raster data
404
+ projection: spatial projection for the composite
405
+ remapper: remapper to apply to pixel values, or None
406
+ resampling_method: resampling method to use when reprojecting
407
+
408
+ Returns:
409
+ Composite of shape (bands, bounds) having per-pixel median of all items in the group
410
+ """
411
+ stacked_arrays = read_and_stack_raster_windows(
412
+ group=group,
413
+ bounds=bounds,
414
+ bands=bands,
415
+ tile_store=tile_store,
416
+ projection=projection,
417
+ nodata_vals=nodata_vals,
418
+ band_dtype=band_dtype,
419
+ remapper=remapper,
420
+ resampling_method=resampling_method,
421
+ )
422
+
423
+ # Mask stacked arrays with nodata values of each band
424
+ masked_data = mask_stacked_rasters(stacked_arrays, nodata_vals)
425
+
426
+ # Compute median along the items axis for all
427
+ mean_result = np.ma.median(masked_data, axis=0)
428
+
429
+ # Fill masked values and convert to target dtype
430
+ fill_vals = np.array(nodata_vals).reshape(-1, 1, 1)
431
+ result = np.ma.filled(mean_result, fill_value=fill_vals).astype(band_dtype)
432
+
433
+ return result
434
+
435
+
436
+ compositing_methods = {
437
+ CompositingMethod.FIRST_VALID: build_first_valid_composite,
438
+ CompositingMethod.MEAN: build_mean_composite,
439
+ CompositingMethod.MEDIAN: build_median_composite,
440
+ }
441
+
442
+
443
+ def build_composite(
444
+ group: list[ItemType],
445
+ compositing_method: CompositingMethod,
446
+ tile_store: TileStoreWithLayer,
447
+ layer_cfg: LayerConfig,
448
+ band_cfg: BandSetConfig,
449
+ projection: Projection,
450
+ bounds: PixelBounds,
451
+ remapper: Remapper | None,
452
+ ) -> npt.NDArray[np.generic]:
453
+ """Build a temporal composite for specified bands from items in the group.
454
+
455
+ Args:
456
+ group: list of items to composite together
457
+ compositing_method: Which method to use for compositing. First valid chooses the first valid value per pixel, mean takes the mean value per pixel
458
+ tile_store: tile store containing the raster data
459
+ layer_cfg: the configuration of the layer to materialize
460
+ band_cfg: the configuration of the layer to materialize. Contains the bands to process.
461
+ projection: spatial projection for the composite
462
+ bounds: pixel bounds defining the spatial extent of the composite
463
+ remapper: remapper to apply to pixel values, or None
464
+ """
465
+ nodata_vals = band_cfg.nodata_vals
466
+ if nodata_vals is None:
467
+ nodata_vals = [0 for _ in band_cfg.bands]
468
+
469
+ return compositing_methods[compositing_method](
470
+ group=group,
471
+ nodata_vals=nodata_vals,
472
+ bands=band_cfg.bands,
473
+ bounds=bounds,
474
+ band_dtype=band_cfg.dtype.get_numpy_dtype(),
475
+ tile_store=tile_store,
476
+ projection=projection,
477
+ resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
478
+ remapper=remapper,
479
+ )
480
+
481
+
100
482
  class RasterMaterializer(Materializer):
101
483
  """A Materializer for raster data."""
102
484
 
103
485
  def materialize(
104
486
  self,
105
- tile_store: TileStore,
487
+ tile_store: TileStoreWithLayer,
106
488
  window: Window,
107
489
  layer_name: str,
108
490
  layer_cfg: LayerConfig,
109
- item_groups: list[list[Item]],
491
+ item_groups: list[list[ItemType]],
110
492
  ) -> None:
111
493
  """Materialize portions of items corresponding to this window into the dataset.
112
494
 
113
495
  Args:
114
- tile_store: the tile store where the items have been ingested (unprefixed)
496
+ tile_store: the tile store where the items have been ingested
115
497
  window: the window to materialize
116
498
  layer_name: name of the layer to materialize
117
499
  layer_cfg: the configuration of the layer to materialize
118
500
  item_groups: the items associated with this window and layer
119
501
  """
120
- assert isinstance(layer_cfg, RasterLayerConfig)
121
-
122
- layer_tile_store = get_tile_store_for_layer(tile_store, layer_name, layer_cfg)
123
-
124
- out_layer_dirs: list[UPath] = []
125
- for group_id in range(len(item_groups)):
126
- if group_id == 0:
127
- out_layer_name = layer_name
128
- else:
129
- out_layer_name = f"{layer_name}.{group_id}"
130
- out_layer_dir = window.path / "layers" / out_layer_name
131
- out_layer_dirs.append(out_layer_dir)
132
-
133
502
  for band_cfg in layer_cfg.band_sets:
134
503
  # band_cfg could specify zoom_offset and maybe other parameters that affect
135
504
  # projection/bounds, so use the corrected projection/bounds.
@@ -142,72 +511,40 @@ class RasterMaterializer(Materializer):
142
511
  if band_cfg.remap:
143
512
  remapper = load_remapper(band_cfg.remap)
144
513
 
145
- raster_format = load_raster_format(
146
- RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
147
- )
514
+ raster_format = band_cfg.instantiate_raster_format()
148
515
 
149
516
  for group_id, group in enumerate(item_groups):
150
- dst = np.zeros(
151
- (len(band_cfg.bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
152
- dtype=band_cfg.dtype.value,
517
+ composite = build_composite(
518
+ group=group,
519
+ compositing_method=layer_cfg.compositing_method,
520
+ tile_store=tile_store,
521
+ layer_cfg=layer_cfg,
522
+ band_cfg=band_cfg,
523
+ projection=projection,
524
+ bounds=bounds,
525
+ remapper=remapper,
153
526
  )
154
- for item in group:
155
- # Identify which tile store layer(s) to read to get the configured
156
- # bands.
157
- needed_band_indexes = {}
158
- for i, band in enumerate(band_cfg.bands):
159
- needed_band_indexes[band] = i
160
- suffixes = layer_tile_store.list_layers((item.name,))
161
- needed_suffixes_and_indexes = []
162
- for suffix in suffixes:
163
- bands = suffix.split("_")
164
- needed_src_indexes = []
165
- needed_dst_indexes = []
166
- for i, band in enumerate(bands):
167
- if band not in needed_band_indexes:
168
- continue
169
- needed_src_indexes.append(i)
170
- needed_dst_indexes.append(needed_band_indexes[band])
171
- del needed_band_indexes[band]
172
- if len(needed_src_indexes) == 0:
173
- continue
174
- needed_suffixes_and_indexes.append(
175
- (suffix, needed_src_indexes, needed_dst_indexes)
176
- )
177
- if len(needed_band_indexes) > 0:
178
- # This item doesn't have all the needed bands, so skip it.
179
- continue
180
-
181
- for suffix, src_indexes, dst_indexes in needed_suffixes_and_indexes:
182
- ts_layer = layer_tile_store.get_layer(
183
- (item.name, suffix, str(projection))
184
- )
185
- read_raster_window_from_tiles(
186
- dst, ts_layer, bounds, src_indexes, dst_indexes, remapper
187
- )
188
-
189
527
  raster_format.encode_raster(
190
- out_layer_dirs[group_id] / "_".join(band_cfg.bands),
528
+ window.get_raster_dir(layer_name, band_cfg.bands, group_id),
191
529
  projection,
192
530
  bounds,
193
- dst,
531
+ composite,
194
532
  )
195
533
 
196
- for out_layer_dir in out_layer_dirs:
197
- (out_layer_dir / "completed").touch()
534
+ for group_id in range(len(item_groups)):
535
+ window.mark_layer_completed(layer_name, group_id)
198
536
 
199
537
 
200
- @Materializers.register("vector")
201
538
  class VectorMaterializer(Materializer):
202
539
  """A Materializer for vector data."""
203
540
 
204
541
  def materialize(
205
542
  self,
206
- tile_store: TileStore,
543
+ tile_store: TileStoreWithLayer,
207
544
  window: Window,
208
545
  layer_name: str,
209
546
  layer_cfg: LayerConfig,
210
- item_groups: list[list[Item]],
547
+ item_groups: list[list[ItemType]],
211
548
  ) -> None:
212
549
  """Materialize portions of items corresponding to this window into the dataset.
213
550
 
@@ -218,33 +555,20 @@ class VectorMaterializer(Materializer):
218
555
  layer_cfg: the configuration of the layer to materialize
219
556
  item_groups: the items associated with this window and layer
220
557
  """
221
- assert isinstance(layer_cfg, VectorLayerConfig)
222
-
223
- projection, bounds = layer_cfg.get_final_projection_and_bounds(
224
- window.projection, window.bounds
225
- )
226
- vector_format = load_vector_format(layer_cfg.format)
227
-
228
- out_layer_dirs: list[UPath] = []
229
- for group_id in range(len(item_groups)):
230
- if group_id == 0:
231
- out_layer_name = layer_name
232
- else:
233
- out_layer_name = f"{layer_name}.{group_id}"
234
- out_layer_dir = window.path / "layers" / out_layer_name
235
- out_layer_dirs.append(out_layer_dir)
558
+ vector_format = layer_cfg.instantiate_vector_format()
236
559
 
237
560
  for group_id, group in enumerate(item_groups):
238
561
  features: list[Feature] = []
239
562
 
240
563
  for item in group:
241
- ts_layer = get_tile_store_for_layer(
242
- tile_store, layer_name, layer_cfg
243
- ).get_layer((item.name, str(projection)))
244
- cur_features = ts_layer.read_vector(bounds)
564
+ cur_features = tile_store.read_vector(
565
+ item.name, window.projection, window.bounds
566
+ )
245
567
  features.extend(cur_features)
246
568
 
247
- vector_format.encode_vector(out_layer_dirs[group_id], projection, features)
569
+ vector_format.encode_vector(
570
+ window.get_layer_dir(layer_name, group_id), features
571
+ )
248
572
 
249
- for out_layer_dir in out_layer_dirs:
250
- (out_layer_dir / "completed").touch()
573
+ for group_id in range(len(item_groups)):
574
+ window.mark_layer_completed(layer_name, group_id)