rslearn 0.0.9__tar.gz → 0.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. rslearn-0.0.12/NOTICE +115 -0
  2. {rslearn-0.0.9/rslearn.egg-info → rslearn-0.0.12}/PKG-INFO +3 -1
  3. {rslearn-0.0.9 → rslearn-0.0.12}/pyproject.toml +3 -1
  4. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/anysat.py +5 -1
  5. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/dinov3.py +6 -1
  6. rslearn-0.0.12/rslearn/models/feature_center_crop.py +50 -0
  7. rslearn-0.0.12/rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  8. rslearn-0.0.12/rslearn/models/olmoearth_pretrain/model.py +263 -0
  9. rslearn-0.0.12/rslearn/models/olmoearth_pretrain/norm.py +84 -0
  10. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/pooling_decoder.py +43 -0
  11. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/prithvi.py +9 -1
  12. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/lightning_module.py +0 -3
  13. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/classification.py +2 -2
  14. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/detection.py +5 -5
  15. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/per_pixel_regression.py +5 -4
  16. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/regression.py +5 -5
  17. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/pad.py +3 -3
  18. {rslearn-0.0.9 → rslearn-0.0.12/rslearn.egg-info}/PKG-INFO +3 -1
  19. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/SOURCES.txt +5 -9
  20. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/requires.txt +1 -0
  21. rslearn-0.0.9/rslearn/models/copernicusfm.py +0 -228
  22. rslearn-0.0.9/rslearn/models/copernicusfm_src/__init__.py +0 -1
  23. rslearn-0.0.9/rslearn/models/copernicusfm_src/aurora/area.py +0 -50
  24. rslearn-0.0.9/rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
  25. rslearn-0.0.9/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
  26. rslearn-0.0.9/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
  27. rslearn-0.0.9/rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
  28. rslearn-0.0.9/rslearn/models/copernicusfm_src/model_vit.py +0 -348
  29. rslearn-0.0.9/rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
  30. {rslearn-0.0.9 → rslearn-0.0.12}/LICENSE +0 -0
  31. {rslearn-0.0.9 → rslearn-0.0.12}/README.md +0 -0
  32. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/__init__.py +0 -0
  33. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/arg_parser.py +0 -0
  34. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/config/__init__.py +0 -0
  35. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/config/dataset.py +0 -0
  36. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/const.py +0 -0
  37. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/__init__.py +0 -0
  38. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_landsat.py +0 -0
  39. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_open_data.py +0 -0
  40. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/aws_sentinel1.py +0 -0
  41. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/climate_data_store.py +0 -0
  42. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/copernicus.py +0 -0
  43. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/data_source.py +0 -0
  44. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/earthdaily.py +0 -0
  45. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/earthdata_srtm.py +0 -0
  46. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/eurocrops.py +0 -0
  47. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/gcp_public_data.py +0 -0
  48. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/geotiff.py +0 -0
  49. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/google_earth_engine.py +0 -0
  50. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/local_files.py +0 -0
  51. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/openstreetmap.py +0 -0
  52. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planet.py +0 -0
  53. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planet_basemap.py +0 -0
  54. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/planetary_computer.py +0 -0
  55. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/raster_source.py +0 -0
  56. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/usda_cdl.py +0 -0
  57. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/usgs_landsat.py +0 -0
  58. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/utils.py +0 -0
  59. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/vector_source.py +0 -0
  60. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldcereal.py +0 -0
  61. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldcover.py +0 -0
  62. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/worldpop.py +0 -0
  63. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/data_sources/xyz_tiles.py +0 -0
  64. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/__init__.py +0 -0
  65. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/add_windows.py +0 -0
  66. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/dataset.py +0 -0
  67. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/handler_summaries.py +0 -0
  68. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/index.py +0 -0
  69. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/manage.py +0 -0
  70. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/materialize.py +0 -0
  71. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/remap.py +0 -0
  72. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/dataset/window.py +0 -0
  73. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/log_utils.py +0 -0
  74. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/main.py +0 -0
  75. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/__init__.py +0 -0
  76. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clay/clay.py +0 -0
  77. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clay/configs/metadata.yaml +0 -0
  78. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/clip.py +0 -0
  79. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/conv.py +0 -0
  80. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/croma.py +0 -0
  81. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/__init__.py +0 -0
  82. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/box_ops.py +0 -0
  83. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/detr.py +0 -0
  84. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/matcher.py +0 -0
  85. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/position_encoding.py +0 -0
  86. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/transformer.py +0 -0
  87. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/detr/util.py +0 -0
  88. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/faster_rcnn.py +0 -0
  89. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/fpn.py +0 -0
  90. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/__init__.py +0 -0
  91. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/galileo.py +0 -0
  92. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/galileo/single_file_galileo.py +0 -0
  93. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/module_wrapper.py +0 -0
  94. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/molmo.py +0 -0
  95. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/multitask.py +0 -0
  96. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon.py +0 -0
  97. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  98. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  99. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  100. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  101. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  102. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  103. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  104. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  105. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  106. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  107. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  108. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  109. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/pick_features.py +0 -0
  110. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/__init__.py +0 -0
  111. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/presto.py +0 -0
  112. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/presto/single_file_presto.py +0 -0
  113. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/registry.py +0 -0
  114. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/resize_features.py +0 -0
  115. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/sam2_enc.py +0 -0
  116. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/satlaspretrain.py +0 -0
  117. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/simple_time_series.py +0 -0
  118. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/singletask.py +0 -0
  119. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/ssl4eo_s12.py +0 -0
  120. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/swin.py +0 -0
  121. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/task_embedding.py +0 -0
  122. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/terramind.py +0 -0
  123. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/trunk.py +0 -0
  124. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/unet.py +0 -0
  125. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/upsample.py +0 -0
  126. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/models/use_croma.py +0 -0
  127. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/py.typed +0 -0
  128. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/template_params.py +0 -0
  129. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/__init__.py +0 -0
  130. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/default.py +0 -0
  131. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/tile_stores/tile_store.py +0 -0
  132. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/__init__.py +0 -0
  133. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/__init__.py +0 -0
  134. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/adapters.py +0 -0
  135. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  136. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/gradients.py +0 -0
  137. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/callbacks/peft.py +0 -0
  138. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/data_module.py +0 -0
  139. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/dataset.py +0 -0
  140. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/optimizer.py +0 -0
  141. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/prediction_writer.py +0 -0
  142. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/scheduler.py +0 -0
  143. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/__init__.py +0 -0
  144. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/multi_task.py +0 -0
  145. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/segmentation.py +0 -0
  146. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/tasks/task.py +0 -0
  147. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/__init__.py +0 -0
  148. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/concatenate.py +0 -0
  149. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/crop.py +0 -0
  150. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/flip.py +0 -0
  151. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/mask.py +0 -0
  152. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/normalize.py +0 -0
  153. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/select_bands.py +0 -0
  154. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/sentinel1.py +0 -0
  155. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/train/transforms/transform.py +0 -0
  156. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/__init__.py +0 -0
  157. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/array.py +0 -0
  158. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/feature.py +0 -0
  159. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/fsspec.py +0 -0
  160. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/geometry.py +0 -0
  161. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/get_utm_ups_crs.py +0 -0
  162. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/grid_index.py +0 -0
  163. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/jsonargparse.py +0 -0
  164. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/mp.py +0 -0
  165. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/raster_format.py +0 -0
  166. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/rtree_index.py +0 -0
  167. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/spatial_index.py +0 -0
  168. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/sqlite_index.py +0 -0
  169. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/time.py +0 -0
  170. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn/utils/vector_format.py +0 -0
  171. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/dependency_links.txt +0 -0
  172. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/entry_points.txt +0 -0
  173. {rslearn-0.0.9 → rslearn-0.0.12}/rslearn.egg-info/top_level.txt +0 -0
  174. {rslearn-0.0.9 → rslearn-0.0.12}/setup.cfg +0 -0
rslearn-0.0.12/NOTICE ADDED
@@ -0,0 +1,115 @@
1
+ rslearn is released under Apache License 2.0
2
+ Copyright 2025 Allen Institute for AI
3
+
4
+ The following third party code is included in this repository.
5
+
6
+ ====================
7
+
8
+ rslearn.models.detr is adapted from https://github.com/facebookresearch/detr which is
9
+ released under Apache License 2.0.
10
+
11
+ Copyright 2020 - present, Facebook, Inc
12
+
13
+ ====================
14
+
15
+ rslearn.models.use_croma is copied from https://github.com/antofuller/CROMA
16
+
17
+ MIT License
18
+
19
+ Copyright (c) 2023 Anthony Fuller
20
+
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ====================
40
+
41
+ rslearn.models.galileo is adapted from https://github.com/nasaharvest/galileo
42
+
43
+ MIT License
44
+
45
+ Copyright (c) 2024 Presto Authors
46
+
47
+ Permission is hereby granted, free of charge, to any person obtaining a copy
48
+ of this software and associated documentation files (the "Software"), to deal
49
+ in the Software without restriction, including without limitation the rights
50
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
51
+ copies of the Software, and to permit persons to whom the Software is
52
+ furnished to do so, subject to the following conditions:
53
+
54
+ The above copyright notice and this permission notice shall be included in all
55
+ copies or substantial portions of the Software.
56
+
57
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
58
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
59
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
60
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
61
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
62
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
63
+ SOFTWARE.
64
+
65
+ ====================
66
+
67
+ rslearn.models.presto is adapted from https://github.com/nasaharvest/presto
68
+
69
+ MIT License
70
+
71
+ Copyright (c) 2024 Presto Authors
72
+
73
+ Permission is hereby granted, free of charge, to any person obtaining a copy
74
+ of this software and associated documentation files (the "Software"), to deal
75
+ in the Software without restriction, including without limitation the rights
76
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
77
+ copies of the Software, and to permit persons to whom the Software is
78
+ furnished to do so, subject to the following conditions:
79
+
80
+ The above copyright notice and this permission notice shall be included in all
81
+ copies or substantial portions of the Software.
82
+
83
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
84
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
85
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
86
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
87
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
88
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
89
+ SOFTWARE.
90
+
91
+ ====================
92
+
93
+ rslearn.models.prithvi includes code adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
94
+
95
+ MIT License
96
+
97
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
98
+
99
+ Permission is hereby granted, free of charge, to any person obtaining a copy
100
+ of this software and associated documentation files (the "Software"), to deal
101
+ in the Software without restriction, including without limitation the rights
102
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
103
+ copies of the Software, and to permit persons to whom the Software is
104
+ furnished to do so, subject to the following conditions:
105
+
106
+ The above copyright notice and this permission notice shall be included in all
107
+ copies or substantial portions of the Software.
108
+
109
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
110
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
111
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
112
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
113
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
114
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
115
+ SOFTWARE.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.9
3
+ Version: 0.0.12
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -211,6 +211,7 @@ Project-URL: repository, https://github.com/allenai/rslearn
211
211
  Requires-Python: >=3.11
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
+ License-File: NOTICE
214
215
  Requires-Dist: boto3>=1.39
215
216
  Requires-Dist: fiona>=1.10
216
217
  Requires-Dist: fsspec>=2025.9.0
@@ -243,6 +244,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
243
244
  Requires-Dist: pycocotools>=2.0; extra == "extra"
244
245
  Requires-Dist: pystac_client>=0.9; extra == "extra"
245
246
  Requires-Dist: rtree>=1.4; extra == "extra"
247
+ Requires-Dist: termcolor>=3.0; extra == "extra"
246
248
  Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
247
249
  Requires-Dist: scipy>=1.16; extra == "extra"
248
250
  Requires-Dist: terratorch>=1.0.2; extra == "extra"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.9"
3
+ version = "0.0.12"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -47,6 +47,8 @@ extra = [
47
47
  "pycocotools>=2.0",
48
48
  "pystac_client>=0.9",
49
49
  "rtree>=1.4",
50
+ # Needed by DINOv3.
51
+ "termcolor>=3.0",
50
52
  "satlaspretrain_models>=0.3",
51
53
  "scipy>=1.16",
52
54
  "terratorch>=1.0.2",
@@ -1,4 +1,8 @@
1
- """AnySat model."""
1
+ """AnySat model.
2
+
3
+ This code loads the AnySat model from torch hub. See
4
+ https://github.com/gastruc/AnySat for applicable license and copyright information.
5
+ """
2
6
 
3
7
  from typing import Any
4
8
 
@@ -1,4 +1,9 @@
1
- """DinoV3 model."""
1
+ """DinoV3 model.
2
+
3
+ This code loads the DINOv3 model. You must obtain the model separately from Meta to use
4
+ it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
5
+ information.
6
+ """
2
7
 
3
8
  from enum import StrEnum
4
9
  from pathlib import Path
@@ -0,0 +1,50 @@
1
+ """Apply center cropping on a feature map."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+
8
+ class FeatureCenterCrop(torch.nn.Module):
9
+ """Apply center cropping on the input feature maps."""
10
+
11
+ def __init__(
12
+ self,
13
+ sizes: list[tuple[int, int]],
14
+ ) -> None:
15
+ """Create a new FeatureCenterCrop.
16
+
17
+ Only the center of each feature map will be retained and passed to the next
18
+ module.
19
+
20
+ Args:
21
+ sizes: a list of (height, width) tuples, with one tuple for each input
22
+ feature map.
23
+ """
24
+ super().__init__()
25
+ self.sizes = sizes
26
+
27
+ def forward(
28
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
29
+ ) -> list[torch.Tensor]:
30
+ """Apply center cropping on the feature maps.
31
+
32
+ Args:
33
+ features: list of feature maps at different resolutions.
34
+ inputs: original inputs (ignored).
35
+
36
+ Returns:
37
+ center cropped feature maps.
38
+ """
39
+ new_features = []
40
+ for i, feat in enumerate(features):
41
+ height, width = self.sizes[i]
42
+ if feat.shape[2] < height or feat.shape[3] < width:
43
+ raise ValueError(
44
+ "feature map is smaller than the desired height and width"
45
+ )
46
+ start_h = feat.shape[2] // 2 - height // 2
47
+ start_w = feat.shape[3] // 2 - width // 2
48
+ feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
49
+ new_features.append(feat)
50
+ return new_features
@@ -0,0 +1 @@
1
+ """OlmoEarth model architecture."""
@@ -0,0 +1,263 @@
1
+ """OlmoEarth model wrapper for fine-tuning in rslearn."""
2
+
3
+ import json
4
+ from contextlib import nullcontext
5
+ from typing import Any
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from olmo_core.config import Config
10
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
+ from olmoearth_pretrain.data.constants import Modality
12
+ from olmoearth_pretrain.model_loader import (
13
+ ModelID,
14
+ load_model_from_id,
15
+ load_model_from_path,
16
+ )
17
+ from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
18
+ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
+ from upath import UPath
20
+
21
+ from rslearn.log_utils import get_logger
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ MODALITY_NAMES = [
26
+ "sentinel2_l2a",
27
+ "sentinel1",
28
+ "worldcover",
29
+ "openstreetmap_raster",
30
+ "landsat",
31
+ ]
32
+
33
+ AUTOCAST_DTYPE_MAP = {
34
+ "bfloat16": torch.bfloat16,
35
+ "float16": torch.float16,
36
+ "float32": torch.float32,
37
+ }
38
+
39
+ EMBEDDING_SIZES = {
40
+ ModelID.OLMOEARTH_V1_NANO: 128,
41
+ ModelID.OLMOEARTH_V1_TINY: 192,
42
+ ModelID.OLMOEARTH_V1_BASE: 768,
43
+ }
44
+
45
+
46
+ class OlmoEarth(torch.nn.Module):
47
+ """A wrapper to support the OlmoEarth model."""
48
+
49
+ def __init__(
50
+ self,
51
+ patch_size: int,
52
+ model_id: ModelID | None = None,
53
+ model_path: str | None = None,
54
+ checkpoint_path: str | None = None,
55
+ selector: list[str | int] = ["encoder"],
56
+ forward_kwargs: dict[str, Any] = {},
57
+ random_initialization: bool = False,
58
+ embedding_size: int | None = None,
59
+ autocast_dtype: str | None = "bfloat16",
60
+ ):
61
+ """Create a new OlmoEarth model.
62
+
63
+ Args:
64
+ patch_size: token spatial patch size to use.
65
+ model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
66
+ set.
67
+ model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
68
+ set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
69
+ checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
70
+ set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
71
+ folder.
72
+ selector: an optional sequence of attribute names or list indices to select
73
+ the sub-module that should be applied on the input images. Defaults to
74
+ ["encoder"] to select only the transformer encoder.
75
+ forward_kwargs: additional arguments to pass to forward pass besides the
76
+ MaskedOlmoEarthSample.
77
+ random_initialization: whether to skip loading the checkpoint so the
78
+ weights are randomly initialized. In this case, the checkpoint is only
79
+ used to define the model architecture.
80
+ embedding_size: optional embedding size to report via
81
+ get_backbone_channels (if model_id is not set).
82
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
83
+ """
84
+ if (
85
+ sum(
86
+ [
87
+ model_id is not None,
88
+ model_path is not None,
89
+ checkpoint_path is not None,
90
+ ]
91
+ )
92
+ != 1
93
+ ):
94
+ raise ValueError(
95
+ "exactly one of model_id, model_path, or checkpoint_path must be set"
96
+ )
97
+
98
+ super().__init__()
99
+ self.patch_size = patch_size
100
+ self.forward_kwargs = forward_kwargs
101
+ self.embedding_size = embedding_size
102
+
103
+ if autocast_dtype is not None:
104
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
105
+ else:
106
+ self.autocast_dtype = None
107
+
108
+ if model_id is not None:
109
+ # Load from Hugging Face.
110
+ model = load_model_from_id(model_id, load_weights=not random_initialization)
111
+ if self.embedding_size is None and model_id in EMBEDDING_SIZES:
112
+ self.embedding_size = EMBEDDING_SIZES[model_id]
113
+
114
+ elif model_path is not None:
115
+ # Load from path.
116
+ model = load_model_from_path(
117
+ UPath(model_path), load_weights=not random_initialization
118
+ )
119
+
120
+ else:
121
+ # Load the distributed model checkpoint by path through Olmo Core
122
+ model = self._load_model_from_checkpoint(
123
+ UPath(checkpoint_path), random_initialization
124
+ )
125
+
126
+ # Select just the portion of the model that we actually want to use.
127
+ for part in selector:
128
+ if isinstance(part, str):
129
+ model = getattr(model, part)
130
+ else:
131
+ model = model[part]
132
+ self.model = model
133
+
134
+ def _load_model_from_checkpoint(
135
+ self, checkpoint_upath: UPath, random_initialization: bool
136
+ ) -> torch.nn.Module:
137
+ """Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
138
+
139
+ The folder should contain config.json as well as the model_and_optim folder
140
+ that contains the distributed checkpoint. This is the format produced by
141
+ pre-training runs in olmoearth_pretrain.
142
+ """
143
+ # Load the model config and initialize it.
144
+ # We avoid loading the train module here because it depends on running within
145
+ # olmo_core.
146
+ with (checkpoint_upath / "config.json").open() as f:
147
+ config_dict = json.load(f)
148
+ model_config = Config.from_dict(config_dict["model"])
149
+
150
+ model = model_config.build()
151
+
152
+ # Load the checkpoint.
153
+ if not random_initialization:
154
+ train_module_dir = checkpoint_upath / "model_and_optim"
155
+ if train_module_dir.exists():
156
+ load_model_and_optim_state(str(train_module_dir), model)
157
+ logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
158
+ else:
159
+ logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
160
+
161
+ return model
162
+
163
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
164
+ """Compute feature maps from the OlmoEarth backbone.
165
+
166
+ Inputs:
167
+ inputs: input dicts. It should include keys corresponding to the modalities
168
+ that should be passed to the OlmoEarth model.
169
+ """
170
+ kwargs = {}
171
+ present_modalities = []
172
+ device = None
173
+ # Handle the case where some modalities are multitemporal and some are not.
174
+ # We assume all multitemporal modalities have the same number of timesteps.
175
+ max_timesteps = 1
176
+ for modality in MODALITY_NAMES:
177
+ if modality not in inputs[0]:
178
+ continue
179
+ present_modalities.append(modality)
180
+ cur = torch.stack([inp[modality] for inp in inputs], dim=0)
181
+ device = cur.device
182
+ # Check if it's single or multitemporal, and reshape accordingly
183
+ num_bands = Modality.get(modality).num_bands
184
+ num_timesteps = cur.shape[1] // num_bands
185
+ max_timesteps = max(max_timesteps, num_timesteps)
186
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
187
+ kwargs[modality] = cur
188
+ # Create mask array which is BHWTS (without channels but with band sets).
189
+ num_band_sets = len(Modality.get(modality).band_sets)
190
+ mask_shape = cur.shape[0:4] + (num_band_sets,)
191
+ mask = (
192
+ torch.ones(mask_shape, dtype=torch.int32, device=device)
193
+ * MaskValue.ONLINE_ENCODER.value
194
+ )
195
+ kwargs[f"{modality}_mask"] = mask
196
+
197
+ # Timestamps is required.
198
+ # Note that only months (0 to 11) are used in OlmoEarth position encoding.
199
+ # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
200
+ timestamps = torch.zeros(
201
+ (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
202
+ )
203
+ timestamps[:, :, 0] = 1 # day
204
+ timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
205
+ None, :
206
+ ] # month
207
+ timestamps[:, :, 2] = 2024 # year
208
+ kwargs["timestamps"] = timestamps
209
+
210
+ sample = MaskedOlmoEarthSample(**kwargs)
211
+
212
+ # Decide context based on self.autocast_dtype.
213
+ if self.autocast_dtype is None:
214
+ context = nullcontext()
215
+ else:
216
+ assert device is not None
217
+ context = torch.amp.autocast(
218
+ device_type=device.type, dtype=self.autocast_dtype
219
+ )
220
+
221
+ with context:
222
+ # Currently we assume the provided model always returns a TokensAndMasks object.
223
+ tokens_and_masks: TokensAndMasks
224
+ if isinstance(self.model, Encoder):
225
+ # Encoder has a fast_pass argument to indicate mask is not needed.
226
+ tokens_and_masks = self.model(
227
+ sample,
228
+ fast_pass=True,
229
+ patch_size=self.patch_size,
230
+ **self.forward_kwargs,
231
+ )["tokens_and_masks"]
232
+ else:
233
+ # Other models like STEncoder do not have this option supported.
234
+ tokens_and_masks = self.model(
235
+ sample, patch_size=self.patch_size, **self.forward_kwargs
236
+ )["tokens_and_masks"]
237
+
238
+ # Apply temporal/modality pooling so we just have one feature per patch.
239
+ features = []
240
+ for modality in present_modalities:
241
+ modality_features = getattr(tokens_and_masks, modality)
242
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
243
+ pooled = modality_features.mean(dim=[3, 4])
244
+ # We want BHWC -> BCHW.
245
+ pooled = rearrange(pooled, "b h w c -> b c h w")
246
+ features.append(pooled)
247
+ # Pool over the modalities, so we get one BCHW feature map.
248
+ pooled = torch.stack(features, dim=0).mean(dim=0)
249
+ return [pooled]
250
+
251
+ def get_backbone_channels(self) -> list:
252
+ """Returns the output channels of this model when used as a backbone.
253
+
254
+ The output channels is a list of (downsample_factor, depth) that corresponds
255
+ to the feature maps that the backbone returns. For example, an element [2, 32]
256
+ indicates that the corresponding feature map is 1/2 the input resolution and
257
+ has 32 channels.
258
+
259
+ Returns:
260
+ the output channels of the backbone as a list of (downsample_factor, depth)
261
+ tuples.
262
+ """
263
+ return [(self.patch_size, self.embedding_size)]
@@ -0,0 +1,84 @@
1
+ """Normalization transforms."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from olmoearth_pretrain.data.normalize import load_computed_config
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.train.transforms.transform import Transform
10
+
11
+ logger = get_logger(__file__)
12
+
13
+
14
+ class OlmoEarthNormalize(Transform):
15
+ """Normalize using OlmoEarth JSON config.
16
+
17
+ For Sentinel-1 data, the values should be converted to decibels before being passed
18
+ to this transform.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ band_names: dict[str, list[str]],
24
+ std_multiplier: float | None = 2,
25
+ config_fname: str | None = None,
26
+ ) -> None:
27
+ """Initialize a new OlmoEarthNormalize.
28
+
29
+ Args:
30
+ band_names: map from modality name to the list of bands in that modality in
31
+ the order they are being loaded. Note that this order must match the
32
+ expected order for the OlmoEarth model.
33
+ std_multiplier: the std multiplier matching the one used for the model
34
+ training in OlmoEarth.
35
+ config_fname: load the normalization configuration from this file, instead
36
+ of getting it from OlmoEarth.
37
+ """
38
+ super().__init__()
39
+ self.band_names = band_names
40
+ self.std_multiplier = std_multiplier
41
+
42
+ if config_fname is None:
43
+ self.norm_config = load_computed_config()
44
+ else:
45
+ logger.warning(
46
+ f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
47
+ )
48
+ with open(config_fname) as f:
49
+ self.norm_config = json.load(f)
50
+
51
+ def forward(
52
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Apply normalization over the inputs and targets.
55
+
56
+ Args:
57
+ input_dict: the input
58
+ target_dict: the target
59
+
60
+ Returns:
61
+ normalized (input_dicts, target_dicts) tuple
62
+ """
63
+ for modality_name, cur_band_names in self.band_names.items():
64
+ band_norms = self.norm_config[modality_name]
65
+ image = input_dict[modality_name]
66
+ # Keep a set of indices to make sure that we normalize all of them.
67
+ needed_band_indices = set(range(image.shape[0]))
68
+ num_timesteps = image.shape[0] // len(cur_band_names)
69
+
70
+ for band, norm_dict in band_norms.items():
71
+ # If multitemporal, normalize each timestep separately.
72
+ for t in range(num_timesteps):
73
+ band_idx = cur_band_names.index(band) + t * len(cur_band_names)
74
+ min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
75
+ max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
76
+ image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
77
+ needed_band_indices.remove(band_idx)
78
+
79
+ if len(needed_band_indices) > 0:
80
+ raise ValueError(
81
+ f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
82
+ )
83
+
84
+ return input_dict, target_dict
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
76
76
  features = torch.amax(features, dim=(2, 3))
77
77
  features = self.fc_layers(features)
78
78
  return self.output_layer(features)
79
+
80
+
81
+ class SegmentationPoolingDecoder(PoolingDecoder):
82
+ """Like PoolingDecoder, but copy output to all pixels.
83
+
84
+ This allows for the model to produce a global output while still being compatible
85
+ with SegmentationTask. This only makes sense for very small windows, since the
86
+ output probabilities will be the same at all pixels. The main use case is to train
87
+ for a classification-like task on small windows, but still produce a raster during
88
+ inference on large windows.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ image_key: str = "image",
96
+ **kwargs: Any,
97
+ ):
98
+ """Create a new SegmentationPoolingDecoder.
99
+
100
+ Args:
101
+ in_channels: input channels (channels in the last feature map passed to
102
+ this module)
103
+ out_channels: channels for the output flat feature vector
104
+ image_key: the key in inputs for the image from which the expected width
105
+ and height is derived.
106
+ kwargs: other arguments to pass to PoolingDecoder.
107
+ """
108
+ super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
+ self.image_key = image_key
110
+
111
+ def forward(
112
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
+ ) -> torch.Tensor:
114
+ """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
+
116
+ This only works when all of the pixels have the same segmentation target.
117
+ """
118
+ output_probs = super().forward(features, inputs)
119
+ # BC -> BCHW
120
+ h, w = inputs[0][self.image_key].shape[1:3]
121
+ return output_probs[:, :, None, None].repeat([1, 1, h, w])
@@ -1,4 +1,12 @@
1
- """Prithvi V2."""
1
+ """Prithvi V2.
2
+
3
+ This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
4
+
5
+ The code is released under:
6
+
7
+ MIT License
8
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
9
+ """
2
10
 
3
11
  import json
4
12
  import logging
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
94
94
  restore_config: RestoreConfig | None = None,
95
95
  print_parameters: bool = False,
96
96
  print_model: bool = False,
97
- strict_loading: bool = True,
98
97
  # Deprecated options.
99
98
  lr: float = 1e-3,
100
99
  plateau: bool = False,
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
118
117
  print_parameters: whether to print the list of model parameters after model
119
118
  initialization
120
119
  print_model: whether to print the model after model initialization
121
- strict_loading: whether to strictly load the model parameters.
122
120
  lr: deprecated.
123
121
  plateau: deprecated.
124
122
  plateau_factor: deprecated.
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
132
130
  self.visualize_dir = visualize_dir
133
131
  self.metrics_file = metrics_file
134
132
  self.restore_config = restore_config
135
- self.strict_loading = strict_loading
136
133
 
137
134
  self.scheduler_factory: SchedulerFactory | None = None
138
135
  if scheduler:
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
49
49
  features with matching properties.
50
50
  read_class_id: whether to read an integer class ID instead of the class
51
51
  name.
52
- allow_invalid: instead of throwing error when no regression label is found
53
- at a window, simply mark the example invalid for this task
52
+ allow_invalid: instead of throwing error when no classification label is
53
+ found at a window, simply mark the example invalid for this task
54
54
  skip_unknown_categories: whether to skip examples with categories that are
55
55
  not passed via classes, instead of throwing error
56
56
  prob_property: when predicting, write probabilities in addition to class ID