rslearn 0.0.9__tar.gz → 0.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn-0.0.12/NOTICE +115 -0
- {rslearn-0.0.9/rslearn.egg-info → rslearn-0.0.12}/PKG-INFO +3 -1
- {rslearn-0.0.9 → rslearn-0.0.12}/pyproject.toml +3 -1
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/anysat.py +5 -1
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/dinov3.py +6 -1
- rslearn-0.0.12/rslearn/models/feature_center_crop.py +50 -0
- rslearn-0.0.12/rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn-0.0.12/rslearn/models/olmoearth_pretrain/model.py +263 -0
- rslearn-0.0.12/rslearn/models/olmoearth_pretrain/norm.py +84 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/pooling_decoder.py +43 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/prithvi.py +9 -1
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/lightning_module.py +0 -3
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/classification.py +2 -2
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/detection.py +5 -5
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/per_pixel_regression.py +5 -4
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/regression.py +5 -5
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/pad.py +3 -3
- {rslearn-0.0.9 → rslearn-0.0.12/rslearn.egg-info}/PKG-INFO +3 -1
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/SOURCES.txt +5 -9
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/requires.txt +1 -0
- rslearn-0.0.9/rslearn/models/copernicusfm.py +0 -228
- rslearn-0.0.9/rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn-0.0.9/rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn-0.0.9/rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn-0.0.9/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn-0.0.9/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn-0.0.9/rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn-0.0.9/rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn-0.0.9/rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.9 → rslearn-0.0.12}/LICENSE +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/README.md +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/const.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/main.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/py.typed +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/template_params.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.9 → rslearn-0.0.12}/setup.cfg +0 -0
rslearn-0.0.12/NOTICE
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
rslearn is released under Apache License 2.0
|
|
2
|
+
Copyright 2025 Allen Institute for AI
|
|
3
|
+
|
|
4
|
+
The following third party code is included in this repository.
|
|
5
|
+
|
|
6
|
+
====================
|
|
7
|
+
|
|
8
|
+
rslearn.models.detr is adapted from https://github.com/facebookresearch/detr which is
|
|
9
|
+
released under Apache License 2.0.
|
|
10
|
+
|
|
11
|
+
Copyright 2020 - present, Facebook, Inc
|
|
12
|
+
|
|
13
|
+
====================
|
|
14
|
+
|
|
15
|
+
rslearn.models.use_croma is copied from https://github.com/antofuller/CROMA
|
|
16
|
+
|
|
17
|
+
MIT License
|
|
18
|
+
|
|
19
|
+
Copyright (c) 2023 Anthony Fuller
|
|
20
|
+
|
|
21
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
22
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
23
|
+
in the Software without restriction, including without limitation the rights
|
|
24
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
25
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
26
|
+
furnished to do so, subject to the following conditions:
|
|
27
|
+
|
|
28
|
+
The above copyright notice and this permission notice shall be included in all
|
|
29
|
+
copies or substantial portions of the Software.
|
|
30
|
+
|
|
31
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
32
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
33
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
34
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
35
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
36
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
37
|
+
SOFTWARE.
|
|
38
|
+
|
|
39
|
+
====================
|
|
40
|
+
|
|
41
|
+
rslearn.models.galileo is adapted from https://github.com/nasaharvest/galileo
|
|
42
|
+
|
|
43
|
+
MIT License
|
|
44
|
+
|
|
45
|
+
Copyright (c) 2024 Presto Authors
|
|
46
|
+
|
|
47
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
48
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
49
|
+
in the Software without restriction, including without limitation the rights
|
|
50
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
51
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
52
|
+
furnished to do so, subject to the following conditions:
|
|
53
|
+
|
|
54
|
+
The above copyright notice and this permission notice shall be included in all
|
|
55
|
+
copies or substantial portions of the Software.
|
|
56
|
+
|
|
57
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
58
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
59
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
60
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
61
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
62
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
63
|
+
SOFTWARE.
|
|
64
|
+
|
|
65
|
+
====================
|
|
66
|
+
|
|
67
|
+
rslearn.models.presto is adapted from https://github.com/nasaharvest/presto
|
|
68
|
+
|
|
69
|
+
MIT License
|
|
70
|
+
|
|
71
|
+
Copyright (c) 2024 Presto Authors
|
|
72
|
+
|
|
73
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
74
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
75
|
+
in the Software without restriction, including without limitation the rights
|
|
76
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
77
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
78
|
+
furnished to do so, subject to the following conditions:
|
|
79
|
+
|
|
80
|
+
The above copyright notice and this permission notice shall be included in all
|
|
81
|
+
copies or substantial portions of the Software.
|
|
82
|
+
|
|
83
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
84
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
85
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
86
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
87
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
88
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
89
|
+
SOFTWARE.
|
|
90
|
+
|
|
91
|
+
====================
|
|
92
|
+
|
|
93
|
+
rslearn.models.prithvi includes code adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
94
|
+
|
|
95
|
+
MIT License
|
|
96
|
+
|
|
97
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
98
|
+
|
|
99
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
100
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
101
|
+
in the Software without restriction, including without limitation the rights
|
|
102
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
103
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
104
|
+
furnished to do so, subject to the following conditions:
|
|
105
|
+
|
|
106
|
+
The above copyright notice and this permission notice shall be included in all
|
|
107
|
+
copies or substantial portions of the Software.
|
|
108
|
+
|
|
109
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
110
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
111
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
112
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
113
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
114
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
115
|
+
SOFTWARE.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.12
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -211,6 +211,7 @@ Project-URL: repository, https://github.com/allenai/rslearn
|
|
|
211
211
|
Requires-Python: >=3.11
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
|
+
License-File: NOTICE
|
|
214
215
|
Requires-Dist: boto3>=1.39
|
|
215
216
|
Requires-Dist: fiona>=1.10
|
|
216
217
|
Requires-Dist: fsspec>=2025.9.0
|
|
@@ -243,6 +244,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
243
244
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
244
245
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
245
246
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
247
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
246
248
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
247
249
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
248
250
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.12"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -47,6 +47,8 @@ extra = [
|
|
|
47
47
|
"pycocotools>=2.0",
|
|
48
48
|
"pystac_client>=0.9",
|
|
49
49
|
"rtree>=1.4",
|
|
50
|
+
# Needed by DINOv3.
|
|
51
|
+
"termcolor>=3.0",
|
|
50
52
|
"satlaspretrain_models>=0.3",
|
|
51
53
|
"scipy>=1.16",
|
|
52
54
|
"terratorch>=1.0.2",
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""DinoV3 model.
|
|
1
|
+
"""DinoV3 model.
|
|
2
|
+
|
|
3
|
+
This code loads the DINOv3 model. You must obtain the model separately from Meta to use
|
|
4
|
+
it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
|
|
5
|
+
information.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
from enum import StrEnum
|
|
4
9
|
from pathlib import Path
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Apply center cropping on a feature map."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FeatureCenterCrop(torch.nn.Module):
|
|
9
|
+
"""Apply center cropping on the input feature maps."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
sizes: list[tuple[int, int]],
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Create a new FeatureCenterCrop.
|
|
16
|
+
|
|
17
|
+
Only the center of each feature map will be retained and passed to the next
|
|
18
|
+
module.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
sizes: a list of (height, width) tuples, with one tuple for each input
|
|
22
|
+
feature map.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.sizes = sizes
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
+
) -> list[torch.Tensor]:
|
|
30
|
+
"""Apply center cropping on the feature maps.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
features: list of feature maps at different resolutions.
|
|
34
|
+
inputs: original inputs (ignored).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
center cropped feature maps.
|
|
38
|
+
"""
|
|
39
|
+
new_features = []
|
|
40
|
+
for i, feat in enumerate(features):
|
|
41
|
+
height, width = self.sizes[i]
|
|
42
|
+
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"feature map is smaller than the desired height and width"
|
|
45
|
+
)
|
|
46
|
+
start_h = feat.shape[2] // 2 - height // 2
|
|
47
|
+
start_w = feat.shape[3] // 2 - width // 2
|
|
48
|
+
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
|
+
new_features.append(feat)
|
|
50
|
+
return new_features
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth model architecture."""
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from olmo_core.config import Config
|
|
10
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.model_loader import (
|
|
13
|
+
ModelID,
|
|
14
|
+
load_model_from_id,
|
|
15
|
+
load_model_from_path,
|
|
16
|
+
)
|
|
17
|
+
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
18
|
+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
19
|
+
from upath import UPath
|
|
20
|
+
|
|
21
|
+
from rslearn.log_utils import get_logger
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
MODALITY_NAMES = [
|
|
26
|
+
"sentinel2_l2a",
|
|
27
|
+
"sentinel1",
|
|
28
|
+
"worldcover",
|
|
29
|
+
"openstreetmap_raster",
|
|
30
|
+
"landsat",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
AUTOCAST_DTYPE_MAP = {
|
|
34
|
+
"bfloat16": torch.bfloat16,
|
|
35
|
+
"float16": torch.float16,
|
|
36
|
+
"float32": torch.float32,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
EMBEDDING_SIZES = {
|
|
40
|
+
ModelID.OLMOEARTH_V1_NANO: 128,
|
|
41
|
+
ModelID.OLMOEARTH_V1_TINY: 192,
|
|
42
|
+
ModelID.OLMOEARTH_V1_BASE: 768,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class OlmoEarth(torch.nn.Module):
|
|
47
|
+
"""A wrapper to support the OlmoEarth model."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
patch_size: int,
|
|
52
|
+
model_id: ModelID | None = None,
|
|
53
|
+
model_path: str | None = None,
|
|
54
|
+
checkpoint_path: str | None = None,
|
|
55
|
+
selector: list[str | int] = ["encoder"],
|
|
56
|
+
forward_kwargs: dict[str, Any] = {},
|
|
57
|
+
random_initialization: bool = False,
|
|
58
|
+
embedding_size: int | None = None,
|
|
59
|
+
autocast_dtype: str | None = "bfloat16",
|
|
60
|
+
):
|
|
61
|
+
"""Create a new OlmoEarth model.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
patch_size: token spatial patch size to use.
|
|
65
|
+
model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
|
|
66
|
+
set.
|
|
67
|
+
model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
|
|
68
|
+
set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
|
|
69
|
+
checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
|
|
70
|
+
set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
|
|
71
|
+
folder.
|
|
72
|
+
selector: an optional sequence of attribute names or list indices to select
|
|
73
|
+
the sub-module that should be applied on the input images. Defaults to
|
|
74
|
+
["encoder"] to select only the transformer encoder.
|
|
75
|
+
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
76
|
+
MaskedOlmoEarthSample.
|
|
77
|
+
random_initialization: whether to skip loading the checkpoint so the
|
|
78
|
+
weights are randomly initialized. In this case, the checkpoint is only
|
|
79
|
+
used to define the model architecture.
|
|
80
|
+
embedding_size: optional embedding size to report via
|
|
81
|
+
get_backbone_channels (if model_id is not set).
|
|
82
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
83
|
+
"""
|
|
84
|
+
if (
|
|
85
|
+
sum(
|
|
86
|
+
[
|
|
87
|
+
model_id is not None,
|
|
88
|
+
model_path is not None,
|
|
89
|
+
checkpoint_path is not None,
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
!= 1
|
|
93
|
+
):
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"exactly one of model_id, model_path, or checkpoint_path must be set"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.patch_size = patch_size
|
|
100
|
+
self.forward_kwargs = forward_kwargs
|
|
101
|
+
self.embedding_size = embedding_size
|
|
102
|
+
|
|
103
|
+
if autocast_dtype is not None:
|
|
104
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
105
|
+
else:
|
|
106
|
+
self.autocast_dtype = None
|
|
107
|
+
|
|
108
|
+
if model_id is not None:
|
|
109
|
+
# Load from Hugging Face.
|
|
110
|
+
model = load_model_from_id(model_id, load_weights=not random_initialization)
|
|
111
|
+
if self.embedding_size is None and model_id in EMBEDDING_SIZES:
|
|
112
|
+
self.embedding_size = EMBEDDING_SIZES[model_id]
|
|
113
|
+
|
|
114
|
+
elif model_path is not None:
|
|
115
|
+
# Load from path.
|
|
116
|
+
model = load_model_from_path(
|
|
117
|
+
UPath(model_path), load_weights=not random_initialization
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
else:
|
|
121
|
+
# Load the distributed model checkpoint by path through Olmo Core
|
|
122
|
+
model = self._load_model_from_checkpoint(
|
|
123
|
+
UPath(checkpoint_path), random_initialization
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Select just the portion of the model that we actually want to use.
|
|
127
|
+
for part in selector:
|
|
128
|
+
if isinstance(part, str):
|
|
129
|
+
model = getattr(model, part)
|
|
130
|
+
else:
|
|
131
|
+
model = model[part]
|
|
132
|
+
self.model = model
|
|
133
|
+
|
|
134
|
+
def _load_model_from_checkpoint(
|
|
135
|
+
self, checkpoint_upath: UPath, random_initialization: bool
|
|
136
|
+
) -> torch.nn.Module:
|
|
137
|
+
"""Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
|
|
138
|
+
|
|
139
|
+
The folder should contain config.json as well as the model_and_optim folder
|
|
140
|
+
that contains the distributed checkpoint. This is the format produced by
|
|
141
|
+
pre-training runs in olmoearth_pretrain.
|
|
142
|
+
"""
|
|
143
|
+
# Load the model config and initialize it.
|
|
144
|
+
# We avoid loading the train module here because it depends on running within
|
|
145
|
+
# olmo_core.
|
|
146
|
+
with (checkpoint_upath / "config.json").open() as f:
|
|
147
|
+
config_dict = json.load(f)
|
|
148
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
149
|
+
|
|
150
|
+
model = model_config.build()
|
|
151
|
+
|
|
152
|
+
# Load the checkpoint.
|
|
153
|
+
if not random_initialization:
|
|
154
|
+
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
155
|
+
if train_module_dir.exists():
|
|
156
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
157
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
158
|
+
else:
|
|
159
|
+
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
160
|
+
|
|
161
|
+
return model
|
|
162
|
+
|
|
163
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
164
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
165
|
+
|
|
166
|
+
Inputs:
|
|
167
|
+
inputs: input dicts. It should include keys corresponding to the modalities
|
|
168
|
+
that should be passed to the OlmoEarth model.
|
|
169
|
+
"""
|
|
170
|
+
kwargs = {}
|
|
171
|
+
present_modalities = []
|
|
172
|
+
device = None
|
|
173
|
+
# Handle the case where some modalities are multitemporal and some are not.
|
|
174
|
+
# We assume all multitemporal modalities have the same number of timesteps.
|
|
175
|
+
max_timesteps = 1
|
|
176
|
+
for modality in MODALITY_NAMES:
|
|
177
|
+
if modality not in inputs[0]:
|
|
178
|
+
continue
|
|
179
|
+
present_modalities.append(modality)
|
|
180
|
+
cur = torch.stack([inp[modality] for inp in inputs], dim=0)
|
|
181
|
+
device = cur.device
|
|
182
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
183
|
+
num_bands = Modality.get(modality).num_bands
|
|
184
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
185
|
+
max_timesteps = max(max_timesteps, num_timesteps)
|
|
186
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
187
|
+
kwargs[modality] = cur
|
|
188
|
+
# Create mask array which is BHWTS (without channels but with band sets).
|
|
189
|
+
num_band_sets = len(Modality.get(modality).band_sets)
|
|
190
|
+
mask_shape = cur.shape[0:4] + (num_band_sets,)
|
|
191
|
+
mask = (
|
|
192
|
+
torch.ones(mask_shape, dtype=torch.int32, device=device)
|
|
193
|
+
* MaskValue.ONLINE_ENCODER.value
|
|
194
|
+
)
|
|
195
|
+
kwargs[f"{modality}_mask"] = mask
|
|
196
|
+
|
|
197
|
+
# Timestamps is required.
|
|
198
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
199
|
+
# For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
|
|
200
|
+
timestamps = torch.zeros(
|
|
201
|
+
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
202
|
+
)
|
|
203
|
+
timestamps[:, :, 0] = 1 # day
|
|
204
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
205
|
+
None, :
|
|
206
|
+
] # month
|
|
207
|
+
timestamps[:, :, 2] = 2024 # year
|
|
208
|
+
kwargs["timestamps"] = timestamps
|
|
209
|
+
|
|
210
|
+
sample = MaskedOlmoEarthSample(**kwargs)
|
|
211
|
+
|
|
212
|
+
# Decide context based on self.autocast_dtype.
|
|
213
|
+
if self.autocast_dtype is None:
|
|
214
|
+
context = nullcontext()
|
|
215
|
+
else:
|
|
216
|
+
assert device is not None
|
|
217
|
+
context = torch.amp.autocast(
|
|
218
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
with context:
|
|
222
|
+
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
223
|
+
tokens_and_masks: TokensAndMasks
|
|
224
|
+
if isinstance(self.model, Encoder):
|
|
225
|
+
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
226
|
+
tokens_and_masks = self.model(
|
|
227
|
+
sample,
|
|
228
|
+
fast_pass=True,
|
|
229
|
+
patch_size=self.patch_size,
|
|
230
|
+
**self.forward_kwargs,
|
|
231
|
+
)["tokens_and_masks"]
|
|
232
|
+
else:
|
|
233
|
+
# Other models like STEncoder do not have this option supported.
|
|
234
|
+
tokens_and_masks = self.model(
|
|
235
|
+
sample, patch_size=self.patch_size, **self.forward_kwargs
|
|
236
|
+
)["tokens_and_masks"]
|
|
237
|
+
|
|
238
|
+
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
239
|
+
features = []
|
|
240
|
+
for modality in present_modalities:
|
|
241
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
242
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
243
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
244
|
+
# We want BHWC -> BCHW.
|
|
245
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
246
|
+
features.append(pooled)
|
|
247
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
248
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
249
|
+
return [pooled]
|
|
250
|
+
|
|
251
|
+
def get_backbone_channels(self) -> list:
|
|
252
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
253
|
+
|
|
254
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
255
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
256
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
257
|
+
has 32 channels.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
261
|
+
tuples.
|
|
262
|
+
"""
|
|
263
|
+
return [(self.patch_size, self.embedding_size)]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Normalization transforms."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from olmoearth_pretrain.data.normalize import load_computed_config
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.train.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__file__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OlmoEarthNormalize(Transform):
|
|
15
|
+
"""Normalize using OlmoEarth JSON config.
|
|
16
|
+
|
|
17
|
+
For Sentinel-1 data, the values should be converted to decibels before being passed
|
|
18
|
+
to this transform.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
band_names: dict[str, list[str]],
|
|
24
|
+
std_multiplier: float | None = 2,
|
|
25
|
+
config_fname: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize a new OlmoEarthNormalize.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band_names: map from modality name to the list of bands in that modality in
|
|
31
|
+
the order they are being loaded. Note that this order must match the
|
|
32
|
+
expected order for the OlmoEarth model.
|
|
33
|
+
std_multiplier: the std multiplier matching the one used for the model
|
|
34
|
+
training in OlmoEarth.
|
|
35
|
+
config_fname: load the normalization configuration from this file, instead
|
|
36
|
+
of getting it from OlmoEarth.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.band_names = band_names
|
|
40
|
+
self.std_multiplier = std_multiplier
|
|
41
|
+
|
|
42
|
+
if config_fname is None:
|
|
43
|
+
self.norm_config = load_computed_config()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
|
|
47
|
+
)
|
|
48
|
+
with open(config_fname) as f:
|
|
49
|
+
self.norm_config = json.load(f)
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Apply normalization over the inputs and targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
input_dict: the input
|
|
58
|
+
target_dict: the target
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
normalized (input_dicts, target_dicts) tuple
|
|
62
|
+
"""
|
|
63
|
+
for modality_name, cur_band_names in self.band_names.items():
|
|
64
|
+
band_norms = self.norm_config[modality_name]
|
|
65
|
+
image = input_dict[modality_name]
|
|
66
|
+
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
+
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
+
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
69
|
+
|
|
70
|
+
for band, norm_dict in band_norms.items():
|
|
71
|
+
# If multitemporal, normalize each timestep separately.
|
|
72
|
+
for t in range(num_timesteps):
|
|
73
|
+
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
|
+
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
|
+
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
+
image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
|
|
77
|
+
needed_band_indices.remove(band_idx)
|
|
78
|
+
|
|
79
|
+
if len(needed_band_indices) > 0:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return input_dict, target_dict
|
|
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
76
76
|
features = torch.amax(features, dim=(2, 3))
|
|
77
77
|
features = self.fc_layers(features)
|
|
78
78
|
return self.output_layer(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
82
|
+
"""Like PoolingDecoder, but copy output to all pixels.
|
|
83
|
+
|
|
84
|
+
This allows for the model to produce a global output while still being compatible
|
|
85
|
+
with SegmentationTask. This only makes sense for very small windows, since the
|
|
86
|
+
output probabilities will be the same at all pixels. The main use case is to train
|
|
87
|
+
for a classification-like task on small windows, but still produce a raster during
|
|
88
|
+
inference on large windows.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
in_channels: int,
|
|
94
|
+
out_channels: int,
|
|
95
|
+
image_key: str = "image",
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
):
|
|
98
|
+
"""Create a new SegmentationPoolingDecoder.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
in_channels: input channels (channels in the last feature map passed to
|
|
102
|
+
this module)
|
|
103
|
+
out_channels: channels for the output flat feature vector
|
|
104
|
+
image_key: the key in inputs for the image from which the expected width
|
|
105
|
+
and height is derived.
|
|
106
|
+
kwargs: other arguments to pass to PoolingDecoder.
|
|
107
|
+
"""
|
|
108
|
+
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
|
+
self.image_key = image_key
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
+
) -> torch.Tensor:
|
|
114
|
+
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
|
+
|
|
116
|
+
This only works when all of the pixels have the same segmentation target.
|
|
117
|
+
"""
|
|
118
|
+
output_probs = super().forward(features, inputs)
|
|
119
|
+
# BC -> BCHW
|
|
120
|
+
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
+
return output_probs[:, :, None, None].repeat([1, 1, h, w])
|
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
"""Prithvi V2.
|
|
1
|
+
"""Prithvi V2.
|
|
2
|
+
|
|
3
|
+
This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
4
|
+
|
|
5
|
+
The code is released under:
|
|
6
|
+
|
|
7
|
+
MIT License
|
|
8
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import json
|
|
4
12
|
import logging
|
|
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
94
94
|
restore_config: RestoreConfig | None = None,
|
|
95
95
|
print_parameters: bool = False,
|
|
96
96
|
print_model: bool = False,
|
|
97
|
-
strict_loading: bool = True,
|
|
98
97
|
# Deprecated options.
|
|
99
98
|
lr: float = 1e-3,
|
|
100
99
|
plateau: bool = False,
|
|
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
118
117
|
print_parameters: whether to print the list of model parameters after model
|
|
119
118
|
initialization
|
|
120
119
|
print_model: whether to print the model after model initialization
|
|
121
|
-
strict_loading: whether to strictly load the model parameters.
|
|
122
120
|
lr: deprecated.
|
|
123
121
|
plateau: deprecated.
|
|
124
122
|
plateau_factor: deprecated.
|
|
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
130
|
self.visualize_dir = visualize_dir
|
|
133
131
|
self.metrics_file = metrics_file
|
|
134
132
|
self.restore_config = restore_config
|
|
135
|
-
self.strict_loading = strict_loading
|
|
136
133
|
|
|
137
134
|
self.scheduler_factory: SchedulerFactory | None = None
|
|
138
135
|
if scheduler:
|
|
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
|
|
|
49
49
|
features with matching properties.
|
|
50
50
|
read_class_id: whether to read an integer class ID instead of the class
|
|
51
51
|
name.
|
|
52
|
-
allow_invalid: instead of throwing error when no
|
|
53
|
-
at a window, simply mark the example invalid for this task
|
|
52
|
+
allow_invalid: instead of throwing error when no classification label is
|
|
53
|
+
found at a window, simply mark the example invalid for this task
|
|
54
54
|
skip_unknown_categories: whether to skip examples with categories that are
|
|
55
55
|
not passed via classes, instead of throwing error
|
|
56
56
|
prob_property: when predicting, write probabilities in addition to class ID
|