rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
rslearn/lightning_cli.py
CHANGED
|
@@ -1,12 +1,107 @@
|
|
|
1
1
|
"""LightningCLI for rslearn."""
|
|
2
2
|
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import shutil
|
|
3
7
|
import sys
|
|
8
|
+
import tempfile
|
|
4
9
|
|
|
10
|
+
import fsspec
|
|
11
|
+
import jsonargparse
|
|
12
|
+
import wandb
|
|
13
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
14
|
+
from lightning.pytorch.callbacks import Callback
|
|
5
15
|
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
16
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
17
|
+
from upath import UPath
|
|
6
18
|
|
|
7
19
|
from rslearn.arg_parser import RslearnArgumentParser
|
|
20
|
+
from rslearn.log_utils import get_logger
|
|
8
21
|
from rslearn.train.data_module import RslearnDataModule
|
|
9
22
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
23
|
+
from rslearn.utils.fsspec import open_atomic
|
|
24
|
+
|
|
25
|
+
WANDB_ID_FNAME = "wandb_id"
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_cached_checkpoint(checkpoint_fname: UPath) -> str:
|
|
31
|
+
"""Get a local cached version of the specified checkpoint.
|
|
32
|
+
|
|
33
|
+
If checkpoint_fname is already local, then it is returned. Otherwise, it is saved
|
|
34
|
+
in a deterministic local cache directory under the system temporary directory, and
|
|
35
|
+
the cached filename is returned.
|
|
36
|
+
|
|
37
|
+
Note that the cache is not deleted when the program exits.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
checkpoint_fname: the potentially non-local checkpoint file to load.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
a local filename containing the same checkpoint.
|
|
44
|
+
"""
|
|
45
|
+
is_local = isinstance(
|
|
46
|
+
checkpoint_fname.fs, fsspec.implementations.local.LocalFileSystem
|
|
47
|
+
)
|
|
48
|
+
if is_local:
|
|
49
|
+
return checkpoint_fname.path
|
|
50
|
+
|
|
51
|
+
cache_id = hashlib.sha256(str(checkpoint_fname).encode()).hexdigest()
|
|
52
|
+
local_fname = os.path.join(
|
|
53
|
+
tempfile.gettempdir(), "rslearn_cache", "checkpoints", f"{cache_id}.ckpt"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if os.path.exists(local_fname):
|
|
57
|
+
logger.info(
|
|
58
|
+
"using cached checkpoint for %s at %s", str(checkpoint_fname), local_fname
|
|
59
|
+
)
|
|
60
|
+
return local_fname
|
|
61
|
+
|
|
62
|
+
logger.info("caching checkpoint %s to %s", str(checkpoint_fname), local_fname)
|
|
63
|
+
os.makedirs(os.path.dirname(local_fname), exist_ok=True)
|
|
64
|
+
with checkpoint_fname.open("rb") as src:
|
|
65
|
+
with open_atomic(UPath(local_fname), "wb") as dst:
|
|
66
|
+
shutil.copyfileobj(src, dst)
|
|
67
|
+
|
|
68
|
+
return local_fname
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class SaveWandbRunIdCallback(Callback):
|
|
72
|
+
"""Callback to save the wandb run ID to project directory in case of resume."""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
project_dir: str,
|
|
77
|
+
config_str: str,
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Create a new SaveWandbRunIdCallback.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
project_dir: the project directory.
|
|
83
|
+
config_str: the JSON-encoded configuration of this experiment
|
|
84
|
+
"""
|
|
85
|
+
self.project_dir = project_dir
|
|
86
|
+
self.config_str = config_str
|
|
87
|
+
|
|
88
|
+
@rank_zero_only
|
|
89
|
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
90
|
+
"""Called just before fit starts I think.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
trainer: the Trainer object.
|
|
94
|
+
pl_module: the LightningModule object.
|
|
95
|
+
"""
|
|
96
|
+
wandb_id = wandb.run.id
|
|
97
|
+
|
|
98
|
+
project_dir = UPath(self.project_dir)
|
|
99
|
+
project_dir.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
with (project_dir / WANDB_ID_FNAME).open("w") as f:
|
|
101
|
+
f.write(wandb_id)
|
|
102
|
+
|
|
103
|
+
if self.config_str is not None and "project_name" not in wandb.config:
|
|
104
|
+
wandb.config.update(json.loads(self.config_str))
|
|
10
105
|
|
|
11
106
|
|
|
12
107
|
class RslearnLightningCLI(LightningCLI):
|
|
@@ -23,6 +118,266 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
23
118
|
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
24
119
|
)
|
|
25
120
|
|
|
121
|
+
# Project management option to have rslearn manage checkpoints and W&B run.
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--management_dir",
|
|
124
|
+
type=str | None,
|
|
125
|
+
help="Enable project management, and use this directory to store checkpoints and configs. If enabled, rslearn will automatically manages checkpoint directory/loading and W&B run",
|
|
126
|
+
default=None,
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--project_name",
|
|
130
|
+
type=str | None,
|
|
131
|
+
help="The project name (used with --management_dir)",
|
|
132
|
+
default=None,
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--run_name",
|
|
136
|
+
type=str | None,
|
|
137
|
+
help="A unique name for this experiment (used with --management_dir)",
|
|
138
|
+
default=None,
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--run_description",
|
|
142
|
+
type=str,
|
|
143
|
+
help="Optional description of this experiment (used with --management_dir)",
|
|
144
|
+
default="",
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--load_checkpoint_mode",
|
|
148
|
+
type=str,
|
|
149
|
+
help="Which checkpoint to load, if any (used with --management_dir). 'none' never loads any checkpoint, 'last' loads the most recent checkpoint, and 'best' loads the best checkpoint. 'auto' will use 'last' during fit and 'best' during val/test/predict.",
|
|
150
|
+
default="auto",
|
|
151
|
+
)
|
|
152
|
+
parser.add_argument(
|
|
153
|
+
"--load_checkpoint_required",
|
|
154
|
+
type=str,
|
|
155
|
+
help="Whether to fail if the expected checkpoint based on load_checkpoint_mode does not exist (used with --management_dir). 'yes' will fail while 'no' won't. 'auto' will use 'no' during fit and 'yes' during val/test/predict.",
|
|
156
|
+
default="auto",
|
|
157
|
+
)
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--log_mode",
|
|
160
|
+
type=str,
|
|
161
|
+
help="Whether to log to W&B (used with --management_dir). 'yes' will enable logging, 'no' will disable logging, and 'auto' will use 'yes' during fit and 'no' during val/test/predict.",
|
|
162
|
+
default="auto",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def _get_checkpoint_path(
|
|
166
|
+
self,
|
|
167
|
+
project_dir: UPath,
|
|
168
|
+
load_checkpoint_mode: str,
|
|
169
|
+
load_checkpoint_required: str,
|
|
170
|
+
stage: str,
|
|
171
|
+
) -> str | None:
|
|
172
|
+
"""Get path to checkpoint to load from, or None to not restore checkpoint.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
project_dir: the project directory determined from the project management
|
|
176
|
+
directory.
|
|
177
|
+
load_checkpoint_mode: "none" to not load any checkpoint, "last" to load the
|
|
178
|
+
most recent checkpoint, "best" to load the best checkpoint. "auto" to
|
|
179
|
+
use "last" during fit and "best" during val/test/predict.
|
|
180
|
+
load_checkpoint_required: "yes" to fail if no checkpoint exists, "no" to
|
|
181
|
+
ignore. "auto" will use "no" during fit and "yes" during
|
|
182
|
+
val/test/predict.
|
|
183
|
+
stage: the lightning stage (fit/val/test/predict).
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
the path to the checkpoint for setting c.ckpt_path, or None if no
|
|
187
|
+
checkpoint should be restored.
|
|
188
|
+
"""
|
|
189
|
+
# Resolve auto options if used.
|
|
190
|
+
if load_checkpoint_mode == "auto":
|
|
191
|
+
if stage == "fit":
|
|
192
|
+
load_checkpoint_mode = "last"
|
|
193
|
+
else:
|
|
194
|
+
load_checkpoint_mode = "best"
|
|
195
|
+
if load_checkpoint_required == "auto":
|
|
196
|
+
if stage == "fit":
|
|
197
|
+
load_checkpoint_required = "no"
|
|
198
|
+
else:
|
|
199
|
+
load_checkpoint_required = "yes"
|
|
200
|
+
|
|
201
|
+
if load_checkpoint_required == "yes" and load_checkpoint_mode == "none":
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"load_checkpoint_required cannot be set when load_checkpoint_mode is none"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
ckpt_path: str | None = None
|
|
207
|
+
|
|
208
|
+
if load_checkpoint_mode == "best":
|
|
209
|
+
# Checkpoints should be either:
|
|
210
|
+
# - last.ckpt
|
|
211
|
+
# - of the form "A=B-C=D-....ckpt" with one key being epoch=X
|
|
212
|
+
# So we want the one with the highest epoch, and only use last.ckpt if
|
|
213
|
+
# it's the only option.
|
|
214
|
+
# User should set save_top_k=1 so there's just one, otherwise we won't
|
|
215
|
+
# actually know which one is the best.
|
|
216
|
+
best_checkpoint = None
|
|
217
|
+
best_epochs = None
|
|
218
|
+
for option in project_dir.iterdir():
|
|
219
|
+
if not option.name.endswith(".ckpt"):
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
# Try to see what epochs this checkpoint is at.
|
|
223
|
+
# If it is some other format, then set it 0 so we only use it if it's
|
|
224
|
+
# the only option.
|
|
225
|
+
# If it is last.ckpt then we set it -100 to only use it if there is not
|
|
226
|
+
# even another format like "best.ckpt".
|
|
227
|
+
extracted_epochs = 0
|
|
228
|
+
if option.name == "last.ckpt":
|
|
229
|
+
extracted_epochs = -100
|
|
230
|
+
|
|
231
|
+
parts = option.name.split(".ckpt")[0].split("-")
|
|
232
|
+
for part in parts:
|
|
233
|
+
kv_parts = part.split("=")
|
|
234
|
+
if len(kv_parts) != 2:
|
|
235
|
+
continue
|
|
236
|
+
if kv_parts[0] != "epoch":
|
|
237
|
+
continue
|
|
238
|
+
extracted_epochs = int(kv_parts[1])
|
|
239
|
+
|
|
240
|
+
if best_epochs is None or extracted_epochs > best_epochs:
|
|
241
|
+
best_checkpoint = option
|
|
242
|
+
best_epochs = extracted_epochs
|
|
243
|
+
|
|
244
|
+
if best_checkpoint is not None:
|
|
245
|
+
# Cache the checkpoint so we only need to download once in case we
|
|
246
|
+
# reuse it later.
|
|
247
|
+
# We only cache with --load_best since this is the only scenario where we
|
|
248
|
+
# expect to keep reusing the same checkpoint.
|
|
249
|
+
ckpt_path = get_cached_checkpoint(best_checkpoint)
|
|
250
|
+
|
|
251
|
+
elif load_checkpoint_mode == "last":
|
|
252
|
+
last_checkpoint_path = project_dir / "last.ckpt"
|
|
253
|
+
if last_checkpoint_path.exists():
|
|
254
|
+
ckpt_path = str(last_checkpoint_path)
|
|
255
|
+
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError(f"unknown load_checkpoint_mode {load_checkpoint_mode}")
|
|
258
|
+
|
|
259
|
+
if load_checkpoint_required == "yes" and ckpt_path is None:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"load_checkpoint_required is set but no checkpoint was found"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return ckpt_path
|
|
265
|
+
|
|
266
|
+
def enable_project_management(self, management_dir: str) -> None:
|
|
267
|
+
"""Enable project management in the specified directory.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
management_dir: the directory to store checkpoints and W&B.
|
|
271
|
+
"""
|
|
272
|
+
subcommand = self.config.subcommand
|
|
273
|
+
c = self.config[subcommand]
|
|
274
|
+
|
|
275
|
+
# Project name and run name are required with project management.
|
|
276
|
+
if not c.project_name or not c.run_name:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"project name and run name must be set when using project management"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Get project directory within the project management directory.
|
|
282
|
+
project_dir = UPath(management_dir) / c.project_name / c.run_name
|
|
283
|
+
|
|
284
|
+
# Add the W&B logger if it isn't already set, and (re-)configure it.
|
|
285
|
+
should_log = False
|
|
286
|
+
if c.log_mode == "yes":
|
|
287
|
+
should_log = True
|
|
288
|
+
elif c.log_mode == "auto":
|
|
289
|
+
should_log = subcommand == "fit"
|
|
290
|
+
if should_log:
|
|
291
|
+
if not c.trainer.logger:
|
|
292
|
+
c.trainer.logger = jsonargparse.Namespace(
|
|
293
|
+
{
|
|
294
|
+
"class_path": "lightning.pytorch.loggers.WandbLogger",
|
|
295
|
+
"init_args": jsonargparse.Namespace(),
|
|
296
|
+
}
|
|
297
|
+
)
|
|
298
|
+
c.trainer.logger.init_args.project = c.project_name
|
|
299
|
+
c.trainer.logger.init_args.name = c.run_name
|
|
300
|
+
if c.run_description:
|
|
301
|
+
c.trainer.logger.init_args.notes = c.run_description
|
|
302
|
+
|
|
303
|
+
# Add callback to save config to W&B.
|
|
304
|
+
upload_wandb_callback = None
|
|
305
|
+
if "callbacks" in c.trainer and c.trainer.callbacks:
|
|
306
|
+
for existing_callback in c.trainer.callbacks:
|
|
307
|
+
if existing_callback.class_path == "SaveWandbRunIdCallback":
|
|
308
|
+
upload_wandb_callback = existing_callback
|
|
309
|
+
else:
|
|
310
|
+
c.trainer.callbacks = []
|
|
311
|
+
|
|
312
|
+
if not upload_wandb_callback:
|
|
313
|
+
config_str = json.dumps(
|
|
314
|
+
c.as_dict(), default=lambda _: "<not serializable>"
|
|
315
|
+
)
|
|
316
|
+
upload_wandb_callback = jsonargparse.Namespace(
|
|
317
|
+
{
|
|
318
|
+
"class_path": "SaveWandbRunIdCallback",
|
|
319
|
+
"init_args": jsonargparse.Namespace(
|
|
320
|
+
{
|
|
321
|
+
"project_dir": str(project_dir),
|
|
322
|
+
"config_str": config_str,
|
|
323
|
+
}
|
|
324
|
+
),
|
|
325
|
+
}
|
|
326
|
+
)
|
|
327
|
+
c.trainer.callbacks.append(upload_wandb_callback)
|
|
328
|
+
else:
|
|
329
|
+
c.trainer.logger = jsonargparse.Namespace({})
|
|
330
|
+
|
|
331
|
+
if subcommand == "fit":
|
|
332
|
+
# Set the checkpoint directory to match the project directory.
|
|
333
|
+
checkpoint_callback = None
|
|
334
|
+
if "callbacks" in c.trainer and c.trainer.callbacks:
|
|
335
|
+
for existing_callback in c.trainer.callbacks:
|
|
336
|
+
if (
|
|
337
|
+
existing_callback.class_path
|
|
338
|
+
== "lightning.pytorch.callbacks.ModelCheckpoint"
|
|
339
|
+
):
|
|
340
|
+
checkpoint_callback = existing_callback
|
|
341
|
+
else:
|
|
342
|
+
c.trainer.callbacks = []
|
|
343
|
+
|
|
344
|
+
if not checkpoint_callback:
|
|
345
|
+
checkpoint_callback = jsonargparse.Namespace(
|
|
346
|
+
{
|
|
347
|
+
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
|
|
348
|
+
"init_args": jsonargparse.Namespace(
|
|
349
|
+
{
|
|
350
|
+
"save_last": True,
|
|
351
|
+
"save_top_k": 1,
|
|
352
|
+
"monitor": "val_loss",
|
|
353
|
+
}
|
|
354
|
+
),
|
|
355
|
+
}
|
|
356
|
+
)
|
|
357
|
+
c.trainer.callbacks.append(checkpoint_callback)
|
|
358
|
+
checkpoint_callback.init_args.dirpath = str(project_dir)
|
|
359
|
+
|
|
360
|
+
# Load existing checkpoint.
|
|
361
|
+
checkpoint_path = self._get_checkpoint_path(
|
|
362
|
+
project_dir=project_dir,
|
|
363
|
+
load_checkpoint_mode=c.load_checkpoint_mode,
|
|
364
|
+
load_checkpoint_required=c.load_checkpoint_required,
|
|
365
|
+
stage=subcommand,
|
|
366
|
+
)
|
|
367
|
+
if checkpoint_path is not None:
|
|
368
|
+
logger.info(f"found checkpoint to resume from at {checkpoint_path}")
|
|
369
|
+
c.ckpt_path = checkpoint_path
|
|
370
|
+
|
|
371
|
+
# If we are resuming from a checkpoint for training, we also try to resume the W&B run.
|
|
372
|
+
if (
|
|
373
|
+
subcommand == "fit"
|
|
374
|
+
and (project_dir / WANDB_ID_FNAME).exists()
|
|
375
|
+
and should_log
|
|
376
|
+
):
|
|
377
|
+
with (project_dir / WANDB_ID_FNAME).open("r") as f:
|
|
378
|
+
wandb_id = f.read().strip()
|
|
379
|
+
c.trainer.logger.init_args.id = wandb_id
|
|
380
|
+
|
|
26
381
|
def before_instantiate_classes(self) -> None:
|
|
27
382
|
"""Called before Lightning class initialization.
|
|
28
383
|
|
|
@@ -33,7 +388,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
33
388
|
|
|
34
389
|
# If there is a RslearnPredictionWriter, set its path.
|
|
35
390
|
prediction_writer_callback = None
|
|
36
|
-
if "callbacks" in c.trainer:
|
|
391
|
+
if "callbacks" in c.trainer and c.trainer.callbacks:
|
|
37
392
|
for existing_callback in c.trainer.callbacks:
|
|
38
393
|
if (
|
|
39
394
|
existing_callback.class_path
|
|
@@ -53,6 +408,20 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
53
408
|
if subcommand == "predict":
|
|
54
409
|
c.return_predictions = False
|
|
55
410
|
|
|
411
|
+
# For now we use DDP strategy with find_unused_parameters=True.
|
|
412
|
+
if subcommand == "fit":
|
|
413
|
+
c.trainer.strategy = jsonargparse.Namespace(
|
|
414
|
+
{
|
|
415
|
+
"class_path": "lightning.pytorch.strategies.DDPStrategy",
|
|
416
|
+
"init_args": jsonargparse.Namespace(
|
|
417
|
+
{"find_unused_parameters": True}
|
|
418
|
+
),
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if c.management_dir:
|
|
423
|
+
self.enable_project_management(c.management_dir)
|
|
424
|
+
|
|
56
425
|
|
|
57
426
|
def model_handler() -> None:
|
|
58
427
|
"""Handler for any rslearn model X commands."""
|
rslearn/main.py
CHANGED
|
@@ -15,7 +15,7 @@ from upath import UPath
|
|
|
15
15
|
|
|
16
16
|
from rslearn.config import LayerConfig
|
|
17
17
|
from rslearn.const import WGS84_EPSG
|
|
18
|
-
from rslearn.data_sources import Item
|
|
18
|
+
from rslearn.data_sources import Item
|
|
19
19
|
from rslearn.dataset import Dataset, Window, WindowLayerData
|
|
20
20
|
from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
|
|
21
21
|
from rslearn.dataset.handler_summaries import (
|
|
@@ -544,7 +544,7 @@ class IngestHandler:
|
|
|
544
544
|
tile_store, layer_name, layer_cfg
|
|
545
545
|
)
|
|
546
546
|
layer_cfg = self.dataset.layers[layer_name]
|
|
547
|
-
data_source =
|
|
547
|
+
data_source = layer_cfg.instantiate_data_source(self.dataset.path)
|
|
548
548
|
|
|
549
549
|
attempts_counter = AttemptsCounter()
|
|
550
550
|
ingest_counts: IngestCounts | UnknownIngestCounts
|
|
@@ -640,7 +640,7 @@ class IngestHandler:
|
|
|
640
640
|
if not layer_cfg.data_source.ingest:
|
|
641
641
|
continue
|
|
642
642
|
|
|
643
|
-
data_source =
|
|
643
|
+
data_source = layer_cfg.instantiate_data_source(self.dataset.path)
|
|
644
644
|
|
|
645
645
|
geometries_by_item: dict = {}
|
|
646
646
|
for window, layer_datas in windows_and_layer_datas:
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -8,6 +8,7 @@ from importlib.resources import files
|
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
11
12
|
import yaml
|
|
12
13
|
from einops import rearrange
|
|
13
14
|
from huggingface_hub import hf_hub_download
|
|
@@ -30,6 +31,7 @@ PATCH_SIZE = 8
|
|
|
30
31
|
CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
|
|
31
32
|
CONFIG_DIR = files("rslearn.models.clay.configs")
|
|
32
33
|
CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
|
|
34
|
+
DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def get_clay_checkpoint_path(
|
|
@@ -49,6 +51,7 @@ class Clay(torch.nn.Module):
|
|
|
49
51
|
modality: str = "sentinel-2-l2a",
|
|
50
52
|
checkpoint_path: str | None = None,
|
|
51
53
|
metadata_path: str = CLAY_METADATA_PATH,
|
|
54
|
+
do_resizing: bool = False,
|
|
52
55
|
) -> None:
|
|
53
56
|
"""Initialize the Clay model.
|
|
54
57
|
|
|
@@ -57,6 +60,7 @@ class Clay(torch.nn.Module):
|
|
|
57
60
|
modality: The modality to use (subset of CLAY_MODALITIES).
|
|
58
61
|
checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
|
|
59
62
|
metadata_path: Path to metadata.yaml.
|
|
63
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
60
64
|
"""
|
|
61
65
|
super().__init__()
|
|
62
66
|
|
|
@@ -95,6 +99,14 @@ class Clay(torch.nn.Module):
|
|
|
95
99
|
|
|
96
100
|
self.model_size = model_size
|
|
97
101
|
self.modality = modality
|
|
102
|
+
self.do_resizing = do_resizing
|
|
103
|
+
|
|
104
|
+
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
105
|
+
"""Resize the image to the input resolution."""
|
|
106
|
+
new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
107
|
+
return F.interpolate(
|
|
108
|
+
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
109
|
+
)
|
|
98
110
|
|
|
99
111
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
100
112
|
"""Forward pass for the Clay model.
|
|
@@ -114,7 +126,8 @@ class Clay(torch.nn.Module):
|
|
|
114
126
|
chips = torch.stack(
|
|
115
127
|
[inp[self.modality] for inp in inputs], dim=0
|
|
116
128
|
) # (B, C, H, W)
|
|
117
|
-
|
|
129
|
+
if self.do_resizing:
|
|
130
|
+
chips = self._resize_image(chips, chips.shape[2])
|
|
118
131
|
order = self.metadata[self.modality]["band_order"]
|
|
119
132
|
wavelengths = []
|
|
120
133
|
for band in self.metadata[self.modality]["band_order"]:
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""Concatenate feature map with features from input data."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConcatenateFeatures(torch.nn.Module):
|
|
9
|
+
"""Concatenate feature map with additional raw data inputs."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
key: str,
|
|
14
|
+
in_channels: int | None = None,
|
|
15
|
+
conv_channels: int = 64,
|
|
16
|
+
out_channels: int | None = None,
|
|
17
|
+
num_conv_layers: int = 1,
|
|
18
|
+
kernel_size: int = 3,
|
|
19
|
+
final_relu: bool = False,
|
|
20
|
+
):
|
|
21
|
+
"""Create a new ConcatenateFeatures.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
key: the key of the input_dict to concatenate.
|
|
25
|
+
in_channels: number of input channels of the additional features.
|
|
26
|
+
conv_channels: number of channels of the convolutional layers.
|
|
27
|
+
out_channels: number of output channels of the additional features.
|
|
28
|
+
num_conv_layers: number of convolutional layers to apply to the additional features.
|
|
29
|
+
kernel_size: kernel size of the convolutional layers.
|
|
30
|
+
final_relu: whether to apply a ReLU activation to the final output, default False.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.key = key
|
|
34
|
+
|
|
35
|
+
if num_conv_layers > 0:
|
|
36
|
+
if in_channels is None or out_channels is None:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"in_channels and out_channels must be specified if num_conv_layers > 0"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
conv_layers = []
|
|
42
|
+
for i in range(num_conv_layers):
|
|
43
|
+
conv_in = in_channels if i == 0 else conv_channels
|
|
44
|
+
conv_out = out_channels if i == num_conv_layers - 1 else conv_channels
|
|
45
|
+
conv_layers.append(
|
|
46
|
+
torch.nn.Conv2d(
|
|
47
|
+
in_channels=conv_in,
|
|
48
|
+
out_channels=conv_out,
|
|
49
|
+
kernel_size=kernel_size,
|
|
50
|
+
padding="same",
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
if i < num_conv_layers - 1 or final_relu:
|
|
54
|
+
conv_layers.append(torch.nn.ReLU(inplace=True))
|
|
55
|
+
|
|
56
|
+
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
57
|
+
|
|
58
|
+
def forward(
|
|
59
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
60
|
+
) -> list[torch.Tensor]:
|
|
61
|
+
"""Concatenate the feature map with the raw data inputs.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
features: list of feature maps at different resolutions.
|
|
65
|
+
inputs: original inputs.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
concatenated feature maps.
|
|
69
|
+
"""
|
|
70
|
+
if not features:
|
|
71
|
+
raise ValueError("Expected at least one feature map, got none.")
|
|
72
|
+
|
|
73
|
+
add_data = torch.stack([input_data[self.key] for input_data in inputs], dim=0)
|
|
74
|
+
add_features = self.conv_layers(add_data)
|
|
75
|
+
|
|
76
|
+
new_features: list[torch.Tensor] = []
|
|
77
|
+
for feature_map in features:
|
|
78
|
+
# Shape of feature map: BCHW
|
|
79
|
+
feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
|
|
80
|
+
|
|
81
|
+
resized_add_features = add_features
|
|
82
|
+
# Resize additional features to match each feature map size if needed
|
|
83
|
+
if add_features.shape[2] != feat_h or add_features.shape[3] != feat_w:
|
|
84
|
+
resized_add_features = torch.nn.functional.interpolate(
|
|
85
|
+
add_features,
|
|
86
|
+
size=(feat_h, feat_w),
|
|
87
|
+
mode="bilinear",
|
|
88
|
+
align_corners=False,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
|
|
92
|
+
|
|
93
|
+
return new_features
|
rslearn/models/croma.py
CHANGED
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
10
11
|
from einops import rearrange
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
@@ -99,6 +100,7 @@ class Croma(torch.nn.Module):
|
|
|
99
100
|
modality: CromaModality,
|
|
100
101
|
pretrained_path: str | None = None,
|
|
101
102
|
image_resolution: int = DEFAULT_IMAGE_RESOLUTION,
|
|
103
|
+
do_resizing: bool = False,
|
|
102
104
|
) -> None:
|
|
103
105
|
"""Instantiate a new Croma instance.
|
|
104
106
|
|
|
@@ -107,12 +109,21 @@ class Croma(torch.nn.Module):
|
|
|
107
109
|
modality: the modalities to configure the model to accept.
|
|
108
110
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
109
111
|
downloaded and cached in temp directory.
|
|
110
|
-
image_resolution: the width and height of the input images.
|
|
112
|
+
image_resolution: the width and height of the input images passed to the model. if do_resizing is True, the image will be resized to this resolution.
|
|
113
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
111
114
|
"""
|
|
112
115
|
super().__init__()
|
|
113
116
|
self.size = size
|
|
114
117
|
self.modality = modality
|
|
115
|
-
self.
|
|
118
|
+
self.do_resizing = do_resizing
|
|
119
|
+
if not do_resizing:
|
|
120
|
+
self.image_resolution = image_resolution
|
|
121
|
+
else:
|
|
122
|
+
# With single pixel input, we always resample to the patch size.
|
|
123
|
+
if image_resolution == 1:
|
|
124
|
+
self.image_resolution = PATCH_SIZE
|
|
125
|
+
else:
|
|
126
|
+
self.image_resolution = DEFAULT_IMAGE_RESOLUTION
|
|
116
127
|
|
|
117
128
|
# Cache the CROMA weights to a deterministic path in temporary directory if the
|
|
118
129
|
# path is not provided by the user.
|
|
@@ -137,7 +148,16 @@ class Croma(torch.nn.Module):
|
|
|
137
148
|
pretrained_path=pretrained_path,
|
|
138
149
|
size=size.value,
|
|
139
150
|
modality=modality.value,
|
|
140
|
-
image_resolution=image_resolution,
|
|
151
|
+
image_resolution=self.image_resolution,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _resize_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
"""Resize the image to the input resolution."""
|
|
156
|
+
return F.interpolate(
|
|
157
|
+
image,
|
|
158
|
+
size=(self.image_resolution, self.image_resolution),
|
|
159
|
+
mode="bilinear",
|
|
160
|
+
align_corners=False,
|
|
141
161
|
)
|
|
142
162
|
|
|
143
163
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
@@ -151,8 +171,11 @@ class Croma(torch.nn.Module):
|
|
|
151
171
|
sentinel2: torch.Tensor | None = None
|
|
152
172
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
153
173
|
sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
|
|
174
|
+
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
154
175
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
155
176
|
sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
177
|
+
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
178
|
+
|
|
156
179
|
outputs = self.model(
|
|
157
180
|
SAR_images=sentinel1,
|
|
158
181
|
optical_images=sentinel2,
|