rslearn 0.0.24__tar.gz → 0.0.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. {rslearn-0.0.24/rslearn.egg-info → rslearn-0.0.25}/PKG-INFO +14 -1
  2. {rslearn-0.0.24 → rslearn-0.0.25}/README.md +12 -0
  3. {rslearn-0.0.24 → rslearn-0.0.25}/pyproject.toml +2 -1
  4. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/dataset.py +44 -3
  5. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/detection.py +1 -18
  6. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/segmentation.py +1 -19
  7. rslearn-0.0.25/rslearn/utils/colors.py +20 -0
  8. rslearn-0.0.25/rslearn/vis/__init__.py +1 -0
  9. rslearn-0.0.25/rslearn/vis/normalization.py +127 -0
  10. rslearn-0.0.25/rslearn/vis/render_raster_label.py +96 -0
  11. rslearn-0.0.25/rslearn/vis/render_sensor_image.py +27 -0
  12. rslearn-0.0.25/rslearn/vis/render_vector_label.py +439 -0
  13. rslearn-0.0.25/rslearn/vis/utils.py +99 -0
  14. rslearn-0.0.25/rslearn/vis/vis_server.py +574 -0
  15. {rslearn-0.0.24 → rslearn-0.0.25/rslearn.egg-info}/PKG-INFO +14 -1
  16. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn.egg-info/SOURCES.txt +9 -1
  17. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn.egg-info/requires.txt +1 -0
  18. {rslearn-0.0.24 → rslearn-0.0.25}/LICENSE +0 -0
  19. {rslearn-0.0.24 → rslearn-0.0.25}/NOTICE +0 -0
  20. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/__init__.py +0 -0
  21. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/arg_parser.py +0 -0
  22. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/config/__init__.py +0 -0
  23. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/config/dataset.py +0 -0
  24. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/const.py +0 -0
  25. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/__init__.py +0 -0
  26. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/aws_landsat.py +0 -0
  27. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/aws_open_data.py +0 -0
  28. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/aws_sentinel1.py +0 -0
  29. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/aws_sentinel2_element84.py +0 -0
  30. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/climate_data_store.py +0 -0
  31. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/copernicus.py +0 -0
  32. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/data_source.py +0 -0
  33. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/earthdaily.py +0 -0
  34. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/earthdata_srtm.py +0 -0
  35. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/eurocrops.py +0 -0
  36. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/gcp_public_data.py +0 -0
  37. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/google_earth_engine.py +0 -0
  38. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/local_files.py +0 -0
  39. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/openstreetmap.py +0 -0
  40. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/planet.py +0 -0
  41. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/planet_basemap.py +0 -0
  42. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/planetary_computer.py +0 -0
  43. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/soilgrids.py +0 -0
  44. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/stac.py +0 -0
  45. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/usda_cdl.py +0 -0
  46. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/usgs_landsat.py +0 -0
  47. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/utils.py +0 -0
  48. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/vector_source.py +0 -0
  49. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/worldcereal.py +0 -0
  50. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/worldcover.py +0 -0
  51. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/worldpop.py +0 -0
  52. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/data_sources/xyz_tiles.py +0 -0
  53. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/__init__.py +0 -0
  54. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/add_windows.py +0 -0
  55. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/dataset.py +0 -0
  56. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/handler_summaries.py +0 -0
  57. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/manage.py +0 -0
  58. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/materialize.py +0 -0
  59. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/remap.py +0 -0
  60. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/storage/__init__.py +0 -0
  61. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/storage/file.py +0 -0
  62. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/storage/storage.py +0 -0
  63. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/dataset/window.py +0 -0
  64. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/lightning_cli.py +0 -0
  65. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/log_utils.py +0 -0
  66. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/main.py +0 -0
  67. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/__init__.py +0 -0
  68. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/anysat.py +0 -0
  69. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/attention_pooling.py +0 -0
  70. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/clay/clay.py +0 -0
  71. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/clay/configs/metadata.yaml +0 -0
  72. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/clip.py +0 -0
  73. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/component.py +0 -0
  74. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/concatenate_features.py +0 -0
  75. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/conv.py +0 -0
  76. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/croma.py +0 -0
  77. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/__init__.py +0 -0
  78. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/box_ops.py +0 -0
  79. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/detr.py +0 -0
  80. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/matcher.py +0 -0
  81. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/position_encoding.py +0 -0
  82. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/transformer.py +0 -0
  83. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/detr/util.py +0 -0
  84. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/dinov3.py +0 -0
  85. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/faster_rcnn.py +0 -0
  86. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/feature_center_crop.py +0 -0
  87. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/fpn.py +0 -0
  88. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/galileo/__init__.py +0 -0
  89. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/galileo/galileo.py +0 -0
  90. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/galileo/single_file_galileo.py +0 -0
  91. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/module_wrapper.py +0 -0
  92. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/molmo.py +0 -0
  93. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/multitask.py +0 -0
  94. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  95. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/olmoearth_pretrain/model.py +0 -0
  96. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  97. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon.py +0 -0
  98. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  99. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  100. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  101. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  102. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  103. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  104. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  105. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  106. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  107. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  108. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  109. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  110. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/pick_features.py +0 -0
  111. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/pooling_decoder.py +0 -0
  112. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/presto/__init__.py +0 -0
  113. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/presto/presto.py +0 -0
  114. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/presto/single_file_presto.py +0 -0
  115. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/prithvi.py +0 -0
  116. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/resize_features.py +0 -0
  117. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/sam2_enc.py +0 -0
  118. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/satlaspretrain.py +0 -0
  119. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/simple_time_series.py +0 -0
  120. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/singletask.py +0 -0
  121. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/ssl4eo_s12.py +0 -0
  122. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/swin.py +0 -0
  123. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/task_embedding.py +0 -0
  124. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/terramind.py +0 -0
  125. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/trunk.py +0 -0
  126. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/unet.py +0 -0
  127. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/upsample.py +0 -0
  128. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/models/use_croma.py +0 -0
  129. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/py.typed +0 -0
  130. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/template_params.py +0 -0
  131. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/tile_stores/__init__.py +0 -0
  132. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/tile_stores/default.py +0 -0
  133. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/tile_stores/tile_store.py +0 -0
  134. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/__init__.py +0 -0
  135. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/all_patches_dataset.py +0 -0
  136. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/callbacks/__init__.py +0 -0
  137. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/callbacks/adapters.py +0 -0
  138. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  139. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/callbacks/gradients.py +0 -0
  140. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/callbacks/peft.py +0 -0
  141. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/data_module.py +0 -0
  142. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/lightning_module.py +0 -0
  143. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/model_context.py +0 -0
  144. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/optimizer.py +0 -0
  145. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/prediction_writer.py +0 -0
  146. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/scheduler.py +0 -0
  147. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/__init__.py +0 -0
  148. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/classification.py +0 -0
  149. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/embedding.py +0 -0
  150. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/multi_task.py +0 -0
  151. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/per_pixel_regression.py +0 -0
  152. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/regression.py +0 -0
  153. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/tasks/task.py +0 -0
  154. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/__init__.py +0 -0
  155. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/concatenate.py +0 -0
  156. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/crop.py +0 -0
  157. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/flip.py +0 -0
  158. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/mask.py +0 -0
  159. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/normalize.py +0 -0
  160. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/pad.py +0 -0
  161. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/resize.py +0 -0
  162. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/select_bands.py +0 -0
  163. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/sentinel1.py +0 -0
  164. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/train/transforms/transform.py +0 -0
  165. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/__init__.py +0 -0
  166. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/array.py +0 -0
  167. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/feature.py +0 -0
  168. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/fsspec.py +0 -0
  169. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/geometry.py +0 -0
  170. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/get_utm_ups_crs.py +0 -0
  171. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/grid_index.py +0 -0
  172. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/jsonargparse.py +0 -0
  173. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/mp.py +0 -0
  174. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/raster_format.py +0 -0
  175. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/rtree_index.py +0 -0
  176. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/spatial_index.py +0 -0
  177. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/sqlite_index.py +0 -0
  178. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/stac.py +0 -0
  179. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/time.py +0 -0
  180. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn/utils/vector_format.py +0 -0
  181. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn.egg-info/dependency_links.txt +0 -0
  182. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn.egg-info/entry_points.txt +0 -0
  183. {rslearn-0.0.24 → rslearn-0.0.25}/rslearn.egg-info/top_level.txt +0 -0
  184. {rslearn-0.0.24 → rslearn-0.0.25}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.24
3
+ Version: 0.0.25
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -214,6 +214,7 @@ License-File: LICENSE
214
214
  License-File: NOTICE
215
215
  Requires-Dist: boto3>=1.39
216
216
  Requires-Dist: fiona>=1.10
217
+ Requires-Dist: flask>=3.0.0
217
218
  Requires-Dist: fsspec>=2025.10.0
218
219
  Requires-Dist: jsonargparse>=4.35.0
219
220
  Requires-Dist: lightning>=2.5.1.post0
@@ -482,6 +483,18 @@ We can visualize both the GeoTIFFs together in qgis:
482
483
  qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
483
484
  ```
484
485
 
486
+ We can also visualize samples using the visualization module:
487
+ ```
488
+ python -m rslearn.vis.vis_server \
489
+ $DATASET_PATH \
490
+ --layers sentinel2 \ # image modality layers
491
+ --label_layers label_raster \ # layer layers
492
+ --bands '{"sentinel2": ["B04", "B03", "B02"]}' \ # specify bands wanted for each image modality
493
+ --normalization '{"sentinel2": "sentinel2_rgb"}' \ # specify normalization wanted for each image modality
494
+ --task_type segmentation \ # segmentation, detection, or classification
495
+ --max_samples 100 \ # number of datapoints to randomly sample and visualize
496
+ --port 8000
497
+ ```
485
498
 
486
499
  ### Training a Model
487
500
 
@@ -216,6 +216,18 @@ We can visualize both the GeoTIFFs together in qgis:
216
216
  qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
217
217
  ```
218
218
 
219
+ We can also visualize samples using the visualization module:
220
+ ```
221
+ python -m rslearn.vis.vis_server \
222
+ $DATASET_PATH \
223
+ --layers sentinel2 \ # image modality layers
224
+ --label_layers label_raster \ # layer layers
225
+ --bands '{"sentinel2": ["B04", "B03", "B02"]}' \ # specify bands wanted for each image modality
226
+ --normalization '{"sentinel2": "sentinel2_rgb"}' \ # specify normalization wanted for each image modality
227
+ --task_type segmentation \ # segmentation, detection, or classification
228
+ --max_samples 100 \ # number of datapoints to randomly sample and visualize
229
+ --port 8000
230
+ ```
219
231
 
220
232
  ### Training a Model
221
233
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.24"
3
+ version = "0.0.25"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -11,6 +11,7 @@ requires-python = ">=3.11"
11
11
  dependencies = [
12
12
  "boto3>=1.39",
13
13
  "fiona>=1.10",
14
+ "flask>=3.0.0",
14
15
  "fsspec>=2025.10.0", # this is used both directly and indirectly (via universal_pathlib) in our code
15
16
  "jsonargparse>=4.35.0",
16
17
  "lightning>=2.5.1.post0",
@@ -445,6 +445,7 @@ class SplitConfig:
445
445
  overlap_ratio: float | None = None,
446
446
  load_all_patches: bool | None = None,
447
447
  skip_targets: bool | None = None,
448
+ output_layer_name_skip_inference_if_exists: str | None = None,
448
449
  ) -> None:
449
450
  """Initialize a new SplitConfig.
450
451
 
@@ -467,6 +468,10 @@ class SplitConfig:
467
468
  for each window, read all patches as separate sequential items in the
468
469
  dataset.
469
470
  skip_targets: whether to skip targets when loading inputs
471
+ output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
472
+ If set, windows that already
473
+ have this layer completed will be skipped (useful for resuming
474
+ partial inference runs).
470
475
  """
471
476
  self.groups = groups
472
477
  self.names = names
@@ -477,6 +482,9 @@ class SplitConfig:
477
482
  self.sampler = sampler
478
483
  self.patch_size = patch_size
479
484
  self.skip_targets = skip_targets
485
+ self.output_layer_name_skip_inference_if_exists = (
486
+ output_layer_name_skip_inference_if_exists
487
+ )
480
488
 
481
489
  # Note that load_all_patches are handled by the RslearnDataModule rather than
482
490
  # the ModelDataset.
@@ -504,6 +512,7 @@ class SplitConfig:
504
512
  overlap_ratio=self.overlap_ratio,
505
513
  load_all_patches=self.load_all_patches,
506
514
  skip_targets=self.skip_targets,
515
+ output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
507
516
  )
508
517
  if other.groups:
509
518
  result.groups = other.groups
@@ -527,6 +536,10 @@ class SplitConfig:
527
536
  result.load_all_patches = other.load_all_patches
528
537
  if other.skip_targets is not None:
529
538
  result.skip_targets = other.skip_targets
539
+ if other.output_layer_name_skip_inference_if_exists is not None:
540
+ result.output_layer_name_skip_inference_if_exists = (
541
+ other.output_layer_name_skip_inference_if_exists
542
+ )
530
543
  return result
531
544
 
532
545
  def get_patch_size(self) -> tuple[int, int] | None:
@@ -549,16 +562,26 @@ class SplitConfig:
549
562
  """Returns whether skip_targets is enabled (default False)."""
550
563
  return True if self.skip_targets is True else False
551
564
 
565
+ def get_output_layer_name_skip_inference_if_exists(self) -> str | None:
566
+ """Returns output layer to use for resume checks (default None)."""
567
+ return self.output_layer_name_skip_inference_if_exists
568
+
552
569
 
553
- def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
570
+ def check_window(
571
+ inputs: dict[str, DataInput],
572
+ window: Window,
573
+ output_layer_name_skip_inference_if_exists: str | None = None,
574
+ ) -> Window | None:
554
575
  """Verify that the window has the required layers based on the specified inputs.
555
576
 
556
577
  Args:
557
578
  inputs: the inputs to the dataset.
558
579
  window: the window to check.
580
+ output_layer_name_skip_inference_if_exists: optional name of the output layer to check for existence.
559
581
 
560
582
  Returns:
561
- the window if it has all the required inputs or None otherwise
583
+ the window if it has all the required inputs and does not need to be skipped
584
+ due to an existing output layer; or None otherwise
562
585
  """
563
586
 
564
587
  # Make sure window has all the needed layers.
@@ -588,6 +611,16 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
588
611
  )
589
612
  return None
590
613
 
614
+ # Optionally skip windows that already have the specified output layer completed.
615
+ if output_layer_name_skip_inference_if_exists is not None:
616
+ if window.is_layer_completed(output_layer_name_skip_inference_if_exists):
617
+ logger.debug(
618
+ "Skipping window %s since output layer '%s' already exists",
619
+ window.name,
620
+ output_layer_name_skip_inference_if_exists,
621
+ )
622
+ return None
623
+
591
624
  return window
592
625
 
593
626
 
@@ -648,7 +681,14 @@ class ModelDataset(torch.utils.data.Dataset):
648
681
  new_windows = []
649
682
  if workers == 0:
650
683
  for window in windows:
651
- if check_window(self.inputs, window) is None:
684
+ if (
685
+ check_window(
686
+ self.inputs,
687
+ window,
688
+ output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
689
+ )
690
+ is None
691
+ ):
652
692
  continue
653
693
  new_windows.append(window)
654
694
  else:
@@ -660,6 +700,7 @@ class ModelDataset(torch.utils.data.Dataset):
660
700
  dict(
661
701
  inputs=self.inputs,
662
702
  window=window,
703
+ output_layer_name_skip_inference_if_exists=self.split_config.get_output_layer_name_skip_inference_if_exists(),
663
704
  )
664
705
  for window in windows
665
706
  ],
@@ -14,27 +14,10 @@ from torchmetrics import Metric, MetricCollection
14
14
 
15
15
  from rslearn.train.model_context import RasterImage, SampleMetadata
16
16
  from rslearn.utils import Feature, STGeometry
17
+ from rslearn.utils.colors import DEFAULT_COLORS
17
18
 
18
19
  from .task import BasicTask
19
20
 
20
- DEFAULT_COLORS = [
21
- (255, 0, 0),
22
- (0, 255, 0),
23
- (0, 0, 255),
24
- (255, 255, 0),
25
- (0, 255, 255),
26
- (255, 0, 255),
27
- (0, 128, 0),
28
- (255, 160, 122),
29
- (139, 69, 19),
30
- (128, 128, 128),
31
- (255, 255, 255),
32
- (143, 188, 143),
33
- (95, 158, 160),
34
- (255, 200, 0),
35
- (128, 0, 0),
36
- ]
37
-
38
21
 
39
22
  class DetectionTask(BasicTask):
40
23
  """A point or bounding box detection task."""
@@ -17,28 +17,10 @@ from rslearn.train.model_context import (
17
17
  SampleMetadata,
18
18
  )
19
19
  from rslearn.utils import Feature
20
+ from rslearn.utils.colors import DEFAULT_COLORS
20
21
 
21
22
  from .task import BasicTask
22
23
 
23
- # TODO: This is duplicated code fix it
24
- DEFAULT_COLORS = [
25
- (255, 0, 0),
26
- (0, 255, 0),
27
- (0, 0, 255),
28
- (255, 255, 0),
29
- (0, 255, 255),
30
- (255, 0, 255),
31
- (0, 128, 0),
32
- (255, 160, 122),
33
- (139, 69, 19),
34
- (128, 128, 128),
35
- (255, 255, 255),
36
- (143, 188, 143),
37
- (95, 158, 160),
38
- (255, 200, 0),
39
- (128, 0, 0),
40
- ]
41
-
42
24
 
43
25
  class SegmentationTask(BasicTask):
44
26
  """A segmentation (per-pixel classification) task."""
@@ -0,0 +1,20 @@
1
+ """Default color palette for visualizations."""
2
+
3
+ DEFAULT_COLORS = [
4
+ (0, 0, 0),
5
+ (255, 0, 0),
6
+ (0, 255, 0),
7
+ (0, 0, 255),
8
+ (255, 255, 0),
9
+ (0, 255, 255),
10
+ (255, 0, 255),
11
+ (0, 128, 0),
12
+ (255, 160, 122),
13
+ (139, 69, 19),
14
+ (128, 128, 128),
15
+ (255, 255, 255),
16
+ (143, 188, 143),
17
+ (95, 158, 160),
18
+ (255, 200, 0),
19
+ (128, 0, 0),
20
+ ]
@@ -0,0 +1 @@
1
+ """Visualization module for rslearn datasets."""
@@ -0,0 +1,127 @@
1
+ """Normalization functions for raster data visualization."""
2
+
3
+ from collections.abc import Callable
4
+ from enum import StrEnum
5
+
6
+ import numpy as np
7
+
8
+ from rslearn.log_utils import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ class NormalizationMethod(StrEnum):
14
+ """Normalization methods for raster data visualization."""
15
+
16
+ SENTINEL2_RGB = "sentinel2_rgb"
17
+ """Divide by 10 and clip (for Sentinel-2 B04/B03/B02 bands)."""
18
+
19
+ PERCENTILE = "percentile"
20
+ """Use 2-98 percentile clipping."""
21
+
22
+ MINMAX = "minmax"
23
+ """Use min-max stretch."""
24
+
25
+
26
+ def _normalize_sentinel2_rgb(band: np.ndarray) -> np.ndarray:
27
+ """Normalize band using Sentinel-2 RGB method (divide by 10 and clip).
28
+
29
+ Args:
30
+ band: Input band data
31
+
32
+ Returns:
33
+ Normalized band as uint8 array
34
+ """
35
+ band = band / 10.0
36
+ band = np.clip(band, 0, 255).astype(np.uint8)
37
+ return band
38
+
39
+
40
+ def _normalize_percentile(band: np.ndarray) -> np.ndarray:
41
+ """Normalize band using 2-98 percentile clipping.
42
+
43
+ Args:
44
+ band: Input band data
45
+
46
+ Returns:
47
+ Normalized band as uint8 array
48
+ """
49
+ valid_pixels = band[~np.isnan(band)]
50
+ if len(valid_pixels) == 0:
51
+ return np.zeros_like(band, dtype=np.uint8)
52
+ vmin, vmax = np.nanpercentile(valid_pixels, (2, 98))
53
+ if vmax == vmin:
54
+ return np.zeros_like(band, dtype=np.uint8)
55
+ band = np.clip(band, vmin, vmax)
56
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
57
+ return band
58
+
59
+
60
+ def _normalize_minmax(band: np.ndarray) -> np.ndarray:
61
+ """Normalize band using min-max stretch.
62
+
63
+ Args:
64
+ band: Input band data
65
+
66
+ Returns:
67
+ Normalized band as uint8 array
68
+ """
69
+ vmin, vmax = np.nanmin(band), np.nanmax(band)
70
+ if vmax == vmin:
71
+ return np.zeros_like(band, dtype=np.uint8)
72
+ band = np.clip(band, vmin, vmax)
73
+ band = ((band - vmin) / (vmax - vmin) * 255).astype(np.uint8)
74
+ return band
75
+
76
+
77
+ _NORMALIZATION_FUNCTIONS: dict[
78
+ NormalizationMethod, Callable[[np.ndarray], np.ndarray]
79
+ ] = {
80
+ NormalizationMethod.SENTINEL2_RGB: _normalize_sentinel2_rgb,
81
+ NormalizationMethod.PERCENTILE: _normalize_percentile,
82
+ NormalizationMethod.MINMAX: _normalize_minmax,
83
+ }
84
+
85
+
86
+ def normalize_band(
87
+ band: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
88
+ ) -> np.ndarray:
89
+ """Normalize band to 0-255 range.
90
+
91
+ Args:
92
+ band: Input band data
93
+ method: Normalization method (string or NormalizationMethod enum)
94
+ - 'sentinel2_rgb': Divide by 10 and clip (for B04/B03/B02)
95
+ - 'percentile': Use 2-98 percentile clipping
96
+ - 'minmax': Use min-max stretch
97
+
98
+ Returns:
99
+ Normalized band as uint8 array
100
+ """
101
+ method_enum = NormalizationMethod(method) if isinstance(method, str) else method
102
+ normalize_func = _NORMALIZATION_FUNCTIONS.get(method_enum)
103
+ if normalize_func is None:
104
+ raise ValueError(f"Unknown normalization method: {method_enum}")
105
+ return normalize_func(band)
106
+
107
+
108
+ def normalize_array(
109
+ array: np.ndarray, method: str | NormalizationMethod = "sentinel2_rgb"
110
+ ) -> np.ndarray:
111
+ """Normalize a multi-band array to 0-255 range.
112
+
113
+ Args:
114
+ array: Input array with shape (channels, height, width) from RasterFormat.decode_raster
115
+ method: Normalization method (applied per-band, string or NormalizationMethod enum)
116
+
117
+ Returns:
118
+ Normalized array as uint8 with shape (height, width, channels)
119
+ """
120
+ if array.ndim == 3:
121
+ array = np.moveaxis(array, 0, -1)
122
+
123
+ normalized = np.zeros_like(array, dtype=np.uint8)
124
+ for i in range(array.shape[-1]):
125
+ normalized[..., i] = normalize_band(array[..., i], method)
126
+
127
+ return normalized
@@ -0,0 +1,96 @@
1
+ """Functions for rendering raster label masks (e.g., segmentation masks)."""
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from rasterio.warp import Resampling
6
+
7
+ from rslearn.config import DType, LayerConfig
8
+ from rslearn.dataset import Window
9
+ from rslearn.log_utils import get_logger
10
+ from rslearn.train.dataset import DataInput, read_raster_layer_for_data_input
11
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def render_raster_label(
17
+ label_array: np.ndarray,
18
+ label_colors: dict[str, tuple[int, int, int]],
19
+ layer_config: LayerConfig,
20
+ ) -> np.ndarray:
21
+ """Render a raster label array as a colored mask numpy array.
22
+
23
+ Args:
24
+ label_array: Raster label array with shape (bands, height, width) - typically single band
25
+ label_colors: Dictionary mapping label class names to RGB color tuples
26
+ layer_config: LayerConfig object (to access class_names if available)
27
+
28
+ Returns:
29
+ Array with shape (height, width, 3) as uint8
30
+ """
31
+ if label_array.ndim == 3:
32
+ label_values = label_array[0, :, :]
33
+ else:
34
+ label_values = label_array
35
+
36
+ height, width = label_values.shape
37
+ mask_img = np.zeros((height, width, 3), dtype=np.uint8)
38
+ valid_mask = ~np.isnan(label_values)
39
+
40
+ if not layer_config.class_names:
41
+ raise ValueError(
42
+ "class_names must be specified in config for raster label layer"
43
+ )
44
+
45
+ label_int = label_values.astype(np.int32)
46
+ for idx in range(len(layer_config.class_names)):
47
+ class_name = layer_config.class_names[idx]
48
+ color = label_colors.get(str(class_name), (0, 0, 0))
49
+ mask = (label_int == idx) & valid_mask
50
+ mask_img[mask] = color
51
+
52
+ img = Image.fromarray(mask_img, mode="RGB")
53
+ return np.array(img)
54
+
55
+
56
+ def read_raster_layer(
57
+ window: Window,
58
+ layer_name: str,
59
+ layer_config: LayerConfig,
60
+ band_names: list[str],
61
+ group_idx: int = 0,
62
+ bounds: PixelBounds | None = None,
63
+ ) -> np.ndarray:
64
+ """Read a raster layer for visualization.
65
+
66
+ This reads bands from potentially multiple band sets to get the requested bands.
67
+ Uses read_raster_layer_for_data_input from rslearn.train.dataset.
68
+
69
+ Args:
70
+ window: The window to read from
71
+ layer_name: The layer name
72
+ layer_config: The layer configuration
73
+ band_names: List of band names to read (e.g., ["B04", "B03", "B02"])
74
+ group_idx: The item group index (default 0)
75
+ bounds: Optional bounds to read. If None, uses window.bounds
76
+
77
+ Returns:
78
+ Array with shape (bands, height, width) as float32
79
+ """
80
+ if bounds is None:
81
+ bounds = window.bounds
82
+
83
+ data_input = DataInput(
84
+ data_type="raster",
85
+ layers=[layer_name],
86
+ bands=band_names,
87
+ dtype=DType.FLOAT32,
88
+ resolution_factor=ResolutionFactor(), # Default 1/1, no scaling
89
+ resampling=Resampling.nearest,
90
+ )
91
+
92
+ image_tensor = read_raster_layer_for_data_input(
93
+ window, bounds, layer_name, group_idx, layer_config, data_input
94
+ )
95
+
96
+ return image_tensor.numpy().astype(np.float32)
@@ -0,0 +1,27 @@
1
+ """Functions for rendering raster sensor images (e.g., Sentinel-2, Landsat)."""
2
+
3
+ import numpy as np
4
+
5
+ from .normalization import normalize_array
6
+
7
+
8
+ def render_sensor_image(
9
+ array: np.ndarray,
10
+ normalization_method: str,
11
+ ) -> np.ndarray:
12
+ """Render a raster sensor image array as a numpy array.
13
+
14
+ Args:
15
+ array: Array with shape (channels, height, width) from RasterFormat.decode_raster
16
+ normalization_method: Normalization method to apply
17
+
18
+ Returns:
19
+ Array with shape (height, width, channels) as uint8
20
+ """
21
+ normalized = normalize_array(array, normalization_method)
22
+
23
+ # If more than 3 channels, take only the first 3 for RGB
24
+ if normalized.shape[-1] > 3:
25
+ normalized = normalized[:, :, :3]
26
+
27
+ return normalized