rslearn 0.0.12__tar.gz → 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.12/rslearn.egg-info → rslearn-0.0.14}/PKG-INFO +2 -2
- {rslearn-0.0.12 → rslearn-0.0.14}/pyproject.toml +2 -2
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/config/dataset.py +23 -14
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/planetary_computer.py +52 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/handler_summaries.py +1 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/manage.py +16 -2
- rslearn-0.0.14/rslearn/lightning_cli.py +67 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/main.py +8 -62
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/model.py +1 -0
- rslearn-0.0.14/rslearn/train/all_patches_dataset.py +458 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/data_module.py +4 -2
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/dataset.py +10 -446
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/prediction_writer.py +25 -8
- rslearn-0.0.14/rslearn/train/tasks/embedding.py +116 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/array.py +6 -4
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/raster_format.py +38 -0
- {rslearn-0.0.12 → rslearn-0.0.14/rslearn.egg-info}/PKG-INFO +2 -2
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn.egg-info/SOURCES.txt +3 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn.egg-info/requires.txt +1 -1
- {rslearn-0.0.12 → rslearn-0.0.14}/LICENSE +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/NOTICE +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/README.md +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/const.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/py.typed +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/template_params.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.12 → rslearn-0.0.14}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.14
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -214,7 +214,7 @@ License-File: LICENSE
|
|
|
214
214
|
License-File: NOTICE
|
|
215
215
|
Requires-Dist: boto3>=1.39
|
|
216
216
|
Requires-Dist: fiona>=1.10
|
|
217
|
-
Requires-Dist: fsspec>=2025.
|
|
217
|
+
Requires-Dist: fsspec>=2025.10.0
|
|
218
218
|
Requires-Dist: jsonargparse>=4.35.0
|
|
219
219
|
Requires-Dist: lightning>=2.5.1.post0
|
|
220
220
|
Requires-Dist: Pillow>=11.3
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.14"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -11,7 +11,7 @@ requires-python = ">=3.11"
|
|
|
11
11
|
dependencies = [
|
|
12
12
|
"boto3>=1.39",
|
|
13
13
|
"fiona>=1.10",
|
|
14
|
-
"fsspec>=2025.
|
|
14
|
+
"fsspec>=2025.10.0", # this is used both directly and indirectly (via universal_pathlib) in our code
|
|
15
15
|
"jsonargparse>=4.35.0",
|
|
16
16
|
"lightning>=2.5.1.post0",
|
|
17
17
|
"Pillow>=11.3",
|
|
@@ -8,7 +8,6 @@ from typing import Any
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import pytimeparse
|
|
11
|
-
import torch
|
|
12
11
|
from rasterio.enums import Resampling
|
|
13
12
|
|
|
14
13
|
from rslearn.utils import PixelBounds, Projection
|
|
@@ -49,15 +48,6 @@ class DType(Enum):
|
|
|
49
48
|
return np.float32
|
|
50
49
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
51
50
|
|
|
52
|
-
def get_torch_dtype(self) -> torch.dtype:
|
|
53
|
-
"""Returns pytorch dtype object corresponding to this DType."""
|
|
54
|
-
if self == DType.INT32:
|
|
55
|
-
return torch.int32
|
|
56
|
-
elif self == DType.FLOAT32:
|
|
57
|
-
return torch.float32
|
|
58
|
-
else:
|
|
59
|
-
raise ValueError(f"unable to handle torch dtype {self}")
|
|
60
|
-
|
|
61
51
|
|
|
62
52
|
RESAMPLING_METHODS = {
|
|
63
53
|
"nearest": Resampling.nearest,
|
|
@@ -125,7 +115,8 @@ class BandSetConfig:
|
|
|
125
115
|
self,
|
|
126
116
|
config_dict: dict[str, Any],
|
|
127
117
|
dtype: DType,
|
|
128
|
-
bands: list[str],
|
|
118
|
+
bands: list[str] | None = None,
|
|
119
|
+
num_bands: int | None = None,
|
|
129
120
|
format: dict[str, Any] | None = None,
|
|
130
121
|
zoom_offset: int = 0,
|
|
131
122
|
remap: dict[str, Any] | None = None,
|
|
@@ -137,7 +128,10 @@ class BandSetConfig:
|
|
|
137
128
|
Args:
|
|
138
129
|
config_dict: the config dict used to configure this BandSetConfig
|
|
139
130
|
dtype: the pixel value type to store tiles in
|
|
140
|
-
bands: list of band names in this BandSetConfig
|
|
131
|
+
bands: list of band names in this BandSetConfig. One of bands or num_bands
|
|
132
|
+
must be set.
|
|
133
|
+
num_bands: the number of bands in this band set. The bands will be named
|
|
134
|
+
B00, B01, B02, etc.
|
|
141
135
|
format: the format to store tiles in, defaults to geotiff
|
|
142
136
|
zoom_offset: store images at a resolution higher or lower than the window
|
|
143
137
|
resolution. This enables keeping source data at its native resolution,
|
|
@@ -155,6 +149,14 @@ class BandSetConfig:
|
|
|
155
149
|
materialization when creating mosaics, to determine which parts of the
|
|
156
150
|
source images should be copied.
|
|
157
151
|
"""
|
|
152
|
+
if (bands is None and num_bands is None) or (
|
|
153
|
+
bands is not None and num_bands is not None
|
|
154
|
+
):
|
|
155
|
+
raise ValueError("exactly one of bands and num_bands must be set")
|
|
156
|
+
if bands is None:
|
|
157
|
+
assert num_bands is not None
|
|
158
|
+
bands = [f"B{idx}" for idx in range(num_bands)]
|
|
159
|
+
|
|
158
160
|
if class_names is not None and len(bands) != len(class_names):
|
|
159
161
|
raise ValueError(
|
|
160
162
|
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
@@ -187,9 +189,16 @@ class BandSetConfig:
|
|
|
187
189
|
kwargs = dict(
|
|
188
190
|
config_dict=config,
|
|
189
191
|
dtype=DType(config["dtype"]),
|
|
190
|
-
bands=config["bands"],
|
|
191
192
|
)
|
|
192
|
-
for k in [
|
|
193
|
+
for k in [
|
|
194
|
+
"bands",
|
|
195
|
+
"num_bands",
|
|
196
|
+
"format",
|
|
197
|
+
"zoom_offset",
|
|
198
|
+
"remap",
|
|
199
|
+
"class_names",
|
|
200
|
+
"nodata_vals",
|
|
201
|
+
]:
|
|
193
202
|
if k in config:
|
|
194
203
|
kwargs[k] = config[k]
|
|
195
204
|
return BandSetConfig(**kwargs) # type: ignore
|
|
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
|
|
|
827
827
|
kwargs[k] = d[k]
|
|
828
828
|
|
|
829
829
|
return Sentinel1(**kwargs)
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
class Naip(PlanetaryComputer):
|
|
833
|
+
"""A data source for NAIP data on Microsoft Planetary Computer.
|
|
834
|
+
|
|
835
|
+
See https://planetarycomputer.microsoft.com/dataset/naip.
|
|
836
|
+
"""
|
|
837
|
+
|
|
838
|
+
COLLECTION_NAME = "naip"
|
|
839
|
+
ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
|
|
840
|
+
|
|
841
|
+
def __init__(
|
|
842
|
+
self,
|
|
843
|
+
**kwargs: Any,
|
|
844
|
+
):
|
|
845
|
+
"""Initialize a new Naip instance.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
band_names: list of bands to try to ingest.
|
|
849
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
850
|
+
"""
|
|
851
|
+
super().__init__(
|
|
852
|
+
collection_name=self.COLLECTION_NAME,
|
|
853
|
+
asset_bands=self.ASSET_BANDS,
|
|
854
|
+
**kwargs,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
@staticmethod
|
|
858
|
+
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
|
|
859
|
+
"""Creates a new Naip instance from a configuration dictionary."""
|
|
860
|
+
if config.data_source is None:
|
|
861
|
+
raise ValueError("config.data_source is required")
|
|
862
|
+
d = config.data_source.config_dict
|
|
863
|
+
kwargs = {}
|
|
864
|
+
|
|
865
|
+
if "timeout_seconds" in d:
|
|
866
|
+
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
867
|
+
|
|
868
|
+
if "cache_dir" in d:
|
|
869
|
+
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
870
|
+
|
|
871
|
+
simple_optionals = [
|
|
872
|
+
"query",
|
|
873
|
+
"sort_by",
|
|
874
|
+
"sort_ascending",
|
|
875
|
+
"max_items_per_client",
|
|
876
|
+
]
|
|
877
|
+
for k in simple_optionals:
|
|
878
|
+
if k in d:
|
|
879
|
+
kwargs[k] = d[k]
|
|
880
|
+
|
|
881
|
+
return Naip(**kwargs)
|
|
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
|
|
|
118
118
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
119
119
|
windows_prepared=0,
|
|
120
120
|
windows_skipped=len(windows),
|
|
121
|
+
windows_rejected=0,
|
|
121
122
|
get_items_attempts=0,
|
|
122
123
|
)
|
|
123
124
|
)
|
|
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
|
|
|
141
142
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
142
143
|
windows_prepared=0,
|
|
143
144
|
windows_skipped=len(windows),
|
|
145
|
+
windows_rejected=0,
|
|
144
146
|
get_items_attempts=0,
|
|
145
147
|
)
|
|
146
148
|
)
|
|
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
|
|
|
181
183
|
attempts_counter=attempts_counter,
|
|
182
184
|
)
|
|
183
185
|
|
|
186
|
+
windows_prepared = 0
|
|
187
|
+
windows_rejected = 0
|
|
188
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
184
189
|
for window, result in zip(needed_windows, results):
|
|
185
190
|
layer_datas = window.load_layer_datas()
|
|
186
191
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
|
|
|
191
196
|
)
|
|
192
197
|
window.save_layer_datas(layer_datas)
|
|
193
198
|
|
|
199
|
+
# If result is empty and min_matches > 0, window was rejected due to min_matches
|
|
200
|
+
if len(result) == 0 and min_matches > 0:
|
|
201
|
+
windows_rejected += 1
|
|
202
|
+
else:
|
|
203
|
+
windows_prepared += 1
|
|
204
|
+
|
|
205
|
+
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
+
|
|
194
207
|
layer_summaries.append(
|
|
195
208
|
LayerPrepareSummary(
|
|
196
209
|
layer_name=layer_name,
|
|
197
210
|
data_source_name=data_source_cfg.name,
|
|
198
211
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
199
|
-
windows_prepared=
|
|
200
|
-
windows_skipped=
|
|
212
|
+
windows_prepared=windows_prepared,
|
|
213
|
+
windows_skipped=windows_skipped,
|
|
214
|
+
windows_rejected=windows_rejected,
|
|
201
215
|
get_items_attempts=attempts_counter.value,
|
|
202
216
|
)
|
|
203
217
|
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""LightningCLI for rslearn."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
6
|
+
|
|
7
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
8
|
+
from rslearn.train.data_module import RslearnDataModule
|
|
9
|
+
from rslearn.train.lightning_module import RslearnLightningModule
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RslearnLightningCLI(LightningCLI):
|
|
13
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
14
|
+
|
|
15
|
+
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
16
|
+
"""Link data.tasks to model.tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
parser: the argument parser
|
|
20
|
+
"""
|
|
21
|
+
# Link data.tasks to model.tasks
|
|
22
|
+
parser.link_arguments(
|
|
23
|
+
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def before_instantiate_classes(self) -> None:
|
|
27
|
+
"""Called before Lightning class initialization.
|
|
28
|
+
|
|
29
|
+
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
30
|
+
"""
|
|
31
|
+
subcommand = self.config.subcommand
|
|
32
|
+
c = self.config[subcommand]
|
|
33
|
+
|
|
34
|
+
# If there is a RslearnPredictionWriter, set its path.
|
|
35
|
+
prediction_writer_callback = None
|
|
36
|
+
if "callbacks" in c.trainer:
|
|
37
|
+
for existing_callback in c.trainer.callbacks:
|
|
38
|
+
if (
|
|
39
|
+
existing_callback.class_path
|
|
40
|
+
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
41
|
+
):
|
|
42
|
+
prediction_writer_callback = existing_callback
|
|
43
|
+
if prediction_writer_callback:
|
|
44
|
+
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
45
|
+
|
|
46
|
+
# Disable the sampler replacement, since the rslearn data module will set the
|
|
47
|
+
# sampler as needed.
|
|
48
|
+
c.trainer.use_distributed_sampler = False
|
|
49
|
+
|
|
50
|
+
# For predict, make sure that return_predictions is False.
|
|
51
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
52
|
+
# high memory consumption.
|
|
53
|
+
if subcommand == "predict":
|
|
54
|
+
c.return_predictions = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def model_handler() -> None:
|
|
58
|
+
"""Handler for any rslearn model X commands."""
|
|
59
|
+
RslearnLightningCLI(
|
|
60
|
+
model_class=RslearnLightningModule,
|
|
61
|
+
datamodule_class=RslearnDataModule,
|
|
62
|
+
args=sys.argv[2:],
|
|
63
|
+
subclass_mode_model=True,
|
|
64
|
+
subclass_mode_data=True,
|
|
65
|
+
save_config_kwargs={"overwrite": True},
|
|
66
|
+
parser_class=RslearnArgumentParser,
|
|
67
|
+
)
|
|
@@ -10,11 +10,9 @@ from datetime import UTC, datetime, timedelta
|
|
|
10
10
|
from typing import Any, TypeVar
|
|
11
11
|
|
|
12
12
|
import tqdm
|
|
13
|
-
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
14
13
|
from rasterio.crs import CRS
|
|
15
14
|
from upath import UPath
|
|
16
15
|
|
|
17
|
-
from rslearn.arg_parser import RslearnArgumentParser
|
|
18
16
|
from rslearn.config import LayerConfig
|
|
19
17
|
from rslearn.const import WGS84_EPSG
|
|
20
18
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -38,8 +36,6 @@ from rslearn.dataset.manage import (
|
|
|
38
36
|
)
|
|
39
37
|
from rslearn.log_utils import get_logger
|
|
40
38
|
from rslearn.tile_stores import get_tile_store_with_layer
|
|
41
|
-
from rslearn.train.data_module import RslearnDataModule
|
|
42
|
-
from rslearn.train.lightning_module import RslearnLightningModule
|
|
43
39
|
from rslearn.utils import Projection, STGeometry
|
|
44
40
|
|
|
45
41
|
logger = get_logger(__name__)
|
|
@@ -831,85 +827,35 @@ def dataset_build_index() -> None:
|
|
|
831
827
|
index.save_index(ds_path)
|
|
832
828
|
|
|
833
829
|
|
|
834
|
-
class RslearnLightningCLI(LightningCLI):
|
|
835
|
-
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
836
|
-
|
|
837
|
-
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
838
|
-
"""Link data.tasks to model.tasks.
|
|
839
|
-
|
|
840
|
-
Args:
|
|
841
|
-
parser: the argument parser
|
|
842
|
-
"""
|
|
843
|
-
# Link data.tasks to model.tasks
|
|
844
|
-
parser.link_arguments(
|
|
845
|
-
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
846
|
-
)
|
|
847
|
-
|
|
848
|
-
def before_instantiate_classes(self) -> None:
|
|
849
|
-
"""Called before Lightning class initialization.
|
|
850
|
-
|
|
851
|
-
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
852
|
-
"""
|
|
853
|
-
subcommand = self.config.subcommand
|
|
854
|
-
c = self.config[subcommand]
|
|
855
|
-
|
|
856
|
-
# If there is a RslearnPredictionWriter, set its path.
|
|
857
|
-
prediction_writer_callback = None
|
|
858
|
-
if "callbacks" in c.trainer:
|
|
859
|
-
for existing_callback in c.trainer.callbacks:
|
|
860
|
-
if (
|
|
861
|
-
existing_callback.class_path
|
|
862
|
-
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
863
|
-
):
|
|
864
|
-
prediction_writer_callback = existing_callback
|
|
865
|
-
if prediction_writer_callback:
|
|
866
|
-
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
867
|
-
|
|
868
|
-
# Disable the sampler replacement, since the rslearn data module will set the
|
|
869
|
-
# sampler as needed.
|
|
870
|
-
c.trainer.use_distributed_sampler = False
|
|
871
|
-
|
|
872
|
-
# For predict, make sure that return_predictions is False.
|
|
873
|
-
# Otherwise all the predictions would be stored in memory which can lead to
|
|
874
|
-
# high memory consumption.
|
|
875
|
-
if subcommand == "predict":
|
|
876
|
-
c.return_predictions = False
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def model_handler() -> None:
|
|
880
|
-
"""Handler for any rslearn model X commands."""
|
|
881
|
-
RslearnLightningCLI(
|
|
882
|
-
model_class=RslearnLightningModule,
|
|
883
|
-
datamodule_class=RslearnDataModule,
|
|
884
|
-
args=sys.argv[2:],
|
|
885
|
-
subclass_mode_model=True,
|
|
886
|
-
subclass_mode_data=True,
|
|
887
|
-
save_config_kwargs={"overwrite": True},
|
|
888
|
-
parser_class=RslearnArgumentParser,
|
|
889
|
-
)
|
|
890
|
-
|
|
891
|
-
|
|
892
830
|
@register_handler("model", "fit")
|
|
893
831
|
def model_fit() -> None:
|
|
894
832
|
"""Handler for rslearn model fit."""
|
|
833
|
+
from .lightning_cli import model_handler
|
|
834
|
+
|
|
895
835
|
model_handler()
|
|
896
836
|
|
|
897
837
|
|
|
898
838
|
@register_handler("model", "validate")
|
|
899
839
|
def model_validate() -> None:
|
|
900
840
|
"""Handler for rslearn model validate."""
|
|
841
|
+
from .lightning_cli import model_handler
|
|
842
|
+
|
|
901
843
|
model_handler()
|
|
902
844
|
|
|
903
845
|
|
|
904
846
|
@register_handler("model", "test")
|
|
905
847
|
def model_test() -> None:
|
|
906
848
|
"""Handler for rslearn model test."""
|
|
849
|
+
from .lightning_cli import model_handler
|
|
850
|
+
|
|
907
851
|
model_handler()
|
|
908
852
|
|
|
909
853
|
|
|
910
854
|
@register_handler("model", "predict")
|
|
911
855
|
def model_predict() -> None:
|
|
912
856
|
"""Handler for rslearn model predict."""
|
|
857
|
+
from .lightning_cli import model_handler
|
|
858
|
+
|
|
913
859
|
model_handler()
|
|
914
860
|
|
|
915
861
|
|