rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Default LightningModule for rslearn."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import os
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
@@ -7,12 +8,17 @@ import lightning as L
|
|
|
7
8
|
import torch
|
|
8
9
|
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
|
|
9
10
|
from PIL import Image
|
|
10
|
-
from torch.optim import AdamW
|
|
11
|
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
12
11
|
from upath import UPath
|
|
13
12
|
|
|
13
|
+
from rslearn.log_utils import get_logger
|
|
14
|
+
|
|
15
|
+
from .model_context import ModelContext, ModelOutput
|
|
16
|
+
from .optimizer import AdamW, OptimizerFactory
|
|
17
|
+
from .scheduler import PlateauScheduler, SchedulerFactory
|
|
14
18
|
from .tasks import Task
|
|
15
19
|
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
16
22
|
|
|
17
23
|
class RestoreConfig:
|
|
18
24
|
"""Configuration for restoring model parameters.
|
|
@@ -36,7 +42,7 @@ class RestoreConfig:
|
|
|
36
42
|
restore_path_options: additional options for the restore_path to pass to
|
|
37
43
|
fsspec.
|
|
38
44
|
selector: path in the torch dict containing the model parameters.
|
|
39
|
-
ignore_prefixes: prefixes to
|
|
45
|
+
ignore_prefixes: prefixes to ignore from the state dict.
|
|
40
46
|
remap_prefixes: list of (old_prefix, new_prefix) to rename parameters
|
|
41
47
|
starting with old_prefix to start with new_prefix instead.
|
|
42
48
|
"""
|
|
@@ -47,9 +53,9 @@ class RestoreConfig:
|
|
|
47
53
|
|
|
48
54
|
def get_state_dict(self) -> dict[str, Any]:
|
|
49
55
|
"""Returns the state dict configured in this RestoreConfig."""
|
|
50
|
-
|
|
56
|
+
logger.info(f"loading state dict from {self.restore_path}")
|
|
51
57
|
with self.restore_path.open("rb") as f:
|
|
52
|
-
state_dict = torch.load(f)
|
|
58
|
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
|
53
59
|
for k in self.selector:
|
|
54
60
|
state_dict = state_dict[k]
|
|
55
61
|
|
|
@@ -82,48 +88,71 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
82
88
|
self,
|
|
83
89
|
model: torch.nn.Module,
|
|
84
90
|
task: Task,
|
|
91
|
+
optimizer: OptimizerFactory | None = None,
|
|
92
|
+
scheduler: SchedulerFactory | None = None,
|
|
93
|
+
visualize_dir: str | None = None,
|
|
94
|
+
metrics_file: str | None = None,
|
|
95
|
+
restore_config: RestoreConfig | None = None,
|
|
96
|
+
print_parameters: bool = False,
|
|
97
|
+
print_model: bool = False,
|
|
98
|
+
# Deprecated options.
|
|
85
99
|
lr: float = 1e-3,
|
|
86
100
|
plateau: bool = False,
|
|
87
101
|
plateau_factor: float = 0.1,
|
|
88
102
|
plateau_patience: int = 10,
|
|
89
103
|
plateau_min_lr: float = 0,
|
|
90
104
|
plateau_cooldown: int = 0,
|
|
91
|
-
visualize_dir: str | None = None,
|
|
92
|
-
restore_config: RestoreConfig | None = None,
|
|
93
|
-
print_parameters: bool = False,
|
|
94
|
-
print_model: bool = False,
|
|
95
105
|
):
|
|
96
106
|
"""Initialize a new RslearnLightningModule.
|
|
97
107
|
|
|
98
108
|
Args:
|
|
99
109
|
model: the model
|
|
100
110
|
task: the task to train on
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
plateau_factor: on plateau, factor to multiply learning rate by
|
|
104
|
-
plateau_patience: number of iterations with no improvement in val loss
|
|
105
|
-
before reducing learning rate
|
|
106
|
-
plateau_min_lr: minimum learning rate to reduce to
|
|
107
|
-
plateau_cooldown: number of iterations after reducing learning rate before
|
|
108
|
-
resetting plateau scheduler
|
|
111
|
+
optimizer: the optimizer factory.
|
|
112
|
+
scheduler: the learning rate scheduler factory.
|
|
109
113
|
visualize_dir: during validation or testing, output visualizations to this
|
|
110
114
|
directory
|
|
115
|
+
metrics_file: file to save metrics to
|
|
111
116
|
restore_config: specification of configuration to restore parameters from
|
|
112
117
|
a non-Lightning checkpoint.
|
|
113
118
|
print_parameters: whether to print the list of model parameters after model
|
|
114
119
|
initialization
|
|
115
120
|
print_model: whether to print the model after model initialization
|
|
121
|
+
lr: deprecated.
|
|
122
|
+
plateau: deprecated.
|
|
123
|
+
plateau_factor: deprecated.
|
|
124
|
+
plateau_patience: deprecated.
|
|
125
|
+
plateau_min_lr: deprecated.
|
|
126
|
+
plateau_cooldown: deprecated.
|
|
116
127
|
"""
|
|
117
128
|
super().__init__()
|
|
118
129
|
self.model = model
|
|
119
130
|
self.task = task
|
|
120
|
-
self.lr = lr
|
|
121
|
-
self.plateau = plateau
|
|
122
|
-
self.plateau_factor = plateau_factor
|
|
123
|
-
self.plateau_patience = plateau_patience
|
|
124
|
-
self.plateau_min_lr = plateau_min_lr
|
|
125
|
-
self.plateau_cooldown = plateau_cooldown
|
|
126
131
|
self.visualize_dir = visualize_dir
|
|
132
|
+
self.metrics_file = metrics_file
|
|
133
|
+
self.restore_config = restore_config
|
|
134
|
+
|
|
135
|
+
self.scheduler_factory: SchedulerFactory | None = None
|
|
136
|
+
if scheduler:
|
|
137
|
+
self.scheduler_factory = scheduler
|
|
138
|
+
elif plateau:
|
|
139
|
+
logger.warning(
|
|
140
|
+
"The plateau argument to RslearnLightningModule is deprecated and will be removed in a future version"
|
|
141
|
+
)
|
|
142
|
+
self.scheduler_factory = PlateauScheduler(
|
|
143
|
+
factor=plateau_factor,
|
|
144
|
+
patience=plateau_patience,
|
|
145
|
+
min_lr=plateau_min_lr,
|
|
146
|
+
cooldown=plateau_cooldown,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if optimizer:
|
|
150
|
+
self.optimizer_factory = optimizer
|
|
151
|
+
else:
|
|
152
|
+
logger.warning(
|
|
153
|
+
"Defaulting the optimizer to AdamW since an OptimizerFactory was not provided. In a future version, the optimizer will be a required argument."
|
|
154
|
+
)
|
|
155
|
+
self.optimizer_factory = AdamW(lr=lr)
|
|
127
156
|
|
|
128
157
|
if print_parameters:
|
|
129
158
|
for name, param in self.named_parameters():
|
|
@@ -132,23 +161,26 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
132
161
|
if print_model:
|
|
133
162
|
print(self.model)
|
|
134
163
|
|
|
135
|
-
if restore_config:
|
|
136
|
-
state_dict = restore_config.get_state_dict()
|
|
137
|
-
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
138
|
-
state_dict, strict=False
|
|
139
|
-
)
|
|
140
|
-
if missing_keys or unexpected_keys:
|
|
141
|
-
print(
|
|
142
|
-
f"warning: restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
|
|
143
|
-
)
|
|
144
|
-
|
|
145
164
|
self.epochs = 0
|
|
146
165
|
|
|
147
166
|
metrics = self.task.get_metrics()
|
|
148
167
|
self.val_metrics = metrics.clone(prefix="val_")
|
|
149
168
|
self.test_metrics = metrics.clone(prefix="test_")
|
|
150
169
|
|
|
151
|
-
self.schedulers = {}
|
|
170
|
+
self.schedulers: dict = {}
|
|
171
|
+
|
|
172
|
+
def on_fit_start(self) -> None:
|
|
173
|
+
"""Called when the fit begins."""
|
|
174
|
+
# Only restore if doing a fresh fit.
|
|
175
|
+
if self.trainer.ckpt_path is None and self.restore_config:
|
|
176
|
+
state_dict = self.restore_config.get_state_dict()
|
|
177
|
+
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
178
|
+
state_dict, strict=False
|
|
179
|
+
)
|
|
180
|
+
if missing_keys or unexpected_keys:
|
|
181
|
+
logger.warning(
|
|
182
|
+
f"restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
|
|
183
|
+
)
|
|
152
184
|
|
|
153
185
|
def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
|
|
154
186
|
"""Initialize the optimizer and learning rate scheduler.
|
|
@@ -156,27 +188,37 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
156
188
|
Returns:
|
|
157
189
|
Optimizer and learning rate scheduler.
|
|
158
190
|
"""
|
|
159
|
-
|
|
160
|
-
optimizer = AdamW(params, lr=self.lr)
|
|
191
|
+
optimizer = self.optimizer_factory.build(self)
|
|
161
192
|
d = dict(
|
|
162
193
|
optimizer=optimizer,
|
|
163
194
|
)
|
|
164
|
-
if self.
|
|
165
|
-
scheduler =
|
|
166
|
-
optimizer,
|
|
167
|
-
factor=self.plateau_factor,
|
|
168
|
-
patience=self.plateau_patience,
|
|
169
|
-
min_lr=self.plateau_min_lr,
|
|
170
|
-
cooldown=self.plateau_cooldown,
|
|
171
|
-
)
|
|
195
|
+
if self.scheduler_factory is not None:
|
|
196
|
+
scheduler = self.scheduler_factory.build(optimizer)
|
|
172
197
|
d["lr_scheduler"] = {
|
|
173
198
|
"scheduler": scheduler,
|
|
174
199
|
"monitor": "train_loss",
|
|
175
200
|
"interval": "epoch",
|
|
176
201
|
}
|
|
177
|
-
self.schedulers["
|
|
202
|
+
self.schedulers["scheduler"] = scheduler
|
|
178
203
|
return d
|
|
179
204
|
|
|
205
|
+
def on_train_epoch_start(self) -> None:
|
|
206
|
+
"""If we are in a multi-dataset distributed strategy, set the epoch."""
|
|
207
|
+
try:
|
|
208
|
+
self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
|
|
209
|
+
except AttributeError:
|
|
210
|
+
# Fail silently for single-dataset case, which is okay
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
def on_test_epoch_end(self) -> None:
|
|
214
|
+
"""Optionally save the test metrics to a file."""
|
|
215
|
+
if self.metrics_file:
|
|
216
|
+
with open(self.metrics_file, "w") as f:
|
|
217
|
+
metrics = self.test_metrics.compute()
|
|
218
|
+
metrics_dict = {k: v.item() for k, v in metrics.items()}
|
|
219
|
+
json.dump(metrics_dict, f, indent=4)
|
|
220
|
+
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
221
|
+
|
|
180
222
|
def training_step(
|
|
181
223
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
182
224
|
) -> torch.Tensor:
|
|
@@ -190,9 +232,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
190
232
|
Returns:
|
|
191
233
|
The loss tensor.
|
|
192
234
|
"""
|
|
193
|
-
inputs, targets,
|
|
235
|
+
inputs, targets, metadatas = batch
|
|
236
|
+
context = ModelContext(
|
|
237
|
+
inputs=inputs,
|
|
238
|
+
metadatas=metadatas,
|
|
239
|
+
)
|
|
194
240
|
batch_size = len(inputs)
|
|
195
|
-
|
|
241
|
+
model_outputs = self(context, targets)
|
|
242
|
+
self.on_train_forward(context, targets, model_outputs)
|
|
243
|
+
|
|
244
|
+
loss_dict = model_outputs.loss_dict
|
|
196
245
|
train_loss = sum(loss_dict.values())
|
|
197
246
|
self.log_dict(
|
|
198
247
|
{"train_" + k: v for k, v in loss_dict.items()},
|
|
@@ -200,6 +249,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
200
249
|
prog_bar=True,
|
|
201
250
|
on_step=False,
|
|
202
251
|
on_epoch=True,
|
|
252
|
+
sync_dist=True,
|
|
203
253
|
)
|
|
204
254
|
self.log(
|
|
205
255
|
"train_loss",
|
|
@@ -207,6 +257,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
207
257
|
batch_size=batch_size,
|
|
208
258
|
on_step=False,
|
|
209
259
|
on_epoch=True,
|
|
260
|
+
sync_dist=True,
|
|
210
261
|
)
|
|
211
262
|
return train_loss
|
|
212
263
|
|
|
@@ -220,15 +271,24 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
220
271
|
batch_idx: Integer displaying index of this batch.
|
|
221
272
|
dataloader_idx: Index of the current dataloader.
|
|
222
273
|
"""
|
|
223
|
-
inputs, targets,
|
|
274
|
+
inputs, targets, metadatas = batch
|
|
275
|
+
context = ModelContext(
|
|
276
|
+
inputs=inputs,
|
|
277
|
+
metadatas=metadatas,
|
|
278
|
+
)
|
|
224
279
|
batch_size = len(inputs)
|
|
225
|
-
|
|
280
|
+
model_outputs = self(context, targets)
|
|
281
|
+
self.on_val_forward(context, targets, model_outputs)
|
|
282
|
+
|
|
283
|
+
loss_dict = model_outputs.loss_dict
|
|
284
|
+
outputs = model_outputs.outputs
|
|
226
285
|
val_loss = sum(loss_dict.values())
|
|
227
286
|
self.log_dict(
|
|
228
287
|
{"val_" + k: v for k, v in loss_dict.items()},
|
|
229
288
|
batch_size=batch_size,
|
|
230
289
|
on_step=False,
|
|
231
290
|
on_epoch=True,
|
|
291
|
+
sync_dist=True,
|
|
232
292
|
)
|
|
233
293
|
self.log(
|
|
234
294
|
"val_loss",
|
|
@@ -237,9 +297,12 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
237
297
|
prog_bar=True,
|
|
238
298
|
on_step=False,
|
|
239
299
|
on_epoch=True,
|
|
300
|
+
sync_dist=True,
|
|
240
301
|
)
|
|
241
302
|
self.val_metrics.update(outputs, targets)
|
|
242
|
-
self.log_dict(
|
|
303
|
+
self.log_dict(
|
|
304
|
+
self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
305
|
+
)
|
|
243
306
|
|
|
244
307
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
|
245
308
|
"""Compute the test loss and additional metrics.
|
|
@@ -250,20 +313,36 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
250
313
|
dataloader_idx: Index of the current dataloader.
|
|
251
314
|
"""
|
|
252
315
|
inputs, targets, metadatas = batch
|
|
316
|
+
context = ModelContext(
|
|
317
|
+
inputs=inputs,
|
|
318
|
+
metadatas=metadatas,
|
|
319
|
+
)
|
|
253
320
|
batch_size = len(inputs)
|
|
254
|
-
|
|
321
|
+
model_outputs = self(context, targets)
|
|
322
|
+
self.on_test_forward(context, targets, model_outputs)
|
|
323
|
+
|
|
324
|
+
loss_dict = model_outputs.loss_dict
|
|
325
|
+
outputs = model_outputs.outputs
|
|
255
326
|
test_loss = sum(loss_dict.values())
|
|
256
327
|
self.log_dict(
|
|
257
328
|
{"test_" + k: v for k, v in loss_dict.items()},
|
|
258
329
|
batch_size=batch_size,
|
|
259
330
|
on_step=False,
|
|
260
331
|
on_epoch=True,
|
|
332
|
+
sync_dist=True,
|
|
261
333
|
)
|
|
262
334
|
self.log(
|
|
263
|
-
"test_loss",
|
|
335
|
+
"test_loss",
|
|
336
|
+
test_loss,
|
|
337
|
+
batch_size=batch_size,
|
|
338
|
+
on_step=False,
|
|
339
|
+
on_epoch=True,
|
|
340
|
+
sync_dist=True,
|
|
264
341
|
)
|
|
265
342
|
self.test_metrics.update(outputs, targets)
|
|
266
|
-
self.log_dict(
|
|
343
|
+
self.log_dict(
|
|
344
|
+
self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
345
|
+
)
|
|
267
346
|
|
|
268
347
|
if self.visualize_dir:
|
|
269
348
|
for idx, (inp, target, output, metadata) in enumerate(
|
|
@@ -273,13 +352,13 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
273
352
|
for image_suffix, image in images.items():
|
|
274
353
|
out_fname = os.path.join(
|
|
275
354
|
self.visualize_dir,
|
|
276
|
-
f
|
|
355
|
+
f"{metadata['window_name']}_{metadata['bounds'][0]}_{metadata['bounds'][1]}_{image_suffix}.png",
|
|
277
356
|
)
|
|
278
357
|
Image.fromarray(image).save(out_fname)
|
|
279
358
|
|
|
280
359
|
def predict_step(
|
|
281
360
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
282
|
-
) ->
|
|
361
|
+
) -> ModelOutput:
|
|
283
362
|
"""Compute the predicted class probabilities.
|
|
284
363
|
|
|
285
364
|
Args:
|
|
@@ -290,18 +369,69 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
290
369
|
Returns:
|
|
291
370
|
Output predicted probabilities.
|
|
292
371
|
"""
|
|
293
|
-
inputs, _,
|
|
294
|
-
|
|
295
|
-
|
|
372
|
+
inputs, _, metadatas = batch
|
|
373
|
+
context = ModelContext(
|
|
374
|
+
inputs=inputs,
|
|
375
|
+
metadatas=metadatas,
|
|
376
|
+
)
|
|
377
|
+
model_outputs = self(context)
|
|
378
|
+
return model_outputs
|
|
296
379
|
|
|
297
|
-
def forward(
|
|
380
|
+
def forward(
|
|
381
|
+
self, context: ModelContext, targets: list[dict[str, Any]] | None = None
|
|
382
|
+
) -> ModelOutput:
|
|
298
383
|
"""Forward pass of the model.
|
|
299
384
|
|
|
300
385
|
Args:
|
|
301
|
-
|
|
302
|
-
|
|
386
|
+
context: the model context.
|
|
387
|
+
targets: the target dicts.
|
|
303
388
|
|
|
304
389
|
Returns:
|
|
305
390
|
Output of the model.
|
|
306
391
|
"""
|
|
307
|
-
return self.model(
|
|
392
|
+
return self.model(context, targets)
|
|
393
|
+
|
|
394
|
+
def on_train_forward(
|
|
395
|
+
self,
|
|
396
|
+
context: ModelContext,
|
|
397
|
+
targets: list[dict[str, Any]],
|
|
398
|
+
model_outputs: ModelOutput,
|
|
399
|
+
) -> None:
|
|
400
|
+
"""Hook to run after the forward pass of the model during training.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
context: The model context.
|
|
404
|
+
targets: The target batch.
|
|
405
|
+
model_outputs: The output of the model.
|
|
406
|
+
"""
|
|
407
|
+
pass
|
|
408
|
+
|
|
409
|
+
def on_val_forward(
|
|
410
|
+
self,
|
|
411
|
+
context: ModelContext,
|
|
412
|
+
targets: list[dict[str, Any]],
|
|
413
|
+
model_outputs: ModelOutput,
|
|
414
|
+
) -> None:
|
|
415
|
+
"""Hook to run after the forward pass of the model during validation.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
context: The model context.
|
|
419
|
+
targets: The target batch.
|
|
420
|
+
model_outputs: The output of the model.
|
|
421
|
+
"""
|
|
422
|
+
pass
|
|
423
|
+
|
|
424
|
+
def on_test_forward(
|
|
425
|
+
self,
|
|
426
|
+
context: ModelContext,
|
|
427
|
+
targets: list[dict[str, Any]],
|
|
428
|
+
model_outputs: ModelOutput,
|
|
429
|
+
) -> None:
|
|
430
|
+
"""Hook to run after the forward pass of the model during testing.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
context: The model context.
|
|
434
|
+
targets: The target batch.
|
|
435
|
+
model_outputs: The output of the model.
|
|
436
|
+
"""
|
|
437
|
+
pass
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Data classes to provide various context to models."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from rslearn.utils.geometry import PixelBounds, Projection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class RasterImage:
|
|
15
|
+
"""A raster image is a torch.tensor containing the images and their associated timestamps."""
|
|
16
|
+
|
|
17
|
+
# image is a 4D CTHW tensor
|
|
18
|
+
image: torch.Tensor
|
|
19
|
+
# if timestamps is not None, len(timestamps) must match the T dimension of the tensor
|
|
20
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def shape(self) -> torch.Size:
|
|
24
|
+
"""The shape of the image."""
|
|
25
|
+
return self.image.shape
|
|
26
|
+
|
|
27
|
+
def dim(self) -> int:
|
|
28
|
+
"""The dim of the image."""
|
|
29
|
+
return self.image.dim()
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> torch.dtype:
|
|
33
|
+
"""The image dtype."""
|
|
34
|
+
return self.image.dtype
|
|
35
|
+
|
|
36
|
+
def single_ts_to_chw_tensor(self) -> torch.Tensor:
|
|
37
|
+
"""Single timestep models expect single timestep inputs.
|
|
38
|
+
|
|
39
|
+
This function (1) checks this raster image only has 1 timestep and
|
|
40
|
+
(2) returns the tensor for that (single) timestep (going from CTHW to CHW).
|
|
41
|
+
"""
|
|
42
|
+
if self.image.shape[1] != 1:
|
|
43
|
+
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
|
+
return self.image[:, 0]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class SampleMetadata:
|
|
49
|
+
"""Metadata pertaining to an example."""
|
|
50
|
+
|
|
51
|
+
window_group: str
|
|
52
|
+
window_name: str
|
|
53
|
+
window_bounds: PixelBounds
|
|
54
|
+
patch_bounds: PixelBounds
|
|
55
|
+
patch_idx: int
|
|
56
|
+
num_patches_in_window: int
|
|
57
|
+
time_range: tuple[datetime, datetime] | None
|
|
58
|
+
projection: Projection
|
|
59
|
+
|
|
60
|
+
# Task name to differentiate different tasks.
|
|
61
|
+
dataset_source: str | None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ModelContext:
|
|
66
|
+
"""Context to pass to all model components."""
|
|
67
|
+
|
|
68
|
+
# One input dict per example in the batch.
|
|
69
|
+
inputs: list[dict[str, torch.Tensor | RasterImage]]
|
|
70
|
+
# One SampleMetadata per example in the batch.
|
|
71
|
+
metadatas: list[SampleMetadata]
|
|
72
|
+
# Arbitrary dict that components can add to.
|
|
73
|
+
context_dict: dict[str, Any] = field(default_factory=lambda: {})
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class ModelOutput:
|
|
78
|
+
"""The output from the Predictor.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
outputs: output compatible with the configured Task.
|
|
82
|
+
loss_dict: map from loss names to scalar tensors.
|
|
83
|
+
metadata: arbitrary dict that can be used to store other outputs.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
outputs: Iterable[Any]
|
|
87
|
+
loss_dict: dict[str, torch.Tensor]
|
|
88
|
+
metadata: dict[str, Any] = field(default_factory=lambda: {})
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Optimizers for rslearn."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, dataclass
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
import torch.optim
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OptimizerFactory:
|
|
11
|
+
"""A factory class that initializes the optimizer given the LightningModule."""
|
|
12
|
+
|
|
13
|
+
def build(self, lm: L.LightningModule) -> Optimizer:
|
|
14
|
+
"""Build the optimizer configured by this factory class."""
|
|
15
|
+
raise NotImplementedError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class AdamW(OptimizerFactory):
|
|
20
|
+
"""Factory for AdamW optimzier."""
|
|
21
|
+
|
|
22
|
+
lr: float = 0.001
|
|
23
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
24
|
+
eps: float | None = None
|
|
25
|
+
weight_decay: float | None = None
|
|
26
|
+
|
|
27
|
+
def build(self, lm: L.LightningModule) -> Optimizer:
|
|
28
|
+
"""Build the AdamW optimizer."""
|
|
29
|
+
params = [p for p in lm.parameters() if p.requires_grad]
|
|
30
|
+
kwargs = {k: v for k, v in asdict(self).items() if v is not None}
|
|
31
|
+
return torch.optim.AdamW(params, **kwargs)
|