rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -186,6 +186,8 @@ class XyzTiles(DataSource):
|
|
|
186
186
|
@staticmethod
|
|
187
187
|
def from_config(config: LayerConfig, ds_path: UPath) -> "XyzTiles":
|
|
188
188
|
"""Creates a new XyzTiles instance from a configuration dictionary."""
|
|
189
|
+
if config.data_source is None:
|
|
190
|
+
raise ValueError("data_source is required")
|
|
189
191
|
d = config.data_source.config_dict
|
|
190
192
|
time_ranges = []
|
|
191
193
|
for str1, str2 in d["time_ranges"]:
|
|
@@ -207,7 +209,7 @@ class XyzTiles(DataSource):
|
|
|
207
209
|
|
|
208
210
|
def get_items(
|
|
209
211
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
210
|
-
) -> list[list[list[
|
|
212
|
+
) -> list[list[list[XyzItem]]]:
|
|
211
213
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
212
214
|
|
|
213
215
|
In XyzTiles we treat the data source as containing a single item, i.e., the
|
|
@@ -278,7 +280,7 @@ class XyzTiles(DataSource):
|
|
|
278
280
|
def materialize(
|
|
279
281
|
self,
|
|
280
282
|
window: Window,
|
|
281
|
-
item_groups: list[list[
|
|
283
|
+
item_groups: list[list[XyzItem]],
|
|
282
284
|
layer_name: str,
|
|
283
285
|
layer_cfg: LayerConfig,
|
|
284
286
|
) -> None:
|
|
@@ -305,13 +307,10 @@ class XyzTiles(DataSource):
|
|
|
305
307
|
window_projection, shapely.box(*window_bounds), None
|
|
306
308
|
)
|
|
307
309
|
projected_geometry = window_geometry.to_projection(self.projection)
|
|
308
|
-
projected_bounds =
|
|
309
|
-
math.floor(projected_geometry.shp.bounds[
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
math.ceil(projected_geometry.shp.bounds[3]),
|
|
313
|
-
]
|
|
314
|
-
projected_raster = self.read_bounds(item.url_template, projected_bounds)
|
|
310
|
+
projected_bounds = tuple(
|
|
311
|
+
math.floor(projected_geometry.shp.bounds[i]) for i in range(4)
|
|
312
|
+
)
|
|
313
|
+
projected_raster = self.read_bounds(item.url_template, projected_bounds) # type: ignore
|
|
315
314
|
|
|
316
315
|
# Attach the transform to the raster.
|
|
317
316
|
src_transform = rasterio.transform.Affine(
|
rslearn/dataset/add_windows.py
CHANGED
rslearn/dataset/dataset.py
CHANGED
|
@@ -7,10 +7,13 @@ import tqdm
|
|
|
7
7
|
from upath import UPath
|
|
8
8
|
|
|
9
9
|
from rslearn.config import TileStoreConfig, load_layer_config
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
10
11
|
from rslearn.tile_stores import TileStore, load_tile_store
|
|
11
12
|
|
|
12
13
|
from .window import Window
|
|
13
14
|
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
14
17
|
|
|
15
18
|
class Dataset:
|
|
16
19
|
"""A rslearn dataset.
|
|
@@ -37,21 +40,29 @@ class Dataset:
|
|
|
37
40
|
materialize.
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
|
-
def __init__(self, path: UPath) -> None:
|
|
43
|
+
def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
|
|
41
44
|
"""Initializes a new Dataset.
|
|
42
45
|
|
|
43
46
|
Args:
|
|
44
47
|
path: the root directory of the dataset
|
|
48
|
+
disabled_layers: list of layers to disable
|
|
45
49
|
"""
|
|
46
50
|
self.path = path
|
|
47
51
|
|
|
48
52
|
# Load dataset configuration.
|
|
53
|
+
|
|
49
54
|
with (self.path / "config.json").open("r") as f:
|
|
50
55
|
config = json.load(f)
|
|
51
|
-
self.layers = {
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
56
|
+
self.layers = {}
|
|
57
|
+
for layer_name, d in config["layers"].items():
|
|
58
|
+
# Layer names must not contain period, since we use period to
|
|
59
|
+
# distinguish different materialized groups within a layer.
|
|
60
|
+
assert "." not in layer_name, "layer names must not contain periods"
|
|
61
|
+
if layer_name in disabled_layers:
|
|
62
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
63
|
+
continue
|
|
64
|
+
self.layers[layer_name] = load_layer_config(d)
|
|
65
|
+
|
|
55
66
|
self.tile_store_config = TileStoreConfig.from_config(config["tile_store"])
|
|
56
67
|
self.materializer_name = config.get("materialize")
|
|
57
68
|
|
rslearn/dataset/manage.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
"""Functions to manage datasets."""
|
|
2
2
|
|
|
3
3
|
import rslearn.data_sources
|
|
4
|
-
from rslearn.config import LayerConfig, LayerType
|
|
4
|
+
from rslearn.config import LayerConfig, LayerType, RasterLayerConfig
|
|
5
5
|
from rslearn.data_sources import DataSource, Item
|
|
6
|
+
from rslearn.log_utils import get_logger
|
|
6
7
|
from rslearn.tile_stores import TileStore, get_tile_store_for_layer
|
|
7
|
-
from rslearn.utils import logger
|
|
8
8
|
|
|
9
9
|
from .dataset import Dataset
|
|
10
10
|
from .materialize import Materializers
|
|
11
11
|
from .window import Window, WindowLayerData
|
|
12
12
|
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
13
15
|
|
|
14
16
|
def prepare_dataset_windows(
|
|
15
17
|
dataset: Dataset, windows: list[Window], force: bool = False
|
|
@@ -37,7 +39,7 @@ def prepare_dataset_windows(
|
|
|
37
39
|
if layer_name in layer_datas and not force:
|
|
38
40
|
continue
|
|
39
41
|
needed_windows.append(window)
|
|
40
|
-
|
|
42
|
+
logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
|
|
41
43
|
if len(needed_windows) == 0:
|
|
42
44
|
continue
|
|
43
45
|
|
|
@@ -101,7 +103,7 @@ def ingest_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:
|
|
|
101
103
|
layer_cfg, dataset.path
|
|
102
104
|
)
|
|
103
105
|
|
|
104
|
-
geometries_by_item = {}
|
|
106
|
+
geometries_by_item: dict = {}
|
|
105
107
|
for window in windows:
|
|
106
108
|
layer_datas = window.load_layer_datas()
|
|
107
109
|
if layer_name not in layer_datas:
|
|
@@ -151,6 +153,7 @@ def is_window_ingested(
|
|
|
151
153
|
item = Item.deserialize(serialized_item)
|
|
152
154
|
|
|
153
155
|
if layer_cfg.layer_type == LayerType.RASTER:
|
|
156
|
+
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
154
157
|
for band_set in layer_cfg.band_sets:
|
|
155
158
|
projection, _ = band_set.get_final_projection_and_bounds(
|
|
156
159
|
window.projection, window.bounds
|
|
@@ -229,6 +232,8 @@ def materialize_window(
|
|
|
229
232
|
item_group.append(item)
|
|
230
233
|
item_groups.append(item_group)
|
|
231
234
|
|
|
235
|
+
if layer_cfg.data_source is None:
|
|
236
|
+
raise ValueError("data_source is required")
|
|
232
237
|
if layer_cfg.data_source.ingest:
|
|
233
238
|
if not is_window_ingested(dataset, window, check_layer_name=layer_name):
|
|
234
239
|
logger.info(
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Classes to implement dataset materialization."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
@@ -24,8 +24,10 @@ from .window import Window
|
|
|
24
24
|
|
|
25
25
|
Materializers = ClassRegistry()
|
|
26
26
|
|
|
27
|
+
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
27
28
|
|
|
28
|
-
|
|
29
|
+
|
|
30
|
+
class Materializer(Generic[LayerConfigType]):
|
|
29
31
|
"""An abstract class that materializes data from a tile store."""
|
|
30
32
|
|
|
31
33
|
def materialize(
|
|
@@ -33,7 +35,7 @@ class Materializer:
|
|
|
33
35
|
tile_store: TileStore,
|
|
34
36
|
window: Window,
|
|
35
37
|
layer_name: str,
|
|
36
|
-
layer_cfg:
|
|
38
|
+
layer_cfg: LayerConfigType,
|
|
37
39
|
item_groups: list[list[Item]],
|
|
38
40
|
) -> None:
|
|
39
41
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
@@ -82,6 +84,8 @@ def read_raster_window_from_tiles(
|
|
|
82
84
|
dst_row_offset = intersection[1] - bounds[1]
|
|
83
85
|
|
|
84
86
|
src = ts_layer.read_raster(intersection)
|
|
87
|
+
if src is None:
|
|
88
|
+
raise ValueError(f"No raster data found for bounds {intersection}")
|
|
85
89
|
src = src[src_indexes, :, :]
|
|
86
90
|
if remapper:
|
|
87
91
|
src = remapper(src, dst.dtype)
|
|
@@ -97,7 +101,7 @@ def read_raster_window_from_tiles(
|
|
|
97
101
|
|
|
98
102
|
|
|
99
103
|
@Materializers.register("raster")
|
|
100
|
-
class RasterMaterializer(Materializer):
|
|
104
|
+
class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
101
105
|
"""A Materializer for raster data."""
|
|
102
106
|
|
|
103
107
|
def materialize(
|
|
@@ -105,7 +109,7 @@ class RasterMaterializer(Materializer):
|
|
|
105
109
|
tile_store: TileStore,
|
|
106
110
|
window: Window,
|
|
107
111
|
layer_name: str,
|
|
108
|
-
layer_cfg:
|
|
112
|
+
layer_cfg: RasterLayerConfig,
|
|
109
113
|
item_groups: list[list[Item]],
|
|
110
114
|
) -> None:
|
|
111
115
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
@@ -142,6 +146,12 @@ class RasterMaterializer(Materializer):
|
|
|
142
146
|
if band_cfg.remap:
|
|
143
147
|
remapper = load_remapper(band_cfg.remap)
|
|
144
148
|
|
|
149
|
+
if band_cfg.format is None or band_cfg.bands is None or bounds is None:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"No raster format or bands specified for {layer_name} \
|
|
152
|
+
with {band_cfg}"
|
|
153
|
+
)
|
|
154
|
+
|
|
145
155
|
raster_format = load_raster_format(
|
|
146
156
|
RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
|
|
147
157
|
)
|
|
@@ -182,6 +192,11 @@ class RasterMaterializer(Materializer):
|
|
|
182
192
|
ts_layer = layer_tile_store.get_layer(
|
|
183
193
|
(item.name, suffix, str(projection))
|
|
184
194
|
)
|
|
195
|
+
if ts_layer is None:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"No tile store layer found for {item.name} {suffix} \
|
|
198
|
+
{projection}"
|
|
199
|
+
)
|
|
185
200
|
read_raster_window_from_tiles(
|
|
186
201
|
dst, ts_layer, bounds, src_indexes, dst_indexes, remapper
|
|
187
202
|
)
|
|
@@ -223,6 +238,8 @@ class VectorMaterializer(Materializer):
|
|
|
223
238
|
projection, bounds = layer_cfg.get_final_projection_and_bounds(
|
|
224
239
|
window.projection, window.bounds
|
|
225
240
|
)
|
|
241
|
+
if bounds is None:
|
|
242
|
+
raise ValueError(f"No bounds specified for {layer_name}")
|
|
226
243
|
vector_format = load_vector_format(layer_cfg.format)
|
|
227
244
|
|
|
228
245
|
out_layer_dirs: list[UPath] = []
|
|
@@ -241,6 +258,10 @@ class VectorMaterializer(Materializer):
|
|
|
241
258
|
ts_layer = get_tile_store_for_layer(
|
|
242
259
|
tile_store, layer_name, layer_cfg
|
|
243
260
|
).get_layer((item.name, str(projection)))
|
|
261
|
+
if ts_layer is None:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
f"No tile store layer found for {item.name} {projection}"
|
|
264
|
+
)
|
|
244
265
|
cur_features = ts_layer.read_vector(bounds)
|
|
245
266
|
features.extend(cur_features)
|
|
246
267
|
|
rslearn/dataset/window.py
CHANGED
|
@@ -7,9 +7,12 @@ from typing import Any
|
|
|
7
7
|
import shapely
|
|
8
8
|
from upath import UPath
|
|
9
9
|
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
10
11
|
from rslearn.utils import Projection, STGeometry
|
|
11
12
|
from rslearn.utils.fsspec import open_atomic
|
|
12
13
|
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
class WindowLayerData:
|
|
15
18
|
"""Layer data for retrieved layers specifying relevant items in the data source.
|
|
@@ -115,6 +118,7 @@ class Window:
|
|
|
115
118
|
"options": self.options,
|
|
116
119
|
}
|
|
117
120
|
metadata_path = self.path / "metadata.json"
|
|
121
|
+
logger.info(f"Saving window metadata to {metadata_path}")
|
|
118
122
|
with open_atomic(metadata_path, "w") as f:
|
|
119
123
|
json.dump(metadata, f)
|
|
120
124
|
|
|
@@ -141,6 +145,7 @@ class Window:
|
|
|
141
145
|
"""Save layer datas to items.json."""
|
|
142
146
|
json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
|
|
143
147
|
items_fname = self.path / "items.json"
|
|
148
|
+
logger.info(f"Saving window items to {items_fname}")
|
|
144
149
|
with open_atomic(items_fname, "w") as f:
|
|
145
150
|
json.dump(json_data, f)
|
|
146
151
|
|
rslearn/log_utils.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Logging utilities."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s"
|
|
8
|
+
# DETAILED_LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s call_trace=%(pathname)s L%(lineno)-4d" # noqa
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_logger(name: str) -> logging.Logger:
|
|
12
|
+
"""Get a logger with a console handler."""
|
|
13
|
+
this_logger = logging.getLogger(name)
|
|
14
|
+
log_level = os.environ.get("RSLEARN_LOGLEVEL", "INFO")
|
|
15
|
+
if not this_logger.handlers:
|
|
16
|
+
console_handler = logging.StreamHandler(sys.stdout)
|
|
17
|
+
console_handler.setLevel(log_level)
|
|
18
|
+
console_formatter = logging.Formatter(LOG_FORMAT)
|
|
19
|
+
console_handler.setFormatter(console_formatter)
|
|
20
|
+
this_logger.addHandler(console_handler)
|
|
21
|
+
|
|
22
|
+
this_logger.setLevel(log_level)
|
|
23
|
+
this_logger.propagate = True
|
|
24
|
+
return this_logger
|