rslearn 0.0.19__tar.gz → 0.0.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.19/rslearn.egg-info → rslearn-0.0.20}/PKG-INFO +1 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/pyproject.toml +1 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/anysat.py +35 -33
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clip.py +5 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/croma.py +11 -3
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/dinov3.py +2 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/faster_rcnn.py +2 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/galileo.py +58 -31
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/module_wrapper.py +6 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/molmo.py +4 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/model.py +93 -29
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/norm.py +5 -3
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/presto.py +45 -15
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/prithvi.py +9 -7
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/sam2_enc.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/satlaspretrain.py +4 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/simple_time_series.py +36 -16
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/ssl4eo_s12.py +19 -14
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/swin.py +3 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/terramind.py +5 -4
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/all_patches_dataset.py +34 -14
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/dataset.py +66 -10
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/model_context.py +35 -1
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/classification.py +8 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/detection.py +3 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/multi_task.py +2 -3
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/per_pixel_regression.py +14 -5
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/regression.py +8 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/segmentation.py +13 -4
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/task.py +2 -2
- rslearn-0.0.20/rslearn/train/transforms/concatenate.py +89 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/crop.py +22 -8
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/flip.py +13 -5
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/mask.py +11 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/normalize.py +46 -15
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/pad.py +15 -3
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/resize.py +18 -9
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/select_bands.py +11 -2
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/sentinel1.py +18 -3
- {rslearn-0.0.19 → rslearn-0.0.20/rslearn.egg-info}/PKG-INFO +1 -1
- rslearn-0.0.19/rslearn/train/transforms/concatenate.py +0 -49
- {rslearn-0.0.19 → rslearn-0.0.20}/LICENSE +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/NOTICE +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/README.md +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/const.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/file.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/storage/storage.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/lightning_cli.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/main.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/attention_pooling.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/component.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/concatenate_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/py.typed +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/template_params.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.19 → rslearn-0.0.20}/setup.cfg +0 -0
|
@@ -4,6 +4,8 @@ This code loads the AnySat model from torch hub. See
|
|
|
4
4
|
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
7
9
|
import torch
|
|
8
10
|
from einops import rearrange
|
|
9
11
|
|
|
@@ -53,7 +55,6 @@ class AnySat(FeatureExtractor):
|
|
|
53
55
|
self,
|
|
54
56
|
modalities: list[str],
|
|
55
57
|
patch_size_meters: int,
|
|
56
|
-
dates: dict[str, list[int]],
|
|
57
58
|
output: str = "patch",
|
|
58
59
|
output_modality: str | None = None,
|
|
59
60
|
hub_repo: str = "gastruc/anysat",
|
|
@@ -85,14 +86,6 @@ class AnySat(FeatureExtractor):
|
|
|
85
86
|
if m not in MODALITY_RESOLUTIONS:
|
|
86
87
|
raise ValueError(f"Invalid modality: {m}")
|
|
87
88
|
|
|
88
|
-
if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
|
|
89
|
-
raise ValueError("`dates` keys must be time-series modalities only.")
|
|
90
|
-
for m in modalities:
|
|
91
|
-
if m in TIME_SERIES_MODALITIES and m not in dates:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
f"Missing required dates for time-series modality '{m}'."
|
|
94
|
-
)
|
|
95
|
-
|
|
96
89
|
if patch_size_meters % 10 != 0:
|
|
97
90
|
raise ValueError(
|
|
98
91
|
"In AnySat, `patch_size` is in meters and must be a multiple of 10."
|
|
@@ -106,7 +99,6 @@ class AnySat(FeatureExtractor):
|
|
|
106
99
|
|
|
107
100
|
self.modalities = modalities
|
|
108
101
|
self.patch_size_meters = int(patch_size_meters)
|
|
109
|
-
self.dates = dates
|
|
110
102
|
self.output = output
|
|
111
103
|
self.output_modality = output_modality
|
|
112
104
|
|
|
@@ -119,6 +111,20 @@ class AnySat(FeatureExtractor):
|
|
|
119
111
|
)
|
|
120
112
|
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
121
113
|
|
|
114
|
+
@staticmethod
|
|
115
|
+
def time_ranges_to_doy(
|
|
116
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
117
|
+
device: torch.device,
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
|
|
120
|
+
|
|
121
|
+
AnySat uses the doy with each timestamp, so we take the midpoint
|
|
122
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
123
|
+
time so that start_time == end_time == mid_time.
|
|
124
|
+
"""
|
|
125
|
+
doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
|
|
126
|
+
return torch.tensor(doys, dtype=torch.int32, device=device)
|
|
127
|
+
|
|
122
128
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
123
129
|
"""Forward pass for the AnySat model.
|
|
124
130
|
|
|
@@ -139,17 +145,29 @@ class AnySat(FeatureExtractor):
|
|
|
139
145
|
raise ValueError(f"Modality '{modality}' not present in inputs.")
|
|
140
146
|
|
|
141
147
|
cur = torch.stack(
|
|
142
|
-
[inp[modality] for inp in inputs], dim=0
|
|
143
|
-
) # (B, C,
|
|
148
|
+
[inp[modality].image for inp in inputs], dim=0
|
|
149
|
+
) # (B, C, T, H, W)
|
|
144
150
|
|
|
145
151
|
if modality in TIME_SERIES_MODALITIES:
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
cur = rearrange(
|
|
149
|
-
cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
|
|
150
|
-
)
|
|
152
|
+
num_bands = cur.shape[1]
|
|
153
|
+
cur = rearrange(cur, "b c t h w -> b t c h w")
|
|
151
154
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
|
+
|
|
156
|
+
if inputs[0][modality].timestamps is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Require timestamps for time series modality {modality}"
|
|
159
|
+
)
|
|
160
|
+
timestamps = torch.stack(
|
|
161
|
+
[
|
|
162
|
+
self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
|
|
163
|
+
for inp in inputs
|
|
164
|
+
],
|
|
165
|
+
dim=0,
|
|
166
|
+
)
|
|
167
|
+
batch[f"{modality}_dates"] = timestamps
|
|
152
168
|
else:
|
|
169
|
+
# take the first (assumed only) timestep
|
|
170
|
+
cur = cur[:, :, 0]
|
|
153
171
|
num_bands = cur.shape[1]
|
|
154
172
|
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
173
|
|
|
@@ -173,22 +191,6 @@ class AnySat(FeatureExtractor):
|
|
|
173
191
|
"All modalities must share the same spatial extent (H*res, W*res)."
|
|
174
192
|
)
|
|
175
193
|
|
|
176
|
-
# Add *_dates
|
|
177
|
-
to_add = {}
|
|
178
|
-
for modality, x in list(batch.items()):
|
|
179
|
-
if modality in TIME_SERIES_MODALITIES:
|
|
180
|
-
B, T = x.shape[0], x.shape[1]
|
|
181
|
-
d = torch.as_tensor(
|
|
182
|
-
self.dates[modality], dtype=torch.long, device=x.device
|
|
183
|
-
)
|
|
184
|
-
if d.ndim != 1 or d.numel() != T:
|
|
185
|
-
raise ValueError(
|
|
186
|
-
f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
|
|
187
|
-
)
|
|
188
|
-
to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
|
|
189
|
-
|
|
190
|
-
batch.update(to_add)
|
|
191
|
-
|
|
192
194
|
kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
|
|
193
195
|
if self.output == "dense":
|
|
194
196
|
kwargs["output_modality"] = self.output_modality
|
|
@@ -43,9 +43,12 @@ class CLIP(FeatureExtractor):
|
|
|
43
43
|
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
44
|
"""
|
|
45
45
|
inputs = context.inputs
|
|
46
|
-
device = inputs[0]["image"].device
|
|
46
|
+
device = inputs[0]["image"].image.device
|
|
47
47
|
clip_inputs = self.processor(
|
|
48
|
-
images=[
|
|
48
|
+
images=[
|
|
49
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
50
|
+
for inp in inputs
|
|
51
|
+
],
|
|
49
52
|
return_tensors="pt",
|
|
50
53
|
padding=True,
|
|
51
54
|
)
|
|
@@ -175,10 +175,16 @@ class Croma(FeatureExtractor):
|
|
|
175
175
|
sentinel1: torch.Tensor | None = None
|
|
176
176
|
sentinel2: torch.Tensor | None = None
|
|
177
177
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
178
|
-
sentinel1 = torch.stack(
|
|
178
|
+
sentinel1 = torch.stack(
|
|
179
|
+
[inp["sentinel1"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
180
|
+
dim=0,
|
|
181
|
+
)
|
|
179
182
|
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
180
183
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
181
|
-
sentinel2 = torch.stack(
|
|
184
|
+
sentinel2 = torch.stack(
|
|
185
|
+
[inp["sentinel2"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
186
|
+
dim=0,
|
|
187
|
+
)
|
|
182
188
|
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
183
189
|
|
|
184
190
|
outputs = self.model(
|
|
@@ -294,5 +300,7 @@ class CromaNormalize(Transform):
|
|
|
294
300
|
for modality in MODALITY_BANDS.keys():
|
|
295
301
|
if modality not in input_dict:
|
|
296
302
|
continue
|
|
297
|
-
input_dict[modality] = self.apply_image(
|
|
303
|
+
input_dict[modality].image = self.apply_image(
|
|
304
|
+
input_dict[modality].image, modality
|
|
305
|
+
)
|
|
298
306
|
return input_dict, target_dict
|
|
@@ -104,7 +104,8 @@ class DinoV3(FeatureExtractor):
|
|
|
104
104
|
a FeatureMaps with one feature map.
|
|
105
105
|
"""
|
|
106
106
|
cur = torch.stack(
|
|
107
|
-
[inp["image"] for inp in context.inputs],
|
|
107
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
108
|
+
dim=0,
|
|
108
109
|
) # (B, C, H, W)
|
|
109
110
|
|
|
110
111
|
if self.do_resizing and (
|
|
@@ -210,7 +210,8 @@ class FasterRCNN(Predictor):
|
|
|
210
210
|
),
|
|
211
211
|
)
|
|
212
212
|
|
|
213
|
-
|
|
213
|
+
# take the first (and assumed to be only) timestep
|
|
214
|
+
image_list = [inp["image"].image[:, 0] for inp in context.inputs]
|
|
214
215
|
images, targets = self.noop_transform(image_list, targets)
|
|
215
216
|
|
|
216
217
|
feature_dict = collections.OrderedDict()
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import math
|
|
4
4
|
import tempfile
|
|
5
5
|
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
6
7
|
from enum import StrEnum
|
|
7
8
|
from typing import cast
|
|
8
9
|
|
|
@@ -411,6 +412,23 @@ class GalileoModel(FeatureExtractor):
|
|
|
411
412
|
months=months,
|
|
412
413
|
)
|
|
413
414
|
|
|
415
|
+
@staticmethod
|
|
416
|
+
def time_ranges_to_timestamps(
|
|
417
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
418
|
+
device: torch.device,
|
|
419
|
+
) -> torch.Tensor:
|
|
420
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
|
|
421
|
+
|
|
422
|
+
Galileo only uses the month associated with each timestamp, so we take the midpoint
|
|
423
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
424
|
+
time so that start_time == end_time == mid_time.
|
|
425
|
+
"""
|
|
426
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
427
|
+
# months are indexed 0-11
|
|
428
|
+
return torch.tensor(
|
|
429
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
430
|
+
)
|
|
431
|
+
|
|
414
432
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
415
433
|
"""Compute feature maps from the Galileo backbone.
|
|
416
434
|
|
|
@@ -418,16 +436,16 @@ class GalileoModel(FeatureExtractor):
|
|
|
418
436
|
context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
|
|
419
437
|
(also documented below) and values are tensors of the following shapes,
|
|
420
438
|
per input key:
|
|
421
|
-
"s1": B
|
|
422
|
-
"s2": B
|
|
423
|
-
"era5": B
|
|
424
|
-
"tc": B
|
|
425
|
-
"viirs": B
|
|
426
|
-
"srtm": B C H W (SRTM has no temporal dimension)
|
|
427
|
-
"dw": : B C H W (Dynamic World should be averaged over time)
|
|
428
|
-
"wc": B C H W (WorldCereal has no temporal dimension)
|
|
429
|
-
"landscan": B C H W (we will average over the H, W dimensions)
|
|
430
|
-
"latlon": B C H W (we will average over the H, W dimensions)
|
|
439
|
+
"s1": B C T H W
|
|
440
|
+
"s2": B C T H W
|
|
441
|
+
"era5": B C T H W (we will average over the H, W dimensions)
|
|
442
|
+
"tc": B C T H W (we will average over the H, W dimensions)
|
|
443
|
+
"viirs": B C T H W (we will average over the H, W dimensions)
|
|
444
|
+
"srtm": B C 1 H W (SRTM has no temporal dimension)
|
|
445
|
+
"dw": : B C 1 H W (Dynamic World should be averaged over time)
|
|
446
|
+
"wc": B C 1 H W (WorldCereal has no temporal dimension)
|
|
447
|
+
"landscan": B C 1 H W (we will average over the H, W dimensions)
|
|
448
|
+
"latlon": B C 1 H W (we will average over the H, W dimensions)
|
|
431
449
|
|
|
432
450
|
The output will be an embedding representing the pooled tokens. If there is
|
|
433
451
|
only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
|
|
@@ -436,15 +454,35 @@ class GalileoModel(FeatureExtractor):
|
|
|
436
454
|
If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
|
|
437
455
|
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
438
456
|
"""
|
|
457
|
+
space_time_modalities = ["s1", "s2"]
|
|
458
|
+
time_modalities = ["era5", "tc", "viirs"]
|
|
439
459
|
stacked_inputs = {}
|
|
460
|
+
months: torch.Tensor | None = None
|
|
440
461
|
for key in context.inputs[0].keys():
|
|
441
462
|
# assume all the keys in an input are consistent
|
|
442
463
|
if key in self.input_keys:
|
|
443
464
|
stacked_inputs[key] = torch.stack(
|
|
444
|
-
[inp[key] for inp in context.inputs], dim=0
|
|
465
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
445
466
|
)
|
|
467
|
+
if key in space_time_modalities + time_modalities:
|
|
468
|
+
if months is None:
|
|
469
|
+
if context.inputs[0][key].timestamps is not None:
|
|
470
|
+
months = torch.stack(
|
|
471
|
+
[
|
|
472
|
+
self.time_ranges_to_timestamps(
|
|
473
|
+
inp[key].timestamps, # type: ignore
|
|
474
|
+
device=stacked_inputs[key].device,
|
|
475
|
+
)
|
|
476
|
+
for inp in context.inputs
|
|
477
|
+
],
|
|
478
|
+
dim=0,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if months is not None:
|
|
482
|
+
stacked_inputs["months"] = months
|
|
483
|
+
|
|
446
484
|
s_t_channels = []
|
|
447
|
-
for space_time_modality in
|
|
485
|
+
for space_time_modality in space_time_modalities:
|
|
448
486
|
if space_time_modality not in stacked_inputs:
|
|
449
487
|
continue
|
|
450
488
|
if space_time_modality == "s1":
|
|
@@ -452,36 +490,27 @@ class GalileoModel(FeatureExtractor):
|
|
|
452
490
|
else:
|
|
453
491
|
s_t_channels += self.s_t_channels_s2
|
|
454
492
|
cur = stacked_inputs[space_time_modality]
|
|
455
|
-
|
|
456
|
-
num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
|
|
457
|
-
num_timesteps = cur.shape[1] // num_bands
|
|
458
|
-
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
493
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
459
494
|
stacked_inputs[space_time_modality] = cur
|
|
460
495
|
|
|
461
496
|
for space_modality in ["srtm", "dw", "wc"]:
|
|
462
497
|
if space_modality not in stacked_inputs:
|
|
463
498
|
continue
|
|
499
|
+
# take the first (and assumed only) timestep
|
|
500
|
+
stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
|
|
464
501
|
stacked_inputs[space_modality] = rearrange(
|
|
465
502
|
stacked_inputs[space_modality], "b c h w -> b h w c"
|
|
466
503
|
)
|
|
467
504
|
|
|
468
|
-
for time_modality in
|
|
505
|
+
for time_modality in time_modalities:
|
|
469
506
|
if time_modality not in stacked_inputs:
|
|
470
507
|
continue
|
|
471
508
|
cur = stacked_inputs[time_modality]
|
|
472
|
-
# Check if it's single or multitemporal, and reshape accordingly
|
|
473
|
-
num_bands = {
|
|
474
|
-
"era5": len(ERA5_BANDS),
|
|
475
|
-
"tc": len(TC_BANDS),
|
|
476
|
-
"viirs": len(VIIRS_BANDS),
|
|
477
|
-
}[time_modality]
|
|
478
|
-
num_timesteps = cur.shape[1] // num_bands
|
|
479
509
|
# take the average over the h, w bands since Galileo
|
|
480
510
|
# treats it as a pixel-timeseries
|
|
481
511
|
cur = rearrange(
|
|
482
|
-
torch.nanmean(
|
|
483
|
-
"b
|
|
484
|
-
t=num_timesteps,
|
|
512
|
+
torch.nanmean(cur, dim=(-1, -2)),
|
|
513
|
+
"b c t -> b t c",
|
|
485
514
|
)
|
|
486
515
|
stacked_inputs[time_modality] = cur
|
|
487
516
|
|
|
@@ -489,9 +518,8 @@ class GalileoModel(FeatureExtractor):
|
|
|
489
518
|
if static_modality not in stacked_inputs:
|
|
490
519
|
continue
|
|
491
520
|
cur = stacked_inputs[static_modality]
|
|
492
|
-
stacked_inputs[static_modality] = torch.nanmean(
|
|
493
|
-
|
|
494
|
-
)
|
|
521
|
+
stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
|
|
522
|
+
|
|
495
523
|
galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
|
|
496
524
|
h = galileo_input.s_t_x.shape[1]
|
|
497
525
|
if h < self.patch_size:
|
|
@@ -511,7 +539,6 @@ class GalileoModel(FeatureExtractor):
|
|
|
511
539
|
torch_context = torch.amp.autocast(
|
|
512
540
|
device_type=device.type, dtype=self.autocast_dtype
|
|
513
541
|
)
|
|
514
|
-
|
|
515
542
|
with torch_context:
|
|
516
543
|
outputs = self.model(
|
|
517
544
|
s_t_x=galileo_input.s_t_x,
|
|
@@ -53,7 +53,12 @@ class EncoderModuleWrapper(FeatureExtractor):
|
|
|
53
53
|
Returns:
|
|
54
54
|
the output from the last wrapped module.
|
|
55
55
|
"""
|
|
56
|
-
|
|
56
|
+
# take the first and only timestep. Currently no intermediate
|
|
57
|
+
# components support multi temporal inputs, so if the input is
|
|
58
|
+
# multitemporal it should be wrapped in a simple time series wrapper.
|
|
59
|
+
images = torch.stack(
|
|
60
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
61
|
+
)
|
|
57
62
|
cur: Any = FeatureMaps([images])
|
|
58
63
|
for m in self.encoder_modules:
|
|
59
64
|
cur = m(cur, context)
|
|
@@ -47,11 +47,13 @@ class Molmo(FeatureExtractor):
|
|
|
47
47
|
a FeatureMaps. Molmo produces features at one scale, so it will contain one
|
|
48
48
|
feature map that is a Bx24x24x2048 tensor.
|
|
49
49
|
"""
|
|
50
|
-
device = context.inputs[0]["image"].device
|
|
50
|
+
device = context.inputs[0]["image"].image.device
|
|
51
51
|
molmo_inputs_list = []
|
|
52
52
|
# Process each one so we can isolate just the full image without any crops.
|
|
53
53
|
for inp in context.inputs:
|
|
54
|
-
image =
|
|
54
|
+
image = (
|
|
55
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
56
|
+
)
|
|
55
57
|
processed = self.processor.process(
|
|
56
58
|
images=[image],
|
|
57
59
|
text="",
|
|
@@ -1,26 +1,27 @@
|
|
|
1
1
|
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import warnings
|
|
4
5
|
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
5
7
|
from typing import Any
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
10
|
from einops import rearrange
|
|
9
|
-
from
|
|
10
|
-
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.config import Config, require_olmo_core
|
|
11
12
|
from olmoearth_pretrain.data.constants import Modality
|
|
13
|
+
from olmoearth_pretrain.datatypes import MaskedOlmoEarthSample, MaskValue
|
|
12
14
|
from olmoearth_pretrain.model_loader import (
|
|
13
15
|
ModelID,
|
|
14
16
|
load_model_from_id,
|
|
15
17
|
load_model_from_path,
|
|
16
18
|
)
|
|
17
19
|
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
18
|
-
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
19
20
|
from upath import UPath
|
|
20
21
|
|
|
21
22
|
from rslearn.log_utils import get_logger
|
|
22
23
|
from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
|
|
23
|
-
from rslearn.train.model_context import ModelContext
|
|
24
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
24
25
|
|
|
25
26
|
logger = get_logger(__name__)
|
|
26
27
|
|
|
@@ -61,6 +62,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
61
62
|
embedding_size: int | None = None,
|
|
62
63
|
autocast_dtype: str | None = "bfloat16",
|
|
63
64
|
token_pooling: bool = True,
|
|
65
|
+
use_legacy_timestamps: bool = True,
|
|
64
66
|
):
|
|
65
67
|
"""Create a new OlmoEarth model.
|
|
66
68
|
|
|
@@ -87,7 +89,15 @@ class OlmoEarth(FeatureExtractor):
|
|
|
87
89
|
token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
|
|
88
90
|
there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
|
|
89
91
|
dimensions.
|
|
92
|
+
use_legacy_timestamps: In our original implementation of OlmoEarth, we applied timestamps starting
|
|
93
|
+
from 0 (instead of the actual timestamps of the input). The option to do this is preserved
|
|
94
|
+
for backwards compatability with finetuned models which were trained against this implementation.
|
|
90
95
|
"""
|
|
96
|
+
if use_legacy_timestamps:
|
|
97
|
+
warnings.warn(
|
|
98
|
+
"For new projects, don't use legacy timesteps.", DeprecationWarning
|
|
99
|
+
)
|
|
100
|
+
|
|
91
101
|
if (
|
|
92
102
|
sum(
|
|
93
103
|
[
|
|
@@ -138,6 +148,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
138
148
|
model = model[part]
|
|
139
149
|
self.model = model
|
|
140
150
|
self.token_pooling = token_pooling
|
|
151
|
+
self.use_legacy_timestamps = use_legacy_timestamps
|
|
141
152
|
|
|
142
153
|
def _load_model_from_checkpoint(
|
|
143
154
|
self, checkpoint_upath: UPath, random_initialization: bool
|
|
@@ -148,9 +159,12 @@ class OlmoEarth(FeatureExtractor):
|
|
|
148
159
|
that contains the distributed checkpoint. This is the format produced by
|
|
149
160
|
pre-training runs in olmoearth_pretrain.
|
|
150
161
|
"""
|
|
151
|
-
# Load the model config and initialize it.
|
|
152
162
|
# We avoid loading the train module here because it depends on running within
|
|
153
163
|
# olmo_core.
|
|
164
|
+
# Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
|
|
165
|
+
require_olmo_core("_load_model_from_checkpoint")
|
|
166
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
167
|
+
|
|
154
168
|
with (checkpoint_upath / "config.json").open() as f:
|
|
155
169
|
config_dict = json.load(f)
|
|
156
170
|
model_config = Config.from_dict(config_dict["model"])
|
|
@@ -165,6 +179,32 @@ class OlmoEarth(FeatureExtractor):
|
|
|
165
179
|
|
|
166
180
|
return model
|
|
167
181
|
|
|
182
|
+
@staticmethod
|
|
183
|
+
def time_ranges_to_timestamps(
|
|
184
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
185
|
+
max_timestamps: int,
|
|
186
|
+
device: torch.device,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by OlmoEarth.
|
|
189
|
+
|
|
190
|
+
OlmoEarth only uses the month associated with each timestamp, so we take the midpoint
|
|
191
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
192
|
+
time so that start_time == end_time == mid_time.
|
|
193
|
+
"""
|
|
194
|
+
timestamps = torch.zeros((max_timestamps, 3), dtype=torch.int32, device=device)
|
|
195
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
196
|
+
timestamps[: len(time_ranges), 0] = torch.tensor(
|
|
197
|
+
[d.day for d in mid_ranges], dtype=torch.int32
|
|
198
|
+
)
|
|
199
|
+
# months are indexed 0-11
|
|
200
|
+
timestamps[: len(time_ranges), 1] = torch.tensor(
|
|
201
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32
|
|
202
|
+
)
|
|
203
|
+
timestamps[: len(time_ranges), 2] = torch.tensor(
|
|
204
|
+
[d.year for d in mid_ranges], dtype=torch.int32
|
|
205
|
+
)
|
|
206
|
+
return timestamps
|
|
207
|
+
|
|
168
208
|
def _prepare_modality_inputs(
|
|
169
209
|
self, context: ModelContext
|
|
170
210
|
) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
|
|
@@ -190,43 +230,55 @@ class OlmoEarth(FeatureExtractor):
|
|
|
190
230
|
# We'll have to fix all that.
|
|
191
231
|
max_timesteps = 1
|
|
192
232
|
modality_data = {}
|
|
233
|
+
# we will just store the longest time range
|
|
234
|
+
# per instance in the batch. This means it may not be
|
|
235
|
+
# aligned per modality
|
|
236
|
+
timestamps_per_instance: list[list[tuple[datetime, datetime]]] = [[]] * len(
|
|
237
|
+
context.inputs
|
|
238
|
+
)
|
|
193
239
|
for modality in MODALITY_NAMES:
|
|
194
240
|
if modality not in context.inputs[0]:
|
|
195
241
|
continue
|
|
196
242
|
present_modalities.append(modality)
|
|
197
|
-
tensors = [
|
|
243
|
+
tensors = []
|
|
244
|
+
for idx, inp in enumerate(context.inputs):
|
|
245
|
+
assert isinstance(inp, RasterImage)
|
|
246
|
+
tensors.append(inp[modality].image)
|
|
247
|
+
cur_timestamps = inp[modality].timestamps
|
|
248
|
+
if cur_timestamps is not None and len(cur_timestamps) > len(
|
|
249
|
+
timestamps_per_instance[idx]
|
|
250
|
+
):
|
|
251
|
+
timestamps_per_instance[idx] = cur_timestamps
|
|
252
|
+
tensors = [inp[modality].image for inp in context.inputs]
|
|
198
253
|
device = tensors[0].device
|
|
199
|
-
|
|
200
|
-
max_t = max(t.shape[0] for t in tensors) // num_bands
|
|
254
|
+
max_t = max(t.shape[1] for t in tensors)
|
|
201
255
|
max_timesteps = max(max_timesteps, max_t)
|
|
202
256
|
modality_data[modality] = (
|
|
203
257
|
tensors,
|
|
204
|
-
num_bands,
|
|
205
258
|
len(Modality.get(modality).band_sets),
|
|
206
259
|
)
|
|
207
260
|
|
|
208
261
|
# Second pass: pad and process each modality with global max_timesteps
|
|
209
262
|
for modality in present_modalities:
|
|
210
|
-
tensors,
|
|
211
|
-
target_ch = max_timesteps * num_bands
|
|
263
|
+
tensors, num_band_sets = modality_data[modality]
|
|
212
264
|
|
|
213
265
|
# Pad tensors to target_ch and track original timesteps for masking
|
|
214
266
|
padded = []
|
|
215
267
|
original_timesteps = []
|
|
216
268
|
for t in tensors:
|
|
217
|
-
orig_t = t.shape[
|
|
269
|
+
orig_t = t.shape[1]
|
|
218
270
|
original_timesteps.append(orig_t)
|
|
219
|
-
if
|
|
271
|
+
if orig_t < max_timesteps:
|
|
220
272
|
pad = torch.zeros(
|
|
221
|
-
|
|
273
|
+
t.shape[:1] + (max_timesteps - orig_t,) + t.shape[2:],
|
|
222
274
|
dtype=t.dtype,
|
|
223
275
|
device=device,
|
|
224
276
|
)
|
|
225
|
-
t = torch.cat([t, pad], dim=
|
|
277
|
+
t = torch.cat([t, pad], dim=1)
|
|
226
278
|
padded.append(t)
|
|
227
279
|
|
|
228
280
|
cur = torch.stack(padded, dim=0)
|
|
229
|
-
cur = rearrange(cur, "b
|
|
281
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
230
282
|
kwargs[modality] = cur
|
|
231
283
|
|
|
232
284
|
# Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
|
|
@@ -242,19 +294,31 @@ class OlmoEarth(FeatureExtractor):
|
|
|
242
294
|
mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
|
|
243
295
|
kwargs[f"{modality}_mask"] = mask
|
|
244
296
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
297
|
+
if self.use_legacy_timestamps:
|
|
298
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
299
|
+
timestamps = torch.zeros(
|
|
300
|
+
(len(context.inputs), max_timesteps, 3),
|
|
301
|
+
dtype=torch.int32,
|
|
302
|
+
device=device,
|
|
303
|
+
)
|
|
304
|
+
timestamps[:, :, 0] = 1 # day
|
|
305
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
306
|
+
None, :
|
|
307
|
+
] # month
|
|
308
|
+
timestamps[:, :, 2] = 2024 # year
|
|
309
|
+
kwargs["timestamps"] = timestamps
|
|
310
|
+
else:
|
|
311
|
+
if max([len(t) for t in timestamps_per_instance]) == 0:
|
|
312
|
+
# Timestamps is required.
|
|
313
|
+
raise ValueError("No inputs had timestamps.")
|
|
314
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
315
|
+
kwargs["timestamps"] = torch.stack(
|
|
316
|
+
[
|
|
317
|
+
self.time_ranges_to_timestamps(time_range, max_timesteps, device)
|
|
318
|
+
for time_range in timestamps_per_instance
|
|
319
|
+
],
|
|
320
|
+
dim=0,
|
|
321
|
+
)
|
|
258
322
|
|
|
259
323
|
return MaskedOlmoEarthSample(**kwargs), present_modalities, device
|
|
260
324
|
|
|
@@ -64,8 +64,8 @@ class OlmoEarthNormalize(Transform):
|
|
|
64
64
|
band_norms = self.norm_config[modality_name]
|
|
65
65
|
image = input_dict[modality_name]
|
|
66
66
|
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
-
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
-
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
67
|
+
needed_band_indices = set(range(image.image.shape[0]))
|
|
68
|
+
num_timesteps = image.image.shape[0] // len(cur_band_names)
|
|
69
69
|
|
|
70
70
|
for band, norm_dict in band_norms.items():
|
|
71
71
|
# If multitemporal, normalize each timestep separately.
|
|
@@ -73,7 +73,9 @@ class OlmoEarthNormalize(Transform):
|
|
|
73
73
|
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
74
|
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
75
|
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
-
image[band_idx] = (image[band_idx] - min_val) / (
|
|
76
|
+
image.image[band_idx] = (image.image[band_idx] - min_val) / (
|
|
77
|
+
max_val - min_val
|
|
78
|
+
)
|
|
77
79
|
needed_band_indices.remove(band_idx)
|
|
78
80
|
|
|
79
81
|
if len(needed_band_indices) > 0:
|
|
@@ -142,7 +142,9 @@ class Panopticon(FeatureExtractor):
|
|
|
142
142
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
143
143
|
"""Forward pass through the panopticon model."""
|
|
144
144
|
batch_inputs = {
|
|
145
|
-
key: torch.stack(
|
|
145
|
+
key: torch.stack(
|
|
146
|
+
[inp[key].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
147
|
+
)
|
|
146
148
|
for key in context.inputs[0].keys()
|
|
147
149
|
}
|
|
148
150
|
panopticon_inputs = self.prepare_input(batch_inputs)
|