rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +55 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +10 -3
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +20 -17
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.18
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -343,10 +343,12 @@ directory `/path/to/dataset` and corresponding configuration file at
|
|
|
343
343
|
"bands": ["R", "G", "B"]
|
|
344
344
|
}],
|
|
345
345
|
"data_source": {
|
|
346
|
-
"
|
|
347
|
-
"
|
|
348
|
-
|
|
349
|
-
|
|
346
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
347
|
+
"init_args": {
|
|
348
|
+
"index_cache_dir": "cache/sentinel2/",
|
|
349
|
+
"sort_by": "cloud_cover",
|
|
350
|
+
"use_rtree_index": false
|
|
351
|
+
}
|
|
350
352
|
}
|
|
351
353
|
}
|
|
352
354
|
}
|
|
@@ -453,8 +455,10 @@ automate this process. Update the dataset `config.json` with a new layer:
|
|
|
453
455
|
}],
|
|
454
456
|
"resampling_method": "nearest",
|
|
455
457
|
"data_source": {
|
|
456
|
-
"
|
|
457
|
-
"
|
|
458
|
+
"class_path": "rslearn.data_sources.local_files.LocalFiles",
|
|
459
|
+
"init_args": {
|
|
460
|
+
"src_dir": "file:///path/to/world_cover_tifs/"
|
|
461
|
+
}
|
|
458
462
|
}
|
|
459
463
|
}
|
|
460
464
|
},
|
|
@@ -516,8 +520,7 @@ model:
|
|
|
516
520
|
data:
|
|
517
521
|
class_path: rslearn.train.data_module.RslearnDataModule
|
|
518
522
|
init_args:
|
|
519
|
-
|
|
520
|
-
path: /path/to/dataset/
|
|
523
|
+
path: ${DATASET_PATH}
|
|
521
524
|
# This defines the layers that should be read for each window.
|
|
522
525
|
# The key ("image" / "targets") is what the data will be called in the model,
|
|
523
526
|
# while the layers option specifies which layers will be read.
|
|
@@ -615,7 +618,9 @@ trainer:
|
|
|
615
618
|
...
|
|
616
619
|
- class_path: rslearn.train.prediction_writer.RslearnWriter
|
|
617
620
|
init_args:
|
|
618
|
-
|
|
621
|
+
# We need to include this argument, but it will be overridden with the dataset
|
|
622
|
+
# path from data.init_args.path.
|
|
623
|
+
path: placeholder
|
|
619
624
|
output_layer: output
|
|
620
625
|
```
|
|
621
626
|
|
|
@@ -768,24 +773,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
|
|
|
768
773
|
SegmentationTask and overriding the visualize function.
|
|
769
774
|
|
|
770
775
|
|
|
771
|
-
###
|
|
776
|
+
### Checkpoint and Logging Management
|
|
777
|
+
|
|
778
|
+
Above, we needed to configure the checkpoint directory in the model config (the
|
|
779
|
+
`dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
|
|
780
|
+
specify the checkpoint path when applying the model. Additionally, metrics are logged
|
|
781
|
+
to the local filesystem and not well organized.
|
|
772
782
|
|
|
773
|
-
We can
|
|
783
|
+
We can instead let rslearn automatically manage checkpoints, along with logging to
|
|
784
|
+
Weights & Biases. To do so, we add project_name, run_name, and management_dir options
|
|
785
|
+
to the model config. The project_name corresponds to the W&B project, and the run name
|
|
786
|
+
corresponds to the W&B name. The management_dir is a directory to store project data;
|
|
787
|
+
rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
|
|
788
|
+
and uses it to store checkpoints.
|
|
774
789
|
|
|
775
790
|
```yaml
|
|
791
|
+
model:
|
|
792
|
+
# ...
|
|
793
|
+
data:
|
|
794
|
+
# ...
|
|
776
795
|
trainer:
|
|
777
796
|
# ...
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
name: version_00
|
|
797
|
+
project_name: land_cover_model
|
|
798
|
+
run_name: version_00
|
|
799
|
+
# This sets the option via the MANAGEMENT_DIR environment variable.
|
|
800
|
+
management_dir: ${MANAGEMENT_DIR}
|
|
783
801
|
```
|
|
784
802
|
|
|
785
|
-
Now,
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
803
|
+
Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
|
|
804
|
+
|
|
805
|
+
```
|
|
806
|
+
export MANAGEMENT_DIR=./project_data
|
|
807
|
+
rslearn model fit --config land_cover_model.yaml
|
|
808
|
+
```
|
|
809
|
+
|
|
810
|
+
The training and validation loss and accuracy metric should now be logged to W&B. The
|
|
811
|
+
accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
|
|
812
|
+
by passing the relevant init_args to the task, e.g. mean IoU and F1:
|
|
789
813
|
|
|
790
814
|
```yaml
|
|
791
815
|
class_path: rslearn.train.tasks.segmentation.SegmentationTask
|
|
@@ -796,6 +820,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
|
|
|
796
820
|
enable_f1_metric: true
|
|
797
821
|
```
|
|
798
822
|
|
|
823
|
+
When calling `model test` and `model predict` with management_dir set, rslearn will
|
|
824
|
+
automatically load the best checkpoint from the project directory, or raise an error if
|
|
825
|
+
no existing checkpoint exists. This behavior can be overridden with the
|
|
826
|
+
`--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
|
|
827
|
+
details). Logging will be enabled during fit but not test/predict, and this can also
|
|
828
|
+
be overridden, using `--log_mode`.
|
|
829
|
+
|
|
799
830
|
|
|
800
831
|
### Inputting Multiple Sentinel-2 Images
|
|
801
832
|
|
|
@@ -818,10 +849,12 @@ query_config section. This can replace the sentinel2 layer:
|
|
|
818
849
|
"bands": ["R", "G", "B"]
|
|
819
850
|
}],
|
|
820
851
|
"data_source": {
|
|
821
|
-
"
|
|
822
|
-
"
|
|
823
|
-
|
|
824
|
-
|
|
852
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
853
|
+
"init_args": {
|
|
854
|
+
"index_cache_dir": "cache/sentinel2/",
|
|
855
|
+
"sort_by": "cloud_cover",
|
|
856
|
+
"use_rtree_index": false
|
|
857
|
+
},
|
|
825
858
|
"query_config": {
|
|
826
859
|
"max_matches": 3
|
|
827
860
|
}
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
rslearn/__init__.py,sha256=fFmAen3vxZyosEfPbG0W46IttujYGVxzrGkJ0YutmmY,73
|
|
2
2
|
rslearn/arg_parser.py,sha256=GNlJncO6Ck_dCNrcg7z_SSG61I-2gKn3Ix2tAxIk9CI,1428
|
|
3
3
|
rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
|
|
4
|
-
rslearn/lightning_cli.py,sha256=
|
|
4
|
+
rslearn/lightning_cli.py,sha256=Cihdf3dOQ17b_n4432Y6LmCQ5XFDghW4rGb4fqw-b6g,17525
|
|
5
5
|
rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
|
|
6
|
-
rslearn/main.py,sha256=
|
|
6
|
+
rslearn/main.py,sha256=CLVLpJJXc8f8BCG9SBdOuw_taADu3majD1TNSOVC6Ws,28565
|
|
7
7
|
rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
|
|
9
|
-
rslearn/config/__init__.py,sha256=
|
|
10
|
-
rslearn/config/dataset.py,sha256=
|
|
9
|
+
rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
|
|
10
|
+
rslearn/config/dataset.py,sha256=fp7HT8FeKyOmkR77pR5tVZYAAnuetCLY6zGFVcdPHyY,23087
|
|
11
11
|
rslearn/data_sources/__init__.py,sha256=zzuZUxrlEIw84YpD2I0HJvCoLDB29LbmnKTXiJykzGU,660
|
|
12
12
|
rslearn/data_sources/aws_landsat.py,sha256=0ZQtmd2NCnvLy4vFSB1AlmoguJbiQB_e_T4eS1tnW9Q,20443
|
|
13
13
|
rslearn/data_sources/aws_open_data.py,sha256=lrHnMJTH3NAaRdNjxwCIxSq8rq90IvV4ho-qAG6Hdgc,29348
|
|
@@ -34,59 +34,61 @@ rslearn/data_sources/worldcover.py,sha256=n7gi-JRytxkvkUhKT--dVziMcWSSyMbZA7ZCzL
|
|
|
34
34
|
rslearn/data_sources/worldpop.py,sha256=S3RSc5kTwSs2bcREjNarBsqf3MBX5CN0eHj7Qkx4K74,5625
|
|
35
35
|
rslearn/data_sources/xyz_tiles.py,sha256=P601CvUmoVDC_ZRVhPKaIoPYYCq5UhZ-v7DaGlN5y_0,13797
|
|
36
36
|
rslearn/dataset/__init__.py,sha256=bHtBlEEBCekO-gaJqiww0-VjvZTE5ahx0llleo8bfP8,289
|
|
37
|
-
rslearn/dataset/add_windows.py,sha256=
|
|
38
|
-
rslearn/dataset/dataset.py,sha256=
|
|
37
|
+
rslearn/dataset/add_windows.py,sha256=NwIvku6zxCJ9kgVFa5phJc0Gj1Y1bCzh6TLb9nEGl0s,8462
|
|
38
|
+
rslearn/dataset/dataset.py,sha256=A-iXXdGvTCus1kCKC58dk1mEidMe29X3Lox9JEto_TQ,3112
|
|
39
39
|
rslearn/dataset/handler_summaries.py,sha256=wI99RDk5erCWkzl1A7Uc4chatQ9KWIr4F_0Hxr9Co6s,2607
|
|
40
|
-
rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
|
|
41
40
|
rslearn/dataset/manage.py,sha256=-lGSIgk6Z7-verF_POwe4n5w9eSpgyt4nEOcOj382rc,18971
|
|
42
|
-
rslearn/dataset/materialize.py,sha256=
|
|
41
|
+
rslearn/dataset/materialize.py,sha256=o05OeLk_wWEOsw15oc5yjpD4J-twGCTfXAtxyAQsQ9I,20974
|
|
43
42
|
rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
|
|
44
|
-
rslearn/dataset/window.py,sha256=
|
|
43
|
+
rslearn/dataset/window.py,sha256=X4q8YzcSOTtwKxCPf71QLMoyKUtYMSnZu0kPnmVSUx4,10644
|
|
44
|
+
rslearn/dataset/storage/__init__.py,sha256=R50AVV5LH2g7ol0-jyvGcB390VsclXGbJXz4fmkn9as,52
|
|
45
|
+
rslearn/dataset/storage/file.py,sha256=g9HZ3CD4QcgyVNsBaXhjIKQgDOAeZ4R08sJ7ntx4wo8,6815
|
|
46
|
+
rslearn/dataset/storage/storage.py,sha256=DxZ7iwV938PiLwdQzb5EXSb4Mj8bRGmOTmA9fzq_Ge8,4840
|
|
45
47
|
rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
|
|
46
|
-
rslearn/models/anysat.py,sha256
|
|
47
|
-
rslearn/models/clip.py,sha256=
|
|
48
|
-
rslearn/models/
|
|
49
|
-
rslearn/models/
|
|
50
|
-
rslearn/models/
|
|
51
|
-
rslearn/models/
|
|
52
|
-
rslearn/models/
|
|
53
|
-
rslearn/models/
|
|
54
|
-
rslearn/models/
|
|
55
|
-
rslearn/models/
|
|
56
|
-
rslearn/models/
|
|
57
|
-
rslearn/models/
|
|
58
|
-
rslearn/models/
|
|
59
|
-
rslearn/models/
|
|
60
|
-
rslearn/models/
|
|
61
|
-
rslearn/models/
|
|
62
|
-
rslearn/models/
|
|
63
|
-
rslearn/models/resize_features.py,sha256=
|
|
64
|
-
rslearn/models/sam2_enc.py,sha256=
|
|
65
|
-
rslearn/models/satlaspretrain.py,sha256=
|
|
66
|
-
rslearn/models/simple_time_series.py,sha256=
|
|
67
|
-
rslearn/models/singletask.py,sha256=
|
|
68
|
-
rslearn/models/ssl4eo_s12.py,sha256=
|
|
69
|
-
rslearn/models/swin.py,sha256=
|
|
48
|
+
rslearn/models/anysat.py,sha256=-1uE2kSfR34lPld5AXEuU8KCfVIa3YAPqs-filxNiWY,8026
|
|
49
|
+
rslearn/models/clip.py,sha256=TgCPA7IEsnONFbWQxQlLguvAXzwp_Y90-Vk9MZnwXak,2337
|
|
50
|
+
rslearn/models/component.py,sha256=1vRK9K7hEgm0eyRhyPyLX4SBWZxiCXzO0fb_bwPQjQg,3015
|
|
51
|
+
rslearn/models/concatenate_features.py,sha256=Attemr5KurxlOojpclD0Pd5Cu2KHpNdpXe8jCSjpJ9U,3818
|
|
52
|
+
rslearn/models/conv.py,sha256=dEAAfhPo4bFlZPSAQjzqZTpP-hdJ394TytYssVK-fDA,2001
|
|
53
|
+
rslearn/models/croma.py,sha256=x8bTFOJB-yR9PydSEjCeV9WnzpalT9QLtozlICNDhyE,10820
|
|
54
|
+
rslearn/models/dinov3.py,sha256=zlICRIeOaxREpR75bzF4nS8P-ZUvLdrJLBU3cBTYgec,6458
|
|
55
|
+
rslearn/models/faster_rcnn.py,sha256=yYRk3attz_GyhJA6jE1ss4ybT_knbLNT1lMRrkz22PI,8614
|
|
56
|
+
rslearn/models/feature_center_crop.py,sha256=_Mu3E4iJLBug9I4ZIBIpB_VJo-xGterHmhtIFGaHR34,1808
|
|
57
|
+
rslearn/models/fpn.py,sha256=qm7nKMgsZrCoAdz8ASmNKU2nvZ6USm5CedMfy_w_gwE,2079
|
|
58
|
+
rslearn/models/module_wrapper.py,sha256=XjRgmhss9_beBEE77t9iySsIvYpV7R-DeyBlOBG000I,2004
|
|
59
|
+
rslearn/models/molmo.py,sha256=pPtCy7eg-xN-iRKCtNL_hpxagfJRyYc9o3NVb-ifonI,2182
|
|
60
|
+
rslearn/models/multitask.py,sha256=bpFxvtFowRyT-tvRSdY7AKbEx_i1y7sToEzZgTMcF4s,16264
|
|
61
|
+
rslearn/models/panopticon.py,sha256=4-KHbgAHUk-ZX04PFRIJF_BRO_euOZO0NfzjB3CEP7Y,5891
|
|
62
|
+
rslearn/models/pick_features.py,sha256=fI9SYubqpCWOAHYGVUSg5sgD31dsnAR9mNuLmqfIeL8,1110
|
|
63
|
+
rslearn/models/pooling_decoder.py,sha256=zrMH6wUExCa-XD1q9CIFD2ScgiasapyJs9plhcUxhIs,4767
|
|
64
|
+
rslearn/models/prithvi.py,sha256=SzdWm0CjYCZIviES3yBXec6TmvrWGQ6pf37w4BaODlY,40326
|
|
65
|
+
rslearn/models/resize_features.py,sha256=U7ZIVwwToJJnwchFG59wLWWP9eikHDB_1c4OtpubxHU,1693
|
|
66
|
+
rslearn/models/sam2_enc.py,sha256=iq2JBqloJ7xkg7Qz5M1HymRfEkLGGn2msF00R_nGBgM,3593
|
|
67
|
+
rslearn/models/satlaspretrain.py,sha256=ytlke5IlxlUEhUoIRcm0pz7umppppZBskCux9YE8mVg,3281
|
|
68
|
+
rslearn/models/simple_time_series.py,sha256=bmy3SsFocpbQwTk4s9gBMAYvz89IlUL_2pAFMZLew_8,13402
|
|
69
|
+
rslearn/models/singletask.py,sha256=9DM9a9-Mv3vVQqRhPOIXG2HHuVqVa_zuvgafeeYh4r0,1903
|
|
70
|
+
rslearn/models/ssl4eo_s12.py,sha256=FTDSkHCaQxw0B-R-8NIe50lZRBYFdI-v68lk9NrnwII,3737
|
|
71
|
+
rslearn/models/swin.py,sha256=QKiV00bu6dw9p0a9PbqDOe5zsnaFHuu6dOfEn41ms7I,6017
|
|
70
72
|
rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
|
|
71
|
-
rslearn/models/terramind.py,sha256=
|
|
72
|
-
rslearn/models/trunk.py,sha256=
|
|
73
|
-
rslearn/models/unet.py,sha256=
|
|
74
|
-
rslearn/models/upsample.py,sha256=
|
|
73
|
+
rslearn/models/terramind.py,sha256=R_daJOIzCnaC_Z3s5n9YtBa9Gf_HgQatgCOwwy4MtSA,8732
|
|
74
|
+
rslearn/models/trunk.py,sha256=1GCH9iyLIytoHVntLSMwfH9duQpe1W4DPmOClLpPKjc,4778
|
|
75
|
+
rslearn/models/unet.py,sha256=HuuINvkB1-5w9ZOTXZCWkVxJShruPKCol8pKeA3zw_4,7251
|
|
76
|
+
rslearn/models/upsample.py,sha256=JvfnktT6Dgcql3cSoySWXZ7dmkDkfpRo6vDkpz8KFAQ,1326
|
|
75
77
|
rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
|
|
76
|
-
rslearn/models/clay/clay.py,sha256=
|
|
78
|
+
rslearn/models/clay/clay.py,sha256=q2vcqyRFhCJvzLHwHLxHfWGyvVt819Xmk7TzuhXdPFI,8478
|
|
77
79
|
rslearn/models/clay/configs/metadata.yaml,sha256=rZTFh4Yb9htEfbQNOPl4HTbFogEhzwIRqFzG-1uT01Y,4652
|
|
78
80
|
rslearn/models/detr/__init__.py,sha256=GGAnTIhyuvl34IRrJ_4gXjm_01OlM5rbQQ3c3TGfbK8,84
|
|
79
81
|
rslearn/models/detr/box_ops.py,sha256=ORCF6EwMpMBB_VgQT05SjR47dCR2rN2gPhL_gsuUWJs,3236
|
|
80
|
-
rslearn/models/detr/detr.py,sha256=
|
|
82
|
+
rslearn/models/detr/detr.py,sha256=ZiYNwJ3zWqZdvvzc0CXDzdvcwSCO_wSOCWlChemTsX8,19178
|
|
81
83
|
rslearn/models/detr/matcher.py,sha256=4h_xFlgTMEJvJ6aLZUamrKZ72L5hDk9wPglNZ81JBg8,4533
|
|
82
84
|
rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJTC608TaX6yY,3869
|
|
83
85
|
rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
|
|
84
86
|
rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
|
|
85
87
|
rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
|
|
86
|
-
rslearn/models/galileo/galileo.py,sha256=
|
|
88
|
+
rslearn/models/galileo/galileo.py,sha256=PR90sPKxBfrpYolAfDf-AyBkNVARr6UbQmYc2XzcNPc,21365
|
|
87
89
|
rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
|
|
88
90
|
rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
|
|
89
|
-
rslearn/models/olmoearth_pretrain/model.py,sha256=
|
|
91
|
+
rslearn/models/olmoearth_pretrain/model.py,sha256=mGGoe45WEhFUJMQBYgKyJ_peDtsXKxNrgiYLz9u8I9o,10882
|
|
90
92
|
rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
|
|
91
93
|
rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
|
|
92
94
|
rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
|
|
@@ -101,18 +103,19 @@ rslearn/models/panopticon_data/sensors/sentinel2.yaml,sha256=qYJ92x-GHO0ZdCrTtCj
|
|
|
101
103
|
rslearn/models/panopticon_data/sensors/superdove.yaml,sha256=QpIRyopdV4hAez_EIsDwhGFT4VtTk7UgzQveyc8t8fc,795
|
|
102
104
|
rslearn/models/panopticon_data/sensors/wv23.yaml,sha256=SWYSlkka6UViKAz6YI8aqwQ-Ayo-S5kmNa9rO3iGW6o,1172
|
|
103
105
|
rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRygze0,64
|
|
104
|
-
rslearn/models/presto/presto.py,sha256=
|
|
105
|
-
rslearn/models/presto/single_file_presto.py,sha256
|
|
106
|
+
rslearn/models/presto/presto.py,sha256=SQCxY-jibe_PQs5yvoEOzjUhv8jaCoEkDE0pvSVhEUY,9542
|
|
107
|
+
rslearn/models/presto/single_file_presto.py,sha256=-P00xjhj9dx3O6HqWpQmG9dPk_i6bT_t8vhX4uQm5tA,30242
|
|
106
108
|
rslearn/tile_stores/__init__.py,sha256=-cW1J7So60SEP5ZLHCPdaFBV5CxvV3QlOhaFnUkhTJ0,1675
|
|
107
109
|
rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
|
|
108
110
|
rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
|
|
109
111
|
rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
|
|
110
|
-
rslearn/train/all_patches_dataset.py,sha256=
|
|
112
|
+
rslearn/train/all_patches_dataset.py,sha256=qgx1tHZOGOFxUB-HQ7jDxJXE_cvySQfDE6cMr60VS7s,18206
|
|
111
113
|
rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
|
|
112
|
-
rslearn/train/dataset.py,sha256=
|
|
113
|
-
rslearn/train/lightning_module.py,sha256=
|
|
114
|
+
rslearn/train/dataset.py,sha256=bsgpUoUKv9AeXzCzBqumiq8b-NZ5PjkFTvLNG82Vx2Q,34179
|
|
115
|
+
rslearn/train/lightning_module.py,sha256=HA9e-74oUZR5s7piQP9Mwxwz0vw0-p4HdmijfgBSpU0,14776
|
|
116
|
+
rslearn/train/model_context.py,sha256=TJHCw8xXLZrRHqGzJr30QARATgG_pJLbMBmyiIcZXRM,1449
|
|
114
117
|
rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
|
|
115
|
-
rslearn/train/prediction_writer.py,sha256=
|
|
118
|
+
rslearn/train/prediction_writer.py,sha256=rW0BUaYT_F1QqmpnQlbrLiLya1iBfC5Pb78G_NlF-vA,15956
|
|
116
119
|
rslearn/train/scheduler.py,sha256=wFbmycMHgL6nRYeYalDjb0G8YVo8VD3T3sABS61jJ7c,2318
|
|
117
120
|
rslearn/train/callbacks/__init__.py,sha256=VNV0ArZyYMvl3dGK2wl6F046khYJ1dEBlJS6G_SYNm0,47
|
|
118
121
|
rslearn/train/callbacks/adapters.py,sha256=yfv8nyCj3jmo2_dNkFrjukKxh0MHsf2xKqWwMF0QUtY,1869
|
|
@@ -120,14 +123,14 @@ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjT
|
|
|
120
123
|
rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
|
|
121
124
|
rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
|
|
122
125
|
rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
|
|
123
|
-
rslearn/train/tasks/classification.py,sha256=
|
|
124
|
-
rslearn/train/tasks/detection.py,sha256=
|
|
125
|
-
rslearn/train/tasks/embedding.py,sha256=
|
|
126
|
-
rslearn/train/tasks/multi_task.py,sha256=
|
|
127
|
-
rslearn/train/tasks/per_pixel_regression.py,sha256=
|
|
128
|
-
rslearn/train/tasks/regression.py,sha256=
|
|
129
|
-
rslearn/train/tasks/segmentation.py,sha256=
|
|
130
|
-
rslearn/train/tasks/task.py,sha256=
|
|
126
|
+
rslearn/train/tasks/classification.py,sha256=2_Iz3g9ifdtMo6a2sRtnAoYGP2Om8JPP7rm2AfwoDLc,14190
|
|
127
|
+
rslearn/train/tasks/detection.py,sha256=DrPLF_63WU99Qh1yILxSJWrYrl_2mCGxlTX2SznCei0,21938
|
|
128
|
+
rslearn/train/tasks/embedding.py,sha256=98ykdmfaxQjsH0UrUdTGmz1f0hCMPcNYTzt1YFqNQwQ,3869
|
|
129
|
+
rslearn/train/tasks/multi_task.py,sha256=piauvjg4j6eBEZnIFrKKxyYPWKdAJ4yUCFxt_ngsxDY,6125
|
|
130
|
+
rslearn/train/tasks/per_pixel_regression.py,sha256=jME-AFC74dfNAF5cBBmdmIgOV-KfkHt_z1GIOifuPJw,9975
|
|
131
|
+
rslearn/train/tasks/regression.py,sha256=8U0bcSofefmP9Drpbo7PO9xAkVptAuso2rOQyPOUhzo,12624
|
|
132
|
+
rslearn/train/tasks/segmentation.py,sha256=D9VaxsuiF0R8nvOKL821wPHey2GTMCgIqm8rZtnAXdk,22778
|
|
133
|
+
rslearn/train/tasks/task.py,sha256=CwGEFXquSjWXoJqbpBE1mQFCJl7NmMvgreAydphvu6U,3942
|
|
131
134
|
rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
|
|
132
135
|
rslearn/train/transforms/concatenate.py,sha256=sdVLJIyr9Nj2tzXEzvWFQnjJjyRSuhR_Faf6UlMIvbg,1568
|
|
133
136
|
rslearn/train/transforms/crop.py,sha256=4jA3JJsC0ghicPHbfsNJ0d3WpChyvftY73ONiwQaif0,4214
|
|
@@ -153,10 +156,10 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
|
|
|
153
156
|
rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
|
|
154
157
|
rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
|
|
155
158
|
rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
|
|
156
|
-
rslearn-0.0.
|
|
157
|
-
rslearn-0.0.
|
|
158
|
-
rslearn-0.0.
|
|
159
|
-
rslearn-0.0.
|
|
160
|
-
rslearn-0.0.
|
|
161
|
-
rslearn-0.0.
|
|
162
|
-
rslearn-0.0.
|
|
159
|
+
rslearn-0.0.18.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
160
|
+
rslearn-0.0.18.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
|
|
161
|
+
rslearn-0.0.18.dist-info/METADATA,sha256=n-P9W0g_cQJ1BFYHxpIvC6EcJjbFOCoNlmHdSiN7sNo,37853
|
|
162
|
+
rslearn-0.0.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
163
|
+
rslearn-0.0.18.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
164
|
+
rslearn-0.0.18.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
165
|
+
rslearn-0.0.18.dist-info/RECORD,,
|
rslearn/dataset/index.py
DELETED
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
"""Index about windows in the dataset."""
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import multiprocessing
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
6
|
-
|
|
7
|
-
import tqdm
|
|
8
|
-
from upath import UPath
|
|
9
|
-
|
|
10
|
-
from .window import (
|
|
11
|
-
Window,
|
|
12
|
-
WindowLayerData,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
if TYPE_CHECKING:
|
|
16
|
-
from .dataset import Dataset
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def get_window_layer_datas(window: Window) -> list[WindowLayerData]:
|
|
20
|
-
"""Helper function for multiprocessing to load window layer datas."""
|
|
21
|
-
return list(window.load_layer_datas().values())
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def get_window_completed_layers(window: Window) -> list[tuple[str, int]]:
|
|
25
|
-
"""Helper function for multiprocessing to load window completed layers."""
|
|
26
|
-
return window.list_completed_layers()
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class DatasetIndex:
|
|
30
|
-
"""Manage an index about windows in the dataset.
|
|
31
|
-
|
|
32
|
-
The index is just a single file containing information about all windows in the
|
|
33
|
-
dataset, so that this information does not need to be loaded from per-window files.
|
|
34
|
-
|
|
35
|
-
The information includes the window metadata, the window layer datas (matching data
|
|
36
|
-
source items), and the completed layers.
|
|
37
|
-
|
|
38
|
-
Currently the index must be manually maintained. It can be created for relatively
|
|
39
|
-
static datasets, and updated each time the dataset is modified.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
FNAME = "dataset_index.json"
|
|
43
|
-
|
|
44
|
-
def __init__(
|
|
45
|
-
self,
|
|
46
|
-
windows: list[Window],
|
|
47
|
-
layer_datas: dict[str, list[WindowLayerData]],
|
|
48
|
-
completed_layers: dict[str, list[tuple[str, int]]],
|
|
49
|
-
):
|
|
50
|
-
"""Create a new DatasetIndex.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
windows: the windows in the dataset.
|
|
54
|
-
layer_datas: map from window name to the layer datas for that window.
|
|
55
|
-
completed_layers: map from window name to its list of completed layers.
|
|
56
|
-
Each element is (layer_name, group_idx).
|
|
57
|
-
"""
|
|
58
|
-
self.windows = windows
|
|
59
|
-
self.layer_datas = layer_datas
|
|
60
|
-
self.completed_layers = completed_layers
|
|
61
|
-
|
|
62
|
-
for window in self.windows:
|
|
63
|
-
window.index = self
|
|
64
|
-
|
|
65
|
-
def get_windows(
|
|
66
|
-
self,
|
|
67
|
-
groups: list[str] | None = None,
|
|
68
|
-
names: list[str] | None = None,
|
|
69
|
-
) -> list[Window]:
|
|
70
|
-
"""Get the windows in the dataset.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
groups: an optional list of groups to filter by
|
|
74
|
-
names: an optional list of window names to filter by
|
|
75
|
-
"""
|
|
76
|
-
windows = self.windows
|
|
77
|
-
if groups is not None:
|
|
78
|
-
group_set = set(groups)
|
|
79
|
-
windows = [window for window in windows if window.group in group_set]
|
|
80
|
-
if names is not None:
|
|
81
|
-
name_set = set(names)
|
|
82
|
-
windows = [window for window in windows if window.name in name_set]
|
|
83
|
-
return windows
|
|
84
|
-
|
|
85
|
-
def save_index(self, ds_path: UPath) -> None:
|
|
86
|
-
"""Save the index to the specified file."""
|
|
87
|
-
encoded_windows = [window.get_metadata() for window in self.windows]
|
|
88
|
-
|
|
89
|
-
encoded_layer_datas = {}
|
|
90
|
-
for window_name, layer_data_list in self.layer_datas.items():
|
|
91
|
-
encoded_layer_datas[window_name] = [
|
|
92
|
-
layer_data.serialize() for layer_data in layer_data_list
|
|
93
|
-
]
|
|
94
|
-
|
|
95
|
-
encoded_index = {
|
|
96
|
-
"windows": encoded_windows,
|
|
97
|
-
"layer_datas": encoded_layer_datas,
|
|
98
|
-
"completed_layers": self.completed_layers,
|
|
99
|
-
}
|
|
100
|
-
with (ds_path / self.FNAME).open("w") as f:
|
|
101
|
-
json.dump(encoded_index, f)
|
|
102
|
-
|
|
103
|
-
@staticmethod
|
|
104
|
-
def load_index(ds_path: UPath) -> "DatasetIndex":
|
|
105
|
-
"""Load the DatasetIndex for the specified dataset."""
|
|
106
|
-
with (ds_path / DatasetIndex.FNAME).open() as f:
|
|
107
|
-
encoded_index = json.load(f)
|
|
108
|
-
|
|
109
|
-
windows = []
|
|
110
|
-
for encoded_window in encoded_index["windows"]:
|
|
111
|
-
window = Window.from_metadata(
|
|
112
|
-
path=Window.get_window_root(
|
|
113
|
-
ds_path, encoded_window["group"], encoded_window["name"]
|
|
114
|
-
),
|
|
115
|
-
metadata=encoded_window,
|
|
116
|
-
)
|
|
117
|
-
windows.append(window)
|
|
118
|
-
|
|
119
|
-
layer_datas = {}
|
|
120
|
-
for window_name, encoded_layer_data_list in encoded_index[
|
|
121
|
-
"layer_datas"
|
|
122
|
-
].items():
|
|
123
|
-
layer_datas[window_name] = [
|
|
124
|
-
WindowLayerData.deserialize(encoded_layer_data)
|
|
125
|
-
for encoded_layer_data in encoded_layer_data_list
|
|
126
|
-
]
|
|
127
|
-
|
|
128
|
-
completed_layers = {}
|
|
129
|
-
for window_name, encoded_layer_list in encoded_index[
|
|
130
|
-
"completed_layers"
|
|
131
|
-
].items():
|
|
132
|
-
completed_layers[window_name] = [
|
|
133
|
-
(layer_name, group_idx)
|
|
134
|
-
for (layer_name, group_idx) in encoded_layer_list
|
|
135
|
-
]
|
|
136
|
-
|
|
137
|
-
return DatasetIndex(
|
|
138
|
-
windows=windows,
|
|
139
|
-
layer_datas=layer_datas,
|
|
140
|
-
completed_layers=completed_layers,
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
@staticmethod
|
|
144
|
-
def build_index(dataset: "Dataset", workers: int) -> "DatasetIndex":
|
|
145
|
-
"""Build a new DatasetIndex for the specified dataset."""
|
|
146
|
-
# Load windows.
|
|
147
|
-
windows = dataset.load_windows(workers=workers, no_index=True)
|
|
148
|
-
|
|
149
|
-
# Load layer datas.
|
|
150
|
-
p = multiprocessing.Pool(workers)
|
|
151
|
-
layer_data_outputs = p.imap(get_window_layer_datas, windows)
|
|
152
|
-
layer_data_outputs = tqdm.tqdm(
|
|
153
|
-
layer_data_outputs, total=len(windows), desc="Loading window layer datas"
|
|
154
|
-
)
|
|
155
|
-
layer_datas = {}
|
|
156
|
-
for window, cur_layer_datas in zip(windows, layer_data_outputs):
|
|
157
|
-
layer_datas[window.name] = cur_layer_datas
|
|
158
|
-
|
|
159
|
-
# Load completed layers.
|
|
160
|
-
completed_layer_outputs = p.imap(get_window_completed_layers, windows)
|
|
161
|
-
completed_layer_outputs = tqdm.tqdm(
|
|
162
|
-
completed_layer_outputs, total=len(windows), desc="Loading completed layers"
|
|
163
|
-
)
|
|
164
|
-
completed_layers = {} # window name -> list of (layer name, group idx)
|
|
165
|
-
for window, cur_completed_layers in zip(windows, completed_layer_outputs):
|
|
166
|
-
completed_layers[window.name] = cur_completed_layers
|
|
167
|
-
p.close()
|
|
168
|
-
|
|
169
|
-
return DatasetIndex(
|
|
170
|
-
windows=windows,
|
|
171
|
-
layer_datas=layer_datas,
|
|
172
|
-
completed_layers=completed_layers,
|
|
173
|
-
)
|
rslearn/models/registry.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
"""Model registry."""
|
|
2
|
-
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from typing import Any, TypeVar
|
|
5
|
-
|
|
6
|
-
_ModelT = TypeVar("_ModelT")
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class _ModelRegistry(dict[str, type[Any]]):
|
|
10
|
-
"""Registry for Model classes."""
|
|
11
|
-
|
|
12
|
-
def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
|
|
13
|
-
"""Decorator to register a model class."""
|
|
14
|
-
|
|
15
|
-
def decorator(cls: type[_ModelT]) -> type[_ModelT]:
|
|
16
|
-
self[name] = cls
|
|
17
|
-
return cls
|
|
18
|
-
|
|
19
|
-
return decorator
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
Models = _ModelRegistry()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|