rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,12 @@ import json
6
6
  import os
7
7
  import tempfile
8
8
  import time
9
- from datetime import datetime, timezone
9
+ from datetime import UTC, datetime
10
10
  from typing import Any
11
11
 
12
12
  import ee
13
+ import numpy as np
14
+ import numpy.typing as npt
13
15
  import rasterio
14
16
  import rasterio.merge
15
17
  import shapely
@@ -18,53 +20,96 @@ from google.cloud import storage
18
20
  from upath import UPath
19
21
 
20
22
  import rslearn.data_sources.utils
21
- import rslearn.utils.mgrs
22
- from rslearn.config import DType, LayerConfig, RasterLayerConfig
23
+ from rslearn.config import DType, LayerConfig
23
24
  from rslearn.const import WGS84_PROJECTION
24
- from rslearn.tile_stores import PrefixedTileStore, TileStore
25
- from rslearn.utils import STGeometry
25
+ from rslearn.dataset.materialize import RasterMaterializer
26
+ from rslearn.dataset.window import Window
27
+ from rslearn.log_utils import get_logger
28
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
29
+ from rslearn.utils.array import copy_spatial_array
26
30
  from rslearn.utils.fsspec import join_upath
27
- from rslearn.utils.rtree_index import get_cached_rtree
31
+ from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
32
+ from rslearn.utils.raster_format import (
33
+ Resampling,
34
+ get_raster_projection_and_bounds_from_transform,
35
+ get_transform_from_projection_and_bounds,
36
+ )
37
+ from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
28
38
 
29
- from .data_source import DataSource, Item, QueryConfig
30
- from .raster_source import ArrayWithTransform, get_needed_projections, ingest_raster
39
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
31
40
 
41
+ logger = get_logger(__name__)
32
42
 
33
- class GEE(DataSource):
43
+
44
+ class NoValidPixelsException(Exception):
45
+ """Exception when GEE API reports that export failed due to no valid pixels."""
46
+
47
+ # Expected GEE error_message when the task fails.
48
+ GEE_MESSAGE = "No valid (un-masked) pixels in export region."
49
+
50
+
51
+ class ExportException(Exception):
52
+ """GEE API export error."""
53
+
54
+ pass
55
+
56
+
57
+ class GEE(DataSource, TileStore):
34
58
  """A data source for ingesting images from Google Earth Engine."""
35
59
 
36
60
  def __init__(
37
61
  self,
38
- config: LayerConfig,
39
62
  collection_name: str,
40
63
  gcs_bucket_name: str,
41
- index_cache_dir: UPath,
64
+ index_cache_dir: str,
42
65
  service_account_name: str,
43
66
  service_account_credentials: str,
67
+ bands: list[str] | None = None,
44
68
  filters: list[tuple[str, Any]] | None = None,
45
69
  dtype: DType | None = None,
70
+ context: DataSourceContext = DataSourceContext(),
46
71
  ) -> None:
47
72
  """Initialize a new GEE instance.
48
73
 
49
74
  Args:
50
- config: configuration for this layer.
51
- collection_name: the Earth Engine collection to ingest images from
75
+ collection_name: the Earth Engine ImageCollection to ingest images from
52
76
  gcs_bucket_name: the Cloud Storage bucket to export GEE images to
53
77
  index_cache_dir: cache directory to store rtree index
54
78
  service_account_name: name of the service account to use for authentication
55
79
  service_account_credentials: service account credentials filename
80
+ bands: the list of bands to ingest, in case the layer config is not present
81
+ in the context.
56
82
  filters: optional list of tuples (property_name, property_value) to filter
57
83
  images (using ee.Filter.eq)
58
84
  dtype: optional desired array data type. If the data obtained from GEE does
59
85
  not match this type, then it is converted.
86
+ context: the data source context.
60
87
  """
61
- self.config = config
62
88
  self.collection_name = collection_name
63
89
  self.gcs_bucket_name = gcs_bucket_name
64
- self.index_cache_dir = index_cache_dir
65
90
  self.filters = filters
66
91
  self.dtype = dtype
67
92
 
93
+ # Get index cache dir depending on dataset path.
94
+ if context.ds_path is not None:
95
+ self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
96
+ else:
97
+ self.index_cache_dir = UPath(index_cache_dir)
98
+
99
+ # Get bands we need to export.
100
+ if context.layer_config is not None:
101
+ self.bands = [
102
+ band
103
+ for band_set in context.layer_config.band_sets
104
+ for band in band_set.bands
105
+ ]
106
+ elif bands is not None:
107
+ self.bands = bands
108
+ else:
109
+ raise ValueError(
110
+ "bands must be specified if layer_config is not present in the context"
111
+ )
112
+
68
113
  self.bucket = storage.Client().bucket(self.gcs_bucket_name)
69
114
 
70
115
  credentials = ee.ServiceAccountCredentials(
@@ -72,44 +117,27 @@ class GEE(DataSource):
72
117
  )
73
118
  ee.Initialize(credentials)
74
119
 
75
- self.rtree_tmp_dir = tempfile.TemporaryDirectory()
76
- self.rtree_index = get_cached_rtree(
77
- self.index_cache_dir, self.rtree_tmp_dir.name, self._build_index
78
- )
120
+ self.index_cache_dir.mkdir(parents=True, exist_ok=True)
121
+ self.rtree_index = get_cached_rtree(self.index_cache_dir, self._build_index)
79
122
 
80
- @staticmethod
81
- def from_config(config: LayerConfig, ds_path: UPath) -> "GEE":
82
- """Creates a new GEE instance from a configuration dictionary."""
83
- d = config.data_source.config_dict
84
- kwargs = {
85
- "config": config,
86
- "collection_name": d["collection_name"],
87
- "gcs_bucket_name": d["gcs_bucket_name"],
88
- "service_account_name": d["service_account_name"],
89
- "service_account_credentials": d["service_account_credentials"],
90
- "filters": d.get("filters"),
91
- "index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
92
- }
93
- if "dtype" in d:
94
- kwargs["dtype"] = DType(d["dtype"])
95
-
96
- return GEE(**kwargs)
97
-
98
- def get_collection(self):
123
+ def get_collection(self) -> ee.ImageCollection:
99
124
  """Returns the Earth Engine image collection for this data source."""
100
125
  image_collection = ee.ImageCollection(self.collection_name)
126
+ if self.filters is None:
127
+ return image_collection
128
+
101
129
  for k, v in self.filters:
102
130
  cur_filter = ee.Filter.eq(k, v)
103
131
  image_collection = image_collection.filter(cur_filter)
104
132
  return image_collection
105
133
 
106
- def _build_index(self, rtree_index):
134
+ def _build_index(self, rtree_index: RtreeIndex) -> None:
107
135
  csv_blob = self.bucket.blob(f"{self.collection_name}/index.csv")
108
136
 
109
137
  if not csv_blob.exists():
110
138
  # Export feature collection of image metadata to GCS.
111
- def image_to_feature(image):
112
- geometry = image.geometry().transform(proj="EPSG:4326")
139
+ def image_to_feature(image: ee.Image) -> ee.Feature:
140
+ geometry = image.geometry().transform(proj="EPSG:4326", maxError=0.001)
113
141
  return ee.Feature(geometry, {"time": image.date().format()})
114
142
 
115
143
  fc = self.get_collection().map(image_to_feature)
@@ -121,17 +149,23 @@ class GEE(DataSource):
121
149
  fileFormat="CSV",
122
150
  )
123
151
  task.start()
124
- print(
125
- "started task to export GEE index "
126
- + f"for image collection {self.collection_name}"
152
+ logger.info(
153
+ "Started task to export GEE index for image collection %s",
154
+ self.collection_name,
127
155
  )
128
156
  while True:
129
157
  time.sleep(10)
130
158
  status_dict = task.status()
131
- print(status_dict)
159
+ logger.debug(
160
+ "Waiting for export task to complete, current status is %s",
161
+ status_dict,
162
+ )
132
163
  if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
133
164
  continue
134
- assert status_dict["state"] == "COMPLETED"
165
+ elif status_dict["state"] != "COMPLETED":
166
+ raise ValueError(
167
+ f"got unexpected GEE task state {status_dict['state']}"
168
+ )
135
169
  break
136
170
 
137
171
  # Read the CSV and add rows into the rtree index.
@@ -141,15 +175,31 @@ class GEE(DataSource):
141
175
  shp = shapely.geometry.shape(json.loads(row[".geo"]))
142
176
  if "E" in row["time"]:
143
177
  unix_time = float(row["time"]) / 1000
144
- ts = datetime.fromtimestamp(unix_time, tz=timezone.utc)
178
+ ts = datetime.fromtimestamp(unix_time, tz=UTC)
145
179
  else:
146
- ts = datetime.fromisoformat(row["time"]).replace(
147
- tzinfo=timezone.utc
148
- )
180
+ ts = datetime.fromisoformat(row["time"]).replace(tzinfo=UTC)
149
181
  geometry = STGeometry(WGS84_PROJECTION, shp, (ts, ts))
150
182
  item = Item(row["system:index"], geometry)
151
183
  rtree_index.insert(shp.bounds, json.dumps(item.serialize()))
152
184
 
185
+ def get_item_by_name(self, name: str) -> Item:
186
+ """Gets an item by name.
187
+
188
+ Args:
189
+ name: the name of the item to get
190
+
191
+ Returns:
192
+ the item object
193
+ """
194
+ filtered = self.get_collection().filter(ee.Filter.eq("system:index", name))
195
+ image = filtered.first()
196
+ shp = shapely.geometry.shape(
197
+ image.geometry().transform(proj="EPSG:4326", maxError=0.001).getInfo()
198
+ )
199
+ ts = datetime.fromisoformat(image.date().format().getInfo()).replace(tzinfo=UTC)
200
+ geometry = STGeometry(WGS84_PROJECTION, shp, (ts, ts))
201
+ return Item(name, geometry)
202
+
153
203
  def get_items(
154
204
  self, geometries: list[STGeometry], query_config: QueryConfig
155
205
  ) -> list[list[list[Item]]]:
@@ -176,7 +226,7 @@ class GEE(DataSource):
176
226
  continue
177
227
  cur_items.append(item)
178
228
 
179
- cur_items.sort(key=lambda item: item.geometry.time_range[0])
229
+ cur_items.sort(key=lambda item: item.geometry.time_range[0]) # type: ignore
180
230
 
181
231
  cur_groups = rslearn.data_sources.utils.match_candidate_items_to_window(
182
232
  geometry, cur_items, query_config
@@ -190,9 +240,143 @@ class GEE(DataSource):
190
240
  assert isinstance(serialized_item, dict)
191
241
  return Item.deserialize(serialized_item)
192
242
 
243
+ def item_to_image(self, item: Item) -> ee.image.Image:
244
+ """Get the Image corresponding to the Item.
245
+
246
+ This function is separated so it can be overriden if subclasses want to add
247
+ modifications to the image.
248
+ """
249
+ filtered = self.get_collection().filter(ee.Filter.eq("system:index", item.name))
250
+ image = filtered.first()
251
+ image = image.select(self.bands)
252
+ return image
253
+
254
+ def export_item(
255
+ self,
256
+ item: Item,
257
+ blob_prefix: str,
258
+ projection_and_bounds: tuple[Projection, PixelBounds] | None = None,
259
+ ) -> None:
260
+ """Export the item to the specified folder.
261
+
262
+ Args:
263
+ item: the item to export.
264
+ blob_prefix: the prefix (folder) to use.
265
+ projection_and_bounds: optionally use this projection and bounds instead of
266
+ the extent of the image.
267
+ """
268
+ image = self.item_to_image(item)
269
+ projection = image.select(self.bands[0]).projection().getInfo()
270
+ logger.info("Starting task to retrieve image %s", item.name)
271
+
272
+ extent_kwargs: dict[str, Any]
273
+ if projection_and_bounds is not None:
274
+ projection, bounds = projection_and_bounds
275
+ transform = get_transform_from_projection_and_bounds(projection, bounds)
276
+ width = bounds[2] - bounds[0]
277
+ height = bounds[3] - bounds[1]
278
+ extent_kwargs = dict(
279
+ crs=str(projection.crs),
280
+ crsTransform=[
281
+ transform.a,
282
+ transform.b,
283
+ transform.c,
284
+ transform.d,
285
+ transform.e,
286
+ transform.f,
287
+ ],
288
+ dimensions=f"{width}x{height}",
289
+ )
290
+ else:
291
+ # Use the native projection of the image.
292
+ # We pass scale instead of crsTransform since some images have positive y
293
+ # resolution which means they are upside down and rasterio cannot merge
294
+ # them.
295
+ extent_kwargs = dict(
296
+ crs=projection["crs"],
297
+ scale=projection["transform"][0],
298
+ )
299
+
300
+ task = ee.batch.Export.image.toCloudStorage(
301
+ image=image,
302
+ description=item.name,
303
+ bucket=self.gcs_bucket_name,
304
+ fileNamePrefix=blob_prefix,
305
+ maxPixels=10000000000,
306
+ fileFormat="GeoTIFF",
307
+ skipEmptyTiles=True,
308
+ **extent_kwargs,
309
+ )
310
+ task.start()
311
+ while True:
312
+ time.sleep(10)
313
+ status_dict = task.status()
314
+ if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
315
+ continue
316
+ if status_dict["state"] == "COMPLETED":
317
+ break
318
+ if status_dict["state"] != "FAILED":
319
+ raise ValueError(
320
+ f"got unexpected GEE task state {status_dict['state']}"
321
+ )
322
+ # The task failed. We see if it is an okay failure case or if we need to
323
+ # raise exception.
324
+ if status_dict["error_message"] == NoValidPixelsException.GEE_MESSAGE:
325
+ raise NoValidPixelsException()
326
+ raise ExportException(f"GEE task failed: {status_dict['error_message']}")
327
+
328
+ def _merge_rasters(
329
+ self,
330
+ blobs: list[storage.Blob],
331
+ crs_bounds: tuple[float, float, float, float] | None = None,
332
+ res: float | None = None,
333
+ ) -> tuple[npt.NDArray, Projection, PixelBounds]:
334
+ """Merge multiple rasters split up during export by GEE.
335
+
336
+ GEE can produce multiple rasters if it determines the file size exceeds its
337
+ internal limit. So in this case we stitch them back together.
338
+
339
+ Args:
340
+ blobs: the list of GCS blobs where the rasters were written.
341
+ crs_bounds: generate merged output under this bounds, in CRS coordinates
342
+ (not pixel units).
343
+ res: generate merged output under this resolution.
344
+
345
+ Returns:
346
+ a tuple (array, projection, bounds) where the projection and bounds
347
+ indicate the extent of the array.
348
+ """
349
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
350
+ rasterio_datasets = []
351
+ for blob in blobs:
352
+ local_fname = os.path.join(tmp_dir_name, blob.name.split("/")[-1])
353
+ blob.download_to_filename(local_fname)
354
+ src = rasterio.open(local_fname)
355
+ rasterio_datasets.append(src)
356
+
357
+ merge_kwargs: dict[str, Any] = dict(
358
+ sources=rasterio_datasets,
359
+ bounds=crs_bounds,
360
+ res=res,
361
+ )
362
+ if self.dtype:
363
+ merge_kwargs["dtype"] = self.dtype.value
364
+ array, transform = rasterio.merge.merge(**merge_kwargs)
365
+ projection, bounds = get_raster_projection_and_bounds_from_transform(
366
+ rasterio_datasets[0].crs,
367
+ transform,
368
+ array.shape[2],
369
+ array.shape[1],
370
+ )
371
+
372
+ for ds in rasterio_datasets:
373
+ ds.close()
374
+
375
+ return array, projection, bounds
376
+
193
377
  def ingest(
194
378
  self,
195
- tile_store: TileStore,
379
+ tile_store: TileStoreWithLayer,
196
380
  items: list[Item],
197
381
  geometries: list[list[STGeometry]],
198
382
  ) -> None:
@@ -203,96 +387,238 @@ class GEE(DataSource):
203
387
  items: the items to ingest
204
388
  geometries: a list of geometries needed for each item
205
389
  """
206
- assert isinstance(self.config, RasterLayerConfig)
207
- bands = []
208
- for band_set in self.config.band_sets:
209
- for band in band_set.bands:
210
- if band in bands:
211
- continue
212
- bands.append(band)
213
-
214
- for item, cur_geometries in zip(items, geometries):
215
- cur_tile_store = PrefixedTileStore(tile_store, (item.name, "_".join(bands)))
216
- needed_projections = get_needed_projections(
217
- cur_tile_store, bands, self.config.band_sets, cur_geometries
218
- )
219
- if not needed_projections:
390
+ for item in items:
391
+ if tile_store.is_raster_ready(item.name, self.bands):
220
392
  continue
221
393
 
222
- filtered = self.get_collection().filter(
223
- ee.Filter.eq("system:index", item.name)
224
- )
225
- image = filtered.first()
226
- image = image.select(bands)
227
-
228
- # Use the native projection of the image to obtain the raster.
229
- projection = image.select(bands[0]).projection().getInfo()
230
- print(f"starting task to retrieve image {item.name}")
231
- blob_path = f"{self.collection_name}/{item.name}.{os.getpid()}/"
232
- task = ee.batch.Export.image.toCloudStorage(
233
- image=image,
234
- description=item.name,
235
- bucket=self.gcs_bucket_name,
236
- fileNamePrefix=blob_path,
237
- fileFormat="GeoTIFF",
238
- crs=projection["crs"],
239
- crsTransform=projection["transform"],
240
- maxPixels=10000000000,
241
- )
242
- task.start()
243
- while True:
244
- time.sleep(10)
245
- status_dict = task.status()
246
- if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
247
- continue
248
- assert status_dict["state"] == "COMPLETED"
249
- break
394
+ # Export the item to GCS.
395
+ blob_prefix = f"{self.collection_name}/{item.name}.{os.getpid()}/"
396
+ self.export_item(item, blob_prefix)
250
397
 
251
398
  # See what files the export produced.
252
399
  # If there are multiple, then we merge them into one file since that's the
253
400
  # simplest way to handle it.
254
- blobs = self.bucket.list_blobs(prefix=blob_path)
255
- blobs = list(blobs)
256
- raster = None
257
-
258
- if len(blobs) == 1:
259
- buf = io.BytesIO()
260
- blobs[0].download_to_file(buf)
261
- buf.seek(0)
262
- raster = rasterio.open(buf)
263
-
264
- else:
265
- with tempfile.TemporaryDirectory() as tmp_dir_name:
266
- rasterio_datasets = []
267
- for blob in blobs:
268
- local_fname = os.path.join(
269
- tmp_dir_name, blob.name.split("/")[-1]
270
- )
271
- blob.download_to_filename(local_fname)
272
- src = rasterio.open(local_fname)
273
- rasterio_datasets.append(src)
274
-
275
- merge_kwargs = {"datasets": rasterio_datasets}
276
- if self.dtype:
277
- merge_kwargs["dtype"] = self.dtype.value
278
- array, transform = rasterio.merge.merge(**merge_kwargs)
279
- crs = rasterio_datasets[0].crs
280
-
281
- for ds in rasterio_datasets:
282
- ds.close()
283
-
284
- raster = ArrayWithTransform(array, crs, transform)
285
-
286
- for projection in needed_projections:
287
- ingest_raster(
288
- tile_store=cur_tile_store,
289
- raster=raster,
290
- projection=projection,
291
- time_range=item.geometry.time_range,
292
- layer_config=self.config,
293
- )
401
+ blobs = list(self.bucket.list_blobs(prefix=blob_prefix))
402
+
403
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
404
+ if len(blobs) == 1:
405
+ local_fname = os.path.join(
406
+ tmp_dir_name, blobs[0].name.split("/")[-1]
407
+ )
408
+ blobs[0].download_to_filename(local_fname)
409
+ tile_store.write_raster_file(
410
+ item.name, self.bands, UPath(local_fname)
411
+ )
294
412
 
295
- raster.close()
413
+ else:
414
+ array, projection, bounds = self._merge_rasters(blobs)
415
+ tile_store.write_raster(
416
+ item.name, self.bands, projection, bounds, array
417
+ )
296
418
 
297
419
  for blob in blobs:
298
420
  blob.delete()
421
+
422
+ def is_raster_ready(
423
+ self, layer_name: str, item_name: str, bands: list[str]
424
+ ) -> bool:
425
+ """Checks if this raster has been written to the store.
426
+
427
+ Args:
428
+ layer_name: the layer name or alias.
429
+ item_name: the item.
430
+ bands: the list of bands identifying which specific raster to read.
431
+
432
+ Returns:
433
+ whether there is a raster in the store matching the source, item, and
434
+ bands.
435
+ """
436
+ # Always ready since we wrap accesses to Planetary Computer.
437
+ return True
438
+
439
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
440
+ """Get the sets of bands that have been stored for the specified item.
441
+
442
+ Args:
443
+ layer_name: the layer name or alias.
444
+ item_name: the item.
445
+
446
+ Returns:
447
+ a list of lists of bands that are in the tile store (with one raster
448
+ stored corresponding to each inner list). If no rasters are ready for
449
+ this item, returns empty list.
450
+ """
451
+ return [self.bands]
452
+
453
+ def get_raster_bounds(
454
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
455
+ ) -> PixelBounds:
456
+ """Get the bounds of the raster in the specified projection.
457
+
458
+ Args:
459
+ layer_name: the layer name or alias.
460
+ item_name: the item to check.
461
+ bands: the list of bands identifying which specific raster to read. These
462
+ bands must match the bands of a stored raster.
463
+ projection: the projection to get the raster's bounds in.
464
+
465
+ Returns:
466
+ the bounds of the raster in the projection.
467
+ """
468
+ item = self.get_item_by_name(item_name)
469
+ geom = item.geometry.to_projection(projection)
470
+ return (
471
+ int(geom.shp.bounds[0]),
472
+ int(geom.shp.bounds[1]),
473
+ int(geom.shp.bounds[2]),
474
+ int(geom.shp.bounds[3]),
475
+ )
476
+
477
+ def read_raster(
478
+ self,
479
+ layer_name: str,
480
+ item_name: str,
481
+ bands: list[str],
482
+ projection: Projection,
483
+ bounds: PixelBounds,
484
+ resampling: Resampling = Resampling.bilinear,
485
+ ) -> npt.NDArray[Any]:
486
+ """Read raster data from the store.
487
+
488
+ Args:
489
+ layer_name: the layer name or alias.
490
+ item_name: the item to read.
491
+ bands: the list of bands identifying which specific raster to read. These
492
+ bands must match the bands of a stored raster.
493
+ projection: the projection to read in.
494
+ bounds: the bounds to read.
495
+ resampling: the resampling method to use in case reprojection is needed.
496
+
497
+ Returns:
498
+ the raster data
499
+ """
500
+ # Extract the requested extent and export to GCS.
501
+ bounds_str = f"{bounds[0]}_{bounds[1]}_{bounds[2]}_{bounds[3]}"
502
+ item = self.get_item_by_name(item_name)
503
+ blob_prefix = f"{self.collection_name}/{item.name}.{bounds_str}.{os.getpid()}/"
504
+
505
+ try:
506
+ self.export_item(
507
+ item, blob_prefix, projection_and_bounds=(projection, bounds)
508
+ )
509
+ except NoValidPixelsException:
510
+ # No valid pixels means the result should be empty.
511
+ logger.info(
512
+ f"No valid pixels in item {item.name} with projection={projection}, bounds={bounds}, returning empty image"
513
+ )
514
+ return np.zeros(
515
+ (len(bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
516
+ dtype=np.float32,
517
+ )
518
+
519
+ wanted_transform = get_transform_from_projection_and_bounds(projection, bounds)
520
+ crs_bounds = (
521
+ bounds[0] * projection.x_resolution,
522
+ bounds[3] * projection.y_resolution,
523
+ bounds[2] * projection.x_resolution,
524
+ bounds[1] * projection.y_resolution,
525
+ )
526
+
527
+ blobs = list(self.bucket.list_blobs(prefix=blob_prefix))
528
+
529
+ if len(blobs) == 1:
530
+ # With a single output, we can simply read it with vrt.
531
+ buf = io.BytesIO()
532
+ blobs[0].download_to_file(buf)
533
+ buf.seek(0)
534
+ with rasterio.open(buf) as src:
535
+ with rasterio.vrt.WarpedVRT(
536
+ src,
537
+ crs=projection.crs,
538
+ transform=wanted_transform,
539
+ width=bounds[2] - bounds[0],
540
+ height=bounds[3] - bounds[1],
541
+ resampling=resampling,
542
+ ) as vrt:
543
+ return vrt.read()
544
+
545
+ else:
546
+ # With multiple outputs, we need to merge them together.
547
+ # We can set the bounds in CRS coordinates when we do the merging.
548
+ if projection.x_resolution != -projection.y_resolution:
549
+ raise NotImplementedError(
550
+ "Only projection with x_res=-y_res is supported for GEE direct materialization"
551
+ )
552
+ src_array, _, src_bounds = self._merge_rasters(
553
+ blobs, crs_bounds=crs_bounds, res=projection.x_resolution
554
+ )
555
+
556
+ # We copy the array if its bounds don't match exactly.
557
+ if src_bounds == bounds:
558
+ return src_array
559
+ dst_array = np.zeros(
560
+ (src_array.shape[0], bounds[3] - bounds[1], bounds[2] - bounds[0]),
561
+ dtype=src_array.dtype,
562
+ )
563
+ copy_spatial_array(src_array, dst_array, src_bounds[0:2], bounds[0:2])
564
+ return dst_array
565
+
566
+ def materialize(
567
+ self,
568
+ window: Window,
569
+ item_groups: list[list[Item]],
570
+ layer_name: str,
571
+ layer_cfg: LayerConfig,
572
+ ) -> None:
573
+ """Materialize data for the window.
574
+
575
+ Args:
576
+ window: the window to materialize
577
+ item_groups: the items from get_items
578
+ layer_name: the name of this layer
579
+ layer_cfg: the config of this layer
580
+ """
581
+ RasterMaterializer().materialize(
582
+ TileStoreWithLayer(self, layer_name),
583
+ window,
584
+ layer_name,
585
+ layer_cfg,
586
+ item_groups,
587
+ )
588
+
589
+
590
+ class GoogleSatelliteEmbeddings(GEE):
591
+ """GEE data source for the Google Satellite Embeddings.
592
+
593
+ See here for details:
594
+ https://developers.google.com/earth-engine/datasets/catalog/GOOGLE_SATELLITE_EMBEDDING_V1_ANNUAL
595
+ """
596
+
597
+ COLLECTION_NAME = "GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL"
598
+
599
+ def __init__(
600
+ self,
601
+ gcs_bucket_name: str,
602
+ index_cache_dir: str,
603
+ service_account_name: str,
604
+ service_account_credentials: str,
605
+ context: DataSourceContext = DataSourceContext(),
606
+ ):
607
+ """Create a new GoogleSatelliteEmbeddings. See GEE for the arguments."""
608
+ super().__init__(
609
+ bands=[f"A{idx:02d}" for idx in range(64)],
610
+ collection_name=self.COLLECTION_NAME,
611
+ gcs_bucket_name=gcs_bucket_name,
612
+ index_cache_dir=index_cache_dir,
613
+ service_account_name=service_account_name,
614
+ service_account_credentials=service_account_credentials,
615
+ context=context,
616
+ )
617
+
618
+ # Override to add conversion to uint16.
619
+ def item_to_image(self, item: Item) -> ee.image.Image:
620
+ """Get the Image corresponding to the Item."""
621
+ filtered = self.get_collection().filter(ee.Filter.eq("system:index", item.name))
622
+ image = filtered.first()
623
+ image = image.select(self.bands)
624
+ return image.multiply(8192).add(8192).toUint16()