rslearn 0.0.18__tar.gz → 0.0.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.18/rslearn.egg-info → rslearn-0.0.19}/PKG-INFO +1 -1
- {rslearn-0.0.18 → rslearn-0.0.19}/pyproject.toml +1 -1
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/arg_parser.py +2 -9
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/config/dataset.py +15 -16
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/dataset.py +28 -22
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/lightning_cli.py +22 -11
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/main.py +1 -1
- rslearn-0.0.19/rslearn/models/attention_pooling.py +177 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/component.py +12 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/model.py +125 -34
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/simple_time_series.py +7 -1
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/all_patches_dataset.py +67 -19
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/dataset.py +36 -43
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/scheduler.py +15 -0
- rslearn-0.0.19/rslearn/train/transforms/resize.py +74 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/geometry.py +73 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18 → rslearn-0.0.19/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/SOURCES.txt +2 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/LICENSE +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/NOTICE +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/README.md +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/const.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/file.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/storage/storage.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/concatenate_features.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/py.typed +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/template_params.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/model_context.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.18 → rslearn-0.0.19}/setup.cfg +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
from jsonargparse import Namespace
|
|
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
21
20
|
def parse_string(
|
|
22
21
|
self,
|
|
23
22
|
cfg_str: str,
|
|
24
|
-
|
|
25
|
-
ext_vars: dict | None = None,
|
|
26
|
-
env: bool | None = None,
|
|
27
|
-
defaults: bool = True,
|
|
28
|
-
with_meta: bool | None = None,
|
|
23
|
+
*args: Any,
|
|
29
24
|
**kwargs: Any,
|
|
30
25
|
) -> Namespace:
|
|
31
26
|
"""Pre-processes string for environment variable substitution before parsing."""
|
|
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
33
28
|
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
34
29
|
|
|
35
30
|
# Call the parent method with the substituted config
|
|
36
|
-
return super().parse_string(
|
|
37
|
-
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
38
|
-
)
|
|
31
|
+
return super().parse_string(substituted_cfg_str, *args, **kwargs)
|
|
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
|
|
|
25
25
|
from upath import UPath
|
|
26
26
|
|
|
27
27
|
from rslearn.log_utils import get_logger
|
|
28
|
-
from rslearn.utils import PixelBounds, Projection
|
|
28
|
+
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
|
|
29
29
|
from rslearn.utils.raster_format import RasterFormat
|
|
30
30
|
from rslearn.utils.vector_format import VectorFormat
|
|
31
31
|
|
|
@@ -215,22 +215,12 @@ class BandSetConfig(BaseModel):
|
|
|
215
215
|
Returns:
|
|
216
216
|
tuple of updated projection and bounds with zoom offset applied
|
|
217
217
|
"""
|
|
218
|
-
if self.zoom_offset
|
|
219
|
-
|
|
220
|
-
projection = Projection(
|
|
221
|
-
projection.crs,
|
|
222
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
223
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
224
|
-
)
|
|
225
|
-
if self.zoom_offset > 0:
|
|
226
|
-
zoom_factor = 2**self.zoom_offset
|
|
227
|
-
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
218
|
+
if self.zoom_offset >= 0:
|
|
219
|
+
factor = ResolutionFactor(numerator=2**self.zoom_offset)
|
|
228
220
|
else:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
)
|
|
233
|
-
return projection, bounds
|
|
221
|
+
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
|
|
222
|
+
|
|
223
|
+
return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
|
|
234
224
|
|
|
235
225
|
@field_validator("format", mode="before")
|
|
236
226
|
@classmethod
|
|
@@ -645,3 +635,12 @@ class DatasetConfig(BaseModel):
|
|
|
645
635
|
default_factory=lambda: StorageConfig(),
|
|
646
636
|
description="jsonargparse configuration for the WindowStorageFactory.",
|
|
647
637
|
)
|
|
638
|
+
|
|
639
|
+
@field_validator("layers", mode="after")
|
|
640
|
+
@classmethod
|
|
641
|
+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
|
|
642
|
+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
|
|
643
|
+
for layer_name in v.keys():
|
|
644
|
+
if "." in layer_name:
|
|
645
|
+
raise ValueError(f"layer names must not contain periods: {layer_name}")
|
|
646
|
+
return v
|
|
@@ -23,7 +23,7 @@ class Dataset:
|
|
|
23
23
|
.. code-block:: none
|
|
24
24
|
|
|
25
25
|
dataset/
|
|
26
|
-
config.json
|
|
26
|
+
config.json # optional, if config provided as runtime object
|
|
27
27
|
windows/
|
|
28
28
|
group1/
|
|
29
29
|
epsg:3857_10_623565_1528020/
|
|
@@ -40,37 +40,43 @@ class Dataset:
|
|
|
40
40
|
materialize.
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
path: UPath,
|
|
46
|
+
disabled_layers: list[str] = [],
|
|
47
|
+
dataset_config: DatasetConfig | None = None,
|
|
48
|
+
) -> None:
|
|
44
49
|
"""Initializes a new Dataset.
|
|
45
50
|
|
|
46
51
|
Args:
|
|
47
52
|
path: the root directory of the dataset
|
|
48
53
|
disabled_layers: list of layers to disable
|
|
54
|
+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
|
|
49
55
|
"""
|
|
50
56
|
self.path = path
|
|
51
57
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for layer_name, layer_config in config.layers.items():
|
|
60
|
-
# Layer names must not contain period, since we use period to
|
|
61
|
-
# distinguish different materialized groups within a layer.
|
|
62
|
-
assert "." not in layer_name, "layer names must not contain periods"
|
|
63
|
-
if layer_name in disabled_layers:
|
|
64
|
-
logger.warning(f"Layer {layer_name} is disabled")
|
|
65
|
-
continue
|
|
66
|
-
self.layers[layer_name] = layer_config
|
|
67
|
-
|
|
68
|
-
self.tile_store_config = config.tile_store
|
|
69
|
-
self.storage = (
|
|
70
|
-
config.storage.instantiate_window_storage_factory().get_storage(
|
|
71
|
-
self.path
|
|
58
|
+
if dataset_config is None:
|
|
59
|
+
# Load dataset configuration from the dataset directory.
|
|
60
|
+
with (self.path / "config.json").open("r") as f:
|
|
61
|
+
config_content = f.read()
|
|
62
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
63
|
+
dataset_config = DatasetConfig.model_validate(
|
|
64
|
+
json.loads(config_content)
|
|
72
65
|
)
|
|
66
|
+
|
|
67
|
+
self.layers = {}
|
|
68
|
+
for layer_name, layer_config in dataset_config.layers.items():
|
|
69
|
+
if layer_name in disabled_layers:
|
|
70
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
71
|
+
continue
|
|
72
|
+
self.layers[layer_name] = layer_config
|
|
73
|
+
|
|
74
|
+
self.tile_store_config = dataset_config.tile_store
|
|
75
|
+
self.storage = (
|
|
76
|
+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
77
|
+
self.path
|
|
73
78
|
)
|
|
79
|
+
)
|
|
74
80
|
|
|
75
81
|
def load_windows(
|
|
76
82
|
self,
|
|
@@ -21,6 +21,7 @@ from rslearn.log_utils import get_logger
|
|
|
21
21
|
from rslearn.train.data_module import RslearnDataModule
|
|
22
22
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
23
23
|
from rslearn.utils.fsspec import open_atomic
|
|
24
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
24
25
|
|
|
25
26
|
WANDB_ID_FNAME = "wandb_id"
|
|
26
27
|
|
|
@@ -390,8 +391,15 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
390
391
|
|
|
391
392
|
Sets the dataset path for any configured RslearnPredictionWriter callbacks.
|
|
392
393
|
"""
|
|
393
|
-
|
|
394
|
-
|
|
394
|
+
if not hasattr(self.config, "subcommand"):
|
|
395
|
+
logger.warning(
|
|
396
|
+
"Config does not have subcommand attribute, assuming we are in run=False mode"
|
|
397
|
+
)
|
|
398
|
+
subcommand = None
|
|
399
|
+
c = self.config
|
|
400
|
+
else:
|
|
401
|
+
subcommand = self.config.subcommand
|
|
402
|
+
c = self.config[subcommand]
|
|
395
403
|
|
|
396
404
|
# If there is a RslearnPredictionWriter, set its path.
|
|
397
405
|
prediction_writer_callback = None
|
|
@@ -415,16 +423,17 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
415
423
|
if subcommand == "predict":
|
|
416
424
|
c.return_predictions = False
|
|
417
425
|
|
|
418
|
-
#
|
|
426
|
+
# Default to DDP with find_unused_parameters. Likely won't get called with unified config
|
|
419
427
|
if subcommand == "fit":
|
|
420
|
-
c.trainer.strategy
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
+
if not c.trainer.strategy:
|
|
429
|
+
c.trainer.strategy = jsonargparse.Namespace(
|
|
430
|
+
{
|
|
431
|
+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
|
|
432
|
+
"init_args": jsonargparse.Namespace(
|
|
433
|
+
{"find_unused_parameters": True}
|
|
434
|
+
),
|
|
435
|
+
}
|
|
436
|
+
)
|
|
428
437
|
|
|
429
438
|
if c.management_dir:
|
|
430
439
|
self.enable_project_management(c.management_dir)
|
|
@@ -432,6 +441,8 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
432
441
|
|
|
433
442
|
def model_handler() -> None:
|
|
434
443
|
"""Handler for any rslearn model X commands."""
|
|
444
|
+
init_jsonargparse()
|
|
445
|
+
|
|
435
446
|
RslearnLightningCLI(
|
|
436
447
|
model_class=RslearnLightningModule,
|
|
437
448
|
datamodule_class=RslearnDataModule,
|
|
@@ -380,7 +380,7 @@ def apply_on_windows(
|
|
|
380
380
|
|
|
381
381
|
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
382
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
383
|
-
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
383
|
+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
|
|
384
384
|
apply_on_windows(
|
|
385
385
|
f=f,
|
|
386
386
|
dataset=dataset,
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""An attention pooling layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import (
|
|
12
|
+
FeatureMaps,
|
|
13
|
+
IntermediateComponent,
|
|
14
|
+
TokenFeatureMaps,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleAttentionPool(IntermediateComponent):
|
|
20
|
+
"""Simple Attention Pooling.
|
|
21
|
+
|
|
22
|
+
Given a token feature map of shape BCHWN,
|
|
23
|
+
learn an attention layer which aggregates over
|
|
24
|
+
the N dimension.
|
|
25
|
+
|
|
26
|
+
This is done simply by learning a mapping D->1 which is the weight
|
|
27
|
+
which should be assigned to each token during averaging:
|
|
28
|
+
|
|
29
|
+
output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
|
|
33
|
+
"""Initialize the simple attention pooling layer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
in_dim: the encoding dimension D
|
|
37
|
+
hidden_linear: whether to apply an additional linear transformation D -> D
|
|
38
|
+
to the feat tokens. If this is True, a ReLU activation is applied
|
|
39
|
+
after the first linear transformation.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
if hidden_linear:
|
|
43
|
+
self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
|
|
44
|
+
else:
|
|
45
|
+
self.hidden_linear = None
|
|
46
|
+
self.linear = nn.Linear(in_features=in_dim, out_features=1)
|
|
47
|
+
|
|
48
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
50
|
+
B, D, H, W, N = feat_tokens.shape
|
|
51
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
52
|
+
if self.hidden_linear is not None:
|
|
53
|
+
feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
|
|
54
|
+
attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
|
|
55
|
+
feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
|
|
56
|
+
return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
|
|
57
|
+
|
|
58
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
59
|
+
"""Forward pass for attention pooling linear probe.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
63
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple maps
|
|
64
|
+
are passed, we apply the same linear layers to all of them.
|
|
65
|
+
context: the model context.
|
|
66
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor:
|
|
70
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
71
|
+
"""
|
|
72
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
73
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
74
|
+
|
|
75
|
+
features = []
|
|
76
|
+
for feat in intermediates.feature_maps:
|
|
77
|
+
features.append(self.forward_for_map(feat))
|
|
78
|
+
return FeatureMaps(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AttentionPool(IntermediateComponent):
|
|
82
|
+
"""Attention Pooling.
|
|
83
|
+
|
|
84
|
+
Given a feature map of shape BCHWN,
|
|
85
|
+
learn an attention layer which aggregates over
|
|
86
|
+
the N dimension.
|
|
87
|
+
|
|
88
|
+
We do this by learning a query token, and applying a standard
|
|
89
|
+
attention mechanism against this learned query token.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
|
|
93
|
+
"""Initialize the attention pooling layer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
in_dim: the encoding dimension D
|
|
97
|
+
num_heads: the number of heads to use
|
|
98
|
+
linear_on_kv: Whether to apply a linear layer on the input tokens
|
|
99
|
+
to create the key and value tokens.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
|
|
103
|
+
if linear_on_kv:
|
|
104
|
+
self.k_linear = nn.Linear(in_dim, in_dim)
|
|
105
|
+
self.v_linear = nn.Linear(in_dim, in_dim)
|
|
106
|
+
else:
|
|
107
|
+
self.k_linear = None
|
|
108
|
+
self.v_linear = None
|
|
109
|
+
if in_dim % num_heads != 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
|
|
112
|
+
)
|
|
113
|
+
self.num_heads = num_heads
|
|
114
|
+
self.init_weights()
|
|
115
|
+
|
|
116
|
+
def init_weights(self) -> None:
|
|
117
|
+
"""Initialize weights for the probe."""
|
|
118
|
+
nn.init.trunc_normal_(self.query_token, std=0.02)
|
|
119
|
+
|
|
120
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
122
|
+
B, D, H, W, N = feat_tokens.shape
|
|
123
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
124
|
+
collapsed_dim = B * H * W
|
|
125
|
+
q = self.query_token.expand(collapsed_dim, 1, -1)
|
|
126
|
+
q = q.reshape(
|
|
127
|
+
collapsed_dim, 1, self.num_heads, D // self.num_heads
|
|
128
|
+
) # [B, 1, head, D_head]
|
|
129
|
+
q = rearrange(q, "b h n d -> b n h d")
|
|
130
|
+
if self.k_linear is not None:
|
|
131
|
+
assert self.v_linear is not None
|
|
132
|
+
k = self.k_linear(feat_tokens).reshape(
|
|
133
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
134
|
+
)
|
|
135
|
+
v = self.v_linear(feat_tokens).reshape(
|
|
136
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
k = feat_tokens.reshape(
|
|
140
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
141
|
+
)
|
|
142
|
+
v = feat_tokens.reshape(
|
|
143
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
144
|
+
)
|
|
145
|
+
k = rearrange(k, "b n h d -> b h n d")
|
|
146
|
+
v = rearrange(v, "b n h d -> b h n d")
|
|
147
|
+
|
|
148
|
+
# Compute attention scores
|
|
149
|
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
|
150
|
+
D // self.num_heads
|
|
151
|
+
)
|
|
152
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
|
|
154
|
+
return x.reshape(B, D, H, W)
|
|
155
|
+
|
|
156
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
|
+
"""Forward pass for attention pooling linear probe.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
161
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple feature
|
|
162
|
+
maps are passed, we apply the same attention weights (query token and linear k, v layers)
|
|
163
|
+
to all the maps.
|
|
164
|
+
context: the model context.
|
|
165
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor:
|
|
169
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
170
|
+
"""
|
|
171
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
172
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
173
|
+
|
|
174
|
+
features = []
|
|
175
|
+
for feat in intermediates.feature_maps:
|
|
176
|
+
features.append(self.forward_for_map(feat))
|
|
177
|
+
return FeatureMaps(features)
|
|
@@ -91,6 +91,18 @@ class FeatureMaps:
|
|
|
91
91
|
feature_maps: list[torch.Tensor]
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
@dataclass
|
|
95
|
+
class TokenFeatureMaps:
|
|
96
|
+
"""An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
|
|
97
|
+
|
|
98
|
+
Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
|
|
102
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
103
|
+
feature_maps: list[torch.Tensor]
|
|
104
|
+
|
|
105
|
+
|
|
94
106
|
@dataclass
|
|
95
107
|
class FeatureVector:
|
|
96
108
|
"""An intermediate output type for a flat feature vector."""
|