rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/main.py
CHANGED
|
@@ -1,39 +1,56 @@
|
|
|
1
1
|
"""Entrypoint for the rslearn command-line interface."""
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
|
-
import logging
|
|
5
4
|
import multiprocessing
|
|
6
5
|
import random
|
|
7
6
|
import sys
|
|
7
|
+
import time
|
|
8
8
|
from collections.abc import Callable
|
|
9
|
-
from datetime import datetime,
|
|
10
|
-
from
|
|
9
|
+
from datetime import UTC, datetime, timedelta
|
|
10
|
+
from typing import Any, TypeVar
|
|
11
11
|
|
|
12
12
|
import tqdm
|
|
13
|
-
import wandb
|
|
14
|
-
from lightning.pytorch.cli import LightningCLI
|
|
15
13
|
from rasterio.crs import CRS
|
|
16
14
|
from upath import UPath
|
|
17
15
|
|
|
18
16
|
from rslearn.config import LayerConfig
|
|
19
17
|
from rslearn.const import WGS84_EPSG
|
|
20
|
-
from rslearn.data_sources import Item
|
|
21
|
-
from rslearn.dataset import Dataset, Window
|
|
18
|
+
from rslearn.data_sources import Item
|
|
19
|
+
from rslearn.dataset import Dataset, Window, WindowLayerData
|
|
22
20
|
from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
|
|
23
|
-
from rslearn.dataset.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
from rslearn.dataset.handler_summaries import (
|
|
22
|
+
ErrorOutcome,
|
|
23
|
+
IngestCounts,
|
|
24
|
+
IngestDatasetJobsSummary,
|
|
25
|
+
LayerIngestSummary,
|
|
26
|
+
MaterializeDatasetWindowsSummary,
|
|
27
|
+
PrepareDatasetWindowsSummary,
|
|
28
|
+
UnknownIngestCounts,
|
|
29
|
+
)
|
|
30
|
+
from rslearn.dataset.manage import (
|
|
31
|
+
AttemptsCounter,
|
|
32
|
+
materialize_dataset_windows,
|
|
33
|
+
prepare_dataset_windows,
|
|
34
|
+
retry,
|
|
35
|
+
)
|
|
36
|
+
from rslearn.dataset.storage.file import FileWindowStorage
|
|
37
|
+
from rslearn.log_utils import get_logger
|
|
38
|
+
from rslearn.tile_stores import get_tile_store_with_layer
|
|
27
39
|
from rslearn.utils import Projection, STGeometry
|
|
28
40
|
|
|
29
|
-
|
|
41
|
+
logger = get_logger(__name__)
|
|
42
|
+
|
|
30
43
|
handler_registry = {}
|
|
31
44
|
|
|
45
|
+
ItemType = TypeVar("ItemType", bound="Item")
|
|
46
|
+
|
|
47
|
+
MULTIPROCESSING_CONTEXT = "forkserver"
|
|
32
48
|
|
|
33
|
-
|
|
49
|
+
|
|
50
|
+
def register_handler(category: Any, command: str) -> Callable:
|
|
34
51
|
"""Register a new handler for a command."""
|
|
35
52
|
|
|
36
|
-
def decorator(f):
|
|
53
|
+
def decorator(f: Callable) -> Callable:
|
|
37
54
|
handler_registry[(category, command)] = f
|
|
38
55
|
return f
|
|
39
56
|
|
|
@@ -47,7 +64,7 @@ def parse_time(time_str: str) -> datetime:
|
|
|
47
64
|
"""
|
|
48
65
|
ts = datetime.fromisoformat(time_str)
|
|
49
66
|
if not ts.tzinfo:
|
|
50
|
-
ts = ts.replace(tzinfo=
|
|
67
|
+
ts = ts.replace(tzinfo=UTC)
|
|
51
68
|
return ts
|
|
52
69
|
|
|
53
70
|
|
|
@@ -60,8 +77,13 @@ def parse_time_range(
|
|
|
60
77
|
return (parse_time(start), parse_time(end))
|
|
61
78
|
|
|
62
79
|
|
|
80
|
+
def parse_disabled_layers(disabled_layers: str) -> list[str]:
|
|
81
|
+
"""Parse the disabled layers string."""
|
|
82
|
+
return disabled_layers.split(",") if disabled_layers else []
|
|
83
|
+
|
|
84
|
+
|
|
63
85
|
@register_handler("dataset", "add_windows")
|
|
64
|
-
def add_windows():
|
|
86
|
+
def add_windows() -> None:
|
|
65
87
|
"""Handler for the rslearn dataset add_windows command."""
|
|
66
88
|
parser = argparse.ArgumentParser(
|
|
67
89
|
prog="rslearn dataset add_windows",
|
|
@@ -156,7 +178,13 @@ def add_windows():
|
|
|
156
178
|
)
|
|
157
179
|
args = parser.parse_args(args=sys.argv[3:])
|
|
158
180
|
|
|
159
|
-
def parse_projection(
|
|
181
|
+
def parse_projection(
|
|
182
|
+
crs_str: str | None,
|
|
183
|
+
resolution: float | None,
|
|
184
|
+
x_res: float,
|
|
185
|
+
y_res: float,
|
|
186
|
+
default_crs: CRS | None = None,
|
|
187
|
+
) -> Projection | None:
|
|
160
188
|
if not crs_str:
|
|
161
189
|
if default_crs:
|
|
162
190
|
crs = default_crs
|
|
@@ -197,7 +225,8 @@ def add_windows():
|
|
|
197
225
|
box = [float(value) for value in args.box.split(",")]
|
|
198
226
|
|
|
199
227
|
windows = add_windows_from_box(
|
|
200
|
-
box
|
|
228
|
+
# TODO: we should have an object for box
|
|
229
|
+
box=box, # type: ignore
|
|
201
230
|
src_projection=parse_projection(
|
|
202
231
|
args.src_crs, args.src_resolution, args.src_x_res, args.src_y_res
|
|
203
232
|
),
|
|
@@ -210,10 +239,10 @@ def add_windows():
|
|
|
210
239
|
else:
|
|
211
240
|
raise Exception("one of box or fname must be specified")
|
|
212
241
|
|
|
213
|
-
|
|
242
|
+
logger.info(f"created {len(windows)} windows")
|
|
214
243
|
|
|
215
244
|
|
|
216
|
-
def add_apply_on_windows_args(parser: argparse.ArgumentParser):
|
|
245
|
+
def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
|
|
217
246
|
"""Add arguments for handlers that use the apply_on_windows helper.
|
|
218
247
|
|
|
219
248
|
Args:
|
|
@@ -223,10 +252,14 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
|
|
|
223
252
|
"--root", type=str, required=True, help="Dataset root directory"
|
|
224
253
|
)
|
|
225
254
|
parser.add_argument(
|
|
226
|
-
"--group",
|
|
255
|
+
"--group",
|
|
256
|
+
type=str,
|
|
257
|
+
nargs="*",
|
|
258
|
+
default=None,
|
|
259
|
+
help="Only prepare windows in these groups",
|
|
227
260
|
)
|
|
228
261
|
parser.add_argument(
|
|
229
|
-
"--window", type=str, default=None, help="Only prepare
|
|
262
|
+
"--window", type=str, nargs="*", default=None, help="Only prepare these windows"
|
|
230
263
|
)
|
|
231
264
|
parser.add_argument(
|
|
232
265
|
"--workers",
|
|
@@ -234,6 +267,12 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
|
|
|
234
267
|
default=0,
|
|
235
268
|
help="Number of worker processes (default 0 to use main process only)",
|
|
236
269
|
)
|
|
270
|
+
parser.add_argument(
|
|
271
|
+
"--load-workers",
|
|
272
|
+
type=int,
|
|
273
|
+
default=None,
|
|
274
|
+
help="Number of workers for loading windows (defaults to --workers)",
|
|
275
|
+
)
|
|
237
276
|
parser.add_argument(
|
|
238
277
|
"--batch-size",
|
|
239
278
|
type=int,
|
|
@@ -255,25 +294,31 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
|
|
|
255
294
|
|
|
256
295
|
|
|
257
296
|
def apply_on_windows(
|
|
258
|
-
f: Callable[[list[Window]],
|
|
297
|
+
f: Callable[[list[Window]], Any],
|
|
259
298
|
dataset: Dataset,
|
|
260
|
-
group: str | None = None,
|
|
261
|
-
|
|
299
|
+
group: str | list[str] | None = None,
|
|
300
|
+
names: list[str] | None = None,
|
|
262
301
|
workers: int = 0,
|
|
302
|
+
load_workers: int | None = None,
|
|
263
303
|
batch_size: int = 1,
|
|
264
304
|
jobs_per_process: int | None = None,
|
|
265
305
|
use_initial_job: bool = True,
|
|
266
|
-
):
|
|
306
|
+
) -> None:
|
|
267
307
|
"""A helper to apply a function on windows in a dataset.
|
|
268
308
|
|
|
269
309
|
Args:
|
|
270
310
|
f: the function to apply on lists of windows.
|
|
271
311
|
dataset: the dataset.
|
|
272
312
|
group: optional, only apply on windows in this group.
|
|
273
|
-
|
|
313
|
+
names: optional, only apply on windows with these names.
|
|
274
314
|
workers: the number of parallel workers to use, default 0 (main thread only).
|
|
315
|
+
load_workers: optional different number of workers to use for loading the
|
|
316
|
+
windows. If set, workers controls the number of workers to process the
|
|
317
|
+
jobs, while load_workers controls the number of workers to use for reading
|
|
318
|
+
windows from the rslearn dataset. Workers is only passed if the window
|
|
319
|
+
storage is FileWindowStorage.
|
|
275
320
|
batch_size: if workers > 0, the maximum number of windows to pass to the
|
|
276
|
-
function.
|
|
321
|
+
function.
|
|
277
322
|
jobs_per_process: optional, terminate processes after they have handled this
|
|
278
323
|
many jobs. This is useful if there is a memory leak in a dependency.
|
|
279
324
|
use_initial_job: if workers > 0, by default, an initial job is run on the first
|
|
@@ -284,30 +329,33 @@ def apply_on_windows(
|
|
|
284
329
|
if hasattr(f, "set_dataset"):
|
|
285
330
|
f.set_dataset(dataset)
|
|
286
331
|
|
|
287
|
-
groups
|
|
288
|
-
|
|
289
|
-
|
|
332
|
+
# Handle group. It can be None (load all groups) or list of groups. But it can also
|
|
333
|
+
# just be group name, in which case we must convert to list.
|
|
334
|
+
groups: list[str] | None
|
|
335
|
+
if isinstance(group, str):
|
|
290
336
|
groups = [group]
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
337
|
+
else:
|
|
338
|
+
groups = group
|
|
339
|
+
|
|
340
|
+
# Load the windows. We pass workers and show_progress if it is FileWindowStorage.
|
|
341
|
+
kwargs: dict[str, Any] = {}
|
|
342
|
+
if isinstance(dataset.storage, FileWindowStorage):
|
|
343
|
+
if load_workers is None:
|
|
344
|
+
load_workers = workers
|
|
345
|
+
kwargs["workers"] = load_workers
|
|
346
|
+
kwargs["show_progress"] = True
|
|
347
|
+
windows = dataset.load_windows(groups=groups, names=names, **kwargs)
|
|
348
|
+
logger.info(f"found {len(windows)} windows")
|
|
297
349
|
|
|
298
350
|
if hasattr(f, "get_jobs"):
|
|
299
|
-
jobs = f.get_jobs(windows,
|
|
300
|
-
|
|
351
|
+
jobs = f.get_jobs(windows, load_workers)
|
|
352
|
+
logger.info(f"got {len(jobs)} jobs")
|
|
301
353
|
else:
|
|
302
354
|
jobs = windows
|
|
303
355
|
|
|
304
|
-
if workers == 0:
|
|
305
|
-
f(jobs)
|
|
306
|
-
return
|
|
307
|
-
|
|
308
356
|
random.shuffle(jobs)
|
|
309
357
|
|
|
310
|
-
if use_initial_job:
|
|
358
|
+
if use_initial_job and len(jobs) > 0:
|
|
311
359
|
# Apply directly on first window to get any initialization out of the way.
|
|
312
360
|
f([jobs[0]])
|
|
313
361
|
jobs = jobs[1:]
|
|
@@ -316,41 +364,59 @@ def apply_on_windows(
|
|
|
316
364
|
for i in range(0, len(jobs), batch_size):
|
|
317
365
|
batches.append(jobs[i : i + batch_size])
|
|
318
366
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
367
|
+
num_batches = len(batches)
|
|
368
|
+
if workers == 0:
|
|
369
|
+
# Process batches sequentially but with same error handling as parallel
|
|
370
|
+
for batch in tqdm.tqdm(batches, total=num_batches):
|
|
371
|
+
f(batch)
|
|
372
|
+
else:
|
|
373
|
+
# Process batches in parallel
|
|
374
|
+
p = multiprocessing.Pool(processes=workers, maxtasksperchild=jobs_per_process)
|
|
375
|
+
outputs = p.imap_unordered(f, batches)
|
|
376
|
+
for _ in tqdm.tqdm(outputs, total=num_batches):
|
|
377
|
+
pass
|
|
378
|
+
p.close()
|
|
324
379
|
|
|
325
380
|
|
|
326
|
-
def apply_on_windows_args(f: Callable[
|
|
381
|
+
def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
|
|
327
382
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
328
|
-
dataset = Dataset(UPath(args.root))
|
|
383
|
+
dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
|
|
329
384
|
apply_on_windows(
|
|
330
|
-
f,
|
|
331
|
-
dataset,
|
|
332
|
-
args.group,
|
|
333
|
-
args.window,
|
|
334
|
-
args.workers,
|
|
335
|
-
args.
|
|
336
|
-
args.
|
|
337
|
-
args.
|
|
385
|
+
f=f,
|
|
386
|
+
dataset=dataset,
|
|
387
|
+
group=args.group,
|
|
388
|
+
names=args.window,
|
|
389
|
+
workers=args.workers,
|
|
390
|
+
load_workers=args.load_workers,
|
|
391
|
+
batch_size=args.batch_size,
|
|
392
|
+
jobs_per_process=args.jobs_per_process,
|
|
393
|
+
use_initial_job=args.use_initial_job,
|
|
338
394
|
)
|
|
339
395
|
|
|
340
396
|
|
|
341
397
|
class PrepareHandler:
|
|
342
398
|
"""apply_on_windows handler for the rslearn dataset prepare command."""
|
|
343
399
|
|
|
344
|
-
def __init__(
|
|
400
|
+
def __init__(
|
|
401
|
+
self,
|
|
402
|
+
force: bool,
|
|
403
|
+
retry_max_attempts: int = 0,
|
|
404
|
+
retry_backoff: timedelta = timedelta(minutes=1),
|
|
405
|
+
) -> None:
|
|
345
406
|
"""Initialize a new PrepareHandler.
|
|
346
407
|
|
|
347
408
|
Args:
|
|
348
409
|
force: force prepare
|
|
410
|
+
retry_max_attempts: set greater than zero to retry for this many attempts in
|
|
411
|
+
case of error.
|
|
412
|
+
retry_backoff: how long to wait before retrying (see retry).
|
|
349
413
|
"""
|
|
350
414
|
self.force = force
|
|
351
|
-
self.dataset = None
|
|
415
|
+
self.dataset: Dataset | None = None
|
|
416
|
+
self.retry_max_attempts = retry_max_attempts
|
|
417
|
+
self.retry_backoff = retry_backoff
|
|
352
418
|
|
|
353
|
-
def set_dataset(self, dataset: Dataset):
|
|
419
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
354
420
|
"""Captures the dataset from apply_on_windows_args.
|
|
355
421
|
|
|
356
422
|
Args:
|
|
@@ -358,13 +424,22 @@ class PrepareHandler:
|
|
|
358
424
|
"""
|
|
359
425
|
self.dataset = dataset
|
|
360
426
|
|
|
361
|
-
def __call__(self, windows: list[Window]):
|
|
427
|
+
def __call__(self, windows: list[Window]) -> PrepareDatasetWindowsSummary:
|
|
362
428
|
"""Prepares the windows from apply_on_windows."""
|
|
363
|
-
|
|
429
|
+
logger.info(f"Running prepare on {len(windows)} windows")
|
|
430
|
+
if self.dataset is None:
|
|
431
|
+
raise ValueError("dataset not set")
|
|
432
|
+
return prepare_dataset_windows(
|
|
433
|
+
self.dataset,
|
|
434
|
+
windows,
|
|
435
|
+
self.force,
|
|
436
|
+
retry_max_attempts=self.retry_max_attempts,
|
|
437
|
+
retry_backoff=self.retry_backoff,
|
|
438
|
+
)
|
|
364
439
|
|
|
365
440
|
|
|
366
441
|
@register_handler("dataset", "prepare")
|
|
367
|
-
def dataset_prepare():
|
|
442
|
+
def dataset_prepare() -> None:
|
|
368
443
|
"""Handler for the rslearn dataset prepare command."""
|
|
369
444
|
parser = argparse.ArgumentParser(
|
|
370
445
|
prog="rslearn dataset prepare",
|
|
@@ -377,14 +452,38 @@ def dataset_prepare():
|
|
|
377
452
|
action=argparse.BooleanOptionalAction,
|
|
378
453
|
help="Prepare windows even if they were previously prepared",
|
|
379
454
|
)
|
|
455
|
+
parser.add_argument(
|
|
456
|
+
"--disabled-layers",
|
|
457
|
+
type=parse_disabled_layers,
|
|
458
|
+
default="",
|
|
459
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
460
|
+
)
|
|
461
|
+
parser.add_argument(
|
|
462
|
+
"--retry-max-attempts",
|
|
463
|
+
type=int,
|
|
464
|
+
default=0,
|
|
465
|
+
help="Retry for this many attempts",
|
|
466
|
+
)
|
|
467
|
+
parser.add_argument(
|
|
468
|
+
"--retry-backoff-seconds",
|
|
469
|
+
type=int,
|
|
470
|
+
default=0,
|
|
471
|
+
help="Backoff time (seconds) between retries",
|
|
472
|
+
)
|
|
380
473
|
add_apply_on_windows_args(parser)
|
|
381
474
|
args = parser.parse_args(args=sys.argv[3:])
|
|
382
475
|
|
|
383
|
-
fn = PrepareHandler(
|
|
476
|
+
fn = PrepareHandler(
|
|
477
|
+
args.force,
|
|
478
|
+
retry_max_attempts=args.retry_max_attempts,
|
|
479
|
+
retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
|
|
480
|
+
)
|
|
384
481
|
apply_on_windows_args(fn, args)
|
|
385
482
|
|
|
386
483
|
|
|
387
|
-
def _load_window_layer_datas(
|
|
484
|
+
def _load_window_layer_datas(
|
|
485
|
+
window: Window,
|
|
486
|
+
) -> tuple[Window, dict[str, WindowLayerData]]:
|
|
388
487
|
# Helper for IngestHandler to use with multiprocessing.
|
|
389
488
|
return window, window.load_layer_datas()
|
|
390
489
|
|
|
@@ -392,11 +491,19 @@ def _load_window_layer_datas(window: Window):
|
|
|
392
491
|
class IngestHandler:
|
|
393
492
|
"""apply_on_windows handler for the rslearn dataset ingest command."""
|
|
394
493
|
|
|
395
|
-
def __init__(
|
|
494
|
+
def __init__(
|
|
495
|
+
self,
|
|
496
|
+
ignore_errors: bool = False,
|
|
497
|
+
retry_max_attempts: int = 0,
|
|
498
|
+
retry_backoff: timedelta = timedelta(minutes=1),
|
|
499
|
+
) -> None:
|
|
396
500
|
"""Initialize a new IngestHandler."""
|
|
397
|
-
self.dataset = None
|
|
501
|
+
self.dataset: Dataset | None = None
|
|
502
|
+
self.ignore_errors = ignore_errors
|
|
503
|
+
self.retry_max_attempts = retry_max_attempts
|
|
504
|
+
self.retry_backoff = retry_backoff
|
|
398
505
|
|
|
399
|
-
def set_dataset(self, dataset: Dataset):
|
|
506
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
400
507
|
"""Captures the dataset from apply_on_windows_args.
|
|
401
508
|
|
|
402
509
|
Args:
|
|
@@ -404,21 +511,32 @@ class IngestHandler:
|
|
|
404
511
|
"""
|
|
405
512
|
self.dataset = dataset
|
|
406
513
|
|
|
407
|
-
def __call__(
|
|
514
|
+
def __call__(
|
|
515
|
+
self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
|
|
516
|
+
) -> IngestDatasetJobsSummary:
|
|
408
517
|
"""Ingest the specified items.
|
|
409
518
|
|
|
410
519
|
The items are computed from list of windows via IngestHandler.get_jobs.
|
|
411
520
|
|
|
412
521
|
Args:
|
|
413
|
-
jobs: list of (layer_name, item, geometries) tuples to ingest.
|
|
522
|
+
jobs: list of (layer_name, layer_cfg, item, geometries) tuples to ingest.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
summary of the ingest jobs operation fit for telemetry purposes.
|
|
414
526
|
"""
|
|
527
|
+
start_time = time.monotonic()
|
|
528
|
+
layer_summaries: list[LayerIngestSummary] = []
|
|
529
|
+
|
|
530
|
+
logger.info(f"Running ingest for {len(jobs)} jobs")
|
|
415
531
|
import gc
|
|
416
532
|
|
|
533
|
+
if self.dataset is None:
|
|
534
|
+
raise ValueError("dataset not set")
|
|
417
535
|
tile_store = self.dataset.get_tile_store()
|
|
418
536
|
|
|
419
537
|
# Group jobs by layer name.
|
|
420
|
-
jobs_by_layer = {}
|
|
421
|
-
configs_by_layer = {}
|
|
538
|
+
jobs_by_layer: dict = {}
|
|
539
|
+
configs_by_layer: dict = {}
|
|
422
540
|
for layer_name, layer_cfg, item, geometries in jobs:
|
|
423
541
|
if layer_name not in jobs_by_layer:
|
|
424
542
|
jobs_by_layer[layer_name] = []
|
|
@@ -426,24 +544,81 @@ class IngestHandler:
|
|
|
426
544
|
configs_by_layer[layer_name] = layer_cfg
|
|
427
545
|
|
|
428
546
|
for layer_name, items_and_geometries in jobs_by_layer.items():
|
|
429
|
-
|
|
547
|
+
layer_tile_store = get_tile_store_with_layer(
|
|
548
|
+
tile_store, layer_name, layer_cfg
|
|
549
|
+
)
|
|
430
550
|
layer_cfg = self.dataset.layers[layer_name]
|
|
431
|
-
data_source =
|
|
551
|
+
data_source = layer_cfg.instantiate_data_source(self.dataset.path)
|
|
432
552
|
|
|
553
|
+
attempts_counter = AttemptsCounter()
|
|
554
|
+
ingest_counts: IngestCounts | UnknownIngestCounts
|
|
433
555
|
try:
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
556
|
+
retry(
|
|
557
|
+
lambda: data_source.ingest(
|
|
558
|
+
tile_store=layer_tile_store,
|
|
559
|
+
items=[item for item, _ in items_and_geometries],
|
|
560
|
+
geometries=[
|
|
561
|
+
geometries for _, geometries in items_and_geometries
|
|
562
|
+
],
|
|
563
|
+
),
|
|
564
|
+
retry_max_attempts=self.retry_max_attempts,
|
|
565
|
+
retry_backoff=self.retry_backoff,
|
|
566
|
+
attempts_counter=attempts_counter,
|
|
567
|
+
)
|
|
568
|
+
ingest_counts = IngestCounts(
|
|
569
|
+
items_ingested=len(items_and_geometries),
|
|
570
|
+
geometries_ingested=sum(
|
|
571
|
+
len(geometries) for _, geometries in items_and_geometries
|
|
572
|
+
),
|
|
438
573
|
)
|
|
439
574
|
except Exception as e:
|
|
440
|
-
|
|
575
|
+
if not self.ignore_errors:
|
|
576
|
+
raise
|
|
577
|
+
|
|
578
|
+
ingest_counts = UnknownIngestCounts(
|
|
579
|
+
items_attempted=len(items_and_geometries),
|
|
580
|
+
geometries_attempted=sum(
|
|
581
|
+
len(geometries) for _, geometries in items_and_geometries
|
|
582
|
+
),
|
|
583
|
+
)
|
|
584
|
+
logger.error(
|
|
441
585
|
"warning: got error while ingesting "
|
|
442
586
|
+ f"{len(items_and_geometries)} items: {e}"
|
|
443
587
|
)
|
|
444
588
|
|
|
589
|
+
layer_summaries.append(
|
|
590
|
+
LayerIngestSummary(
|
|
591
|
+
layer_name=layer_name,
|
|
592
|
+
data_source_name=getattr(layer_cfg.data_source, "name", "N/A"),
|
|
593
|
+
duration_seconds=time.monotonic() - start_time,
|
|
594
|
+
ingest_counts=ingest_counts,
|
|
595
|
+
ingest_attempts=attempts_counter.value,
|
|
596
|
+
)
|
|
597
|
+
)
|
|
598
|
+
|
|
445
599
|
gc.collect()
|
|
446
600
|
|
|
601
|
+
return IngestDatasetJobsSummary(
|
|
602
|
+
duration_seconds=time.monotonic() - start_time,
|
|
603
|
+
num_jobs=len(jobs),
|
|
604
|
+
layer_summaries=layer_summaries,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
def _load_layer_data_for_windows(
|
|
608
|
+
self, windows: list[Window], workers: int
|
|
609
|
+
) -> list[tuple[Window, dict[str, WindowLayerData]]]:
|
|
610
|
+
if workers == 0:
|
|
611
|
+
return [(_load_window_layer_datas(window)) for window in windows]
|
|
612
|
+
p = multiprocessing.Pool(workers)
|
|
613
|
+
outputs = p.imap_unordered(_load_window_layer_datas, windows)
|
|
614
|
+
windows_and_layer_datas = []
|
|
615
|
+
for window, layer_datas in tqdm.tqdm(
|
|
616
|
+
outputs, total=len(windows), desc="Loading window layer datas"
|
|
617
|
+
):
|
|
618
|
+
windows_and_layer_datas.append((window, layer_datas))
|
|
619
|
+
p.close()
|
|
620
|
+
return windows_and_layer_datas
|
|
621
|
+
|
|
447
622
|
def get_jobs(
|
|
448
623
|
self, windows: list[Window], workers: int
|
|
449
624
|
) -> list[tuple[str, LayerConfig, Item, list[STGeometry]]]:
|
|
@@ -455,17 +630,12 @@ class IngestHandler:
|
|
|
455
630
|
This makes sure that jobs are grouped by item rather than by window, which
|
|
456
631
|
makes sense because there's no reason to ingest the same item twice.
|
|
457
632
|
"""
|
|
633
|
+
if self.dataset is None:
|
|
634
|
+
raise ValueError("dataset not set")
|
|
458
635
|
# TODO: avoid duplicating ingest_dataset_windows...
|
|
459
636
|
|
|
460
637
|
# Load layer datas of each window.
|
|
461
|
-
|
|
462
|
-
outputs = p.imap_unordered(_load_window_layer_datas, windows)
|
|
463
|
-
windows_and_layer_datas = []
|
|
464
|
-
for window, layer_datas in tqdm.tqdm(
|
|
465
|
-
outputs, total=len(windows), desc="Loading window layer datas"
|
|
466
|
-
):
|
|
467
|
-
windows_and_layer_datas.append((window, layer_datas))
|
|
468
|
-
p.close()
|
|
638
|
+
windows_and_layer_datas = self._load_layer_data_for_windows(windows, workers)
|
|
469
639
|
|
|
470
640
|
jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]] = []
|
|
471
641
|
for layer_name, layer_cfg in self.dataset.layers.items():
|
|
@@ -474,9 +644,9 @@ class IngestHandler:
|
|
|
474
644
|
if not layer_cfg.data_source.ingest:
|
|
475
645
|
continue
|
|
476
646
|
|
|
477
|
-
data_source =
|
|
647
|
+
data_source = layer_cfg.instantiate_data_source(self.dataset.path)
|
|
478
648
|
|
|
479
|
-
geometries_by_item = {}
|
|
649
|
+
geometries_by_item: dict = {}
|
|
480
650
|
for window, layer_datas in windows_and_layer_datas:
|
|
481
651
|
if layer_name not in layer_datas:
|
|
482
652
|
continue
|
|
@@ -484,7 +654,9 @@ class IngestHandler:
|
|
|
484
654
|
layer_data = layer_datas[layer_name]
|
|
485
655
|
for group in layer_data.serialized_item_groups:
|
|
486
656
|
for serialized_item in group:
|
|
487
|
-
item = data_source.deserialize_item(
|
|
657
|
+
item = data_source.deserialize_item( # type: ignore
|
|
658
|
+
serialized_item
|
|
659
|
+
)
|
|
488
660
|
if item not in geometries_by_item:
|
|
489
661
|
geometries_by_item[item] = []
|
|
490
662
|
geometries_by_item[item].append(geometry)
|
|
@@ -492,32 +664,69 @@ class IngestHandler:
|
|
|
492
664
|
for item, geometries in geometries_by_item.items():
|
|
493
665
|
jobs.append((layer_name, layer_cfg, item, geometries))
|
|
494
666
|
|
|
495
|
-
|
|
667
|
+
logger.info(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
|
|
496
668
|
return jobs
|
|
497
669
|
|
|
498
670
|
|
|
499
671
|
@register_handler("dataset", "ingest")
|
|
500
|
-
def dataset_ingest():
|
|
672
|
+
def dataset_ingest() -> None:
|
|
501
673
|
"""Handler for the rslearn dataset ingest command."""
|
|
502
674
|
parser = argparse.ArgumentParser(
|
|
503
675
|
prog="rslearn dataset ingest",
|
|
504
676
|
description="rslearn dataset ingest: ingest items in retrieved data sources",
|
|
505
677
|
)
|
|
678
|
+
parser.add_argument(
|
|
679
|
+
"--disabled-layers",
|
|
680
|
+
type=parse_disabled_layers,
|
|
681
|
+
default="",
|
|
682
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
683
|
+
)
|
|
684
|
+
parser.add_argument(
|
|
685
|
+
"--ignore-errors",
|
|
686
|
+
type=bool,
|
|
687
|
+
default=False,
|
|
688
|
+
help="Ignore ingestion errors in individual jobs",
|
|
689
|
+
action=argparse.BooleanOptionalAction,
|
|
690
|
+
)
|
|
691
|
+
parser.add_argument(
|
|
692
|
+
"--retry-max-attempts",
|
|
693
|
+
type=int,
|
|
694
|
+
default=0,
|
|
695
|
+
help="Retry for this many attempts",
|
|
696
|
+
)
|
|
697
|
+
parser.add_argument(
|
|
698
|
+
"--retry-backoff-seconds",
|
|
699
|
+
type=int,
|
|
700
|
+
default=0,
|
|
701
|
+
help="Backoff time (seconds) between retries",
|
|
702
|
+
)
|
|
506
703
|
add_apply_on_windows_args(parser)
|
|
507
704
|
args = parser.parse_args(args=sys.argv[3:])
|
|
508
705
|
|
|
509
|
-
fn = IngestHandler(
|
|
706
|
+
fn = IngestHandler(
|
|
707
|
+
ignore_errors=args.ignore_errors,
|
|
708
|
+
retry_max_attempts=args.retry_max_attempts,
|
|
709
|
+
retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
|
|
710
|
+
)
|
|
510
711
|
apply_on_windows_args(fn, args)
|
|
511
712
|
|
|
512
713
|
|
|
513
714
|
class MaterializeHandler:
|
|
514
715
|
"""apply_on_windows handler for the rslearn dataset materialize command."""
|
|
515
716
|
|
|
516
|
-
def __init__(
|
|
717
|
+
def __init__(
|
|
718
|
+
self,
|
|
719
|
+
ignore_errors: bool = False,
|
|
720
|
+
retry_max_attempts: int = 0,
|
|
721
|
+
retry_backoff: timedelta = timedelta(minutes=1),
|
|
722
|
+
) -> None:
|
|
517
723
|
"""Initialize a MaterializeHandler."""
|
|
518
|
-
self.dataset = None
|
|
724
|
+
self.dataset: Dataset | None = None
|
|
725
|
+
self.ignore_errors = ignore_errors
|
|
726
|
+
self.retry_max_attempts = retry_max_attempts
|
|
727
|
+
self.retry_backoff = retry_backoff
|
|
519
728
|
|
|
520
|
-
def set_dataset(self, dataset: Dataset):
|
|
729
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
521
730
|
"""Captures the dataset from apply_on_windows_args.
|
|
522
731
|
|
|
523
732
|
Args:
|
|
@@ -525,13 +734,31 @@ class MaterializeHandler:
|
|
|
525
734
|
"""
|
|
526
735
|
self.dataset = dataset
|
|
527
736
|
|
|
528
|
-
def __call__(
|
|
737
|
+
def __call__(
|
|
738
|
+
self, windows: list[Window]
|
|
739
|
+
) -> MaterializeDatasetWindowsSummary | ErrorOutcome:
|
|
529
740
|
"""Materializes the windows from apply_on_windows."""
|
|
530
|
-
|
|
741
|
+
logger.info(f"Running Materialize with {len(windows)} windows")
|
|
742
|
+
start_time = time.monotonic()
|
|
743
|
+
if self.dataset is None:
|
|
744
|
+
raise ValueError("dataset not set")
|
|
745
|
+
try:
|
|
746
|
+
return materialize_dataset_windows(
|
|
747
|
+
self.dataset,
|
|
748
|
+
windows,
|
|
749
|
+
retry_max_attempts=self.retry_max_attempts,
|
|
750
|
+
retry_backoff=self.retry_backoff,
|
|
751
|
+
)
|
|
752
|
+
except Exception as e:
|
|
753
|
+
if not self.ignore_errors:
|
|
754
|
+
logger.error(f"Error materializing windows: {e}")
|
|
755
|
+
raise
|
|
756
|
+
logger.warning(f"Ignoring error while materializing windows: {e}")
|
|
757
|
+
return ErrorOutcome(duration_seconds=time.monotonic() - start_time)
|
|
531
758
|
|
|
532
759
|
|
|
533
760
|
@register_handler("dataset", "materialize")
|
|
534
|
-
def dataset_materialize():
|
|
761
|
+
def dataset_materialize() -> None:
|
|
535
762
|
"""Handler for the rslearn dataset materialize command."""
|
|
536
763
|
parser = argparse.ArgumentParser(
|
|
537
764
|
prog="rslearn dataset materialize",
|
|
@@ -540,110 +767,87 @@ def dataset_materialize():
|
|
|
540
767
|
+ "materialize data from retrieved data sources"
|
|
541
768
|
),
|
|
542
769
|
)
|
|
770
|
+
parser.add_argument(
|
|
771
|
+
"--disabled-layers",
|
|
772
|
+
type=parse_disabled_layers,
|
|
773
|
+
default="",
|
|
774
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
775
|
+
)
|
|
776
|
+
parser.add_argument(
|
|
777
|
+
"--ignore-errors",
|
|
778
|
+
type=bool,
|
|
779
|
+
default=False,
|
|
780
|
+
help="Ignore errors in individual jobs",
|
|
781
|
+
action=argparse.BooleanOptionalAction,
|
|
782
|
+
)
|
|
783
|
+
parser.add_argument(
|
|
784
|
+
"--retry-max-attempts",
|
|
785
|
+
type=int,
|
|
786
|
+
default=0,
|
|
787
|
+
help="Retry for this many attempts",
|
|
788
|
+
)
|
|
789
|
+
parser.add_argument(
|
|
790
|
+
"--retry-backoff-seconds",
|
|
791
|
+
type=int,
|
|
792
|
+
default=0,
|
|
793
|
+
help="Backoff time (seconds) between retries",
|
|
794
|
+
)
|
|
543
795
|
add_apply_on_windows_args(parser)
|
|
544
796
|
args = parser.parse_args(args=sys.argv[3:])
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
class RslearnLightningCLI(LightningCLI):
|
|
551
|
-
"""LightningCLI that links data.tasks to model.tasks."""
|
|
552
|
-
|
|
553
|
-
def add_arguments_to_parser(self, parser) -> None:
|
|
554
|
-
"""Link data.tasks to model.tasks.
|
|
555
|
-
|
|
556
|
-
Args:
|
|
557
|
-
parser: the argument parser
|
|
558
|
-
"""
|
|
559
|
-
parser.link_arguments(
|
|
560
|
-
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
561
|
-
)
|
|
562
|
-
parser.add_argument(
|
|
563
|
-
"--wandb_run_id",
|
|
564
|
-
default="",
|
|
565
|
-
type=str,
|
|
566
|
-
help="W&B run ID to load checkpoint from",
|
|
567
|
-
)
|
|
568
|
-
parser.add_argument(
|
|
569
|
-
"--wandb_resume",
|
|
570
|
-
default=False,
|
|
571
|
-
type=bool,
|
|
572
|
-
help="Whether to resume from specified wandb_run_id",
|
|
573
|
-
)
|
|
574
|
-
|
|
575
|
-
def before_instantiate_classes(self):
|
|
576
|
-
"""Called before Lightning class initialization.
|
|
577
|
-
|
|
578
|
-
Sets up wandb_run_id / wandb_resume arguments.
|
|
579
|
-
"""
|
|
580
|
-
subcommand = self.config.subcommand
|
|
581
|
-
c = self.config[subcommand]
|
|
582
|
-
|
|
583
|
-
if c.wandb_run_id:
|
|
584
|
-
api = wandb.Api()
|
|
585
|
-
artifact_id = (
|
|
586
|
-
f"{c.trainer.logger.init_args.project}/model-{c.wandb_run_id}:latest"
|
|
587
|
-
)
|
|
588
|
-
print(f"restoring from artifact {artifact_id} on wandb")
|
|
589
|
-
artifact = api.artifact(artifact_id, type="model")
|
|
590
|
-
artifact_dir = artifact.download()
|
|
591
|
-
c.ckpt_path = str(Path(artifact_dir) / "model.ckpt")
|
|
592
|
-
|
|
593
|
-
if c.wandb_resume:
|
|
594
|
-
c.trainer.logger.init_args.id = c.wandb_run_id
|
|
595
|
-
|
|
596
|
-
# If there is a RslearnPredictionWriter, set its path.
|
|
597
|
-
prediction_writer_callback = None
|
|
598
|
-
if "callbacks" in c.trainer:
|
|
599
|
-
for existing_callback in c.trainer.callbacks:
|
|
600
|
-
if (
|
|
601
|
-
existing_callback.class_path
|
|
602
|
-
== "rslearn.train.prediction_writer.RslearnWriter"
|
|
603
|
-
):
|
|
604
|
-
prediction_writer_callback = existing_callback
|
|
605
|
-
if prediction_writer_callback:
|
|
606
|
-
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
def model_handler():
|
|
610
|
-
"""Handler for any rslearn model X commands."""
|
|
611
|
-
RslearnLightningCLI(
|
|
612
|
-
model_class=RslearnLightningModule,
|
|
613
|
-
datamodule_class=RslearnDataModule,
|
|
614
|
-
args=sys.argv[2:],
|
|
615
|
-
subclass_mode_model=True,
|
|
616
|
-
subclass_mode_data=True,
|
|
617
|
-
save_config_kwargs={"overwrite": True},
|
|
797
|
+
fn = MaterializeHandler(
|
|
798
|
+
ignore_errors=args.ignore_errors,
|
|
799
|
+
retry_max_attempts=args.retry_max_attempts,
|
|
800
|
+
retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
|
|
618
801
|
)
|
|
802
|
+
apply_on_windows_args(fn, args)
|
|
619
803
|
|
|
620
804
|
|
|
621
805
|
@register_handler("model", "fit")
|
|
622
|
-
def model_fit():
|
|
806
|
+
def model_fit() -> None:
|
|
623
807
|
"""Handler for rslearn model fit."""
|
|
808
|
+
from .lightning_cli import model_handler
|
|
809
|
+
|
|
624
810
|
model_handler()
|
|
625
811
|
|
|
626
812
|
|
|
627
813
|
@register_handler("model", "validate")
|
|
628
|
-
def model_validate():
|
|
814
|
+
def model_validate() -> None:
|
|
629
815
|
"""Handler for rslearn model validate."""
|
|
816
|
+
from .lightning_cli import model_handler
|
|
817
|
+
|
|
630
818
|
model_handler()
|
|
631
819
|
|
|
632
820
|
|
|
633
821
|
@register_handler("model", "test")
|
|
634
|
-
def model_test():
|
|
822
|
+
def model_test() -> None:
|
|
635
823
|
"""Handler for rslearn model test."""
|
|
824
|
+
from .lightning_cli import model_handler
|
|
825
|
+
|
|
636
826
|
model_handler()
|
|
637
827
|
|
|
638
828
|
|
|
639
829
|
@register_handler("model", "predict")
|
|
640
|
-
def model_predict():
|
|
830
|
+
def model_predict() -> None:
|
|
641
831
|
"""Handler for rslearn model predict."""
|
|
832
|
+
from .lightning_cli import model_handler
|
|
833
|
+
|
|
642
834
|
model_handler()
|
|
643
835
|
|
|
644
836
|
|
|
645
|
-
def main():
|
|
837
|
+
def main() -> None:
|
|
646
838
|
"""CLI entrypoint."""
|
|
839
|
+
try:
|
|
840
|
+
multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
|
|
841
|
+
except RuntimeError as e:
|
|
842
|
+
logger.error(
|
|
843
|
+
f"Multiprocessing context already set to {multiprocessing.get_context()}: "
|
|
844
|
+
+ f"ignoring {e}"
|
|
845
|
+
)
|
|
846
|
+
except Exception as e:
|
|
847
|
+
logger.error(f"Failed to set multiprocessing context: {e}")
|
|
848
|
+
raise
|
|
849
|
+
finally:
|
|
850
|
+
logger.info(f"Using multiprocessing context: {multiprocessing.get_context()}")
|
|
647
851
|
parser = argparse.ArgumentParser(description="rslearn")
|
|
648
852
|
parser.add_argument(
|
|
649
853
|
"category", help="Command category: dataset, annotate, or model"
|
|
@@ -653,12 +857,11 @@ def main():
|
|
|
653
857
|
|
|
654
858
|
handler = handler_registry.get((args.category, args.command))
|
|
655
859
|
if handler is None:
|
|
656
|
-
|
|
860
|
+
logger.error(f"Unknown command: {args.category} {args.command}")
|
|
657
861
|
sys.exit(1)
|
|
658
862
|
|
|
659
863
|
handler()
|
|
660
864
|
|
|
661
865
|
|
|
662
866
|
if __name__ == "__main__":
|
|
663
|
-
multiprocessing.set_start_method("forkserver")
|
|
664
867
|
main()
|