rslearn 0.0.9__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. rslearn/models/anysat.py +5 -1
  2. rslearn/models/dinov3.py +6 -1
  3. rslearn/models/feature_center_crop.py +50 -0
  4. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  5. rslearn/models/olmoearth_pretrain/model.py +263 -0
  6. rslearn/models/olmoearth_pretrain/norm.py +84 -0
  7. rslearn/models/pooling_decoder.py +43 -0
  8. rslearn/models/prithvi.py +9 -1
  9. rslearn/train/lightning_module.py +0 -3
  10. rslearn/train/tasks/classification.py +2 -2
  11. rslearn/train/tasks/detection.py +5 -5
  12. rslearn/train/tasks/per_pixel_regression.py +5 -4
  13. rslearn/train/tasks/regression.py +5 -5
  14. rslearn/train/transforms/pad.py +3 -3
  15. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/METADATA +3 -1
  16. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/RECORD +21 -25
  17. rslearn-0.0.12.dist-info/licenses/NOTICE +115 -0
  18. rslearn/models/copernicusfm.py +0 -228
  19. rslearn/models/copernicusfm_src/__init__.py +0 -1
  20. rslearn/models/copernicusfm_src/aurora/area.py +0 -50
  21. rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
  22. rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
  23. rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
  24. rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
  25. rslearn/models/copernicusfm_src/model_vit.py +0 -348
  26. rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
  27. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/WHEEL +0 -0
  28. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/entry_points.txt +0 -0
  29. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/licenses/LICENSE +0 -0
  30. {rslearn-0.0.9.dist-info → rslearn-0.0.12.dist-info}/top_level.txt +0 -0
@@ -44,21 +44,21 @@ rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWB
44
44
  rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
45
45
  rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
46
46
  rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
47
- rslearn/models/anysat.py,sha256=3BnaiS1sYB4SnV6qRjHksiz_r9vUuZeGPUO2XUziFA0,7810
47
+ rslearn/models/anysat.py,sha256=3Oh2gWxicVdUzOjevBEZf0PuolmCy0KC5Ad7JY-0Plc,7949
48
48
  rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
49
49
  rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
50
- rslearn/models/copernicusfm.py,sha256=3AiORuUre9sZYwydbrDgShwKtxeTLmExp7WQmJtBylg,7842
51
50
  rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
52
- rslearn/models/dinov3.py,sha256=GKk5qXZPCEporATJdjaSWsDTfWDlAGRWBplFUJN5nRM,6146
51
+ rslearn/models/dinov3.py,sha256=9k9kNlXCorQQwKjLGptooANd48TUBsITQ1e4fUomlM4,6337
53
52
  rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
53
+ rslearn/models/feature_center_crop.py,sha256=24eOrvLEGGVWPw7kPHyUes5HtYNAX7GZ_NpqDGMILEY,1553
54
54
  rslearn/models/fpn.py,sha256=s3cz29I14FaSuvBvLOcwCrqVsaRBxG5GjLlqap4WgPc,1603
55
55
  rslearn/models/module_wrapper.py,sha256=H2zb-8Au4t31kawW_4JEKHsaXFjpYDawb31ZEauKcxU,2728
56
56
  rslearn/models/molmo.py,sha256=mVrARBhZciMzOgOOjGB5AHlPIf2iO9IBSJmdyKSl1L8,2061
57
57
  rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,14672
58
58
  rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
59
59
  rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
60
- rslearn/models/pooling_decoder.py,sha256=jZfEQCfthfa21C9sEjgFHUcfhHMVlvG7_nDMw_1FLaE,2727
61
- rslearn/models/prithvi.py,sha256=SVM3ypJlVTkXQ69pPhB4UeJr87VnmADTCuyV365dbkU,39961
60
+ rslearn/models/pooling_decoder.py,sha256=unr2fSE_QmJHPi3dKtopqMtb1Kn-2h94LgwwAVP9vZg,4437
61
+ rslearn/models/prithvi.py,sha256=AIzcO5xk1ggR0MjbfhIzqPVgUKFN7odxygmgyAelfW8,40143
62
62
  rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
63
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
64
64
  rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
@@ -75,14 +75,6 @@ rslearn/models/upsample.py,sha256=3kWbyWZIk56JJxj8en9pieitbrk3XnbIsTKlEkiDQQY,93
75
75
  rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
76
76
  rslearn/models/clay/clay.py,sha256=5RO5H8EM0tKjCwWMQ4xDkKkUCwKpm2K_Yw1alnhvVhU,7773
77
77
  rslearn/models/clay/configs/metadata.yaml,sha256=rZTFh4Yb9htEfbQNOPl4HTbFogEhzwIRqFzG-1uT01Y,4652
78
- rslearn/models/copernicusfm_src/__init__.py,sha256=8QLhisbHub6VJl6egijnrOPKK5QNAe5FJhfcxEelj4Y,22
79
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py,sha256=aWH5_PgmS8umIwRbGA42RuEx-stb13z1nBjyUhBtaN4,18218
80
- rslearn/models/copernicusfm_src/model_vit.py,sha256=3coM_xYILlFY2TJiACmQBSe2z16jSG80SVEad_3uB3Q,11396
81
- rslearn/models/copernicusfm_src/aurora/area.py,sha256=ssg9aXgoZktOsFcEXDEY9670aPUN_PHfCOfDMtpsz1s,1711
82
- rslearn/models/copernicusfm_src/aurora/fourier.py,sha256=bmoNV3P6CH8R6W2GFuVW8zT_frQVaL-PAgpN3aFS5fA,4414
83
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py,sha256=EQgbsHBXDq0dTM9kApmmIqd5ZV2X9CPuA_AytbE51uM,9363
84
- rslearn/models/copernicusfm_src/flexivit/utils.py,sha256=tLBlzgT5bpwMSvyir46bPRWsMmRKh8s7VwMNuvSatGo,2192
85
- rslearn/models/copernicusfm_src/util/pos_embed.py,sha256=dUYuM_Nch2LB8jQ7UDTmFj36KWe4mM9bsY6dv5m_yZI,8511
86
78
  rslearn/models/detr/__init__.py,sha256=GGAnTIhyuvl34IRrJ_4gXjm_01OlM5rbQQ3c3TGfbK8,84
87
79
  rslearn/models/detr/box_ops.py,sha256=ORCF6EwMpMBB_VgQT05SjR47dCR2rN2gPhL_gsuUWJs,3236
88
80
  rslearn/models/detr/detr.py,sha256=otLmmyUm05e4MUyvQBoqo-RKnx3hbodTXvfPQWvuTEI,18737
@@ -93,6 +85,9 @@ rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,6
93
85
  rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
94
86
  rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
95
87
  rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
88
+ rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
89
+ rslearn/models/olmoearth_pretrain/model.py,sha256=I_RWFbwzO5yCWpEcEQP8PeiD8M1QpeMtVrjl15evIHU,10632
90
+ rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
96
91
  rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
97
92
  rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
98
93
  rslearn/models/panopticon_data/sensors/goes.yaml,sha256=o00aoWCYqam0aB1rPmXq1MKe8hsKak_qyBG7BPL27Sc,152
@@ -114,7 +109,7 @@ rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_
114
109
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
115
110
  rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
116
111
  rslearn/train/dataset.py,sha256=YiskNlYYcKqZxyw0Xzop1RGLbjMc-oK_rmhrSMVbTQg,51857
117
- rslearn/train/lightning_module.py,sha256=ge2z8trU7cMvxBeqUXC1tB44pftzitw7DRsIa6asBS4,14623
112
+ rslearn/train/lightning_module.py,sha256=ZLBiId3secUlVs2yzkN-mwVv4rMdh5TkdZYl4vv_Cw0,14466
118
113
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
119
114
  rslearn/train/prediction_writer.py,sha256=YNs92QqPrqbREZXoE-aPa_oKQW0C9LvZAY129vyvI08,13288
120
115
  rslearn/train/scheduler.py,sha256=wFbmycMHgL6nRYeYalDjb0G8YVo8VD3T3sABS61jJ7c,2318
@@ -124,11 +119,11 @@ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjT
124
119
  rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
125
120
  rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
126
121
  rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
127
- rslearn/train/tasks/classification.py,sha256=DI0_Wzs-9rNPWokvfxi1BIA6QyqNee42SpptQx82WHM,13182
128
- rslearn/train/tasks/detection.py,sha256=OoZzC8ZbmhyZ30tD-4cB-3Jj0AN6Y7hg0wk27rDguCE,22297
122
+ rslearn/train/tasks/classification.py,sha256=kahVdXPU6fDwDCdqlrjZGb9uA-PYG74DbQQ0kJUt-Eg,13186
123
+ rslearn/train/tasks/detection.py,sha256=9j9webusrjGexvUmZ7gl3NTBS63Qq511VFlB2WbLi5Y,22302
129
124
  rslearn/train/tasks/multi_task.py,sha256=dBWsnbvQ0CReNsbDHmZ_-sXjUE0H4S2OPcbJwMquG9g,6016
130
- rslearn/train/tasks/per_pixel_regression.py,sha256=tkVntKFzPlWFxdupPlMfhIRWlJ0UCgxg_FGhcA2-wjE,8649
131
- rslearn/train/tasks/regression.py,sha256=_PoxOfWNseujD4IWsuTL82fAAXgtco4WdfkNXQ68Nbg,11497
125
+ rslearn/train/tasks/per_pixel_regression.py,sha256=W8dbLyIiPgFI3gA_aZQX0pSFRWLP2v6tthsFbKhcDVg,8783
126
+ rslearn/train/tasks/regression.py,sha256=zZhrrZ1qxjrdLjKWC9McRivDXCcKiYfdLC-kaMeVkDc,11547
132
127
  rslearn/train/tasks/segmentation.py,sha256=xEni3CLDyetviv84XrpJg5xeJU87WHGFKTVfIeemGIY,21868
133
128
  rslearn/train/tasks/task.py,sha256=4w2xKL_U5JAtdj2dYoVv82h6xTtgUsA3IvIOcXyZecs,3887
134
129
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
@@ -137,7 +132,7 @@ rslearn/train/transforms/crop.py,sha256=4jA3JJsC0ghicPHbfsNJ0d3WpChyvftY73ONiwQa
137
132
  rslearn/train/transforms/flip.py,sha256=lkTeje3T8gNn2gt6957morXq1fGNho-apSpCvNp0_9o,3480
138
133
  rslearn/train/transforms/mask.py,sha256=pwt33XXWLwldLiar-PgVgBQzQd1qfL18SPz3LYQMoYM,2111
139
134
  rslearn/train/transforms/normalize.py,sha256=uyv2hE5hw5B2kCRHa4JIx0tfowm-C7bgumwINvvfyts,5014
140
- rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
135
+ rslearn/train/transforms/pad.py,sha256=pj4Ql8GSRrhg8KOZTNPB40Qq8CoCCHdGo04uficik84,4698
141
136
  rslearn/train/transforms/select_bands.py,sha256=uDfD9G8Z4VTt88QZsjj1FB20QEmzSefhKf7uDXYn77M,2441
142
137
  rslearn/train/transforms/sentinel1.py,sha256=FrLaYZs2AjqWQCun8DTFtgo1l0xLxqaFKtDNIehtpDg,1913
143
138
  rslearn/train/transforms/transform.py,sha256=n1Qzqix2dVvej-Q7iPzHeOQbqH79IBlvqPoymxhNVpE,4446
@@ -156,9 +151,10 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
156
151
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
157
152
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
158
153
  rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
159
- rslearn-0.0.9.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
160
- rslearn-0.0.9.dist-info/METADATA,sha256=6BV8wt9tuo94FkoKjR3RcF3AbKNbU3IodkJtK4tASkE,36248
161
- rslearn-0.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
162
- rslearn-0.0.9.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
163
- rslearn-0.0.9.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
164
- rslearn-0.0.9.dist-info/RECORD,,
154
+ rslearn-0.0.12.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
155
+ rslearn-0.0.12.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
156
+ rslearn-0.0.12.dist-info/METADATA,sha256=0jHeiz1QCT56zOws1CGGFVM9TotMOWIboQmGASdZAwY,36318
157
+ rslearn-0.0.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
158
+ rslearn-0.0.12.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
159
+ rslearn-0.0.12.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
160
+ rslearn-0.0.12.dist-info/RECORD,,
@@ -0,0 +1,115 @@
1
+ rslearn is released under Apache License 2.0
2
+ Copyright 2025 Allen Institute for AI
3
+
4
+ The following third party code is included in this repository.
5
+
6
+ ====================
7
+
8
+ rslearn.models.detr is adapted from https://github.com/facebookresearch/detr which is
9
+ released under Apache License 2.0.
10
+
11
+ Copyright 2020 - present, Facebook, Inc
12
+
13
+ ====================
14
+
15
+ rslearn.models.use_croma is copied from https://github.com/antofuller/CROMA
16
+
17
+ MIT License
18
+
19
+ Copyright (c) 2023 Anthony Fuller
20
+
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ====================
40
+
41
+ rslearn.models.galileo is adapted from https://github.com/nasaharvest/galileo
42
+
43
+ MIT License
44
+
45
+ Copyright (c) 2024 Presto Authors
46
+
47
+ Permission is hereby granted, free of charge, to any person obtaining a copy
48
+ of this software and associated documentation files (the "Software"), to deal
49
+ in the Software without restriction, including without limitation the rights
50
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
51
+ copies of the Software, and to permit persons to whom the Software is
52
+ furnished to do so, subject to the following conditions:
53
+
54
+ The above copyright notice and this permission notice shall be included in all
55
+ copies or substantial portions of the Software.
56
+
57
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
58
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
59
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
60
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
61
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
62
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
63
+ SOFTWARE.
64
+
65
+ ====================
66
+
67
+ rslearn.models.presto is adapted from https://github.com/nasaharvest/presto
68
+
69
+ MIT License
70
+
71
+ Copyright (c) 2024 Presto Authors
72
+
73
+ Permission is hereby granted, free of charge, to any person obtaining a copy
74
+ of this software and associated documentation files (the "Software"), to deal
75
+ in the Software without restriction, including without limitation the rights
76
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
77
+ copies of the Software, and to permit persons to whom the Software is
78
+ furnished to do so, subject to the following conditions:
79
+
80
+ The above copyright notice and this permission notice shall be included in all
81
+ copies or substantial portions of the Software.
82
+
83
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
84
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
85
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
86
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
87
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
88
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
89
+ SOFTWARE.
90
+
91
+ ====================
92
+
93
+ rslearn.models.prithvi includes code adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
94
+
95
+ MIT License
96
+
97
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
98
+
99
+ Permission is hereby granted, free of charge, to any person obtaining a copy
100
+ of this software and associated documentation files (the "Software"), to deal
101
+ in the Software without restriction, including without limitation the rights
102
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
103
+ copies of the Software, and to permit persons to whom the Software is
104
+ furnished to do so, subject to the following conditions:
105
+
106
+ The above copyright notice and this permission notice shall be included in all
107
+ copies or substantial portions of the Software.
108
+
109
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
110
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
111
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
112
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
113
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
114
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
115
+ SOFTWARE.
@@ -1,228 +0,0 @@
1
- """Copernicus FM model."""
2
-
3
- import logging
4
- import math
5
- from enum import Enum
6
- from pathlib import Path
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- from einops import rearrange
11
- from huggingface_hub import hf_hub_download
12
-
13
- from .copernicusfm_src.model_vit import vit_base_patch16
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class CopernicusFMModality(Enum):
19
- """Modality for Copernicus FM."""
20
-
21
- SENTINEL2_L2A = "sentinel2_l2a"
22
- SENTINEL1 = "sentinel1"
23
-
24
-
25
- MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
26
- # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s2.yaml
27
- CopernicusFMModality.SENTINEL2_L2A.value: {
28
- "band_names": [
29
- "B01",
30
- "B02",
31
- "B03",
32
- "B04",
33
- "B05",
34
- "B06",
35
- "B07",
36
- "B08",
37
- "B8A",
38
- "B09",
39
- "B10",
40
- "B11",
41
- "B12",
42
- ],
43
- "band_wavelengths": [
44
- 440,
45
- 490,
46
- 560,
47
- 665,
48
- 705,
49
- 740,
50
- 783,
51
- 842,
52
- 860,
53
- 940,
54
- 1370,
55
- 1610,
56
- 2190,
57
- ],
58
- "band_bandwidths": [20, 65, 35, 30, 15, 15, 20, 115, 20, 20, 30, 90, 180],
59
- },
60
- # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/configs/dataset/cobench_eurosat_s1.yaml
61
- CopernicusFMModality.SENTINEL1.value: {
62
- "band_names": ["vv", "vh"],
63
- "band_wavelengths": [50000000, 50000000],
64
- "band_bandwidths": [1e9, 1e9],
65
- },
66
- }
67
-
68
- HF_REPO_ID = "wangyi111/Copernicus-FM"
69
- HF_REPO_REVISION = "e1db406d517a122c8373802e1c130c5fc4789f84"
70
- HF_FILENAME = "CopernicusFM_ViT_base_varlang_e100.pth"
71
-
72
-
73
- class CopernicusFM(torch.nn.Module):
74
- """Wrapper for Copernicus FM to ingest Masked Helios Sample."""
75
-
76
- image_resolution = 224
77
- patch_size = 16
78
- input_mode = "spectral"
79
- # Don't need this as band order is provided
80
- supported_modalities = [
81
- CopernicusFMModality.SENTINEL2_L2A.value,
82
- CopernicusFMModality.SENTINEL1.value,
83
- ]
84
-
85
- def __init__(
86
- self,
87
- band_order: dict[str, list[str]],
88
- cache_dir: str | Path | None = None,
89
- ) -> None:
90
- """Initialize the Copernicus FM wrapper.
91
-
92
- Args:
93
- band_order: The band order for each modality that will be used. The bands
94
- can be provided in any order, and any subset can be used.
95
- cache_dir: The directory to cache the weights. If None, a default directory
96
- managed by huggingface_hub is used. The weights are downloaded from
97
- Hugging Face (https://huggingface.co/wangyi111/Copernicus-FM).
98
- """
99
- super().__init__()
100
-
101
- # Make sure all keys in band_order are in supported_modalities.
102
- for modality_name in band_order.keys():
103
- if modality_name in self.supported_modalities:
104
- continue
105
- raise ValueError(
106
- f"band_order contains unsupported modality {modality_name}"
107
- )
108
-
109
- # global_pool=True so that we initialize the fc_norm layer
110
- self.model = vit_base_patch16(num_classes=10, global_pool=True)
111
-
112
- # Load weights, downloading if needed.
113
- local_fname = hf_hub_download(
114
- repo_id=HF_REPO_ID,
115
- revision=HF_REPO_REVISION,
116
- filename=HF_FILENAME,
117
- local_dir=cache_dir,
118
- ) # nosec
119
- state_dict = torch.load(local_fname, weights_only=True)
120
- self.model.load_state_dict(state_dict, strict=False)
121
-
122
- # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrange it so that it has the same
123
- # ordering as the user-provided band order.
124
- self.modality_to_wavelength_bandwidths = {}
125
- for modality in self.supported_modalities:
126
- if modality not in band_order:
127
- continue
128
-
129
- wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
130
- wavelengths = []
131
- bandwidths = []
132
- for b in band_order[modality]:
133
- cfm_idx = wavelength_bandwidths["band_names"].index(b)
134
- wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
135
- bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
136
- self.modality_to_wavelength_bandwidths[modality] = {
137
- "band_bandwidths": bandwidths,
138
- "band_wavelengths": wavelengths,
139
- }
140
-
141
- def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
142
- """Process individual modality data.
143
-
144
- Args:
145
- data: Input tensor of shape [B, C, H, W]
146
-
147
- Returns:
148
- list of tensors of shape [B, C, H, W]
149
- """
150
- # Get original dimensions
151
- original_height = data.shape[2]
152
- new_height = self.patch_size if original_height == 1 else self.image_resolution
153
- data = F.interpolate(
154
- data,
155
- size=(new_height, new_height),
156
- mode="bilinear",
157
- align_corners=False,
158
- )
159
- return data
160
-
161
- def prepare_input(
162
- self,
163
- inputs: dict[str, torch.Tensor],
164
- ) -> tuple[torch.Tensor, list[int], list[int]]:
165
- """Prepare input for the CopernicusFM model from MaskedHeliosSample."""
166
- wavelengths: list[int] = []
167
- bandwidths: list[int] = []
168
- all_processed_data: list[list[torch.Tensor]] = []
169
- for modality in inputs.keys():
170
- if modality not in self.supported_modalities:
171
- logger.debug(
172
- f"Skipping modality {modality} as it is not in the supported "
173
- f"modalities list {self.supported_modalities}"
174
- )
175
- continue
176
-
177
- data = inputs[modality]
178
-
179
- if data is None:
180
- continue
181
-
182
- all_processed_data.append(self._resize_data(data))
183
- wavelengths.extend(
184
- self.modality_to_wavelength_bandwidths[modality]["band_wavelengths"]
185
- )
186
- bandwidths.extend(
187
- self.modality_to_wavelength_bandwidths[modality]["band_bandwidths"]
188
- )
189
-
190
- concatenated_processed_data = torch.cat(all_processed_data, dim=1)
191
- return concatenated_processed_data, wavelengths, bandwidths
192
-
193
- def forward(
194
- self,
195
- inputs: list[dict[str, torch.Tensor]],
196
- ) -> torch.Tensor:
197
- """Forward pass through CopernicusFM model."""
198
- batch_inputs = {
199
- key: torch.stack([inp[key] for inp in inputs], dim=0)
200
- for key in inputs[0].keys()
201
- }
202
- # Prepare input
203
- data, wavelengths, bandwidths = self.prepare_input(batch_inputs)
204
- meta = torch.full(
205
- (1, 4), float("nan"), device=data.device
206
- ) # [lon, lat, delta_time, patch_token_area], assume unknown
207
- # "The embed tensor contains the encoded image features, which can be used for downstream tasks."
208
- _, timestep_output = self.model(
209
- data,
210
- meta,
211
- wavelengths,
212
- bandwidths,
213
- None,
214
- self.input_mode,
215
- self.patch_size,
216
- )
217
- # no norm, following
218
- # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py
219
- side = math.isqrt(timestep_output.shape[1])
220
- output_features = rearrange(
221
- timestep_output, "b (h w) c -> b c h w ", h=side, w=side
222
- )
223
- return [output_features]
224
-
225
- def get_backbone_channels(self) -> list[tuple[int, int]]:
226
- """Returns the output channels of this model when used as a backbone."""
227
- # TODO: load this from a constant depending on the model size
228
- return [(self.patch_size, 768)]
@@ -1 +0,0 @@
1
- # mypy: ignore-errors
@@ -1,50 +0,0 @@
1
- """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
2
-
3
- import torch
4
-
5
- __all__ = ["area", "radius_earth"]
6
-
7
-
8
- # float: Radius of the earth in kilometers.
9
- radius_earth = 6378137 / 1000
10
-
11
-
12
- def area(polygon: torch.Tensor) -> torch.Tensor:
13
- """Compute the area of a polygon specified by latitudes and longitudes in degrees.
14
-
15
- This function is a PyTorch port of the PyPI package `area`. In particular, it is heavily
16
- inspired by the following file:
17
-
18
- https://github.com/scisco/area/blob/9d9549d6ebffcbe4bffe11b71efa2d406d1c9fe9/area/__init__.py
19
-
20
- Args:
21
- polygon (:class:`torch.Tensor`): Polygon of the shape `(*b, n, 2)` where `b` is an optional
22
- multidimensional batch size, `n` is the number of points of the polygon, and 2
23
- concatenates first latitudes and then longitudes. The polygon does not have be closed.
24
-
25
- Returns:
26
- :class:`torch.Tensor`: Area in square kilometers.
27
- """
28
- # Be sure to close the loop.
29
- polygon = torch.cat((polygon, polygon[..., -1:, :]), axis=-2)
30
-
31
- area = torch.zeros(polygon.shape[:-2], dtype=polygon.dtype, device=polygon.device)
32
- n = polygon.shape[-2] # Number of points of the polygon
33
-
34
- rad = torch.deg2rad # Convert degrees to radians.
35
-
36
- if n > 2:
37
- for i in range(n):
38
- i_lower = i
39
- i_middle = (i + 1) % n
40
- i_upper = (i + 2) % n
41
-
42
- lon_lower = polygon[..., i_lower, 1]
43
- lat_middle = polygon[..., i_middle, 0]
44
- lon_upper = polygon[..., i_upper, 1]
45
-
46
- area = area + (rad(lon_upper) - rad(lon_lower)) * torch.sin(rad(lat_middle))
47
-
48
- area = area * radius_earth * radius_earth / 2
49
-
50
- return torch.abs(area)
@@ -1,134 +0,0 @@
1
- # type: ignore
2
- """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
3
-
4
- import math
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
-
10
- from .area import area, radius_earth
11
-
12
- __all__ = [
13
- "FourierExpansion",
14
- "pos_expansion",
15
- "scale_expansion",
16
- "lead_time_expansion",
17
- "levels_expansion",
18
- "absolute_time_expansion",
19
- ]
20
-
21
-
22
- class FourierExpansion(nn.Module):
23
- """A Fourier series-style expansion into a high-dimensional space.
24
-
25
- Attributes:
26
- lower (float): Lower wavelength.
27
- upper (float): Upper wavelength.
28
- assert_range (bool): Assert that the encoded tensor is within the specified wavelength
29
- range.
30
- """
31
-
32
- def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
33
- """Initialise.
34
-
35
- Args:
36
- lower (float): Lower wavelength.
37
- upper (float): Upper wavelength.
38
- assert_range (bool, optional): Assert that the encoded tensor is within the specified
39
- wavelength range. Defaults to `True`.
40
- """
41
- super().__init__()
42
- self.lower = lower
43
- self.upper = upper
44
- self.assert_range = assert_range
45
-
46
- def forward(self, x: torch.Tensor, d: int) -> torch.Tensor:
47
- """Perform the expansion.
48
-
49
- Adds a dimension of length `d` to the end of the shape of `x`.
50
-
51
- Args:
52
- x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must
53
- lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
54
- d (int): Dimensionality. Must be a multiple of two.
55
-
56
- Raises:
57
- AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not
58
- within `[self.lower, self.upper]`.
59
- ValueError: If `d` is not a multiple of two.
60
-
61
- Returns:
62
- torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`.
63
- """
64
- # If the input is not within the configured range, the embedding might be ambiguous!
65
- in_range = torch.logical_and(
66
- self.lower <= x.abs(), torch.all(x.abs() <= self.upper)
67
- )
68
- in_range_or_zero = torch.all(
69
- torch.logical_or(in_range, x == 0)
70
- ) # Allow zeros to pass through.
71
- if self.assert_range and not in_range_or_zero:
72
- raise AssertionError(
73
- f"The input tensor is not within the configured range"
74
- f" `[{self.lower}, {self.upper}]`."
75
- )
76
-
77
- # We will use half of the dimensionality for `sin` and the other half for `cos`.
78
- if not (d % 2 == 0):
79
- raise ValueError("The dimensionality must be a multiple of two.")
80
-
81
- # Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
82
- x = x.double()
83
-
84
- wavelengths = torch.logspace(
85
- math.log10(self.lower),
86
- math.log10(self.upper),
87
- d // 2,
88
- base=10,
89
- device=x.device,
90
- dtype=x.dtype,
91
- )
92
- prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths)
93
- encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)
94
-
95
- return encoding.float() # Cast to `float32` to avoid incompatibilities.
96
-
97
-
98
- # Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in
99
- # latitudes and longitudes.
100
- _delta = 0.01 # Reasonable smallest delta in latitude and longitude
101
- _min_patch_area: float = area(
102
- torch.tensor(
103
- [
104
- # The smallest patches will be at the poles. Just use the north pole.
105
- [90, 0],
106
- [90, _delta],
107
- [90 - _delta, _delta],
108
- [90 - _delta, 0],
109
- ],
110
- dtype=torch.float64,
111
- )
112
- ).item()
113
- _area_earth = 4 * np.pi * radius_earth * radius_earth
114
-
115
- pos_expansion = FourierExpansion(_delta, 720)
116
-
117
-
118
- scale_expansion = FourierExpansion(_min_patch_area, _area_earth)
119
-
120
-
121
- lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3)
122
-
123
- levels_expansion = FourierExpansion(0.01, 1e5)
124
-
125
- absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False)
126
-
127
- ### new for SSL4EO-S ###
128
- # min wavelength: ultraviolet light (100 nm)
129
- # max wavelength: radio waves (1 m)
130
- spectrum_central_expansion = FourierExpansion(1e-7, 1)
131
-
132
- # min bandwidth: 10nm
133
- # max bandwidth: 1m
134
- spectrum_width_expansion = FourierExpansion(1e-7, 1)