rslearn 0.0.11__tar.gz → 0.0.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn-0.0.13/NOTICE +115 -0
- {rslearn-0.0.11/rslearn.egg-info → rslearn-0.0.13}/PKG-INFO +3 -2
- {rslearn-0.0.11 → rslearn-0.0.13}/pyproject.toml +2 -2
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/config/dataset.py +23 -4
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/planetary_computer.py +52 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/handler_summaries.py +1 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/manage.py +16 -2
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/anysat.py +5 -1
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/dinov3.py +6 -1
- rslearn-0.0.13/rslearn/models/feature_center_crop.py +50 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/model.py +88 -27
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/prithvi.py +9 -1
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/lightning_module.py +0 -3
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/prediction_writer.py +25 -8
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/classification.py +2 -2
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/detection.py +5 -5
- rslearn-0.0.13/rslearn/train/tasks/embedding.py +116 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/per_pixel_regression.py +5 -4
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/regression.py +5 -5
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/pad.py +3 -3
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/raster_format.py +38 -0
- {rslearn-0.0.11 → rslearn-0.0.13/rslearn.egg-info}/PKG-INFO +3 -2
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn.egg-info/SOURCES.txt +3 -9
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn.egg-info/requires.txt +1 -1
- rslearn-0.0.11/rslearn/models/copernicusfm.py +0 -228
- rslearn-0.0.11/rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn-0.0.11/rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn-0.0.11/rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn-0.0.11/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn-0.0.11/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn-0.0.11/rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn-0.0.11/rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn-0.0.11/rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.11 → rslearn-0.0.13}/LICENSE +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/README.md +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/const.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/main.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/py.typed +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/template_params.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.11 → rslearn-0.0.13}/setup.cfg +0 -0
rslearn-0.0.13/NOTICE
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
rslearn is released under Apache License 2.0
|
|
2
|
+
Copyright 2025 Allen Institute for AI
|
|
3
|
+
|
|
4
|
+
The following third party code is included in this repository.
|
|
5
|
+
|
|
6
|
+
====================
|
|
7
|
+
|
|
8
|
+
rslearn.models.detr is adapted from https://github.com/facebookresearch/detr which is
|
|
9
|
+
released under Apache License 2.0.
|
|
10
|
+
|
|
11
|
+
Copyright 2020 - present, Facebook, Inc
|
|
12
|
+
|
|
13
|
+
====================
|
|
14
|
+
|
|
15
|
+
rslearn.models.use_croma is copied from https://github.com/antofuller/CROMA
|
|
16
|
+
|
|
17
|
+
MIT License
|
|
18
|
+
|
|
19
|
+
Copyright (c) 2023 Anthony Fuller
|
|
20
|
+
|
|
21
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
22
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
23
|
+
in the Software without restriction, including without limitation the rights
|
|
24
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
25
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
26
|
+
furnished to do so, subject to the following conditions:
|
|
27
|
+
|
|
28
|
+
The above copyright notice and this permission notice shall be included in all
|
|
29
|
+
copies or substantial portions of the Software.
|
|
30
|
+
|
|
31
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
32
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
33
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
34
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
35
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
36
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
37
|
+
SOFTWARE.
|
|
38
|
+
|
|
39
|
+
====================
|
|
40
|
+
|
|
41
|
+
rslearn.models.galileo is adapted from https://github.com/nasaharvest/galileo
|
|
42
|
+
|
|
43
|
+
MIT License
|
|
44
|
+
|
|
45
|
+
Copyright (c) 2024 Presto Authors
|
|
46
|
+
|
|
47
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
48
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
49
|
+
in the Software without restriction, including without limitation the rights
|
|
50
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
51
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
52
|
+
furnished to do so, subject to the following conditions:
|
|
53
|
+
|
|
54
|
+
The above copyright notice and this permission notice shall be included in all
|
|
55
|
+
copies or substantial portions of the Software.
|
|
56
|
+
|
|
57
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
58
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
59
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
60
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
61
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
62
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
63
|
+
SOFTWARE.
|
|
64
|
+
|
|
65
|
+
====================
|
|
66
|
+
|
|
67
|
+
rslearn.models.presto is adapted from https://github.com/nasaharvest/presto
|
|
68
|
+
|
|
69
|
+
MIT License
|
|
70
|
+
|
|
71
|
+
Copyright (c) 2024 Presto Authors
|
|
72
|
+
|
|
73
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
74
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
75
|
+
in the Software without restriction, including without limitation the rights
|
|
76
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
77
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
78
|
+
furnished to do so, subject to the following conditions:
|
|
79
|
+
|
|
80
|
+
The above copyright notice and this permission notice shall be included in all
|
|
81
|
+
copies or substantial portions of the Software.
|
|
82
|
+
|
|
83
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
84
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
85
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
86
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
87
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
88
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
89
|
+
SOFTWARE.
|
|
90
|
+
|
|
91
|
+
====================
|
|
92
|
+
|
|
93
|
+
rslearn.models.prithvi includes code adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
94
|
+
|
|
95
|
+
MIT License
|
|
96
|
+
|
|
97
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
98
|
+
|
|
99
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
100
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
101
|
+
in the Software without restriction, including without limitation the rights
|
|
102
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
103
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
104
|
+
furnished to do so, subject to the following conditions:
|
|
105
|
+
|
|
106
|
+
The above copyright notice and this permission notice shall be included in all
|
|
107
|
+
copies or substantial portions of the Software.
|
|
108
|
+
|
|
109
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
110
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
111
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
112
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
113
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
114
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
115
|
+
SOFTWARE.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.13
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -211,9 +211,10 @@ Project-URL: repository, https://github.com/allenai/rslearn
|
|
|
211
211
|
Requires-Python: >=3.11
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
|
+
License-File: NOTICE
|
|
214
215
|
Requires-Dist: boto3>=1.39
|
|
215
216
|
Requires-Dist: fiona>=1.10
|
|
216
|
-
Requires-Dist: fsspec>=2025.
|
|
217
|
+
Requires-Dist: fsspec>=2025.10.0
|
|
217
218
|
Requires-Dist: jsonargparse>=4.35.0
|
|
218
219
|
Requires-Dist: lightning>=2.5.1.post0
|
|
219
220
|
Requires-Dist: Pillow>=11.3
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.13"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -11,7 +11,7 @@ requires-python = ">=3.11"
|
|
|
11
11
|
dependencies = [
|
|
12
12
|
"boto3>=1.39",
|
|
13
13
|
"fiona>=1.10",
|
|
14
|
-
"fsspec>=2025.
|
|
14
|
+
"fsspec>=2025.10.0", # this is used both directly and indirectly (via universal_pathlib) in our code
|
|
15
15
|
"jsonargparse>=4.35.0",
|
|
16
16
|
"lightning>=2.5.1.post0",
|
|
17
17
|
"Pillow>=11.3",
|
|
@@ -125,7 +125,8 @@ class BandSetConfig:
|
|
|
125
125
|
self,
|
|
126
126
|
config_dict: dict[str, Any],
|
|
127
127
|
dtype: DType,
|
|
128
|
-
bands: list[str],
|
|
128
|
+
bands: list[str] | None = None,
|
|
129
|
+
num_bands: int | None = None,
|
|
129
130
|
format: dict[str, Any] | None = None,
|
|
130
131
|
zoom_offset: int = 0,
|
|
131
132
|
remap: dict[str, Any] | None = None,
|
|
@@ -137,7 +138,10 @@ class BandSetConfig:
|
|
|
137
138
|
Args:
|
|
138
139
|
config_dict: the config dict used to configure this BandSetConfig
|
|
139
140
|
dtype: the pixel value type to store tiles in
|
|
140
|
-
bands: list of band names in this BandSetConfig
|
|
141
|
+
bands: list of band names in this BandSetConfig. One of bands or num_bands
|
|
142
|
+
must be set.
|
|
143
|
+
num_bands: the number of bands in this band set. The bands will be named
|
|
144
|
+
B00, B01, B02, etc.
|
|
141
145
|
format: the format to store tiles in, defaults to geotiff
|
|
142
146
|
zoom_offset: store images at a resolution higher or lower than the window
|
|
143
147
|
resolution. This enables keeping source data at its native resolution,
|
|
@@ -155,6 +159,14 @@ class BandSetConfig:
|
|
|
155
159
|
materialization when creating mosaics, to determine which parts of the
|
|
156
160
|
source images should be copied.
|
|
157
161
|
"""
|
|
162
|
+
if (bands is None and num_bands is None) or (
|
|
163
|
+
bands is not None and num_bands is not None
|
|
164
|
+
):
|
|
165
|
+
raise ValueError("exactly one of bands and num_bands must be set")
|
|
166
|
+
if bands is None:
|
|
167
|
+
assert num_bands is not None
|
|
168
|
+
bands = [f"B{idx}" for idx in range(num_bands)]
|
|
169
|
+
|
|
158
170
|
if class_names is not None and len(bands) != len(class_names):
|
|
159
171
|
raise ValueError(
|
|
160
172
|
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
@@ -187,9 +199,16 @@ class BandSetConfig:
|
|
|
187
199
|
kwargs = dict(
|
|
188
200
|
config_dict=config,
|
|
189
201
|
dtype=DType(config["dtype"]),
|
|
190
|
-
bands=config["bands"],
|
|
191
202
|
)
|
|
192
|
-
for k in [
|
|
203
|
+
for k in [
|
|
204
|
+
"bands",
|
|
205
|
+
"num_bands",
|
|
206
|
+
"format",
|
|
207
|
+
"zoom_offset",
|
|
208
|
+
"remap",
|
|
209
|
+
"class_names",
|
|
210
|
+
"nodata_vals",
|
|
211
|
+
]:
|
|
193
212
|
if k in config:
|
|
194
213
|
kwargs[k] = config[k]
|
|
195
214
|
return BandSetConfig(**kwargs) # type: ignore
|
|
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
|
|
|
827
827
|
kwargs[k] = d[k]
|
|
828
828
|
|
|
829
829
|
return Sentinel1(**kwargs)
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
class Naip(PlanetaryComputer):
|
|
833
|
+
"""A data source for NAIP data on Microsoft Planetary Computer.
|
|
834
|
+
|
|
835
|
+
See https://planetarycomputer.microsoft.com/dataset/naip.
|
|
836
|
+
"""
|
|
837
|
+
|
|
838
|
+
COLLECTION_NAME = "naip"
|
|
839
|
+
ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
|
|
840
|
+
|
|
841
|
+
def __init__(
|
|
842
|
+
self,
|
|
843
|
+
**kwargs: Any,
|
|
844
|
+
):
|
|
845
|
+
"""Initialize a new Naip instance.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
band_names: list of bands to try to ingest.
|
|
849
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
850
|
+
"""
|
|
851
|
+
super().__init__(
|
|
852
|
+
collection_name=self.COLLECTION_NAME,
|
|
853
|
+
asset_bands=self.ASSET_BANDS,
|
|
854
|
+
**kwargs,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
@staticmethod
|
|
858
|
+
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
|
|
859
|
+
"""Creates a new Naip instance from a configuration dictionary."""
|
|
860
|
+
if config.data_source is None:
|
|
861
|
+
raise ValueError("config.data_source is required")
|
|
862
|
+
d = config.data_source.config_dict
|
|
863
|
+
kwargs = {}
|
|
864
|
+
|
|
865
|
+
if "timeout_seconds" in d:
|
|
866
|
+
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
867
|
+
|
|
868
|
+
if "cache_dir" in d:
|
|
869
|
+
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
870
|
+
|
|
871
|
+
simple_optionals = [
|
|
872
|
+
"query",
|
|
873
|
+
"sort_by",
|
|
874
|
+
"sort_ascending",
|
|
875
|
+
"max_items_per_client",
|
|
876
|
+
]
|
|
877
|
+
for k in simple_optionals:
|
|
878
|
+
if k in d:
|
|
879
|
+
kwargs[k] = d[k]
|
|
880
|
+
|
|
881
|
+
return Naip(**kwargs)
|
|
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
|
|
|
118
118
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
119
119
|
windows_prepared=0,
|
|
120
120
|
windows_skipped=len(windows),
|
|
121
|
+
windows_rejected=0,
|
|
121
122
|
get_items_attempts=0,
|
|
122
123
|
)
|
|
123
124
|
)
|
|
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
|
|
|
141
142
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
142
143
|
windows_prepared=0,
|
|
143
144
|
windows_skipped=len(windows),
|
|
145
|
+
windows_rejected=0,
|
|
144
146
|
get_items_attempts=0,
|
|
145
147
|
)
|
|
146
148
|
)
|
|
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
|
|
|
181
183
|
attempts_counter=attempts_counter,
|
|
182
184
|
)
|
|
183
185
|
|
|
186
|
+
windows_prepared = 0
|
|
187
|
+
windows_rejected = 0
|
|
188
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
184
189
|
for window, result in zip(needed_windows, results):
|
|
185
190
|
layer_datas = window.load_layer_datas()
|
|
186
191
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
|
|
|
191
196
|
)
|
|
192
197
|
window.save_layer_datas(layer_datas)
|
|
193
198
|
|
|
199
|
+
# If result is empty and min_matches > 0, window was rejected due to min_matches
|
|
200
|
+
if len(result) == 0 and min_matches > 0:
|
|
201
|
+
windows_rejected += 1
|
|
202
|
+
else:
|
|
203
|
+
windows_prepared += 1
|
|
204
|
+
|
|
205
|
+
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
+
|
|
194
207
|
layer_summaries.append(
|
|
195
208
|
LayerPrepareSummary(
|
|
196
209
|
layer_name=layer_name,
|
|
197
210
|
data_source_name=data_source_cfg.name,
|
|
198
211
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
199
|
-
windows_prepared=
|
|
200
|
-
windows_skipped=
|
|
212
|
+
windows_prepared=windows_prepared,
|
|
213
|
+
windows_skipped=windows_skipped,
|
|
214
|
+
windows_rejected=windows_rejected,
|
|
201
215
|
get_items_attempts=attempts_counter.value,
|
|
202
216
|
)
|
|
203
217
|
)
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""DinoV3 model.
|
|
1
|
+
"""DinoV3 model.
|
|
2
|
+
|
|
3
|
+
This code loads the DINOv3 model. You must obtain the model separately from Meta to use
|
|
4
|
+
it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
|
|
5
|
+
information.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
from enum import StrEnum
|
|
4
9
|
from pathlib import Path
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Apply center cropping on a feature map."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FeatureCenterCrop(torch.nn.Module):
|
|
9
|
+
"""Apply center cropping on the input feature maps."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
sizes: list[tuple[int, int]],
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Create a new FeatureCenterCrop.
|
|
16
|
+
|
|
17
|
+
Only the center of each feature map will be retained and passed to the next
|
|
18
|
+
module.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
sizes: a list of (height, width) tuples, with one tuple for each input
|
|
22
|
+
feature map.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.sizes = sizes
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
+
) -> list[torch.Tensor]:
|
|
30
|
+
"""Apply center cropping on the feature maps.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
features: list of feature maps at different resolutions.
|
|
34
|
+
inputs: original inputs (ignored).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
center cropped feature maps.
|
|
38
|
+
"""
|
|
39
|
+
new_features = []
|
|
40
|
+
for i, feat in enumerate(features):
|
|
41
|
+
height, width = self.sizes[i]
|
|
42
|
+
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"feature map is smaller than the desired height and width"
|
|
45
|
+
)
|
|
46
|
+
start_h = feat.shape[2] // 2 - height // 2
|
|
47
|
+
start_w = feat.shape[3] // 2 - width // 2
|
|
48
|
+
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
|
+
new_features.append(feat)
|
|
50
|
+
return new_features
|
|
@@ -9,6 +9,11 @@ from einops import rearrange
|
|
|
9
9
|
from olmo_core.config import Config
|
|
10
10
|
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
11
|
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.model_loader import (
|
|
13
|
+
ModelID,
|
|
14
|
+
load_model_from_id,
|
|
15
|
+
load_model_from_path,
|
|
16
|
+
)
|
|
12
17
|
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
13
18
|
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
14
19
|
from upath import UPath
|
|
@@ -31,54 +36,115 @@ AUTOCAST_DTYPE_MAP = {
|
|
|
31
36
|
"float32": torch.float32,
|
|
32
37
|
}
|
|
33
38
|
|
|
39
|
+
EMBEDDING_SIZES = {
|
|
40
|
+
ModelID.OLMOEARTH_V1_NANO: 128,
|
|
41
|
+
ModelID.OLMOEARTH_V1_TINY: 192,
|
|
42
|
+
ModelID.OLMOEARTH_V1_BASE: 768,
|
|
43
|
+
ModelID.OLMOEARTH_V1_LARGE: 1024,
|
|
44
|
+
}
|
|
45
|
+
|
|
34
46
|
|
|
35
47
|
class OlmoEarth(torch.nn.Module):
|
|
36
48
|
"""A wrapper to support the OlmoEarth model."""
|
|
37
49
|
|
|
38
50
|
def __init__(
|
|
39
51
|
self,
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
52
|
+
patch_size: int,
|
|
53
|
+
model_id: ModelID | None = None,
|
|
54
|
+
model_path: str | None = None,
|
|
55
|
+
checkpoint_path: str | None = None,
|
|
56
|
+
selector: list[str | int] = ["encoder"],
|
|
44
57
|
forward_kwargs: dict[str, Any] = {},
|
|
45
58
|
random_initialization: bool = False,
|
|
46
59
|
embedding_size: int | None = None,
|
|
47
|
-
patch_size: int | None = None,
|
|
48
60
|
autocast_dtype: str | None = "bfloat16",
|
|
49
61
|
):
|
|
50
62
|
"""Create a new OlmoEarth model.
|
|
51
63
|
|
|
52
64
|
Args:
|
|
53
|
-
|
|
54
|
-
|
|
65
|
+
patch_size: token spatial patch size to use.
|
|
66
|
+
model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
|
|
67
|
+
set.
|
|
68
|
+
model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
|
|
69
|
+
set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
|
|
70
|
+
checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
|
|
71
|
+
set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
|
|
72
|
+
folder.
|
|
55
73
|
selector: an optional sequence of attribute names or list indices to select
|
|
56
|
-
the sub-module that should be applied on the input images.
|
|
74
|
+
the sub-module that should be applied on the input images. Defaults to
|
|
75
|
+
["encoder"] to select only the transformer encoder.
|
|
57
76
|
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
58
77
|
MaskedOlmoEarthSample.
|
|
59
78
|
random_initialization: whether to skip loading the checkpoint so the
|
|
60
79
|
weights are randomly initialized. In this case, the checkpoint is only
|
|
61
80
|
used to define the model architecture.
|
|
62
81
|
embedding_size: optional embedding size to report via
|
|
63
|
-
get_backbone_channels.
|
|
64
|
-
patch_size: optional patch size to report via get_backbone_channels.
|
|
82
|
+
get_backbone_channels (if model_id is not set).
|
|
65
83
|
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
66
84
|
"""
|
|
85
|
+
if (
|
|
86
|
+
sum(
|
|
87
|
+
[
|
|
88
|
+
model_id is not None,
|
|
89
|
+
model_path is not None,
|
|
90
|
+
checkpoint_path is not None,
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
!= 1
|
|
94
|
+
):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"exactly one of model_id, model_path, or checkpoint_path must be set"
|
|
97
|
+
)
|
|
98
|
+
|
|
67
99
|
super().__init__()
|
|
68
|
-
|
|
100
|
+
self.patch_size = patch_size
|
|
69
101
|
self.forward_kwargs = forward_kwargs
|
|
70
102
|
self.embedding_size = embedding_size
|
|
71
|
-
self.patch_size = patch_size
|
|
72
103
|
|
|
73
104
|
if autocast_dtype is not None:
|
|
74
105
|
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
75
106
|
else:
|
|
76
107
|
self.autocast_dtype = None
|
|
77
108
|
|
|
109
|
+
if model_id is not None:
|
|
110
|
+
# Load from Hugging Face.
|
|
111
|
+
model = load_model_from_id(model_id, load_weights=not random_initialization)
|
|
112
|
+
if self.embedding_size is None and model_id in EMBEDDING_SIZES:
|
|
113
|
+
self.embedding_size = EMBEDDING_SIZES[model_id]
|
|
114
|
+
|
|
115
|
+
elif model_path is not None:
|
|
116
|
+
# Load from path.
|
|
117
|
+
model = load_model_from_path(
|
|
118
|
+
UPath(model_path), load_weights=not random_initialization
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
# Load the distributed model checkpoint by path through Olmo Core
|
|
123
|
+
model = self._load_model_from_checkpoint(
|
|
124
|
+
UPath(checkpoint_path), random_initialization
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Select just the portion of the model that we actually want to use.
|
|
128
|
+
for part in selector:
|
|
129
|
+
if isinstance(part, str):
|
|
130
|
+
model = getattr(model, part)
|
|
131
|
+
else:
|
|
132
|
+
model = model[part]
|
|
133
|
+
self.model = model
|
|
134
|
+
|
|
135
|
+
def _load_model_from_checkpoint(
|
|
136
|
+
self, checkpoint_upath: UPath, random_initialization: bool
|
|
137
|
+
) -> torch.nn.Module:
|
|
138
|
+
"""Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
|
|
139
|
+
|
|
140
|
+
The folder should contain config.json as well as the model_and_optim folder
|
|
141
|
+
that contains the distributed checkpoint. This is the format produced by
|
|
142
|
+
pre-training runs in olmoearth_pretrain.
|
|
143
|
+
"""
|
|
78
144
|
# Load the model config and initialize it.
|
|
79
145
|
# We avoid loading the train module here because it depends on running within
|
|
80
146
|
# olmo_core.
|
|
81
|
-
with (
|
|
147
|
+
with (checkpoint_upath / "config.json").open() as f:
|
|
82
148
|
config_dict = json.load(f)
|
|
83
149
|
model_config = Config.from_dict(config_dict["model"])
|
|
84
150
|
|
|
@@ -86,22 +152,14 @@ class OlmoEarth(torch.nn.Module):
|
|
|
86
152
|
|
|
87
153
|
# Load the checkpoint.
|
|
88
154
|
if not random_initialization:
|
|
89
|
-
train_module_dir =
|
|
155
|
+
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
90
156
|
if train_module_dir.exists():
|
|
91
157
|
load_model_and_optim_state(str(train_module_dir), model)
|
|
92
158
|
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
93
159
|
else:
|
|
94
160
|
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
95
|
-
else:
|
|
96
|
-
logger.info("skipping loading OlmoEarth encoder")
|
|
97
161
|
|
|
98
|
-
|
|
99
|
-
for part in selector:
|
|
100
|
-
if isinstance(part, str):
|
|
101
|
-
model = getattr(model, part)
|
|
102
|
-
else:
|
|
103
|
-
model = model[part]
|
|
104
|
-
self.model = model
|
|
162
|
+
return model
|
|
105
163
|
|
|
106
164
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
107
165
|
"""Compute feature maps from the OlmoEarth backbone.
|
|
@@ -167,13 +225,16 @@ class OlmoEarth(torch.nn.Module):
|
|
|
167
225
|
if isinstance(self.model, Encoder):
|
|
168
226
|
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
169
227
|
tokens_and_masks = self.model(
|
|
170
|
-
sample,
|
|
228
|
+
sample,
|
|
229
|
+
fast_pass=True,
|
|
230
|
+
patch_size=self.patch_size,
|
|
231
|
+
**self.forward_kwargs,
|
|
171
232
|
)["tokens_and_masks"]
|
|
172
233
|
else:
|
|
173
234
|
# Other models like STEncoder do not have this option supported.
|
|
174
|
-
tokens_and_masks = self.model(
|
|
175
|
-
|
|
176
|
-
]
|
|
235
|
+
tokens_and_masks = self.model(
|
|
236
|
+
sample, patch_size=self.patch_size, **self.forward_kwargs
|
|
237
|
+
)["tokens_and_masks"]
|
|
177
238
|
|
|
178
239
|
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
179
240
|
features = []
|
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
"""Prithvi V2.
|
|
1
|
+
"""Prithvi V2.
|
|
2
|
+
|
|
3
|
+
This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
4
|
+
|
|
5
|
+
The code is released under:
|
|
6
|
+
|
|
7
|
+
MIT License
|
|
8
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import json
|
|
4
12
|
import logging
|
|
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
94
94
|
restore_config: RestoreConfig | None = None,
|
|
95
95
|
print_parameters: bool = False,
|
|
96
96
|
print_model: bool = False,
|
|
97
|
-
strict_loading: bool = True,
|
|
98
97
|
# Deprecated options.
|
|
99
98
|
lr: float = 1e-3,
|
|
100
99
|
plateau: bool = False,
|
|
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
118
117
|
print_parameters: whether to print the list of model parameters after model
|
|
119
118
|
initialization
|
|
120
119
|
print_model: whether to print the model after model initialization
|
|
121
|
-
strict_loading: whether to strictly load the model parameters.
|
|
122
120
|
lr: deprecated.
|
|
123
121
|
plateau: deprecated.
|
|
124
122
|
plateau_factor: deprecated.
|
|
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
130
|
self.visualize_dir = visualize_dir
|
|
133
131
|
self.metrics_file = metrics_file
|
|
134
132
|
self.restore_config = restore_config
|
|
135
|
-
self.strict_loading = strict_loading
|
|
136
133
|
|
|
137
134
|
self.scheduler_factory: SchedulerFactory | None = None
|
|
138
135
|
if scheduler:
|