rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +420 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +14 -20
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +377 -1
- rslearn/main.py +3 -3
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/olmoearth_pretrain/model.py +2 -5
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
|
@@ -12,10 +12,8 @@ from lightning.pytorch.callbacks import BasePredictionWriter
|
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
14
|
from rslearn.config import (
|
|
15
|
+
LayerConfig,
|
|
15
16
|
LayerType,
|
|
16
|
-
RasterFormatConfig,
|
|
17
|
-
RasterLayerConfig,
|
|
18
|
-
VectorLayerConfig,
|
|
19
17
|
)
|
|
20
18
|
from rslearn.dataset import Dataset, Window
|
|
21
19
|
from rslearn.log_utils import get_logger
|
|
@@ -25,9 +23,8 @@ from rslearn.utils.geometry import PixelBounds
|
|
|
25
23
|
from rslearn.utils.raster_format import (
|
|
26
24
|
RasterFormat,
|
|
27
25
|
adjust_projection_and_bounds_for_array,
|
|
28
|
-
load_raster_format,
|
|
29
26
|
)
|
|
30
|
-
from rslearn.utils.vector_format import VectorFormat
|
|
27
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
31
28
|
|
|
32
29
|
from .lightning_module import RslearnLightningModule
|
|
33
30
|
from .tasks.task import Task
|
|
@@ -150,7 +147,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
150
147
|
selector: list[str] | None = None,
|
|
151
148
|
merger: PatchPredictionMerger | None = None,
|
|
152
149
|
output_path: str | Path | None = None,
|
|
153
|
-
layer_config:
|
|
150
|
+
layer_config: LayerConfig | None = None,
|
|
154
151
|
):
|
|
155
152
|
"""Create a new RslearnWriter.
|
|
156
153
|
|
|
@@ -178,43 +175,31 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
178
175
|
)
|
|
179
176
|
|
|
180
177
|
# Handle dataset and layer config
|
|
181
|
-
self.layer_config:
|
|
178
|
+
self.layer_config: LayerConfig
|
|
182
179
|
if layer_config:
|
|
183
180
|
self.layer_config = layer_config
|
|
184
|
-
self.dataset = None if self.output_path else Dataset(self.path)
|
|
185
181
|
else:
|
|
186
|
-
|
|
187
|
-
if self.output_layer not in
|
|
182
|
+
dataset = Dataset(self.path)
|
|
183
|
+
if self.output_layer not in dataset.layers:
|
|
188
184
|
raise KeyError(
|
|
189
185
|
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
190
186
|
)
|
|
191
|
-
|
|
192
|
-
# Type narrowing to ensure compatibility
|
|
193
|
-
if isinstance(raw_layer_config, (RasterLayerConfig | VectorLayerConfig)):
|
|
194
|
-
self.layer_config = raw_layer_config
|
|
195
|
-
else:
|
|
196
|
-
raise ValueError(
|
|
197
|
-
f"Layer config must be RasterLayerConfig or VectorLayerConfig, got {type(raw_layer_config)}"
|
|
198
|
-
)
|
|
187
|
+
self.layer_config = dataset.layers[self.output_layer]
|
|
199
188
|
|
|
200
189
|
self.format: RasterFormat | VectorFormat
|
|
201
|
-
if self.layer_config.
|
|
202
|
-
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
190
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
203
191
|
band_cfg = self.layer_config.band_sets[0]
|
|
204
|
-
self.format =
|
|
205
|
-
|
|
206
|
-
)
|
|
207
|
-
elif self.layer_config.layer_type == LayerType.VECTOR:
|
|
208
|
-
assert isinstance(self.layer_config, VectorLayerConfig)
|
|
209
|
-
self.format = load_vector_format(self.layer_config.format)
|
|
192
|
+
self.format = band_cfg.instantiate_raster_format()
|
|
193
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
194
|
+
self.format = self.layer_config.instantiate_vector_format()
|
|
210
195
|
else:
|
|
211
|
-
raise ValueError(f"invalid layer type {self.layer_config.
|
|
196
|
+
raise ValueError(f"invalid layer type {self.layer_config.type}")
|
|
212
197
|
|
|
213
198
|
if merger is not None:
|
|
214
199
|
self.merger = merger
|
|
215
|
-
elif self.layer_config.
|
|
200
|
+
elif self.layer_config.type == LayerType.RASTER:
|
|
216
201
|
self.merger = RasterMerger()
|
|
217
|
-
elif self.layer_config.
|
|
202
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
218
203
|
self.merger = VectorMerger()
|
|
219
204
|
|
|
220
205
|
# Map from window name to pending data to write.
|
|
@@ -337,8 +322,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
337
322
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
338
323
|
merged_output = self.merger.merge(window, pending_output)
|
|
339
324
|
|
|
340
|
-
if self.layer_config.
|
|
341
|
-
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
325
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
342
326
|
raster_dir = window.get_raster_dir(
|
|
343
327
|
self.output_layer, self.layer_config.band_sets[0].bands
|
|
344
328
|
)
|
|
@@ -351,7 +335,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
351
335
|
)
|
|
352
336
|
self.format.encode_raster(raster_dir, projection, bounds, merged_output)
|
|
353
337
|
|
|
354
|
-
elif self.layer_config.
|
|
338
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
355
339
|
layer_dir = window.get_layer_dir(self.output_layer)
|
|
356
340
|
assert isinstance(self.format, VectorFormat)
|
|
357
341
|
self.format.encode_vector(layer_dir, merged_output)
|
|
@@ -26,7 +26,7 @@ class ClassificationTask(BasicTask):
|
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
28
|
property_name: str,
|
|
29
|
-
classes: list,
|
|
29
|
+
classes: list[str],
|
|
30
30
|
filters: list[tuple[str, str]] = [],
|
|
31
31
|
read_class_id: bool = False,
|
|
32
32
|
allow_invalid: bool = False,
|
|
@@ -176,6 +176,7 @@ class ClassificationTask(BasicTask):
|
|
|
176
176
|
# For multiclass classification or when using the default threshold
|
|
177
177
|
class_idx = probs.argmax().item()
|
|
178
178
|
|
|
179
|
+
value: str | int
|
|
179
180
|
if not self.read_class_id:
|
|
180
181
|
value = self.classes[class_idx] # type: ignore
|
|
181
182
|
else:
|
rslearn/utils/fsspec.py
CHANGED
|
@@ -156,3 +156,23 @@ def open_rasterio_upath_writer(
|
|
|
156
156
|
with path.open("wb") as f:
|
|
157
157
|
with rasterio.open(f, "w", **kwargs) as raster:
|
|
158
158
|
yield raster
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
|
|
162
|
+
"""Get the suffix of fname relative to base_dir.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
base_dir: the base directory.
|
|
166
|
+
fname: a filename within the base directory.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
the suffix on base_dir that would yield the given filename.
|
|
170
|
+
"""
|
|
171
|
+
if not fname.path.startswith(base_dir.path):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"filename {fname.path} must start with base directory {base_dir.path}"
|
|
174
|
+
)
|
|
175
|
+
suffix = fname.path[len(base_dir.path) :]
|
|
176
|
+
if suffix.startswith("/"):
|
|
177
|
+
suffix = suffix[1:]
|
|
178
|
+
return suffix
|
rslearn/utils/jsonargparse.py
CHANGED
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
"""Custom serialization for jsonargparse."""
|
|
2
2
|
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
3
6
|
import jsonargparse
|
|
4
7
|
from rasterio.crs import CRS
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
14
|
+
|
|
15
|
+
INITIALIZED = False
|
|
5
16
|
|
|
6
17
|
|
|
7
18
|
def crs_serializer(v: CRS) -> str:
|
|
@@ -28,6 +39,74 @@ def crs_deserializer(v: str) -> CRS:
|
|
|
28
39
|
return CRS.from_string(v)
|
|
29
40
|
|
|
30
41
|
|
|
42
|
+
def datetime_serializer(v: datetime) -> str:
|
|
43
|
+
"""Serialize datetime for jsonargparse.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
v: the datetime object.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
the datetime encoded to string
|
|
50
|
+
"""
|
|
51
|
+
return v.isoformat()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def datetime_deserializer(v: str) -> datetime:
|
|
55
|
+
"""Deserialize datetime for jsonargparse.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
v: the encoded datetime.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
the decoded datetime object
|
|
62
|
+
"""
|
|
63
|
+
return datetime.fromisoformat(v)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
|
|
67
|
+
"""Serialize DataSourceContext for jsonargparse."""
|
|
68
|
+
x = {
|
|
69
|
+
"ds_path": (str(v.ds_path) if v.ds_path is not None else None),
|
|
70
|
+
"layer_config": (
|
|
71
|
+
v.layer_config.model_dump(mode="json")
|
|
72
|
+
if v.layer_config is not None
|
|
73
|
+
else None
|
|
74
|
+
),
|
|
75
|
+
}
|
|
76
|
+
return x
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
80
|
+
"""Deserialize DataSourceContext for jsonargparse."""
|
|
81
|
+
# We lazily import these to avoid cyclic dependency.
|
|
82
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
83
|
+
|
|
84
|
+
return DataSourceContext(
|
|
85
|
+
ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
|
|
86
|
+
layer_config=(
|
|
87
|
+
LayerConfig.model_validate(v["layer_config"])
|
|
88
|
+
if v["layer_config"] is not None
|
|
89
|
+
else None
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
31
94
|
def init_jsonargparse() -> None:
|
|
32
95
|
"""Initialize custom jsonargparse serializers."""
|
|
96
|
+
global INITIALIZED
|
|
97
|
+
if INITIALIZED:
|
|
98
|
+
return
|
|
33
99
|
jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
|
|
100
|
+
jsonargparse.typing.register_type(
|
|
101
|
+
datetime, datetime_serializer, datetime_deserializer
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
105
|
+
|
|
106
|
+
jsonargparse.typing.register_type(
|
|
107
|
+
DataSourceContext,
|
|
108
|
+
data_source_context_serializer,
|
|
109
|
+
data_source_context_deserializer,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
INITIALIZED = True
|
rslearn/utils/raster_format.py
CHANGED
|
@@ -2,8 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import json
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, BinaryIO, TypeVar
|
|
5
|
+
from typing import Any, BinaryIO
|
|
7
6
|
|
|
8
7
|
import affine
|
|
9
8
|
import numpy as np
|
|
@@ -14,34 +13,12 @@ from rasterio.crs import CRS
|
|
|
14
13
|
from rasterio.enums import Resampling
|
|
15
14
|
from upath import UPath
|
|
16
15
|
|
|
17
|
-
from rslearn.config import RasterFormatConfig
|
|
18
16
|
from rslearn.const import TILE_SIZE
|
|
19
17
|
from rslearn.log_utils import get_logger
|
|
20
18
|
from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
|
|
21
19
|
|
|
22
20
|
from .geometry import PixelBounds, Projection
|
|
23
21
|
|
|
24
|
-
_RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
|
|
28
|
-
"""Registry for RasterFormat classes."""
|
|
29
|
-
|
|
30
|
-
def register(
|
|
31
|
-
self, name: str
|
|
32
|
-
) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
|
|
33
|
-
"""Decorator to register a raster format class."""
|
|
34
|
-
|
|
35
|
-
def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
|
|
36
|
-
self[name] = cls
|
|
37
|
-
return cls
|
|
38
|
-
|
|
39
|
-
return decorator
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
RasterFormats = _RasterFormatRegistry()
|
|
43
|
-
|
|
44
|
-
|
|
45
22
|
logger = get_logger(__name__)
|
|
46
23
|
|
|
47
24
|
|
|
@@ -219,7 +196,6 @@ class RasterFormat:
|
|
|
219
196
|
raise NotImplementedError
|
|
220
197
|
|
|
221
198
|
|
|
222
|
-
@RasterFormats.register("image_tile")
|
|
223
199
|
class ImageTileRasterFormat(RasterFormat):
|
|
224
200
|
"""A RasterFormat that stores data in image tiles corresponding to grid cells.
|
|
225
201
|
|
|
@@ -468,7 +444,6 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
468
444
|
)
|
|
469
445
|
|
|
470
446
|
|
|
471
|
-
@RasterFormats.register("geotiff")
|
|
472
447
|
class GeotiffRasterFormat(RasterFormat):
|
|
473
448
|
"""A raster format that uses one big, tiled GeoTIFF with small block size."""
|
|
474
449
|
|
|
@@ -623,7 +598,6 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
623
598
|
return GeotiffRasterFormat(**kwargs)
|
|
624
599
|
|
|
625
600
|
|
|
626
|
-
@RasterFormats.register("single_image")
|
|
627
601
|
class SingleImageRasterFormat(RasterFormat):
|
|
628
602
|
"""A raster format that produces a single image called image.png/jpg.
|
|
629
603
|
|
|
@@ -775,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
|
|
|
775
749
|
if "format" in config:
|
|
776
750
|
kwargs["format"] = config["format"]
|
|
777
751
|
return SingleImageRasterFormat(**kwargs)
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
|
|
781
|
-
"""Loads a RasterFormat from a RasterFormatConfig.
|
|
782
|
-
|
|
783
|
-
Args:
|
|
784
|
-
config: the RasterFormatConfig configuration object specifying the
|
|
785
|
-
RasterFormat.
|
|
786
|
-
|
|
787
|
-
Returns:
|
|
788
|
-
the loaded RasterFormat implementation
|
|
789
|
-
"""
|
|
790
|
-
cls = RasterFormats[config.name]
|
|
791
|
-
return cls.from_config(config.name, config.config_dict)
|
rslearn/utils/vector_format.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
"""Classes for writing vector data to a UPath."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Callable
|
|
5
4
|
from enum import Enum
|
|
6
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
7
6
|
|
|
8
7
|
import shapely
|
|
9
8
|
from rasterio.crs import CRS
|
|
10
9
|
from upath import UPath
|
|
11
10
|
|
|
12
|
-
from rslearn.config import VectorFormatConfig
|
|
13
11
|
from rslearn.const import WGS84_PROJECTION
|
|
14
12
|
from rslearn.log_utils import get_logger
|
|
15
13
|
from rslearn.utils.fsspec import open_atomic
|
|
@@ -18,25 +16,6 @@ from .feature import Feature
|
|
|
18
16
|
from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
|
|
19
17
|
|
|
20
18
|
logger = get_logger(__name__)
|
|
21
|
-
_VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
|
|
25
|
-
"""Registry for VectorFormat classes."""
|
|
26
|
-
|
|
27
|
-
def register(
|
|
28
|
-
self, name: str
|
|
29
|
-
) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
|
|
30
|
-
"""Decorator to register a vector format class."""
|
|
31
|
-
|
|
32
|
-
def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
|
|
33
|
-
self[name] = cls
|
|
34
|
-
return cls
|
|
35
|
-
|
|
36
|
-
return decorator
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
VectorFormats = _VectorFormatRegistry()
|
|
40
19
|
|
|
41
20
|
|
|
42
21
|
class VectorFormat:
|
|
@@ -85,7 +64,6 @@ class VectorFormat:
|
|
|
85
64
|
raise NotImplementedError
|
|
86
65
|
|
|
87
66
|
|
|
88
|
-
@VectorFormats.register("tile")
|
|
89
67
|
class TileVectorFormat(VectorFormat):
|
|
90
68
|
"""TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
|
|
91
69
|
|
|
@@ -275,7 +253,6 @@ class GeojsonCoordinateMode(Enum):
|
|
|
275
253
|
WGS84 = "wgs84"
|
|
276
254
|
|
|
277
255
|
|
|
278
|
-
@VectorFormats.register("geojson")
|
|
279
256
|
class GeojsonVectorFormat(VectorFormat):
|
|
280
257
|
"""A vector format that uses one big GeoJSON."""
|
|
281
258
|
|
|
@@ -429,17 +406,3 @@ class GeojsonVectorFormat(VectorFormat):
|
|
|
429
406
|
if "coordinate_mode" in config:
|
|
430
407
|
kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
|
|
431
408
|
return GeojsonVectorFormat(**kwargs)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
|
|
435
|
-
"""Loads a VectorFormat from a VectorFormatConfig.
|
|
436
|
-
|
|
437
|
-
Args:
|
|
438
|
-
config: the VectorFormatConfig configuration object specifying the
|
|
439
|
-
VectorFormat.
|
|
440
|
-
|
|
441
|
-
Returns:
|
|
442
|
-
the loaded VectorFormat implementation
|
|
443
|
-
"""
|
|
444
|
-
cls = VectorFormats[config.name]
|
|
445
|
-
return cls.from_config(config.name, config.config_dict)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -343,10 +343,12 @@ directory `/path/to/dataset` and corresponding configuration file at
|
|
|
343
343
|
"bands": ["R", "G", "B"]
|
|
344
344
|
}],
|
|
345
345
|
"data_source": {
|
|
346
|
-
"
|
|
347
|
-
"
|
|
348
|
-
|
|
349
|
-
|
|
346
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
347
|
+
"init_args": {
|
|
348
|
+
"index_cache_dir": "cache/sentinel2/",
|
|
349
|
+
"sort_by": "cloud_cover",
|
|
350
|
+
"use_rtree_index": false
|
|
351
|
+
}
|
|
350
352
|
}
|
|
351
353
|
}
|
|
352
354
|
}
|
|
@@ -453,8 +455,10 @@ automate this process. Update the dataset `config.json` with a new layer:
|
|
|
453
455
|
}],
|
|
454
456
|
"resampling_method": "nearest",
|
|
455
457
|
"data_source": {
|
|
456
|
-
"
|
|
457
|
-
"
|
|
458
|
+
"class_path": "rslearn.data_sources.local_files.LocalFiles",
|
|
459
|
+
"init_args": {
|
|
460
|
+
"src_dir": "file:///path/to/world_cover_tifs/"
|
|
461
|
+
}
|
|
458
462
|
}
|
|
459
463
|
}
|
|
460
464
|
},
|
|
@@ -516,8 +520,7 @@ model:
|
|
|
516
520
|
data:
|
|
517
521
|
class_path: rslearn.train.data_module.RslearnDataModule
|
|
518
522
|
init_args:
|
|
519
|
-
|
|
520
|
-
path: /path/to/dataset/
|
|
523
|
+
path: ${DATASET_PATH}
|
|
521
524
|
# This defines the layers that should be read for each window.
|
|
522
525
|
# The key ("image" / "targets") is what the data will be called in the model,
|
|
523
526
|
# while the layers option specifies which layers will be read.
|
|
@@ -615,7 +618,9 @@ trainer:
|
|
|
615
618
|
...
|
|
616
619
|
- class_path: rslearn.train.prediction_writer.RslearnWriter
|
|
617
620
|
init_args:
|
|
618
|
-
|
|
621
|
+
# We need to include this argument, but it will be overridden with the dataset
|
|
622
|
+
# path from data.init_args.path.
|
|
623
|
+
path: placeholder
|
|
619
624
|
output_layer: output
|
|
620
625
|
```
|
|
621
626
|
|
|
@@ -768,24 +773,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
|
|
|
768
773
|
SegmentationTask and overriding the visualize function.
|
|
769
774
|
|
|
770
775
|
|
|
771
|
-
###
|
|
776
|
+
### Checkpoint and Logging Management
|
|
777
|
+
|
|
778
|
+
Above, we needed to configure the checkpoint directory in the model config (the
|
|
779
|
+
`dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
|
|
780
|
+
specify the checkpoint path when applying the model. Additionally, metrics are logged
|
|
781
|
+
to the local filesystem and not well organized.
|
|
772
782
|
|
|
773
|
-
We can
|
|
783
|
+
We can instead let rslearn automatically manage checkpoints, along with logging to
|
|
784
|
+
Weights & Biases. To do so, we add project_name, run_name, and management_dir options
|
|
785
|
+
to the model config. The project_name corresponds to the W&B project, and the run name
|
|
786
|
+
corresponds to the W&B name. The management_dir is a directory to store project data;
|
|
787
|
+
rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
|
|
788
|
+
and uses it to store checkpoints.
|
|
774
789
|
|
|
775
790
|
```yaml
|
|
791
|
+
model:
|
|
792
|
+
# ...
|
|
793
|
+
data:
|
|
794
|
+
# ...
|
|
776
795
|
trainer:
|
|
777
796
|
# ...
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
name: version_00
|
|
797
|
+
project_name: land_cover_model
|
|
798
|
+
run_name: version_00
|
|
799
|
+
# This sets the option via the MANAGEMENT_DIR environment variable.
|
|
800
|
+
management_dir: ${MANAGEMENT_DIR}
|
|
783
801
|
```
|
|
784
802
|
|
|
785
|
-
Now,
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
803
|
+
Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
|
|
804
|
+
|
|
805
|
+
```
|
|
806
|
+
export MANAGEMENT_DIR=./project_data
|
|
807
|
+
rslearn model fit --config land_cover_model.yaml
|
|
808
|
+
```
|
|
809
|
+
|
|
810
|
+
The training and validation loss and accuracy metric should now be logged to W&B. The
|
|
811
|
+
accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
|
|
812
|
+
by passing the relevant init_args to the task, e.g. mean IoU and F1:
|
|
789
813
|
|
|
790
814
|
```yaml
|
|
791
815
|
class_path: rslearn.train.tasks.segmentation.SegmentationTask
|
|
@@ -796,6 +820,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
|
|
|
796
820
|
enable_f1_metric: true
|
|
797
821
|
```
|
|
798
822
|
|
|
823
|
+
When calling `model test` and `model predict` with management_dir set, rslearn will
|
|
824
|
+
automatically load the best checkpoint from the project directory, or raise an error if
|
|
825
|
+
no existing checkpoint exists. This behavior can be overridden with the
|
|
826
|
+
`--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
|
|
827
|
+
details). Logging will be enabled during fit but not test/predict, and this can also
|
|
828
|
+
be overridden, using `--log_mode`.
|
|
829
|
+
|
|
799
830
|
|
|
800
831
|
### Inputting Multiple Sentinel-2 Images
|
|
801
832
|
|
|
@@ -818,10 +849,12 @@ query_config section. This can replace the sentinel2 layer:
|
|
|
818
849
|
"bands": ["R", "G", "B"]
|
|
819
850
|
}],
|
|
820
851
|
"data_source": {
|
|
821
|
-
"
|
|
822
|
-
"
|
|
823
|
-
|
|
824
|
-
|
|
852
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
853
|
+
"init_args": {
|
|
854
|
+
"index_cache_dir": "cache/sentinel2/",
|
|
855
|
+
"sort_by": "cloud_cover",
|
|
856
|
+
"use_rtree_index": false
|
|
857
|
+
},
|
|
825
858
|
"query_config": {
|
|
826
859
|
"max_matches": 3
|
|
827
860
|
}
|