rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +55 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/lightning_cli.py +10 -3
  11. rslearn/main.py +11 -36
  12. rslearn/models/anysat.py +11 -9
  13. rslearn/models/clay/clay.py +8 -9
  14. rslearn/models/clip.py +18 -15
  15. rslearn/models/component.py +99 -0
  16. rslearn/models/concatenate_features.py +21 -11
  17. rslearn/models/conv.py +15 -8
  18. rslearn/models/croma.py +13 -8
  19. rslearn/models/detr/detr.py +25 -14
  20. rslearn/models/dinov3.py +11 -6
  21. rslearn/models/faster_rcnn.py +19 -9
  22. rslearn/models/feature_center_crop.py +12 -9
  23. rslearn/models/fpn.py +19 -8
  24. rslearn/models/galileo/galileo.py +23 -18
  25. rslearn/models/module_wrapper.py +26 -57
  26. rslearn/models/molmo.py +16 -14
  27. rslearn/models/multitask.py +102 -73
  28. rslearn/models/olmoearth_pretrain/model.py +20 -17
  29. rslearn/models/panopticon.py +8 -7
  30. rslearn/models/pick_features.py +18 -24
  31. rslearn/models/pooling_decoder.py +22 -14
  32. rslearn/models/presto/presto.py +16 -10
  33. rslearn/models/presto/single_file_presto.py +4 -10
  34. rslearn/models/prithvi.py +12 -8
  35. rslearn/models/resize_features.py +21 -7
  36. rslearn/models/sam2_enc.py +11 -9
  37. rslearn/models/satlaspretrain.py +15 -9
  38. rslearn/models/simple_time_series.py +31 -17
  39. rslearn/models/singletask.py +24 -17
  40. rslearn/models/ssl4eo_s12.py +15 -10
  41. rslearn/models/swin.py +22 -13
  42. rslearn/models/terramind.py +24 -7
  43. rslearn/models/trunk.py +6 -3
  44. rslearn/models/unet.py +18 -9
  45. rslearn/models/upsample.py +22 -9
  46. rslearn/train/all_patches_dataset.py +22 -18
  47. rslearn/train/dataset.py +69 -54
  48. rslearn/train/lightning_module.py +51 -32
  49. rslearn/train/model_context.py +54 -0
  50. rslearn/train/prediction_writer.py +111 -41
  51. rslearn/train/tasks/classification.py +34 -15
  52. rslearn/train/tasks/detection.py +24 -31
  53. rslearn/train/tasks/embedding.py +33 -29
  54. rslearn/train/tasks/multi_task.py +7 -7
  55. rslearn/train/tasks/per_pixel_regression.py +41 -19
  56. rslearn/train/tasks/regression.py +38 -21
  57. rslearn/train/tasks/segmentation.py +33 -15
  58. rslearn/train/tasks/task.py +3 -2
  59. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
  60. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
  61. rslearn/dataset/index.py +0 -173
  62. rslearn/models/registry.py +0 -22
  63. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  64. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  65. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  66. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  67. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -343,10 +343,12 @@ directory `/path/to/dataset` and corresponding configuration file at
343
343
  "bands": ["R", "G", "B"]
344
344
  }],
345
345
  "data_source": {
346
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
- "index_cache_dir": "cache/sentinel2/",
348
- "sort_by": "cloud_cover",
349
- "use_rtree_index": false
346
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
+ "init_args": {
348
+ "index_cache_dir": "cache/sentinel2/",
349
+ "sort_by": "cloud_cover",
350
+ "use_rtree_index": false
351
+ }
350
352
  }
351
353
  }
352
354
  }
@@ -453,8 +455,10 @@ automate this process. Update the dataset `config.json` with a new layer:
453
455
  }],
454
456
  "resampling_method": "nearest",
455
457
  "data_source": {
456
- "name": "rslearn.data_sources.local_files.LocalFiles",
457
- "src_dir": "file:///path/to/world_cover_tifs/"
458
+ "class_path": "rslearn.data_sources.local_files.LocalFiles",
459
+ "init_args": {
460
+ "src_dir": "file:///path/to/world_cover_tifs/"
461
+ }
458
462
  }
459
463
  }
460
464
  },
@@ -516,8 +520,7 @@ model:
516
520
  data:
517
521
  class_path: rslearn.train.data_module.RslearnDataModule
518
522
  init_args:
519
- # Replace this with the dataset path.
520
- path: /path/to/dataset/
523
+ path: ${DATASET_PATH}
521
524
  # This defines the layers that should be read for each window.
522
525
  # The key ("image" / "targets") is what the data will be called in the model,
523
526
  # while the layers option specifies which layers will be read.
@@ -615,7 +618,9 @@ trainer:
615
618
  ...
616
619
  - class_path: rslearn.train.prediction_writer.RslearnWriter
617
620
  init_args:
618
- path: /path/to/dataset/
621
+ # We need to include this argument, but it will be overridden with the dataset
622
+ # path from data.init_args.path.
623
+ path: placeholder
619
624
  output_layer: output
620
625
  ```
621
626
 
@@ -768,24 +773,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
768
773
  SegmentationTask and overriding the visualize function.
769
774
 
770
775
 
771
- ### Logging to Weights & Biases
776
+ ### Checkpoint and Logging Management
777
+
778
+ Above, we needed to configure the checkpoint directory in the model config (the
779
+ `dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
780
+ specify the checkpoint path when applying the model. Additionally, metrics are logged
781
+ to the local filesystem and not well organized.
772
782
 
773
- We can log to W&B by setting the logger under trainer in the model configuration file:
783
+ We can instead let rslearn automatically manage checkpoints, along with logging to
784
+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
785
+ to the model config. The project_name corresponds to the W&B project, and the run name
786
+ corresponds to the W&B name. The management_dir is a directory to store project data;
787
+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
788
+ and uses it to store checkpoints.
774
789
 
775
790
  ```yaml
791
+ model:
792
+ # ...
793
+ data:
794
+ # ...
776
795
  trainer:
777
796
  # ...
778
- logger:
779
- class_path: lightning.pytorch.loggers.WandbLogger
780
- init_args:
781
- project: land_cover_model
782
- name: version_00
797
+ project_name: land_cover_model
798
+ run_name: version_00
799
+ # This sets the option via the MANAGEMENT_DIR environment variable.
800
+ management_dir: ${MANAGEMENT_DIR}
783
801
  ```
784
802
 
785
- Now, runs with this model configuration should show on W&B. For `model fit` runs,
786
- the training and validation loss and accuracy metric will be logged. The accuracy
787
- metric is provided by SegmentationTask, and additional metrics can be enabled by
788
- passing the relevant init_args to the task, e.g. mean IoU and F1:
803
+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
804
+
805
+ ```
806
+ export MANAGEMENT_DIR=./project_data
807
+ rslearn model fit --config land_cover_model.yaml
808
+ ```
809
+
810
+ The training and validation loss and accuracy metric should now be logged to W&B. The
811
+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
812
+ by passing the relevant init_args to the task, e.g. mean IoU and F1:
789
813
 
790
814
  ```yaml
791
815
  class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -796,6 +820,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
796
820
  enable_f1_metric: true
797
821
  ```
798
822
 
823
+ When calling `model test` and `model predict` with management_dir set, rslearn will
824
+ automatically load the best checkpoint from the project directory, or raise an error if
825
+ no existing checkpoint exists. This behavior can be overridden with the
826
+ `--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
827
+ details). Logging will be enabled during fit but not test/predict, and this can also
828
+ be overridden, using `--log_mode`.
829
+
799
830
 
800
831
  ### Inputting Multiple Sentinel-2 Images
801
832
 
@@ -818,10 +849,12 @@ query_config section. This can replace the sentinel2 layer:
818
849
  "bands": ["R", "G", "B"]
819
850
  }],
820
851
  "data_source": {
821
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
822
- "index_cache_dir": "cache/sentinel2/",
823
- "sort_by": "cloud_cover",
824
- "use_rtree_index": false,
852
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
853
+ "init_args": {
854
+ "index_cache_dir": "cache/sentinel2/",
855
+ "sort_by": "cloud_cover",
856
+ "use_rtree_index": false
857
+ },
825
858
  "query_config": {
826
859
  "max_matches": 3
827
860
  }
@@ -1,13 +1,13 @@
1
1
  rslearn/__init__.py,sha256=fFmAen3vxZyosEfPbG0W46IttujYGVxzrGkJ0YutmmY,73
2
2
  rslearn/arg_parser.py,sha256=GNlJncO6Ck_dCNrcg7z_SSG61I-2gKn3Ix2tAxIk9CI,1428
3
3
  rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
4
- rslearn/lightning_cli.py,sha256=x8i2QJvEBaYdqh2_f0-ety7_sNEH9UCKRZUPkqWYZdU,17169
4
+ rslearn/lightning_cli.py,sha256=Cihdf3dOQ17b_n4432Y6LmCQ5XFDghW4rGb4fqw-b6g,17525
5
5
  rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
6
- rslearn/main.py,sha256=0g1SRO975eC9DTzKqJnwlWHgVo2Pvotyr72KoJBgjew,29060
6
+ rslearn/main.py,sha256=CLVLpJJXc8f8BCG9SBdOuw_taADu3majD1TNSOVC6Ws,28565
7
7
  rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
9
- rslearn/config/__init__.py,sha256=a8xTvYSnpfIzniHgcnSeob5jo5PVBfacpakA_150MME,434
10
- rslearn/config/dataset.py,sha256=iUFuwzlM9z6n1pGCd40SmAVj3fG6zTXWrlH6eenfon8,21143
9
+ rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
10
+ rslearn/config/dataset.py,sha256=fp7HT8FeKyOmkR77pR5tVZYAAnuetCLY6zGFVcdPHyY,23087
11
11
  rslearn/data_sources/__init__.py,sha256=zzuZUxrlEIw84YpD2I0HJvCoLDB29LbmnKTXiJykzGU,660
12
12
  rslearn/data_sources/aws_landsat.py,sha256=0ZQtmd2NCnvLy4vFSB1AlmoguJbiQB_e_T4eS1tnW9Q,20443
13
13
  rslearn/data_sources/aws_open_data.py,sha256=lrHnMJTH3NAaRdNjxwCIxSq8rq90IvV4ho-qAG6Hdgc,29348
@@ -34,59 +34,61 @@ rslearn/data_sources/worldcover.py,sha256=n7gi-JRytxkvkUhKT--dVziMcWSSyMbZA7ZCzL
34
34
  rslearn/data_sources/worldpop.py,sha256=S3RSc5kTwSs2bcREjNarBsqf3MBX5CN0eHj7Qkx4K74,5625
35
35
  rslearn/data_sources/xyz_tiles.py,sha256=P601CvUmoVDC_ZRVhPKaIoPYYCq5UhZ-v7DaGlN5y_0,13797
36
36
  rslearn/dataset/__init__.py,sha256=bHtBlEEBCekO-gaJqiww0-VjvZTE5ahx0llleo8bfP8,289
37
- rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVkc,8494
38
- rslearn/dataset/dataset.py,sha256=qmZmFfQOHoKVx6_sBYtBR5H4GTNMgETvq0S4XrqafQU,5165
37
+ rslearn/dataset/add_windows.py,sha256=NwIvku6zxCJ9kgVFa5phJc0Gj1Y1bCzh6TLb9nEGl0s,8462
38
+ rslearn/dataset/dataset.py,sha256=A-iXXdGvTCus1kCKC58dk1mEidMe29X3Lox9JEto_TQ,3112
39
39
  rslearn/dataset/handler_summaries.py,sha256=wI99RDk5erCWkzl1A7Uc4chatQ9KWIr4F_0Hxr9Co6s,2607
40
- rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
41
40
  rslearn/dataset/manage.py,sha256=-lGSIgk6Z7-verF_POwe4n5w9eSpgyt4nEOcOj382rc,18971
42
- rslearn/dataset/materialize.py,sha256=x7FewmdqFUviLtPZGZIfAcw1rd0wfKZhk_N-uN-tQms,20922
41
+ rslearn/dataset/materialize.py,sha256=o05OeLk_wWEOsw15oc5yjpD4J-twGCTfXAtxyAQsQ9I,20974
43
42
  rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
44
- rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
43
+ rslearn/dataset/window.py,sha256=X4q8YzcSOTtwKxCPf71QLMoyKUtYMSnZu0kPnmVSUx4,10644
44
+ rslearn/dataset/storage/__init__.py,sha256=R50AVV5LH2g7ol0-jyvGcB390VsclXGbJXz4fmkn9as,52
45
+ rslearn/dataset/storage/file.py,sha256=g9HZ3CD4QcgyVNsBaXhjIKQgDOAeZ4R08sJ7ntx4wo8,6815
46
+ rslearn/dataset/storage/storage.py,sha256=DxZ7iwV938PiLwdQzb5EXSb4Mj8bRGmOTmA9fzq_Ge8,4840
45
47
  rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
46
- rslearn/models/anysat.py,sha256=3Oh2gWxicVdUzOjevBEZf0PuolmCy0KC5Ad7JY-0Plc,7949
47
- rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
48
- rslearn/models/concatenate_features.py,sha256=qQPwKF-wz18hC1EA1OS81C6RPjYggPrZ5grCuHIM4aY,3434
49
- rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
50
- rslearn/models/croma.py,sha256=n7yunpT7lo8vWWaOpx4yt8jZSXjgWqfgZcZKFW5zuEQ,10591
51
- rslearn/models/dinov3.py,sha256=9k9kNlXCorQQwKjLGptooANd48TUBsITQ1e4fUomlM4,6337
52
- rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
53
- rslearn/models/feature_center_crop.py,sha256=24eOrvLEGGVWPw7kPHyUes5HtYNAX7GZ_NpqDGMILEY,1553
54
- rslearn/models/fpn.py,sha256=s3cz29I14FaSuvBvLOcwCrqVsaRBxG5GjLlqap4WgPc,1603
55
- rslearn/models/module_wrapper.py,sha256=H2zb-8Au4t31kawW_4JEKHsaXFjpYDawb31ZEauKcxU,2728
56
- rslearn/models/molmo.py,sha256=mVrARBhZciMzOgOOjGB5AHlPIf2iO9IBSJmdyKSl1L8,2061
57
- rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,14672
58
- rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
59
- rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
60
- rslearn/models/pooling_decoder.py,sha256=unr2fSE_QmJHPi3dKtopqMtb1Kn-2h94LgwwAVP9vZg,4437
61
- rslearn/models/prithvi.py,sha256=AIzcO5xk1ggR0MjbfhIzqPVgUKFN7odxygmgyAelfW8,40143
62
- rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
- rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
64
- rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
65
- rslearn/models/satlaspretrain.py,sha256=b6FR_il6MnWU4UpB9OxInZSK9n0IS0PcQuLrWH4YD8g,3046
66
- rslearn/models/simple_time_series.py,sha256=oTg_akabYFBExJu7JCpbuM211-ZgQS4WerG2nEYrIZY,12774
67
- rslearn/models/singletask.py,sha256=z4vN9Yvzz0I-U4KJdVZxLJK2ZV-MIv9tzwCGcOWoUPY,1604
68
- rslearn/models/ssl4eo_s12.py,sha256=sOGEHcDo-rNdmEyoLu2AVEqfxRM_cv6zpfAmyn5c6tw,3553
69
- rslearn/models/swin.py,sha256=bMlGePXMFou4A_YSUZzjHgN9NniGXaCWdGQ31xHDKis,5511
48
+ rslearn/models/anysat.py,sha256=-1uE2kSfR34lPld5AXEuU8KCfVIa3YAPqs-filxNiWY,8026
49
+ rslearn/models/clip.py,sha256=TgCPA7IEsnONFbWQxQlLguvAXzwp_Y90-Vk9MZnwXak,2337
50
+ rslearn/models/component.py,sha256=1vRK9K7hEgm0eyRhyPyLX4SBWZxiCXzO0fb_bwPQjQg,3015
51
+ rslearn/models/concatenate_features.py,sha256=Attemr5KurxlOojpclD0Pd5Cu2KHpNdpXe8jCSjpJ9U,3818
52
+ rslearn/models/conv.py,sha256=dEAAfhPo4bFlZPSAQjzqZTpP-hdJ394TytYssVK-fDA,2001
53
+ rslearn/models/croma.py,sha256=x8bTFOJB-yR9PydSEjCeV9WnzpalT9QLtozlICNDhyE,10820
54
+ rslearn/models/dinov3.py,sha256=zlICRIeOaxREpR75bzF4nS8P-ZUvLdrJLBU3cBTYgec,6458
55
+ rslearn/models/faster_rcnn.py,sha256=yYRk3attz_GyhJA6jE1ss4ybT_knbLNT1lMRrkz22PI,8614
56
+ rslearn/models/feature_center_crop.py,sha256=_Mu3E4iJLBug9I4ZIBIpB_VJo-xGterHmhtIFGaHR34,1808
57
+ rslearn/models/fpn.py,sha256=qm7nKMgsZrCoAdz8ASmNKU2nvZ6USm5CedMfy_w_gwE,2079
58
+ rslearn/models/module_wrapper.py,sha256=XjRgmhss9_beBEE77t9iySsIvYpV7R-DeyBlOBG000I,2004
59
+ rslearn/models/molmo.py,sha256=pPtCy7eg-xN-iRKCtNL_hpxagfJRyYc9o3NVb-ifonI,2182
60
+ rslearn/models/multitask.py,sha256=bpFxvtFowRyT-tvRSdY7AKbEx_i1y7sToEzZgTMcF4s,16264
61
+ rslearn/models/panopticon.py,sha256=4-KHbgAHUk-ZX04PFRIJF_BRO_euOZO0NfzjB3CEP7Y,5891
62
+ rslearn/models/pick_features.py,sha256=fI9SYubqpCWOAHYGVUSg5sgD31dsnAR9mNuLmqfIeL8,1110
63
+ rslearn/models/pooling_decoder.py,sha256=zrMH6wUExCa-XD1q9CIFD2ScgiasapyJs9plhcUxhIs,4767
64
+ rslearn/models/prithvi.py,sha256=SzdWm0CjYCZIviES3yBXec6TmvrWGQ6pf37w4BaODlY,40326
65
+ rslearn/models/resize_features.py,sha256=U7ZIVwwToJJnwchFG59wLWWP9eikHDB_1c4OtpubxHU,1693
66
+ rslearn/models/sam2_enc.py,sha256=iq2JBqloJ7xkg7Qz5M1HymRfEkLGGn2msF00R_nGBgM,3593
67
+ rslearn/models/satlaspretrain.py,sha256=ytlke5IlxlUEhUoIRcm0pz7umppppZBskCux9YE8mVg,3281
68
+ rslearn/models/simple_time_series.py,sha256=bmy3SsFocpbQwTk4s9gBMAYvz89IlUL_2pAFMZLew_8,13402
69
+ rslearn/models/singletask.py,sha256=9DM9a9-Mv3vVQqRhPOIXG2HHuVqVa_zuvgafeeYh4r0,1903
70
+ rslearn/models/ssl4eo_s12.py,sha256=FTDSkHCaQxw0B-R-8NIe50lZRBYFdI-v68lk9NrnwII,3737
71
+ rslearn/models/swin.py,sha256=QKiV00bu6dw9p0a9PbqDOe5zsnaFHuu6dOfEn41ms7I,6017
70
72
  rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
71
- rslearn/models/terramind.py,sha256=5POVk_y29LlbVswa6ojd9gdB70iO41yB9Y2aqVY4WdQ,8327
72
- rslearn/models/trunk.py,sha256=H1QPQGAKsmocq3OiF66GW8MQI4LffupTDrgZR4Ta7QM,4708
73
- rslearn/models/unet.py,sha256=WUgLgvvlgV8l_6MIDBl6aX1HNOkb24DfnVRIyYXHCjo,6865
74
- rslearn/models/upsample.py,sha256=3kWbyWZIk56JJxj8en9pieitbrk3XnbIsTKlEkiDQQY,938
73
+ rslearn/models/terramind.py,sha256=R_daJOIzCnaC_Z3s5n9YtBa9Gf_HgQatgCOwwy4MtSA,8732
74
+ rslearn/models/trunk.py,sha256=1GCH9iyLIytoHVntLSMwfH9duQpe1W4DPmOClLpPKjc,4778
75
+ rslearn/models/unet.py,sha256=HuuINvkB1-5w9ZOTXZCWkVxJShruPKCol8pKeA3zw_4,7251
76
+ rslearn/models/upsample.py,sha256=JvfnktT6Dgcql3cSoySWXZ7dmkDkfpRo6vDkpz8KFAQ,1326
75
77
  rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
76
- rslearn/models/clay/clay.py,sha256=29CGCOysx9duEX4Y6LUNHXck_sHjCFrlV4w8CP_hKmI,8460
78
+ rslearn/models/clay/clay.py,sha256=q2vcqyRFhCJvzLHwHLxHfWGyvVt819Xmk7TzuhXdPFI,8478
77
79
  rslearn/models/clay/configs/metadata.yaml,sha256=rZTFh4Yb9htEfbQNOPl4HTbFogEhzwIRqFzG-1uT01Y,4652
78
80
  rslearn/models/detr/__init__.py,sha256=GGAnTIhyuvl34IRrJ_4gXjm_01OlM5rbQQ3c3TGfbK8,84
79
81
  rslearn/models/detr/box_ops.py,sha256=ORCF6EwMpMBB_VgQT05SjR47dCR2rN2gPhL_gsuUWJs,3236
80
- rslearn/models/detr/detr.py,sha256=otLmmyUm05e4MUyvQBoqo-RKnx3hbodTXvfPQWvuTEI,18737
82
+ rslearn/models/detr/detr.py,sha256=ZiYNwJ3zWqZdvvzc0CXDzdvcwSCO_wSOCWlChemTsX8,19178
81
83
  rslearn/models/detr/matcher.py,sha256=4h_xFlgTMEJvJ6aLZUamrKZ72L5hDk9wPglNZ81JBg8,4533
82
84
  rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJTC608TaX6yY,3869
83
85
  rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
84
86
  rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
85
87
  rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
86
- rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
88
+ rslearn/models/galileo/galileo.py,sha256=PR90sPKxBfrpYolAfDf-AyBkNVARr6UbQmYc2XzcNPc,21365
87
89
  rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
88
90
  rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
89
- rslearn/models/olmoearth_pretrain/model.py,sha256=ZJgoyy7vwB0PUMJtHF-sdJ-uSBqnUXMDBco0Dx4cAes,10670
91
+ rslearn/models/olmoearth_pretrain/model.py,sha256=mGGoe45WEhFUJMQBYgKyJ_peDtsXKxNrgiYLz9u8I9o,10882
90
92
  rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
91
93
  rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
92
94
  rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
@@ -101,18 +103,19 @@ rslearn/models/panopticon_data/sensors/sentinel2.yaml,sha256=qYJ92x-GHO0ZdCrTtCj
101
103
  rslearn/models/panopticon_data/sensors/superdove.yaml,sha256=QpIRyopdV4hAez_EIsDwhGFT4VtTk7UgzQveyc8t8fc,795
102
104
  rslearn/models/panopticon_data/sensors/wv23.yaml,sha256=SWYSlkka6UViKAz6YI8aqwQ-Ayo-S5kmNa9rO3iGW6o,1172
103
105
  rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRygze0,64
104
- rslearn/models/presto/presto.py,sha256=8mZnc0jk_r_JikybHQNyyHg6t7JNPmoPmgoivyNf-U8,9177
105
- rslearn/models/presto/single_file_presto.py,sha256=Kbwp8V7pO8HHM2vlCPpjekQiFiDryW8zQkWmt1g05BY,30381
106
+ rslearn/models/presto/presto.py,sha256=SQCxY-jibe_PQs5yvoEOzjUhv8jaCoEkDE0pvSVhEUY,9542
107
+ rslearn/models/presto/single_file_presto.py,sha256=-P00xjhj9dx3O6HqWpQmG9dPk_i6bT_t8vhX4uQm5tA,30242
106
108
  rslearn/tile_stores/__init__.py,sha256=-cW1J7So60SEP5ZLHCPdaFBV5CxvV3QlOhaFnUkhTJ0,1675
107
109
  rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
108
110
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
109
111
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
110
- rslearn/train/all_patches_dataset.py,sha256=xFJ96HU3CodrUBzXTsgrmEShosKH79T2SxI0xDVSH3Q,18217
112
+ rslearn/train/all_patches_dataset.py,sha256=qgx1tHZOGOFxUB-HQ7jDxJXE_cvySQfDE6cMr60VS7s,18206
111
113
  rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
112
- rslearn/train/dataset.py,sha256=OLPBtVf7tmHodoMnB_gI-jLQq2xQ9aXz38Hq8kBgbp0,33944
113
- rslearn/train/lightning_module.py,sha256=ZLBiId3secUlVs2yzkN-mwVv4rMdh5TkdZYl4vv_Cw0,14466
114
+ rslearn/train/dataset.py,sha256=bsgpUoUKv9AeXzCzBqumiq8b-NZ5PjkFTvLNG82Vx2Q,34179
115
+ rslearn/train/lightning_module.py,sha256=HA9e-74oUZR5s7piQP9Mwxwz0vw0-p4HdmijfgBSpU0,14776
116
+ rslearn/train/model_context.py,sha256=TJHCw8xXLZrRHqGzJr30QARATgG_pJLbMBmyiIcZXRM,1449
114
117
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
115
- rslearn/train/prediction_writer.py,sha256=D2CCLlwlElMoMxnPiI6B9Q9HafGspuwoqYD8TKq98pk,13173
118
+ rslearn/train/prediction_writer.py,sha256=rW0BUaYT_F1QqmpnQlbrLiLya1iBfC5Pb78G_NlF-vA,15956
116
119
  rslearn/train/scheduler.py,sha256=wFbmycMHgL6nRYeYalDjb0G8YVo8VD3T3sABS61jJ7c,2318
117
120
  rslearn/train/callbacks/__init__.py,sha256=VNV0ArZyYMvl3dGK2wl6F046khYJ1dEBlJS6G_SYNm0,47
118
121
  rslearn/train/callbacks/adapters.py,sha256=yfv8nyCj3jmo2_dNkFrjukKxh0MHsf2xKqWwMF0QUtY,1869
@@ -120,14 +123,14 @@ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjT
120
123
  rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
121
124
  rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
122
125
  rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
123
- rslearn/train/tasks/classification.py,sha256=8nSv0caf2PzV3Pmme_iN4WQIac4ry3hdW6FRHbh4L1M,13152
124
- rslearn/train/tasks/detection.py,sha256=9j9webusrjGexvUmZ7gl3NTBS63Qq511VFlB2WbLi5Y,22302
125
- rslearn/train/tasks/embedding.py,sha256=DK3l1aQ3d5gQUT1h3cD6vcUaNKvSsH26RHx2Bbzutbg,3667
126
- rslearn/train/tasks/multi_task.py,sha256=dBWsnbvQ0CReNsbDHmZ_-sXjUE0H4S2OPcbJwMquG9g,6016
127
- rslearn/train/tasks/per_pixel_regression.py,sha256=W8dbLyIiPgFI3gA_aZQX0pSFRWLP2v6tthsFbKhcDVg,8783
128
- rslearn/train/tasks/regression.py,sha256=zZhrrZ1qxjrdLjKWC9McRivDXCcKiYfdLC-kaMeVkDc,11547
129
- rslearn/train/tasks/segmentation.py,sha256=xEni3CLDyetviv84XrpJg5xeJU87WHGFKTVfIeemGIY,21868
130
- rslearn/train/tasks/task.py,sha256=4w2xKL_U5JAtdj2dYoVv82h6xTtgUsA3IvIOcXyZecs,3887
126
+ rslearn/train/tasks/classification.py,sha256=2_Iz3g9ifdtMo6a2sRtnAoYGP2Om8JPP7rm2AfwoDLc,14190
127
+ rslearn/train/tasks/detection.py,sha256=DrPLF_63WU99Qh1yILxSJWrYrl_2mCGxlTX2SznCei0,21938
128
+ rslearn/train/tasks/embedding.py,sha256=98ykdmfaxQjsH0UrUdTGmz1f0hCMPcNYTzt1YFqNQwQ,3869
129
+ rslearn/train/tasks/multi_task.py,sha256=piauvjg4j6eBEZnIFrKKxyYPWKdAJ4yUCFxt_ngsxDY,6125
130
+ rslearn/train/tasks/per_pixel_regression.py,sha256=jME-AFC74dfNAF5cBBmdmIgOV-KfkHt_z1GIOifuPJw,9975
131
+ rslearn/train/tasks/regression.py,sha256=8U0bcSofefmP9Drpbo7PO9xAkVptAuso2rOQyPOUhzo,12624
132
+ rslearn/train/tasks/segmentation.py,sha256=D9VaxsuiF0R8nvOKL821wPHey2GTMCgIqm8rZtnAXdk,22778
133
+ rslearn/train/tasks/task.py,sha256=CwGEFXquSjWXoJqbpBE1mQFCJl7NmMvgreAydphvu6U,3942
131
134
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
132
135
  rslearn/train/transforms/concatenate.py,sha256=sdVLJIyr9Nj2tzXEzvWFQnjJjyRSuhR_Faf6UlMIvbg,1568
133
136
  rslearn/train/transforms/crop.py,sha256=4jA3JJsC0ghicPHbfsNJ0d3WpChyvftY73ONiwQaif0,4214
@@ -153,10 +156,10 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
153
156
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
154
157
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
155
158
  rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
156
- rslearn-0.0.16.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
157
- rslearn-0.0.16.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
158
- rslearn-0.0.16.dist-info/METADATA,sha256=h0p9V4jlSLDsrC2_owCn0xEKL7Kka74mEsE_pj-tJf0,36319
159
- rslearn-0.0.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
160
- rslearn-0.0.16.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
161
- rslearn-0.0.16.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
162
- rslearn-0.0.16.dist-info/RECORD,,
159
+ rslearn-0.0.18.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
+ rslearn-0.0.18.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
161
+ rslearn-0.0.18.dist-info/METADATA,sha256=n-P9W0g_cQJ1BFYHxpIvC6EcJjbFOCoNlmHdSiN7sNo,37853
162
+ rslearn-0.0.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
163
+ rslearn-0.0.18.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
164
+ rslearn-0.0.18.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
165
+ rslearn-0.0.18.dist-info/RECORD,,
rslearn/dataset/index.py DELETED
@@ -1,173 +0,0 @@
1
- """Index about windows in the dataset."""
2
-
3
- import json
4
- import multiprocessing
5
- from typing import TYPE_CHECKING
6
-
7
- import tqdm
8
- from upath import UPath
9
-
10
- from .window import (
11
- Window,
12
- WindowLayerData,
13
- )
14
-
15
- if TYPE_CHECKING:
16
- from .dataset import Dataset
17
-
18
-
19
- def get_window_layer_datas(window: Window) -> list[WindowLayerData]:
20
- """Helper function for multiprocessing to load window layer datas."""
21
- return list(window.load_layer_datas().values())
22
-
23
-
24
- def get_window_completed_layers(window: Window) -> list[tuple[str, int]]:
25
- """Helper function for multiprocessing to load window completed layers."""
26
- return window.list_completed_layers()
27
-
28
-
29
- class DatasetIndex:
30
- """Manage an index about windows in the dataset.
31
-
32
- The index is just a single file containing information about all windows in the
33
- dataset, so that this information does not need to be loaded from per-window files.
34
-
35
- The information includes the window metadata, the window layer datas (matching data
36
- source items), and the completed layers.
37
-
38
- Currently the index must be manually maintained. It can be created for relatively
39
- static datasets, and updated each time the dataset is modified.
40
- """
41
-
42
- FNAME = "dataset_index.json"
43
-
44
- def __init__(
45
- self,
46
- windows: list[Window],
47
- layer_datas: dict[str, list[WindowLayerData]],
48
- completed_layers: dict[str, list[tuple[str, int]]],
49
- ):
50
- """Create a new DatasetIndex.
51
-
52
- Args:
53
- windows: the windows in the dataset.
54
- layer_datas: map from window name to the layer datas for that window.
55
- completed_layers: map from window name to its list of completed layers.
56
- Each element is (layer_name, group_idx).
57
- """
58
- self.windows = windows
59
- self.layer_datas = layer_datas
60
- self.completed_layers = completed_layers
61
-
62
- for window in self.windows:
63
- window.index = self
64
-
65
- def get_windows(
66
- self,
67
- groups: list[str] | None = None,
68
- names: list[str] | None = None,
69
- ) -> list[Window]:
70
- """Get the windows in the dataset.
71
-
72
- Args:
73
- groups: an optional list of groups to filter by
74
- names: an optional list of window names to filter by
75
- """
76
- windows = self.windows
77
- if groups is not None:
78
- group_set = set(groups)
79
- windows = [window for window in windows if window.group in group_set]
80
- if names is not None:
81
- name_set = set(names)
82
- windows = [window for window in windows if window.name in name_set]
83
- return windows
84
-
85
- def save_index(self, ds_path: UPath) -> None:
86
- """Save the index to the specified file."""
87
- encoded_windows = [window.get_metadata() for window in self.windows]
88
-
89
- encoded_layer_datas = {}
90
- for window_name, layer_data_list in self.layer_datas.items():
91
- encoded_layer_datas[window_name] = [
92
- layer_data.serialize() for layer_data in layer_data_list
93
- ]
94
-
95
- encoded_index = {
96
- "windows": encoded_windows,
97
- "layer_datas": encoded_layer_datas,
98
- "completed_layers": self.completed_layers,
99
- }
100
- with (ds_path / self.FNAME).open("w") as f:
101
- json.dump(encoded_index, f)
102
-
103
- @staticmethod
104
- def load_index(ds_path: UPath) -> "DatasetIndex":
105
- """Load the DatasetIndex for the specified dataset."""
106
- with (ds_path / DatasetIndex.FNAME).open() as f:
107
- encoded_index = json.load(f)
108
-
109
- windows = []
110
- for encoded_window in encoded_index["windows"]:
111
- window = Window.from_metadata(
112
- path=Window.get_window_root(
113
- ds_path, encoded_window["group"], encoded_window["name"]
114
- ),
115
- metadata=encoded_window,
116
- )
117
- windows.append(window)
118
-
119
- layer_datas = {}
120
- for window_name, encoded_layer_data_list in encoded_index[
121
- "layer_datas"
122
- ].items():
123
- layer_datas[window_name] = [
124
- WindowLayerData.deserialize(encoded_layer_data)
125
- for encoded_layer_data in encoded_layer_data_list
126
- ]
127
-
128
- completed_layers = {}
129
- for window_name, encoded_layer_list in encoded_index[
130
- "completed_layers"
131
- ].items():
132
- completed_layers[window_name] = [
133
- (layer_name, group_idx)
134
- for (layer_name, group_idx) in encoded_layer_list
135
- ]
136
-
137
- return DatasetIndex(
138
- windows=windows,
139
- layer_datas=layer_datas,
140
- completed_layers=completed_layers,
141
- )
142
-
143
- @staticmethod
144
- def build_index(dataset: "Dataset", workers: int) -> "DatasetIndex":
145
- """Build a new DatasetIndex for the specified dataset."""
146
- # Load windows.
147
- windows = dataset.load_windows(workers=workers, no_index=True)
148
-
149
- # Load layer datas.
150
- p = multiprocessing.Pool(workers)
151
- layer_data_outputs = p.imap(get_window_layer_datas, windows)
152
- layer_data_outputs = tqdm.tqdm(
153
- layer_data_outputs, total=len(windows), desc="Loading window layer datas"
154
- )
155
- layer_datas = {}
156
- for window, cur_layer_datas in zip(windows, layer_data_outputs):
157
- layer_datas[window.name] = cur_layer_datas
158
-
159
- # Load completed layers.
160
- completed_layer_outputs = p.imap(get_window_completed_layers, windows)
161
- completed_layer_outputs = tqdm.tqdm(
162
- completed_layer_outputs, total=len(windows), desc="Loading completed layers"
163
- )
164
- completed_layers = {} # window name -> list of (layer name, group idx)
165
- for window, cur_completed_layers in zip(windows, completed_layer_outputs):
166
- completed_layers[window.name] = cur_completed_layers
167
- p.close()
168
-
169
- return DatasetIndex(
170
- windows=windows,
171
- layer_datas=layer_datas,
172
- completed_layers=completed_layers,
173
- )
@@ -1,22 +0,0 @@
1
- """Model registry."""
2
-
3
- from collections.abc import Callable
4
- from typing import Any, TypeVar
5
-
6
- _ModelT = TypeVar("_ModelT")
7
-
8
-
9
- class _ModelRegistry(dict[str, type[Any]]):
10
- """Registry for Model classes."""
11
-
12
- def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
13
- """Decorator to register a model class."""
14
-
15
- def decorator(cls: type[_ModelT]) -> type[_ModelT]:
16
- self[name] = cls
17
- return cls
18
-
19
- return decorator
20
-
21
-
22
- Models = _ModelRegistry()