rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,11 @@
2
2
 
3
3
  import io
4
4
  import json
5
+ import os
5
6
  import tempfile
6
7
  import xml.etree.ElementTree as ET
7
8
  from collections.abc import Callable, Generator
8
- from datetime import datetime, timedelta, timezone
9
+ from datetime import UTC, datetime
9
10
  from enum import Enum
10
11
  from typing import Any, BinaryIO
11
12
 
@@ -14,7 +15,6 @@ import dateutil.parser
14
15
  import fiona
15
16
  import fiona.transform
16
17
  import numpy.typing as npt
17
- import pytimeparse
18
18
  import rasterio
19
19
  import shapely
20
20
  import tqdm
@@ -22,27 +22,21 @@ from rasterio.crs import CRS
22
22
  from upath import UPath
23
23
 
24
24
  import rslearn.data_sources.utils
25
- import rslearn.utils.mgrs
26
- from rslearn.config import LayerConfig, RasterLayerConfig
27
25
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_EPSG, WGS84_PROJECTION
28
- from rslearn.tile_stores import PrefixedTileStore, TileStore
29
- from rslearn.utils import (
30
- GridIndex,
31
- Projection,
32
- STGeometry,
33
- daterange,
34
- )
26
+ from rslearn.tile_stores import TileStoreWithLayer
27
+ from rslearn.utils import GridIndex, Projection, STGeometry, daterange
35
28
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
29
+ from rslearn.utils.raster_format import get_raster_projection_and_bounds
36
30
 
37
- from .copernicus import get_harmonize_callback
31
+ from .copernicus import get_harmonize_callback, get_sentinel2_tiles
38
32
  from .data_source import (
39
33
  DataSource,
34
+ DataSourceContext,
40
35
  Item,
41
36
  ItemLookupDataSource,
42
37
  QueryConfig,
43
38
  RetrieveItemDataSource,
44
39
  )
45
- from .raster_source import get_needed_projections, ingest_raster
46
40
 
47
41
 
48
42
  class NaipItem(Item):
@@ -66,7 +60,7 @@ class NaipItem(Item):
66
60
  return d
67
61
 
68
62
  @staticmethod
69
- def deserialize(d: dict) -> Item:
63
+ def deserialize(d: dict) -> "NaipItem":
70
64
  """Deserializes an item from a JSON-decoded dictionary."""
71
65
  item = super(NaipItem, NaipItem).deserialize(d)
72
66
  return NaipItem(
@@ -89,16 +83,15 @@ class Naip(DataSource):
89
83
 
90
84
  def __init__(
91
85
  self,
92
- config: LayerConfig,
93
- index_cache_dir: UPath,
86
+ index_cache_dir: str,
94
87
  use_rtree_index: bool = False,
95
88
  states: list[str] | None = None,
96
89
  years: list[int] | None = None,
90
+ context: DataSourceContext = DataSourceContext(),
97
91
  ) -> None:
98
92
  """Initialize a new Naip instance.
99
93
 
100
94
  Args:
101
- config: the LayerConfig of the layer containing this data source.
102
95
  index_cache_dir: directory to cache index shapefiles.
103
96
  use_rtree_index: whether to create an rtree index to enable faster lookups
104
97
  (default false)
@@ -106,40 +99,30 @@ class Naip(DataSource):
106
99
  the search. If use_rtree_index is enabled, the rtree will only be
107
100
  populated with data from these states.
108
101
  years: optional list of years to restrict the search
102
+ context: the data source context.
109
103
  """
110
- self.config = config
111
- self.index_cache_dir = index_cache_dir
104
+ # If context is provided, we join the directory with the dataset path,
105
+ # otherwise we treat it directly as UPath.
106
+ if context.ds_path is not None:
107
+ self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
108
+ else:
109
+ self.index_cache_dir = UPath(index_cache_dir)
110
+
112
111
  self.states = states
113
112
  self.years = years
114
113
 
115
- self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
114
+ self.index_cache_dir.mkdir(parents=True, exist_ok=True)
116
115
 
116
+ self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
117
+ self.rtree_index: Any | None = None
117
118
  if use_rtree_index:
118
119
  from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
119
120
 
120
- def build_fn(index: RtreeIndex):
121
+ def build_fn(index: RtreeIndex) -> None:
121
122
  for item in self._read_index_shapefiles(desc="Building rtree index"):
122
123
  index.insert(item.geometry.shp.bounds, json.dumps(item.serialize()))
123
124
 
124
- self.rtree_tmp_dir = tempfile.TemporaryDirectory()
125
- self.rtree_index = get_cached_rtree(
126
- self.index_cache_dir, self.rtree_tmp_dir.name, build_fn
127
- )
128
- else:
129
- self.rtree_index = None
130
-
131
- @staticmethod
132
- def from_config(config: LayerConfig, ds_path: UPath) -> "Naip":
133
- """Creates a new Naip instance from a configuration dictionary."""
134
- assert isinstance(config, RasterLayerConfig)
135
- d = config.data_source.config_dict
136
- kwargs = dict(
137
- config=config,
138
- index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
139
- )
140
- if "use_rtree_index" in d:
141
- kwargs["use_rtree_index"] = d["use_rtree_index"]
142
- return Naip(**kwargs)
125
+ self.rtree_index = get_cached_rtree(self.index_cache_dir, build_fn)
143
126
 
144
127
  def _download_manifest(self) -> UPath:
145
128
  """Download the manifest that enumerates files in the bucket.
@@ -149,7 +132,7 @@ class Naip(DataSource):
149
132
  """
150
133
  manifest_path = self.index_cache_dir / self.manifest_fname
151
134
  if not manifest_path.exists():
152
- with manifest_path.open("wb") as dst:
135
+ with open_atomic(manifest_path, "wb") as dst:
153
136
  self.bucket.download_fileobj(
154
137
  self.manifest_fname,
155
138
  dst,
@@ -195,7 +178,9 @@ class Naip(DataSource):
195
178
  blob_path, dst, ExtraArgs={"RequestPayer": "requester"}
196
179
  )
197
180
 
198
- def _read_index_shapefiles(self, desc=None) -> Generator[NaipItem, None, None]:
181
+ def _read_index_shapefiles(
182
+ self, desc: str | None = None
183
+ ) -> Generator[NaipItem, None, None]:
199
184
  """Read the index shapefiles and yield NaipItems corresponding to each image."""
200
185
  self._download_index_shapefiles()
201
186
 
@@ -275,7 +260,7 @@ class Naip(DataSource):
275
260
  else:
276
261
  src_img_date = fname_parts[5]
277
262
  time = datetime.strptime(src_img_date, "%Y%m%d").replace(
278
- tzinfo=timezone.utc
263
+ tzinfo=UTC
279
264
  )
280
265
 
281
266
  geometry = STGeometry(WGS84_PROJECTION, shp, (time, time))
@@ -288,7 +273,7 @@ class Naip(DataSource):
288
273
 
289
274
  def get_items(
290
275
  self, geometries: list[STGeometry], query_config: QueryConfig
291
- ) -> list[list[list[Item]]]:
276
+ ) -> list[list[list[NaipItem]]]:
292
277
  """Get a list of items in the data source intersecting the given geometries.
293
278
 
294
279
  Args:
@@ -302,7 +287,7 @@ class Naip(DataSource):
302
287
  geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
303
288
  ]
304
289
 
305
- items = [[] for _ in geometries]
290
+ items: list = [[] for _ in geometries]
306
291
  if self.rtree_index:
307
292
  for idx, geometry in enumerate(wgs84_geometries):
308
293
  encoded_items = self.rtree_index.query(geometry.shp.bounds)
@@ -331,15 +316,15 @@ class Naip(DataSource):
331
316
  groups.append(cur_groups)
332
317
  return groups
333
318
 
334
- def deserialize_item(self, serialized_item: Any) -> Item:
319
+ def deserialize_item(self, serialized_item: Any) -> NaipItem:
335
320
  """Deserializes an item from JSON-decoded data."""
336
321
  assert isinstance(serialized_item, dict)
337
322
  return NaipItem.deserialize(serialized_item)
338
323
 
339
324
  def ingest(
340
325
  self,
341
- tile_store: TileStore,
342
- items: list[Item],
326
+ tile_store: TileStoreWithLayer,
327
+ items: list[NaipItem],
343
328
  geometries: list[list[STGeometry]],
344
329
  ) -> None:
345
330
  """Ingest items into the given tile store.
@@ -349,29 +334,17 @@ class Naip(DataSource):
349
334
  items: the items to ingest
350
335
  geometries: a list of geometries needed for each item
351
336
  """
352
- for item, cur_geometries in zip(items, geometries):
337
+ for item in items:
353
338
  bands = ["R", "G", "B", "IR"]
354
- cur_tile_store = PrefixedTileStore(tile_store, (item.name, "_".join(bands)))
355
- needed_projections = get_needed_projections(
356
- cur_tile_store, bands, self.config.band_sets, cur_geometries
357
- )
358
- if not needed_projections:
339
+ if tile_store.is_raster_ready(item.name, bands):
359
340
  continue
360
341
 
361
- buf = io.BytesIO()
362
- self.bucket.download_fileobj(
363
- item.blob_path, buf, ExtraArgs={"RequestPayer": "requester"}
364
- )
365
- buf.seek(0)
366
- with rasterio.open(buf) as raster:
367
- for projection in needed_projections:
368
- ingest_raster(
369
- tile_store=cur_tile_store,
370
- raster=raster,
371
- projection=projection,
372
- time_range=item.geometry.time_range,
373
- layer_config=self.config,
374
- )
342
+ with tempfile.TemporaryDirectory() as tmp_dir:
343
+ fname = os.path.join(tmp_dir, item.blob_path.split("/")[-1])
344
+ self.bucket.download_file(
345
+ item.blob_path, fname, ExtraArgs={"RequestPayer": "requester"}
346
+ )
347
+ tile_store.write_raster_file(item.name, bands, UPath(fname))
375
348
 
376
349
 
377
350
  class Sentinel2Modality(Enum):
@@ -407,7 +380,7 @@ class Sentinel2Item(Item):
407
380
  return d
408
381
 
409
382
  @staticmethod
410
- def deserialize(d: dict) -> Item:
383
+ def deserialize(d: dict) -> "Sentinel2Item":
411
384
  """Deserializes an item from a JSON-decoded dictionary."""
412
385
  if "name" not in d:
413
386
  d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
@@ -420,7 +393,9 @@ class Sentinel2Item(Item):
420
393
  )
421
394
 
422
395
 
423
- class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
396
+ class Sentinel2(
397
+ ItemLookupDataSource[Sentinel2Item], RetrieveItemDataSource[Sentinel2Item]
398
+ ):
424
399
  """A data source for Sentinel-2 L1C and L2A imagery on AWS.
425
400
 
426
401
  Specifically, uses the sentinel-s2-l1c and sentinel-s2-l2a S3 buckets maintained by
@@ -474,61 +449,39 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
474
449
 
475
450
  def __init__(
476
451
  self,
477
- config: LayerConfig,
478
452
  modality: Sentinel2Modality,
479
- metadata_cache_dir: UPath,
480
- max_time_delta: timedelta = timedelta(days=30),
453
+ metadata_cache_dir: str,
481
454
  sort_by: str | None = None,
482
455
  harmonize: bool = False,
456
+ context: DataSourceContext = DataSourceContext(),
483
457
  ) -> None:
484
458
  """Initialize a new Sentinel2 instance.
485
459
 
486
460
  Args:
487
- config: the LayerConfig of the layer containing this data source.
488
461
  modality: L1C or L2A.
489
462
  metadata_cache_dir: directory to cache product metadata files.
490
- max_time_delta: maximum time before a query start time or after a
491
- query end time to look for products. This is required due to the large
492
- number of available products, and defaults to 30 days.
493
463
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
494
464
  SpaceMode.WITHIN.
495
465
  harmonize: harmonize pixel values across different processing baselines,
496
466
  see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
467
+ context: the data source context.
497
468
  """ # noqa: E501
498
- self.config = config
469
+ # If context is provided, we join the directory with the dataset path,
470
+ # otherwise we treat it directly as UPath.
471
+ if context.ds_path is not None:
472
+ self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
473
+ else:
474
+ self.metadata_cache_dir = UPath(metadata_cache_dir)
475
+
499
476
  self.modality = modality
500
- self.metadata_cache_dir = metadata_cache_dir
501
- self.max_time_delta = max_time_delta
502
477
  self.sort_by = sort_by
503
478
  self.harmonize = harmonize
504
479
 
505
480
  bucket_name = self.bucket_names[modality]
506
481
  self.bucket = boto3.resource("s3").Bucket(bucket_name)
507
482
 
508
- @staticmethod
509
- def from_config(config: LayerConfig, ds_path: UPath) -> "Sentinel2":
510
- """Creates a new Sentinel2 instance from a configuration dictionary."""
511
- assert isinstance(config, RasterLayerConfig)
512
- d = config.data_source.config_dict
513
- kwargs = dict(
514
- config=config,
515
- modality=Sentinel2Modality(d["modality"]),
516
- metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
517
- )
518
-
519
- if "max_time_delta" in d:
520
- kwargs["max_time_delta"] = timedelta(
521
- seconds=pytimeparse.parse(d["max_time_delta"])
522
- )
523
- simple_optionals = ["sort_by", "harmonize"]
524
- for k in simple_optionals:
525
- if k in d:
526
- kwargs[k] = d[k]
527
-
528
- return Sentinel2(**kwargs)
529
-
530
483
  def _read_products(
531
- self, needed_cell_months: set[tuple[str, int, int, int]]
484
+ self, needed_cell_months: set[tuple[str, int, int]]
532
485
  ) -> Generator[Sentinel2Item, None, None]:
533
486
  """Read productInfo.json files and yield relevant Sentinel2Items.
534
487
 
@@ -603,7 +556,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
603
556
 
604
557
  def get_items(
605
558
  self, geometries: list[STGeometry], query_config: QueryConfig
606
- ) -> list[list[list[Item]]]:
559
+ ) -> list[list[list[Sentinel2Item]]]:
607
560
  """Get a list of items in the data source intersecting the given geometries.
608
561
 
609
562
  Args:
@@ -626,14 +579,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
626
579
  raise ValueError(
627
580
  "Sentinel2 on AWS requires geometry time ranges to be set"
628
581
  )
629
- for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds):
582
+ for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
630
583
  for ts in daterange(
631
- wgs84_geometry.time_range[0] - self.max_time_delta,
632
- wgs84_geometry.time_range[1] + self.max_time_delta,
584
+ wgs84_geometry.time_range[0],
585
+ wgs84_geometry.time_range[1],
633
586
  ):
634
587
  needed_cell_months.add((cell_id, ts.year, ts.month))
635
588
 
636
- items_by_cell = {}
589
+ items_by_cell: dict[str, list[Sentinel2Item]] = {}
637
590
  for item in self._read_products(needed_cell_months):
638
591
  cell_id = "".join(item.blob_path.split("/")[1:4])
639
592
  if cell_id not in items_by_cell:
@@ -643,7 +596,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
643
596
  groups = []
644
597
  for geometry, wgs84_geometry in zip(geometries, wgs84_geometries):
645
598
  items = []
646
- for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds):
599
+ for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
647
600
  for item in items_by_cell.get(cell_id, []):
648
601
  try:
649
602
  item_geom = item.geometry.to_projection(geometry.projection)
@@ -666,7 +619,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
666
619
 
667
620
  return groups
668
621
 
669
- def get_item_by_name(self, name: str) -> Item:
622
+ def get_item_by_name(self, name: str) -> Sentinel2Item:
670
623
  """Gets an item by name."""
671
624
  # Product name is like:
672
625
  # S2B_MSIL1C_20240201T230819_N0510_R015_T51CWM_20240202T012755.
@@ -685,12 +638,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
685
638
  return item
686
639
  raise ValueError(f"item {name} not found")
687
640
 
688
- def deserialize_item(self, serialized_item: Any) -> Item:
641
+ def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
689
642
  """Deserializes an item from JSON-decoded data."""
690
643
  assert isinstance(serialized_item, dict)
691
644
  return Sentinel2Item.deserialize(serialized_item)
692
645
 
693
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
646
+ def retrieve_item(
647
+ self, item: Sentinel2Item
648
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
694
649
  """Retrieves the rasters corresponding to an item as file streams."""
695
650
  for fname, _ in self.band_fnames[self.modality]:
696
651
  buf = io.BytesIO()
@@ -701,7 +656,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
701
656
  yield (fname, buf)
702
657
 
703
658
  def _get_harmonize_callback(
704
- self, item: Item
659
+ self, item: Sentinel2Item
705
660
  ) -> Callable[[npt.NDArray], npt.NDArray] | None:
706
661
  """Gets the harmonization callback for the given item.
707
662
 
@@ -715,6 +670,8 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
715
670
  return None
716
671
  # Search metadata XML for the RADIO_ADD_OFFSET tag.
717
672
  # This contains the per-band offset, but we assume all bands have the same offset.
673
+ if item.geometry.time_range is None:
674
+ raise ValueError("Sentinel2 on AWS requires geometry time ranges to be set")
718
675
  ts = item.geometry.time_range[0]
719
676
  metadata_fname = (
720
677
  f"products/{ts.year}/{ts.month}/{ts.day}/{item.name}/metadata.xml"
@@ -724,13 +681,15 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
724
681
  metadata_fname, buf, ExtraArgs={"RequestPayer": "requester"}
725
682
  )
726
683
  buf.seek(0)
727
- tree = ET.ElementTree(ET.fromstring(buf.getvalue()))
684
+ tree: ET.ElementTree[ET.Element[str]] = ET.ElementTree(
685
+ ET.fromstring(buf.getvalue())
686
+ )
728
687
  return get_harmonize_callback(tree)
729
688
 
730
689
  def ingest(
731
690
  self,
732
- tile_store: TileStore,
733
- items: list[Item],
691
+ tile_store: TileStoreWithLayer,
692
+ items: list[Sentinel2Item],
734
693
  geometries: list[list[STGeometry]],
735
694
  ) -> None:
736
695
  """Ingest items into the given tile store.
@@ -740,42 +699,43 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
740
699
  items: the items to ingest
741
700
  geometries: a list of geometries needed for each item
742
701
  """
743
- for item, cur_geometries in zip(items, geometries):
744
- harmonize_callback = self._get_harmonize_callback(item)
745
-
702
+ for item in items:
746
703
  for fname, band_names in self.band_fnames[self.modality]:
747
- cur_tile_store = PrefixedTileStore(
748
- tile_store, (item.name, "_".join(band_names))
749
- )
750
- needed_projections = get_needed_projections(
751
- cur_tile_store, band_names, self.config.band_sets, cur_geometries
752
- )
753
- if not needed_projections:
704
+ if tile_store.is_raster_ready(item.name, band_names):
754
705
  continue
755
706
 
756
- buf = io.BytesIO()
757
- try:
758
- self.bucket.download_fileobj(
759
- item.blob_path + fname,
760
- buf,
761
- ExtraArgs={"RequestPayer": "requester"},
762
- )
763
- except Exception as e:
764
- # TODO: sometimes for some reason object doesn't exist
765
- # we should probably investigate further why it happens
766
- # and then should create the layer here and mark it completed
767
- print(
768
- f"warning: got error {e} downloading {item.blob_path + fname}"
769
- )
770
- continue
771
- buf.seek(0)
772
- with rasterio.open(buf) as raster:
773
- for projection in needed_projections:
774
- ingest_raster(
775
- tile_store=cur_tile_store,
776
- raster=raster,
777
- projection=projection,
778
- time_range=item.geometry.time_range,
779
- layer_config=self.config,
780
- array_callback=harmonize_callback,
707
+ with tempfile.TemporaryDirectory() as tmp_dir:
708
+ local_fname = os.path.join(tmp_dir, fname.split("/")[-1])
709
+
710
+ try:
711
+ self.bucket.download_file(
712
+ item.blob_path + fname,
713
+ local_fname,
714
+ ExtraArgs={"RequestPayer": "requester"},
715
+ )
716
+ except Exception as e:
717
+ # TODO: sometimes for some reason object doesn't exist
718
+ # we should probably investigate further why it happens
719
+ # and then should create the layer here and mark it completed
720
+ print(
721
+ f"warning: got error {e} downloading {item.blob_path + fname}"
722
+ )
723
+ continue
724
+
725
+ harmonize_callback = self._get_harmonize_callback(item)
726
+
727
+ if harmonize_callback is not None:
728
+ # In this case we need to read the array, convert the pixel
729
+ # values, and pass modified array directly to the TileStore.
730
+ with rasterio.open(local_fname) as src:
731
+ array = src.read()
732
+ projection, bounds = get_raster_projection_and_bounds(src)
733
+ array = harmonize_callback(array)
734
+ tile_store.write_raster(
735
+ item.name, band_names, projection, bounds, array
736
+ )
737
+
738
+ else:
739
+ tile_store.write_raster_file(
740
+ item.name, band_names, UPath(local_fname)
781
741
  )
@@ -0,0 +1,131 @@
1
+ """Data source for Sentinel-1 on AWS."""
2
+
3
+ import os
4
+ import tempfile
5
+ from typing import Any
6
+
7
+ import boto3
8
+ from upath import UPath
9
+
10
+ from rslearn.data_sources.copernicus import (
11
+ CopernicusItem,
12
+ Sentinel1OrbitDirection,
13
+ Sentinel1Polarisation,
14
+ Sentinel1ProductType,
15
+ )
16
+ from rslearn.data_sources.copernicus import Sentinel1 as CopernicusSentinel1
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
19
+ from rslearn.utils.geometry import STGeometry
20
+
21
+ from .data_source import DataSource, DataSourceContext, QueryConfig
22
+
23
+ WRS2_GRID_SIZE = 1.0
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ class Sentinel1(DataSource, TileStore):
29
+ """A data source for Sentinel-1 GRD imagery on AWS.
30
+
31
+ Specifically, uses the sentinel-s1-l1c S3 bucket maintained by Sinergise. See
32
+ https://aws.amazon.com/marketplace/pp/prodview-uxrsbvhd35ifw for details about the
33
+ bucket.
34
+
35
+ We use the Copernicus API for metadata search. So the bucket is only used for
36
+ downloading the images.
37
+
38
+ Currently, it only supports GRD IW DV scenes.
39
+ """
40
+
41
+ bucket_name = "sentinel-s1-l1c"
42
+ bands = ["vv", "vh"]
43
+
44
+ def __init__(
45
+ self,
46
+ orbit_direction: Sentinel1OrbitDirection | None = None,
47
+ context: DataSourceContext = DataSourceContext(),
48
+ ) -> None:
49
+ """Initialize a new Sentinel1 instance.
50
+
51
+ Args:
52
+ orbit_direction: optional orbit direction to filter by.
53
+ context: the data source context.
54
+ """
55
+ self.client = boto3.client("s3")
56
+ self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
57
+ self.sentinel1 = CopernicusSentinel1(
58
+ product_type=Sentinel1ProductType.IW_GRDH,
59
+ polarisation=Sentinel1Polarisation.VV_VH,
60
+ orbit_direction=orbit_direction,
61
+ )
62
+
63
+ def get_items(
64
+ self, geometries: list[STGeometry], query_config: QueryConfig
65
+ ) -> list[list[list[CopernicusItem]]]:
66
+ """Get a list of items in the data source intersecting the given geometries.
67
+
68
+ Args:
69
+ geometries: the spatiotemporal geometries
70
+ query_config: the query configuration
71
+
72
+ Returns:
73
+ List of groups of items that should be retrieved for each geometry.
74
+ """
75
+ return self.sentinel1.get_items(geometries, query_config)
76
+
77
+ def get_item_by_name(self, name: str) -> CopernicusItem:
78
+ """Gets an item by name."""
79
+ return self.sentinel1.get_item_by_name(name)
80
+
81
+ def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
82
+ """Deserializes an item from JSON-decoded data."""
83
+ assert isinstance(serialized_item, dict)
84
+ return CopernicusItem.deserialize(serialized_item)
85
+
86
+ def ingest(
87
+ self,
88
+ tile_store: TileStoreWithLayer,
89
+ items: list[CopernicusItem],
90
+ geometries: list[list[STGeometry]],
91
+ ) -> None:
92
+ """Ingest items into the given tile store.
93
+
94
+ Args:
95
+ tile_store: the tile store to ingest into
96
+ items: the items to ingest
97
+ geometries: a list of geometries needed for each item
98
+ """
99
+ for item in items:
100
+ for band in self.bands:
101
+ band_names = [band]
102
+ if tile_store.is_raster_ready(item.name, band_names):
103
+ continue
104
+
105
+ # Item name is like "S1C_IW_GRDH_1SDV_20250528T172106_20250528T172131_002534_00545C_B433.SAFE".
106
+ item_name_prefix = item.name.split(".")[0]
107
+ time_str = item_name_prefix.split("_")[4]
108
+ if len(time_str) != 15:
109
+ raise ValueError(
110
+ f"expected 15-character time string but got {time_str}"
111
+ )
112
+ # We convert to int here since path in bucket isn't padded with leading 0s.
113
+ year = int(time_str[0:4])
114
+ month = int(time_str[4:6])
115
+ day = int(time_str[6:8])
116
+ blob_path = f"GRD/{year}/{month}/{day}/IW/DV/{item_name_prefix}/measurement/iw-{band}.tiff"
117
+
118
+ with tempfile.TemporaryDirectory() as tmp_dir:
119
+ fname = os.path.join(tmp_dir, f"{band}.tif")
120
+ try:
121
+ self.bucket.download_file(
122
+ blob_path,
123
+ fname,
124
+ ExtraArgs={"RequestPayer": "requester"},
125
+ )
126
+ except:
127
+ logger.error(
128
+ f"encountered error while downloading s3://{self.bucket_name}/{blob_path}"
129
+ )
130
+ raise
131
+ tile_store.write_raster_file(item.name, band_names, UPath(fname))