rslearn 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +23 -4
- rslearn/data_sources/planetary_computer.py +52 -0
- rslearn/dataset/handler_summaries.py +1 -0
- rslearn/dataset/manage.py +16 -2
- rslearn/models/anysat.py +5 -1
- rslearn/models/dinov3.py +6 -1
- rslearn/models/feature_center_crop.py +50 -0
- rslearn/models/olmoearth_pretrain/model.py +88 -27
- rslearn/models/prithvi.py +9 -1
- rslearn/train/lightning_module.py +0 -3
- rslearn/train/prediction_writer.py +25 -8
- rslearn/train/tasks/classification.py +2 -2
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/embedding.py +116 -0
- rslearn/train/tasks/per_pixel_regression.py +5 -4
- rslearn/train/tasks/regression.py +5 -5
- rslearn/train/transforms/pad.py +3 -3
- rslearn/utils/raster_format.py +38 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/METADATA +3 -2
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/RECORD +25 -31
- rslearn-0.0.13.dist-info/licenses/NOTICE +115 -0
- rslearn/models/copernicusfm.py +0 -228
- rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/WHEEL +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py
CHANGED
|
@@ -125,7 +125,8 @@ class BandSetConfig:
|
|
|
125
125
|
self,
|
|
126
126
|
config_dict: dict[str, Any],
|
|
127
127
|
dtype: DType,
|
|
128
|
-
bands: list[str],
|
|
128
|
+
bands: list[str] | None = None,
|
|
129
|
+
num_bands: int | None = None,
|
|
129
130
|
format: dict[str, Any] | None = None,
|
|
130
131
|
zoom_offset: int = 0,
|
|
131
132
|
remap: dict[str, Any] | None = None,
|
|
@@ -137,7 +138,10 @@ class BandSetConfig:
|
|
|
137
138
|
Args:
|
|
138
139
|
config_dict: the config dict used to configure this BandSetConfig
|
|
139
140
|
dtype: the pixel value type to store tiles in
|
|
140
|
-
bands: list of band names in this BandSetConfig
|
|
141
|
+
bands: list of band names in this BandSetConfig. One of bands or num_bands
|
|
142
|
+
must be set.
|
|
143
|
+
num_bands: the number of bands in this band set. The bands will be named
|
|
144
|
+
B00, B01, B02, etc.
|
|
141
145
|
format: the format to store tiles in, defaults to geotiff
|
|
142
146
|
zoom_offset: store images at a resolution higher or lower than the window
|
|
143
147
|
resolution. This enables keeping source data at its native resolution,
|
|
@@ -155,6 +159,14 @@ class BandSetConfig:
|
|
|
155
159
|
materialization when creating mosaics, to determine which parts of the
|
|
156
160
|
source images should be copied.
|
|
157
161
|
"""
|
|
162
|
+
if (bands is None and num_bands is None) or (
|
|
163
|
+
bands is not None and num_bands is not None
|
|
164
|
+
):
|
|
165
|
+
raise ValueError("exactly one of bands and num_bands must be set")
|
|
166
|
+
if bands is None:
|
|
167
|
+
assert num_bands is not None
|
|
168
|
+
bands = [f"B{idx}" for idx in range(num_bands)]
|
|
169
|
+
|
|
158
170
|
if class_names is not None and len(bands) != len(class_names):
|
|
159
171
|
raise ValueError(
|
|
160
172
|
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
@@ -187,9 +199,16 @@ class BandSetConfig:
|
|
|
187
199
|
kwargs = dict(
|
|
188
200
|
config_dict=config,
|
|
189
201
|
dtype=DType(config["dtype"]),
|
|
190
|
-
bands=config["bands"],
|
|
191
202
|
)
|
|
192
|
-
for k in [
|
|
203
|
+
for k in [
|
|
204
|
+
"bands",
|
|
205
|
+
"num_bands",
|
|
206
|
+
"format",
|
|
207
|
+
"zoom_offset",
|
|
208
|
+
"remap",
|
|
209
|
+
"class_names",
|
|
210
|
+
"nodata_vals",
|
|
211
|
+
]:
|
|
193
212
|
if k in config:
|
|
194
213
|
kwargs[k] = config[k]
|
|
195
214
|
return BandSetConfig(**kwargs) # type: ignore
|
|
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
|
|
|
827
827
|
kwargs[k] = d[k]
|
|
828
828
|
|
|
829
829
|
return Sentinel1(**kwargs)
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
class Naip(PlanetaryComputer):
|
|
833
|
+
"""A data source for NAIP data on Microsoft Planetary Computer.
|
|
834
|
+
|
|
835
|
+
See https://planetarycomputer.microsoft.com/dataset/naip.
|
|
836
|
+
"""
|
|
837
|
+
|
|
838
|
+
COLLECTION_NAME = "naip"
|
|
839
|
+
ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
|
|
840
|
+
|
|
841
|
+
def __init__(
|
|
842
|
+
self,
|
|
843
|
+
**kwargs: Any,
|
|
844
|
+
):
|
|
845
|
+
"""Initialize a new Naip instance.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
band_names: list of bands to try to ingest.
|
|
849
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
850
|
+
"""
|
|
851
|
+
super().__init__(
|
|
852
|
+
collection_name=self.COLLECTION_NAME,
|
|
853
|
+
asset_bands=self.ASSET_BANDS,
|
|
854
|
+
**kwargs,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
@staticmethod
|
|
858
|
+
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
|
|
859
|
+
"""Creates a new Naip instance from a configuration dictionary."""
|
|
860
|
+
if config.data_source is None:
|
|
861
|
+
raise ValueError("config.data_source is required")
|
|
862
|
+
d = config.data_source.config_dict
|
|
863
|
+
kwargs = {}
|
|
864
|
+
|
|
865
|
+
if "timeout_seconds" in d:
|
|
866
|
+
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
867
|
+
|
|
868
|
+
if "cache_dir" in d:
|
|
869
|
+
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
870
|
+
|
|
871
|
+
simple_optionals = [
|
|
872
|
+
"query",
|
|
873
|
+
"sort_by",
|
|
874
|
+
"sort_ascending",
|
|
875
|
+
"max_items_per_client",
|
|
876
|
+
]
|
|
877
|
+
for k in simple_optionals:
|
|
878
|
+
if k in d:
|
|
879
|
+
kwargs[k] = d[k]
|
|
880
|
+
|
|
881
|
+
return Naip(**kwargs)
|
rslearn/dataset/manage.py
CHANGED
|
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
|
|
|
118
118
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
119
119
|
windows_prepared=0,
|
|
120
120
|
windows_skipped=len(windows),
|
|
121
|
+
windows_rejected=0,
|
|
121
122
|
get_items_attempts=0,
|
|
122
123
|
)
|
|
123
124
|
)
|
|
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
|
|
|
141
142
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
142
143
|
windows_prepared=0,
|
|
143
144
|
windows_skipped=len(windows),
|
|
145
|
+
windows_rejected=0,
|
|
144
146
|
get_items_attempts=0,
|
|
145
147
|
)
|
|
146
148
|
)
|
|
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
|
|
|
181
183
|
attempts_counter=attempts_counter,
|
|
182
184
|
)
|
|
183
185
|
|
|
186
|
+
windows_prepared = 0
|
|
187
|
+
windows_rejected = 0
|
|
188
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
184
189
|
for window, result in zip(needed_windows, results):
|
|
185
190
|
layer_datas = window.load_layer_datas()
|
|
186
191
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
|
|
|
191
196
|
)
|
|
192
197
|
window.save_layer_datas(layer_datas)
|
|
193
198
|
|
|
199
|
+
# If result is empty and min_matches > 0, window was rejected due to min_matches
|
|
200
|
+
if len(result) == 0 and min_matches > 0:
|
|
201
|
+
windows_rejected += 1
|
|
202
|
+
else:
|
|
203
|
+
windows_prepared += 1
|
|
204
|
+
|
|
205
|
+
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
+
|
|
194
207
|
layer_summaries.append(
|
|
195
208
|
LayerPrepareSummary(
|
|
196
209
|
layer_name=layer_name,
|
|
197
210
|
data_source_name=data_source_cfg.name,
|
|
198
211
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
199
|
-
windows_prepared=
|
|
200
|
-
windows_skipped=
|
|
212
|
+
windows_prepared=windows_prepared,
|
|
213
|
+
windows_skipped=windows_skipped,
|
|
214
|
+
windows_rejected=windows_rejected,
|
|
201
215
|
get_items_attempts=attempts_counter.value,
|
|
202
216
|
)
|
|
203
217
|
)
|
rslearn/models/anysat.py
CHANGED
rslearn/models/dinov3.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
-
"""DinoV3 model.
|
|
1
|
+
"""DinoV3 model.
|
|
2
|
+
|
|
3
|
+
This code loads the DINOv3 model. You must obtain the model separately from Meta to use
|
|
4
|
+
it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
|
|
5
|
+
information.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
from enum import StrEnum
|
|
4
9
|
from pathlib import Path
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Apply center cropping on a feature map."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FeatureCenterCrop(torch.nn.Module):
|
|
9
|
+
"""Apply center cropping on the input feature maps."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
sizes: list[tuple[int, int]],
|
|
14
|
+
) -> None:
|
|
15
|
+
"""Create a new FeatureCenterCrop.
|
|
16
|
+
|
|
17
|
+
Only the center of each feature map will be retained and passed to the next
|
|
18
|
+
module.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
sizes: a list of (height, width) tuples, with one tuple for each input
|
|
22
|
+
feature map.
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.sizes = sizes
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
+
) -> list[torch.Tensor]:
|
|
30
|
+
"""Apply center cropping on the feature maps.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
features: list of feature maps at different resolutions.
|
|
34
|
+
inputs: original inputs (ignored).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
center cropped feature maps.
|
|
38
|
+
"""
|
|
39
|
+
new_features = []
|
|
40
|
+
for i, feat in enumerate(features):
|
|
41
|
+
height, width = self.sizes[i]
|
|
42
|
+
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"feature map is smaller than the desired height and width"
|
|
45
|
+
)
|
|
46
|
+
start_h = feat.shape[2] // 2 - height // 2
|
|
47
|
+
start_w = feat.shape[3] // 2 - width // 2
|
|
48
|
+
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
|
+
new_features.append(feat)
|
|
50
|
+
return new_features
|
|
@@ -9,6 +9,11 @@ from einops import rearrange
|
|
|
9
9
|
from olmo_core.config import Config
|
|
10
10
|
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
11
|
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.model_loader import (
|
|
13
|
+
ModelID,
|
|
14
|
+
load_model_from_id,
|
|
15
|
+
load_model_from_path,
|
|
16
|
+
)
|
|
12
17
|
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
13
18
|
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
14
19
|
from upath import UPath
|
|
@@ -31,54 +36,115 @@ AUTOCAST_DTYPE_MAP = {
|
|
|
31
36
|
"float32": torch.float32,
|
|
32
37
|
}
|
|
33
38
|
|
|
39
|
+
EMBEDDING_SIZES = {
|
|
40
|
+
ModelID.OLMOEARTH_V1_NANO: 128,
|
|
41
|
+
ModelID.OLMOEARTH_V1_TINY: 192,
|
|
42
|
+
ModelID.OLMOEARTH_V1_BASE: 768,
|
|
43
|
+
ModelID.OLMOEARTH_V1_LARGE: 1024,
|
|
44
|
+
}
|
|
45
|
+
|
|
34
46
|
|
|
35
47
|
class OlmoEarth(torch.nn.Module):
|
|
36
48
|
"""A wrapper to support the OlmoEarth model."""
|
|
37
49
|
|
|
38
50
|
def __init__(
|
|
39
51
|
self,
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
52
|
+
patch_size: int,
|
|
53
|
+
model_id: ModelID | None = None,
|
|
54
|
+
model_path: str | None = None,
|
|
55
|
+
checkpoint_path: str | None = None,
|
|
56
|
+
selector: list[str | int] = ["encoder"],
|
|
44
57
|
forward_kwargs: dict[str, Any] = {},
|
|
45
58
|
random_initialization: bool = False,
|
|
46
59
|
embedding_size: int | None = None,
|
|
47
|
-
patch_size: int | None = None,
|
|
48
60
|
autocast_dtype: str | None = "bfloat16",
|
|
49
61
|
):
|
|
50
62
|
"""Create a new OlmoEarth model.
|
|
51
63
|
|
|
52
64
|
Args:
|
|
53
|
-
|
|
54
|
-
|
|
65
|
+
patch_size: token spatial patch size to use.
|
|
66
|
+
model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
|
|
67
|
+
set.
|
|
68
|
+
model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
|
|
69
|
+
set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
|
|
70
|
+
checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
|
|
71
|
+
set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
|
|
72
|
+
folder.
|
|
55
73
|
selector: an optional sequence of attribute names or list indices to select
|
|
56
|
-
the sub-module that should be applied on the input images.
|
|
74
|
+
the sub-module that should be applied on the input images. Defaults to
|
|
75
|
+
["encoder"] to select only the transformer encoder.
|
|
57
76
|
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
58
77
|
MaskedOlmoEarthSample.
|
|
59
78
|
random_initialization: whether to skip loading the checkpoint so the
|
|
60
79
|
weights are randomly initialized. In this case, the checkpoint is only
|
|
61
80
|
used to define the model architecture.
|
|
62
81
|
embedding_size: optional embedding size to report via
|
|
63
|
-
get_backbone_channels.
|
|
64
|
-
patch_size: optional patch size to report via get_backbone_channels.
|
|
82
|
+
get_backbone_channels (if model_id is not set).
|
|
65
83
|
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
66
84
|
"""
|
|
85
|
+
if (
|
|
86
|
+
sum(
|
|
87
|
+
[
|
|
88
|
+
model_id is not None,
|
|
89
|
+
model_path is not None,
|
|
90
|
+
checkpoint_path is not None,
|
|
91
|
+
]
|
|
92
|
+
)
|
|
93
|
+
!= 1
|
|
94
|
+
):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"exactly one of model_id, model_path, or checkpoint_path must be set"
|
|
97
|
+
)
|
|
98
|
+
|
|
67
99
|
super().__init__()
|
|
68
|
-
|
|
100
|
+
self.patch_size = patch_size
|
|
69
101
|
self.forward_kwargs = forward_kwargs
|
|
70
102
|
self.embedding_size = embedding_size
|
|
71
|
-
self.patch_size = patch_size
|
|
72
103
|
|
|
73
104
|
if autocast_dtype is not None:
|
|
74
105
|
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
75
106
|
else:
|
|
76
107
|
self.autocast_dtype = None
|
|
77
108
|
|
|
109
|
+
if model_id is not None:
|
|
110
|
+
# Load from Hugging Face.
|
|
111
|
+
model = load_model_from_id(model_id, load_weights=not random_initialization)
|
|
112
|
+
if self.embedding_size is None and model_id in EMBEDDING_SIZES:
|
|
113
|
+
self.embedding_size = EMBEDDING_SIZES[model_id]
|
|
114
|
+
|
|
115
|
+
elif model_path is not None:
|
|
116
|
+
# Load from path.
|
|
117
|
+
model = load_model_from_path(
|
|
118
|
+
UPath(model_path), load_weights=not random_initialization
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
# Load the distributed model checkpoint by path through Olmo Core
|
|
123
|
+
model = self._load_model_from_checkpoint(
|
|
124
|
+
UPath(checkpoint_path), random_initialization
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Select just the portion of the model that we actually want to use.
|
|
128
|
+
for part in selector:
|
|
129
|
+
if isinstance(part, str):
|
|
130
|
+
model = getattr(model, part)
|
|
131
|
+
else:
|
|
132
|
+
model = model[part]
|
|
133
|
+
self.model = model
|
|
134
|
+
|
|
135
|
+
def _load_model_from_checkpoint(
|
|
136
|
+
self, checkpoint_upath: UPath, random_initialization: bool
|
|
137
|
+
) -> torch.nn.Module:
|
|
138
|
+
"""Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
|
|
139
|
+
|
|
140
|
+
The folder should contain config.json as well as the model_and_optim folder
|
|
141
|
+
that contains the distributed checkpoint. This is the format produced by
|
|
142
|
+
pre-training runs in olmoearth_pretrain.
|
|
143
|
+
"""
|
|
78
144
|
# Load the model config and initialize it.
|
|
79
145
|
# We avoid loading the train module here because it depends on running within
|
|
80
146
|
# olmo_core.
|
|
81
|
-
with (
|
|
147
|
+
with (checkpoint_upath / "config.json").open() as f:
|
|
82
148
|
config_dict = json.load(f)
|
|
83
149
|
model_config = Config.from_dict(config_dict["model"])
|
|
84
150
|
|
|
@@ -86,22 +152,14 @@ class OlmoEarth(torch.nn.Module):
|
|
|
86
152
|
|
|
87
153
|
# Load the checkpoint.
|
|
88
154
|
if not random_initialization:
|
|
89
|
-
train_module_dir =
|
|
155
|
+
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
90
156
|
if train_module_dir.exists():
|
|
91
157
|
load_model_and_optim_state(str(train_module_dir), model)
|
|
92
158
|
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
93
159
|
else:
|
|
94
160
|
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
95
|
-
else:
|
|
96
|
-
logger.info("skipping loading OlmoEarth encoder")
|
|
97
161
|
|
|
98
|
-
|
|
99
|
-
for part in selector:
|
|
100
|
-
if isinstance(part, str):
|
|
101
|
-
model = getattr(model, part)
|
|
102
|
-
else:
|
|
103
|
-
model = model[part]
|
|
104
|
-
self.model = model
|
|
162
|
+
return model
|
|
105
163
|
|
|
106
164
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
107
165
|
"""Compute feature maps from the OlmoEarth backbone.
|
|
@@ -167,13 +225,16 @@ class OlmoEarth(torch.nn.Module):
|
|
|
167
225
|
if isinstance(self.model, Encoder):
|
|
168
226
|
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
169
227
|
tokens_and_masks = self.model(
|
|
170
|
-
sample,
|
|
228
|
+
sample,
|
|
229
|
+
fast_pass=True,
|
|
230
|
+
patch_size=self.patch_size,
|
|
231
|
+
**self.forward_kwargs,
|
|
171
232
|
)["tokens_and_masks"]
|
|
172
233
|
else:
|
|
173
234
|
# Other models like STEncoder do not have this option supported.
|
|
174
|
-
tokens_and_masks = self.model(
|
|
175
|
-
|
|
176
|
-
]
|
|
235
|
+
tokens_and_masks = self.model(
|
|
236
|
+
sample, patch_size=self.patch_size, **self.forward_kwargs
|
|
237
|
+
)["tokens_and_masks"]
|
|
177
238
|
|
|
178
239
|
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
179
240
|
features = []
|
rslearn/models/prithvi.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
"""Prithvi V2.
|
|
1
|
+
"""Prithvi V2.
|
|
2
|
+
|
|
3
|
+
This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
|
|
4
|
+
|
|
5
|
+
The code is released under:
|
|
6
|
+
|
|
7
|
+
MIT License
|
|
8
|
+
Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
|
|
9
|
+
"""
|
|
2
10
|
|
|
3
11
|
import json
|
|
4
12
|
import logging
|
|
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
94
94
|
restore_config: RestoreConfig | None = None,
|
|
95
95
|
print_parameters: bool = False,
|
|
96
96
|
print_model: bool = False,
|
|
97
|
-
strict_loading: bool = True,
|
|
98
97
|
# Deprecated options.
|
|
99
98
|
lr: float = 1e-3,
|
|
100
99
|
plateau: bool = False,
|
|
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
118
117
|
print_parameters: whether to print the list of model parameters after model
|
|
119
118
|
initialization
|
|
120
119
|
print_model: whether to print the model after model initialization
|
|
121
|
-
strict_loading: whether to strictly load the model parameters.
|
|
122
120
|
lr: deprecated.
|
|
123
121
|
plateau: deprecated.
|
|
124
122
|
plateau_factor: deprecated.
|
|
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
130
|
self.visualize_dir = visualize_dir
|
|
133
131
|
self.metrics_file = metrics_file
|
|
134
132
|
self.restore_config = restore_config
|
|
135
|
-
self.strict_loading = strict_loading
|
|
136
133
|
|
|
137
134
|
self.scheduler_factory: SchedulerFactory | None = None
|
|
138
135
|
if scheduler:
|
|
@@ -22,7 +22,11 @@ from rslearn.log_utils import get_logger
|
|
|
22
22
|
from rslearn.utils.array import copy_spatial_array
|
|
23
23
|
from rslearn.utils.feature import Feature
|
|
24
24
|
from rslearn.utils.geometry import PixelBounds
|
|
25
|
-
from rslearn.utils.raster_format import
|
|
25
|
+
from rslearn.utils.raster_format import (
|
|
26
|
+
RasterFormat,
|
|
27
|
+
adjust_projection_and_bounds_for_array,
|
|
28
|
+
load_raster_format,
|
|
29
|
+
)
|
|
26
30
|
from rslearn.utils.vector_format import VectorFormat, load_vector_format
|
|
27
31
|
|
|
28
32
|
from .lightning_module import RslearnLightningModule
|
|
@@ -68,15 +72,18 @@ class VectorMerger(PatchPredictionMerger):
|
|
|
68
72
|
class RasterMerger(PatchPredictionMerger):
|
|
69
73
|
"""Merger for raster data that copies the rasters to the output."""
|
|
70
74
|
|
|
71
|
-
def __init__(self, padding: int | None = None):
|
|
75
|
+
def __init__(self, padding: int | None = None, downsample_factor: int = 1):
|
|
72
76
|
"""Create a new RasterMerger.
|
|
73
77
|
|
|
74
78
|
Args:
|
|
75
79
|
padding: the padding around the individual patch outputs to remove. This is
|
|
76
80
|
typically used when leveraging overlapping patches. Portions of outputs
|
|
77
81
|
at the border of the window will still be retained.
|
|
82
|
+
downsample_factor: the factor by which the rasters output by the task are
|
|
83
|
+
lower in resolution relative to the window resolution.
|
|
78
84
|
"""
|
|
79
85
|
self.padding = padding
|
|
86
|
+
self.downsample_factor = downsample_factor
|
|
80
87
|
|
|
81
88
|
def merge(
|
|
82
89
|
self, window: Window, outputs: Sequence[PendingPatchOutput]
|
|
@@ -87,8 +94,8 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
87
94
|
merged_image = np.zeros(
|
|
88
95
|
(
|
|
89
96
|
num_channels,
|
|
90
|
-
window.bounds[3] - window.bounds[1],
|
|
91
|
-
window.bounds[2] - window.bounds[0],
|
|
97
|
+
(window.bounds[3] - window.bounds[1]) // self.downsample_factor,
|
|
98
|
+
(window.bounds[2] - window.bounds[0]) // self.downsample_factor,
|
|
92
99
|
),
|
|
93
100
|
dtype=dtype,
|
|
94
101
|
)
|
|
@@ -104,7 +111,10 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
104
111
|
# If the output is not on the left or top boundary, then we should apply
|
|
105
112
|
# the padding (if set).
|
|
106
113
|
src = output.output
|
|
107
|
-
src_offset = (
|
|
114
|
+
src_offset = (
|
|
115
|
+
output.bounds[0] // self.downsample_factor,
|
|
116
|
+
output.bounds[1] // self.downsample_factor,
|
|
117
|
+
)
|
|
108
118
|
if self.padding is not None and output.bounds[0] != window.bounds[0]:
|
|
109
119
|
src = src[:, :, self.padding :]
|
|
110
120
|
src_offset = (src_offset[0] + self.padding, src_offset[1])
|
|
@@ -116,7 +126,10 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
116
126
|
src=src,
|
|
117
127
|
dst=merged_image,
|
|
118
128
|
src_offset=src_offset,
|
|
119
|
-
dst_offset=(
|
|
129
|
+
dst_offset=(
|
|
130
|
+
window.bounds[0] // self.downsample_factor,
|
|
131
|
+
window.bounds[1] // self.downsample_factor,
|
|
132
|
+
),
|
|
120
133
|
)
|
|
121
134
|
|
|
122
135
|
return merged_image
|
|
@@ -330,9 +343,13 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
330
343
|
self.output_layer, self.layer_config.band_sets[0].bands
|
|
331
344
|
)
|
|
332
345
|
assert isinstance(self.format, RasterFormat)
|
|
333
|
-
|
|
334
|
-
|
|
346
|
+
|
|
347
|
+
# In case the merged_output is at a different resolution than the window,
|
|
348
|
+
# get adjusted projection and bounds for writing it.
|
|
349
|
+
projection, bounds = adjust_projection_and_bounds_for_array(
|
|
350
|
+
window.projection, window.bounds, merged_output
|
|
335
351
|
)
|
|
352
|
+
self.format.encode_raster(raster_dir, projection, bounds, merged_output)
|
|
336
353
|
|
|
337
354
|
elif self.layer_config.layer_type == LayerType.VECTOR:
|
|
338
355
|
layer_dir = window.get_layer_dir(self.output_layer)
|
|
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
|
|
|
49
49
|
features with matching properties.
|
|
50
50
|
read_class_id: whether to read an integer class ID instead of the class
|
|
51
51
|
name.
|
|
52
|
-
allow_invalid: instead of throwing error when no
|
|
53
|
-
at a window, simply mark the example invalid for this task
|
|
52
|
+
allow_invalid: instead of throwing error when no classification label is
|
|
53
|
+
found at a window, simply mark the example invalid for this task
|
|
54
54
|
skip_unknown_categories: whether to skip examples with categories that are
|
|
55
55
|
not passed via classes, instead of throwing error
|
|
56
56
|
prob_property: when predicting, write probabilities in addition to class ID
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -72,11 +72,11 @@ class DetectionTask(BasicTask):
|
|
|
72
72
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
73
73
|
**kwargs: Any,
|
|
74
74
|
) -> None:
|
|
75
|
-
"""Initialize a new
|
|
75
|
+
"""Initialize a new DetectionTask.
|
|
76
76
|
|
|
77
77
|
Args:
|
|
78
|
-
property_name: the property from which to extract the class name.
|
|
79
|
-
|
|
78
|
+
property_name: the property from which to extract the class name. Features
|
|
79
|
+
without this property name are ignored.
|
|
80
80
|
classes: a list of class names.
|
|
81
81
|
filters: optional list of (property_name, property_value) to only consider
|
|
82
82
|
features with matching properties.
|
|
@@ -86,8 +86,8 @@ class DetectionTask(BasicTask):
|
|
|
86
86
|
not passed via classes, instead of throwing error
|
|
87
87
|
skip_empty_examples: whether to skip examples with zero labels.
|
|
88
88
|
colors: optional colors for each class
|
|
89
|
-
box_size: force all boxes to be this size, centered at the
|
|
90
|
-
geometry. Required for Point geometries.
|
|
89
|
+
box_size: force all boxes to be two times this size, centered at the
|
|
90
|
+
centroid of the geometry. Required for Point geometries.
|
|
91
91
|
clip_boxes: whether to clip boxes to the image bounds.
|
|
92
92
|
exclude_by_center: before optionally clipping boxes, exclude boxes if the
|
|
93
93
|
center is outside the image bounds.
|