rslearn 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.9/rslearn.egg-info → rslearn-0.0.11}/PKG-INFO +2 -1
- {rslearn-0.0.9 → rslearn-0.0.11}/pyproject.toml +3 -1
- rslearn-0.0.11/rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn-0.0.11/rslearn/models/olmoearth_pretrain/model.py +203 -0
- rslearn-0.0.11/rslearn/models/olmoearth_pretrain/norm.py +84 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/pooling_decoder.py +43 -0
- {rslearn-0.0.9 → rslearn-0.0.11/rslearn.egg-info}/PKG-INFO +2 -1
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn.egg-info/SOURCES.txt +3 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn.egg-info/requires.txt +1 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/LICENSE +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/README.md +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/const.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/main.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/aurora/area.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/aurora/fourier.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/flexivit/utils.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/model_vit.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/copernicusfm_src/util/pos_embed.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/py.typed +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/template_params.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.11}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.11
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -243,6 +243,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
243
243
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
244
244
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
245
245
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
246
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
246
247
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
247
248
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
248
249
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.11"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -47,6 +47,8 @@ extra = [
|
|
|
47
47
|
"pycocotools>=2.0",
|
|
48
48
|
"pystac_client>=0.9",
|
|
49
49
|
"rtree>=1.4",
|
|
50
|
+
# Needed by DINOv3.
|
|
51
|
+
"termcolor>=3.0",
|
|
50
52
|
"satlaspretrain_models>=0.3",
|
|
51
53
|
"scipy>=1.16",
|
|
52
54
|
"terratorch>=1.0.2",
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth model architecture."""
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from olmo_core.config import Config
|
|
10
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
13
|
+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
14
|
+
from upath import UPath
|
|
15
|
+
|
|
16
|
+
from rslearn.log_utils import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
MODALITY_NAMES = [
|
|
21
|
+
"sentinel2_l2a",
|
|
22
|
+
"sentinel1",
|
|
23
|
+
"worldcover",
|
|
24
|
+
"openstreetmap_raster",
|
|
25
|
+
"landsat",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
AUTOCAST_DTYPE_MAP = {
|
|
29
|
+
"bfloat16": torch.bfloat16,
|
|
30
|
+
"float16": torch.float16,
|
|
31
|
+
"float32": torch.float32,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class OlmoEarth(torch.nn.Module):
|
|
36
|
+
"""A wrapper to support the OlmoEarth model."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
# TODO: we should accept model ID instead of checkpoint_path once we are closer
|
|
41
|
+
# to being ready for release.
|
|
42
|
+
checkpoint_path: str,
|
|
43
|
+
selector: list[str | int] = [],
|
|
44
|
+
forward_kwargs: dict[str, Any] = {},
|
|
45
|
+
random_initialization: bool = False,
|
|
46
|
+
embedding_size: int | None = None,
|
|
47
|
+
patch_size: int | None = None,
|
|
48
|
+
autocast_dtype: str | None = "bfloat16",
|
|
49
|
+
):
|
|
50
|
+
"""Create a new OlmoEarth model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
checkpoint_path: the checkpoint directory to load. It should contain
|
|
54
|
+
config.json file as well as model_and_optim folder.
|
|
55
|
+
selector: an optional sequence of attribute names or list indices to select
|
|
56
|
+
the sub-module that should be applied on the input images.
|
|
57
|
+
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
58
|
+
MaskedOlmoEarthSample.
|
|
59
|
+
random_initialization: whether to skip loading the checkpoint so the
|
|
60
|
+
weights are randomly initialized. In this case, the checkpoint is only
|
|
61
|
+
used to define the model architecture.
|
|
62
|
+
embedding_size: optional embedding size to report via
|
|
63
|
+
get_backbone_channels.
|
|
64
|
+
patch_size: optional patch size to report via get_backbone_channels.
|
|
65
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__()
|
|
68
|
+
_checkpoint_path = UPath(checkpoint_path)
|
|
69
|
+
self.forward_kwargs = forward_kwargs
|
|
70
|
+
self.embedding_size = embedding_size
|
|
71
|
+
self.patch_size = patch_size
|
|
72
|
+
|
|
73
|
+
if autocast_dtype is not None:
|
|
74
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
75
|
+
else:
|
|
76
|
+
self.autocast_dtype = None
|
|
77
|
+
|
|
78
|
+
# Load the model config and initialize it.
|
|
79
|
+
# We avoid loading the train module here because it depends on running within
|
|
80
|
+
# olmo_core.
|
|
81
|
+
with (_checkpoint_path / "config.json").open() as f:
|
|
82
|
+
config_dict = json.load(f)
|
|
83
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
84
|
+
|
|
85
|
+
model = model_config.build()
|
|
86
|
+
|
|
87
|
+
# Load the checkpoint.
|
|
88
|
+
if not random_initialization:
|
|
89
|
+
train_module_dir = _checkpoint_path / "model_and_optim"
|
|
90
|
+
if train_module_dir.exists():
|
|
91
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
92
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
93
|
+
else:
|
|
94
|
+
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
95
|
+
else:
|
|
96
|
+
logger.info("skipping loading OlmoEarth encoder")
|
|
97
|
+
|
|
98
|
+
# Select just the portion of the model that we actually want to use.
|
|
99
|
+
for part in selector:
|
|
100
|
+
if isinstance(part, str):
|
|
101
|
+
model = getattr(model, part)
|
|
102
|
+
else:
|
|
103
|
+
model = model[part]
|
|
104
|
+
self.model = model
|
|
105
|
+
|
|
106
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
107
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
108
|
+
|
|
109
|
+
Inputs:
|
|
110
|
+
inputs: input dicts. It should include keys corresponding to the modalities
|
|
111
|
+
that should be passed to the OlmoEarth model.
|
|
112
|
+
"""
|
|
113
|
+
kwargs = {}
|
|
114
|
+
present_modalities = []
|
|
115
|
+
device = None
|
|
116
|
+
# Handle the case where some modalities are multitemporal and some are not.
|
|
117
|
+
# We assume all multitemporal modalities have the same number of timesteps.
|
|
118
|
+
max_timesteps = 1
|
|
119
|
+
for modality in MODALITY_NAMES:
|
|
120
|
+
if modality not in inputs[0]:
|
|
121
|
+
continue
|
|
122
|
+
present_modalities.append(modality)
|
|
123
|
+
cur = torch.stack([inp[modality] for inp in inputs], dim=0)
|
|
124
|
+
device = cur.device
|
|
125
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
126
|
+
num_bands = Modality.get(modality).num_bands
|
|
127
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
128
|
+
max_timesteps = max(max_timesteps, num_timesteps)
|
|
129
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
130
|
+
kwargs[modality] = cur
|
|
131
|
+
# Create mask array which is BHWTS (without channels but with band sets).
|
|
132
|
+
num_band_sets = len(Modality.get(modality).band_sets)
|
|
133
|
+
mask_shape = cur.shape[0:4] + (num_band_sets,)
|
|
134
|
+
mask = (
|
|
135
|
+
torch.ones(mask_shape, dtype=torch.int32, device=device)
|
|
136
|
+
* MaskValue.ONLINE_ENCODER.value
|
|
137
|
+
)
|
|
138
|
+
kwargs[f"{modality}_mask"] = mask
|
|
139
|
+
|
|
140
|
+
# Timestamps is required.
|
|
141
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
142
|
+
# For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
|
|
143
|
+
timestamps = torch.zeros(
|
|
144
|
+
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
145
|
+
)
|
|
146
|
+
timestamps[:, :, 0] = 1 # day
|
|
147
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
148
|
+
None, :
|
|
149
|
+
] # month
|
|
150
|
+
timestamps[:, :, 2] = 2024 # year
|
|
151
|
+
kwargs["timestamps"] = timestamps
|
|
152
|
+
|
|
153
|
+
sample = MaskedOlmoEarthSample(**kwargs)
|
|
154
|
+
|
|
155
|
+
# Decide context based on self.autocast_dtype.
|
|
156
|
+
if self.autocast_dtype is None:
|
|
157
|
+
context = nullcontext()
|
|
158
|
+
else:
|
|
159
|
+
assert device is not None
|
|
160
|
+
context = torch.amp.autocast(
|
|
161
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
with context:
|
|
165
|
+
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
166
|
+
tokens_and_masks: TokensAndMasks
|
|
167
|
+
if isinstance(self.model, Encoder):
|
|
168
|
+
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
169
|
+
tokens_and_masks = self.model(
|
|
170
|
+
sample, fast_pass=True, **self.forward_kwargs
|
|
171
|
+
)["tokens_and_masks"]
|
|
172
|
+
else:
|
|
173
|
+
# Other models like STEncoder do not have this option supported.
|
|
174
|
+
tokens_and_masks = self.model(sample, **self.forward_kwargs)[
|
|
175
|
+
"tokens_and_masks"
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
179
|
+
features = []
|
|
180
|
+
for modality in present_modalities:
|
|
181
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
182
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
183
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
184
|
+
# We want BHWC -> BCHW.
|
|
185
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
186
|
+
features.append(pooled)
|
|
187
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
188
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
189
|
+
return [pooled]
|
|
190
|
+
|
|
191
|
+
def get_backbone_channels(self) -> list:
|
|
192
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
193
|
+
|
|
194
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
195
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
196
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
197
|
+
has 32 channels.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
201
|
+
tuples.
|
|
202
|
+
"""
|
|
203
|
+
return [(self.patch_size, self.embedding_size)]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Normalization transforms."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from olmoearth_pretrain.data.normalize import load_computed_config
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.train.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__file__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OlmoEarthNormalize(Transform):
|
|
15
|
+
"""Normalize using OlmoEarth JSON config.
|
|
16
|
+
|
|
17
|
+
For Sentinel-1 data, the values should be converted to decibels before being passed
|
|
18
|
+
to this transform.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
band_names: dict[str, list[str]],
|
|
24
|
+
std_multiplier: float | None = 2,
|
|
25
|
+
config_fname: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize a new OlmoEarthNormalize.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band_names: map from modality name to the list of bands in that modality in
|
|
31
|
+
the order they are being loaded. Note that this order must match the
|
|
32
|
+
expected order for the OlmoEarth model.
|
|
33
|
+
std_multiplier: the std multiplier matching the one used for the model
|
|
34
|
+
training in OlmoEarth.
|
|
35
|
+
config_fname: load the normalization configuration from this file, instead
|
|
36
|
+
of getting it from OlmoEarth.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.band_names = band_names
|
|
40
|
+
self.std_multiplier = std_multiplier
|
|
41
|
+
|
|
42
|
+
if config_fname is None:
|
|
43
|
+
self.norm_config = load_computed_config()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
|
|
47
|
+
)
|
|
48
|
+
with open(config_fname) as f:
|
|
49
|
+
self.norm_config = json.load(f)
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Apply normalization over the inputs and targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
input_dict: the input
|
|
58
|
+
target_dict: the target
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
normalized (input_dicts, target_dicts) tuple
|
|
62
|
+
"""
|
|
63
|
+
for modality_name, cur_band_names in self.band_names.items():
|
|
64
|
+
band_norms = self.norm_config[modality_name]
|
|
65
|
+
image = input_dict[modality_name]
|
|
66
|
+
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
+
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
+
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
69
|
+
|
|
70
|
+
for band, norm_dict in band_norms.items():
|
|
71
|
+
# If multitemporal, normalize each timestep separately.
|
|
72
|
+
for t in range(num_timesteps):
|
|
73
|
+
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
|
+
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
|
+
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
+
image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
|
|
77
|
+
needed_band_indices.remove(band_idx)
|
|
78
|
+
|
|
79
|
+
if len(needed_band_indices) > 0:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return input_dict, target_dict
|
|
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
76
76
|
features = torch.amax(features, dim=(2, 3))
|
|
77
77
|
features = self.fc_layers(features)
|
|
78
78
|
return self.output_layer(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
82
|
+
"""Like PoolingDecoder, but copy output to all pixels.
|
|
83
|
+
|
|
84
|
+
This allows for the model to produce a global output while still being compatible
|
|
85
|
+
with SegmentationTask. This only makes sense for very small windows, since the
|
|
86
|
+
output probabilities will be the same at all pixels. The main use case is to train
|
|
87
|
+
for a classification-like task on small windows, but still produce a raster during
|
|
88
|
+
inference on large windows.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
in_channels: int,
|
|
94
|
+
out_channels: int,
|
|
95
|
+
image_key: str = "image",
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
):
|
|
98
|
+
"""Create a new SegmentationPoolingDecoder.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
in_channels: input channels (channels in the last feature map passed to
|
|
102
|
+
this module)
|
|
103
|
+
out_channels: channels for the output flat feature vector
|
|
104
|
+
image_key: the key in inputs for the image from which the expected width
|
|
105
|
+
and height is derived.
|
|
106
|
+
kwargs: other arguments to pass to PoolingDecoder.
|
|
107
|
+
"""
|
|
108
|
+
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
|
+
self.image_key = image_key
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
+
) -> torch.Tensor:
|
|
114
|
+
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
|
+
|
|
116
|
+
This only works when all of the pixels have the same segmentation target.
|
|
117
|
+
"""
|
|
118
|
+
output_probs = super().forward(features, inputs)
|
|
119
|
+
# BC -> BCHW
|
|
120
|
+
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
+
return output_probs[:, :, None, None].repeat([1, 1, h, w])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.11
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -243,6 +243,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
243
243
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
244
244
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
245
245
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
246
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
246
247
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
247
248
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
248
249
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
@@ -102,6 +102,9 @@ rslearn/models/detr/util.py
|
|
|
102
102
|
rslearn/models/galileo/__init__.py
|
|
103
103
|
rslearn/models/galileo/galileo.py
|
|
104
104
|
rslearn/models/galileo/single_file_galileo.py
|
|
105
|
+
rslearn/models/olmoearth_pretrain/__init__.py
|
|
106
|
+
rslearn/models/olmoearth_pretrain/model.py
|
|
107
|
+
rslearn/models/olmoearth_pretrain/norm.py
|
|
105
108
|
rslearn/models/panopticon_data/sensors/drone.yaml
|
|
106
109
|
rslearn/models/panopticon_data/sensors/enmap.yaml
|
|
107
110
|
rslearn/models/panopticon_data/sensors/goes.yaml
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|