rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
rslearn/main.py
CHANGED
|
@@ -1,39 +1,43 @@
|
|
|
1
1
|
"""Entrypoint for the rslearn command-line interface."""
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
|
-
import logging
|
|
5
4
|
import multiprocessing
|
|
6
5
|
import random
|
|
7
6
|
import sys
|
|
8
7
|
from collections.abc import Callable
|
|
9
8
|
from datetime import datetime, timezone
|
|
10
9
|
from pathlib import Path
|
|
10
|
+
from typing import Any, TypeVar
|
|
11
11
|
|
|
12
12
|
import tqdm
|
|
13
13
|
import wandb
|
|
14
|
-
from lightning.pytorch.cli import LightningCLI
|
|
14
|
+
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
15
15
|
from rasterio.crs import CRS
|
|
16
16
|
from upath import UPath
|
|
17
17
|
|
|
18
18
|
from rslearn.config import LayerConfig
|
|
19
19
|
from rslearn.const import WGS84_EPSG
|
|
20
20
|
from rslearn.data_sources import Item, data_source_from_config
|
|
21
|
-
from rslearn.dataset import Dataset, Window
|
|
21
|
+
from rslearn.dataset import Dataset, Window, WindowLayerData
|
|
22
22
|
from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
|
|
23
23
|
from rslearn.dataset.manage import materialize_dataset_windows, prepare_dataset_windows
|
|
24
|
+
from rslearn.log_utils import get_logger
|
|
24
25
|
from rslearn.tile_stores import get_tile_store_for_layer
|
|
25
26
|
from rslearn.train.data_module import RslearnDataModule
|
|
26
27
|
from rslearn.train.lightning_module import RslearnLightningModule
|
|
27
|
-
from rslearn.utils import Projection, STGeometry
|
|
28
|
+
from rslearn.utils import Projection, STGeometry, parse_disabled_layers
|
|
29
|
+
|
|
30
|
+
logger = get_logger(__name__)
|
|
28
31
|
|
|
29
|
-
logging.basicConfig()
|
|
30
32
|
handler_registry = {}
|
|
31
33
|
|
|
34
|
+
ItemType = TypeVar("ItemType", bound="Item")
|
|
35
|
+
|
|
32
36
|
|
|
33
|
-
def register_handler(category, command):
|
|
37
|
+
def register_handler(category: Any, command: str) -> Callable:
|
|
34
38
|
"""Register a new handler for a command."""
|
|
35
39
|
|
|
36
|
-
def decorator(f):
|
|
40
|
+
def decorator(f: Callable) -> Callable:
|
|
37
41
|
handler_registry[(category, command)] = f
|
|
38
42
|
return f
|
|
39
43
|
|
|
@@ -61,7 +65,7 @@ def parse_time_range(
|
|
|
61
65
|
|
|
62
66
|
|
|
63
67
|
@register_handler("dataset", "add_windows")
|
|
64
|
-
def add_windows():
|
|
68
|
+
def add_windows() -> None:
|
|
65
69
|
"""Handler for the rslearn dataset add_windows command."""
|
|
66
70
|
parser = argparse.ArgumentParser(
|
|
67
71
|
prog="rslearn dataset add_windows",
|
|
@@ -156,7 +160,13 @@ def add_windows():
|
|
|
156
160
|
)
|
|
157
161
|
args = parser.parse_args(args=sys.argv[3:])
|
|
158
162
|
|
|
159
|
-
def parse_projection(
|
|
163
|
+
def parse_projection(
|
|
164
|
+
crs_str: str | None,
|
|
165
|
+
resolution: float | None,
|
|
166
|
+
x_res: float,
|
|
167
|
+
y_res: float,
|
|
168
|
+
default_crs: CRS | None = None,
|
|
169
|
+
) -> Projection | None:
|
|
160
170
|
if not crs_str:
|
|
161
171
|
if default_crs:
|
|
162
172
|
crs = default_crs
|
|
@@ -197,7 +207,8 @@ def add_windows():
|
|
|
197
207
|
box = [float(value) for value in args.box.split(",")]
|
|
198
208
|
|
|
199
209
|
windows = add_windows_from_box(
|
|
200
|
-
box
|
|
210
|
+
# TODO: we should have an object for box
|
|
211
|
+
box=box, # type: ignore
|
|
201
212
|
src_projection=parse_projection(
|
|
202
213
|
args.src_crs, args.src_resolution, args.src_x_res, args.src_y_res
|
|
203
214
|
),
|
|
@@ -210,10 +221,10 @@ def add_windows():
|
|
|
210
221
|
else:
|
|
211
222
|
raise Exception("one of box or fname must be specified")
|
|
212
223
|
|
|
213
|
-
|
|
224
|
+
logger.info(f"created {len(windows)} windows")
|
|
214
225
|
|
|
215
226
|
|
|
216
|
-
def add_apply_on_windows_args(parser: argparse.ArgumentParser):
|
|
227
|
+
def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
|
|
217
228
|
"""Add arguments for handlers that use the apply_on_windows helper.
|
|
218
229
|
|
|
219
230
|
Args:
|
|
@@ -263,7 +274,7 @@ def apply_on_windows(
|
|
|
263
274
|
batch_size: int = 1,
|
|
264
275
|
jobs_per_process: int | None = None,
|
|
265
276
|
use_initial_job: bool = True,
|
|
266
|
-
):
|
|
277
|
+
) -> None:
|
|
267
278
|
"""A helper to apply a function on windows in a dataset.
|
|
268
279
|
|
|
269
280
|
Args:
|
|
@@ -293,11 +304,11 @@ def apply_on_windows(
|
|
|
293
304
|
windows = dataset.load_windows(
|
|
294
305
|
groups=groups, names=names, workers=workers, show_progress=True
|
|
295
306
|
)
|
|
296
|
-
|
|
307
|
+
logger.info(f"found {len(windows)} windows")
|
|
297
308
|
|
|
298
309
|
if hasattr(f, "get_jobs"):
|
|
299
310
|
jobs = f.get_jobs(windows, workers)
|
|
300
|
-
|
|
311
|
+
logger.info(f"got {len(jobs)} jobs")
|
|
301
312
|
else:
|
|
302
313
|
jobs = windows
|
|
303
314
|
|
|
@@ -323,9 +334,9 @@ def apply_on_windows(
|
|
|
323
334
|
p.close()
|
|
324
335
|
|
|
325
336
|
|
|
326
|
-
def apply_on_windows_args(f: Callable[
|
|
337
|
+
def apply_on_windows_args(f: Callable[..., None], args: argparse.Namespace) -> None:
|
|
327
338
|
"""Call apply_on_windows with arguments passed via command-line interface."""
|
|
328
|
-
dataset = Dataset(UPath(args.root))
|
|
339
|
+
dataset = Dataset(UPath(args.root), args.disabled_layers)
|
|
329
340
|
apply_on_windows(
|
|
330
341
|
f,
|
|
331
342
|
dataset,
|
|
@@ -341,16 +352,16 @@ def apply_on_windows_args(f: Callable[[list[Window]], None], args: argparse.Name
|
|
|
341
352
|
class PrepareHandler:
|
|
342
353
|
"""apply_on_windows handler for the rslearn dataset prepare command."""
|
|
343
354
|
|
|
344
|
-
def __init__(self, force: bool):
|
|
355
|
+
def __init__(self, force: bool) -> None:
|
|
345
356
|
"""Initialize a new PrepareHandler.
|
|
346
357
|
|
|
347
358
|
Args:
|
|
348
359
|
force: force prepare
|
|
349
360
|
"""
|
|
350
361
|
self.force = force
|
|
351
|
-
self.dataset = None
|
|
362
|
+
self.dataset: Dataset | None = None
|
|
352
363
|
|
|
353
|
-
def set_dataset(self, dataset: Dataset):
|
|
364
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
354
365
|
"""Captures the dataset from apply_on_windows_args.
|
|
355
366
|
|
|
356
367
|
Args:
|
|
@@ -358,13 +369,16 @@ class PrepareHandler:
|
|
|
358
369
|
"""
|
|
359
370
|
self.dataset = dataset
|
|
360
371
|
|
|
361
|
-
def __call__(self, windows: list[Window]):
|
|
372
|
+
def __call__(self, windows: list[Window]) -> None:
|
|
362
373
|
"""Prepares the windows from apply_on_windows."""
|
|
374
|
+
logger.info(f"Running prepare on {len(windows)} windows")
|
|
375
|
+
if self.dataset is None:
|
|
376
|
+
raise ValueError("dataset not set")
|
|
363
377
|
prepare_dataset_windows(self.dataset, windows, self.force)
|
|
364
378
|
|
|
365
379
|
|
|
366
380
|
@register_handler("dataset", "prepare")
|
|
367
|
-
def dataset_prepare():
|
|
381
|
+
def dataset_prepare() -> None:
|
|
368
382
|
"""Handler for the rslearn dataset prepare command."""
|
|
369
383
|
parser = argparse.ArgumentParser(
|
|
370
384
|
prog="rslearn dataset prepare",
|
|
@@ -377,6 +391,12 @@ def dataset_prepare():
|
|
|
377
391
|
action=argparse.BooleanOptionalAction,
|
|
378
392
|
help="Prepare windows even if they were previously prepared",
|
|
379
393
|
)
|
|
394
|
+
parser.add_argument(
|
|
395
|
+
"--disabled-layers",
|
|
396
|
+
type=parse_disabled_layers,
|
|
397
|
+
default="",
|
|
398
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
399
|
+
)
|
|
380
400
|
add_apply_on_windows_args(parser)
|
|
381
401
|
args = parser.parse_args(args=sys.argv[3:])
|
|
382
402
|
|
|
@@ -384,7 +404,9 @@ def dataset_prepare():
|
|
|
384
404
|
apply_on_windows_args(fn, args)
|
|
385
405
|
|
|
386
406
|
|
|
387
|
-
def _load_window_layer_datas(
|
|
407
|
+
def _load_window_layer_datas(
|
|
408
|
+
window: Window,
|
|
409
|
+
) -> tuple[Window, dict[str, WindowLayerData]]:
|
|
388
410
|
# Helper for IngestHandler to use with multiprocessing.
|
|
389
411
|
return window, window.load_layer_datas()
|
|
390
412
|
|
|
@@ -392,11 +414,12 @@ def _load_window_layer_datas(window: Window):
|
|
|
392
414
|
class IngestHandler:
|
|
393
415
|
"""apply_on_windows handler for the rslearn dataset ingest command."""
|
|
394
416
|
|
|
395
|
-
def __init__(self):
|
|
417
|
+
def __init__(self, ignore_errors: bool = False) -> None:
|
|
396
418
|
"""Initialize a new IngestHandler."""
|
|
397
|
-
self.dataset = None
|
|
419
|
+
self.dataset: Dataset | None = None
|
|
420
|
+
self.ignore_errors = ignore_errors
|
|
398
421
|
|
|
399
|
-
def set_dataset(self, dataset: Dataset):
|
|
422
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
400
423
|
"""Captures the dataset from apply_on_windows_args.
|
|
401
424
|
|
|
402
425
|
Args:
|
|
@@ -404,7 +427,9 @@ class IngestHandler:
|
|
|
404
427
|
"""
|
|
405
428
|
self.dataset = dataset
|
|
406
429
|
|
|
407
|
-
def __call__(
|
|
430
|
+
def __call__(
|
|
431
|
+
self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
|
|
432
|
+
) -> None:
|
|
408
433
|
"""Ingest the specified items.
|
|
409
434
|
|
|
410
435
|
The items are computed from list of windows via IngestHandler.get_jobs.
|
|
@@ -412,13 +437,16 @@ class IngestHandler:
|
|
|
412
437
|
Args:
|
|
413
438
|
jobs: list of (layer_name, item, geometries) tuples to ingest.
|
|
414
439
|
"""
|
|
440
|
+
logger.info(f"Running ingest for {len(jobs)} jobs")
|
|
415
441
|
import gc
|
|
416
442
|
|
|
443
|
+
if self.dataset is None:
|
|
444
|
+
raise ValueError("dataset not set")
|
|
417
445
|
tile_store = self.dataset.get_tile_store()
|
|
418
446
|
|
|
419
447
|
# Group jobs by layer name.
|
|
420
|
-
jobs_by_layer = {}
|
|
421
|
-
configs_by_layer = {}
|
|
448
|
+
jobs_by_layer: dict = {}
|
|
449
|
+
configs_by_layer: dict = {}
|
|
422
450
|
for layer_name, layer_cfg, item, geometries in jobs:
|
|
423
451
|
if layer_name not in jobs_by_layer:
|
|
424
452
|
jobs_by_layer[layer_name] = []
|
|
@@ -437,13 +465,31 @@ class IngestHandler:
|
|
|
437
465
|
geometries=[geometries for _, geometries in items_and_geometries],
|
|
438
466
|
)
|
|
439
467
|
except Exception as e:
|
|
440
|
-
|
|
468
|
+
if not self.ignore_errors:
|
|
469
|
+
raise
|
|
470
|
+
|
|
471
|
+
logger.error(
|
|
441
472
|
"warning: got error while ingesting "
|
|
442
473
|
+ f"{len(items_and_geometries)} items: {e}"
|
|
443
474
|
)
|
|
444
475
|
|
|
445
476
|
gc.collect()
|
|
446
477
|
|
|
478
|
+
def _load_layer_data_for_windows(
|
|
479
|
+
self, windows: list[Window], workers: int
|
|
480
|
+
) -> list[tuple[Window, dict[str, WindowLayerData]]]:
|
|
481
|
+
if workers == 0:
|
|
482
|
+
return [(_load_window_layer_datas(window)) for window in windows]
|
|
483
|
+
p = multiprocessing.Pool(workers)
|
|
484
|
+
outputs = p.imap_unordered(_load_window_layer_datas, windows)
|
|
485
|
+
windows_and_layer_datas = []
|
|
486
|
+
for window, layer_datas in tqdm.tqdm(
|
|
487
|
+
outputs, total=len(windows), desc="Loading window layer datas"
|
|
488
|
+
):
|
|
489
|
+
windows_and_layer_datas.append((window, layer_datas))
|
|
490
|
+
p.close()
|
|
491
|
+
return windows_and_layer_datas
|
|
492
|
+
|
|
447
493
|
def get_jobs(
|
|
448
494
|
self, windows: list[Window], workers: int
|
|
449
495
|
) -> list[tuple[str, LayerConfig, Item, list[STGeometry]]]:
|
|
@@ -455,17 +501,12 @@ class IngestHandler:
|
|
|
455
501
|
This makes sure that jobs are grouped by item rather than by window, which
|
|
456
502
|
makes sense because there's no reason to ingest the same item twice.
|
|
457
503
|
"""
|
|
504
|
+
if self.dataset is None:
|
|
505
|
+
raise ValueError("dataset not set")
|
|
458
506
|
# TODO: avoid duplicating ingest_dataset_windows...
|
|
459
507
|
|
|
460
508
|
# Load layer datas of each window.
|
|
461
|
-
|
|
462
|
-
outputs = p.imap_unordered(_load_window_layer_datas, windows)
|
|
463
|
-
windows_and_layer_datas = []
|
|
464
|
-
for window, layer_datas in tqdm.tqdm(
|
|
465
|
-
outputs, total=len(windows), desc="Loading window layer datas"
|
|
466
|
-
):
|
|
467
|
-
windows_and_layer_datas.append((window, layer_datas))
|
|
468
|
-
p.close()
|
|
509
|
+
windows_and_layer_datas = self._load_layer_data_for_windows(windows, workers)
|
|
469
510
|
|
|
470
511
|
jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]] = []
|
|
471
512
|
for layer_name, layer_cfg in self.dataset.layers.items():
|
|
@@ -476,7 +517,7 @@ class IngestHandler:
|
|
|
476
517
|
|
|
477
518
|
data_source = data_source_from_config(layer_cfg, self.dataset.path)
|
|
478
519
|
|
|
479
|
-
geometries_by_item = {}
|
|
520
|
+
geometries_by_item: dict = {}
|
|
480
521
|
for window, layer_datas in windows_and_layer_datas:
|
|
481
522
|
if layer_name not in layer_datas:
|
|
482
523
|
continue
|
|
@@ -484,7 +525,9 @@ class IngestHandler:
|
|
|
484
525
|
layer_data = layer_datas[layer_name]
|
|
485
526
|
for group in layer_data.serialized_item_groups:
|
|
486
527
|
for serialized_item in group:
|
|
487
|
-
item = data_source.deserialize_item(
|
|
528
|
+
item = data_source.deserialize_item( # type: ignore
|
|
529
|
+
serialized_item
|
|
530
|
+
)
|
|
488
531
|
if item not in geometries_by_item:
|
|
489
532
|
geometries_by_item[item] = []
|
|
490
533
|
geometries_by_item[item].append(geometry)
|
|
@@ -492,32 +535,45 @@ class IngestHandler:
|
|
|
492
535
|
for item, geometries in geometries_by_item.items():
|
|
493
536
|
jobs.append((layer_name, layer_cfg, item, geometries))
|
|
494
537
|
|
|
495
|
-
|
|
538
|
+
logger.info(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
|
|
496
539
|
return jobs
|
|
497
540
|
|
|
498
541
|
|
|
499
542
|
@register_handler("dataset", "ingest")
|
|
500
|
-
def dataset_ingest():
|
|
543
|
+
def dataset_ingest() -> None:
|
|
501
544
|
"""Handler for the rslearn dataset ingest command."""
|
|
502
545
|
parser = argparse.ArgumentParser(
|
|
503
546
|
prog="rslearn dataset ingest",
|
|
504
547
|
description="rslearn dataset ingest: ingest items in retrieved data sources",
|
|
505
548
|
)
|
|
549
|
+
parser.add_argument(
|
|
550
|
+
"--disabled-layers",
|
|
551
|
+
type=parse_disabled_layers,
|
|
552
|
+
default="",
|
|
553
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
554
|
+
)
|
|
555
|
+
parser.add_argument(
|
|
556
|
+
"--ignore-errors",
|
|
557
|
+
type=bool,
|
|
558
|
+
default=False,
|
|
559
|
+
help="Ignore ingestion errors in individual jobs",
|
|
560
|
+
action=argparse.BooleanOptionalAction,
|
|
561
|
+
)
|
|
506
562
|
add_apply_on_windows_args(parser)
|
|
507
563
|
args = parser.parse_args(args=sys.argv[3:])
|
|
508
564
|
|
|
509
|
-
fn = IngestHandler()
|
|
565
|
+
fn = IngestHandler(ignore_errors=args.ignore_errors)
|
|
510
566
|
apply_on_windows_args(fn, args)
|
|
511
567
|
|
|
512
568
|
|
|
513
569
|
class MaterializeHandler:
|
|
514
570
|
"""apply_on_windows handler for the rslearn dataset materialize command."""
|
|
515
571
|
|
|
516
|
-
def __init__(self):
|
|
572
|
+
def __init__(self) -> None:
|
|
517
573
|
"""Initialize a MaterializeHandler."""
|
|
518
|
-
self.dataset = None
|
|
574
|
+
self.dataset: Dataset | None = None
|
|
519
575
|
|
|
520
|
-
def set_dataset(self, dataset: Dataset):
|
|
576
|
+
def set_dataset(self, dataset: Dataset) -> None:
|
|
521
577
|
"""Captures the dataset from apply_on_windows_args.
|
|
522
578
|
|
|
523
579
|
Args:
|
|
@@ -525,13 +581,16 @@ class MaterializeHandler:
|
|
|
525
581
|
"""
|
|
526
582
|
self.dataset = dataset
|
|
527
583
|
|
|
528
|
-
def __call__(self, windows: list[Window]):
|
|
584
|
+
def __call__(self, windows: list[Window]) -> None:
|
|
529
585
|
"""Materializes the windows from apply_on_windows."""
|
|
586
|
+
logger.info(f"Running Materialize with {len(windows)} windows")
|
|
587
|
+
if self.dataset is None:
|
|
588
|
+
raise ValueError("dataset not set")
|
|
530
589
|
materialize_dataset_windows(self.dataset, windows)
|
|
531
590
|
|
|
532
591
|
|
|
533
592
|
@register_handler("dataset", "materialize")
|
|
534
|
-
def dataset_materialize():
|
|
593
|
+
def dataset_materialize() -> None:
|
|
535
594
|
"""Handler for the rslearn dataset materialize command."""
|
|
536
595
|
parser = argparse.ArgumentParser(
|
|
537
596
|
prog="rslearn dataset materialize",
|
|
@@ -540,9 +599,14 @@ def dataset_materialize():
|
|
|
540
599
|
+ "materialize data from retrieved data sources"
|
|
541
600
|
),
|
|
542
601
|
)
|
|
602
|
+
parser.add_argument(
|
|
603
|
+
"--disabled-layers",
|
|
604
|
+
type=parse_disabled_layers,
|
|
605
|
+
default="",
|
|
606
|
+
help="List of layers to disable e.g 'layer1,layer2'",
|
|
607
|
+
)
|
|
543
608
|
add_apply_on_windows_args(parser)
|
|
544
609
|
args = parser.parse_args(args=sys.argv[3:])
|
|
545
|
-
|
|
546
610
|
fn = MaterializeHandler()
|
|
547
611
|
apply_on_windows_args(fn, args)
|
|
548
612
|
|
|
@@ -550,7 +614,7 @@ def dataset_materialize():
|
|
|
550
614
|
class RslearnLightningCLI(LightningCLI):
|
|
551
615
|
"""LightningCLI that links data.tasks to model.tasks."""
|
|
552
616
|
|
|
553
|
-
def add_arguments_to_parser(self, parser) -> None:
|
|
617
|
+
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
554
618
|
"""Link data.tasks to model.tasks.
|
|
555
619
|
|
|
556
620
|
Args:
|
|
@@ -572,7 +636,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
572
636
|
help="Whether to resume from specified wandb_run_id",
|
|
573
637
|
)
|
|
574
638
|
|
|
575
|
-
def before_instantiate_classes(self):
|
|
639
|
+
def before_instantiate_classes(self) -> None:
|
|
576
640
|
"""Called before Lightning class initialization.
|
|
577
641
|
|
|
578
642
|
Sets up wandb_run_id / wandb_resume arguments.
|
|
@@ -585,7 +649,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
585
649
|
artifact_id = (
|
|
586
650
|
f"{c.trainer.logger.init_args.project}/model-{c.wandb_run_id}:latest"
|
|
587
651
|
)
|
|
588
|
-
|
|
652
|
+
logger.info(f"restoring from artifact {artifact_id} on wandb")
|
|
589
653
|
artifact = api.artifact(artifact_id, type="model")
|
|
590
654
|
artifact_dir = artifact.download()
|
|
591
655
|
c.ckpt_path = str(Path(artifact_dir) / "model.ckpt")
|
|
@@ -606,7 +670,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
606
670
|
prediction_writer_callback.init_args.path = c.data.init_args.path
|
|
607
671
|
|
|
608
672
|
|
|
609
|
-
def model_handler():
|
|
673
|
+
def model_handler() -> None:
|
|
610
674
|
"""Handler for any rslearn model X commands."""
|
|
611
675
|
RslearnLightningCLI(
|
|
612
676
|
model_class=RslearnLightningModule,
|
|
@@ -619,30 +683,30 @@ def model_handler():
|
|
|
619
683
|
|
|
620
684
|
|
|
621
685
|
@register_handler("model", "fit")
|
|
622
|
-
def model_fit():
|
|
686
|
+
def model_fit() -> None:
|
|
623
687
|
"""Handler for rslearn model fit."""
|
|
624
688
|
model_handler()
|
|
625
689
|
|
|
626
690
|
|
|
627
691
|
@register_handler("model", "validate")
|
|
628
|
-
def model_validate():
|
|
692
|
+
def model_validate() -> None:
|
|
629
693
|
"""Handler for rslearn model validate."""
|
|
630
694
|
model_handler()
|
|
631
695
|
|
|
632
696
|
|
|
633
697
|
@register_handler("model", "test")
|
|
634
|
-
def model_test():
|
|
698
|
+
def model_test() -> None:
|
|
635
699
|
"""Handler for rslearn model test."""
|
|
636
700
|
model_handler()
|
|
637
701
|
|
|
638
702
|
|
|
639
703
|
@register_handler("model", "predict")
|
|
640
|
-
def model_predict():
|
|
704
|
+
def model_predict() -> None:
|
|
641
705
|
"""Handler for rslearn model predict."""
|
|
642
706
|
model_handler()
|
|
643
707
|
|
|
644
708
|
|
|
645
|
-
def main():
|
|
709
|
+
def main() -> None:
|
|
646
710
|
"""CLI entrypoint."""
|
|
647
711
|
parser = argparse.ArgumentParser(description="rslearn")
|
|
648
712
|
parser.add_argument(
|
|
@@ -653,7 +717,7 @@ def main():
|
|
|
653
717
|
|
|
654
718
|
handler = handler_registry.get((args.category, args.command))
|
|
655
719
|
if handler is None:
|
|
656
|
-
|
|
720
|
+
logger.error(f"Unknown command: {args.category} {args.command}")
|
|
657
721
|
sys.exit(1)
|
|
658
722
|
|
|
659
723
|
handler()
|
rslearn/models/clip.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""OpenAI CLIP models."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CLIP(torch.nn.Module):
|
|
10
|
+
"""CLIP image encoder."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
model_name: str,
|
|
15
|
+
):
|
|
16
|
+
"""Instantiate a new CLIP instance.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model_name: the model name like "openai/clip-vit-large-patch14-336".
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
24
|
+
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
|
|
25
|
+
self.encoder = model.vision_model
|
|
26
|
+
|
|
27
|
+
# Get number of features and token map size from encoder attributes.
|
|
28
|
+
self.num_features = self.encoder.post_layernorm.normalized_shape[0]
|
|
29
|
+
crop_size = self.processor.image_processor.crop_size
|
|
30
|
+
stride = self.encoder.embeddings.patch_embedding.stride
|
|
31
|
+
self.height = crop_size["height"] // stride[0]
|
|
32
|
+
self.width = crop_size["width"] // stride[1]
|
|
33
|
+
|
|
34
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
35
|
+
"""Compute outputs from the backbone.
|
|
36
|
+
|
|
37
|
+
Inputs:
|
|
38
|
+
inputs: input dicts that must include "image" key containing the image to
|
|
39
|
+
process. The images should have values 0-255.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
list of feature maps. The ViT produces features at one scale, so the list
|
|
43
|
+
contains a single Bx24x24x1024 feature map.
|
|
44
|
+
"""
|
|
45
|
+
device = inputs[0]["image"].device
|
|
46
|
+
clip_inputs = self.processor(
|
|
47
|
+
images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
|
|
48
|
+
return_tensors="pt",
|
|
49
|
+
padding=True,
|
|
50
|
+
)
|
|
51
|
+
pixel_values = clip_inputs["pixel_values"].to(device)
|
|
52
|
+
output = self.encoder(pixel_values=pixel_values)
|
|
53
|
+
# Ignore class token output which is before the patch tokens.
|
|
54
|
+
image_features = output.last_hidden_state[:, 1:, :]
|
|
55
|
+
batch_size = image_features.shape[0]
|
|
56
|
+
|
|
57
|
+
# 576x1024 -> HxWxC
|
|
58
|
+
return [
|
|
59
|
+
image_features.reshape(
|
|
60
|
+
batch_size, self.height, self.width, self.num_features
|
|
61
|
+
).permute(0, 3, 1, 2)
|
|
62
|
+
]
|
rslearn/models/conv.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""A single convolutional layer."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Conv(torch.nn.Module):
|
|
7
|
+
"""A single convolutional layer.
|
|
8
|
+
|
|
9
|
+
It inputs a set of feature maps; the conv layer is applied to each feature map
|
|
10
|
+
independently, and list of outputs is returned.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
in_channels: int,
|
|
16
|
+
out_channels: int,
|
|
17
|
+
kernel_size: int,
|
|
18
|
+
padding: str = "same",
|
|
19
|
+
stride: int = 1,
|
|
20
|
+
activation: torch.nn.Module = torch.nn.ReLU(inplace=True),
|
|
21
|
+
):
|
|
22
|
+
"""Initialize a Conv.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
in_channels: number of input channels.
|
|
26
|
+
out_channels: number of output channels.
|
|
27
|
+
kernel_size: kernel size
|
|
28
|
+
padding: either "same" or "valid" to control padding
|
|
29
|
+
stride: stride to apply.
|
|
30
|
+
activation: activation to apply after convolution
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self.layer = torch.nn.Conv2d(
|
|
35
|
+
in_channels, out_channels, kernel_size, padding=padding, stride=stride
|
|
36
|
+
)
|
|
37
|
+
self.activation = activation
|
|
38
|
+
|
|
39
|
+
def forward(
|
|
40
|
+
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
41
|
+
) -> list[torch.Tensor]:
|
|
42
|
+
"""Compute flat output vector from multi-scale feature map.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
features: list of feature maps at different resolutions.
|
|
46
|
+
inputs: original inputs (ignored).
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
flat feature vector
|
|
50
|
+
"""
|
|
51
|
+
new_features = []
|
|
52
|
+
for feat_map in features:
|
|
53
|
+
feat_map = self.layer(feat_map)
|
|
54
|
+
feat_map = self.activation(feat_map)
|
|
55
|
+
new_features.append(feat_map)
|
|
56
|
+
return new_features
|
rslearn/models/faster_rcnn.py
CHANGED
|
@@ -10,7 +10,7 @@ import torchvision
|
|
|
10
10
|
class NoopTransform(torch.nn.Module):
|
|
11
11
|
"""A placeholder transform used with torchvision detection model."""
|
|
12
12
|
|
|
13
|
-
def __init__(self):
|
|
13
|
+
def __init__(self) -> None:
|
|
14
14
|
"""Create a new NoopTransform."""
|
|
15
15
|
super().__init__()
|
|
16
16
|
|
|
@@ -46,23 +46,6 @@ class NoopTransform(torch.nn.Module):
|
|
|
46
46
|
)
|
|
47
47
|
return image_list, targets
|
|
48
48
|
|
|
49
|
-
def postprocess(
|
|
50
|
-
self, detections: dict[str, torch.Tensor], image_sizes, orig_sizes
|
|
51
|
-
) -> dict[str, torch.Tensor]:
|
|
52
|
-
"""Post-process the detections to reflect original image size.
|
|
53
|
-
|
|
54
|
-
Since we didn't transform the images, we don't need to do anything here.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
detections: the raw detections
|
|
58
|
-
image_sizes: the transformed image sizes
|
|
59
|
-
orig_sizes: the original image sizes
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
the post-processed detections (unmodified from the provided detections)
|
|
63
|
-
"""
|
|
64
|
-
return detections
|
|
65
|
-
|
|
66
49
|
|
|
67
50
|
class FasterRCNN(torch.nn.Module):
|
|
68
51
|
"""Faster R-CNN head for predicting bounding boxes.
|
|
@@ -80,7 +63,7 @@ class FasterRCNN(torch.nn.Module):
|
|
|
80
63
|
anchor_sizes: list[list[int]],
|
|
81
64
|
instance_segmentation: bool = False,
|
|
82
65
|
box_score_thresh: float = 0.05,
|
|
83
|
-
):
|
|
66
|
+
) -> None:
|
|
84
67
|
"""Create a new FasterRCNN.
|
|
85
68
|
|
|
86
69
|
Args:
|
rslearn/models/fpn.py
CHANGED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Module wrappers."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DecoderModuleWrapper(torch.nn.Module):
|
|
7
|
+
"""Wrapper for a module that processes features to work in decoder.
|
|
8
|
+
|
|
9
|
+
The module should input feature map and produce a new feature map.
|
|
10
|
+
|
|
11
|
+
We wrap it to process each feature map in multi-scale features which is what's used
|
|
12
|
+
for most decoders.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
module: torch.nn.Module,
|
|
18
|
+
):
|
|
19
|
+
"""Initialize a DecoderModuleWrapper.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
module: the module to wrap
|
|
23
|
+
"""
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.module = module
|
|
26
|
+
|
|
27
|
+
def forward(
|
|
28
|
+
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
29
|
+
) -> list[torch.Tensor]:
|
|
30
|
+
"""Apply the wrapped module on each feature map.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
features: list of feature maps at different resolutions.
|
|
34
|
+
inputs: original inputs (ignored).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
new features
|
|
38
|
+
"""
|
|
39
|
+
new_features = []
|
|
40
|
+
for feat_map in features:
|
|
41
|
+
feat_map = self.module(feat_map)
|
|
42
|
+
new_features.append(feat_map)
|
|
43
|
+
return new_features
|