rslearn 0.0.16__tar.gz → 0.0.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. {rslearn-0.0.16/rslearn.egg-info → rslearn-0.0.18}/PKG-INFO +58 -25
  2. {rslearn-0.0.16 → rslearn-0.0.18}/README.md +57 -24
  3. {rslearn-0.0.16 → rslearn-0.0.18}/pyproject.toml +1 -1
  4. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/config/__init__.py +2 -0
  5. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/config/dataset.py +55 -4
  6. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/add_windows.py +1 -1
  7. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/dataset.py +9 -65
  8. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/materialize.py +5 -5
  9. rslearn-0.0.18/rslearn/dataset/storage/__init__.py +1 -0
  10. rslearn-0.0.18/rslearn/dataset/storage/file.py +202 -0
  11. rslearn-0.0.18/rslearn/dataset/storage/storage.py +140 -0
  12. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/window.py +26 -80
  13. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/lightning_cli.py +10 -3
  14. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/main.py +11 -36
  15. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/anysat.py +11 -9
  16. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/clay/clay.py +8 -9
  17. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/clip.py +18 -15
  18. rslearn-0.0.18/rslearn/models/component.py +99 -0
  19. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/concatenate_features.py +21 -11
  20. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/conv.py +15 -8
  21. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/croma.py +13 -8
  22. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/detr.py +25 -14
  23. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/dinov3.py +11 -6
  24. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/faster_rcnn.py +19 -9
  25. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/feature_center_crop.py +12 -9
  26. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/fpn.py +19 -8
  27. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/galileo/galileo.py +23 -18
  28. rslearn-0.0.18/rslearn/models/module_wrapper.py +60 -0
  29. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/molmo.py +16 -14
  30. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/multitask.py +102 -73
  31. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/olmoearth_pretrain/model.py +20 -17
  32. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon.py +8 -7
  33. rslearn-0.0.18/rslearn/models/pick_features.py +40 -0
  34. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/pooling_decoder.py +22 -14
  35. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/presto/presto.py +16 -10
  36. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/presto/single_file_presto.py +4 -10
  37. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/prithvi.py +12 -8
  38. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/resize_features.py +21 -7
  39. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/sam2_enc.py +11 -9
  40. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/satlaspretrain.py +15 -9
  41. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/simple_time_series.py +31 -17
  42. rslearn-0.0.18/rslearn/models/singletask.py +58 -0
  43. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/ssl4eo_s12.py +15 -10
  44. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/swin.py +22 -13
  45. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/terramind.py +24 -7
  46. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/trunk.py +6 -3
  47. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/unet.py +18 -9
  48. rslearn-0.0.18/rslearn/models/upsample.py +48 -0
  49. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/all_patches_dataset.py +22 -18
  50. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/dataset.py +69 -54
  51. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/lightning_module.py +51 -32
  52. rslearn-0.0.18/rslearn/train/model_context.py +54 -0
  53. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/prediction_writer.py +111 -41
  54. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/classification.py +34 -15
  55. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/detection.py +24 -31
  56. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/embedding.py +33 -29
  57. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/multi_task.py +7 -7
  58. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/regression.py +38 -21
  60. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/segmentation.py +33 -15
  61. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/task.py +3 -2
  62. {rslearn-0.0.16 → rslearn-0.0.18/rslearn.egg-info}/PKG-INFO +58 -25
  63. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn.egg-info/SOURCES.txt +5 -2
  64. rslearn-0.0.16/rslearn/dataset/index.py +0 -173
  65. rslearn-0.0.16/rslearn/models/module_wrapper.py +0 -91
  66. rslearn-0.0.16/rslearn/models/pick_features.py +0 -46
  67. rslearn-0.0.16/rslearn/models/registry.py +0 -22
  68. rslearn-0.0.16/rslearn/models/singletask.py +0 -51
  69. rslearn-0.0.16/rslearn/models/upsample.py +0 -35
  70. {rslearn-0.0.16 → rslearn-0.0.18}/LICENSE +0 -0
  71. {rslearn-0.0.16 → rslearn-0.0.18}/NOTICE +0 -0
  72. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/__init__.py +0 -0
  73. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/arg_parser.py +0 -0
  74. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/const.py +0 -0
  75. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/__init__.py +0 -0
  76. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/aws_landsat.py +0 -0
  77. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/aws_open_data.py +0 -0
  78. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/aws_sentinel1.py +0 -0
  79. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/climate_data_store.py +0 -0
  80. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/copernicus.py +0 -0
  81. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/data_source.py +0 -0
  82. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/earthdaily.py +0 -0
  83. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/earthdata_srtm.py +0 -0
  84. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/eurocrops.py +0 -0
  85. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/gcp_public_data.py +0 -0
  86. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/google_earth_engine.py +0 -0
  87. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/local_files.py +0 -0
  88. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/openstreetmap.py +0 -0
  89. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/planet.py +0 -0
  90. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/planet_basemap.py +0 -0
  91. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/planetary_computer.py +0 -0
  92. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/usda_cdl.py +0 -0
  93. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/usgs_landsat.py +0 -0
  94. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/utils.py +0 -0
  95. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/vector_source.py +0 -0
  96. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/worldcereal.py +0 -0
  97. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/worldcover.py +0 -0
  98. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/worldpop.py +0 -0
  99. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/data_sources/xyz_tiles.py +0 -0
  100. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/__init__.py +0 -0
  101. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/handler_summaries.py +0 -0
  102. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/manage.py +0 -0
  103. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/dataset/remap.py +0 -0
  104. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/log_utils.py +0 -0
  105. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/__init__.py +0 -0
  106. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/clay/configs/metadata.yaml +0 -0
  107. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/__init__.py +0 -0
  108. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/box_ops.py +0 -0
  109. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/matcher.py +0 -0
  110. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/position_encoding.py +0 -0
  111. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/transformer.py +0 -0
  112. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/detr/util.py +0 -0
  113. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/galileo/__init__.py +0 -0
  114. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/galileo/single_file_galileo.py +0 -0
  115. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  116. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  117. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  118. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  119. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  120. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  121. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  122. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  123. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  124. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  125. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  126. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  127. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  128. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  129. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/presto/__init__.py +0 -0
  130. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/task_embedding.py +0 -0
  131. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/models/use_croma.py +0 -0
  132. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/py.typed +0 -0
  133. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/template_params.py +0 -0
  134. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/tile_stores/__init__.py +0 -0
  135. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/tile_stores/default.py +0 -0
  136. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/tile_stores/tile_store.py +0 -0
  137. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/__init__.py +0 -0
  138. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/callbacks/__init__.py +0 -0
  139. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/callbacks/adapters.py +0 -0
  140. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  141. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/callbacks/gradients.py +0 -0
  142. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/callbacks/peft.py +0 -0
  143. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/data_module.py +0 -0
  144. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/optimizer.py +0 -0
  145. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/scheduler.py +0 -0
  146. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/tasks/__init__.py +0 -0
  147. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/__init__.py +0 -0
  148. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/concatenate.py +0 -0
  149. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/crop.py +0 -0
  150. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/flip.py +0 -0
  151. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/mask.py +0 -0
  152. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/normalize.py +0 -0
  153. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/pad.py +0 -0
  154. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/select_bands.py +0 -0
  155. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/sentinel1.py +0 -0
  156. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/train/transforms/transform.py +0 -0
  157. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/__init__.py +0 -0
  158. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/array.py +0 -0
  159. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/feature.py +0 -0
  160. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/fsspec.py +0 -0
  161. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/geometry.py +0 -0
  162. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/get_utm_ups_crs.py +0 -0
  163. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/grid_index.py +0 -0
  164. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/jsonargparse.py +0 -0
  165. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/mp.py +0 -0
  166. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/raster_format.py +0 -0
  167. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/rtree_index.py +0 -0
  168. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/spatial_index.py +0 -0
  169. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/sqlite_index.py +0 -0
  170. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/time.py +0 -0
  171. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn/utils/vector_format.py +0 -0
  172. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn.egg-info/dependency_links.txt +0 -0
  173. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn.egg-info/entry_points.txt +0 -0
  174. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn.egg-info/requires.txt +0 -0
  175. {rslearn-0.0.16 → rslearn-0.0.18}/rslearn.egg-info/top_level.txt +0 -0
  176. {rslearn-0.0.16 → rslearn-0.0.18}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -343,10 +343,12 @@ directory `/path/to/dataset` and corresponding configuration file at
343
343
  "bands": ["R", "G", "B"]
344
344
  }],
345
345
  "data_source": {
346
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
- "index_cache_dir": "cache/sentinel2/",
348
- "sort_by": "cloud_cover",
349
- "use_rtree_index": false
346
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
+ "init_args": {
348
+ "index_cache_dir": "cache/sentinel2/",
349
+ "sort_by": "cloud_cover",
350
+ "use_rtree_index": false
351
+ }
350
352
  }
351
353
  }
352
354
  }
@@ -453,8 +455,10 @@ automate this process. Update the dataset `config.json` with a new layer:
453
455
  }],
454
456
  "resampling_method": "nearest",
455
457
  "data_source": {
456
- "name": "rslearn.data_sources.local_files.LocalFiles",
457
- "src_dir": "file:///path/to/world_cover_tifs/"
458
+ "class_path": "rslearn.data_sources.local_files.LocalFiles",
459
+ "init_args": {
460
+ "src_dir": "file:///path/to/world_cover_tifs/"
461
+ }
458
462
  }
459
463
  }
460
464
  },
@@ -516,8 +520,7 @@ model:
516
520
  data:
517
521
  class_path: rslearn.train.data_module.RslearnDataModule
518
522
  init_args:
519
- # Replace this with the dataset path.
520
- path: /path/to/dataset/
523
+ path: ${DATASET_PATH}
521
524
  # This defines the layers that should be read for each window.
522
525
  # The key ("image" / "targets") is what the data will be called in the model,
523
526
  # while the layers option specifies which layers will be read.
@@ -615,7 +618,9 @@ trainer:
615
618
  ...
616
619
  - class_path: rslearn.train.prediction_writer.RslearnWriter
617
620
  init_args:
618
- path: /path/to/dataset/
621
+ # We need to include this argument, but it will be overridden with the dataset
622
+ # path from data.init_args.path.
623
+ path: placeholder
619
624
  output_layer: output
620
625
  ```
621
626
 
@@ -768,24 +773,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
768
773
  SegmentationTask and overriding the visualize function.
769
774
 
770
775
 
771
- ### Logging to Weights & Biases
776
+ ### Checkpoint and Logging Management
777
+
778
+ Above, we needed to configure the checkpoint directory in the model config (the
779
+ `dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
780
+ specify the checkpoint path when applying the model. Additionally, metrics are logged
781
+ to the local filesystem and not well organized.
772
782
 
773
- We can log to W&B by setting the logger under trainer in the model configuration file:
783
+ We can instead let rslearn automatically manage checkpoints, along with logging to
784
+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
785
+ to the model config. The project_name corresponds to the W&B project, and the run name
786
+ corresponds to the W&B name. The management_dir is a directory to store project data;
787
+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
788
+ and uses it to store checkpoints.
774
789
 
775
790
  ```yaml
791
+ model:
792
+ # ...
793
+ data:
794
+ # ...
776
795
  trainer:
777
796
  # ...
778
- logger:
779
- class_path: lightning.pytorch.loggers.WandbLogger
780
- init_args:
781
- project: land_cover_model
782
- name: version_00
797
+ project_name: land_cover_model
798
+ run_name: version_00
799
+ # This sets the option via the MANAGEMENT_DIR environment variable.
800
+ management_dir: ${MANAGEMENT_DIR}
783
801
  ```
784
802
 
785
- Now, runs with this model configuration should show on W&B. For `model fit` runs,
786
- the training and validation loss and accuracy metric will be logged. The accuracy
787
- metric is provided by SegmentationTask, and additional metrics can be enabled by
788
- passing the relevant init_args to the task, e.g. mean IoU and F1:
803
+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
804
+
805
+ ```
806
+ export MANAGEMENT_DIR=./project_data
807
+ rslearn model fit --config land_cover_model.yaml
808
+ ```
809
+
810
+ The training and validation loss and accuracy metric should now be logged to W&B. The
811
+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
812
+ by passing the relevant init_args to the task, e.g. mean IoU and F1:
789
813
 
790
814
  ```yaml
791
815
  class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -796,6 +820,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
796
820
  enable_f1_metric: true
797
821
  ```
798
822
 
823
+ When calling `model test` and `model predict` with management_dir set, rslearn will
824
+ automatically load the best checkpoint from the project directory, or raise an error if
825
+ no existing checkpoint exists. This behavior can be overridden with the
826
+ `--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
827
+ details). Logging will be enabled during fit but not test/predict, and this can also
828
+ be overridden, using `--log_mode`.
829
+
799
830
 
800
831
  ### Inputting Multiple Sentinel-2 Images
801
832
 
@@ -818,10 +849,12 @@ query_config section. This can replace the sentinel2 layer:
818
849
  "bands": ["R", "G", "B"]
819
850
  }],
820
851
  "data_source": {
821
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
822
- "index_cache_dir": "cache/sentinel2/",
823
- "sort_by": "cloud_cover",
824
- "use_rtree_index": false,
852
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
853
+ "init_args": {
854
+ "index_cache_dir": "cache/sentinel2/",
855
+ "sort_by": "cloud_cover",
856
+ "use_rtree_index": false
857
+ },
825
858
  "query_config": {
826
859
  "max_matches": 3
827
860
  }
@@ -79,10 +79,12 @@ directory `/path/to/dataset` and corresponding configuration file at
79
79
  "bands": ["R", "G", "B"]
80
80
  }],
81
81
  "data_source": {
82
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
83
- "index_cache_dir": "cache/sentinel2/",
84
- "sort_by": "cloud_cover",
85
- "use_rtree_index": false
82
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
83
+ "init_args": {
84
+ "index_cache_dir": "cache/sentinel2/",
85
+ "sort_by": "cloud_cover",
86
+ "use_rtree_index": false
87
+ }
86
88
  }
87
89
  }
88
90
  }
@@ -189,8 +191,10 @@ automate this process. Update the dataset `config.json` with a new layer:
189
191
  }],
190
192
  "resampling_method": "nearest",
191
193
  "data_source": {
192
- "name": "rslearn.data_sources.local_files.LocalFiles",
193
- "src_dir": "file:///path/to/world_cover_tifs/"
194
+ "class_path": "rslearn.data_sources.local_files.LocalFiles",
195
+ "init_args": {
196
+ "src_dir": "file:///path/to/world_cover_tifs/"
197
+ }
194
198
  }
195
199
  }
196
200
  },
@@ -252,8 +256,7 @@ model:
252
256
  data:
253
257
  class_path: rslearn.train.data_module.RslearnDataModule
254
258
  init_args:
255
- # Replace this with the dataset path.
256
- path: /path/to/dataset/
259
+ path: ${DATASET_PATH}
257
260
  # This defines the layers that should be read for each window.
258
261
  # The key ("image" / "targets") is what the data will be called in the model,
259
262
  # while the layers option specifies which layers will be read.
@@ -351,7 +354,9 @@ trainer:
351
354
  ...
352
355
  - class_path: rslearn.train.prediction_writer.RslearnWriter
353
356
  init_args:
354
- path: /path/to/dataset/
357
+ # We need to include this argument, but it will be overridden with the dataset
358
+ # path from data.init_args.path.
359
+ path: placeholder
355
360
  output_layer: output
356
361
  ```
357
362
 
@@ -504,24 +509,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
504
509
  SegmentationTask and overriding the visualize function.
505
510
 
506
511
 
507
- ### Logging to Weights & Biases
512
+ ### Checkpoint and Logging Management
513
+
514
+ Above, we needed to configure the checkpoint directory in the model config (the
515
+ `dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
516
+ specify the checkpoint path when applying the model. Additionally, metrics are logged
517
+ to the local filesystem and not well organized.
508
518
 
509
- We can log to W&B by setting the logger under trainer in the model configuration file:
519
+ We can instead let rslearn automatically manage checkpoints, along with logging to
520
+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
521
+ to the model config. The project_name corresponds to the W&B project, and the run name
522
+ corresponds to the W&B name. The management_dir is a directory to store project data;
523
+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
524
+ and uses it to store checkpoints.
510
525
 
511
526
  ```yaml
527
+ model:
528
+ # ...
529
+ data:
530
+ # ...
512
531
  trainer:
513
532
  # ...
514
- logger:
515
- class_path: lightning.pytorch.loggers.WandbLogger
516
- init_args:
517
- project: land_cover_model
518
- name: version_00
533
+ project_name: land_cover_model
534
+ run_name: version_00
535
+ # This sets the option via the MANAGEMENT_DIR environment variable.
536
+ management_dir: ${MANAGEMENT_DIR}
519
537
  ```
520
538
 
521
- Now, runs with this model configuration should show on W&B. For `model fit` runs,
522
- the training and validation loss and accuracy metric will be logged. The accuracy
523
- metric is provided by SegmentationTask, and additional metrics can be enabled by
524
- passing the relevant init_args to the task, e.g. mean IoU and F1:
539
+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
540
+
541
+ ```
542
+ export MANAGEMENT_DIR=./project_data
543
+ rslearn model fit --config land_cover_model.yaml
544
+ ```
545
+
546
+ The training and validation loss and accuracy metric should now be logged to W&B. The
547
+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
548
+ by passing the relevant init_args to the task, e.g. mean IoU and F1:
525
549
 
526
550
  ```yaml
527
551
  class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -532,6 +556,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
532
556
  enable_f1_metric: true
533
557
  ```
534
558
 
559
+ When calling `model test` and `model predict` with management_dir set, rslearn will
560
+ automatically load the best checkpoint from the project directory, or raise an error if
561
+ no existing checkpoint exists. This behavior can be overridden with the
562
+ `--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
563
+ details). Logging will be enabled during fit but not test/predict, and this can also
564
+ be overridden, using `--log_mode`.
565
+
535
566
 
536
567
  ### Inputting Multiple Sentinel-2 Images
537
568
 
@@ -554,10 +585,12 @@ query_config section. This can replace the sentinel2 layer:
554
585
  "bands": ["R", "G", "B"]
555
586
  }],
556
587
  "data_source": {
557
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
558
- "index_cache_dir": "cache/sentinel2/",
559
- "sort_by": "cloud_cover",
560
- "use_rtree_index": false,
588
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
589
+ "init_args": {
590
+ "index_cache_dir": "cache/sentinel2/",
591
+ "sort_by": "cloud_cover",
592
+ "use_rtree_index": false
593
+ },
561
594
  "query_config": {
562
595
  "max_matches": 3
563
596
  }
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.16"
3
+ version = "0.0.18"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -10,6 +10,7 @@ from .dataset import (
10
10
  LayerType,
11
11
  QueryConfig,
12
12
  SpaceMode,
13
+ StorageConfig,
13
14
  TimeMode,
14
15
  )
15
16
 
@@ -23,5 +24,6 @@ __all__ = [
23
24
  "LayerType",
24
25
  "QueryConfig",
25
26
  "SpaceMode",
27
+ "StorageConfig",
26
28
  "TimeMode",
27
29
  ]
@@ -31,6 +31,7 @@ from rslearn.utils.vector_format import VectorFormat
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from rslearn.data_sources.data_source import DataSource
34
+ from rslearn.dataset.storage.storage import WindowStorageFactory
34
35
 
35
36
  logger = get_logger("__name__")
36
37
 
@@ -132,7 +133,11 @@ class BandSetConfig(BaseModel):
132
133
  bands.
133
134
  """
134
135
 
135
- dtype: DType = Field(description="Pixel value type to store the data under")
136
+ model_config = ConfigDict(extra="forbid")
137
+
138
+ dtype: DType = Field(
139
+ description="Pixel value type to store the data under. This is used during dataset materialize and model predict."
140
+ )
136
141
  bands: list[str] = Field(
137
142
  default_factory=lambda: [],
138
143
  description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
@@ -244,6 +249,9 @@ class BandSetConfig(BaseModel):
244
249
  "use `{'class_path': '...', 'init_args': {...}}` instead.",
245
250
  DeprecationWarning,
246
251
  )
252
+ logger.warning(
253
+ "BandSet.format uses legacy format; support will be removed after 2026-03-01."
254
+ )
247
255
 
248
256
  legacy_name_to_class_path = {
249
257
  "image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
@@ -326,7 +334,7 @@ class TimeMode(StrEnum):
326
334
  class QueryConfig(BaseModel):
327
335
  """A configuration for querying items in a data source."""
328
336
 
329
- model_config = ConfigDict(frozen=True)
337
+ model_config = ConfigDict(frozen=True, extra="forbid")
330
338
 
331
339
  space_mode: SpaceMode = Field(
332
340
  default=SpaceMode.MOSAIC,
@@ -360,7 +368,7 @@ class QueryConfig(BaseModel):
360
368
  class DataSourceConfig(BaseModel):
361
369
  """Configuration for a DataSource in a dataset layer."""
362
370
 
363
- model_config = ConfigDict(frozen=True)
371
+ model_config = ConfigDict(frozen=True, extra="forbid")
364
372
 
365
373
  class_path: str = Field(description="Class path for the data source.")
366
374
  init_args: dict[str, Any] = Field(
@@ -409,6 +417,9 @@ class DataSourceConfig(BaseModel):
409
417
  "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
410
418
  DeprecationWarning,
411
419
  )
420
+ logger.warning(
421
+ "Data source configuration uses legacy format; support will be removed after 2026-03-01."
422
+ )
412
423
 
413
424
  # Split the dict into the base config that is in the pydantic model, and the
414
425
  # source-specific options that should be moved to init_args dict.
@@ -463,7 +474,7 @@ class CompositingMethod(StrEnum):
463
474
  class LayerConfig(BaseModel):
464
475
  """Configuration of a layer in a dataset."""
465
476
 
466
- model_config = ConfigDict(frozen=True)
477
+ model_config = ConfigDict(frozen=True, extra="forbid")
467
478
 
468
479
  type: LayerType = Field(description="The LayerType (raster or vector).")
469
480
  data_source: DataSourceConfig | None = Field(
@@ -586,11 +597,51 @@ class LayerConfig(BaseModel):
586
597
  return vector_format
587
598
 
588
599
 
600
+ class StorageConfig(BaseModel):
601
+ """Configuration for the WindowStorageFactory (window metadata storage backend)."""
602
+
603
+ model_config = ConfigDict(frozen=True, extra="forbid")
604
+
605
+ class_path: str = Field(
606
+ default="rslearn.dataset.storage.file.FileWindowStorageFactory",
607
+ description="Class path for the WindowStorageFactory.",
608
+ )
609
+ init_args: dict[str, Any] = Field(
610
+ default_factory=lambda: {},
611
+ description="jsonargparse init args for the WindowStorageFactory.",
612
+ )
613
+
614
+ def instantiate_window_storage_factory(self) -> "WindowStorageFactory":
615
+ """Instantiate the WindowStorageFactory specified by this config."""
616
+ from rslearn.dataset.storage.storage import WindowStorageFactory
617
+ from rslearn.utils.jsonargparse import init_jsonargparse
618
+
619
+ init_jsonargparse()
620
+ parser = jsonargparse.ArgumentParser()
621
+ parser.add_argument("--wsf", type=WindowStorageFactory)
622
+ cfg = parser.parse_object(
623
+ {
624
+ "wsf": dict(
625
+ class_path=self.class_path,
626
+ init_args=self.init_args,
627
+ )
628
+ }
629
+ )
630
+ wsf = parser.instantiate_classes(cfg).wsf
631
+ return wsf
632
+
633
+
589
634
  class DatasetConfig(BaseModel):
590
635
  """Overall dataset configuration."""
591
636
 
637
+ model_config = ConfigDict(extra="forbid")
638
+
592
639
  layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
593
640
  tile_store: dict[str, Any] = Field(
594
641
  default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
595
642
  description="jsonargparse configuration for the TileStore.",
596
643
  )
644
+ storage: StorageConfig = Field(
645
+ default_factory=lambda: StorageConfig(),
646
+ description="jsonargparse configuration for the WindowStorageFactory.",
647
+ )
@@ -131,7 +131,7 @@ def add_windows_from_geometries(
131
131
  f"_{time_range[0].isoformat()}_{time_range[1].isoformat()}"
132
132
  )
133
133
  window = Window(
134
- path=dataset.path / "windows" / group / cur_window_name,
134
+ storage=dataset.storage,
135
135
  group=group,
136
136
  name=cur_window_name,
137
137
  projection=cur_projection,
@@ -1,9 +1,8 @@
1
1
  """rslearn dataset class."""
2
2
 
3
3
  import json
4
- import multiprocessing
4
+ from typing import Any
5
5
 
6
- import tqdm
7
6
  from upath import UPath
8
7
 
9
8
  from rslearn.config import DatasetConfig
@@ -11,7 +10,6 @@ from rslearn.log_utils import get_logger
11
10
  from rslearn.template_params import substitute_env_vars_in_string
12
11
  from rslearn.tile_stores import TileStore, load_tile_store
13
12
 
14
- from .index import DatasetIndex
15
13
  from .window import Window
16
14
 
17
15
  logger = get_logger(__name__)
@@ -68,80 +66,26 @@ class Dataset:
68
66
  self.layers[layer_name] = layer_config
69
67
 
70
68
  self.tile_store_config = config.tile_store
71
-
72
- def _get_index(self) -> DatasetIndex | None:
73
- index_fname = self.path / DatasetIndex.FNAME
74
- if not index_fname.exists():
75
- return None
76
- return DatasetIndex.load_index(self.path)
69
+ self.storage = (
70
+ config.storage.instantiate_window_storage_factory().get_storage(
71
+ self.path
72
+ )
73
+ )
77
74
 
78
75
  def load_windows(
79
76
  self,
80
77
  groups: list[str] | None = None,
81
78
  names: list[str] | None = None,
82
- show_progress: bool = False,
83
- workers: int = 0,
84
- no_index: bool = False,
79
+ **kwargs: Any,
85
80
  ) -> list[Window]:
86
81
  """Load the windows in the dataset.
87
82
 
88
83
  Args:
89
84
  groups: an optional list of groups to filter loading
90
85
  names: an optional list of window names to filter loading
91
- show_progress: whether to show tqdm progress bar
92
- workers: number of parallel workers, default 0 (use main thread only to load windows)
93
- no_index: don't use the dataset index even if it exists.
86
+ kwargs: optional keyword arguments to pass to WindowStorage.get_windows.
94
87
  """
95
- # Load from index if it exists.
96
- # We never use the index if names is set since loading the index will likely be
97
- # slower than loading a few windows.
98
- if not no_index and names is None:
99
- dataset_index = self._get_index()
100
- if dataset_index is not None:
101
- return dataset_index.get_windows(groups=groups, names=names)
102
-
103
- # Avoid directory does not exist errors later.
104
- if not (self.path / "windows").exists():
105
- return []
106
-
107
- window_dirs = []
108
- if not groups:
109
- groups = []
110
- for p in (self.path / "windows").iterdir():
111
- groups.append(p.name)
112
- for group in groups:
113
- group_dir = self.path / "windows" / group
114
- if not group_dir.exists():
115
- logger.warning(
116
- f"Skipping group directory {group_dir} since it does not exist"
117
- )
118
- continue
119
- if names:
120
- cur_names = names
121
- else:
122
- cur_names = []
123
- for p in group_dir.iterdir():
124
- cur_names.append(p.name)
125
-
126
- for window_name in cur_names:
127
- window_dir = group_dir / window_name
128
- window_dirs.append(window_dir)
129
-
130
- if workers == 0:
131
- windows = [Window.load(window_dir) for window_dir in window_dirs]
132
- else:
133
- p = multiprocessing.Pool(workers)
134
- outputs = p.imap_unordered(Window.load, window_dirs)
135
- if show_progress:
136
- outputs = tqdm.tqdm(
137
- outputs, total=len(window_dirs), desc="Loading windows"
138
- )
139
- windows = []
140
- for window in outputs:
141
- windows.append(window)
142
- p.close()
143
-
144
- return windows
88
+ return self.storage.get_windows(groups=groups, names=names, **kwargs)
145
89
 
146
90
  def get_tile_store(self) -> TileStore:
147
91
  """Get the tile store associated with this dataset.
@@ -161,7 +161,7 @@ def build_first_valid_composite(
161
161
  nodata_vals: list[Any],
162
162
  bands: list[str],
163
163
  bounds: PixelBounds,
164
- band_dtype: Any,
164
+ band_dtype: npt.DTypeLike,
165
165
  tile_store: TileStoreWithLayer,
166
166
  projection: Projection,
167
167
  remapper: Remapper | None,
@@ -233,7 +233,7 @@ def read_and_stack_raster_windows(
233
233
  projection: Projection,
234
234
  nodata_vals: list[Any],
235
235
  remapper: Remapper | None,
236
- band_dtype: Any,
236
+ band_dtype: npt.DTypeLike,
237
237
  resampling_method: Resampling = Resampling.bilinear,
238
238
  ) -> npt.NDArray[np.generic]:
239
239
  """Create a stack of extent aligned raster windows.
@@ -326,7 +326,7 @@ def build_mean_composite(
326
326
  nodata_vals: list[Any],
327
327
  bands: list[str],
328
328
  bounds: PixelBounds,
329
- band_dtype: Any,
329
+ band_dtype: npt.DTypeLike,
330
330
  tile_store: TileStoreWithLayer,
331
331
  projection: Projection,
332
332
  remapper: Remapper | None,
@@ -383,7 +383,7 @@ def build_median_composite(
383
383
  nodata_vals: list[Any],
384
384
  bands: list[str],
385
385
  bounds: PixelBounds,
386
- band_dtype: Any,
386
+ band_dtype: npt.DTypeLike,
387
387
  tile_store: TileStoreWithLayer,
388
388
  projection: Projection,
389
389
  remapper: Remapper | None,
@@ -471,7 +471,7 @@ def build_composite(
471
471
  nodata_vals=nodata_vals,
472
472
  bands=band_cfg.bands,
473
473
  bounds=bounds,
474
- band_dtype=band_cfg.dtype.value,
474
+ band_dtype=band_cfg.dtype.get_numpy_dtype(),
475
475
  tile_store=tile_store,
476
476
  projection=projection,
477
477
  resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
@@ -0,0 +1 @@
1
+ """Storage backends for rslearn window metadata."""