rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,28 @@
|
|
|
1
1
|
"""Data source for raster data in ESA Copernicus API."""
|
|
2
2
|
|
|
3
|
+
import functools
|
|
4
|
+
import io
|
|
5
|
+
import json
|
|
6
|
+
import shutil
|
|
7
|
+
import urllib.request
|
|
3
8
|
import xml.etree.ElementTree as ET
|
|
9
|
+
import zipfile
|
|
4
10
|
from collections.abc import Callable
|
|
5
11
|
|
|
6
12
|
import numpy as np
|
|
7
13
|
import numpy.typing as npt
|
|
14
|
+
from upath import UPath
|
|
15
|
+
|
|
16
|
+
from rslearn.const import WGS84_PROJECTION
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.utils.fsspec import open_atomic
|
|
19
|
+
from rslearn.utils.geometry import STGeometry, flatten_shape
|
|
20
|
+
from rslearn.utils.grid_index import GridIndex
|
|
21
|
+
|
|
22
|
+
SENTINEL2_TILE_URL = "https://sentiwiki.copernicus.eu/__attachments/1692737/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.zip"
|
|
23
|
+
SENTINEL2_KML_NAMESPACE = "{http://www.opengis.net/kml/2.2}"
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
8
26
|
|
|
9
27
|
|
|
10
28
|
def get_harmonize_callback(
|
|
@@ -24,6 +42,8 @@ def get_harmonize_callback(
|
|
|
24
42
|
"""
|
|
25
43
|
offset = None
|
|
26
44
|
for el in tree.iter("RADIO_ADD_OFFSET"):
|
|
45
|
+
if el.text is None:
|
|
46
|
+
raise ValueError(f"text is missing in {el}")
|
|
27
47
|
value = int(el.text)
|
|
28
48
|
if offset is None:
|
|
29
49
|
offset = value
|
|
@@ -36,7 +56,133 @@ def get_harmonize_callback(
|
|
|
36
56
|
if offset is None or offset == 0:
|
|
37
57
|
return None
|
|
38
58
|
|
|
39
|
-
def callback(array):
|
|
40
|
-
return np.clip(array, -offset, None) + offset
|
|
59
|
+
def callback(array: npt.NDArray) -> npt.NDArray:
|
|
60
|
+
return np.clip(array, -offset, None) + offset # type: ignore
|
|
41
61
|
|
|
42
62
|
return callback
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _cache_sentinel2_tile_index(cache_dir: UPath) -> None:
|
|
66
|
+
"""Cache the tiles from SENTINEL2_TILE_URL.
|
|
67
|
+
|
|
68
|
+
This way we just need to download it once.
|
|
69
|
+
"""
|
|
70
|
+
json_fname = cache_dir / "tile_index.json"
|
|
71
|
+
|
|
72
|
+
if json_fname.exists():
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
logger.info(f"caching list of Sentinel-2 tiles to {json_fname}")
|
|
76
|
+
|
|
77
|
+
# Identify the Sentinel-2 tile names and bounds using the KML file.
|
|
78
|
+
# First, download the zip file and extract and parse the KML.
|
|
79
|
+
buf = io.BytesIO()
|
|
80
|
+
with urllib.request.urlopen(SENTINEL2_TILE_URL) as response:
|
|
81
|
+
shutil.copyfileobj(response, buf)
|
|
82
|
+
buf.seek(0)
|
|
83
|
+
with zipfile.ZipFile(buf, "r") as zipf:
|
|
84
|
+
member_names = zipf.namelist()
|
|
85
|
+
if len(member_names) != 1:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"Sentinel-2 tile zip file unexpectedly contains more than one file"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
with zipf.open(member_names[0]) as memberf:
|
|
91
|
+
tree = ET.parse(memberf)
|
|
92
|
+
|
|
93
|
+
# Map from the tile name to the longitude/latitude bounds.
|
|
94
|
+
json_data: dict[str, tuple[float, float, float, float]] = {}
|
|
95
|
+
|
|
96
|
+
# The KML is list of Placemark so iterate over those.
|
|
97
|
+
for placemark_node in tree.iter(SENTINEL2_KML_NAMESPACE + "Placemark"):
|
|
98
|
+
# The <name> node specifies the Sentinel-2 tile name.
|
|
99
|
+
name_node = placemark_node.find(SENTINEL2_KML_NAMESPACE + "name")
|
|
100
|
+
if name_node is None or name_node.text is None:
|
|
101
|
+
raise ValueError("Sentinel-2 KML has Placemark without valid name node")
|
|
102
|
+
|
|
103
|
+
tile_name = name_node.text
|
|
104
|
+
|
|
105
|
+
# There may be one or more <coordinates> nodes depending on whether it is a
|
|
106
|
+
# MultiGeometry. Here we just iterate over all of the coordinates since we are
|
|
107
|
+
# only interested in the bounds in WGS-84 coordinates.
|
|
108
|
+
lons = []
|
|
109
|
+
lats = []
|
|
110
|
+
for coord_node in placemark_node.iter(SENTINEL2_KML_NAMESPACE + "coordinates"):
|
|
111
|
+
# It is list of space-separated coordinates like:
|
|
112
|
+
# 180,-73.0597374076,0 176.8646237862,-72.9914734628,0 ...
|
|
113
|
+
if coord_node.text is None:
|
|
114
|
+
raise ValueError("Sentinel-2 KML has coordinates node missing text")
|
|
115
|
+
|
|
116
|
+
point_strs = coord_node.text.strip().split()
|
|
117
|
+
for point_str in point_strs:
|
|
118
|
+
parts = point_str.split(",")
|
|
119
|
+
if len(parts) != 2 and len(parts) != 3:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
lon = float(parts[0])
|
|
123
|
+
lat = float(parts[1])
|
|
124
|
+
lons.append(lon)
|
|
125
|
+
lats.append(lat)
|
|
126
|
+
|
|
127
|
+
if len(lons) == 0 or len(lats) == 0:
|
|
128
|
+
raise ValueError("Sentinel-2 KML has Placemark with no coordinates")
|
|
129
|
+
|
|
130
|
+
bounds = (
|
|
131
|
+
min(lons),
|
|
132
|
+
min(lats),
|
|
133
|
+
max(lons),
|
|
134
|
+
max(lats),
|
|
135
|
+
)
|
|
136
|
+
json_data[tile_name] = bounds
|
|
137
|
+
|
|
138
|
+
with open_atomic(json_fname, "w") as f:
|
|
139
|
+
json.dump(json_data, f)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@functools.cache
|
|
143
|
+
def load_sentinel2_tile_index(cache_dir: UPath) -> GridIndex:
|
|
144
|
+
"""Load a GridIndex over Sentinel-2 tiles.
|
|
145
|
+
|
|
146
|
+
This function is cached so the GridIndex only needs to be constructed once (per
|
|
147
|
+
process).
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
cache_dir: the directory to cache the list of Sentinel-2 tiles.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
GridIndex over the tile names
|
|
154
|
+
"""
|
|
155
|
+
_cache_sentinel2_tile_index(cache_dir)
|
|
156
|
+
json_fname = cache_dir / "tile_index.json"
|
|
157
|
+
with json_fname.open() as f:
|
|
158
|
+
json_data = json.load(f)
|
|
159
|
+
|
|
160
|
+
grid_index = GridIndex(0.5)
|
|
161
|
+
for tile_name, bounds in json_data.items():
|
|
162
|
+
grid_index.insert(bounds, tile_name)
|
|
163
|
+
|
|
164
|
+
return grid_index
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
|
|
168
|
+
"""Get all Sentinel-2 tiles (like 01CCV) intersecting the given geometry.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
geometry: the geometry to check.
|
|
172
|
+
cache_dir: directory to cache the tiles.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
list of Sentinel-2 tile names that intersect the geometry.
|
|
176
|
+
"""
|
|
177
|
+
tile_index = load_sentinel2_tile_index(cache_dir)
|
|
178
|
+
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
179
|
+
# If the shape is a collection, it could be cutting across prime meridian.
|
|
180
|
+
# So we query each component shape separately and collect the results to avoid
|
|
181
|
+
# issues.
|
|
182
|
+
# We assume the caller has already applied split_at_prime_meridian.
|
|
183
|
+
results = set()
|
|
184
|
+
for shp in flatten_shape(wgs84_geometry.shp):
|
|
185
|
+
for result in tile_index.query(shp.bounds):
|
|
186
|
+
assert isinstance(result, str)
|
|
187
|
+
results.add(result)
|
|
188
|
+
return list(results)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Base classes for rslearn data sources."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Generator
|
|
4
|
-
from typing import Any, BinaryIO
|
|
4
|
+
from typing import Any, BinaryIO, Generic, TypeVar
|
|
5
5
|
|
|
6
6
|
from rslearn.config import LayerConfig, QueryConfig
|
|
7
7
|
from rslearn.dataset import Window
|
|
@@ -51,15 +51,20 @@ class Item:
|
|
|
51
51
|
return hash(self.name)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
|
|
54
|
+
ItemType = TypeVar("ItemType", bound="Item")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DataSource(Generic[ItemType]):
|
|
55
58
|
"""A set of raster or vector files that can be retrieved.
|
|
56
59
|
|
|
57
60
|
Data sources should support at least one of ingest and materialize.
|
|
58
61
|
"""
|
|
59
62
|
|
|
63
|
+
TIMEOUT = 1000000 # Set very high to start
|
|
64
|
+
|
|
60
65
|
def get_items(
|
|
61
66
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
62
|
-
) -> list[list[list[
|
|
67
|
+
) -> list[list[list[ItemType]]]:
|
|
63
68
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
64
69
|
|
|
65
70
|
Args:
|
|
@@ -71,14 +76,14 @@ class DataSource:
|
|
|
71
76
|
"""
|
|
72
77
|
raise NotImplementedError
|
|
73
78
|
|
|
74
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
79
|
+
def deserialize_item(self, serialized_item: Any) -> ItemType:
|
|
75
80
|
"""Deserializes an item from JSON-decoded data."""
|
|
76
81
|
raise NotImplementedError
|
|
77
82
|
|
|
78
83
|
def ingest(
|
|
79
84
|
self,
|
|
80
85
|
tile_store: TileStore,
|
|
81
|
-
items: list[
|
|
86
|
+
items: list[ItemType],
|
|
82
87
|
geometries: list[list[STGeometry]],
|
|
83
88
|
) -> None:
|
|
84
89
|
"""Ingest items into the given tile store.
|
|
@@ -93,7 +98,7 @@ class DataSource:
|
|
|
93
98
|
def materialize(
|
|
94
99
|
self,
|
|
95
100
|
window: Window,
|
|
96
|
-
item_groups: list[list[
|
|
101
|
+
item_groups: list[list[ItemType]],
|
|
97
102
|
layer_name: str,
|
|
98
103
|
layer_cfg: LayerConfig,
|
|
99
104
|
) -> None:
|
|
@@ -108,17 +113,19 @@ class DataSource:
|
|
|
108
113
|
raise NotImplementedError
|
|
109
114
|
|
|
110
115
|
|
|
111
|
-
class ItemLookupDataSource(DataSource):
|
|
116
|
+
class ItemLookupDataSource(DataSource[ItemType]):
|
|
112
117
|
"""A data source that can look up items by name."""
|
|
113
118
|
|
|
114
|
-
def get_item_by_name(self, name: str) ->
|
|
119
|
+
def get_item_by_name(self, name: str) -> ItemType:
|
|
115
120
|
"""Gets an item by name."""
|
|
116
121
|
raise NotImplementedError
|
|
117
122
|
|
|
118
123
|
|
|
119
|
-
class RetrieveItemDataSource(DataSource):
|
|
124
|
+
class RetrieveItemDataSource(DataSource[ItemType]):
|
|
120
125
|
"""A data source that can retrieve items in their raw format."""
|
|
121
126
|
|
|
122
|
-
def retrieve_item(
|
|
127
|
+
def retrieve_item(
|
|
128
|
+
self, item: ItemType
|
|
129
|
+
) -> Generator[tuple[str, BinaryIO], None, None]:
|
|
123
130
|
"""Retrieves the rasters corresponding to an item as file streams."""
|
|
124
131
|
raise NotImplementedError
|