rslearn 0.0.13__tar.gz → 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.13/rslearn.egg-info → rslearn-0.0.14}/PKG-INFO +1 -1
- {rslearn-0.0.13 → rslearn-0.0.14}/pyproject.toml +1 -1
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/config/dataset.py +0 -10
- rslearn-0.0.14/rslearn/lightning_cli.py +67 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/main.py +8 -62
- rslearn-0.0.14/rslearn/train/all_patches_dataset.py +458 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/data_module.py +4 -2
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/dataset.py +10 -446
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/array.py +6 -4
- {rslearn-0.0.13 → rslearn-0.0.14/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn.egg-info/SOURCES.txt +2 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/LICENSE +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/NOTICE +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/README.md +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/const.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/py.typed +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/template_params.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.14}/setup.cfg +0 -0
|
@@ -8,7 +8,6 @@ from typing import Any
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import pytimeparse
|
|
11
|
-
import torch
|
|
12
11
|
from rasterio.enums import Resampling
|
|
13
12
|
|
|
14
13
|
from rslearn.utils import PixelBounds, Projection
|
|
@@ -49,15 +48,6 @@ class DType(Enum):
|
|
|
49
48
|
return np.float32
|
|
50
49
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
51
50
|
|
|
52
|
-
def get_torch_dtype(self) -> torch.dtype:
|
|
53
|
-
"""Returns pytorch dtype object corresponding to this DType."""
|
|
54
|
-
if self == DType.INT32:
|
|
55
|
-
return torch.int32
|
|
56
|
-
elif self == DType.FLOAT32:
|
|
57
|
-
return torch.float32
|
|
58
|
-
else:
|
|
59
|
-
raise ValueError(f"unable to handle torch dtype {self}")
|
|
60
|
-
|
|
61
51
|
|
|
62
52
|
RESAMPLING_METHODS = {
|
|
63
53
|
"nearest": Resampling.nearest,
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""LightningCLI for rslearn."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
6
|
+
|
|
7
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
8
|
+
from rslearn.train.data_module import RslearnDataModule
|
|
9
|
+
from rslearn.train.lightning_module import RslearnLightningModule
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RslearnLightningCLI(LightningCLI):
|
|
13
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
14
|
+
|
|
15
|
+
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
16
|
+
"""Link data.tasks to model.tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
parser: the argument parser
|
|
20
|
+
"""
|
|
21
|
+
# Link data.tasks to model.tasks
|
|
22
|
+
parser.link_arguments(
|
|
23
|
+
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def before_instantiate_classes(self) -> None:
|
|
27
|
+
"""Called before Lightning class initialization.
|
|
28
|
+
|
|
29
|
+
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
30
|
+
"""
|
|
31
|
+
subcommand = self.config.subcommand
|
|
32
|
+
c = self.config[subcommand]
|
|
33
|
+
|
|
34
|
+
# If there is a RslearnPredictionWriter, set its path.
|
|
35
|
+
prediction_writer_callback = None
|
|
36
|
+
if "callbacks" in c.trainer:
|
|
37
|
+
for existing_callback in c.trainer.callbacks:
|
|
38
|
+
if (
|
|
39
|
+
existing_callback.class_path
|
|
40
|
+
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
41
|
+
):
|
|
42
|
+
prediction_writer_callback = existing_callback
|
|
43
|
+
if prediction_writer_callback:
|
|
44
|
+
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
45
|
+
|
|
46
|
+
# Disable the sampler replacement, since the rslearn data module will set the
|
|
47
|
+
# sampler as needed.
|
|
48
|
+
c.trainer.use_distributed_sampler = False
|
|
49
|
+
|
|
50
|
+
# For predict, make sure that return_predictions is False.
|
|
51
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
52
|
+
# high memory consumption.
|
|
53
|
+
if subcommand == "predict":
|
|
54
|
+
c.return_predictions = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def model_handler() -> None:
|
|
58
|
+
"""Handler for any rslearn model X commands."""
|
|
59
|
+
RslearnLightningCLI(
|
|
60
|
+
model_class=RslearnLightningModule,
|
|
61
|
+
datamodule_class=RslearnDataModule,
|
|
62
|
+
args=sys.argv[2:],
|
|
63
|
+
subclass_mode_model=True,
|
|
64
|
+
subclass_mode_data=True,
|
|
65
|
+
save_config_kwargs={"overwrite": True},
|
|
66
|
+
parser_class=RslearnArgumentParser,
|
|
67
|
+
)
|
|
@@ -10,11 +10,9 @@ from datetime import UTC, datetime, timedelta
|
|
|
10
10
|
from typing import Any, TypeVar
|
|
11
11
|
|
|
12
12
|
import tqdm
|
|
13
|
-
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
14
13
|
from rasterio.crs import CRS
|
|
15
14
|
from upath import UPath
|
|
16
15
|
|
|
17
|
-
from rslearn.arg_parser import RslearnArgumentParser
|
|
18
16
|
from rslearn.config import LayerConfig
|
|
19
17
|
from rslearn.const import WGS84_EPSG
|
|
20
18
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -38,8 +36,6 @@ from rslearn.dataset.manage import (
|
|
|
38
36
|
)
|
|
39
37
|
from rslearn.log_utils import get_logger
|
|
40
38
|
from rslearn.tile_stores import get_tile_store_with_layer
|
|
41
|
-
from rslearn.train.data_module import RslearnDataModule
|
|
42
|
-
from rslearn.train.lightning_module import RslearnLightningModule
|
|
43
39
|
from rslearn.utils import Projection, STGeometry
|
|
44
40
|
|
|
45
41
|
logger = get_logger(__name__)
|
|
@@ -831,85 +827,35 @@ def dataset_build_index() -> None:
|
|
|
831
827
|
index.save_index(ds_path)
|
|
832
828
|
|
|
833
829
|
|
|
834
|
-
class RslearnLightningCLI(LightningCLI):
|
|
835
|
-
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
836
|
-
|
|
837
|
-
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
838
|
-
"""Link data.tasks to model.tasks.
|
|
839
|
-
|
|
840
|
-
Args:
|
|
841
|
-
parser: the argument parser
|
|
842
|
-
"""
|
|
843
|
-
# Link data.tasks to model.tasks
|
|
844
|
-
parser.link_arguments(
|
|
845
|
-
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
846
|
-
)
|
|
847
|
-
|
|
848
|
-
def before_instantiate_classes(self) -> None:
|
|
849
|
-
"""Called before Lightning class initialization.
|
|
850
|
-
|
|
851
|
-
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
852
|
-
"""
|
|
853
|
-
subcommand = self.config.subcommand
|
|
854
|
-
c = self.config[subcommand]
|
|
855
|
-
|
|
856
|
-
# If there is a RslearnPredictionWriter, set its path.
|
|
857
|
-
prediction_writer_callback = None
|
|
858
|
-
if "callbacks" in c.trainer:
|
|
859
|
-
for existing_callback in c.trainer.callbacks:
|
|
860
|
-
if (
|
|
861
|
-
existing_callback.class_path
|
|
862
|
-
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
863
|
-
):
|
|
864
|
-
prediction_writer_callback = existing_callback
|
|
865
|
-
if prediction_writer_callback:
|
|
866
|
-
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
867
|
-
|
|
868
|
-
# Disable the sampler replacement, since the rslearn data module will set the
|
|
869
|
-
# sampler as needed.
|
|
870
|
-
c.trainer.use_distributed_sampler = False
|
|
871
|
-
|
|
872
|
-
# For predict, make sure that return_predictions is False.
|
|
873
|
-
# Otherwise all the predictions would be stored in memory which can lead to
|
|
874
|
-
# high memory consumption.
|
|
875
|
-
if subcommand == "predict":
|
|
876
|
-
c.return_predictions = False
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def model_handler() -> None:
|
|
880
|
-
"""Handler for any rslearn model X commands."""
|
|
881
|
-
RslearnLightningCLI(
|
|
882
|
-
model_class=RslearnLightningModule,
|
|
883
|
-
datamodule_class=RslearnDataModule,
|
|
884
|
-
args=sys.argv[2:],
|
|
885
|
-
subclass_mode_model=True,
|
|
886
|
-
subclass_mode_data=True,
|
|
887
|
-
save_config_kwargs={"overwrite": True},
|
|
888
|
-
parser_class=RslearnArgumentParser,
|
|
889
|
-
)
|
|
890
|
-
|
|
891
|
-
|
|
892
830
|
@register_handler("model", "fit")
|
|
893
831
|
def model_fit() -> None:
|
|
894
832
|
"""Handler for rslearn model fit."""
|
|
833
|
+
from .lightning_cli import model_handler
|
|
834
|
+
|
|
895
835
|
model_handler()
|
|
896
836
|
|
|
897
837
|
|
|
898
838
|
@register_handler("model", "validate")
|
|
899
839
|
def model_validate() -> None:
|
|
900
840
|
"""Handler for rslearn model validate."""
|
|
841
|
+
from .lightning_cli import model_handler
|
|
842
|
+
|
|
901
843
|
model_handler()
|
|
902
844
|
|
|
903
845
|
|
|
904
846
|
@register_handler("model", "test")
|
|
905
847
|
def model_test() -> None:
|
|
906
848
|
"""Handler for rslearn model test."""
|
|
849
|
+
from .lightning_cli import model_handler
|
|
850
|
+
|
|
907
851
|
model_handler()
|
|
908
852
|
|
|
909
853
|
|
|
910
854
|
@register_handler("model", "predict")
|
|
911
855
|
def model_predict() -> None:
|
|
912
856
|
"""Handler for rslearn model predict."""
|
|
857
|
+
from .lightning_cli import model_handler
|
|
858
|
+
|
|
913
859
|
model_handler()
|
|
914
860
|
|
|
915
861
|
|
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
"""Wrapper around ModelDataset to load all patches (crops) in a window."""
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
from collections.abc import Iterable, Iterator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import shapely
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset import Window
|
|
11
|
+
from rslearn.train.dataset import ModelDataset
|
|
12
|
+
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_window_patch_options(
|
|
16
|
+
patch_size: tuple[int, int],
|
|
17
|
+
overlap_size: tuple[int, int],
|
|
18
|
+
bounds: PixelBounds,
|
|
19
|
+
) -> list[PixelBounds]:
|
|
20
|
+
"""Get the bounds of each input patch within the window bounds.
|
|
21
|
+
|
|
22
|
+
This is used when running inference on all patches (crops) of a large window, to
|
|
23
|
+
compute the position of each patch.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
patch_size: the size of the patches to extract.
|
|
27
|
+
overlap_size: the size of the overlap between patches.
|
|
28
|
+
bounds: the window bounds to divide up into smaller patches.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
a list of patch bounds within the overall bounds. The rightmost and
|
|
32
|
+
bottommost patches may extend beyond the provided bounds.
|
|
33
|
+
"""
|
|
34
|
+
# We stride the patches by patch_size - overlap_size until the last patch.
|
|
35
|
+
# We handle the last patch with a special case to ensure it does not exceed the
|
|
36
|
+
# window bounds. Instead, it may overlap the previous patch.
|
|
37
|
+
cols = list(
|
|
38
|
+
range(
|
|
39
|
+
bounds[0],
|
|
40
|
+
bounds[2] - patch_size[0],
|
|
41
|
+
patch_size[0] - overlap_size[0],
|
|
42
|
+
)
|
|
43
|
+
) + [bounds[2] - patch_size[0]]
|
|
44
|
+
rows = list(
|
|
45
|
+
range(
|
|
46
|
+
bounds[1],
|
|
47
|
+
bounds[3] - patch_size[1],
|
|
48
|
+
patch_size[1] - overlap_size[1],
|
|
49
|
+
)
|
|
50
|
+
) + [bounds[3] - patch_size[1]]
|
|
51
|
+
|
|
52
|
+
patch_bounds: list[PixelBounds] = []
|
|
53
|
+
for col in cols:
|
|
54
|
+
for row in rows:
|
|
55
|
+
patch_bounds.append((col, row, col + patch_size[0], row + patch_size[1]))
|
|
56
|
+
return patch_bounds
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def pad_slice_protect(
|
|
60
|
+
raw_inputs: dict[str, Any],
|
|
61
|
+
passthrough_inputs: dict[str, Any],
|
|
62
|
+
patch_size: tuple[int, int],
|
|
63
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
64
|
+
"""Pad tensors in-place by patch size to protect slicing near right/bottom edges.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
raw_inputs: the raw inputs to pad.
|
|
68
|
+
passthrough_inputs: the passthrough inputs to pad.
|
|
69
|
+
patch_size: the size of the patches to extract.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
a tuple of (raw_inputs, passthrough_inputs).
|
|
73
|
+
"""
|
|
74
|
+
for d in [raw_inputs, passthrough_inputs]:
|
|
75
|
+
for input_name, value in list(d.items()):
|
|
76
|
+
if not isinstance(value, torch.Tensor):
|
|
77
|
+
continue
|
|
78
|
+
d[input_name] = torch.nn.functional.pad(
|
|
79
|
+
value, pad=(0, patch_size[0], 0, patch_size[1])
|
|
80
|
+
)
|
|
81
|
+
return raw_inputs, passthrough_inputs
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
85
|
+
"""This wraps a ModelDataset to iterate over all patches in that dataset.
|
|
86
|
+
|
|
87
|
+
This should be used when SplitConfig.load_all_patches is enabled. The ModelDataset
|
|
88
|
+
is configured with no patch size (load entire windows), and the dataset is wrapped
|
|
89
|
+
in an AllPatchesDataset.
|
|
90
|
+
|
|
91
|
+
Similar to DistributedSampler, we add extra samples at each rank to ensure
|
|
92
|
+
consistent number of batches across all ranks.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
dataset: ModelDataset,
|
|
98
|
+
patch_size: tuple[int, int],
|
|
99
|
+
overlap_ratio: float = 0.0,
|
|
100
|
+
rank: int = 0,
|
|
101
|
+
world_size: int = 1,
|
|
102
|
+
):
|
|
103
|
+
"""Create a new IterableAllPatchesDataset.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
dataset: the ModelDataset to wrap.
|
|
107
|
+
patch_size: the size of the patches to extract.
|
|
108
|
+
overlap_ratio: whether to include overlap between the patches. Note that
|
|
109
|
+
the right/bottom-most patches may still overlap since we ensure that
|
|
110
|
+
all patches are contained in the window bounds.
|
|
111
|
+
rank: the global rank of this train worker process.
|
|
112
|
+
world_size: the total number of train worker processes.
|
|
113
|
+
"""
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.dataset = dataset
|
|
116
|
+
self.patch_size = patch_size
|
|
117
|
+
self.overlap_size = (
|
|
118
|
+
round(self.patch_size[0] * overlap_ratio),
|
|
119
|
+
round(self.patch_size[1] * overlap_ratio),
|
|
120
|
+
)
|
|
121
|
+
self.rank = rank
|
|
122
|
+
self.world_size = world_size
|
|
123
|
+
self.windows = self.dataset.get_dataset_examples()
|
|
124
|
+
|
|
125
|
+
def set_name(self, name: str) -> None:
|
|
126
|
+
"""Sets dataset name.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
name: dataset name
|
|
130
|
+
"""
|
|
131
|
+
self.dataset.set_name(name)
|
|
132
|
+
|
|
133
|
+
def get_window_num_patches(self, bounds: PixelBounds) -> int:
|
|
134
|
+
"""Get the number of patches for these bounds.
|
|
135
|
+
|
|
136
|
+
This corresponds to the length of the list returned by get_patch_options.
|
|
137
|
+
"""
|
|
138
|
+
num_cols = (
|
|
139
|
+
len(
|
|
140
|
+
range(
|
|
141
|
+
bounds[0],
|
|
142
|
+
bounds[2] - self.patch_size[0],
|
|
143
|
+
self.patch_size[0] - self.overlap_size[0],
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
+ 1
|
|
147
|
+
)
|
|
148
|
+
num_rows = (
|
|
149
|
+
len(
|
|
150
|
+
range(
|
|
151
|
+
bounds[1],
|
|
152
|
+
bounds[3] - self.patch_size[1],
|
|
153
|
+
self.patch_size[1] - self.overlap_size[1],
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
+ 1
|
|
157
|
+
)
|
|
158
|
+
return num_cols * num_rows
|
|
159
|
+
|
|
160
|
+
def _get_worker_iteration_data(self) -> tuple[Iterable[int], int]:
|
|
161
|
+
"""Get the windows we should iterate over.
|
|
162
|
+
|
|
163
|
+
This is split both by training worker (self.rank) and data loader worker (via
|
|
164
|
+
get_worker_info).
|
|
165
|
+
|
|
166
|
+
We also compute the total number of samples that each data loader worker should
|
|
167
|
+
yield. This is important for DDP to ensure that all ranks see the same number
|
|
168
|
+
of batches.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
a tuple (window_ids, num_samples_per_worker).
|
|
172
|
+
"""
|
|
173
|
+
# Figure out the total number of data loader workers and our worker ID.
|
|
174
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
175
|
+
if worker_info is None:
|
|
176
|
+
worker_id = 0
|
|
177
|
+
num_workers = 1
|
|
178
|
+
else:
|
|
179
|
+
worker_id = worker_info.id
|
|
180
|
+
num_workers = worker_info.num_workers
|
|
181
|
+
global_worker_id = self.rank * num_workers + worker_id
|
|
182
|
+
global_num_workers = self.world_size * num_workers
|
|
183
|
+
|
|
184
|
+
# Split up the windows evenly among the workers.
|
|
185
|
+
# We compute this for all workers since we will need to see the maximum number
|
|
186
|
+
# of samples under this assignment across workers.
|
|
187
|
+
window_indexes = range(len(self.windows))
|
|
188
|
+
windows_by_worker = [
|
|
189
|
+
window_indexes[cur_rank :: self.world_size][cur_worker_id::num_workers]
|
|
190
|
+
for cur_rank in range(self.world_size)
|
|
191
|
+
for cur_worker_id in range(num_workers)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# Now compute the maximum number of samples across workers.
|
|
195
|
+
max_num_patches = 0
|
|
196
|
+
for worker_windows in windows_by_worker:
|
|
197
|
+
worker_num_patches = 0
|
|
198
|
+
for window_id in worker_windows:
|
|
199
|
+
worker_num_patches += self.get_window_num_patches(
|
|
200
|
+
self.windows[window_id].bounds
|
|
201
|
+
)
|
|
202
|
+
max_num_patches = max(max_num_patches, worker_num_patches)
|
|
203
|
+
|
|
204
|
+
# Each worker needs at least one window, otherwise it won't be able to pad.
|
|
205
|
+
# Unless there are zero windows total, which is fine.
|
|
206
|
+
# Previously we would address this by borrowing the windows from another
|
|
207
|
+
# worker, but this causes issues with RslearnWriter: if we yield the same
|
|
208
|
+
# window from parallel workers, it may end up writing an empty output for that
|
|
209
|
+
# window in the end.
|
|
210
|
+
# So now we raise an error instead, and require the number of workers to be
|
|
211
|
+
# less than the number of windows.
|
|
212
|
+
if len(windows_by_worker[global_worker_id]) == 0 and max_num_patches > 0:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return (windows_by_worker[global_worker_id], max_num_patches)
|
|
218
|
+
|
|
219
|
+
def __iter__(
|
|
220
|
+
self,
|
|
221
|
+
) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]:
|
|
222
|
+
"""Iterate over all patches in each element of the underlying ModelDataset."""
|
|
223
|
+
# Iterate over the window IDs until we have returned enough samples.
|
|
224
|
+
window_ids, num_samples_needed = self._get_worker_iteration_data()
|
|
225
|
+
num_samples_returned = 0
|
|
226
|
+
|
|
227
|
+
for iteration_idx in itertools.count():
|
|
228
|
+
for window_id in window_ids:
|
|
229
|
+
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
|
|
230
|
+
window_id
|
|
231
|
+
)
|
|
232
|
+
bounds = metadata["bounds"]
|
|
233
|
+
|
|
234
|
+
# For simplicity, pad tensors by patch size to ensure that any patch bounds
|
|
235
|
+
# extending outside the window bounds will not have issues when we slice
|
|
236
|
+
# the tensors later.
|
|
237
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
|
|
238
|
+
|
|
239
|
+
# Now iterate over the patches and extract/yield the crops.
|
|
240
|
+
# Note that, in case user is leveraging RslearnWriter, it is important that
|
|
241
|
+
# the patch_idx be increasing (as we iterate) within one window.
|
|
242
|
+
patches = get_window_patch_options(
|
|
243
|
+
self.patch_size, self.overlap_size, bounds
|
|
244
|
+
)
|
|
245
|
+
for patch_idx, patch_bounds in enumerate(patches):
|
|
246
|
+
cur_geom = STGeometry(
|
|
247
|
+
metadata["projection"], shapely.box(*patch_bounds), None
|
|
248
|
+
)
|
|
249
|
+
start_offset = (
|
|
250
|
+
patch_bounds[0] - bounds[0],
|
|
251
|
+
patch_bounds[1] - bounds[1],
|
|
252
|
+
)
|
|
253
|
+
end_offset = (
|
|
254
|
+
patch_bounds[2] - bounds[0],
|
|
255
|
+
patch_bounds[3] - bounds[1],
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Define a helper function to handle each input dict.
|
|
259
|
+
def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
260
|
+
cropped = {}
|
|
261
|
+
for input_name, value in d.items():
|
|
262
|
+
if isinstance(value, torch.Tensor):
|
|
263
|
+
# Crop the CHW tensor.
|
|
264
|
+
cropped[input_name] = value[
|
|
265
|
+
:,
|
|
266
|
+
start_offset[1] : end_offset[1],
|
|
267
|
+
start_offset[0] : end_offset[0],
|
|
268
|
+
].clone()
|
|
269
|
+
elif isinstance(value, list):
|
|
270
|
+
cropped[input_name] = [
|
|
271
|
+
feat
|
|
272
|
+
for feat in value
|
|
273
|
+
if cur_geom.intersects(feat.geometry)
|
|
274
|
+
]
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"got input that is neither tensor nor feature list"
|
|
278
|
+
)
|
|
279
|
+
return cropped
|
|
280
|
+
|
|
281
|
+
cur_raw_inputs = crop_input_dict(raw_inputs)
|
|
282
|
+
cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
|
|
283
|
+
|
|
284
|
+
# Adjust the metadata as well.
|
|
285
|
+
cur_metadata = metadata.copy()
|
|
286
|
+
cur_metadata["bounds"] = patch_bounds
|
|
287
|
+
cur_metadata["patch_idx"] = patch_idx
|
|
288
|
+
cur_metadata["num_patches"] = len(patches)
|
|
289
|
+
|
|
290
|
+
# Now we can compute input and target dicts via the task.
|
|
291
|
+
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
292
|
+
cur_raw_inputs,
|
|
293
|
+
metadata=cur_metadata,
|
|
294
|
+
load_targets=not self.dataset.split_config.get_skip_targets(),
|
|
295
|
+
)
|
|
296
|
+
input_dict.update(cur_passthrough_inputs)
|
|
297
|
+
input_dict, target_dict = self.dataset.transforms(
|
|
298
|
+
input_dict, target_dict
|
|
299
|
+
)
|
|
300
|
+
input_dict["dataset_source"] = self.dataset.name
|
|
301
|
+
|
|
302
|
+
if num_samples_returned < num_samples_needed:
|
|
303
|
+
yield input_dict, target_dict, cur_metadata
|
|
304
|
+
num_samples_returned += 1
|
|
305
|
+
else:
|
|
306
|
+
assert iteration_idx > 0
|
|
307
|
+
|
|
308
|
+
if num_samples_returned >= num_samples_needed:
|
|
309
|
+
break
|
|
310
|
+
|
|
311
|
+
def get_dataset_examples(self) -> list[Window]:
|
|
312
|
+
"""Returns a list of windows in this dataset."""
|
|
313
|
+
return self.dataset.get_dataset_examples()
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
317
|
+
"""This wraps a ModelDataset to iterate over all patches in that dataset.
|
|
318
|
+
|
|
319
|
+
This should be used when SplitConfig.load_all_patches is enabled.
|
|
320
|
+
|
|
321
|
+
This is a simpler version of IterableAllPatchesDataset that caches all windows in memory.
|
|
322
|
+
This is useful for small datasets that fit in memory.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def __init__(
|
|
326
|
+
self,
|
|
327
|
+
dataset: ModelDataset,
|
|
328
|
+
patch_size: tuple[int, int],
|
|
329
|
+
overlap_ratio: float = 0.0,
|
|
330
|
+
):
|
|
331
|
+
"""Create a new InMemoryAllPatchesDataset.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
dataset: the ModelDataset to wrap.
|
|
335
|
+
patch_size: the size of the patches to extract.
|
|
336
|
+
overlap_ratio: whether to include overlap between the patches. Note that
|
|
337
|
+
the right/bottom-most patches may still overlap since we ensure that
|
|
338
|
+
all patches are contained in the window bounds.
|
|
339
|
+
"""
|
|
340
|
+
super().__init__()
|
|
341
|
+
self.dataset = dataset
|
|
342
|
+
self.patch_size = patch_size
|
|
343
|
+
self.overlap_size = (
|
|
344
|
+
round(self.patch_size[0] * overlap_ratio),
|
|
345
|
+
round(self.patch_size[1] * overlap_ratio),
|
|
346
|
+
)
|
|
347
|
+
self.windows = self.dataset.get_dataset_examples()
|
|
348
|
+
self.window_cache: dict[
|
|
349
|
+
int, tuple[dict[str, Any], dict[str, Any], dict[str, Any]]
|
|
350
|
+
] = {}
|
|
351
|
+
|
|
352
|
+
# Precompute the batch boundaries for each window
|
|
353
|
+
self.patches = []
|
|
354
|
+
for window_id, window in enumerate(self.windows):
|
|
355
|
+
patch_bounds = get_window_patch_options(
|
|
356
|
+
self.patch_size, self.overlap_size, window.bounds
|
|
357
|
+
)
|
|
358
|
+
for i, patch_bound in enumerate(patch_bounds):
|
|
359
|
+
self.patches.append((window_id, patch_bound, (i, len(patch_bounds))))
|
|
360
|
+
|
|
361
|
+
def get_raw_inputs(
|
|
362
|
+
self, index: int
|
|
363
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
364
|
+
"""Get the raw inputs for a single patch. Retrieve from cache if possible.
|
|
365
|
+
|
|
366
|
+
Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
index: the index of the patch.
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
a tuple of (raw_inputs, passthrough_inputs, metadata).
|
|
373
|
+
"""
|
|
374
|
+
if index in self.window_cache:
|
|
375
|
+
return self.window_cache[index]
|
|
376
|
+
|
|
377
|
+
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
|
|
378
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
|
|
379
|
+
|
|
380
|
+
self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
|
|
381
|
+
return self.window_cache[index]
|
|
382
|
+
|
|
383
|
+
@staticmethod
|
|
384
|
+
def _crop_input_dict(
|
|
385
|
+
d: dict[str, Any],
|
|
386
|
+
start_offset: tuple[int, int],
|
|
387
|
+
end_offset: tuple[int, int],
|
|
388
|
+
cur_geom: STGeometry,
|
|
389
|
+
) -> dict[str, Any]:
|
|
390
|
+
"""Crop a dictionary of inputs to the given bounds."""
|
|
391
|
+
cropped = {}
|
|
392
|
+
for input_name, value in d.items():
|
|
393
|
+
if isinstance(value, torch.Tensor):
|
|
394
|
+
cropped[input_name] = value[
|
|
395
|
+
:,
|
|
396
|
+
start_offset[1] : end_offset[1],
|
|
397
|
+
start_offset[0] : end_offset[0],
|
|
398
|
+
].clone()
|
|
399
|
+
elif isinstance(value, list):
|
|
400
|
+
cropped[input_name] = [
|
|
401
|
+
feat for feat in value if cur_geom.intersects(feat.geometry)
|
|
402
|
+
]
|
|
403
|
+
else:
|
|
404
|
+
raise ValueError("got input that is neither tensor nor feature list")
|
|
405
|
+
return cropped
|
|
406
|
+
|
|
407
|
+
def __len__(self) -> int:
|
|
408
|
+
"""Return the total number of patches in the dataset."""
|
|
409
|
+
return len(self.patches)
|
|
410
|
+
|
|
411
|
+
def __getitem__(
|
|
412
|
+
self, index: int
|
|
413
|
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
414
|
+
"""Return (input_dict, target_dict, metadata) for a single flattened patch."""
|
|
415
|
+
(window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
|
|
416
|
+
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
|
|
417
|
+
bounds = metadata["bounds"]
|
|
418
|
+
|
|
419
|
+
cur_geom = STGeometry(metadata["projection"], shapely.box(*patch_bounds), None)
|
|
420
|
+
start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
|
|
421
|
+
end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
|
|
422
|
+
|
|
423
|
+
cur_raw_inputs = self._crop_input_dict(
|
|
424
|
+
raw_inputs, start_offset, end_offset, cur_geom
|
|
425
|
+
)
|
|
426
|
+
cur_passthrough_inputs = self._crop_input_dict(
|
|
427
|
+
passthrough_inputs, start_offset, end_offset, cur_geom
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Adjust the metadata as well.
|
|
431
|
+
cur_metadata = metadata.copy()
|
|
432
|
+
cur_metadata["bounds"] = patch_bounds
|
|
433
|
+
cur_metadata["patch_idx"] = patch_idx
|
|
434
|
+
cur_metadata["num_patches"] = num_patches
|
|
435
|
+
|
|
436
|
+
# Now we can compute input and target dicts via the task.
|
|
437
|
+
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
438
|
+
cur_raw_inputs,
|
|
439
|
+
metadata=cur_metadata,
|
|
440
|
+
load_targets=not self.dataset.split_config.get_skip_targets(),
|
|
441
|
+
)
|
|
442
|
+
input_dict.update(cur_passthrough_inputs)
|
|
443
|
+
input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
|
|
444
|
+
input_dict["dataset_source"] = self.dataset.name
|
|
445
|
+
|
|
446
|
+
return input_dict, target_dict, cur_metadata
|
|
447
|
+
|
|
448
|
+
def get_dataset_examples(self) -> list[Window]:
|
|
449
|
+
"""Returns a list of windows in this dataset."""
|
|
450
|
+
return self.dataset.get_dataset_examples()
|
|
451
|
+
|
|
452
|
+
def set_name(self, name: str) -> None:
|
|
453
|
+
"""Sets dataset name.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
name: dataset name
|
|
457
|
+
"""
|
|
458
|
+
self.dataset.set_name(name)
|
|
@@ -15,10 +15,12 @@ from rslearn.dataset import Dataset
|
|
|
15
15
|
from rslearn.log_utils import get_logger
|
|
16
16
|
from rslearn.train.tasks import Task
|
|
17
17
|
|
|
18
|
-
from .
|
|
19
|
-
DataInput,
|
|
18
|
+
from .all_patches_dataset import (
|
|
20
19
|
InMemoryAllPatchesDataset,
|
|
21
20
|
IterableAllPatchesDataset,
|
|
21
|
+
)
|
|
22
|
+
from .dataset import (
|
|
23
|
+
DataInput,
|
|
22
24
|
ModelDataset,
|
|
23
25
|
MultiDataset,
|
|
24
26
|
RetryDataset,
|