rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
rslearn/main.py CHANGED
@@ -1,39 +1,43 @@
1
1
  """Entrypoint for the rslearn command-line interface."""
2
2
 
3
3
  import argparse
4
- import logging
5
4
  import multiprocessing
6
5
  import random
7
6
  import sys
8
7
  from collections.abc import Callable
9
8
  from datetime import datetime, timezone
10
9
  from pathlib import Path
10
+ from typing import Any, TypeVar
11
11
 
12
12
  import tqdm
13
13
  import wandb
14
- from lightning.pytorch.cli import LightningCLI
14
+ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
15
15
  from rasterio.crs import CRS
16
16
  from upath import UPath
17
17
 
18
18
  from rslearn.config import LayerConfig
19
19
  from rslearn.const import WGS84_EPSG
20
20
  from rslearn.data_sources import Item, data_source_from_config
21
- from rslearn.dataset import Dataset, Window
21
+ from rslearn.dataset import Dataset, Window, WindowLayerData
22
22
  from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
23
23
  from rslearn.dataset.manage import materialize_dataset_windows, prepare_dataset_windows
24
+ from rslearn.log_utils import get_logger
24
25
  from rslearn.tile_stores import get_tile_store_for_layer
25
26
  from rslearn.train.data_module import RslearnDataModule
26
27
  from rslearn.train.lightning_module import RslearnLightningModule
27
- from rslearn.utils import Projection, STGeometry
28
+ from rslearn.utils import Projection, STGeometry, parse_disabled_layers
29
+
30
+ logger = get_logger(__name__)
28
31
 
29
- logging.basicConfig()
30
32
  handler_registry = {}
31
33
 
34
+ ItemType = TypeVar("ItemType", bound="Item")
35
+
32
36
 
33
- def register_handler(category, command):
37
+ def register_handler(category: Any, command: str) -> Callable:
34
38
  """Register a new handler for a command."""
35
39
 
36
- def decorator(f):
40
+ def decorator(f: Callable) -> Callable:
37
41
  handler_registry[(category, command)] = f
38
42
  return f
39
43
 
@@ -61,7 +65,7 @@ def parse_time_range(
61
65
 
62
66
 
63
67
  @register_handler("dataset", "add_windows")
64
- def add_windows():
68
+ def add_windows() -> None:
65
69
  """Handler for the rslearn dataset add_windows command."""
66
70
  parser = argparse.ArgumentParser(
67
71
  prog="rslearn dataset add_windows",
@@ -156,7 +160,13 @@ def add_windows():
156
160
  )
157
161
  args = parser.parse_args(args=sys.argv[3:])
158
162
 
159
- def parse_projection(crs_str, resolution, x_res, y_res, default_crs=None):
163
+ def parse_projection(
164
+ crs_str: str | None,
165
+ resolution: float | None,
166
+ x_res: float,
167
+ y_res: float,
168
+ default_crs: CRS | None = None,
169
+ ) -> Projection | None:
160
170
  if not crs_str:
161
171
  if default_crs:
162
172
  crs = default_crs
@@ -197,7 +207,8 @@ def add_windows():
197
207
  box = [float(value) for value in args.box.split(",")]
198
208
 
199
209
  windows = add_windows_from_box(
200
- box=box,
210
+ # TODO: we should have an object for box
211
+ box=box, # type: ignore
201
212
  src_projection=parse_projection(
202
213
  args.src_crs, args.src_resolution, args.src_x_res, args.src_y_res
203
214
  ),
@@ -210,10 +221,10 @@ def add_windows():
210
221
  else:
211
222
  raise Exception("one of box or fname must be specified")
212
223
 
213
- print(f"created {len(windows)} windows")
224
+ logger.info(f"created {len(windows)} windows")
214
225
 
215
226
 
216
- def add_apply_on_windows_args(parser: argparse.ArgumentParser):
227
+ def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
217
228
  """Add arguments for handlers that use the apply_on_windows helper.
218
229
 
219
230
  Args:
@@ -263,7 +274,7 @@ def apply_on_windows(
263
274
  batch_size: int = 1,
264
275
  jobs_per_process: int | None = None,
265
276
  use_initial_job: bool = True,
266
- ):
277
+ ) -> None:
267
278
  """A helper to apply a function on windows in a dataset.
268
279
 
269
280
  Args:
@@ -293,11 +304,11 @@ def apply_on_windows(
293
304
  windows = dataset.load_windows(
294
305
  groups=groups, names=names, workers=workers, show_progress=True
295
306
  )
296
- print(f"found {len(windows)} windows")
307
+ logger.info(f"found {len(windows)} windows")
297
308
 
298
309
  if hasattr(f, "get_jobs"):
299
310
  jobs = f.get_jobs(windows, workers)
300
- print(f"got {len(jobs)} jobs")
311
+ logger.info(f"got {len(jobs)} jobs")
301
312
  else:
302
313
  jobs = windows
303
314
 
@@ -323,9 +334,9 @@ def apply_on_windows(
323
334
  p.close()
324
335
 
325
336
 
326
- def apply_on_windows_args(f: Callable[[list[Window]], None], args: argparse.Namespace):
337
+ def apply_on_windows_args(f: Callable[..., None], args: argparse.Namespace) -> None:
327
338
  """Call apply_on_windows with arguments passed via command-line interface."""
328
- dataset = Dataset(UPath(args.root))
339
+ dataset = Dataset(UPath(args.root), args.disabled_layers)
329
340
  apply_on_windows(
330
341
  f,
331
342
  dataset,
@@ -341,16 +352,16 @@ def apply_on_windows_args(f: Callable[[list[Window]], None], args: argparse.Name
341
352
  class PrepareHandler:
342
353
  """apply_on_windows handler for the rslearn dataset prepare command."""
343
354
 
344
- def __init__(self, force: bool):
355
+ def __init__(self, force: bool) -> None:
345
356
  """Initialize a new PrepareHandler.
346
357
 
347
358
  Args:
348
359
  force: force prepare
349
360
  """
350
361
  self.force = force
351
- self.dataset = None
362
+ self.dataset: Dataset | None = None
352
363
 
353
- def set_dataset(self, dataset: Dataset):
364
+ def set_dataset(self, dataset: Dataset) -> None:
354
365
  """Captures the dataset from apply_on_windows_args.
355
366
 
356
367
  Args:
@@ -358,13 +369,16 @@ class PrepareHandler:
358
369
  """
359
370
  self.dataset = dataset
360
371
 
361
- def __call__(self, windows: list[Window]):
372
+ def __call__(self, windows: list[Window]) -> None:
362
373
  """Prepares the windows from apply_on_windows."""
374
+ logger.info(f"Running prepare on {len(windows)} windows")
375
+ if self.dataset is None:
376
+ raise ValueError("dataset not set")
363
377
  prepare_dataset_windows(self.dataset, windows, self.force)
364
378
 
365
379
 
366
380
  @register_handler("dataset", "prepare")
367
- def dataset_prepare():
381
+ def dataset_prepare() -> None:
368
382
  """Handler for the rslearn dataset prepare command."""
369
383
  parser = argparse.ArgumentParser(
370
384
  prog="rslearn dataset prepare",
@@ -377,6 +391,12 @@ def dataset_prepare():
377
391
  action=argparse.BooleanOptionalAction,
378
392
  help="Prepare windows even if they were previously prepared",
379
393
  )
394
+ parser.add_argument(
395
+ "--disabled-layers",
396
+ type=parse_disabled_layers,
397
+ default="",
398
+ help="List of layers to disable e.g 'layer1,layer2'",
399
+ )
380
400
  add_apply_on_windows_args(parser)
381
401
  args = parser.parse_args(args=sys.argv[3:])
382
402
 
@@ -384,7 +404,9 @@ def dataset_prepare():
384
404
  apply_on_windows_args(fn, args)
385
405
 
386
406
 
387
- def _load_window_layer_datas(window: Window):
407
+ def _load_window_layer_datas(
408
+ window: Window,
409
+ ) -> tuple[Window, dict[str, WindowLayerData]]:
388
410
  # Helper for IngestHandler to use with multiprocessing.
389
411
  return window, window.load_layer_datas()
390
412
 
@@ -392,11 +414,12 @@ def _load_window_layer_datas(window: Window):
392
414
  class IngestHandler:
393
415
  """apply_on_windows handler for the rslearn dataset ingest command."""
394
416
 
395
- def __init__(self):
417
+ def __init__(self, ignore_errors: bool = False) -> None:
396
418
  """Initialize a new IngestHandler."""
397
- self.dataset = None
419
+ self.dataset: Dataset | None = None
420
+ self.ignore_errors = ignore_errors
398
421
 
399
- def set_dataset(self, dataset: Dataset):
422
+ def set_dataset(self, dataset: Dataset) -> None:
400
423
  """Captures the dataset from apply_on_windows_args.
401
424
 
402
425
  Args:
@@ -404,7 +427,9 @@ class IngestHandler:
404
427
  """
405
428
  self.dataset = dataset
406
429
 
407
- def __call__(self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]):
430
+ def __call__(
431
+ self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
432
+ ) -> None:
408
433
  """Ingest the specified items.
409
434
 
410
435
  The items are computed from list of windows via IngestHandler.get_jobs.
@@ -412,13 +437,16 @@ class IngestHandler:
412
437
  Args:
413
438
  jobs: list of (layer_name, item, geometries) tuples to ingest.
414
439
  """
440
+ logger.info(f"Running ingest for {len(jobs)} jobs")
415
441
  import gc
416
442
 
443
+ if self.dataset is None:
444
+ raise ValueError("dataset not set")
417
445
  tile_store = self.dataset.get_tile_store()
418
446
 
419
447
  # Group jobs by layer name.
420
- jobs_by_layer = {}
421
- configs_by_layer = {}
448
+ jobs_by_layer: dict = {}
449
+ configs_by_layer: dict = {}
422
450
  for layer_name, layer_cfg, item, geometries in jobs:
423
451
  if layer_name not in jobs_by_layer:
424
452
  jobs_by_layer[layer_name] = []
@@ -437,13 +465,31 @@ class IngestHandler:
437
465
  geometries=[geometries for _, geometries in items_and_geometries],
438
466
  )
439
467
  except Exception as e:
440
- print(
468
+ if not self.ignore_errors:
469
+ raise
470
+
471
+ logger.error(
441
472
  "warning: got error while ingesting "
442
473
  + f"{len(items_and_geometries)} items: {e}"
443
474
  )
444
475
 
445
476
  gc.collect()
446
477
 
478
+ def _load_layer_data_for_windows(
479
+ self, windows: list[Window], workers: int
480
+ ) -> list[tuple[Window, dict[str, WindowLayerData]]]:
481
+ if workers == 0:
482
+ return [(_load_window_layer_datas(window)) for window in windows]
483
+ p = multiprocessing.Pool(workers)
484
+ outputs = p.imap_unordered(_load_window_layer_datas, windows)
485
+ windows_and_layer_datas = []
486
+ for window, layer_datas in tqdm.tqdm(
487
+ outputs, total=len(windows), desc="Loading window layer datas"
488
+ ):
489
+ windows_and_layer_datas.append((window, layer_datas))
490
+ p.close()
491
+ return windows_and_layer_datas
492
+
447
493
  def get_jobs(
448
494
  self, windows: list[Window], workers: int
449
495
  ) -> list[tuple[str, LayerConfig, Item, list[STGeometry]]]:
@@ -455,17 +501,12 @@ class IngestHandler:
455
501
  This makes sure that jobs are grouped by item rather than by window, which
456
502
  makes sense because there's no reason to ingest the same item twice.
457
503
  """
504
+ if self.dataset is None:
505
+ raise ValueError("dataset not set")
458
506
  # TODO: avoid duplicating ingest_dataset_windows...
459
507
 
460
508
  # Load layer datas of each window.
461
- p = multiprocessing.Pool(workers)
462
- outputs = p.imap_unordered(_load_window_layer_datas, windows)
463
- windows_and_layer_datas = []
464
- for window, layer_datas in tqdm.tqdm(
465
- outputs, total=len(windows), desc="Loading window layer datas"
466
- ):
467
- windows_and_layer_datas.append((window, layer_datas))
468
- p.close()
509
+ windows_and_layer_datas = self._load_layer_data_for_windows(windows, workers)
469
510
 
470
511
  jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]] = []
471
512
  for layer_name, layer_cfg in self.dataset.layers.items():
@@ -476,7 +517,7 @@ class IngestHandler:
476
517
 
477
518
  data_source = data_source_from_config(layer_cfg, self.dataset.path)
478
519
 
479
- geometries_by_item = {}
520
+ geometries_by_item: dict = {}
480
521
  for window, layer_datas in windows_and_layer_datas:
481
522
  if layer_name not in layer_datas:
482
523
  continue
@@ -484,7 +525,9 @@ class IngestHandler:
484
525
  layer_data = layer_datas[layer_name]
485
526
  for group in layer_data.serialized_item_groups:
486
527
  for serialized_item in group:
487
- item = data_source.deserialize_item(serialized_item)
528
+ item = data_source.deserialize_item( # type: ignore
529
+ serialized_item
530
+ )
488
531
  if item not in geometries_by_item:
489
532
  geometries_by_item[item] = []
490
533
  geometries_by_item[item].append(geometry)
@@ -492,32 +535,45 @@ class IngestHandler:
492
535
  for item, geometries in geometries_by_item.items():
493
536
  jobs.append((layer_name, layer_cfg, item, geometries))
494
537
 
495
- print(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
538
+ logger.info(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
496
539
  return jobs
497
540
 
498
541
 
499
542
  @register_handler("dataset", "ingest")
500
- def dataset_ingest():
543
+ def dataset_ingest() -> None:
501
544
  """Handler for the rslearn dataset ingest command."""
502
545
  parser = argparse.ArgumentParser(
503
546
  prog="rslearn dataset ingest",
504
547
  description="rslearn dataset ingest: ingest items in retrieved data sources",
505
548
  )
549
+ parser.add_argument(
550
+ "--disabled-layers",
551
+ type=parse_disabled_layers,
552
+ default="",
553
+ help="List of layers to disable e.g 'layer1,layer2'",
554
+ )
555
+ parser.add_argument(
556
+ "--ignore-errors",
557
+ type=bool,
558
+ default=False,
559
+ help="Ignore ingestion errors in individual jobs",
560
+ action=argparse.BooleanOptionalAction,
561
+ )
506
562
  add_apply_on_windows_args(parser)
507
563
  args = parser.parse_args(args=sys.argv[3:])
508
564
 
509
- fn = IngestHandler()
565
+ fn = IngestHandler(ignore_errors=args.ignore_errors)
510
566
  apply_on_windows_args(fn, args)
511
567
 
512
568
 
513
569
  class MaterializeHandler:
514
570
  """apply_on_windows handler for the rslearn dataset materialize command."""
515
571
 
516
- def __init__(self):
572
+ def __init__(self) -> None:
517
573
  """Initialize a MaterializeHandler."""
518
- self.dataset = None
574
+ self.dataset: Dataset | None = None
519
575
 
520
- def set_dataset(self, dataset: Dataset):
576
+ def set_dataset(self, dataset: Dataset) -> None:
521
577
  """Captures the dataset from apply_on_windows_args.
522
578
 
523
579
  Args:
@@ -525,13 +581,16 @@ class MaterializeHandler:
525
581
  """
526
582
  self.dataset = dataset
527
583
 
528
- def __call__(self, windows: list[Window]):
584
+ def __call__(self, windows: list[Window]) -> None:
529
585
  """Materializes the windows from apply_on_windows."""
586
+ logger.info(f"Running Materialize with {len(windows)} windows")
587
+ if self.dataset is None:
588
+ raise ValueError("dataset not set")
530
589
  materialize_dataset_windows(self.dataset, windows)
531
590
 
532
591
 
533
592
  @register_handler("dataset", "materialize")
534
- def dataset_materialize():
593
+ def dataset_materialize() -> None:
535
594
  """Handler for the rslearn dataset materialize command."""
536
595
  parser = argparse.ArgumentParser(
537
596
  prog="rslearn dataset materialize",
@@ -540,9 +599,14 @@ def dataset_materialize():
540
599
  + "materialize data from retrieved data sources"
541
600
  ),
542
601
  )
602
+ parser.add_argument(
603
+ "--disabled-layers",
604
+ type=parse_disabled_layers,
605
+ default="",
606
+ help="List of layers to disable e.g 'layer1,layer2'",
607
+ )
543
608
  add_apply_on_windows_args(parser)
544
609
  args = parser.parse_args(args=sys.argv[3:])
545
-
546
610
  fn = MaterializeHandler()
547
611
  apply_on_windows_args(fn, args)
548
612
 
@@ -550,7 +614,7 @@ def dataset_materialize():
550
614
  class RslearnLightningCLI(LightningCLI):
551
615
  """LightningCLI that links data.tasks to model.tasks."""
552
616
 
553
- def add_arguments_to_parser(self, parser) -> None:
617
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
554
618
  """Link data.tasks to model.tasks.
555
619
 
556
620
  Args:
@@ -572,7 +636,7 @@ class RslearnLightningCLI(LightningCLI):
572
636
  help="Whether to resume from specified wandb_run_id",
573
637
  )
574
638
 
575
- def before_instantiate_classes(self):
639
+ def before_instantiate_classes(self) -> None:
576
640
  """Called before Lightning class initialization.
577
641
 
578
642
  Sets up wandb_run_id / wandb_resume arguments.
@@ -585,7 +649,7 @@ class RslearnLightningCLI(LightningCLI):
585
649
  artifact_id = (
586
650
  f"{c.trainer.logger.init_args.project}/model-{c.wandb_run_id}:latest"
587
651
  )
588
- print(f"restoring from artifact {artifact_id} on wandb")
652
+ logger.info(f"restoring from artifact {artifact_id} on wandb")
589
653
  artifact = api.artifact(artifact_id, type="model")
590
654
  artifact_dir = artifact.download()
591
655
  c.ckpt_path = str(Path(artifact_dir) / "model.ckpt")
@@ -606,7 +670,7 @@ class RslearnLightningCLI(LightningCLI):
606
670
  prediction_writer_callback.init_args.path = c.data.init_args.path
607
671
 
608
672
 
609
- def model_handler():
673
+ def model_handler() -> None:
610
674
  """Handler for any rslearn model X commands."""
611
675
  RslearnLightningCLI(
612
676
  model_class=RslearnLightningModule,
@@ -619,30 +683,30 @@ def model_handler():
619
683
 
620
684
 
621
685
  @register_handler("model", "fit")
622
- def model_fit():
686
+ def model_fit() -> None:
623
687
  """Handler for rslearn model fit."""
624
688
  model_handler()
625
689
 
626
690
 
627
691
  @register_handler("model", "validate")
628
- def model_validate():
692
+ def model_validate() -> None:
629
693
  """Handler for rslearn model validate."""
630
694
  model_handler()
631
695
 
632
696
 
633
697
  @register_handler("model", "test")
634
- def model_test():
698
+ def model_test() -> None:
635
699
  """Handler for rslearn model test."""
636
700
  model_handler()
637
701
 
638
702
 
639
703
  @register_handler("model", "predict")
640
- def model_predict():
704
+ def model_predict() -> None:
641
705
  """Handler for rslearn model predict."""
642
706
  model_handler()
643
707
 
644
708
 
645
- def main():
709
+ def main() -> None:
646
710
  """CLI entrypoint."""
647
711
  parser = argparse.ArgumentParser(description="rslearn")
648
712
  parser.add_argument(
@@ -653,7 +717,7 @@ def main():
653
717
 
654
718
  handler = handler_registry.get((args.category, args.command))
655
719
  if handler is None:
656
- print(f"Unknown command: {args.category} {args.command}", file=sys.stderr)
720
+ logger.error(f"Unknown command: {args.category} {args.command}")
657
721
  sys.exit(1)
658
722
 
659
723
  handler()
rslearn/models/clip.py ADDED
@@ -0,0 +1,62 @@
1
+ """OpenAI CLIP models."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
7
+
8
+
9
+ class CLIP(torch.nn.Module):
10
+ """CLIP image encoder."""
11
+
12
+ def __init__(
13
+ self,
14
+ model_name: str,
15
+ ):
16
+ """Instantiate a new CLIP instance.
17
+
18
+ Args:
19
+ model_name: the model name like "openai/clip-vit-large-patch14-336".
20
+ """
21
+ super().__init__()
22
+
23
+ self.processor = AutoProcessor.from_pretrained(model_name)
24
+ model = AutoModelForZeroShotImageClassification.from_pretrained(model_name)
25
+ self.encoder = model.vision_model
26
+
27
+ # Get number of features and token map size from encoder attributes.
28
+ self.num_features = self.encoder.post_layernorm.normalized_shape[0]
29
+ crop_size = self.processor.image_processor.crop_size
30
+ stride = self.encoder.embeddings.patch_embedding.stride
31
+ self.height = crop_size["height"] // stride[0]
32
+ self.width = crop_size["width"] // stride[1]
33
+
34
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
35
+ """Compute outputs from the backbone.
36
+
37
+ Inputs:
38
+ inputs: input dicts that must include "image" key containing the image to
39
+ process. The images should have values 0-255.
40
+
41
+ Returns:
42
+ list of feature maps. The ViT produces features at one scale, so the list
43
+ contains a single Bx24x24x1024 feature map.
44
+ """
45
+ device = inputs[0]["image"].device
46
+ clip_inputs = self.processor(
47
+ images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
48
+ return_tensors="pt",
49
+ padding=True,
50
+ )
51
+ pixel_values = clip_inputs["pixel_values"].to(device)
52
+ output = self.encoder(pixel_values=pixel_values)
53
+ # Ignore class token output which is before the patch tokens.
54
+ image_features = output.last_hidden_state[:, 1:, :]
55
+ batch_size = image_features.shape[0]
56
+
57
+ # 576x1024 -> HxWxC
58
+ return [
59
+ image_features.reshape(
60
+ batch_size, self.height, self.width, self.num_features
61
+ ).permute(0, 3, 1, 2)
62
+ ]
rslearn/models/conv.py ADDED
@@ -0,0 +1,56 @@
1
+ """A single convolutional layer."""
2
+
3
+ import torch
4
+
5
+
6
+ class Conv(torch.nn.Module):
7
+ """A single convolutional layer.
8
+
9
+ It inputs a set of feature maps; the conv layer is applied to each feature map
10
+ independently, and list of outputs is returned.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ out_channels: int,
17
+ kernel_size: int,
18
+ padding: str = "same",
19
+ stride: int = 1,
20
+ activation: torch.nn.Module = torch.nn.ReLU(inplace=True),
21
+ ):
22
+ """Initialize a Conv.
23
+
24
+ Args:
25
+ in_channels: number of input channels.
26
+ out_channels: number of output channels.
27
+ kernel_size: kernel size
28
+ padding: either "same" or "valid" to control padding
29
+ stride: stride to apply.
30
+ activation: activation to apply after convolution
31
+ """
32
+ super().__init__()
33
+
34
+ self.layer = torch.nn.Conv2d(
35
+ in_channels, out_channels, kernel_size, padding=padding, stride=stride
36
+ )
37
+ self.activation = activation
38
+
39
+ def forward(
40
+ self, features: list[torch.Tensor], inputs: list[torch.Tensor]
41
+ ) -> list[torch.Tensor]:
42
+ """Compute flat output vector from multi-scale feature map.
43
+
44
+ Args:
45
+ features: list of feature maps at different resolutions.
46
+ inputs: original inputs (ignored).
47
+
48
+ Returns:
49
+ flat feature vector
50
+ """
51
+ new_features = []
52
+ for feat_map in features:
53
+ feat_map = self.layer(feat_map)
54
+ feat_map = self.activation(feat_map)
55
+ new_features.append(feat_map)
56
+ return new_features
@@ -10,7 +10,7 @@ import torchvision
10
10
  class NoopTransform(torch.nn.Module):
11
11
  """A placeholder transform used with torchvision detection model."""
12
12
 
13
- def __init__(self):
13
+ def __init__(self) -> None:
14
14
  """Create a new NoopTransform."""
15
15
  super().__init__()
16
16
 
@@ -46,23 +46,6 @@ class NoopTransform(torch.nn.Module):
46
46
  )
47
47
  return image_list, targets
48
48
 
49
- def postprocess(
50
- self, detections: dict[str, torch.Tensor], image_sizes, orig_sizes
51
- ) -> dict[str, torch.Tensor]:
52
- """Post-process the detections to reflect original image size.
53
-
54
- Since we didn't transform the images, we don't need to do anything here.
55
-
56
- Args:
57
- detections: the raw detections
58
- image_sizes: the transformed image sizes
59
- orig_sizes: the original image sizes
60
-
61
- Returns:
62
- the post-processed detections (unmodified from the provided detections)
63
- """
64
- return detections
65
-
66
49
 
67
50
  class FasterRCNN(torch.nn.Module):
68
51
  """Faster R-CNN head for predicting bounding boxes.
@@ -80,7 +63,7 @@ class FasterRCNN(torch.nn.Module):
80
63
  anchor_sizes: list[list[int]],
81
64
  instance_segmentation: bool = False,
82
65
  box_score_thresh: float = 0.05,
83
- ):
66
+ ) -> None:
84
67
  """Create a new FasterRCNN.
85
68
 
86
69
  Args:
rslearn/models/fpn.py CHANGED
@@ -32,7 +32,7 @@ class Fpn(torch.nn.Module):
32
32
  in_channels_list=in_channels, out_channels=out_channels
33
33
  )
34
34
 
35
- def forward(self, x: list[torch.Tensor]):
35
+ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
36
36
  """Compute outputs of the FPN.
37
37
 
38
38
  Args:
@@ -0,0 +1,43 @@
1
+ """Module wrappers."""
2
+
3
+ import torch
4
+
5
+
6
+ class DecoderModuleWrapper(torch.nn.Module):
7
+ """Wrapper for a module that processes features to work in decoder.
8
+
9
+ The module should input feature map and produce a new feature map.
10
+
11
+ We wrap it to process each feature map in multi-scale features which is what's used
12
+ for most decoders.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ module: torch.nn.Module,
18
+ ):
19
+ """Initialize a DecoderModuleWrapper.
20
+
21
+ Args:
22
+ module: the module to wrap
23
+ """
24
+ super().__init__()
25
+ self.module = module
26
+
27
+ def forward(
28
+ self, features: list[torch.Tensor], inputs: list[torch.Tensor]
29
+ ) -> list[torch.Tensor]:
30
+ """Apply the wrapped module on each feature map.
31
+
32
+ Args:
33
+ features: list of feature maps at different resolutions.
34
+ inputs: original inputs (ignored).
35
+
36
+ Returns:
37
+ new features
38
+ """
39
+ new_features = []
40
+ for feat_map in features:
41
+ feat_map = self.module(feat_map)
42
+ new_features.append(feat_map)
43
+ return new_features