rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Data from worldpop.org."""
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from datetime import timedelta
|
|
5
|
+
from html.parser import HTMLParser
|
|
6
|
+
from urllib.parse import urljoin
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
from upath import UPath
|
|
10
|
+
|
|
11
|
+
from rslearn.config import LayerType
|
|
12
|
+
from rslearn.data_sources import DataSourceContext
|
|
13
|
+
from rslearn.data_sources.local_files import LocalFiles
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LinkExtractor(HTMLParser):
|
|
21
|
+
"""Extract links from HTML.
|
|
22
|
+
|
|
23
|
+
The links attribute will be filled with the href attribute of all links that appear
|
|
24
|
+
on the HTML page.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self) -> None:
|
|
28
|
+
"""Create a new LinkExtractor."""
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.links: list[str] = []
|
|
31
|
+
|
|
32
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
33
|
+
"""Handle start of tag from the HTML parsing."""
|
|
34
|
+
if tag.lower() != "a":
|
|
35
|
+
return
|
|
36
|
+
for name, value in attrs:
|
|
37
|
+
if name.lower() != "href":
|
|
38
|
+
continue
|
|
39
|
+
if value is None:
|
|
40
|
+
continue
|
|
41
|
+
self.links.append(value)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class WorldPop(LocalFiles):
|
|
45
|
+
"""World population data from worldpop.org.
|
|
46
|
+
|
|
47
|
+
Currently, this only supports the WorldPop Constrained 2020 100 m Resolution
|
|
48
|
+
dataset. See https://hub.worldpop.org/project/categories?id=3 for details.
|
|
49
|
+
|
|
50
|
+
The data is split by country. We implement with LocalFiles data source for
|
|
51
|
+
simplicity, but it means that all of the data must be downloaded first.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
INDEX_URLS = [
|
|
55
|
+
"https://data.worldpop.org/GIS/Population/Global_2000_2020_Constrained/2020/BSGM/",
|
|
56
|
+
"https://data.worldpop.org/GIS/Population/Global_2000_2020_Constrained/2020/maxar_v1/",
|
|
57
|
+
]
|
|
58
|
+
FILENAME_SUFFIX = "_ppp_2020_constrained.tif"
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
worldpop_dir: str,
|
|
63
|
+
timeout: timedelta = timedelta(seconds=30),
|
|
64
|
+
context: DataSourceContext = DataSourceContext(),
|
|
65
|
+
):
|
|
66
|
+
"""Create a new WorldPop.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
worldpop_dir: the directory to extract the WorldPop GeoTIFF files. For
|
|
70
|
+
high performance, this should be a local directory; if the dataset is
|
|
71
|
+
remote, prefix with a protocol ("file://") to use a local directory
|
|
72
|
+
instead of a path relative to the dataset path.
|
|
73
|
+
timeout: timeout for HTTP requests.
|
|
74
|
+
context: the data source context.
|
|
75
|
+
"""
|
|
76
|
+
if context.ds_path is not None:
|
|
77
|
+
worldpop_upath = join_upath(context.ds_path, worldpop_dir)
|
|
78
|
+
else:
|
|
79
|
+
worldpop_upath = UPath(worldpop_dir)
|
|
80
|
+
worldpop_upath.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
self.download_worldpop_data(worldpop_upath, timeout)
|
|
82
|
+
super().__init__(
|
|
83
|
+
src_dir=worldpop_upath,
|
|
84
|
+
layer_type=LayerType.RASTER,
|
|
85
|
+
context=context,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def download_worldpop_data(self, worldpop_dir: UPath, timeout: timedelta) -> None:
|
|
89
|
+
"""Download and extract the WorldPop data.
|
|
90
|
+
|
|
91
|
+
If the data was previously downloaded, this function returns quickly.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
worldpop_dir: the directory to download to.
|
|
95
|
+
timeout: timeout for HTTP requests.
|
|
96
|
+
"""
|
|
97
|
+
completed_fname = worldpop_dir / "completed"
|
|
98
|
+
if completed_fname.exists():
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
# Scan the index URLs to get all the per-country subfolders.
|
|
102
|
+
# These should be four characters with slash at the end, like "USA/".
|
|
103
|
+
country_urls = []
|
|
104
|
+
for index_url in self.INDEX_URLS:
|
|
105
|
+
logger.info(f"Getting per-country subfolders from {index_url}")
|
|
106
|
+
response = requests.get(index_url, timeout=timeout.total_seconds())
|
|
107
|
+
response.raise_for_status()
|
|
108
|
+
parser = LinkExtractor()
|
|
109
|
+
parser.feed(response.text)
|
|
110
|
+
country_urls.extend(
|
|
111
|
+
[
|
|
112
|
+
urljoin(index_url, href)
|
|
113
|
+
for href in parser.links
|
|
114
|
+
if len(href) == 4 and href[3] == "/"
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
logger.info(f"Got {len(country_urls)} country subfolders to download")
|
|
119
|
+
# Shuffling here enables the user to run multiple processes to speed up the
|
|
120
|
+
# download.
|
|
121
|
+
random.shuffle(country_urls)
|
|
122
|
+
|
|
123
|
+
# Now iterate over the country-level URLs and download the GeoTIFF.
|
|
124
|
+
for country_url in country_urls:
|
|
125
|
+
response = requests.get(country_url, timeout=timeout.total_seconds())
|
|
126
|
+
response.raise_for_status()
|
|
127
|
+
parser = LinkExtractor()
|
|
128
|
+
parser.feed(response.text)
|
|
129
|
+
tif_links = [
|
|
130
|
+
urljoin(country_url, href)
|
|
131
|
+
for href in parser.links
|
|
132
|
+
if href.endswith(self.FILENAME_SUFFIX)
|
|
133
|
+
]
|
|
134
|
+
if len(tif_links) != 1:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"expected {country_url} to contain one GeoTIFF ending in {self.FILENAME_SUFFIX} but got {parser.links}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
country_fname = tif_links[0].split("/")[-1]
|
|
140
|
+
dst_fname = worldpop_dir / country_fname
|
|
141
|
+
if dst_fname.exists():
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
logger.info(f"Downloading from {tif_links[0]} to {dst_fname}")
|
|
145
|
+
with requests.get(
|
|
146
|
+
tif_links[0], stream=True, timeout=timeout.total_seconds()
|
|
147
|
+
) as r:
|
|
148
|
+
r.raise_for_status()
|
|
149
|
+
with open_atomic(dst_fname, "wb") as f:
|
|
150
|
+
for chunk in r.iter_content(chunk_size=8192):
|
|
151
|
+
f.write(chunk)
|
|
152
|
+
|
|
153
|
+
completed_fname.touch()
|
|
@@ -13,15 +13,17 @@ import rasterio.warp
|
|
|
13
13
|
import shapely
|
|
14
14
|
from PIL import Image
|
|
15
15
|
from rasterio.crs import CRS
|
|
16
|
-
from
|
|
16
|
+
from rasterio.enums import Resampling
|
|
17
17
|
|
|
18
|
-
from rslearn.config import LayerConfig, QueryConfig
|
|
18
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
19
19
|
from rslearn.dataset import Window
|
|
20
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
21
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
20
22
|
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
21
23
|
from rslearn.utils.array import copy_spatial_array
|
|
24
|
+
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
22
25
|
|
|
23
|
-
from .data_source import DataSource, Item
|
|
24
|
-
from .raster_source import ArrayWithTransform, materialize_raster
|
|
26
|
+
from .data_source import DataSource, DataSourceContext, Item
|
|
25
27
|
from .utils import match_candidate_items_to_window
|
|
26
28
|
|
|
27
29
|
WEB_MERCATOR_EPSG = 3857
|
|
@@ -81,58 +83,24 @@ def read_from_tile_callback(
|
|
|
81
83
|
return data
|
|
82
84
|
|
|
83
85
|
|
|
84
|
-
class
|
|
85
|
-
"""An item in the XyzTiles data source.
|
|
86
|
-
|
|
87
|
-
Each item represents one layer of tiles. Often there is only one itm in the data
|
|
88
|
-
source, but if there are multiple then they should correspond to different time
|
|
89
|
-
ranges.
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
def __init__(self, name: str, geometry: STGeometry, url_template: str):
|
|
93
|
-
"""Creates a new XyzItem.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
name: unique name of the item
|
|
97
|
-
geometry: the spatial and temporal extent of the item
|
|
98
|
-
url_template: the URL template for an xyz tile.
|
|
99
|
-
"""
|
|
100
|
-
super().__init__(name, geometry)
|
|
101
|
-
self.url_template = url_template
|
|
102
|
-
|
|
103
|
-
def serialize(self) -> dict:
|
|
104
|
-
"""Serializes the item to a JSON-encodable dictionary."""
|
|
105
|
-
d = super().serialize()
|
|
106
|
-
d["url_template"] = self.url_template
|
|
107
|
-
return d
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def deserialize(d: dict) -> Item:
|
|
111
|
-
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
112
|
-
item = super(XyzItem, XyzItem).deserialize(d)
|
|
113
|
-
return XyzItem(
|
|
114
|
-
name=item.name, geometry=item.geometry, url_template=d["url_template"]
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class XyzTiles(DataSource):
|
|
86
|
+
class XyzTiles(DataSource, TileStore):
|
|
119
87
|
"""A data source for web xyz image tiles.
|
|
120
88
|
|
|
121
89
|
These tiles are usually in WebMercator projection, but different CRS can be
|
|
122
90
|
configured here.
|
|
123
91
|
"""
|
|
124
92
|
|
|
125
|
-
item_name = "xyz_tiles"
|
|
126
|
-
|
|
127
93
|
def __init__(
|
|
128
94
|
self,
|
|
129
95
|
url_templates: list[str],
|
|
130
96
|
time_ranges: list[tuple[datetime, datetime]],
|
|
131
97
|
zoom: int,
|
|
132
|
-
crs: CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
|
|
98
|
+
crs: str | CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
|
|
133
99
|
total_units: float = WEB_MERCATOR_UNITS,
|
|
134
100
|
offset: float = WEB_MERCATOR_UNITS / 2,
|
|
135
101
|
tile_size: int = 256,
|
|
102
|
+
band_names: list[str] = ["R", "G", "B"],
|
|
103
|
+
context: DataSourceContext = DataSourceContext(),
|
|
136
104
|
):
|
|
137
105
|
"""Initialize an XyzTiles instance.
|
|
138
106
|
|
|
@@ -152,14 +120,22 @@ class XyzTiles(DataSource):
|
|
|
152
120
|
the pixel size to map from projection coordinates to pixel coordinates.
|
|
153
121
|
offset: offset added to projection units when converting to tile positions.
|
|
154
122
|
tile_size: size in pixels of each tile. Tiles must be square.
|
|
123
|
+
band_names: what to name the bands that we read.
|
|
124
|
+
context: the data source context.
|
|
155
125
|
"""
|
|
156
126
|
self.url_templates = url_templates
|
|
157
127
|
self.time_ranges = time_ranges
|
|
158
128
|
self.zoom = zoom
|
|
159
|
-
self.crs = crs
|
|
160
129
|
self.total_units = total_units
|
|
161
130
|
self.offset = offset
|
|
162
131
|
self.tile_size = tile_size
|
|
132
|
+
self.band_names = band_names
|
|
133
|
+
|
|
134
|
+
# Convert to CRS if needed.
|
|
135
|
+
if isinstance(crs, str):
|
|
136
|
+
self.crs = CRS.from_string(crs)
|
|
137
|
+
else:
|
|
138
|
+
self.crs = crs
|
|
163
139
|
|
|
164
140
|
# Compute total number of pixels (a function of the zoom level and tile size).
|
|
165
141
|
self.total_pixels = tile_size * (2**zoom)
|
|
@@ -169,7 +145,7 @@ class XyzTiles(DataSource):
|
|
|
169
145
|
self.pixel_offset = int(self.offset / self.pixel_size)
|
|
170
146
|
# Compute the extent in pixel coordinates as an STGeometry.
|
|
171
147
|
# Note that pixel coordinates are prior to applying the offset.
|
|
172
|
-
shp = shapely.box(
|
|
148
|
+
self.shp = shapely.box(
|
|
173
149
|
-self.total_pixels // 2,
|
|
174
150
|
-self.total_pixels // 2,
|
|
175
151
|
self.total_pixels // 2,
|
|
@@ -179,32 +155,10 @@ class XyzTiles(DataSource):
|
|
|
179
155
|
|
|
180
156
|
self.items = []
|
|
181
157
|
for url_template, time_range in zip(self.url_templates, self.time_ranges):
|
|
182
|
-
geometry = STGeometry(self.projection, shp, time_range)
|
|
183
|
-
item =
|
|
158
|
+
geometry = STGeometry(self.projection, self.shp, time_range)
|
|
159
|
+
item = Item(url_template, geometry)
|
|
184
160
|
self.items.append(item)
|
|
185
161
|
|
|
186
|
-
@staticmethod
|
|
187
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "XyzTiles":
|
|
188
|
-
"""Creates a new XyzTiles instance from a configuration dictionary."""
|
|
189
|
-
d = config.data_source.config_dict
|
|
190
|
-
time_ranges = []
|
|
191
|
-
for str1, str2 in d["time_ranges"]:
|
|
192
|
-
time1 = datetime.fromisoformat(str1)
|
|
193
|
-
time2 = datetime.fromisoformat(str2)
|
|
194
|
-
time_ranges.append((time1, time2))
|
|
195
|
-
kwargs = dict(
|
|
196
|
-
url_templates=d["url_templates"], zoom=d["zoom"], time_ranges=time_ranges
|
|
197
|
-
)
|
|
198
|
-
if "crs" in d:
|
|
199
|
-
kwargs["crs"] = CRS.from_string(d["crs"])
|
|
200
|
-
if "total_units" in d:
|
|
201
|
-
kwargs["total_units"] = d["total_units"]
|
|
202
|
-
if "offset" in d:
|
|
203
|
-
kwargs["offset"] = d["offset"]
|
|
204
|
-
if "tile_size" in d:
|
|
205
|
-
kwargs["tile_size"] = d["tile_size"]
|
|
206
|
-
return XyzTiles(**kwargs)
|
|
207
|
-
|
|
208
162
|
def get_items(
|
|
209
163
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
210
164
|
) -> list[list[list[Item]]]:
|
|
@@ -232,7 +186,7 @@ class XyzTiles(DataSource):
|
|
|
232
186
|
|
|
233
187
|
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
234
188
|
"""Deserializes an item from JSON-decoded data."""
|
|
235
|
-
return
|
|
189
|
+
return Item.deserialize(serialized_item)
|
|
236
190
|
|
|
237
191
|
def read_tile(self, url_template: str, col: int, row: int) -> npt.NDArray[Any]:
|
|
238
192
|
"""Read the tile at specified column and row.
|
|
@@ -249,8 +203,11 @@ class XyzTiles(DataSource):
|
|
|
249
203
|
url = url.replace("{x}", str(col))
|
|
250
204
|
url = url.replace("{y}", str(row))
|
|
251
205
|
url = url.replace("{z}", str(self.zoom))
|
|
252
|
-
image = Image.open(urllib.request.urlopen(url))
|
|
253
|
-
|
|
206
|
+
image = np.array(Image.open(urllib.request.urlopen(url)))
|
|
207
|
+
# Handle grayscale images (add single-band channel dimension).
|
|
208
|
+
if len(image.shape) == 2:
|
|
209
|
+
image = image[:, :, None]
|
|
210
|
+
return image.transpose(2, 0, 1)
|
|
254
211
|
|
|
255
212
|
def read_bounds(self, url_template: str, bounds: PixelBounds) -> npt.NDArray[Any]:
|
|
256
213
|
"""Reads the portion of the raster in the specified bounds.
|
|
@@ -275,6 +232,122 @@ class XyzTiles(DataSource):
|
|
|
275
232
|
self.tile_size,
|
|
276
233
|
)
|
|
277
234
|
|
|
235
|
+
def is_raster_ready(
|
|
236
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
237
|
+
) -> bool:
|
|
238
|
+
"""Checks if this raster has been written to the store.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
layer_name: the layer name or alias.
|
|
242
|
+
item_name: the item.
|
|
243
|
+
bands: the list of bands identifying which specific raster to read.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
whether there is a raster in the store matching the source, item, and
|
|
247
|
+
bands.
|
|
248
|
+
"""
|
|
249
|
+
# Always ready since we wrap accesses to the XYZ tile URL.
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
253
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
layer_name: the layer name or alias.
|
|
257
|
+
item_name: the item.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
a list of lists of bands that are in the tile store (with one raster
|
|
261
|
+
stored corresponding to each inner list). If no rasters are ready for
|
|
262
|
+
this item, returns empty list.
|
|
263
|
+
"""
|
|
264
|
+
return [self.band_names]
|
|
265
|
+
|
|
266
|
+
def get_raster_bounds(
|
|
267
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
268
|
+
) -> PixelBounds:
|
|
269
|
+
"""Get the bounds of the raster in the specified projection.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
layer_name: the layer name or alias.
|
|
273
|
+
item_name: the item to check.
|
|
274
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
275
|
+
bands must match the bands of a stored raster.
|
|
276
|
+
projection: the projection to get the raster's bounds in.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
the bounds of the raster in the projection.
|
|
280
|
+
"""
|
|
281
|
+
geom = STGeometry(self.projection, self.shp, None).to_projection(projection)
|
|
282
|
+
return (
|
|
283
|
+
int(geom.shp.bounds[0]),
|
|
284
|
+
int(geom.shp.bounds[1]),
|
|
285
|
+
int(geom.shp.bounds[2]),
|
|
286
|
+
int(geom.shp.bounds[3]),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def read_raster(
|
|
290
|
+
self,
|
|
291
|
+
layer_name: str,
|
|
292
|
+
item_name: str,
|
|
293
|
+
bands: list[str],
|
|
294
|
+
projection: Projection,
|
|
295
|
+
bounds: PixelBounds,
|
|
296
|
+
resampling: Resampling = Resampling.bilinear,
|
|
297
|
+
) -> npt.NDArray[Any]:
|
|
298
|
+
"""Read raster data from the store.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
layer_name: the layer name or alias.
|
|
302
|
+
item_name: the item to read.
|
|
303
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
304
|
+
bands must match the bands of a stored raster.
|
|
305
|
+
projection: the projection to read in.
|
|
306
|
+
bounds: the bounds to read.
|
|
307
|
+
resampling: the resampling method to use in case reprojection is needed.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
the raster data
|
|
311
|
+
"""
|
|
312
|
+
# Validate bands.
|
|
313
|
+
if bands != self.band_names:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"expected request for bands {self.band_names} but requested {bands}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Read a raster matching the given bounds but projected onto the projection of
|
|
319
|
+
# the xyz tiles.
|
|
320
|
+
request_geometry = STGeometry(projection, shapely.box(*bounds), None)
|
|
321
|
+
projected_geometry = request_geometry.to_projection(self.projection)
|
|
322
|
+
projected_bounds = (
|
|
323
|
+
math.floor(projected_geometry.shp.bounds[0]),
|
|
324
|
+
math.floor(projected_geometry.shp.bounds[1]),
|
|
325
|
+
math.ceil(projected_geometry.shp.bounds[2]),
|
|
326
|
+
math.ceil(projected_geometry.shp.bounds[3]),
|
|
327
|
+
)
|
|
328
|
+
# The item name is the URL template.
|
|
329
|
+
url_template = item_name
|
|
330
|
+
array = self.read_bounds(url_template, projected_bounds)
|
|
331
|
+
# Now project it back to the requested geometry.
|
|
332
|
+
src_transform = get_transform_from_projection_and_bounds(
|
|
333
|
+
self.projection, projected_bounds
|
|
334
|
+
)
|
|
335
|
+
dst_transform = get_transform_from_projection_and_bounds(projection, bounds)
|
|
336
|
+
dst_array = np.zeros(
|
|
337
|
+
(array.shape[0], bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
338
|
+
dtype=array.dtype,
|
|
339
|
+
)
|
|
340
|
+
rasterio.warp.reproject(
|
|
341
|
+
source=array,
|
|
342
|
+
src_crs=self.projection.crs,
|
|
343
|
+
src_transform=src_transform,
|
|
344
|
+
destination=dst_array,
|
|
345
|
+
dst_crs=projection.crs,
|
|
346
|
+
dst_transform=dst_transform,
|
|
347
|
+
resampling=resampling,
|
|
348
|
+
)
|
|
349
|
+
return dst_array
|
|
350
|
+
|
|
278
351
|
def materialize(
|
|
279
352
|
self,
|
|
280
353
|
window: Window,
|
|
@@ -290,40 +363,10 @@ class XyzTiles(DataSource):
|
|
|
290
363
|
layer_name: the name of this layer
|
|
291
364
|
layer_cfg: the config of this layer
|
|
292
365
|
"""
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
300
|
-
band_cfg = layer_cfg.band_sets[0]
|
|
301
|
-
window_projection, window_bounds = band_cfg.get_final_projection_and_bounds(
|
|
302
|
-
window.projection, window.bounds
|
|
303
|
-
)
|
|
304
|
-
window_geometry = STGeometry(
|
|
305
|
-
window_projection, shapely.box(*window_bounds), None
|
|
366
|
+
RasterMaterializer().materialize(
|
|
367
|
+
TileStoreWithLayer(self, layer_name),
|
|
368
|
+
window,
|
|
369
|
+
layer_name,
|
|
370
|
+
layer_cfg,
|
|
371
|
+
item_groups,
|
|
306
372
|
)
|
|
307
|
-
projected_geometry = window_geometry.to_projection(self.projection)
|
|
308
|
-
projected_bounds = [
|
|
309
|
-
math.floor(projected_geometry.shp.bounds[0]),
|
|
310
|
-
math.floor(projected_geometry.shp.bounds[1]),
|
|
311
|
-
math.ceil(projected_geometry.shp.bounds[2]),
|
|
312
|
-
math.ceil(projected_geometry.shp.bounds[3]),
|
|
313
|
-
]
|
|
314
|
-
projected_raster = self.read_bounds(item.url_template, projected_bounds)
|
|
315
|
-
|
|
316
|
-
# Attach the transform to the raster.
|
|
317
|
-
src_transform = rasterio.transform.Affine(
|
|
318
|
-
self.projection.x_resolution,
|
|
319
|
-
0,
|
|
320
|
-
projected_bounds[0] * self.projection.x_resolution,
|
|
321
|
-
0,
|
|
322
|
-
self.projection.y_resolution,
|
|
323
|
-
projected_bounds[1] * self.projection.y_resolution,
|
|
324
|
-
)
|
|
325
|
-
array_with_transform = ArrayWithTransform(
|
|
326
|
-
projected_raster, self.projection.crs, src_transform
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
materialize_raster(array_with_transform, window, layer_name, band_cfg)
|
rslearn/dataset/__init__.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
"""rslearn dataset storage and operations."""
|
|
2
2
|
|
|
3
3
|
from .dataset import Dataset
|
|
4
|
-
from .window import Window, WindowLayerData
|
|
4
|
+
from .window import Window, WindowLayerData, get_window_layer_dir, get_window_raster_dir
|
|
5
5
|
|
|
6
|
-
__all__ = (
|
|
6
|
+
__all__ = (
|
|
7
|
+
"Dataset",
|
|
8
|
+
"Window",
|
|
9
|
+
"WindowLayerData",
|
|
10
|
+
"get_window_layer_dir",
|
|
11
|
+
"get_window_raster_dir",
|
|
12
|
+
)
|
rslearn/dataset/add_windows.py
CHANGED
|
@@ -25,7 +25,7 @@ def add_windows_from_geometries(
|
|
|
25
25
|
window_size: int | None = None,
|
|
26
26
|
time_range: tuple[datetime, datetime] | None = None,
|
|
27
27
|
use_utm: bool = False,
|
|
28
|
-
):
|
|
28
|
+
) -> list[Window]:
|
|
29
29
|
"""Create windows based on a list of STGeometry.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -131,7 +131,7 @@ def add_windows_from_geometries(
|
|
|
131
131
|
f"_{time_range[0].isoformat()}_{time_range[1].isoformat()}"
|
|
132
132
|
)
|
|
133
133
|
window = Window(
|
|
134
|
-
|
|
134
|
+
storage=dataset.storage,
|
|
135
135
|
group=group,
|
|
136
136
|
name=cur_window_name,
|
|
137
137
|
projection=cur_projection,
|
rslearn/dataset/dataset.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
|
1
1
|
"""rslearn dataset class."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
import
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
|
-
import tqdm
|
|
7
6
|
from upath import UPath
|
|
8
7
|
|
|
9
|
-
from rslearn.config import
|
|
8
|
+
from rslearn.config import DatasetConfig
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
from rslearn.template_params import substitute_env_vars_in_string
|
|
10
11
|
from rslearn.tile_stores import TileStore, load_tile_store
|
|
11
12
|
|
|
12
13
|
from .window import Window
|
|
13
14
|
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
14
17
|
|
|
15
18
|
class Dataset:
|
|
16
19
|
"""A rslearn dataset.
|
|
@@ -20,7 +23,7 @@ class Dataset:
|
|
|
20
23
|
.. code-block:: none
|
|
21
24
|
|
|
22
25
|
dataset/
|
|
23
|
-
config.json
|
|
26
|
+
config.json # optional, if config provided as runtime object
|
|
24
27
|
windows/
|
|
25
28
|
group1/
|
|
26
29
|
epsg:3857_10_623565_1528020/
|
|
@@ -37,72 +40,58 @@ class Dataset:
|
|
|
37
40
|
materialize.
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
path: UPath,
|
|
46
|
+
disabled_layers: list[str] = [],
|
|
47
|
+
dataset_config: DatasetConfig | None = None,
|
|
48
|
+
) -> None:
|
|
41
49
|
"""Initializes a new Dataset.
|
|
42
50
|
|
|
43
51
|
Args:
|
|
44
52
|
path: the root directory of the dataset
|
|
53
|
+
disabled_layers: list of layers to disable
|
|
54
|
+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
|
|
45
55
|
"""
|
|
46
56
|
self.path = path
|
|
47
57
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
58
|
+
if dataset_config is None:
|
|
59
|
+
# Load dataset configuration from the dataset directory.
|
|
60
|
+
with (self.path / "config.json").open("r") as f:
|
|
61
|
+
config_content = f.read()
|
|
62
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
63
|
+
dataset_config = DatasetConfig.model_validate(
|
|
64
|
+
json.loads(config_content)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.layers = {}
|
|
68
|
+
for layer_name, layer_config in dataset_config.layers.items():
|
|
69
|
+
if layer_name in disabled_layers:
|
|
70
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
71
|
+
continue
|
|
72
|
+
self.layers[layer_name] = layer_config
|
|
73
|
+
|
|
74
|
+
self.tile_store_config = dataset_config.tile_store
|
|
75
|
+
self.storage = (
|
|
76
|
+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
77
|
+
self.path
|
|
78
|
+
)
|
|
79
|
+
)
|
|
57
80
|
|
|
58
81
|
def load_windows(
|
|
59
82
|
self,
|
|
60
83
|
groups: list[str] | None = None,
|
|
61
84
|
names: list[str] | None = None,
|
|
62
|
-
|
|
63
|
-
workers: int = 0,
|
|
85
|
+
**kwargs: Any,
|
|
64
86
|
) -> list[Window]:
|
|
65
87
|
"""Load the windows in the dataset.
|
|
66
88
|
|
|
67
89
|
Args:
|
|
68
90
|
groups: an optional list of groups to filter loading
|
|
69
91
|
names: an optional list of window names to filter loading
|
|
70
|
-
|
|
71
|
-
workers: number of parallel workers, default 0 (use main thread only to load windows)
|
|
92
|
+
kwargs: optional keyword arguments to pass to WindowStorage.get_windows.
|
|
72
93
|
"""
|
|
73
|
-
|
|
74
|
-
if not groups:
|
|
75
|
-
groups = []
|
|
76
|
-
for p in (self.path / "windows").iterdir():
|
|
77
|
-
groups.append(p.name)
|
|
78
|
-
for group in groups:
|
|
79
|
-
group_dir = self.path / "windows" / group
|
|
80
|
-
if names:
|
|
81
|
-
cur_names = names
|
|
82
|
-
else:
|
|
83
|
-
cur_names = []
|
|
84
|
-
for p in group_dir.iterdir():
|
|
85
|
-
cur_names.append(p.name)
|
|
86
|
-
|
|
87
|
-
for window_name in cur_names:
|
|
88
|
-
window_dir = group_dir / window_name
|
|
89
|
-
window_dirs.append(window_dir)
|
|
90
|
-
|
|
91
|
-
if workers == 0:
|
|
92
|
-
windows = [Window.load(window_dir) for window_dir in window_dirs]
|
|
93
|
-
else:
|
|
94
|
-
p = multiprocessing.Pool(workers)
|
|
95
|
-
outputs = p.imap_unordered(Window.load, window_dirs)
|
|
96
|
-
if show_progress:
|
|
97
|
-
outputs = tqdm.tqdm(
|
|
98
|
-
outputs, total=len(window_dirs), desc="Loading windows"
|
|
99
|
-
)
|
|
100
|
-
windows = []
|
|
101
|
-
for window in outputs:
|
|
102
|
-
windows.append(window)
|
|
103
|
-
p.close()
|
|
104
|
-
|
|
105
|
-
return windows
|
|
94
|
+
return self.storage.get_windows(groups=groups, names=names, **kwargs)
|
|
106
95
|
|
|
107
96
|
def get_tile_store(self) -> TileStore:
|
|
108
97
|
"""Get the tile store associated with this dataset.
|