rslearn 0.0.13__tar.gz → 0.0.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.13/rslearn.egg-info → rslearn-0.0.15}/PKG-INFO +1 -1
- {rslearn-0.0.13 → rslearn-0.0.15}/pyproject.toml +1 -1
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/config/dataset.py +0 -10
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/manage.py +14 -6
- rslearn-0.0.15/rslearn/lightning_cli.py +67 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/main.py +8 -62
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/clay/clay.py +14 -1
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/croma.py +26 -3
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/satlaspretrain.py +18 -4
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/terramind.py +19 -0
- rslearn-0.0.15/rslearn/train/all_patches_dataset.py +458 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/data_module.py +4 -2
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/dataset.py +10 -446
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/array.py +6 -4
- {rslearn-0.0.13 → rslearn-0.0.15/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn.egg-info/SOURCES.txt +2 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/LICENSE +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/NOTICE +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/README.md +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/const.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/py.typed +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/template_params.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.13 → rslearn-0.0.15}/setup.cfg +0 -0
|
@@ -8,7 +8,6 @@ from typing import Any
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import numpy.typing as npt
|
|
10
10
|
import pytimeparse
|
|
11
|
-
import torch
|
|
12
11
|
from rasterio.enums import Resampling
|
|
13
12
|
|
|
14
13
|
from rslearn.utils import PixelBounds, Projection
|
|
@@ -49,15 +48,6 @@ class DType(Enum):
|
|
|
49
48
|
return np.float32
|
|
50
49
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
51
50
|
|
|
52
|
-
def get_torch_dtype(self) -> torch.dtype:
|
|
53
|
-
"""Returns pytorch dtype object corresponding to this DType."""
|
|
54
|
-
if self == DType.INT32:
|
|
55
|
-
return torch.int32
|
|
56
|
-
elif self == DType.FLOAT32:
|
|
57
|
-
return torch.float32
|
|
58
|
-
else:
|
|
59
|
-
raise ValueError(f"unable to handle torch dtype {self}")
|
|
60
|
-
|
|
61
51
|
|
|
62
52
|
RESAMPLING_METHODS = {
|
|
63
53
|
"nearest": Resampling.nearest,
|
|
@@ -124,12 +124,24 @@ def prepare_dataset_windows(
|
|
|
124
124
|
)
|
|
125
125
|
continue
|
|
126
126
|
data_source_cfg = layer_cfg.data_source
|
|
127
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
127
128
|
|
|
128
129
|
# Get windows that need to be prepared for this layer.
|
|
130
|
+
# Also track which windows are skipped vs previously rejected.
|
|
129
131
|
needed_windows = []
|
|
132
|
+
windows_skipped = 0
|
|
133
|
+
windows_rejected = 0
|
|
130
134
|
for window in windows:
|
|
131
135
|
layer_datas = window.load_layer_datas()
|
|
132
136
|
if layer_name in layer_datas and not force:
|
|
137
|
+
# Window already has layer data - check if it was previously rejected
|
|
138
|
+
layer_data = layer_datas[layer_name]
|
|
139
|
+
if len(layer_data.serialized_item_groups) == 0 and min_matches > 0:
|
|
140
|
+
# Previously rejected due to min_matches
|
|
141
|
+
windows_rejected += 1
|
|
142
|
+
else:
|
|
143
|
+
# Successfully prepared previously
|
|
144
|
+
windows_skipped += 1
|
|
133
145
|
continue
|
|
134
146
|
needed_windows.append(window)
|
|
135
147
|
logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
|
|
@@ -141,8 +153,8 @@ def prepare_dataset_windows(
|
|
|
141
153
|
data_source_name=data_source_cfg.name,
|
|
142
154
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
143
155
|
windows_prepared=0,
|
|
144
|
-
windows_skipped=
|
|
145
|
-
windows_rejected=
|
|
156
|
+
windows_skipped=windows_skipped,
|
|
157
|
+
windows_rejected=windows_rejected,
|
|
146
158
|
get_items_attempts=0,
|
|
147
159
|
)
|
|
148
160
|
)
|
|
@@ -184,8 +196,6 @@ def prepare_dataset_windows(
|
|
|
184
196
|
)
|
|
185
197
|
|
|
186
198
|
windows_prepared = 0
|
|
187
|
-
windows_rejected = 0
|
|
188
|
-
min_matches = data_source_cfg.query_config.min_matches
|
|
189
199
|
for window, result in zip(needed_windows, results):
|
|
190
200
|
layer_datas = window.load_layer_datas()
|
|
191
201
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -202,8 +212,6 @@ def prepare_dataset_windows(
|
|
|
202
212
|
else:
|
|
203
213
|
windows_prepared += 1
|
|
204
214
|
|
|
205
|
-
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
-
|
|
207
215
|
layer_summaries.append(
|
|
208
216
|
LayerPrepareSummary(
|
|
209
217
|
layer_name=layer_name,
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""LightningCLI for rslearn."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
6
|
+
|
|
7
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
8
|
+
from rslearn.train.data_module import RslearnDataModule
|
|
9
|
+
from rslearn.train.lightning_module import RslearnLightningModule
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RslearnLightningCLI(LightningCLI):
|
|
13
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
14
|
+
|
|
15
|
+
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
16
|
+
"""Link data.tasks to model.tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
parser: the argument parser
|
|
20
|
+
"""
|
|
21
|
+
# Link data.tasks to model.tasks
|
|
22
|
+
parser.link_arguments(
|
|
23
|
+
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def before_instantiate_classes(self) -> None:
|
|
27
|
+
"""Called before Lightning class initialization.
|
|
28
|
+
|
|
29
|
+
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
30
|
+
"""
|
|
31
|
+
subcommand = self.config.subcommand
|
|
32
|
+
c = self.config[subcommand]
|
|
33
|
+
|
|
34
|
+
# If there is a RslearnPredictionWriter, set its path.
|
|
35
|
+
prediction_writer_callback = None
|
|
36
|
+
if "callbacks" in c.trainer:
|
|
37
|
+
for existing_callback in c.trainer.callbacks:
|
|
38
|
+
if (
|
|
39
|
+
existing_callback.class_path
|
|
40
|
+
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
41
|
+
):
|
|
42
|
+
prediction_writer_callback = existing_callback
|
|
43
|
+
if prediction_writer_callback:
|
|
44
|
+
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
45
|
+
|
|
46
|
+
# Disable the sampler replacement, since the rslearn data module will set the
|
|
47
|
+
# sampler as needed.
|
|
48
|
+
c.trainer.use_distributed_sampler = False
|
|
49
|
+
|
|
50
|
+
# For predict, make sure that return_predictions is False.
|
|
51
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
52
|
+
# high memory consumption.
|
|
53
|
+
if subcommand == "predict":
|
|
54
|
+
c.return_predictions = False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def model_handler() -> None:
|
|
58
|
+
"""Handler for any rslearn model X commands."""
|
|
59
|
+
RslearnLightningCLI(
|
|
60
|
+
model_class=RslearnLightningModule,
|
|
61
|
+
datamodule_class=RslearnDataModule,
|
|
62
|
+
args=sys.argv[2:],
|
|
63
|
+
subclass_mode_model=True,
|
|
64
|
+
subclass_mode_data=True,
|
|
65
|
+
save_config_kwargs={"overwrite": True},
|
|
66
|
+
parser_class=RslearnArgumentParser,
|
|
67
|
+
)
|
|
@@ -10,11 +10,9 @@ from datetime import UTC, datetime, timedelta
|
|
|
10
10
|
from typing import Any, TypeVar
|
|
11
11
|
|
|
12
12
|
import tqdm
|
|
13
|
-
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
14
13
|
from rasterio.crs import CRS
|
|
15
14
|
from upath import UPath
|
|
16
15
|
|
|
17
|
-
from rslearn.arg_parser import RslearnArgumentParser
|
|
18
16
|
from rslearn.config import LayerConfig
|
|
19
17
|
from rslearn.const import WGS84_EPSG
|
|
20
18
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -38,8 +36,6 @@ from rslearn.dataset.manage import (
|
|
|
38
36
|
)
|
|
39
37
|
from rslearn.log_utils import get_logger
|
|
40
38
|
from rslearn.tile_stores import get_tile_store_with_layer
|
|
41
|
-
from rslearn.train.data_module import RslearnDataModule
|
|
42
|
-
from rslearn.train.lightning_module import RslearnLightningModule
|
|
43
39
|
from rslearn.utils import Projection, STGeometry
|
|
44
40
|
|
|
45
41
|
logger = get_logger(__name__)
|
|
@@ -831,85 +827,35 @@ def dataset_build_index() -> None:
|
|
|
831
827
|
index.save_index(ds_path)
|
|
832
828
|
|
|
833
829
|
|
|
834
|
-
class RslearnLightningCLI(LightningCLI):
|
|
835
|
-
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
836
|
-
|
|
837
|
-
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
838
|
-
"""Link data.tasks to model.tasks.
|
|
839
|
-
|
|
840
|
-
Args:
|
|
841
|
-
parser: the argument parser
|
|
842
|
-
"""
|
|
843
|
-
# Link data.tasks to model.tasks
|
|
844
|
-
parser.link_arguments(
|
|
845
|
-
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
846
|
-
)
|
|
847
|
-
|
|
848
|
-
def before_instantiate_classes(self) -> None:
|
|
849
|
-
"""Called before Lightning class initialization.
|
|
850
|
-
|
|
851
|
-
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
852
|
-
"""
|
|
853
|
-
subcommand = self.config.subcommand
|
|
854
|
-
c = self.config[subcommand]
|
|
855
|
-
|
|
856
|
-
# If there is a RslearnPredictionWriter, set its path.
|
|
857
|
-
prediction_writer_callback = None
|
|
858
|
-
if "callbacks" in c.trainer:
|
|
859
|
-
for existing_callback in c.trainer.callbacks:
|
|
860
|
-
if (
|
|
861
|
-
existing_callback.class_path
|
|
862
|
-
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
863
|
-
):
|
|
864
|
-
prediction_writer_callback = existing_callback
|
|
865
|
-
if prediction_writer_callback:
|
|
866
|
-
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
867
|
-
|
|
868
|
-
# Disable the sampler replacement, since the rslearn data module will set the
|
|
869
|
-
# sampler as needed.
|
|
870
|
-
c.trainer.use_distributed_sampler = False
|
|
871
|
-
|
|
872
|
-
# For predict, make sure that return_predictions is False.
|
|
873
|
-
# Otherwise all the predictions would be stored in memory which can lead to
|
|
874
|
-
# high memory consumption.
|
|
875
|
-
if subcommand == "predict":
|
|
876
|
-
c.return_predictions = False
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def model_handler() -> None:
|
|
880
|
-
"""Handler for any rslearn model X commands."""
|
|
881
|
-
RslearnLightningCLI(
|
|
882
|
-
model_class=RslearnLightningModule,
|
|
883
|
-
datamodule_class=RslearnDataModule,
|
|
884
|
-
args=sys.argv[2:],
|
|
885
|
-
subclass_mode_model=True,
|
|
886
|
-
subclass_mode_data=True,
|
|
887
|
-
save_config_kwargs={"overwrite": True},
|
|
888
|
-
parser_class=RslearnArgumentParser,
|
|
889
|
-
)
|
|
890
|
-
|
|
891
|
-
|
|
892
830
|
@register_handler("model", "fit")
|
|
893
831
|
def model_fit() -> None:
|
|
894
832
|
"""Handler for rslearn model fit."""
|
|
833
|
+
from .lightning_cli import model_handler
|
|
834
|
+
|
|
895
835
|
model_handler()
|
|
896
836
|
|
|
897
837
|
|
|
898
838
|
@register_handler("model", "validate")
|
|
899
839
|
def model_validate() -> None:
|
|
900
840
|
"""Handler for rslearn model validate."""
|
|
841
|
+
from .lightning_cli import model_handler
|
|
842
|
+
|
|
901
843
|
model_handler()
|
|
902
844
|
|
|
903
845
|
|
|
904
846
|
@register_handler("model", "test")
|
|
905
847
|
def model_test() -> None:
|
|
906
848
|
"""Handler for rslearn model test."""
|
|
849
|
+
from .lightning_cli import model_handler
|
|
850
|
+
|
|
907
851
|
model_handler()
|
|
908
852
|
|
|
909
853
|
|
|
910
854
|
@register_handler("model", "predict")
|
|
911
855
|
def model_predict() -> None:
|
|
912
856
|
"""Handler for rslearn model predict."""
|
|
857
|
+
from .lightning_cli import model_handler
|
|
858
|
+
|
|
913
859
|
model_handler()
|
|
914
860
|
|
|
915
861
|
|
|
@@ -8,6 +8,7 @@ from importlib.resources import files
|
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
11
12
|
import yaml
|
|
12
13
|
from einops import rearrange
|
|
13
14
|
from huggingface_hub import hf_hub_download
|
|
@@ -30,6 +31,7 @@ PATCH_SIZE = 8
|
|
|
30
31
|
CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
|
|
31
32
|
CONFIG_DIR = files("rslearn.models.clay.configs")
|
|
32
33
|
CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
|
|
34
|
+
DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def get_clay_checkpoint_path(
|
|
@@ -49,6 +51,7 @@ class Clay(torch.nn.Module):
|
|
|
49
51
|
modality: str = "sentinel-2-l2a",
|
|
50
52
|
checkpoint_path: str | None = None,
|
|
51
53
|
metadata_path: str = CLAY_METADATA_PATH,
|
|
54
|
+
do_resizing: bool = False,
|
|
52
55
|
) -> None:
|
|
53
56
|
"""Initialize the Clay model.
|
|
54
57
|
|
|
@@ -57,6 +60,7 @@ class Clay(torch.nn.Module):
|
|
|
57
60
|
modality: The modality to use (subset of CLAY_MODALITIES).
|
|
58
61
|
checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
|
|
59
62
|
metadata_path: Path to metadata.yaml.
|
|
63
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
60
64
|
"""
|
|
61
65
|
super().__init__()
|
|
62
66
|
|
|
@@ -95,6 +99,14 @@ class Clay(torch.nn.Module):
|
|
|
95
99
|
|
|
96
100
|
self.model_size = model_size
|
|
97
101
|
self.modality = modality
|
|
102
|
+
self.do_resizing = do_resizing
|
|
103
|
+
|
|
104
|
+
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
105
|
+
"""Resize the image to the input resolution."""
|
|
106
|
+
new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
107
|
+
return F.interpolate(
|
|
108
|
+
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
109
|
+
)
|
|
98
110
|
|
|
99
111
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
100
112
|
"""Forward pass for the Clay model.
|
|
@@ -114,7 +126,8 @@ class Clay(torch.nn.Module):
|
|
|
114
126
|
chips = torch.stack(
|
|
115
127
|
[inp[self.modality] for inp in inputs], dim=0
|
|
116
128
|
) # (B, C, H, W)
|
|
117
|
-
|
|
129
|
+
if self.do_resizing:
|
|
130
|
+
chips = self._resize_image(chips, chips.shape[2])
|
|
118
131
|
order = self.metadata[self.modality]["band_order"]
|
|
119
132
|
wavelengths = []
|
|
120
133
|
for band in self.metadata[self.modality]["band_order"]:
|
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
10
11
|
from einops import rearrange
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
@@ -99,6 +100,7 @@ class Croma(torch.nn.Module):
|
|
|
99
100
|
modality: CromaModality,
|
|
100
101
|
pretrained_path: str | None = None,
|
|
101
102
|
image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
|
|
103
|
+
do_resizing: bool = False,
|
|
102
104
|
) -> None:
|
|
103
105
|
"""Instantiate a new Croma instance.
|
|
104
106
|
|
|
@@ -107,12 +109,21 @@ class Croma(torch.nn.Module):
|
|
|
107
109
|
modality: the modalities to configure the model to accept.
|
|
108
110
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
109
111
|
downloaded and cached in temp directory.
|
|
110
|
-
image_resolution: the width and height of the input images.
|
|
112
|
+
image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
|
|
113
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
111
114
|
"""
|
|
112
115
|
super().__init__()
|
|
113
116
|
self.size = size
|
|
114
117
|
self.modality = modality
|
|
115
|
-
self.
|
|
118
|
+
self.do_resizing = do_resizing
|
|
119
|
+
if not do_resizing:
|
|
120
|
+
self.image_resolution = image_resolution
|
|
121
|
+
else:
|
|
122
|
+
# With single pixel input, we always resample to the patch size.
|
|
123
|
+
if image_resolution == 1:
|
|
124
|
+
self.image_resolution = PATCH_SIZE
|
|
125
|
+
else:
|
|
126
|
+
self.image_resolution = DEFAULT_IMAGE_RESOLUTION
|
|
116
127
|
|
|
117
128
|
# Cache the CROMA weights to a deterministic path in temporary directory if the
|
|
118
129
|
# path is not provided by the user.
|
|
@@ -137,7 +148,16 @@ class Croma(torch.nn.Module):
|
|
|
137
148
|
pretrained_path=pretrained_path,
|
|
138
149
|
size=size.value,
|
|
139
150
|
modality=modality.value,
|
|
140
|
-
image_resolution=image_resolution,
|
|
151
|
+
image_resolution=self.image_resolution,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
"""Resize the image to the input resolution."""
|
|
156
|
+
return F.interpolate(
|
|
157
|
+
image,
|
|
158
|
+
size=(self.image_resolution, self.image_resolution),
|
|
159
|
+
mode="bilinear",
|
|
160
|
+
align_corners=False,
|
|
141
161
|
)
|
|
142
162
|
|
|
143
163
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
@@ -151,8 +171,11 @@ class Croma(torch.nn.Module):
|
|
|
151
171
|
sentinel2: torch.Tensor | None = None
|
|
152
172
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
153
173
|
sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
|
|
174
|
+
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
154
175
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
155
176
|
sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
177
|
+
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
178
|
+
|
|
156
179
|
outputs = self.model(
|
|
157
180
|
SAR_images=sentinel1,
|
|
158
181
|
optical_images=sentinel2,
|
|
@@ -4,15 +4,14 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import satlaspretrain_models
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SatlasPretrain(torch.nn.Module):
|
|
10
11
|
"""SatlasPretrain backbones."""
|
|
11
12
|
|
|
12
13
|
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
model_identifier: str,
|
|
15
|
-
fpn: bool = False,
|
|
14
|
+
self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
|
|
16
15
|
) -> None:
|
|
17
16
|
"""Instantiate a new SatlasPretrain instance.
|
|
18
17
|
|
|
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
21
20
|
https://github.com/allenai/satlaspretrain_models
|
|
22
21
|
fpn: whether to include the feature pyramid network, otherwise only the
|
|
23
22
|
Swin-v2-Transformer is used.
|
|
23
|
+
resize_to_pretrain: whether to resize inputs to the pretraining input
|
|
24
|
+
size (512 x 512)
|
|
24
25
|
"""
|
|
25
26
|
super().__init__()
|
|
26
27
|
weights_manager = satlaspretrain_models.Weights()
|
|
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
49
50
|
[16, 1024],
|
|
50
51
|
[32, 2048],
|
|
51
52
|
]
|
|
53
|
+
self.resize_to_pretrain = resize_to_pretrain
|
|
54
|
+
|
|
55
|
+
def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
|
|
56
|
+
"""Resize to pretraining sizes if resize_to_pretrain == True."""
|
|
57
|
+
if self.resize_to_pretrain:
|
|
58
|
+
return F.interpolate(
|
|
59
|
+
data,
|
|
60
|
+
size=(512, 512),
|
|
61
|
+
mode="bilinear",
|
|
62
|
+
align_corners=False,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return data
|
|
52
66
|
|
|
53
67
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
54
68
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
58
72
|
process.
|
|
59
73
|
"""
|
|
60
74
|
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
61
|
-
return self.model(images)
|
|
75
|
+
return self.model(self.maybe_resize(images))
|
|
62
76
|
|
|
63
77
|
def get_backbone_channels(self) -> list:
|
|
64
78
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -4,6 +4,7 @@ from enum import Enum
|
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
from einops import rearrange
|
|
8
9
|
from terratorch.registry import BACKBONE_REGISTRY
|
|
9
10
|
|
|
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
|
|
|
18
19
|
LARGE = "large"
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
# Pretraining image size for Terramind
|
|
23
|
+
IMAGE_SIZE = 224
|
|
21
24
|
# Default patch size for Terramind
|
|
22
25
|
PATCH_SIZE = 16
|
|
23
26
|
|
|
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
|
|
|
89
92
|
self,
|
|
90
93
|
model_size: TerramindSize,
|
|
91
94
|
modalities: list[str] = ["S2L2A"],
|
|
95
|
+
do_resizing: bool = False,
|
|
92
96
|
) -> None:
|
|
93
97
|
"""Initialize the Terramind model.
|
|
94
98
|
|
|
95
99
|
Args:
|
|
96
100
|
model_size: The size of the Terramind model.
|
|
97
101
|
modalities: The modalities to use.
|
|
102
|
+
do_resizing: Whether to resize the input images to the pretraining resolution.
|
|
98
103
|
"""
|
|
99
104
|
super().__init__()
|
|
100
105
|
|
|
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
|
|
|
116
121
|
|
|
117
122
|
self.model_size = model_size
|
|
118
123
|
self.modalities = modalities
|
|
124
|
+
self.do_resizing = do_resizing
|
|
119
125
|
|
|
120
126
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
121
127
|
"""Forward pass for the Terramind model.
|
|
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
|
|
|
132
138
|
if modality not in inputs[0]:
|
|
133
139
|
continue
|
|
134
140
|
cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
|
|
141
|
+
if self.do_resizing and (
|
|
142
|
+
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
143
|
+
):
|
|
144
|
+
if cur.shape[2] == 1 and cur.shape[3] == 1:
|
|
145
|
+
new_height, new_width = PATCH_SIZE, PATCH_SIZE
|
|
146
|
+
else:
|
|
147
|
+
new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
|
|
148
|
+
cur = F.interpolate(
|
|
149
|
+
cur,
|
|
150
|
+
size=(new_height, new_width),
|
|
151
|
+
mode="bilinear",
|
|
152
|
+
align_corners=False,
|
|
153
|
+
)
|
|
135
154
|
model_inputs[modality] = cur
|
|
136
155
|
|
|
137
156
|
# By default, the patch embeddings are averaged over all modalities to reduce output tokens
|