rslearn 0.0.14__tar.gz → 0.0.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.14/rslearn.egg-info → rslearn-0.0.15}/PKG-INFO +1 -1
- {rslearn-0.0.14 → rslearn-0.0.15}/pyproject.toml +1 -1
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/manage.py +14 -6
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/clay/clay.py +14 -1
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/croma.py +26 -3
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/satlaspretrain.py +18 -4
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/terramind.py +19 -0
- {rslearn-0.0.14 → rslearn-0.0.15/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.14 → rslearn-0.0.15}/LICENSE +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/NOTICE +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/README.md +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/const.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/lightning_cli.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/main.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/py.typed +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/template_params.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/all_patches_dataset.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.15}/setup.cfg +0 -0
|
@@ -124,12 +124,24 @@ def prepare_dataset_windows(
|
|
|
124
124
|
)
|
|
125
125
|
continue
|
|
126
126
|
data_source_cfg = layer_cfg.data_source
|
|
127
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
127
128
|
|
|
128
129
|
# Get windows that need to be prepared for this layer.
|
|
130
|
+
# Also track which windows are skipped vs previously rejected.
|
|
129
131
|
needed_windows = []
|
|
132
|
+
windows_skipped = 0
|
|
133
|
+
windows_rejected = 0
|
|
130
134
|
for window in windows:
|
|
131
135
|
layer_datas = window.load_layer_datas()
|
|
132
136
|
if layer_name in layer_datas and not force:
|
|
137
|
+
# Window already has layer data - check if it was previously rejected
|
|
138
|
+
layer_data = layer_datas[layer_name]
|
|
139
|
+
if len(layer_data.serialized_item_groups) == 0 and min_matches > 0:
|
|
140
|
+
# Previously rejected due to min_matches
|
|
141
|
+
windows_rejected += 1
|
|
142
|
+
else:
|
|
143
|
+
# Successfully prepared previously
|
|
144
|
+
windows_skipped += 1
|
|
133
145
|
continue
|
|
134
146
|
needed_windows.append(window)
|
|
135
147
|
logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
|
|
@@ -141,8 +153,8 @@ def prepare_dataset_windows(
|
|
|
141
153
|
data_source_name=data_source_cfg.name,
|
|
142
154
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
143
155
|
windows_prepared=0,
|
|
144
|
-
windows_skipped=
|
|
145
|
-
windows_rejected=
|
|
156
|
+
windows_skipped=windows_skipped,
|
|
157
|
+
windows_rejected=windows_rejected,
|
|
146
158
|
get_items_attempts=0,
|
|
147
159
|
)
|
|
148
160
|
)
|
|
@@ -184,8 +196,6 @@ def prepare_dataset_windows(
|
|
|
184
196
|
)
|
|
185
197
|
|
|
186
198
|
windows_prepared = 0
|
|
187
|
-
windows_rejected = 0
|
|
188
|
-
min_matches = data_source_cfg.query_config.min_matches
|
|
189
199
|
for window, result in zip(needed_windows, results):
|
|
190
200
|
layer_datas = window.load_layer_datas()
|
|
191
201
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -202,8 +212,6 @@ def prepare_dataset_windows(
|
|
|
202
212
|
else:
|
|
203
213
|
windows_prepared += 1
|
|
204
214
|
|
|
205
|
-
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
-
|
|
207
215
|
layer_summaries.append(
|
|
208
216
|
LayerPrepareSummary(
|
|
209
217
|
layer_name=layer_name,
|
|
@@ -8,6 +8,7 @@ from importlib.resources import files
|
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
11
12
|
import yaml
|
|
12
13
|
from einops import rearrange
|
|
13
14
|
from huggingface_hub import hf_hub_download
|
|
@@ -30,6 +31,7 @@ PATCH_SIZE = 8
|
|
|
30
31
|
CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
|
|
31
32
|
CONFIG_DIR = files("rslearn.models.clay.configs")
|
|
32
33
|
CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
|
|
34
|
+
DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def get_clay_checkpoint_path(
|
|
@@ -49,6 +51,7 @@ class Clay(torch.nn.Module):
|
|
|
49
51
|
modality: str = "sentinel-2-l2a",
|
|
50
52
|
checkpoint_path: str | None = None,
|
|
51
53
|
metadata_path: str = CLAY_METADATA_PATH,
|
|
54
|
+
do_resizing: bool = False,
|
|
52
55
|
) -> None:
|
|
53
56
|
"""Initialize the Clay model.
|
|
54
57
|
|
|
@@ -57,6 +60,7 @@ class Clay(torch.nn.Module):
|
|
|
57
60
|
modality: The modality to use (subset of CLAY_MODALITIES).
|
|
58
61
|
checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
|
|
59
62
|
metadata_path: Path to metadata.yaml.
|
|
63
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
60
64
|
"""
|
|
61
65
|
super().__init__()
|
|
62
66
|
|
|
@@ -95,6 +99,14 @@ class Clay(torch.nn.Module):
|
|
|
95
99
|
|
|
96
100
|
self.model_size = model_size
|
|
97
101
|
self.modality = modality
|
|
102
|
+
self.do_resizing = do_resizing
|
|
103
|
+
|
|
104
|
+
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
105
|
+
"""Resize the image to the input resolution."""
|
|
106
|
+
new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
107
|
+
return F.interpolate(
|
|
108
|
+
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
109
|
+
)
|
|
98
110
|
|
|
99
111
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
100
112
|
"""Forward pass for the Clay model.
|
|
@@ -114,7 +126,8 @@ class Clay(torch.nn.Module):
|
|
|
114
126
|
chips = torch.stack(
|
|
115
127
|
[inp[self.modality] for inp in inputs], dim=0
|
|
116
128
|
) # (B, C, H, W)
|
|
117
|
-
|
|
129
|
+
if self.do_resizing:
|
|
130
|
+
chips = self._resize_image(chips, chips.shape[2])
|
|
118
131
|
order = self.metadata[self.modality]["band_order"]
|
|
119
132
|
wavelengths = []
|
|
120
133
|
for band in self.metadata[self.modality]["band_order"]:
|
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
10
11
|
from einops import rearrange
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
@@ -99,6 +100,7 @@ class Croma(torch.nn.Module):
|
|
|
99
100
|
modality: CromaModality,
|
|
100
101
|
pretrained_path: str | None = None,
|
|
101
102
|
image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
|
|
103
|
+
do_resizing: bool = False,
|
|
102
104
|
) -> None:
|
|
103
105
|
"""Instantiate a new Croma instance.
|
|
104
106
|
|
|
@@ -107,12 +109,21 @@ class Croma(torch.nn.Module):
|
|
|
107
109
|
modality: the modalities to configure the model to accept.
|
|
108
110
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
109
111
|
downloaded and cached in temp directory.
|
|
110
|
-
image_resolution: the width and height of the input images.
|
|
112
|
+
image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
|
|
113
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
111
114
|
"""
|
|
112
115
|
super().__init__()
|
|
113
116
|
self.size = size
|
|
114
117
|
self.modality = modality
|
|
115
|
-
self.
|
|
118
|
+
self.do_resizing = do_resizing
|
|
119
|
+
if not do_resizing:
|
|
120
|
+
self.image_resolution = image_resolution
|
|
121
|
+
else:
|
|
122
|
+
# With single pixel input, we always resample to the patch size.
|
|
123
|
+
if image_resolution == 1:
|
|
124
|
+
self.image_resolution = PATCH_SIZE
|
|
125
|
+
else:
|
|
126
|
+
self.image_resolution = DEFAULT_IMAGE_RESOLUTION
|
|
116
127
|
|
|
117
128
|
# Cache the CROMA weights to a deterministic path in temporary directory if the
|
|
118
129
|
# path is not provided by the user.
|
|
@@ -137,7 +148,16 @@ class Croma(torch.nn.Module):
|
|
|
137
148
|
pretrained_path=pretrained_path,
|
|
138
149
|
size=size.value,
|
|
139
150
|
modality=modality.value,
|
|
140
|
-
image_resolution=image_resolution,
|
|
151
|
+
image_resolution=self.image_resolution,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
"""Resize the image to the input resolution."""
|
|
156
|
+
return F.interpolate(
|
|
157
|
+
image,
|
|
158
|
+
size=(self.image_resolution, self.image_resolution),
|
|
159
|
+
mode="bilinear",
|
|
160
|
+
align_corners=False,
|
|
141
161
|
)
|
|
142
162
|
|
|
143
163
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
@@ -151,8 +171,11 @@ class Croma(torch.nn.Module):
|
|
|
151
171
|
sentinel2: torch.Tensor | None = None
|
|
152
172
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
153
173
|
sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
|
|
174
|
+
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
154
175
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
155
176
|
sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
177
|
+
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
178
|
+
|
|
156
179
|
outputs = self.model(
|
|
157
180
|
SAR_images=sentinel1,
|
|
158
181
|
optical_images=sentinel2,
|
|
@@ -4,15 +4,14 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import satlaspretrain_models
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SatlasPretrain(torch.nn.Module):
|
|
10
11
|
"""SatlasPretrain backbones."""
|
|
11
12
|
|
|
12
13
|
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
model_identifier: str,
|
|
15
|
-
fpn: bool = False,
|
|
14
|
+
self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
|
|
16
15
|
) -> None:
|
|
17
16
|
"""Instantiate a new SatlasPretrain instance.
|
|
18
17
|
|
|
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
21
20
|
https://github.com/allenai/satlaspretrain_models
|
|
22
21
|
fpn: whether to include the feature pyramid network, otherwise only the
|
|
23
22
|
Swin-v2-Transformer is used.
|
|
23
|
+
resize_to_pretrain: whether to resize inputs to the pretraining input
|
|
24
|
+
size (512 x 512)
|
|
24
25
|
"""
|
|
25
26
|
super().__init__()
|
|
26
27
|
weights_manager = satlaspretrain_models.Weights()
|
|
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
49
50
|
[16, 1024],
|
|
50
51
|
[32, 2048],
|
|
51
52
|
]
|
|
53
|
+
self.resize_to_pretrain = resize_to_pretrain
|
|
54
|
+
|
|
55
|
+
def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
|
|
56
|
+
"""Resize to pretraining sizes if resize_to_pretrain == True."""
|
|
57
|
+
if self.resize_to_pretrain:
|
|
58
|
+
return F.interpolate(
|
|
59
|
+
data,
|
|
60
|
+
size=(512, 512),
|
|
61
|
+
mode="bilinear",
|
|
62
|
+
align_corners=False,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return data
|
|
52
66
|
|
|
53
67
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
54
68
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
58
72
|
process.
|
|
59
73
|
"""
|
|
60
74
|
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
61
|
-
return self.model(images)
|
|
75
|
+
return self.model(self.maybe_resize(images))
|
|
62
76
|
|
|
63
77
|
def get_backbone_channels(self) -> list:
|
|
64
78
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -4,6 +4,7 @@ from enum import Enum
|
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
from einops import rearrange
|
|
8
9
|
from terratorch.registry import BACKBONE_REGISTRY
|
|
9
10
|
|
|
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
|
|
|
18
19
|
LARGE = "large"
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
# Pretraining image size for Terramind
|
|
23
|
+
IMAGE_SIZE = 224
|
|
21
24
|
# Default patch size for Terramind
|
|
22
25
|
PATCH_SIZE = 16
|
|
23
26
|
|
|
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
|
|
|
89
92
|
self,
|
|
90
93
|
model_size: TerramindSize,
|
|
91
94
|
modalities: list[str] = ["S2L2A"],
|
|
95
|
+
do_resizing: bool = False,
|
|
92
96
|
) -> None:
|
|
93
97
|
"""Initialize the Terramind model.
|
|
94
98
|
|
|
95
99
|
Args:
|
|
96
100
|
model_size: The size of the Terramind model.
|
|
97
101
|
modalities: The modalities to use.
|
|
102
|
+
do_resizing: Whether to resize the input images to the pretraining resolution.
|
|
98
103
|
"""
|
|
99
104
|
super().__init__()
|
|
100
105
|
|
|
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
|
|
|
116
121
|
|
|
117
122
|
self.model_size = model_size
|
|
118
123
|
self.modalities = modalities
|
|
124
|
+
self.do_resizing = do_resizing
|
|
119
125
|
|
|
120
126
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
121
127
|
"""Forward pass for the Terramind model.
|
|
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
|
|
|
132
138
|
if modality not in inputs[0]:
|
|
133
139
|
continue
|
|
134
140
|
cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
|
|
141
|
+
if self.do_resizing and (
|
|
142
|
+
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
143
|
+
):
|
|
144
|
+
if cur.shape[2] == 1 and cur.shape[3] == 1:
|
|
145
|
+
new_height, new_width = PATCH_SIZE, PATCH_SIZE
|
|
146
|
+
else:
|
|
147
|
+
new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
|
|
148
|
+
cur = F.interpolate(
|
|
149
|
+
cur,
|
|
150
|
+
size=(new_height, new_width),
|
|
151
|
+
mode="bilinear",
|
|
152
|
+
align_corners=False,
|
|
153
|
+
)
|
|
135
154
|
model_inputs[modality] = cur
|
|
136
155
|
|
|
137
156
|
# By default, the patch embeddings are averaged over all modalities to reduce output tokens
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|