rslearn 0.0.8__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.8/rslearn.egg-info → rslearn-0.0.9}/PKG-INFO +2 -2
- {rslearn-0.0.8 → rslearn-0.0.9}/README.md +1 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/pyproject.toml +1 -2
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/local_files.py +20 -3
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/planetary_computer.py +79 -14
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/manage.py +2 -2
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/materialize.py +21 -2
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/remap.py +29 -4
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/dinov3.py +12 -11
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/galileo/galileo.py +58 -12
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/galileo/single_file_galileo.py +7 -1
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/presto/presto.py +11 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/prithvi.py +11 -0
- rslearn-0.0.9/rslearn/models/registry.py +22 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/tile_stores/default.py +3 -1
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/transform.py +23 -6
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/raster_format.py +37 -4
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/vector_format.py +35 -4
- {rslearn-0.0.8 → rslearn-0.0.9/rslearn.egg-info}/PKG-INFO +2 -2
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn.egg-info/requires.txt +0 -1
- rslearn-0.0.8/rslearn/models/registry.py +0 -5
- {rslearn-0.0.8 → rslearn-0.0.9}/LICENSE +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/const.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/main.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/aurora/area.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/aurora/fourier.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/flexivit/utils.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/model_vit.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/copernicusfm_src/util/pos_embed.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/py.typed +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/template_params.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.8 → rslearn-0.0.9}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
|
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
214
|
Requires-Dist: boto3>=1.39
|
|
215
|
-
Requires-Dist: class_registry>=2.1
|
|
216
215
|
Requires-Dist: fiona>=1.10
|
|
217
216
|
Requires-Dist: fsspec>=2025.9.0
|
|
218
217
|
Requires-Dist: jsonargparse>=4.35.0
|
|
@@ -284,6 +283,7 @@ Quick links:
|
|
|
284
283
|
- [Examples](docs/Examples.md) contains more examples, including customizing different
|
|
285
284
|
stages of rslearn with additional code.
|
|
286
285
|
- [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
|
|
286
|
+
- [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
|
|
287
287
|
|
|
288
288
|
|
|
289
289
|
Setup
|
|
@@ -21,6 +21,7 @@ Quick links:
|
|
|
21
21
|
- [Examples](docs/Examples.md) contains more examples, including customizing different
|
|
22
22
|
stages of rslearn with additional code.
|
|
23
23
|
- [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
|
|
24
|
+
- [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
Setup
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.9"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -10,7 +10,6 @@ license = {file = "LICENSE"}
|
|
|
10
10
|
requires-python = ">=3.11"
|
|
11
11
|
dependencies = [
|
|
12
12
|
"boto3>=1.39",
|
|
13
|
-
"class_registry>=2.1",
|
|
14
13
|
"fiona>=1.10",
|
|
15
14
|
"fsspec>=2025.9.0", # this is used both directly and indirectly (via universal_pathlib) in our code
|
|
16
15
|
"jsonargparse>=4.35.0",
|
|
@@ -2,12 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import json
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from typing import Any, Generic, TypeVar
|
|
6
7
|
|
|
7
8
|
import fiona
|
|
8
9
|
import shapely
|
|
9
10
|
import shapely.geometry
|
|
10
|
-
from class_registry import ClassRegistry
|
|
11
11
|
from rasterio.crs import CRS
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
@@ -23,7 +23,24 @@ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
|
|
|
23
23
|
from .data_source import DataSource, Item, QueryConfig
|
|
24
24
|
|
|
25
25
|
logger = get_logger("__name__")
|
|
26
|
-
|
|
26
|
+
_ImporterT = TypeVar("_ImporterT", bound="Importer")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _ImporterRegistry(dict[str, type["Importer"]]):
|
|
30
|
+
"""Registry for Importer classes."""
|
|
31
|
+
|
|
32
|
+
def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
|
|
33
|
+
"""Decorator to register an importer class."""
|
|
34
|
+
|
|
35
|
+
def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
|
|
36
|
+
self[name] = cls
|
|
37
|
+
return cls
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
Importers = _ImporterRegistry()
|
|
43
|
+
|
|
27
44
|
|
|
28
45
|
ItemType = TypeVar("ItemType", bound=Item)
|
|
29
46
|
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
@@ -425,7 +442,7 @@ class LocalFiles(DataSource):
|
|
|
425
442
|
"""
|
|
426
443
|
self.config = config
|
|
427
444
|
|
|
428
|
-
self.importer = Importers[config.layer_type.value]
|
|
445
|
+
self.importer = Importers[config.layer_type.value]()
|
|
429
446
|
self.src_dir = src_dir
|
|
430
447
|
|
|
431
448
|
@staticmethod
|
|
@@ -83,6 +83,10 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
83
83
|
|
|
84
84
|
STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
|
85
85
|
|
|
86
|
+
# Default threshold for recreating the STAC client to prevent memory leaks
|
|
87
|
+
# from the pystac Catalog's resolved objects cache growing unbounded
|
|
88
|
+
DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
|
|
89
|
+
|
|
86
90
|
def __init__(
|
|
87
91
|
self,
|
|
88
92
|
collection_name: str,
|
|
@@ -93,6 +97,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
93
97
|
timeout: timedelta = timedelta(seconds=10),
|
|
94
98
|
skip_items_missing_assets: bool = False,
|
|
95
99
|
cache_dir: UPath | None = None,
|
|
100
|
+
max_items_per_client: int | None = None,
|
|
96
101
|
):
|
|
97
102
|
"""Initialize a new PlanetaryComputer instance.
|
|
98
103
|
|
|
@@ -109,6 +114,9 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
109
114
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
110
115
|
If not set, there will be no cache and instead STAC requests will be
|
|
111
116
|
needed each time.
|
|
117
|
+
max_items_per_client: number of STAC items to process before recreating
|
|
118
|
+
the client to prevent memory leaks from the resolved objects cache.
|
|
119
|
+
Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
|
|
112
120
|
"""
|
|
113
121
|
self.collection_name = collection_name
|
|
114
122
|
self.asset_bands = asset_bands
|
|
@@ -118,12 +126,15 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
118
126
|
self.timeout = timeout
|
|
119
127
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
120
128
|
self.cache_dir = cache_dir
|
|
129
|
+
self.max_items_per_client = (
|
|
130
|
+
max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
|
|
131
|
+
)
|
|
121
132
|
|
|
122
133
|
if self.cache_dir is not None:
|
|
123
134
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
124
135
|
|
|
125
136
|
self.client: pystac_client.Client | None = None
|
|
126
|
-
self.
|
|
137
|
+
self._client_item_count = 0
|
|
127
138
|
|
|
128
139
|
@staticmethod
|
|
129
140
|
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
|
|
@@ -142,7 +153,12 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
142
153
|
if "cache_dir" in d:
|
|
143
154
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
144
155
|
|
|
145
|
-
simple_optionals = [
|
|
156
|
+
simple_optionals = [
|
|
157
|
+
"query",
|
|
158
|
+
"sort_by",
|
|
159
|
+
"sort_ascending",
|
|
160
|
+
"max_items_per_client",
|
|
161
|
+
]
|
|
146
162
|
for k in simple_optionals:
|
|
147
163
|
if k in d:
|
|
148
164
|
kwargs[k] = d[k]
|
|
@@ -151,20 +167,40 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
151
167
|
|
|
152
168
|
def _load_client(
|
|
153
169
|
self,
|
|
154
|
-
) ->
|
|
170
|
+
) -> pystac_client.Client:
|
|
155
171
|
"""Lazily load pystac client.
|
|
156
172
|
|
|
157
173
|
We don't load it when creating the data source because it takes time and caller
|
|
158
174
|
may not be calling get_items. Additionally, loading it during the get_items
|
|
159
175
|
call enables leveraging the retry loop functionality in
|
|
160
176
|
prepare_dataset_windows.
|
|
161
|
-
"""
|
|
162
|
-
if self.client is not None:
|
|
163
|
-
return self.client, self.collection
|
|
164
177
|
|
|
178
|
+
Note: We periodically recreate the client to prevent memory leaks from the
|
|
179
|
+
pystac Catalog's resolved objects cache, which grows unbounded as STAC items
|
|
180
|
+
are deserialized and cached. The cache cannot be cleared or disabled.
|
|
181
|
+
"""
|
|
182
|
+
if self.client is None:
|
|
183
|
+
logger.info("Creating initial STAC client")
|
|
184
|
+
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
|
|
185
|
+
return self.client
|
|
186
|
+
|
|
187
|
+
if self._client_item_count < self.max_items_per_client:
|
|
188
|
+
return self.client
|
|
189
|
+
|
|
190
|
+
# Recreate client to clear the resolved objects cache
|
|
191
|
+
current_client = self.client
|
|
192
|
+
logger.debug(
|
|
193
|
+
"Recreating STAC client after processing %d items (threshold: %d)",
|
|
194
|
+
self._client_item_count,
|
|
195
|
+
self.max_items_per_client,
|
|
196
|
+
)
|
|
197
|
+
client_root = current_client.get_root()
|
|
198
|
+
client_root.clear_links()
|
|
199
|
+
client_root.clear_items()
|
|
200
|
+
client_root.clear_children()
|
|
201
|
+
self._client_item_count = 0
|
|
165
202
|
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
|
|
166
|
-
|
|
167
|
-
return self.client, self.collection
|
|
203
|
+
return self.client
|
|
168
204
|
|
|
169
205
|
def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
|
|
170
206
|
shp = shapely.geometry.shape(stac_item.geometry)
|
|
@@ -210,10 +246,26 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
210
246
|
|
|
211
247
|
# No cache or not in cache, so we need to make the STAC request.
|
|
212
248
|
logger.debug("Getting STAC item {name}")
|
|
213
|
-
|
|
214
|
-
|
|
249
|
+
client = self._load_client()
|
|
250
|
+
|
|
251
|
+
search_result = client.search(ids=[name], collections=[self.collection_name])
|
|
252
|
+
stac_items = list(search_result.items())
|
|
253
|
+
|
|
254
|
+
if not stac_items:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Item {name} not found in collection {self.collection_name}"
|
|
257
|
+
)
|
|
258
|
+
if len(stac_items) > 1:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"Multiple items found for ID {name} in collection {self.collection_name}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
stac_item = stac_items[0]
|
|
215
264
|
item = self._stac_item_to_item(stac_item)
|
|
216
265
|
|
|
266
|
+
# Track items processed for client recreation threshold (after deserialization)
|
|
267
|
+
self._client_item_count += 1
|
|
268
|
+
|
|
217
269
|
# Finally we cache it if cache_dir is set.
|
|
218
270
|
if cache_fname is not None:
|
|
219
271
|
with cache_fname.open("w") as f:
|
|
@@ -233,7 +285,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
233
285
|
Returns:
|
|
234
286
|
List of groups of items that should be retrieved for each geometry.
|
|
235
287
|
"""
|
|
236
|
-
client
|
|
288
|
+
client = self._load_client()
|
|
237
289
|
|
|
238
290
|
groups = []
|
|
239
291
|
for geometry in geometries:
|
|
@@ -247,7 +299,9 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
247
299
|
datetime=wgs84_geometry.time_range,
|
|
248
300
|
query=self.query,
|
|
249
301
|
)
|
|
250
|
-
stac_items = [item for item in result.
|
|
302
|
+
stac_items = [item for item in result.items()]
|
|
303
|
+
# Track items processed for client recreation threshold (after deserialization)
|
|
304
|
+
self._client_item_count += len(stac_items)
|
|
251
305
|
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
252
306
|
|
|
253
307
|
if self.skip_items_missing_assets:
|
|
@@ -580,7 +634,13 @@ class Sentinel2(PlanetaryComputer):
|
|
|
580
634
|
if "cache_dir" in d:
|
|
581
635
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
582
636
|
|
|
583
|
-
simple_optionals = [
|
|
637
|
+
simple_optionals = [
|
|
638
|
+
"harmonize",
|
|
639
|
+
"query",
|
|
640
|
+
"sort_by",
|
|
641
|
+
"sort_ascending",
|
|
642
|
+
"max_items_per_client",
|
|
643
|
+
]
|
|
584
644
|
for k in simple_optionals:
|
|
585
645
|
if k in d:
|
|
586
646
|
kwargs[k] = d[k]
|
|
@@ -756,7 +816,12 @@ class Sentinel1(PlanetaryComputer):
|
|
|
756
816
|
if "cache_dir" in d:
|
|
757
817
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
758
818
|
|
|
759
|
-
simple_optionals = [
|
|
819
|
+
simple_optionals = [
|
|
820
|
+
"query",
|
|
821
|
+
"sort_by",
|
|
822
|
+
"sort_ascending",
|
|
823
|
+
"max_items_per_client",
|
|
824
|
+
]
|
|
760
825
|
for k in simple_optionals:
|
|
761
826
|
if k in d:
|
|
762
827
|
kwargs[k] = d[k]
|
|
@@ -396,9 +396,9 @@ def materialize_window(
|
|
|
396
396
|
)
|
|
397
397
|
|
|
398
398
|
if dataset.materializer_name:
|
|
399
|
-
materializer = Materializers[dataset.materializer_name]
|
|
399
|
+
materializer = Materializers[dataset.materializer_name]()
|
|
400
400
|
else:
|
|
401
|
-
materializer = Materializers[layer_cfg.layer_type.value]
|
|
401
|
+
materializer = Materializers[layer_cfg.layer_type.value]()
|
|
402
402
|
|
|
403
403
|
retry(
|
|
404
404
|
fn=lambda: materializer.materialize(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Classes to implement dataset materialization."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from typing import Any, Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
from rasterio.enums import Resampling
|
|
9
9
|
|
|
10
10
|
from rslearn.config import (
|
|
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
|
|
|
25
25
|
from .remap import Remapper, load_remapper
|
|
26
26
|
from .window import Window
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
_MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _MaterializerRegistry(dict[str, type["Materializer"]]):
|
|
32
|
+
"""Registry for Materializer classes."""
|
|
33
|
+
|
|
34
|
+
def register(
|
|
35
|
+
self, name: str
|
|
36
|
+
) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
|
|
37
|
+
"""Decorator to register a materializer class."""
|
|
38
|
+
|
|
39
|
+
def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
|
|
40
|
+
self[name] = cls
|
|
41
|
+
return cls
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
Materializers = _MaterializerRegistry()
|
|
47
|
+
|
|
29
48
|
|
|
30
49
|
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
31
50
|
|
|
@@ -1,18 +1,42 @@
|
|
|
1
1
|
"""Classes to remap raster values."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
_RemapperT = TypeVar("_RemapperT", bound="Remapper")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _RemapperRegistry(dict[str, type["Remapper"]]):
|
|
13
|
+
"""Registry for Remapper classes."""
|
|
14
|
+
|
|
15
|
+
def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
|
|
16
|
+
"""Decorator to register a remapper class."""
|
|
17
|
+
|
|
18
|
+
def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
|
|
19
|
+
self[name] = cls
|
|
20
|
+
return cls
|
|
21
|
+
|
|
22
|
+
return decorator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Remappers = _RemapperRegistry()
|
|
10
26
|
"""Registry of Remapper implementations."""
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class Remapper:
|
|
14
30
|
"""An abstract class that remaps pixel values based on layer configuration."""
|
|
15
31
|
|
|
32
|
+
def __init__(self, config: dict[str, Any]) -> None:
|
|
33
|
+
"""Initialize a Remapper.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: the config dict for this remapper.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
16
40
|
def __call__(
|
|
17
41
|
self, array: npt.NDArray[Any], dtype: npt.DTypeLike
|
|
18
42
|
) -> npt.NDArray[Any]:
|
|
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
|
|
|
67
91
|
|
|
68
92
|
def load_remapper(config: dict[str, Any]) -> Remapper:
|
|
69
93
|
"""Load a remapper from a configuration dictionary."""
|
|
70
|
-
|
|
94
|
+
cls = Remappers[config["name"]]
|
|
95
|
+
return cls(config)
|
|
@@ -7,8 +7,8 @@ from typing import Any
|
|
|
7
7
|
import torch
|
|
8
8
|
import torchvision
|
|
9
9
|
from einops import rearrange
|
|
10
|
-
from torchvision.transforms import v2
|
|
11
10
|
|
|
11
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
12
12
|
from rslearn.train.transforms.transform import Transform
|
|
13
13
|
|
|
14
14
|
|
|
@@ -139,15 +139,17 @@ class DinoV3Normalize(Transform):
|
|
|
139
139
|
super().__init__()
|
|
140
140
|
self.satellite = satellite
|
|
141
141
|
if satellite:
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
std=(0.213, 0.156, 0.143),
|
|
145
|
-
)
|
|
142
|
+
mean = [0.430, 0.411, 0.296]
|
|
143
|
+
std = [0.213, 0.156, 0.143]
|
|
146
144
|
else:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
145
|
+
mean = [0.485, 0.456, 0.406]
|
|
146
|
+
std = [0.229, 0.224, 0.225]
|
|
147
|
+
|
|
148
|
+
self.normalize = Normalize(
|
|
149
|
+
[value * 255 for value in mean],
|
|
150
|
+
[value * 255 for value in std],
|
|
151
|
+
num_bands=3,
|
|
152
|
+
)
|
|
151
153
|
|
|
152
154
|
def forward(
|
|
153
155
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -161,5 +163,4 @@ class DinoV3Normalize(Transform):
|
|
|
161
163
|
Returns:
|
|
162
164
|
normalized (input_dicts, target_dicts) tuple
|
|
163
165
|
"""
|
|
164
|
-
|
|
165
|
-
return input_dict, target_dict
|
|
166
|
+
return self.normalize(input_dict, target_dict)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
import tempfile
|
|
5
|
+
from contextlib import nullcontext
|
|
5
6
|
from enum import StrEnum
|
|
6
7
|
from typing import Any, cast
|
|
7
8
|
|
|
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
|
|
|
63
64
|
|
|
64
65
|
DEFAULT_NORMALIZER = Normalizer()
|
|
65
66
|
|
|
67
|
+
AUTOCAST_DTYPE_MAP = {
|
|
68
|
+
"bfloat16": torch.bfloat16,
|
|
69
|
+
"float32": torch.float32,
|
|
70
|
+
}
|
|
71
|
+
|
|
66
72
|
|
|
67
73
|
class GalileoModel(nn.Module):
|
|
68
74
|
"""Galileo backbones."""
|
|
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
|
|
|
85
91
|
size: GalileoSize,
|
|
86
92
|
patch_size: int = 4,
|
|
87
93
|
pretrained_path: str | UPath | None = None,
|
|
94
|
+
autocast_dtype: str | None = "bfloat16",
|
|
88
95
|
) -> None:
|
|
89
96
|
"""Initialize the Galileo model.
|
|
90
97
|
|
|
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
|
|
|
93
100
|
patch_size: The patch size to use.
|
|
94
101
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
95
102
|
downloaded and cached in temp directory.
|
|
103
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
96
104
|
"""
|
|
97
105
|
super().__init__()
|
|
98
106
|
if pretrained_path is None:
|
|
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
|
|
|
128
136
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
|
|
129
137
|
]
|
|
130
138
|
|
|
139
|
+
self.size = size
|
|
131
140
|
self.patch_size = patch_size
|
|
132
141
|
|
|
142
|
+
if autocast_dtype is not None:
|
|
143
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
144
|
+
else:
|
|
145
|
+
self.autocast_dtype = None
|
|
146
|
+
|
|
133
147
|
@staticmethod
|
|
134
148
|
def to_cartesian(
|
|
135
149
|
lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
|
|
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
|
|
|
484
498
|
patch_size = h
|
|
485
499
|
else:
|
|
486
500
|
patch_size = self.patch_size
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
501
|
+
|
|
502
|
+
# Decide context based on self.autocast_dtype.
|
|
503
|
+
device = galileo_input.s_t_x.device
|
|
504
|
+
if self.autocast_dtype is None:
|
|
505
|
+
context = nullcontext()
|
|
506
|
+
else:
|
|
507
|
+
assert device is not None
|
|
508
|
+
context = torch.amp.autocast(
|
|
509
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
with context:
|
|
513
|
+
outputs = self.model(
|
|
514
|
+
s_t_x=galileo_input.s_t_x,
|
|
515
|
+
s_t_m=galileo_input.s_t_m,
|
|
516
|
+
sp_x=galileo_input.sp_x,
|
|
517
|
+
sp_m=galileo_input.sp_m,
|
|
518
|
+
t_x=galileo_input.t_x,
|
|
519
|
+
t_m=galileo_input.t_m,
|
|
520
|
+
st_x=galileo_input.st_x,
|
|
521
|
+
st_m=galileo_input.st_m,
|
|
522
|
+
months=galileo_input.months,
|
|
523
|
+
patch_size=patch_size,
|
|
524
|
+
)
|
|
525
|
+
|
|
499
526
|
if h == patch_size:
|
|
500
527
|
# only one spatial patch, so we can just take an average
|
|
501
528
|
# of all the tokens to output b c_g 1 1
|
|
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
|
|
|
515
542
|
"b h w c_g d -> b c_g d h w",
|
|
516
543
|
).mean(dim=1)
|
|
517
544
|
]
|
|
545
|
+
|
|
546
|
+
def get_backbone_channels(self) -> list:
|
|
547
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
548
|
+
|
|
549
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
550
|
+
to the feature maps that the backbone returns.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
554
|
+
"""
|
|
555
|
+
if self.size == GalileoSize.BASE:
|
|
556
|
+
depth = 768
|
|
557
|
+
elif self.model_size == GalileoSize.TINY:
|
|
558
|
+
depth = 192
|
|
559
|
+
elif self.model_size == GalileoSize.NANO:
|
|
560
|
+
depth = 128
|
|
561
|
+
else:
|
|
562
|
+
raise ValueError(f"Invalid model size: {self.size}")
|
|
563
|
+
return [(self.patch_size, depth)]
|
|
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
|
|
|
1469
1469
|
# we take the inverse of the mask because a value
|
|
1470
1470
|
# of True indicates the value *should* take part in
|
|
1471
1471
|
# attention
|
|
1472
|
-
|
|
1472
|
+
temp_mask = ~new_m.bool()
|
|
1473
|
+
if temp_mask.all():
|
|
1474
|
+
# if all the tokens are used in attention we can pass a None mask
|
|
1475
|
+
# to the attention block
|
|
1476
|
+
temp_mask = None
|
|
1477
|
+
|
|
1478
|
+
x = blk(x=x, y=None, attn_mask=temp_mask)
|
|
1473
1479
|
|
|
1474
1480
|
if exit_ids_seq is not None:
|
|
1475
1481
|
assert exited_tokens is not None
|
|
@@ -248,3 +248,14 @@ class Presto(nn.Module):
|
|
|
248
248
|
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
249
249
|
|
|
250
250
|
return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
251
|
+
|
|
252
|
+
def get_backbone_channels(self) -> list:
|
|
253
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
254
|
+
|
|
255
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
256
|
+
to the feature maps that the backbone returns.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
260
|
+
"""
|
|
261
|
+
return [(1, 128)]
|
|
@@ -173,6 +173,17 @@ class PrithviV2(nn.Module):
|
|
|
173
173
|
features, num_timesteps
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
+
def get_backbone_channels(self) -> list:
|
|
177
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
178
|
+
|
|
179
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
180
|
+
to the feature maps that the backbone returns.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
184
|
+
"""
|
|
185
|
+
return [(1, 1024)]
|
|
186
|
+
|
|
176
187
|
|
|
177
188
|
class PrithviNormalize(Transform):
|
|
178
189
|
"""Normalize inputs using Prithvi normalization.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Model registry."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
|
|
6
|
+
_ModelT = TypeVar("_ModelT")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _ModelRegistry(dict[str, type[Any]]):
|
|
10
|
+
"""Registry for Model classes."""
|
|
11
|
+
|
|
12
|
+
def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
|
|
13
|
+
"""Decorator to register a model class."""
|
|
14
|
+
|
|
15
|
+
def decorator(cls: type[_ModelT]) -> type[_ModelT]:
|
|
16
|
+
self[name] = cls
|
|
17
|
+
return cls
|
|
18
|
+
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Models = _ModelRegistry()
|
|
@@ -130,10 +130,12 @@ class DefaultTileStore(TileStore):
|
|
|
130
130
|
"""
|
|
131
131
|
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
132
132
|
for fname in raster_dir.iterdir():
|
|
133
|
-
# Ignore completed sentinel files as well as temporary files created by
|
|
133
|
+
# Ignore completed sentinel files, bands files, as well as temporary files created by
|
|
134
134
|
# open_atomic (in case this tile store is on local filesystem).
|
|
135
135
|
if fname.name == COMPLETED_FNAME:
|
|
136
136
|
continue
|
|
137
|
+
if fname.name == BANDS_FNAME:
|
|
138
|
+
continue
|
|
137
139
|
if ".tmp." in fname.name:
|
|
138
140
|
continue
|
|
139
141
|
return fname
|
|
@@ -54,7 +54,7 @@ def read_selector(
|
|
|
54
54
|
the item specified by the selector
|
|
55
55
|
"""
|
|
56
56
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
57
|
-
parts = selector.split("/")
|
|
57
|
+
parts = selector.split("/") if selector else []
|
|
58
58
|
cur = d
|
|
59
59
|
for part in parts:
|
|
60
60
|
cur = cur[part]
|
|
@@ -76,11 +76,28 @@ def write_selector(
|
|
|
76
76
|
v: the value to write
|
|
77
77
|
"""
|
|
78
78
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
if selector:
|
|
80
|
+
parts = selector.split("/")
|
|
81
|
+
cur = d
|
|
82
|
+
for part in parts[:-1]:
|
|
83
|
+
cur = cur[part]
|
|
84
|
+
cur[parts[-1]] = v
|
|
85
|
+
else:
|
|
86
|
+
# If the selector references the input or target dictionary directly, then we
|
|
87
|
+
# have a special case where instead of overwriting with v, we replace the keys
|
|
88
|
+
# with those in v. v must be a dictionary here, not a tensor, since otherwise
|
|
89
|
+
# it wouldn't match the type of the input or target dictionary.
|
|
90
|
+
if not isinstance(v, dict):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"when directly specifying the input or target dict, expected the value to be a dict"
|
|
93
|
+
)
|
|
94
|
+
if d == v:
|
|
95
|
+
# This may happen if the writer did not make a copy of the dictionary. In
|
|
96
|
+
# this case the code below would not update d correctly since it would also
|
|
97
|
+
# clear v.
|
|
98
|
+
return
|
|
99
|
+
d.clear()
|
|
100
|
+
d.update(v)
|
|
84
101
|
|
|
85
102
|
|
|
86
103
|
class Transform(torch.nn.Module):
|