rslearn 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.4/rslearn.egg-info → rslearn-0.0.6}/PKG-INFO +4 -1
- {rslearn-0.0.4 → rslearn-0.0.6}/pyproject.toml +6 -2
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/arg_parser.py +1 -22
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/copernicus.py +6 -4
- rslearn-0.0.6/rslearn/data_sources/eurocrops.py +246 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/local_files.py +11 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/openstreetmap.py +2 -4
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/dataset.py +4 -1
- rslearn-0.0.6/rslearn/models/copernicusfm.py +216 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/__init__.py +1 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/aurora/area.py +50 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/model_vit.py +348 -0
- rslearn-0.0.6/rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
- rslearn-0.0.6/rslearn/models/panopticon.py +167 -0
- rslearn-0.0.6/rslearn/models/presto/__init__.py +5 -0
- rslearn-0.0.6/rslearn/models/presto/presto.py +247 -0
- rslearn-0.0.6/rslearn/models/presto/single_file_presto.py +932 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/unet.py +15 -0
- rslearn-0.0.6/rslearn/template_params.py +26 -0
- {rslearn-0.0.4 → rslearn-0.0.6/rslearn.egg-info}/PKG-INFO +4 -1
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn.egg-info/SOURCES.txt +15 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/LICENSE +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/README.md +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/const.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/main.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/py.typed +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.4 → rslearn-0.0.6}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -205,6 +205,9 @@ License: Apache License
|
|
|
205
205
|
See the License for the specific language governing permissions and
|
|
206
206
|
limitations under the License.
|
|
207
207
|
|
|
208
|
+
Project-URL: homepage, https://github.com/allenai/rslearn
|
|
209
|
+
Project-URL: issues, https://github.com/allenai/rslearn/issues
|
|
210
|
+
Project-URL: repository, https://github.com/allenai/rslearn
|
|
208
211
|
Requires-Python: >=3.11
|
|
209
212
|
Description-Content-Type: text/markdown
|
|
210
213
|
License-File: LICENSE
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.6"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "OlmoEarth Team" },
|
|
@@ -8,7 +8,6 @@ authors = [
|
|
|
8
8
|
readme = "README.md"
|
|
9
9
|
license = {file = "LICENSE"}
|
|
10
10
|
requires-python = ">=3.11"
|
|
11
|
-
|
|
12
11
|
dependencies = [
|
|
13
12
|
"boto3>=1.39",
|
|
14
13
|
"class_registry>=2.1",
|
|
@@ -67,6 +66,11 @@ dev = [
|
|
|
67
66
|
"pytest-xdist",
|
|
68
67
|
]
|
|
69
68
|
|
|
69
|
+
[project.urls]
|
|
70
|
+
homepage = "https://github.com/allenai/rslearn"
|
|
71
|
+
issues = "https://github.com/allenai/rslearn/issues"
|
|
72
|
+
repository = "https://github.com/allenai/rslearn"
|
|
73
|
+
|
|
70
74
|
[build-system]
|
|
71
75
|
requires = ["setuptools>=61"]
|
|
72
76
|
build-backend = "setuptools.build_meta"
|
|
@@ -1,33 +1,12 @@
|
|
|
1
1
|
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
import re
|
|
5
4
|
from typing import Any
|
|
6
5
|
|
|
7
6
|
from jsonargparse import Namespace
|
|
8
7
|
from lightning.pytorch.cli import LightningArgumentParser
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
def substitute_env_vars_in_string(content: str) -> str:
|
|
12
|
-
"""Substitute environment variables in a string.
|
|
13
|
-
|
|
14
|
-
Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
|
|
15
|
-
This works on raw string content before YAML parsing.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
content: The string content containing template variables
|
|
19
|
-
|
|
20
|
-
Returns:
|
|
21
|
-
The string with environment variables substituted
|
|
22
|
-
"""
|
|
23
|
-
pattern = r"\$\{([^}]+)\}"
|
|
24
|
-
|
|
25
|
-
def replace_variable(match_obj: re.Match[str]) -> str:
|
|
26
|
-
var_name = match_obj.group(1)
|
|
27
|
-
env_value = os.getenv(var_name, "")
|
|
28
|
-
return env_value if env_value is not None else ""
|
|
29
|
-
|
|
30
|
-
return re.sub(pattern, replace_variable, content)
|
|
9
|
+
from rslearn.template_params import substitute_env_vars_in_string
|
|
31
10
|
|
|
32
11
|
|
|
33
12
|
class RslearnArgumentParser(LightningArgumentParser):
|
|
@@ -319,7 +319,6 @@ class Copernicus(DataSource):
|
|
|
319
319
|
then we attempt to read the username/password from COPERNICUS_USERNAME
|
|
320
320
|
and COPERNICUS_PASSWORD (this is useful since access tokens are only
|
|
321
321
|
valid for an hour).
|
|
322
|
-
password: set API username/password instead of access token.
|
|
323
322
|
query_filter: filter string to include when searching for items. This will
|
|
324
323
|
be appended to other name, geographic, and sensing time filters where
|
|
325
324
|
applicable. For example, "Collection/Name eq 'SENTINEL-2'". See the API
|
|
@@ -368,6 +367,7 @@ class Copernicus(DataSource):
|
|
|
368
367
|
"order_by",
|
|
369
368
|
"sort_by",
|
|
370
369
|
"sort_desc",
|
|
370
|
+
"timeout",
|
|
371
371
|
]
|
|
372
372
|
for k in simple_optionals:
|
|
373
373
|
if k in d:
|
|
@@ -709,6 +709,8 @@ class Sentinel2(Copernicus):
|
|
|
709
709
|
"B12": ["B12"],
|
|
710
710
|
"B8A": ["B8A"],
|
|
711
711
|
"TCI": ["R", "G", "B"],
|
|
712
|
+
# L1C-only products.
|
|
713
|
+
"B10": ["B10"],
|
|
712
714
|
# L2A-only products.
|
|
713
715
|
"AOT": ["AOT"],
|
|
714
716
|
"WVP": ["WVP"],
|
|
@@ -809,17 +811,16 @@ class Sentinel2(Copernicus):
|
|
|
809
811
|
|
|
810
812
|
kwargs: dict[str, Any] = dict(
|
|
811
813
|
assets=list(needed_assets),
|
|
814
|
+
product_type=Sentinel2ProductType[d["product_type"]],
|
|
812
815
|
)
|
|
813
816
|
|
|
814
|
-
if "product_type" in d:
|
|
815
|
-
kwargs["product_type"] = Sentinel2ProductType(d["product_type"])
|
|
816
|
-
|
|
817
817
|
simple_optionals = [
|
|
818
818
|
"harmonize",
|
|
819
819
|
"access_token",
|
|
820
820
|
"order_by",
|
|
821
821
|
"sort_by",
|
|
822
822
|
"sort_desc",
|
|
823
|
+
"timeout",
|
|
823
824
|
]
|
|
824
825
|
for k in simple_optionals:
|
|
825
826
|
if k in d:
|
|
@@ -965,6 +966,7 @@ class Sentinel1(Copernicus):
|
|
|
965
966
|
"order_by",
|
|
966
967
|
"sort_by",
|
|
967
968
|
"sort_desc",
|
|
969
|
+
"timeout",
|
|
968
970
|
]
|
|
969
971
|
for k in simple_optionals:
|
|
970
972
|
if k in d:
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""Data source for vector EuroCrops crop type data."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
import zipfile
|
|
7
|
+
from datetime import UTC, datetime, timedelta
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import fiona
|
|
11
|
+
import requests
|
|
12
|
+
from rasterio.crs import CRS
|
|
13
|
+
from upath import UPath
|
|
14
|
+
|
|
15
|
+
from rslearn.config import QueryConfig, VectorLayerConfig
|
|
16
|
+
from rslearn.const import WGS84_PROJECTION
|
|
17
|
+
from rslearn.data_sources import DataSource, Item
|
|
18
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
19
|
+
from rslearn.log_utils import get_logger
|
|
20
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
21
|
+
from rslearn.utils.feature import Feature
|
|
22
|
+
from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EuroCropsItem(Item):
|
|
28
|
+
"""An item in the EuroCrops data source.
|
|
29
|
+
|
|
30
|
+
For simplicity, we have just one item per year, so each item combines all of the
|
|
31
|
+
country-level files for that year.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, name: str, geometry: STGeometry, zip_fnames: list[str]):
|
|
35
|
+
"""Creates a new EuroCropsItem.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
name: unique name of the item. It is just the year that this item
|
|
39
|
+
corresponds to.
|
|
40
|
+
geometry: the spatial and temporal extent of the item
|
|
41
|
+
zip_fnames: the filenames of the zip files that contain country-level crop
|
|
42
|
+
type data for this year.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(name, geometry)
|
|
45
|
+
self.zip_fnames = zip_fnames
|
|
46
|
+
|
|
47
|
+
def serialize(self) -> dict:
|
|
48
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
49
|
+
d = super().serialize()
|
|
50
|
+
d["zip_fnames"] = self.zip_fnames
|
|
51
|
+
return d
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def deserialize(d: dict) -> "EuroCropsItem":
|
|
55
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
56
|
+
item = super(EuroCropsItem, EuroCropsItem).deserialize(d)
|
|
57
|
+
return EuroCropsItem(
|
|
58
|
+
name=item.name, geometry=item.geometry, zip_fnames=d["zip_fnames"]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class EuroCrops(DataSource[EuroCropsItem]):
|
|
63
|
+
"""A data source for EuroCrops vector data (v11).
|
|
64
|
+
|
|
65
|
+
See https://zenodo.org/records/14094196 for details.
|
|
66
|
+
|
|
67
|
+
While the source data is split into country-level files, this data source uses one
|
|
68
|
+
item per year for simplicity. So each item corresponds to all of the country-level
|
|
69
|
+
files for that year.
|
|
70
|
+
|
|
71
|
+
Note that the RO_ny.zip file is not used.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
BASE_URL = "https://zenodo.org/records/14094196/files/"
|
|
75
|
+
FILENAMES_BY_YEAR = {
|
|
76
|
+
2018: [
|
|
77
|
+
"FR_2018.zip",
|
|
78
|
+
],
|
|
79
|
+
2019: [
|
|
80
|
+
"DK_2019.zip",
|
|
81
|
+
],
|
|
82
|
+
2020: [
|
|
83
|
+
"ES_NA_2020.zip",
|
|
84
|
+
"FI_2020.zip",
|
|
85
|
+
"HR_2020.zip",
|
|
86
|
+
"NL_2020.zip",
|
|
87
|
+
],
|
|
88
|
+
2021: [
|
|
89
|
+
"AT_2021.zip",
|
|
90
|
+
"BE_VLG_2021.zip",
|
|
91
|
+
"BE_WAL_2021.zip",
|
|
92
|
+
"EE_2021.zip",
|
|
93
|
+
"LT_2021.zip",
|
|
94
|
+
"LV_2021.zip",
|
|
95
|
+
"PT_2021.zip",
|
|
96
|
+
"SE_2021.zip",
|
|
97
|
+
"SI_2021.zip",
|
|
98
|
+
"SK_2021.zip",
|
|
99
|
+
],
|
|
100
|
+
2023: [
|
|
101
|
+
"CZ_2023.zip",
|
|
102
|
+
"DE_BB_2023.zip",
|
|
103
|
+
"DE_LS_2021.zip",
|
|
104
|
+
"DE_NRW_2021.zip",
|
|
105
|
+
"ES_2023.zip",
|
|
106
|
+
"IE_2023.zip",
|
|
107
|
+
],
|
|
108
|
+
}
|
|
109
|
+
TIMEOUT = timedelta(seconds=10)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def from_config(config: VectorLayerConfig, ds_path: UPath) -> "EuroCrops":
|
|
113
|
+
"""Creates a new EuroCrops instance from a configuration dictionary."""
|
|
114
|
+
if config.data_source is None:
|
|
115
|
+
raise ValueError("data_source is required")
|
|
116
|
+
return EuroCrops()
|
|
117
|
+
|
|
118
|
+
def _get_all_items(self) -> list[EuroCropsItem]:
|
|
119
|
+
"""Get a list of all available items in the data source."""
|
|
120
|
+
items: list[EuroCropsItem] = []
|
|
121
|
+
for year, fnames in self.FILENAMES_BY_YEAR.items():
|
|
122
|
+
items.append(
|
|
123
|
+
EuroCropsItem(
|
|
124
|
+
str(year),
|
|
125
|
+
get_global_geometry(
|
|
126
|
+
time_range=(
|
|
127
|
+
datetime(year, 1, 1, tzinfo=UTC),
|
|
128
|
+
datetime(year + 1, 1, 1, tzinfo=UTC),
|
|
129
|
+
),
|
|
130
|
+
),
|
|
131
|
+
fnames,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
return items
|
|
135
|
+
|
|
136
|
+
def get_items(
|
|
137
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
138
|
+
) -> list[list[list[EuroCropsItem]]]:
|
|
139
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
geometries: the spatiotemporal geometries
|
|
143
|
+
query_config: the query configuration
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List of groups of items that should be retrieved for each geometry.
|
|
147
|
+
"""
|
|
148
|
+
wgs84_geometries = [
|
|
149
|
+
geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
|
|
150
|
+
]
|
|
151
|
+
all_items = self._get_all_items()
|
|
152
|
+
groups = []
|
|
153
|
+
for geometry in wgs84_geometries:
|
|
154
|
+
cur_groups = match_candidate_items_to_window(
|
|
155
|
+
geometry, all_items, query_config
|
|
156
|
+
)
|
|
157
|
+
groups.append(cur_groups)
|
|
158
|
+
return groups
|
|
159
|
+
|
|
160
|
+
def deserialize_item(self, serialized_item: Any) -> EuroCropsItem:
|
|
161
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
162
|
+
return EuroCropsItem.deserialize(serialized_item)
|
|
163
|
+
|
|
164
|
+
def _extract_features(self, fname: str) -> list[Feature]:
|
|
165
|
+
"""Download the given zip file, extract shapefile, and return list of features."""
|
|
166
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
167
|
+
# Download the zip file.
|
|
168
|
+
url = self.BASE_URL + fname
|
|
169
|
+
logger.debug(f"Downloading zip file from {url}")
|
|
170
|
+
response = requests.get(
|
|
171
|
+
url,
|
|
172
|
+
stream=True,
|
|
173
|
+
timeout=self.TIMEOUT.total_seconds(),
|
|
174
|
+
allow_redirects=False,
|
|
175
|
+
)
|
|
176
|
+
response.raise_for_status()
|
|
177
|
+
zip_fname = os.path.join(tmp_dir, "data.zip")
|
|
178
|
+
with open(zip_fname, "wb") as f:
|
|
179
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
180
|
+
f.write(chunk)
|
|
181
|
+
|
|
182
|
+
# Extract all of the files and look for shapefile filename.
|
|
183
|
+
logger.debug(f"Extracting zip file {fname}")
|
|
184
|
+
with zipfile.ZipFile(zip_fname) as zip_f:
|
|
185
|
+
zip_f.extractall(path=tmp_dir)
|
|
186
|
+
|
|
187
|
+
# The shapefiles or geopackage files can appear at any level in the hierarchy.
|
|
188
|
+
# Most zip files contain one but some contain multiple (one per region).
|
|
189
|
+
shp_fnames = glob.glob(
|
|
190
|
+
"**/*.shp", root_dir=tmp_dir, recursive=True
|
|
191
|
+
) + glob.glob("**/*.gpkg", root_dir=tmp_dir, recursive=True)
|
|
192
|
+
if len(shp_fnames) == 0:
|
|
193
|
+
tmp_dir_fnames = os.listdir(tmp_dir)
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"expected {fname} to contain .shp file but none found (matches={shp_fnames}, ls={tmp_dir_fnames})"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Load the features from the shapefile(s).
|
|
199
|
+
features = []
|
|
200
|
+
for shp_fname in shp_fnames:
|
|
201
|
+
logger.debug(f"Loading feature list from {shp_fname}")
|
|
202
|
+
with fiona.open(os.path.join(tmp_dir, shp_fname)) as src:
|
|
203
|
+
crs = CRS.from_wkt(src.crs.to_wkt())
|
|
204
|
+
# Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
|
|
205
|
+
# should be 1 projection unit/pixel.
|
|
206
|
+
projection = Projection(crs, 1, 1)
|
|
207
|
+
|
|
208
|
+
for feat in src:
|
|
209
|
+
features.append(
|
|
210
|
+
Feature.from_geojson(
|
|
211
|
+
projection,
|
|
212
|
+
{
|
|
213
|
+
"type": "Feature",
|
|
214
|
+
"geometry": dict(feat.geometry),
|
|
215
|
+
"properties": dict(feat.properties),
|
|
216
|
+
},
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return features
|
|
221
|
+
|
|
222
|
+
def ingest(
|
|
223
|
+
self,
|
|
224
|
+
tile_store: TileStoreWithLayer,
|
|
225
|
+
items: list[EuroCropsItem],
|
|
226
|
+
geometries: list[list[STGeometry]],
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Ingest items into the given tile store.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
tile_store: the tile store to ingest into
|
|
232
|
+
items: the items to ingest
|
|
233
|
+
geometries: a list of geometries needed for each item
|
|
234
|
+
"""
|
|
235
|
+
for item in items:
|
|
236
|
+
if tile_store.is_vector_ready(item.name):
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
# Get features across all shapefiles.
|
|
240
|
+
features: list[Feature] = []
|
|
241
|
+
for fname in item.zip_fnames:
|
|
242
|
+
logger.debug(f"Getting features from {fname} for item {item.name}")
|
|
243
|
+
features.extend(self._extract_features(fname))
|
|
244
|
+
|
|
245
|
+
logger.debug(f"Writing features for {item.name} to the tile store")
|
|
246
|
+
tile_store.write_vector(item.name, features)
|
|
@@ -232,6 +232,17 @@ class RasterImporter(Importer):
|
|
|
232
232
|
projection = Projection(crs, x_resolution, y_resolution)
|
|
233
233
|
geometry = STGeometry(projection, shp, None)
|
|
234
234
|
|
|
235
|
+
if geometry.is_too_large():
|
|
236
|
+
geometry = get_global_geometry(time_range=None)
|
|
237
|
+
logger.warning(
|
|
238
|
+
"Global geometry detected: this geometry will be matched against all "
|
|
239
|
+
"windows in the rslearn dataset. When using settings like "
|
|
240
|
+
"max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
|
|
241
|
+
"the geometry’s valid bounds to be materialized from the global raster "
|
|
242
|
+
"instead of a more appropriate source. Consider using COMPOSITE mode, "
|
|
243
|
+
"or increasing max_matches if this behavior is unintended."
|
|
244
|
+
)
|
|
245
|
+
|
|
235
246
|
if spec.name:
|
|
236
247
|
item_name = spec.name
|
|
237
248
|
else:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Data source for
|
|
1
|
+
"""Data source for OpenStreetMap vector features."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import shutil
|
|
@@ -392,7 +392,7 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
392
392
|
bounds_fname: UPath,
|
|
393
393
|
categories: dict[str, Filter],
|
|
394
394
|
):
|
|
395
|
-
"""Initialize a new
|
|
395
|
+
"""Initialize a new OpenStreetMap instance.
|
|
396
396
|
|
|
397
397
|
Args:
|
|
398
398
|
config: the configuration of this layer.
|
|
@@ -508,8 +508,6 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
508
508
|
items: the items to ingest
|
|
509
509
|
geometries: a list of geometries needed for each item
|
|
510
510
|
"""
|
|
511
|
-
item_names = [item.name for item in items]
|
|
512
|
-
item_names.sort()
|
|
513
511
|
for cur_item, cur_geometries in zip(items, geometries):
|
|
514
512
|
if tile_store.is_vector_ready(cur_item.name):
|
|
515
513
|
continue
|
|
@@ -8,6 +8,7 @@ from upath import UPath
|
|
|
8
8
|
|
|
9
9
|
from rslearn.config import load_layer_config
|
|
10
10
|
from rslearn.log_utils import get_logger
|
|
11
|
+
from rslearn.template_params import substitute_env_vars_in_string
|
|
11
12
|
from rslearn.tile_stores import TileStore, load_tile_store
|
|
12
13
|
|
|
13
14
|
from .index import DatasetIndex
|
|
@@ -52,7 +53,9 @@ class Dataset:
|
|
|
52
53
|
|
|
53
54
|
# Load dataset configuration.
|
|
54
55
|
with (self.path / "config.json").open("r") as f:
|
|
55
|
-
|
|
56
|
+
config_content = f.read()
|
|
57
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
58
|
+
config = json.loads(config_content)
|
|
56
59
|
self.layers = {}
|
|
57
60
|
for layer_name, d in config["layers"].items():
|
|
58
61
|
# Layer names must not contain period, since we use period to
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Copernicus FM model."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from upath import UPath
|
|
11
|
+
|
|
12
|
+
from .copernicusfm_src.model_vit import vit_base_patch16
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CopernicusFMModality(Enum):
|
|
18
|
+
"""Modality for Copernicus FM."""
|
|
19
|
+
|
|
20
|
+
SENTINEL2_L2A = "sentinel2_l2a"
|
|
21
|
+
SENTINEL1 = "sentinel1"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
|
|
25
|
+
# https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml
|
|
26
|
+
CopernicusFMModality.SENTINEL2_L2A.value: {
|
|
27
|
+
"band_names": [
|
|
28
|
+
"B01",
|
|
29
|
+
"B02",
|
|
30
|
+
"B03",
|
|
31
|
+
"B04",
|
|
32
|
+
"B05",
|
|
33
|
+
"B06",
|
|
34
|
+
"B07",
|
|
35
|
+
"B08",
|
|
36
|
+
"B8A",
|
|
37
|
+
"B09",
|
|
38
|
+
"B10",
|
|
39
|
+
"B11",
|
|
40
|
+
"B12",
|
|
41
|
+
],
|
|
42
|
+
"band_wavelengths": [
|
|
43
|
+
440,
|
|
44
|
+
490,
|
|
45
|
+
560,
|
|
46
|
+
665,
|
|
47
|
+
705,
|
|
48
|
+
740,
|
|
49
|
+
783,
|
|
50
|
+
842,
|
|
51
|
+
860,
|
|
52
|
+
940,
|
|
53
|
+
1370,
|
|
54
|
+
1610,
|
|
55
|
+
2190,
|
|
56
|
+
],
|
|
57
|
+
"band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180],
|
|
58
|
+
},
|
|
59
|
+
# https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml
|
|
60
|
+
CopernicusFMModality.SENTINEL1.value: {
|
|
61
|
+
"band_names": ["vv", "vh"],
|
|
62
|
+
"band_wavelengths": [50000000, 50000000],
|
|
63
|
+
"band_bandwidths": [1e9, 1e9],
|
|
64
|
+
},
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class CopernicusFM(torch.nn.Module):
|
|
69
|
+
"""Wrapper for Copernicus FM to ingest Masked Helios Sample."""
|
|
70
|
+
|
|
71
|
+
image_resolution = 224
|
|
72
|
+
patch_size = 16
|
|
73
|
+
input_mode = "spectral"
|
|
74
|
+
# Don't need this as band order is provided
|
|
75
|
+
supported_modalities = [
|
|
76
|
+
CopernicusFMModality.SENTINEL2_L2A.value,
|
|
77
|
+
CopernicusFMModality.SENTINEL1.value,
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
band_order: dict[str, list[str]],
|
|
83
|
+
load_directory: str | None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Initialize the Copernicus FM wrapper.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
band_order: The band order for each modality
|
|
89
|
+
load_directory: The directory to load from, if None no weights are loaded
|
|
90
|
+
"""
|
|
91
|
+
super().__init__()
|
|
92
|
+
|
|
93
|
+
# global_pool=True so that we initialize the fc_norm layer
|
|
94
|
+
self.band_order = band_order
|
|
95
|
+
self.model = vit_base_patch16(num_classes=10, global_pool=True)
|
|
96
|
+
if load_directory is not None:
|
|
97
|
+
check_point = torch.load(
|
|
98
|
+
UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth",
|
|
99
|
+
weights_only=True,
|
|
100
|
+
)
|
|
101
|
+
if "model" in check_point:
|
|
102
|
+
state_dict = check_point["model"]
|
|
103
|
+
else:
|
|
104
|
+
state_dict = check_point
|
|
105
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
106
|
+
|
|
107
|
+
# take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same
|
|
108
|
+
# ordering as the Helios band orders, defined by Modality.band_order
|
|
109
|
+
self.modality_to_wavelength_bandwidths = {}
|
|
110
|
+
for modality in self.supported_modalities:
|
|
111
|
+
wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
|
|
112
|
+
wavelengths = []
|
|
113
|
+
bandwidths = []
|
|
114
|
+
modality_band_order = self.band_order.get(modality, None)
|
|
115
|
+
if modality_band_order is None:
|
|
116
|
+
logger.warning(
|
|
117
|
+
f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
|
|
118
|
+
)
|
|
119
|
+
continue
|
|
120
|
+
for b in modality_band_order:
|
|
121
|
+
cfm_idx = wavelength_bandwidths["band_names"].index(b)
|
|
122
|
+
wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
|
|
123
|
+
bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
|
|
124
|
+
self.modality_to_wavelength_bandwidths[modality] = {
|
|
125
|
+
"band_bandwidths": bandwidths,
|
|
126
|
+
"band_wavelengths": wavelengths,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
"""Process individual modality data.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
data: Input tensor of shape [B, C, H, W]
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
list of tensors of shape [B, C, H, W]
|
|
137
|
+
"""
|
|
138
|
+
# Get original dimensions
|
|
139
|
+
original_height = data.shape[2]
|
|
140
|
+
new_height = self.patch_size if original_height == 1 else self.image_resolution
|
|
141
|
+
data = F.interpolate(
|
|
142
|
+
data,
|
|
143
|
+
size=(new_height, new_height),
|
|
144
|
+
mode="bilinear",
|
|
145
|
+
align_corners=False,
|
|
146
|
+
)
|
|
147
|
+
return data
|
|
148
|
+
|
|
149
|
+
def prepare_input(
|
|
150
|
+
self,
|
|
151
|
+
inputs: dict[str, torch.Tensor],
|
|
152
|
+
) -> tuple[torch.Tensor, list[int], list[int]]:
|
|
153
|
+
"""Prepare input for the CopernicusFM model from MaskedHeliosSample."""
|
|
154
|
+
wavelengths: list[int] = []
|
|
155
|
+
bandwidths: list[int] = []
|
|
156
|
+
all_processed_data: list[list[torch.Tensor]] = []
|
|
157
|
+
for modality in inputs.keys():
|
|
158
|
+
if modality not in self.supported_modalities:
|
|
159
|
+
logger.debug(
|
|
160
|
+
f"Skipping modality {modality} as it is not in the supported "
|
|
161
|
+
f"modalities list {self.supported_modalities}"
|
|
162
|
+
)
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
data = inputs[modality]
|
|
166
|
+
|
|
167
|
+
if data is None:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
all_processed_data.append(self._resize_data(data))
|
|
171
|
+
wavelengths.extend(
|
|
172
|
+
self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"]
|
|
173
|
+
)
|
|
174
|
+
bandwidths.extend(
|
|
175
|
+
self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"]
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
concatenated_processed_data = torch.cat(all_processed_data, dim=1)
|
|
179
|
+
return concatenated_processed_data, wavelengths, bandwidths
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self,
|
|
183
|
+
inputs: list[dict[str, torch.Tensor]],
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
"""Forward pass through CopernicusFM model."""
|
|
186
|
+
batch_inputs = {
|
|
187
|
+
key: torch.stack([inp[key] for inp in inputs], dim=0)
|
|
188
|
+
for key in inputs[0].keys()
|
|
189
|
+
}
|
|
190
|
+
# Prepare input
|
|
191
|
+
data, wavelengths, bandwidths = self.prepare_input(batch_inputs)
|
|
192
|
+
meta = torch.full(
|
|
193
|
+
(1, 4), float("nan"), device=data.device
|
|
194
|
+
) # [lon, lat, delta_time, patch_token_area], assume unknown
|
|
195
|
+
# "The embed tensor contains the encoded image features, which can be used for downstream tasks."
|
|
196
|
+
_, timestep_output = self.model(
|
|
197
|
+
data,
|
|
198
|
+
meta,
|
|
199
|
+
wavelengths,
|
|
200
|
+
bandwidths,
|
|
201
|
+
None,
|
|
202
|
+
self.input_mode,
|
|
203
|
+
self.patch_size,
|
|
204
|
+
)
|
|
205
|
+
# no norm, following
|
|
206
|
+
# https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py
|
|
207
|
+
side = math.isqrt(timestep_output.shape[1])
|
|
208
|
+
output_features = rearrange(
|
|
209
|
+
timestep_output, "b (h w) c -> b c h w ", h=side, w=side
|
|
210
|
+
)
|
|
211
|
+
return [output_features]
|
|
212
|
+
|
|
213
|
+
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
214
|
+
"""Returns the output channels of this model when used as a backbone."""
|
|
215
|
+
# TODO: load this from a constant depending on the model size
|
|
216
|
+
return [(self.patch_size, 768)]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# mypy: ignore-errors
|