rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ """Data from worldpop.org."""
2
+
3
+ import random
4
+ from datetime import timedelta
5
+ from html.parser import HTMLParser
6
+ from urllib.parse import urljoin
7
+
8
+ import requests
9
+ from upath import UPath
10
+
11
+ from rslearn.config import LayerType
12
+ from rslearn.data_sources import DataSourceContext
13
+ from rslearn.data_sources.local_files import LocalFiles
14
+ from rslearn.log_utils import get_logger
15
+ from rslearn.utils.fsspec import join_upath, open_atomic
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class LinkExtractor(HTMLParser):
21
+ """Extract links from HTML.
22
+
23
+ The links attribute will be filled with the href attribute of all links that appear
24
+ on the HTML page.
25
+ """
26
+
27
+ def __init__(self) -> None:
28
+ """Create a new LinkExtractor."""
29
+ super().__init__()
30
+ self.links: list[str] = []
31
+
32
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
33
+ """Handle start of tag from the HTML parsing."""
34
+ if tag.lower() != "a":
35
+ return
36
+ for name, value in attrs:
37
+ if name.lower() != "href":
38
+ continue
39
+ if value is None:
40
+ continue
41
+ self.links.append(value)
42
+
43
+
44
+ class WorldPop(LocalFiles):
45
+ """World population data from worldpop.org.
46
+
47
+ Currently, this only supports the WorldPop Constrained 2020 100 m Resolution
48
+ dataset. See https://hub.worldpop.org/project/categories?id=3 for details.
49
+
50
+ The data is split by country. We implement with LocalFiles data source for
51
+ simplicity, but it means that all of the data must be downloaded first.
52
+ """
53
+
54
+ INDEX_URLS = [
55
+ "https://data.worldpop.org/GIS/Population/Global_2000_2020_Constrained/2020/BSGM/",
56
+ "https://data.worldpop.org/GIS/Population/Global_2000_2020_Constrained/2020/maxar_v1/",
57
+ ]
58
+ FILENAME_SUFFIX = "_ppp_2020_constrained.tif"
59
+
60
+ def __init__(
61
+ self,
62
+ worldpop_dir: str,
63
+ timeout: timedelta = timedelta(seconds=30),
64
+ context: DataSourceContext = DataSourceContext(),
65
+ ):
66
+ """Create a new WorldPop.
67
+
68
+ Args:
69
+ worldpop_dir: the directory to extract the WorldPop GeoTIFF files. For
70
+ high performance, this should be a local directory; if the dataset is
71
+ remote, prefix with a protocol ("file://") to use a local directory
72
+ instead of a path relative to the dataset path.
73
+ timeout: timeout for HTTP requests.
74
+ context: the data source context.
75
+ """
76
+ if context.ds_path is not None:
77
+ worldpop_upath = join_upath(context.ds_path, worldpop_dir)
78
+ else:
79
+ worldpop_upath = UPath(worldpop_dir)
80
+ worldpop_upath.mkdir(parents=True, exist_ok=True)
81
+ self.download_worldpop_data(worldpop_upath, timeout)
82
+ super().__init__(
83
+ src_dir=worldpop_upath,
84
+ layer_type=LayerType.RASTER,
85
+ context=context,
86
+ )
87
+
88
+ def download_worldpop_data(self, worldpop_dir: UPath, timeout: timedelta) -> None:
89
+ """Download and extract the WorldPop data.
90
+
91
+ If the data was previously downloaded, this function returns quickly.
92
+
93
+ Args:
94
+ worldpop_dir: the directory to download to.
95
+ timeout: timeout for HTTP requests.
96
+ """
97
+ completed_fname = worldpop_dir / "completed"
98
+ if completed_fname.exists():
99
+ return
100
+
101
+ # Scan the index URLs to get all the per-country subfolders.
102
+ # These should be four characters with slash at the end, like "USA/".
103
+ country_urls = []
104
+ for index_url in self.INDEX_URLS:
105
+ logger.info(f"Getting per-country subfolders from {index_url}")
106
+ response = requests.get(index_url, timeout=timeout.total_seconds())
107
+ response.raise_for_status()
108
+ parser = LinkExtractor()
109
+ parser.feed(response.text)
110
+ country_urls.extend(
111
+ [
112
+ urljoin(index_url, href)
113
+ for href in parser.links
114
+ if len(href) == 4 and href[3] == "/"
115
+ ]
116
+ )
117
+
118
+ logger.info(f"Got {len(country_urls)} country subfolders to download")
119
+ # Shuffling here enables the user to run multiple processes to speed up the
120
+ # download.
121
+ random.shuffle(country_urls)
122
+
123
+ # Now iterate over the country-level URLs and download the GeoTIFF.
124
+ for country_url in country_urls:
125
+ response = requests.get(country_url, timeout=timeout.total_seconds())
126
+ response.raise_for_status()
127
+ parser = LinkExtractor()
128
+ parser.feed(response.text)
129
+ tif_links = [
130
+ urljoin(country_url, href)
131
+ for href in parser.links
132
+ if href.endswith(self.FILENAME_SUFFIX)
133
+ ]
134
+ if len(tif_links) != 1:
135
+ raise ValueError(
136
+ f"expected {country_url} to contain one GeoTIFF ending in {self.FILENAME_SUFFIX} but got {parser.links}"
137
+ )
138
+
139
+ country_fname = tif_links[0].split("/")[-1]
140
+ dst_fname = worldpop_dir / country_fname
141
+ if dst_fname.exists():
142
+ continue
143
+
144
+ logger.info(f"Downloading from {tif_links[0]} to {dst_fname}")
145
+ with requests.get(
146
+ tif_links[0], stream=True, timeout=timeout.total_seconds()
147
+ ) as r:
148
+ r.raise_for_status()
149
+ with open_atomic(dst_fname, "wb") as f:
150
+ for chunk in r.iter_content(chunk_size=8192):
151
+ f.write(chunk)
152
+
153
+ completed_fname.touch()
@@ -13,15 +13,17 @@ import rasterio.warp
13
13
  import shapely
14
14
  from PIL import Image
15
15
  from rasterio.crs import CRS
16
- from upath import UPath
16
+ from rasterio.enums import Resampling
17
17
 
18
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
18
+ from rslearn.config import LayerConfig, QueryConfig
19
19
  from rslearn.dataset import Window
20
+ from rslearn.dataset.materialize import RasterMaterializer
21
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
20
22
  from rslearn.utils import PixelBounds, Projection, STGeometry
21
23
  from rslearn.utils.array import copy_spatial_array
24
+ from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
22
25
 
23
- from .data_source import DataSource, Item
24
- from .raster_source import ArrayWithTransform, materialize_raster
26
+ from .data_source import DataSource, DataSourceContext, Item
25
27
  from .utils import match_candidate_items_to_window
26
28
 
27
29
  WEB_MERCATOR_EPSG = 3857
@@ -81,58 +83,24 @@ def read_from_tile_callback(
81
83
  return data
82
84
 
83
85
 
84
- class XyzItem(Item):
85
- """An item in the XyzTiles data source.
86
-
87
- Each item represents one layer of tiles. Often there is only one itm in the data
88
- source, but if there are multiple then they should correspond to different time
89
- ranges.
90
- """
91
-
92
- def __init__(self, name: str, geometry: STGeometry, url_template: str):
93
- """Creates a new XyzItem.
94
-
95
- Args:
96
- name: unique name of the item
97
- geometry: the spatial and temporal extent of the item
98
- url_template: the URL template for an xyz tile.
99
- """
100
- super().__init__(name, geometry)
101
- self.url_template = url_template
102
-
103
- def serialize(self) -> dict:
104
- """Serializes the item to a JSON-encodable dictionary."""
105
- d = super().serialize()
106
- d["url_template"] = self.url_template
107
- return d
108
-
109
- @staticmethod
110
- def deserialize(d: dict) -> Item:
111
- """Deserializes an item from a JSON-decoded dictionary."""
112
- item = super(XyzItem, XyzItem).deserialize(d)
113
- return XyzItem(
114
- name=item.name, geometry=item.geometry, url_template=d["url_template"]
115
- )
116
-
117
-
118
- class XyzTiles(DataSource):
86
+ class XyzTiles(DataSource, TileStore):
119
87
  """A data source for web xyz image tiles.
120
88
 
121
89
  These tiles are usually in WebMercator projection, but different CRS can be
122
90
  configured here.
123
91
  """
124
92
 
125
- item_name = "xyz_tiles"
126
-
127
93
  def __init__(
128
94
  self,
129
95
  url_templates: list[str],
130
96
  time_ranges: list[tuple[datetime, datetime]],
131
97
  zoom: int,
132
- crs: CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
98
+ crs: str | CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
133
99
  total_units: float = WEB_MERCATOR_UNITS,
134
100
  offset: float = WEB_MERCATOR_UNITS / 2,
135
101
  tile_size: int = 256,
102
+ band_names: list[str] = ["R", "G", "B"],
103
+ context: DataSourceContext = DataSourceContext(),
136
104
  ):
137
105
  """Initialize an XyzTiles instance.
138
106
 
@@ -152,14 +120,22 @@ class XyzTiles(DataSource):
152
120
  the pixel size to map from projection coordinates to pixel coordinates.
153
121
  offset: offset added to projection units when converting to tile positions.
154
122
  tile_size: size in pixels of each tile. Tiles must be square.
123
+ band_names: what to name the bands that we read.
124
+ context: the data source context.
155
125
  """
156
126
  self.url_templates = url_templates
157
127
  self.time_ranges = time_ranges
158
128
  self.zoom = zoom
159
- self.crs = crs
160
129
  self.total_units = total_units
161
130
  self.offset = offset
162
131
  self.tile_size = tile_size
132
+ self.band_names = band_names
133
+
134
+ # Convert to CRS if needed.
135
+ if isinstance(crs, str):
136
+ self.crs = CRS.from_string(crs)
137
+ else:
138
+ self.crs = crs
163
139
 
164
140
  # Compute total number of pixels (a function of the zoom level and tile size).
165
141
  self.total_pixels = tile_size * (2**zoom)
@@ -169,7 +145,7 @@ class XyzTiles(DataSource):
169
145
  self.pixel_offset = int(self.offset / self.pixel_size)
170
146
  # Compute the extent in pixel coordinates as an STGeometry.
171
147
  # Note that pixel coordinates are prior to applying the offset.
172
- shp = shapely.box(
148
+ self.shp = shapely.box(
173
149
  -self.total_pixels // 2,
174
150
  -self.total_pixels // 2,
175
151
  self.total_pixels // 2,
@@ -179,32 +155,10 @@ class XyzTiles(DataSource):
179
155
 
180
156
  self.items = []
181
157
  for url_template, time_range in zip(self.url_templates, self.time_ranges):
182
- geometry = STGeometry(self.projection, shp, time_range)
183
- item = XyzItem(self.item_name, geometry, url_template)
158
+ geometry = STGeometry(self.projection, self.shp, time_range)
159
+ item = Item(url_template, geometry)
184
160
  self.items.append(item)
185
161
 
186
- @staticmethod
187
- def from_config(config: LayerConfig, ds_path: UPath) -> "XyzTiles":
188
- """Creates a new XyzTiles instance from a configuration dictionary."""
189
- d = config.data_source.config_dict
190
- time_ranges = []
191
- for str1, str2 in d["time_ranges"]:
192
- time1 = datetime.fromisoformat(str1)
193
- time2 = datetime.fromisoformat(str2)
194
- time_ranges.append((time1, time2))
195
- kwargs = dict(
196
- url_templates=d["url_templates"], zoom=d["zoom"], time_ranges=time_ranges
197
- )
198
- if "crs" in d:
199
- kwargs["crs"] = CRS.from_string(d["crs"])
200
- if "total_units" in d:
201
- kwargs["total_units"] = d["total_units"]
202
- if "offset" in d:
203
- kwargs["offset"] = d["offset"]
204
- if "tile_size" in d:
205
- kwargs["tile_size"] = d["tile_size"]
206
- return XyzTiles(**kwargs)
207
-
208
162
  def get_items(
209
163
  self, geometries: list[STGeometry], query_config: QueryConfig
210
164
  ) -> list[list[list[Item]]]:
@@ -232,7 +186,7 @@ class XyzTiles(DataSource):
232
186
 
233
187
  def deserialize_item(self, serialized_item: Any) -> Item:
234
188
  """Deserializes an item from JSON-decoded data."""
235
- return XyzItem.deserialize(serialized_item)
189
+ return Item.deserialize(serialized_item)
236
190
 
237
191
  def read_tile(self, url_template: str, col: int, row: int) -> npt.NDArray[Any]:
238
192
  """Read the tile at specified column and row.
@@ -249,8 +203,11 @@ class XyzTiles(DataSource):
249
203
  url = url.replace("{x}", str(col))
250
204
  url = url.replace("{y}", str(row))
251
205
  url = url.replace("{z}", str(self.zoom))
252
- image = Image.open(urllib.request.urlopen(url))
253
- return np.array(image).transpose(2, 0, 1)
206
+ image = np.array(Image.open(urllib.request.urlopen(url)))
207
+ # Handle grayscale images (add single-band channel dimension).
208
+ if len(image.shape) == 2:
209
+ image = image[:, :, None]
210
+ return image.transpose(2, 0, 1)
254
211
 
255
212
  def read_bounds(self, url_template: str, bounds: PixelBounds) -> npt.NDArray[Any]:
256
213
  """Reads the portion of the raster in the specified bounds.
@@ -275,6 +232,122 @@ class XyzTiles(DataSource):
275
232
  self.tile_size,
276
233
  )
277
234
 
235
+ def is_raster_ready(
236
+ self, layer_name: str, item_name: str, bands: list[str]
237
+ ) -> bool:
238
+ """Checks if this raster has been written to the store.
239
+
240
+ Args:
241
+ layer_name: the layer name or alias.
242
+ item_name: the item.
243
+ bands: the list of bands identifying which specific raster to read.
244
+
245
+ Returns:
246
+ whether there is a raster in the store matching the source, item, and
247
+ bands.
248
+ """
249
+ # Always ready since we wrap accesses to the XYZ tile URL.
250
+ return True
251
+
252
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
253
+ """Get the sets of bands that have been stored for the specified item.
254
+
255
+ Args:
256
+ layer_name: the layer name or alias.
257
+ item_name: the item.
258
+
259
+ Returns:
260
+ a list of lists of bands that are in the tile store (with one raster
261
+ stored corresponding to each inner list). If no rasters are ready for
262
+ this item, returns empty list.
263
+ """
264
+ return [self.band_names]
265
+
266
+ def get_raster_bounds(
267
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
268
+ ) -> PixelBounds:
269
+ """Get the bounds of the raster in the specified projection.
270
+
271
+ Args:
272
+ layer_name: the layer name or alias.
273
+ item_name: the item to check.
274
+ bands: the list of bands identifying which specific raster to read. These
275
+ bands must match the bands of a stored raster.
276
+ projection: the projection to get the raster's bounds in.
277
+
278
+ Returns:
279
+ the bounds of the raster in the projection.
280
+ """
281
+ geom = STGeometry(self.projection, self.shp, None).to_projection(projection)
282
+ return (
283
+ int(geom.shp.bounds[0]),
284
+ int(geom.shp.bounds[1]),
285
+ int(geom.shp.bounds[2]),
286
+ int(geom.shp.bounds[3]),
287
+ )
288
+
289
+ def read_raster(
290
+ self,
291
+ layer_name: str,
292
+ item_name: str,
293
+ bands: list[str],
294
+ projection: Projection,
295
+ bounds: PixelBounds,
296
+ resampling: Resampling = Resampling.bilinear,
297
+ ) -> npt.NDArray[Any]:
298
+ """Read raster data from the store.
299
+
300
+ Args:
301
+ layer_name: the layer name or alias.
302
+ item_name: the item to read.
303
+ bands: the list of bands identifying which specific raster to read. These
304
+ bands must match the bands of a stored raster.
305
+ projection: the projection to read in.
306
+ bounds: the bounds to read.
307
+ resampling: the resampling method to use in case reprojection is needed.
308
+
309
+ Returns:
310
+ the raster data
311
+ """
312
+ # Validate bands.
313
+ if bands != self.band_names:
314
+ raise ValueError(
315
+ f"expected request for bands {self.band_names} but requested {bands}"
316
+ )
317
+
318
+ # Read a raster matching the given bounds but projected onto the projection of
319
+ # the xyz tiles.
320
+ request_geometry = STGeometry(projection, shapely.box(*bounds), None)
321
+ projected_geometry = request_geometry.to_projection(self.projection)
322
+ projected_bounds = (
323
+ math.floor(projected_geometry.shp.bounds[0]),
324
+ math.floor(projected_geometry.shp.bounds[1]),
325
+ math.ceil(projected_geometry.shp.bounds[2]),
326
+ math.ceil(projected_geometry.shp.bounds[3]),
327
+ )
328
+ # The item name is the URL template.
329
+ url_template = item_name
330
+ array = self.read_bounds(url_template, projected_bounds)
331
+ # Now project it back to the requested geometry.
332
+ src_transform = get_transform_from_projection_and_bounds(
333
+ self.projection, projected_bounds
334
+ )
335
+ dst_transform = get_transform_from_projection_and_bounds(projection, bounds)
336
+ dst_array = np.zeros(
337
+ (array.shape[0], bounds[3] - bounds[1], bounds[2] - bounds[0]),
338
+ dtype=array.dtype,
339
+ )
340
+ rasterio.warp.reproject(
341
+ source=array,
342
+ src_crs=self.projection.crs,
343
+ src_transform=src_transform,
344
+ destination=dst_array,
345
+ dst_crs=projection.crs,
346
+ dst_transform=dst_transform,
347
+ resampling=resampling,
348
+ )
349
+ return dst_array
350
+
278
351
  def materialize(
279
352
  self,
280
353
  window: Window,
@@ -290,40 +363,10 @@ class XyzTiles(DataSource):
290
363
  layer_name: the name of this layer
291
364
  layer_cfg: the config of this layer
292
365
  """
293
- assert len(item_groups) == 1 and len(item_groups[0]) == 1
294
- item = item_groups[0][0]
295
- assert isinstance(item, XyzItem)
296
-
297
- # Read a raster matching the bounds of the window's bounds projected onto the
298
- # projection of the xyz tiles.
299
- assert isinstance(layer_cfg, RasterLayerConfig)
300
- band_cfg = layer_cfg.band_sets[0]
301
- window_projection, window_bounds = band_cfg.get_final_projection_and_bounds(
302
- window.projection, window.bounds
303
- )
304
- window_geometry = STGeometry(
305
- window_projection, shapely.box(*window_bounds), None
366
+ RasterMaterializer().materialize(
367
+ TileStoreWithLayer(self, layer_name),
368
+ window,
369
+ layer_name,
370
+ layer_cfg,
371
+ item_groups,
306
372
  )
307
- projected_geometry = window_geometry.to_projection(self.projection)
308
- projected_bounds = [
309
- math.floor(projected_geometry.shp.bounds[0]),
310
- math.floor(projected_geometry.shp.bounds[1]),
311
- math.ceil(projected_geometry.shp.bounds[2]),
312
- math.ceil(projected_geometry.shp.bounds[3]),
313
- ]
314
- projected_raster = self.read_bounds(item.url_template, projected_bounds)
315
-
316
- # Attach the transform to the raster.
317
- src_transform = rasterio.transform.Affine(
318
- self.projection.x_resolution,
319
- 0,
320
- projected_bounds[0] * self.projection.x_resolution,
321
- 0,
322
- self.projection.y_resolution,
323
- projected_bounds[1] * self.projection.y_resolution,
324
- )
325
- array_with_transform = ArrayWithTransform(
326
- projected_raster, self.projection.crs, src_transform
327
- )
328
-
329
- materialize_raster(array_with_transform, window, layer_name, band_cfg)
@@ -1,6 +1,12 @@
1
1
  """rslearn dataset storage and operations."""
2
2
 
3
3
  from .dataset import Dataset
4
- from .window import Window, WindowLayerData
4
+ from .window import Window, WindowLayerData, get_window_layer_dir, get_window_raster_dir
5
5
 
6
- __all__ = ("Dataset", "Window", "WindowLayerData")
6
+ __all__ = (
7
+ "Dataset",
8
+ "Window",
9
+ "WindowLayerData",
10
+ "get_window_layer_dir",
11
+ "get_window_raster_dir",
12
+ )
@@ -25,7 +25,7 @@ def add_windows_from_geometries(
25
25
  window_size: int | None = None,
26
26
  time_range: tuple[datetime, datetime] | None = None,
27
27
  use_utm: bool = False,
28
- ):
28
+ ) -> list[Window]:
29
29
  """Create windows based on a list of STGeometry.
30
30
 
31
31
  Args:
@@ -131,7 +131,7 @@ def add_windows_from_geometries(
131
131
  f"_{time_range[0].isoformat()}_{time_range[1].isoformat()}"
132
132
  )
133
133
  window = Window(
134
- path=dataset.path / "windows" / group / cur_window_name,
134
+ storage=dataset.storage,
135
135
  group=group,
136
136
  name=cur_window_name,
137
137
  projection=cur_projection,
@@ -1,16 +1,19 @@
1
1
  """rslearn dataset class."""
2
2
 
3
3
  import json
4
- import multiprocessing
4
+ from typing import Any
5
5
 
6
- import tqdm
7
6
  from upath import UPath
8
7
 
9
- from rslearn.config import TileStoreConfig, load_layer_config
8
+ from rslearn.config import DatasetConfig
9
+ from rslearn.log_utils import get_logger
10
+ from rslearn.template_params import substitute_env_vars_in_string
10
11
  from rslearn.tile_stores import TileStore, load_tile_store
11
12
 
12
13
  from .window import Window
13
14
 
15
+ logger = get_logger(__name__)
16
+
14
17
 
15
18
  class Dataset:
16
19
  """A rslearn dataset.
@@ -20,7 +23,7 @@ class Dataset:
20
23
  .. code-block:: none
21
24
 
22
25
  dataset/
23
- config.json
26
+ config.json # optional, if config provided as runtime object
24
27
  windows/
25
28
  group1/
26
29
  epsg:3857_10_623565_1528020/
@@ -37,72 +40,58 @@ class Dataset:
37
40
  materialize.
38
41
  """
39
42
 
40
- def __init__(self, path: UPath) -> None:
43
+ def __init__(
44
+ self,
45
+ path: UPath,
46
+ disabled_layers: list[str] = [],
47
+ dataset_config: DatasetConfig | None = None,
48
+ ) -> None:
41
49
  """Initializes a new Dataset.
42
50
 
43
51
  Args:
44
52
  path: the root directory of the dataset
53
+ disabled_layers: list of layers to disable
54
+ dataset_config: optional dataset configuration to use instead of loading from the dataset directory
45
55
  """
46
56
  self.path = path
47
57
 
48
- # Load dataset configuration.
49
- with (self.path / "config.json").open("r") as f:
50
- config = json.load(f)
51
- self.layers = {
52
- layer_name: load_layer_config(d)
53
- for layer_name, d in config["layers"].items()
54
- }
55
- self.tile_store_config = TileStoreConfig.from_config(config["tile_store"])
56
- self.materializer_name = config.get("materialize")
58
+ if dataset_config is None:
59
+ # Load dataset configuration from the dataset directory.
60
+ with (self.path / "config.json").open("r") as f:
61
+ config_content = f.read()
62
+ config_content = substitute_env_vars_in_string(config_content)
63
+ dataset_config = DatasetConfig.model_validate(
64
+ json.loads(config_content)
65
+ )
66
+
67
+ self.layers = {}
68
+ for layer_name, layer_config in dataset_config.layers.items():
69
+ if layer_name in disabled_layers:
70
+ logger.warning(f"Layer {layer_name} is disabled")
71
+ continue
72
+ self.layers[layer_name] = layer_config
73
+
74
+ self.tile_store_config = dataset_config.tile_store
75
+ self.storage = (
76
+ dataset_config.storage.instantiate_window_storage_factory().get_storage(
77
+ self.path
78
+ )
79
+ )
57
80
 
58
81
  def load_windows(
59
82
  self,
60
83
  groups: list[str] | None = None,
61
84
  names: list[str] | None = None,
62
- show_progress: bool = False,
63
- workers: int = 0,
85
+ **kwargs: Any,
64
86
  ) -> list[Window]:
65
87
  """Load the windows in the dataset.
66
88
 
67
89
  Args:
68
90
  groups: an optional list of groups to filter loading
69
91
  names: an optional list of window names to filter loading
70
- show_progress: whether to show tqdm progress bar
71
- workers: number of parallel workers, default 0 (use main thread only to load windows)
92
+ kwargs: optional keyword arguments to pass to WindowStorage.get_windows.
72
93
  """
73
- window_dirs = []
74
- if not groups:
75
- groups = []
76
- for p in (self.path / "windows").iterdir():
77
- groups.append(p.name)
78
- for group in groups:
79
- group_dir = self.path / "windows" / group
80
- if names:
81
- cur_names = names
82
- else:
83
- cur_names = []
84
- for p in group_dir.iterdir():
85
- cur_names.append(p.name)
86
-
87
- for window_name in cur_names:
88
- window_dir = group_dir / window_name
89
- window_dirs.append(window_dir)
90
-
91
- if workers == 0:
92
- windows = [Window.load(window_dir) for window_dir in window_dirs]
93
- else:
94
- p = multiprocessing.Pool(workers)
95
- outputs = p.imap_unordered(Window.load, window_dirs)
96
- if show_progress:
97
- outputs = tqdm.tqdm(
98
- outputs, total=len(window_dirs), desc="Loading windows"
99
- )
100
- windows = []
101
- for window in outputs:
102
- windows.append(window)
103
- p.close()
104
-
105
- return windows
94
+ return self.storage.get_windows(groups=groups, names=names, **kwargs)
106
95
 
107
96
  def get_tile_store(self) -> TileStore:
108
97
  """Get the tile store associated with this dataset.