rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/const.py CHANGED
@@ -1,23 +1,17 @@
1
1
  """Constants."""
2
2
 
3
- from rasterio.crs import CRS
4
-
5
- from rslearn.utils import PixelBounds, Projection
6
-
7
- WGS84_EPSG = 4326
8
- """The EPSG code for WGS-84."""
9
-
10
- WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
11
- """The Projection for WGS-84 assuming 1 degree per pixel.
12
-
13
- This can be used to create STGeometry with shapes in longitude/latitude coordinates.
14
- """
15
-
16
- WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
17
- """The bounds of the WGS-84 projection."""
3
+ from rslearn.utils.geometry import WGS84_BOUNDS, WGS84_EPSG, WGS84_PROJECTION
18
4
 
19
5
  TILE_SIZE = 512
20
6
  """Default tile size. TODO: remove this or move it elsewhere."""
21
7
 
22
8
  SHAPEFILE_AUX_EXTENSIONS = [".cpg", ".dbf", ".prj", ".sbn", ".sbx", ".shx", ".txt"]
23
9
  """Extensions of potential auxiliary files to .shp file."""
10
+
11
+ __all__ = (
12
+ "WGS84_PROJECTION",
13
+ "WGS84_EPSG",
14
+ "WGS84_BOUNDS",
15
+ "TILE_SIZE",
16
+ "SHAPEFILE_AUX_EXTENSIONS",
17
+ )
@@ -10,32 +10,17 @@ Each source supports operations to lookup items that match with spatiotemporal
10
10
  geometries, and ingest those items.
11
11
  """
12
12
 
13
- import importlib
14
-
15
- from upath import UPath
16
-
17
- from rslearn.config import LayerConfig
18
-
19
- from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource
20
-
21
-
22
- def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
23
- """Loads a data source from config dict.
24
-
25
- Args:
26
- config: the LayerConfig containing this data source.
27
- ds_path: the dataset root directory.
28
- """
29
- name = config.data_source.name
30
- module_name = ".".join(name.split(".")[:-1])
31
- class_name = name.split(".")[-1]
32
- module = importlib.import_module(module_name)
33
- class_ = getattr(module, class_name)
34
- return class_.from_config(config, ds_path)
35
-
13
+ from .data_source import (
14
+ DataSource,
15
+ DataSourceContext,
16
+ Item,
17
+ ItemLookupDataSource,
18
+ RetrieveItemDataSource,
19
+ )
36
20
 
37
21
  __all__ = (
38
22
  "DataSource",
23
+ "DataSourceContext",
39
24
  "Item",
40
25
  "ItemLookupDataSource",
41
26
  "RetrieveItemDataSource",
@@ -2,33 +2,41 @@
2
2
 
3
3
  import io
4
4
  import json
5
+ import os
5
6
  import shutil
7
+ import tempfile
6
8
  import urllib.request
7
9
  import zipfile
8
10
  from collections.abc import Generator
9
- from datetime import timedelta
11
+ from datetime import datetime
10
12
  from typing import Any, BinaryIO
11
13
 
14
+ import affine
12
15
  import boto3
13
16
  import dateutil.parser
14
17
  import fiona
15
18
  import fiona.transform
16
- import pytimeparse
19
+ import numpy.typing as npt
17
20
  import rasterio
18
21
  import shapely
22
+ import shapely.geometry
19
23
  import tqdm
24
+ from rasterio.enums import Resampling
20
25
  from upath import UPath
21
26
 
22
27
  import rslearn.data_sources.utils
23
- import rslearn.utils.mgrs
24
- from rslearn.config import LayerConfig, RasterLayerConfig
28
+ from rslearn.config import LayerConfig
25
29
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
26
- from rslearn.tile_stores import PrefixedTileStore, TileStore
27
- from rslearn.utils import STGeometry
30
+ from rslearn.dataset import Window
31
+ from rslearn.dataset.materialize import RasterMaterializer
32
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
28
33
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
34
+ from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
35
+ from rslearn.utils.grid_index import GridIndex
29
36
 
30
- from .data_source import DataSource, Item, QueryConfig
31
- from .raster_source import get_needed_projections, ingest_raster
37
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
38
+
39
+ WRS2_GRID_SIZE = 1.0
32
40
 
33
41
 
34
42
  class LandsatOliTirsItem(Item):
@@ -36,7 +44,7 @@ class LandsatOliTirsItem(Item):
36
44
 
37
45
  def __init__(
38
46
  self, name: str, geometry: STGeometry, blob_path: str, cloud_cover: float
39
- ):
47
+ ) -> None:
40
48
  """Creates a new LandsatOliTirsItem.
41
49
 
42
50
  Args:
@@ -58,7 +66,7 @@ class LandsatOliTirsItem(Item):
58
66
  return d
59
67
 
60
68
  @staticmethod
61
- def deserialize(d: dict) -> Item:
69
+ def deserialize(d: dict) -> "LandsatOliTirsItem":
62
70
  """Deserializes an item from a JSON-decoded dictionary."""
63
71
  if "name" not in d:
64
72
  d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
@@ -71,7 +79,7 @@ class LandsatOliTirsItem(Item):
71
79
  )
72
80
 
73
81
 
74
- class LandsatOliTirs(DataSource):
82
+ class LandsatOliTirs(DataSource, TileStore):
75
83
  """A data source for Landsat 8/9 OLI-TIRS imagery on AWS.
76
84
 
77
85
  Specifically, uses the usgs-landsat S3 bucket maintained by USGS. The data includes
@@ -90,53 +98,37 @@ class LandsatOliTirs(DataSource):
90
98
 
91
99
  def __init__(
92
100
  self,
93
- config: LayerConfig,
94
- metadata_cache_dir: UPath,
95
- max_time_delta: timedelta = timedelta(days=30),
101
+ metadata_cache_dir: str,
96
102
  sort_by: str | None = None,
103
+ context: DataSourceContext = DataSourceContext(),
97
104
  ) -> None:
98
105
  """Initialize a new LandsatOliTirs instance.
99
106
 
100
107
  Args:
101
- config: configuration of this layer
102
- metadata_cache_dir: directory to cache product metadata files.
103
- max_time_delta: maximum time before a query start time or after a
104
- query end time to look for products. This is required due to the large
105
- number of available products, and defaults to 30 days.
108
+ metadata_cache_dir: directory to cache produtc metadata files.
106
109
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
107
110
  SpaceMode.WITHIN.
111
+ context: the data source context.
108
112
  """
109
- self.config = config
110
- self.metadata_cache_dir = metadata_cache_dir
111
- self.max_time_delta = max_time_delta
113
+ # If context is provided, we join the directory with the dataset path,
114
+ # otherwise we treat it directly as UPath.
115
+ if context.ds_path is not None:
116
+ self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
117
+ else:
118
+ self.metadata_cache_dir = UPath(metadata_cache_dir)
119
+
112
120
  self.sort_by = sort_by
113
121
 
122
+ self.client = boto3.client("s3")
114
123
  self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
115
-
116
124
  self.metadata_cache_dir.mkdir(parents=True, exist_ok=True)
117
125
 
118
- @staticmethod
119
- def from_config(config: LayerConfig, ds_path: UPath) -> "LandsatOliTirs":
120
- """Creates a new LandsatOliTirs instance from a configuration dictionary."""
121
- assert isinstance(config, RasterLayerConfig)
122
- d = config.data_source.config_dict
123
- kwargs = dict(
124
- config=config,
125
- metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
126
- )
127
- if "max_time_delta" in d:
128
- kwargs["max_time_delta"] = timedelta(
129
- seconds=pytimeparse.parse(d["max_time_delta"])
130
- )
131
- if "sort_by" in d:
132
- kwargs["sort_by"] = d["sort_by"]
133
-
134
- return LandsatOliTirs(**kwargs)
126
+ self.wrs2_index: GridIndex | None = None
135
127
 
136
128
  def _read_products(
137
129
  self, needed_year_pathrows: set[tuple[int, str, str]]
138
130
  ) -> Generator[LandsatOliTirsItem, None, None]:
139
- """Read _MTL.json files and yield relevant LandsatOliTirsItems.
131
+ """Read _stac.json files and yield relevant LandsatOliTirsItems.
140
132
 
141
133
  Args:
142
134
  needed_year_pathrows: set of (year, path, row) where we need to search for
@@ -155,7 +147,10 @@ class LandsatOliTirs(DataSource):
155
147
  for obj in self.bucket.objects.filter(
156
148
  Prefix=prefix, RequestPayer="requester"
157
149
  ):
158
- if not obj.key.endswith("_MTL.json"):
150
+ # Only read the _stac.json files.
151
+ # Previously we used _MTL.json but those files don't have the full
152
+ # geometry of the Landsat scene, only the bounding box.
153
+ if not obj.key.endswith("_stac.json"):
159
154
  continue
160
155
  # Load JSON data.
161
156
  buf = io.BytesIO()
@@ -163,33 +158,32 @@ class LandsatOliTirs(DataSource):
163
158
  obj.key, buf, ExtraArgs={"RequestPayer": "requester"}
164
159
  )
165
160
  buf.seek(0)
166
- product = json.load(buf)
167
- metadata = product["LANDSAT_METADATA_FILE"]
168
- image_attributes = metadata["IMAGE_ATTRIBUTES"]
169
- projection_attributes = metadata["PROJECTION_ATTRIBUTES"]
161
+ stac_data = json.load(buf)
170
162
 
171
163
  # Get polygon coordinates.
172
- coordinates = []
173
- for corner_id in ["UL", "UR", "LR", "LL"]:
174
- lon = projection_attributes[f"CORNER_{corner_id}_LON_PRODUCT"]
175
- lat = projection_attributes[f"CORNER_{corner_id}_LAT_PRODUCT"]
176
- coordinates.append((lon, lat))
164
+ shp = shapely.geometry.shape(stac_data["geometry"])
177
165
 
178
166
  # Get datetime.
179
- date_str = image_attributes["DATE_ACQUIRED"]
180
- time_str = image_attributes["SCENE_CENTER_TIME"]
181
- ts = dateutil.parser.isoparse(date_str + "T" + time_str)
182
-
183
- blob_path = obj.key.split("MTL.json")[0]
184
- geometry = STGeometry(
185
- WGS84_PROJECTION, shapely.Polygon(coordinates), [ts, ts]
186
- )
167
+ ts = dateutil.parser.isoparse(stac_data["properties"]["datetime"])
168
+
169
+ blob_path = obj.key.split("stac.json")[0]
170
+ time_range: tuple[datetime, datetime] = (ts, ts)
171
+ geometry = STGeometry(WGS84_PROJECTION, shp, time_range)
172
+ cloud_cover: float
173
+ if "eo:cloud_cover" in stac_data["properties"]:
174
+ cloud_cover = stac_data["properties"]["eo:cloud_cover"]
175
+ elif "landsat:cloud_cover_land" in stac_data["properties"]:
176
+ cloud_cover = stac_data["properties"][
177
+ "landsat:cloud_cover_land"
178
+ ]
179
+ else:
180
+ cloud_cover = -1
187
181
  items.append(
188
182
  LandsatOliTirsItem(
189
- name=metadata["PRODUCT_CONTENTS"]["LANDSAT_PRODUCT_ID"],
183
+ name=stac_data["id"],
190
184
  geometry=geometry,
191
185
  blob_path=blob_path,
192
- cloud_cover=image_attributes["CLOUD_COVER"],
186
+ cloud_cover=cloud_cover,
193
187
  )
194
188
  )
195
189
 
@@ -205,7 +199,7 @@ class LandsatOliTirs(DataSource):
205
199
 
206
200
  yield from items
207
201
 
208
- def get_wrs2_polygons(self) -> list[tuple[shapely.Geometry, str, str]]:
202
+ def _get_wrs2_polygons(self) -> list[tuple[shapely.Geometry, str, str]]:
209
203
  """Get polygons for each (path, row) in the WRS2 grid.
210
204
 
211
205
  Returns:
@@ -216,6 +210,7 @@ class LandsatOliTirs(DataSource):
216
210
  if not shp_fname.exists():
217
211
  # Download and extract zip to cache dir.
218
212
  zip_fname = self.metadata_cache_dir / f"{prefix}.zip"
213
+ print(f"Downloading {self.wrs2_url} to {zip_fname}")
219
214
  with urllib.request.urlopen(self.wrs2_url) as response:
220
215
  with zip_fname.open("wb") as f:
221
216
  shutil.copyfileobj(response, f)
@@ -257,9 +252,22 @@ class LandsatOliTirs(DataSource):
257
252
  polygons.append((shp, path, row))
258
253
  return polygons
259
254
 
255
+ def _get_wrs2_index(self) -> GridIndex:
256
+ """Get a grid index over the WRS2 polygons."""
257
+ if self.wrs2_index is not None:
258
+ return self.wrs2_index
259
+
260
+ # Index doesn't exist so we need to build it.
261
+ # We cache it with the object since it takes a bit of time to create it.
262
+ polygons = self._get_wrs2_polygons()
263
+ self.wrs2_index = GridIndex(WRS2_GRID_SIZE)
264
+ for polygon, path, row in polygons:
265
+ self.wrs2_index.insert(polygon.bounds, (polygon, path, row))
266
+ return self.wrs2_index
267
+
260
268
  def get_items(
261
269
  self, geometries: list[STGeometry], query_config: QueryConfig
262
- ) -> list[list[list[Item]]]:
270
+ ) -> list[list[list[LandsatOliTirsItem]]]:
263
271
  """Get a list of items in the data source intersecting the given geometries.
264
272
 
265
273
  Args:
@@ -269,7 +277,7 @@ class LandsatOliTirs(DataSource):
269
277
  Returns:
270
278
  List of groups of items that should be retrieved for each geometry.
271
279
  """
272
- wrs2_polygons = self.get_wrs2_polygons()
280
+ wrs2_index = self._get_wrs2_index()
273
281
  needed_year_pathrows = set()
274
282
  wgs84_geometries = [
275
283
  geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
@@ -280,13 +288,13 @@ class LandsatOliTirs(DataSource):
280
288
  "Landsat on AWS requires geometry time ranges to be set"
281
289
  )
282
290
  cur_pathrows = set()
283
- for polygon, path, row in wrs2_polygons:
291
+ for polygon, path, row in wrs2_index.query(wgs84_geometry.shp.bounds):
284
292
  if wgs84_geometry.shp.intersects(polygon):
285
293
  cur_pathrows.add((path, row))
286
294
  for path, row in cur_pathrows:
287
295
  for year in range(
288
- (wgs84_geometry.time_range[0] - self.max_time_delta).year,
289
- (wgs84_geometry.time_range[1] + self.max_time_delta).year + 1,
296
+ wgs84_geometry.time_range[0].year,
297
+ wgs84_geometry.time_range[1].year + 1,
290
298
  ):
291
299
  needed_year_pathrows.add((year, path, row))
292
300
 
@@ -301,18 +309,22 @@ class LandsatOliTirs(DataSource):
301
309
  cur_items.append(item)
302
310
 
303
311
  if self.sort_by == "cloud_cover":
304
- items.sort(key=lambda item: item.cloud_cover)
312
+ cur_items.sort(
313
+ key=lambda item: item.cloud_cover if item.cloud_cover >= 0 else 100
314
+ )
305
315
  elif self.sort_by is not None:
306
316
  raise ValueError(f"invalid sort_by setting ({self.sort_by})")
307
317
 
308
- cur_groups = rslearn.data_sources.utils.match_candidate_items_to_window(
309
- geometry, cur_items, query_config
318
+ cur_groups: list[list[LandsatOliTirsItem]] = (
319
+ rslearn.data_sources.utils.match_candidate_items_to_window(
320
+ geometry, cur_items, query_config
321
+ )
310
322
  )
311
323
  groups.append(cur_groups)
312
324
 
313
325
  return groups
314
326
 
315
- def get_item_by_name(self, name: str) -> Item:
327
+ def get_item_by_name(self, name: str) -> LandsatOliTirsItem:
316
328
  """Gets an item by name."""
317
329
  # Product name is like LC08_L1TP_046027_20230715_20230724_02_T1.
318
330
  # We want to use _read_products so we need to extract:
@@ -330,12 +342,14 @@ class LandsatOliTirs(DataSource):
330
342
  return item
331
343
  raise ValueError(f"item {name} not found")
332
344
 
333
- def deserialize_item(self, serialized_item: Any) -> Item:
345
+ def deserialize_item(self, serialized_item: Any) -> LandsatOliTirsItem:
334
346
  """Deserializes an item from JSON-decoded data."""
335
347
  assert isinstance(serialized_item, dict)
336
348
  return LandsatOliTirsItem.deserialize(serialized_item)
337
349
 
338
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
350
+ def retrieve_item(
351
+ self, item: LandsatOliTirsItem
352
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
339
353
  """Retrieves the rasters corresponding to an item as file streams."""
340
354
  for band in self.bands:
341
355
  buf = io.BytesIO()
@@ -350,8 +364,8 @@ class LandsatOliTirs(DataSource):
350
364
 
351
365
  def ingest(
352
366
  self,
353
- tile_store: TileStore,
354
- items: list[Item],
367
+ tile_store: TileStoreWithLayer,
368
+ items: list[LandsatOliTirsItem],
355
369
  geometries: list[list[STGeometry]],
356
370
  ) -> None:
357
371
  """Ingest items into the given tile store.
@@ -364,28 +378,158 @@ class LandsatOliTirs(DataSource):
364
378
  for item, cur_geometries in zip(items, geometries):
365
379
  for band in self.bands:
366
380
  band_names = [band]
367
- cur_tile_store = PrefixedTileStore(
368
- tile_store, (item.name, "_".join(band_names))
369
- )
370
- needed_projections = get_needed_projections(
371
- cur_tile_store, band_names, self.config.band_sets, cur_geometries
372
- )
373
- if not needed_projections:
381
+ if tile_store.is_raster_ready(item.name, band_names):
374
382
  continue
375
383
 
376
- buf = io.BytesIO()
377
- self.bucket.download_fileobj(
378
- item.blob_path + f"{band}.TIF",
379
- buf,
380
- ExtraArgs={"RequestPayer": "requester"},
381
- )
382
- buf.seek(0)
383
- with rasterio.open(buf) as raster:
384
- for projection in needed_projections:
385
- ingest_raster(
386
- tile_store=cur_tile_store,
387
- raster=raster,
388
- projection=projection,
389
- time_range=item.geometry.time_range,
390
- layer_config=self.config,
391
- )
384
+ with tempfile.TemporaryDirectory() as tmp_dir:
385
+ fname = os.path.join(tmp_dir, f"{band}.tif")
386
+ self.bucket.download_file(
387
+ item.blob_path + f"{band}.TIF",
388
+ fname,
389
+ ExtraArgs={"RequestPayer": "requester"},
390
+ )
391
+ tile_store.write_raster_file(item.name, band_names, UPath(fname))
392
+
393
+ # The functions below are to emulate TileStore functionality so we can easily
394
+ # support materialization directly from the COGs.
395
+ def is_raster_ready(
396
+ self, layer_name: str, item_name: str, bands: list[str]
397
+ ) -> bool:
398
+ """Checks if this raster has been written to the store.
399
+
400
+ Args:
401
+ layer_name: the layer name or alias.
402
+ item_name: the item.
403
+ bands: the list of bands identifying which specific raster to read.
404
+
405
+ Returns:
406
+ whether there is a raster in the store matching the source, item, and
407
+ bands.
408
+ """
409
+ # Always ready since we access it on AWS bucket.
410
+ return True
411
+
412
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
413
+ """Get the sets of bands that have been stored for the specified item.
414
+
415
+ Args:
416
+ layer_name: the layer name or alias.
417
+ item_name: the item.
418
+
419
+ Returns:
420
+ a list of lists of bands that are in the tile store (with one raster
421
+ stored corresponding to each inner list). If no rasters are ready for
422
+ this item, returns empty list.
423
+ """
424
+ return [[band] for band in self.bands]
425
+
426
+ def get_raster_bounds(
427
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
428
+ ) -> PixelBounds:
429
+ """Get the bounds of the raster in the specified projection.
430
+
431
+ Args:
432
+ layer_name: the layer name or alias.
433
+ item_name: the item to check.
434
+ bands: the list of bands identifying which specific raster to read. These
435
+ bands must match the bands of a stored raster.
436
+ projection: the projection to get the raster's bounds in.
437
+
438
+ Returns:
439
+ the bounds of the raster in the projection.
440
+ """
441
+ item = self.get_item_by_name(item_name)
442
+ geom = item.geometry.to_projection(projection)
443
+ return (
444
+ int(geom.shp.bounds[0]),
445
+ int(geom.shp.bounds[1]),
446
+ int(geom.shp.bounds[2]),
447
+ int(geom.shp.bounds[3]),
448
+ )
449
+
450
+ def read_raster(
451
+ self,
452
+ layer_name: str,
453
+ item_name: str,
454
+ bands: list[str],
455
+ projection: Projection,
456
+ bounds: PixelBounds,
457
+ resampling: Resampling = Resampling.bilinear,
458
+ ) -> npt.NDArray[Any]:
459
+ """Read raster data from the store.
460
+
461
+ Args:
462
+ layer_name: the layer name or alias.
463
+ item_name: the item to read.
464
+ bands: the list of bands identifying which specific raster to read. These
465
+ bands must match the bands of a stored raster.
466
+ projection: the projection to read in.
467
+ bounds: the bounds to read.
468
+ resampling: the resampling method to use in case reprojection is needed.
469
+
470
+ Returns:
471
+ the raster data
472
+ """
473
+ # Landsat assets have single band per asset.
474
+ assert len(bands) == 1
475
+ band = bands[0]
476
+
477
+ # Get the item since it has the blob path.
478
+ item = self.get_item_by_name(item_name)
479
+
480
+ # Create pre-signed URL for rasterio access.
481
+ # We do this because accessing via URL is much faster since rasterio can use
482
+ # the URL directly.
483
+ blob_key = item.blob_path + f"{band}.TIF"
484
+ url = self.client.generate_presigned_url(
485
+ "get_object",
486
+ Params={
487
+ "Bucket": self.bucket_name,
488
+ "Key": blob_key,
489
+ "RequestPayer": "requester",
490
+ },
491
+ )
492
+
493
+ # Construct the transform to use for the warped dataset.
494
+ wanted_transform = affine.Affine(
495
+ projection.x_resolution,
496
+ 0,
497
+ bounds[0] * projection.x_resolution,
498
+ 0,
499
+ projection.y_resolution,
500
+ bounds[1] * projection.y_resolution,
501
+ )
502
+
503
+ with rasterio.open(url) as src:
504
+ with rasterio.vrt.WarpedVRT(
505
+ src,
506
+ crs=projection.crs,
507
+ transform=wanted_transform,
508
+ width=bounds[2] - bounds[0],
509
+ height=bounds[3] - bounds[1],
510
+ resampling=resampling,
511
+ ) as vrt:
512
+ return vrt.read()
513
+
514
+ def materialize(
515
+ self,
516
+ window: Window,
517
+ item_groups: list[list[LandsatOliTirsItem]],
518
+ layer_name: str,
519
+ layer_cfg: LayerConfig,
520
+ ) -> None:
521
+ """Materialize data for the window.
522
+
523
+ Args:
524
+ window: the window to materialize
525
+ item_groups: the items from get_items
526
+ layer_name: the name of this layer
527
+ layer_cfg: the config of this layer
528
+ """
529
+ RasterMaterializer().materialize(
530
+ TileStoreWithLayer(self, layer_name),
531
+ window,
532
+ layer_name,
533
+ layer_cfg,
534
+ item_groups,
535
+ )