rslearn 0.0.27__tar.gz → 0.0.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. {rslearn-0.0.27/rslearn.egg-info → rslearn-0.0.28}/PKG-INFO +1 -1
  2. {rslearn-0.0.27 → rslearn-0.0.28}/pyproject.toml +1 -1
  3. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/storage/file.py +16 -12
  4. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/tile_stores/default.py +4 -2
  5. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/data_module.py +10 -7
  6. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/dataset.py +118 -74
  7. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/lightning_module.py +59 -3
  8. rslearn-0.0.28/rslearn/train/metrics.py +162 -0
  9. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/classification.py +13 -0
  10. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/per_pixel_regression.py +19 -6
  11. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/regression.py +18 -2
  12. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/segmentation.py +17 -0
  13. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/fsspec.py +51 -1
  14. {rslearn-0.0.27 → rslearn-0.0.28/rslearn.egg-info}/PKG-INFO +1 -1
  15. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn.egg-info/SOURCES.txt +1 -0
  16. {rslearn-0.0.27 → rslearn-0.0.28}/LICENSE +0 -0
  17. {rslearn-0.0.27 → rslearn-0.0.28}/NOTICE +0 -0
  18. {rslearn-0.0.27 → rslearn-0.0.28}/README.md +0 -0
  19. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/__init__.py +0 -0
  20. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/arg_parser.py +0 -0
  21. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/config/__init__.py +0 -0
  22. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/config/dataset.py +0 -0
  23. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/const.py +0 -0
  24. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/__init__.py +0 -0
  25. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/aws_landsat.py +0 -0
  26. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/aws_open_data.py +0 -0
  27. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/aws_sentinel1.py +0 -0
  28. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/aws_sentinel2_element84.py +0 -0
  29. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/climate_data_store.py +0 -0
  30. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/copernicus.py +0 -0
  31. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/data_source.py +0 -0
  32. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/direct_materialize_data_source.py +0 -0
  33. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/earthdaily.py +0 -0
  34. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/earthdatahub.py +0 -0
  35. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/eurocrops.py +0 -0
  36. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/gcp_public_data.py +0 -0
  37. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/google_earth_engine.py +0 -0
  38. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/hf_srtm.py +0 -0
  39. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/local_files.py +0 -0
  40. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/openstreetmap.py +0 -0
  41. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/planet.py +0 -0
  42. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/planet_basemap.py +0 -0
  43. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/planetary_computer.py +0 -0
  44. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/soilgrids.py +0 -0
  45. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/stac.py +0 -0
  46. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/usda_cdl.py +0 -0
  47. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/usgs_landsat.py +0 -0
  48. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/utils.py +0 -0
  49. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/vector_source.py +0 -0
  50. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/worldcereal.py +0 -0
  51. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/worldcover.py +0 -0
  52. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/worldpop.py +0 -0
  53. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/data_sources/xyz_tiles.py +0 -0
  54. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/__init__.py +0 -0
  55. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/add_windows.py +0 -0
  56. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/dataset.py +0 -0
  57. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/handler_summaries.py +0 -0
  58. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/manage.py +0 -0
  59. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/materialize.py +0 -0
  60. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/remap.py +0 -0
  61. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/storage/__init__.py +0 -0
  62. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/storage/storage.py +0 -0
  63. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/dataset/window.py +0 -0
  64. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/lightning_cli.py +0 -0
  65. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/log_utils.py +0 -0
  66. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/main.py +0 -0
  67. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/__init__.py +0 -0
  68. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/anysat.py +0 -0
  69. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/attention_pooling.py +0 -0
  70. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/clay/clay.py +0 -0
  71. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/clay/configs/metadata.yaml +0 -0
  72. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/clip.py +0 -0
  73. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/component.py +0 -0
  74. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/concatenate_features.py +0 -0
  75. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/conv.py +0 -0
  76. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/croma.py +0 -0
  77. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/__init__.py +0 -0
  78. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/box_ops.py +0 -0
  79. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/detr.py +0 -0
  80. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/matcher.py +0 -0
  81. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/position_encoding.py +0 -0
  82. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/transformer.py +0 -0
  83. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/detr/util.py +0 -0
  84. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/dinov3.py +0 -0
  85. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/faster_rcnn.py +0 -0
  86. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/feature_center_crop.py +0 -0
  87. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/fpn.py +0 -0
  88. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/galileo/__init__.py +0 -0
  89. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/galileo/galileo.py +0 -0
  90. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/galileo/single_file_galileo.py +0 -0
  91. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/module_wrapper.py +0 -0
  92. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/molmo.py +0 -0
  93. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/multitask.py +0 -0
  94. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  95. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/olmoearth_pretrain/model.py +0 -0
  96. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  97. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon.py +0 -0
  98. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  99. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  100. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  101. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  102. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  103. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  104. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  105. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  106. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  107. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  108. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  109. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  110. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/pick_features.py +0 -0
  111. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/pooling_decoder.py +0 -0
  112. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/presto/__init__.py +0 -0
  113. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/presto/presto.py +0 -0
  114. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/presto/single_file_presto.py +0 -0
  115. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/prithvi.py +0 -0
  116. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/resize_features.py +0 -0
  117. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/sam2_enc.py +0 -0
  118. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/satlaspretrain.py +0 -0
  119. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/simple_time_series.py +0 -0
  120. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/singletask.py +0 -0
  121. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/ssl4eo_s12.py +0 -0
  122. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/swin.py +0 -0
  123. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/task_embedding.py +0 -0
  124. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/terramind.py +0 -0
  125. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/trunk.py +0 -0
  126. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/unet.py +0 -0
  127. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/upsample.py +0 -0
  128. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/models/use_croma.py +0 -0
  129. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/py.typed +0 -0
  130. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/template_params.py +0 -0
  131. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/tile_stores/__init__.py +0 -0
  132. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/tile_stores/tile_store.py +0 -0
  133. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/__init__.py +0 -0
  134. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/all_crops_dataset.py +0 -0
  135. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/callbacks/__init__.py +0 -0
  136. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/callbacks/adapters.py +0 -0
  137. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  138. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/callbacks/gradients.py +0 -0
  139. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/callbacks/peft.py +0 -0
  140. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/dataset_index.py +0 -0
  141. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/model_context.py +0 -0
  142. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/optimizer.py +0 -0
  143. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/prediction_writer.py +0 -0
  144. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/scheduler.py +0 -0
  145. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/__init__.py +0 -0
  146. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/detection.py +0 -0
  147. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/embedding.py +0 -0
  148. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/multi_task.py +0 -0
  149. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/tasks/task.py +0 -0
  150. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/__init__.py +0 -0
  151. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/concatenate.py +0 -0
  152. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/crop.py +0 -0
  153. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/flip.py +0 -0
  154. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/mask.py +0 -0
  155. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/normalize.py +0 -0
  156. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/pad.py +0 -0
  157. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/resize.py +0 -0
  158. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/select_bands.py +0 -0
  159. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/sentinel1.py +0 -0
  160. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/train/transforms/transform.py +0 -0
  161. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/__init__.py +0 -0
  162. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/array.py +0 -0
  163. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/colors.py +0 -0
  164. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/feature.py +0 -0
  165. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/geometry.py +0 -0
  166. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/get_utm_ups_crs.py +0 -0
  167. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/grid_index.py +0 -0
  168. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/jsonargparse.py +0 -0
  169. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/m2m_api.py +0 -0
  170. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/mp.py +0 -0
  171. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/raster_format.py +0 -0
  172. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/retry_session.py +0 -0
  173. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/rtree_index.py +0 -0
  174. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/spatial_index.py +0 -0
  175. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/sqlite_index.py +0 -0
  176. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/stac.py +0 -0
  177. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/time.py +0 -0
  178. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/utils/vector_format.py +0 -0
  179. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/__init__.py +0 -0
  180. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/normalization.py +0 -0
  181. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/render_raster_label.py +0 -0
  182. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/render_sensor_image.py +0 -0
  183. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/render_vector_label.py +0 -0
  184. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/utils.py +0 -0
  185. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn/vis/vis_server.py +0 -0
  186. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn.egg-info/dependency_links.txt +0 -0
  187. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn.egg-info/entry_points.txt +0 -0
  188. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn.egg-info/requires.txt +0 -0
  189. {rslearn-0.0.27 → rslearn-0.0.28}/rslearn.egg-info/top_level.txt +0 -0
  190. {rslearn-0.0.27 → rslearn-0.0.28}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.27
3
+ Version: 0.0.28
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.27"
3
+ version = "0.0.28"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -15,7 +15,7 @@ from rslearn.dataset.window import (
15
15
  get_window_layer_dir,
16
16
  )
17
17
  from rslearn.log_utils import get_logger
18
- from rslearn.utils.fsspec import open_atomic
18
+ from rslearn.utils.fsspec import iter_nonhidden_subdirs, open_atomic
19
19
  from rslearn.utils.mp import star_imap_unordered
20
20
 
21
21
  from .storage import WindowStorage, WindowStorageFactory
@@ -77,8 +77,8 @@ class FileWindowStorage(WindowStorage):
77
77
  window_dirs = []
78
78
  if not groups:
79
79
  groups = []
80
- for p in (self.path / "windows").iterdir():
81
- groups.append(p.name)
80
+ for group_dir in iter_nonhidden_subdirs(self.path / "windows"):
81
+ groups.append(group_dir.name)
82
82
  for group in groups:
83
83
  group_dir = self.path / "windows" / group
84
84
  if not group_dir.exists():
@@ -86,16 +86,20 @@ class FileWindowStorage(WindowStorage):
86
86
  f"Skipping group directory {group_dir} since it does not exist"
87
87
  )
88
88
  continue
89
+ if not group_dir.is_dir():
90
+ logger.warning(
91
+ f"Skipping group path {group_dir} since it is not a directory"
92
+ )
93
+ continue
89
94
  if names:
90
- cur_names = names
95
+ for window_name in names:
96
+ window_dir = group_dir / window_name
97
+ if not window_dir.is_dir():
98
+ continue
99
+ window_dirs.append(window_dir)
91
100
  else:
92
- cur_names = []
93
- for p in group_dir.iterdir():
94
- cur_names.append(p.name)
95
-
96
- for window_name in cur_names:
97
- window_dir = group_dir / window_name
98
- window_dirs.append(window_dir)
101
+ for window_dir in iter_nonhidden_subdirs(group_dir):
102
+ window_dirs.append(window_dir)
99
103
 
100
104
  if workers == 0:
101
105
  windows = [load_window(self, window_dir) for window_dir in window_dirs]
@@ -162,7 +166,7 @@ class FileWindowStorage(WindowStorage):
162
166
  return []
163
167
 
164
168
  completed_layers = []
165
- for layer_dir in layers_directory.iterdir():
169
+ for layer_dir in iter_nonhidden_subdirs(layers_directory):
166
170
  layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
167
171
  if not self.is_layer_completed(group, name, layer_name, group_idx):
168
172
  continue
@@ -15,6 +15,8 @@ from upath import UPath
15
15
  from rslearn.const import WGS84_PROJECTION
16
16
  from rslearn.utils.feature import Feature
17
17
  from rslearn.utils.fsspec import (
18
+ iter_nonhidden_files,
19
+ iter_nonhidden_subdirs,
18
20
  join_upath,
19
21
  open_atomic,
20
22
  open_rasterio_upath_reader,
@@ -129,7 +131,7 @@ class DefaultTileStore(TileStore):
129
131
  ValueError: if no file is found.
130
132
  """
131
133
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
132
- for fname in raster_dir.iterdir():
134
+ for fname in iter_nonhidden_files(raster_dir):
133
135
  # Ignore completed sentinel files, bands files, as well as temporary files created by
134
136
  # open_atomic (in case this tile store is on local filesystem).
135
137
  if fname.name == COMPLETED_FNAME:
@@ -175,7 +177,7 @@ class DefaultTileStore(TileStore):
175
177
  return []
176
178
 
177
179
  bands: list[list[str]] = []
178
- for raster_dir in item_dir.iterdir():
180
+ for raster_dir in iter_nonhidden_subdirs(item_dir):
179
181
  if not (raster_dir / BANDS_FNAME).exists():
180
182
  # This is likely a legacy directory where the bands are only encoded in
181
183
  # the directory name, so we have to rely on that.
@@ -108,10 +108,10 @@ class RslearnDataModule(L.LightningDataModule):
108
108
  self.use_in_memory_all_crops_dataset = use_in_memory_all_crops_dataset
109
109
  self.index_mode = index_mode
110
110
  self.split_configs = {
111
- "train": default_config.update(train_config),
112
- "val": default_config.update(val_config),
113
- "test": default_config.update(test_config),
114
- "predict": default_config.update(predict_config),
111
+ "train": SplitConfig.merge_and_validate([default_config, train_config]),
112
+ "val": SplitConfig.merge_and_validate([default_config, val_config]),
113
+ "test": SplitConfig.merge_and_validate([default_config, test_config]),
114
+ "predict": SplitConfig.merge_and_validate([default_config, predict_config]),
115
115
  }
116
116
 
117
117
  def setup(
@@ -141,7 +141,7 @@ class RslearnDataModule(L.LightningDataModule):
141
141
  task=self.task,
142
142
  workers=self.init_workers,
143
143
  name=self.name,
144
- fix_patch_pick=(split != "train"),
144
+ fix_crop_pick=(split != "train"),
145
145
  index_mode=self.index_mode,
146
146
  )
147
147
  logger.info(f"got {len(dataset)} examples in split {split}")
@@ -203,13 +203,16 @@ class RslearnDataModule(L.LightningDataModule):
203
203
  # Enable persistent workers unless we are using main process.
204
204
  persistent_workers = self.num_workers > 0
205
205
 
206
- # If using all patches, limit number of workers to the number of windows.
206
+ # If using all crops, limit number of workers to the number of windows.
207
207
  # Otherwise it has to distribute the same window to different workers which can
208
208
  # cause issues for RslearnWriter.
209
209
  # If the number of windows is 0, then we can set positive number of workers
210
210
  # since they won't yield anything anyway.
211
211
  num_workers = self.num_workers
212
- if split_config.load_all_crops and len(dataset.get_dataset_examples()) > 0:
212
+ if (
213
+ split_config.get_load_all_crops()
214
+ and len(dataset.get_dataset_examples()) > 0
215
+ ):
213
216
  num_workers = min(num_workers, len(dataset.get_dataset_examples()))
214
217
 
215
218
  kwargs: dict[str, Any] = dict(
@@ -496,53 +496,6 @@ class SplitConfig:
496
496
  overlap_ratio: deprecated, use overlap_pixels instead
497
497
  load_all_patches: deprecated, use load_all_crops instead
498
498
  """
499
- # Handle deprecated load_all_patches parameter
500
- if load_all_patches is not None:
501
- warnings.warn(
502
- "load_all_patches is deprecated, use load_all_crops instead",
503
- FutureWarning,
504
- stacklevel=2,
505
- )
506
- if load_all_crops is not None:
507
- raise ValueError(
508
- "Cannot specify both load_all_patches and load_all_crops"
509
- )
510
- load_all_crops = load_all_patches
511
- # Handle deprecated patch_size parameter
512
- if patch_size is not None:
513
- warnings.warn(
514
- "patch_size is deprecated, use crop_size instead",
515
- FutureWarning,
516
- stacklevel=2,
517
- )
518
- if crop_size is not None:
519
- raise ValueError("Cannot specify both patch_size and crop_size")
520
- crop_size = patch_size
521
-
522
- # Normalize crop_size to tuple[int, int] | None
523
- self.crop_size: tuple[int, int] | None = None
524
- if crop_size is not None:
525
- if isinstance(crop_size, int):
526
- self.crop_size = (crop_size, crop_size)
527
- else:
528
- self.crop_size = crop_size
529
-
530
- # Handle deprecated overlap_ratio parameter
531
- if overlap_ratio is not None:
532
- warnings.warn(
533
- "overlap_ratio is deprecated, use overlap_pixels instead",
534
- FutureWarning,
535
- stacklevel=2,
536
- )
537
- if overlap_pixels is not None:
538
- raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
539
- if self.crop_size is None:
540
- raise ValueError("overlap_ratio requires crop_size to be set")
541
- overlap_pixels = round(self.crop_size[0] * overlap_ratio)
542
-
543
- if overlap_pixels is not None and overlap_pixels < 0:
544
- raise ValueError("overlap_pixels must be non-negative")
545
-
546
499
  self.groups = groups
547
500
  self.names = names
548
501
  self.tags = tags
@@ -555,13 +508,22 @@ class SplitConfig:
555
508
  output_layer_name_skip_inference_if_exists
556
509
  )
557
510
 
558
- # Note that load_all_crops is handled by the RslearnDataModule rather than the
559
- # ModelDataset.
560
- self.load_all_crops = load_all_crops
561
- self.overlap_pixels = overlap_pixels
511
+ # These have deprecated equivalents -- we store both raw values since we don't
512
+ # have a complete picture until the final merged SplitConfig is computed. We
513
+ # raise deprecation warnings in merge_and_validate and we disambiguate them in
514
+ # get_ functions (so the variables should never be accessed directly).
515
+ self._crop_size = crop_size
516
+ self._patch_size = patch_size
517
+ self._overlap_pixels = overlap_pixels
518
+ self._overlap_ratio = overlap_ratio
519
+ self._load_all_crops = load_all_crops
520
+ self._load_all_patches = load_all_patches
562
521
 
563
- def update(self, other: "SplitConfig") -> "SplitConfig":
564
- """Override settings in this SplitConfig with those in another.
522
+ def _merge(self, other: "SplitConfig") -> "SplitConfig":
523
+ """Merge settings from another SplitConfig into this one.
524
+
525
+ Args:
526
+ other: the config to merge in (its non-None values override self's)
565
527
 
566
528
  Returns:
567
529
  the resulting SplitConfig combining the settings.
@@ -574,9 +536,12 @@ class SplitConfig:
574
536
  num_patches=self.num_patches,
575
537
  transforms=self.transforms,
576
538
  sampler=self.sampler,
577
- crop_size=self.crop_size,
578
- overlap_pixels=self.overlap_pixels,
579
- load_all_crops=self.load_all_crops,
539
+ crop_size=self._crop_size,
540
+ patch_size=self._patch_size,
541
+ overlap_pixels=self._overlap_pixels,
542
+ overlap_ratio=self._overlap_ratio,
543
+ load_all_crops=self._load_all_crops,
544
+ load_all_patches=self._load_all_patches,
580
545
  skip_targets=self.skip_targets,
581
546
  output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
582
547
  )
@@ -594,12 +559,18 @@ class SplitConfig:
594
559
  result.transforms = other.transforms
595
560
  if other.sampler:
596
561
  result.sampler = other.sampler
597
- if other.crop_size:
598
- result.crop_size = other.crop_size
599
- if other.overlap_pixels is not None:
600
- result.overlap_pixels = other.overlap_pixels
601
- if other.load_all_crops is not None:
602
- result.load_all_crops = other.load_all_crops
562
+ if other._crop_size is not None:
563
+ result._crop_size = other._crop_size
564
+ if other._patch_size is not None:
565
+ result._patch_size = other._patch_size
566
+ if other._overlap_pixels is not None:
567
+ result._overlap_pixels = other._overlap_pixels
568
+ if other._overlap_ratio is not None:
569
+ result._overlap_ratio = other._overlap_ratio
570
+ if other._load_all_crops is not None:
571
+ result._load_all_crops = other._load_all_crops
572
+ if other._load_all_patches is not None:
573
+ result._load_all_patches = other._load_all_patches
603
574
  if other.skip_targets is not None:
604
575
  result.skip_targets = other.skip_targets
605
576
  if other.output_layer_name_skip_inference_if_exists is not None:
@@ -608,17 +579,90 @@ class SplitConfig:
608
579
  )
609
580
  return result
610
581
 
582
+ @staticmethod
583
+ def merge_and_validate(configs: list["SplitConfig"]) -> "SplitConfig":
584
+ """Merge a list of SplitConfigs and validate the result.
585
+
586
+ Args:
587
+ configs: list of SplitConfig to merge. Later configs override earlier ones.
588
+
589
+ Returns:
590
+ the merged and validated SplitConfig.
591
+ """
592
+ if not configs:
593
+ return SplitConfig()
594
+
595
+ result = configs[0]
596
+ for config in configs[1:]:
597
+ result = result._merge(config)
598
+
599
+ # Emit deprecation warnings
600
+ if result._patch_size is not None:
601
+ warnings.warn(
602
+ "patch_size is deprecated, use crop_size instead",
603
+ FutureWarning,
604
+ stacklevel=2,
605
+ )
606
+ if result._overlap_ratio is not None:
607
+ warnings.warn(
608
+ "overlap_ratio is deprecated, use overlap_pixels instead",
609
+ FutureWarning,
610
+ stacklevel=2,
611
+ )
612
+ if result._load_all_patches is not None:
613
+ warnings.warn(
614
+ "load_all_patches is deprecated, use load_all_crops instead",
615
+ FutureWarning,
616
+ stacklevel=2,
617
+ )
618
+
619
+ # Check for conflicting parameters
620
+ if result._crop_size is not None and result._patch_size is not None:
621
+ raise ValueError("Cannot specify both crop_size and patch_size")
622
+ if result._overlap_pixels is not None and result._overlap_ratio is not None:
623
+ raise ValueError("Cannot specify both overlap_pixels and overlap_ratio")
624
+ if result._load_all_crops is not None and result._load_all_patches is not None:
625
+ raise ValueError("Cannot specify both load_all_crops and load_all_patches")
626
+
627
+ # Validate overlap_pixels is non-negative
628
+ if result._overlap_pixels is not None and result._overlap_pixels < 0:
629
+ raise ValueError("overlap_pixels must be non-negative")
630
+
631
+ # overlap_pixels requires load_all_crops.
632
+ if result.get_overlap_pixels() > 0 and not result.get_load_all_crops():
633
+ raise ValueError(
634
+ "overlap_pixels requires load_all_crops to be True since (overlap is only used during sliding window inference"
635
+ )
636
+
637
+ return result
638
+
611
639
  def get_crop_size(self) -> tuple[int, int] | None:
612
- """Get crop size as tuple."""
613
- return self.crop_size
640
+ """Get crop size as tuple, handling deprecated patch_size."""
641
+ size = self._crop_size if self._crop_size is not None else self._patch_size
642
+ if size is None:
643
+ return None
644
+ if isinstance(size, int):
645
+ return (size, size)
646
+ return size
614
647
 
615
648
  def get_overlap_pixels(self) -> int:
616
- """Get the overlap pixels (default 0)."""
617
- return self.overlap_pixels if self.overlap_pixels is not None else 0
649
+ """Get the overlap pixels (default 0), handling deprecated overlap_ratio."""
650
+ if self._overlap_pixels is not None:
651
+ return self._overlap_pixels
652
+ if self._overlap_ratio is not None:
653
+ crop_size = self.get_crop_size()
654
+ if crop_size is None:
655
+ raise ValueError("overlap_ratio requires crop_size to be set")
656
+ return round(crop_size[0] * self._overlap_ratio)
657
+ return 0
618
658
 
619
659
  def get_load_all_crops(self) -> bool:
620
- """Returns whether loading all patches is enabled (default False)."""
621
- return True if self.load_all_crops is True else False
660
+ """Returns whether loading all crops is enabled (default False)."""
661
+ if self._load_all_crops is not None:
662
+ return self._load_all_crops
663
+ if self._load_all_patches is not None:
664
+ return self._load_all_patches
665
+ return False
622
666
 
623
667
  def get_skip_targets(self) -> bool:
624
668
  """Returns whether skip_targets is enabled (default False)."""
@@ -697,7 +741,7 @@ class ModelDataset(torch.utils.data.Dataset):
697
741
  task: Task,
698
742
  workers: int,
699
743
  name: str | None = None,
700
- fix_patch_pick: bool = False,
744
+ fix_crop_pick: bool = False,
701
745
  index_mode: IndexMode = IndexMode.OFF,
702
746
  ) -> None:
703
747
  """Instantiate a new ModelDataset.
@@ -709,7 +753,7 @@ class ModelDataset(torch.utils.data.Dataset):
709
753
  task: the task to train on
710
754
  workers: number of workers to use for initializing the dataset
711
755
  name: name of the dataset
712
- fix_patch_pick: if True, fix the patch pick to be the same every time
756
+ fix_crop_pick: if True, fix the crop pick to be the same every time
713
757
  for a given window. Useful for testing (default: False)
714
758
  index_mode: controls dataset index caching behavior (default: IndexMode.OFF)
715
759
  """
@@ -718,14 +762,14 @@ class ModelDataset(torch.utils.data.Dataset):
718
762
  self.inputs = inputs
719
763
  self.task = task
720
764
  self.name = name
721
- self.fix_patch_pick = fix_patch_pick
765
+ self.fix_crop_pick = fix_crop_pick
722
766
  if split_config.transforms:
723
767
  self.transforms = Sequential(*split_config.transforms)
724
768
  else:
725
769
  self.transforms = rslearn.train.transforms.transform.Identity()
726
770
 
727
771
  # Get normalized crop size from the SplitConfig.
728
- # But if load all patches is enabled, this is handled by AllCropsDataset, so
772
+ # But if load_all_crops is enabled, this is handled by AllCropsDataset, so
729
773
  # here we instead load the entire windows.
730
774
  if split_config.get_load_all_crops():
731
775
  self.crop_size = None
@@ -952,7 +996,7 @@ class ModelDataset(torch.utils.data.Dataset):
952
996
  """Get a list of examples in the dataset.
953
997
 
954
998
  If load_all_crops is False, this is a list of Windows. Otherwise, this is a
955
- list of (window, crop_bounds, (crop_idx, # patches)) tuples.
999
+ list of (window, crop_bounds, (crop_idx, # crops)) tuples.
956
1000
  """
957
1001
  if self.dataset_examples is None:
958
1002
  logger.debug(
@@ -985,7 +1029,7 @@ class ModelDataset(torch.utils.data.Dataset):
985
1029
  """
986
1030
  dataset_examples = self.get_dataset_examples()
987
1031
  example = dataset_examples[idx]
988
- rng = random.Random(idx if self.fix_patch_pick else None)
1032
+ rng = random.Random(idx if self.fix_crop_pick else None)
989
1033
 
990
1034
  # Select bounds to read.
991
1035
  if self.crop_size:
@@ -6,12 +6,14 @@ from typing import Any
6
6
 
7
7
  import lightning as L
8
8
  import torch
9
+ import wandb
9
10
  from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
10
11
  from PIL import Image
11
12
  from upath import UPath
12
13
 
13
14
  from rslearn.log_utils import get_logger
14
15
 
16
+ from .metrics import NonScalarMetricOutput
15
17
  from .model_context import ModelContext, ModelOutput
16
18
  from .optimizer import AdamW, OptimizerFactory
17
19
  from .scheduler import PlateauScheduler, SchedulerFactory
@@ -210,15 +212,53 @@ class RslearnLightningModule(L.LightningModule):
210
212
  # Fail silently for single-dataset case, which is okay
211
213
  pass
212
214
 
215
+ def _log_non_scalar_metric(self, name: str, value: NonScalarMetricOutput) -> None:
216
+ """Log a non-scalar metric to wandb.
217
+
218
+ Args:
219
+ name: the metric name (e.g., "val_confusion_matrix")
220
+ value: the non-scalar metric output
221
+ """
222
+ # The non-scalar metrics are logging directly without Lightning
223
+ # So we need to skip logging during sanity check.
224
+ if self.trainer.sanity_checking:
225
+ return
226
+
227
+ # Wandb is required for logging non-scalar metrics.
228
+ if not wandb.run:
229
+ logger.warning(
230
+ f"Weights & Biases is not initialized, skipping logging of {name}"
231
+ )
232
+ return
233
+
234
+ value.log_to_wandb(name)
235
+
213
236
  def on_validation_epoch_end(self) -> None:
214
237
  """Compute and log validation metrics at epoch end.
215
238
 
216
239
  We manually compute and log metrics here (instead of passing the MetricCollection
217
240
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
241
  metrics, while log_dict expects each metric to return a scalar tensor.
242
+
243
+ Non-scalar metrics (like confusion matrices) are logged separately using
244
+ logger-specific APIs.
219
245
  """
220
246
  metrics = self.val_metrics.compute()
221
- self.log_dict(metrics)
247
+
248
+ # Separate scalar and non-scalar metrics
249
+ scalar_metrics = {}
250
+ for k, v in metrics.items():
251
+ if isinstance(v, NonScalarMetricOutput):
252
+ self._log_non_scalar_metric(k, v)
253
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
254
+ raise ValueError(
255
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
256
+ "Wrap it in a NonScalarMetricOutput subclass."
257
+ )
258
+ else:
259
+ scalar_metrics[k] = v
260
+
261
+ self.log_dict(scalar_metrics)
222
262
  self.val_metrics.reset()
223
263
 
224
264
  def on_test_epoch_end(self) -> None:
@@ -227,14 +267,30 @@ class RslearnLightningModule(L.LightningModule):
227
267
  We manually compute and log metrics here (instead of passing the MetricCollection
228
268
  to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
269
  metrics, while log_dict expects each metric to return a scalar tensor.
270
+
271
+ Non-scalar metrics (like confusion matrices) are logged separately.
230
272
  """
231
273
  metrics = self.test_metrics.compute()
232
- self.log_dict(metrics)
274
+
275
+ # Separate scalar and non-scalar metrics
276
+ scalar_metrics = {}
277
+ for k, v in metrics.items():
278
+ if isinstance(v, NonScalarMetricOutput):
279
+ self._log_non_scalar_metric(k, v)
280
+ elif isinstance(v, torch.Tensor) and v.dim() > 0 and v.numel() > 1:
281
+ raise ValueError(
282
+ f"Metric '{k}' returned a non-scalar tensor with shape {v.shape}. "
283
+ "Wrap it in a NonScalarMetricOutput subclass."
284
+ )
285
+ else:
286
+ scalar_metrics[k] = v
287
+
288
+ self.log_dict(scalar_metrics)
233
289
  self.test_metrics.reset()
234
290
 
235
291
  if self.metrics_file:
236
292
  with open(self.metrics_file, "w") as f:
237
- metrics_dict = {k: v.item() for k, v in metrics.items()}
293
+ metrics_dict = {k: v.item() for k, v in scalar_metrics.items()}
238
294
  json.dump(metrics_dict, f, indent=4)
239
295
  logger.info(f"Saved metrics to {self.metrics_file}")
240
296
 
@@ -0,0 +1,162 @@
1
+ """Metric output classes for non-scalar metrics."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import wandb
8
+ from torchmetrics import Metric
9
+
10
+ from rslearn.log_utils import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class NonScalarMetricOutput(ABC):
17
+ """Base class for non-scalar metric outputs that need special logging.
18
+
19
+ Subclasses should implement the log_to_wandb method to define how the metric
20
+ should be logged (only supports logging to Weights & Biases).
21
+ """
22
+
23
+ @abstractmethod
24
+ def log_to_wandb(self, name: str) -> None:
25
+ """Log this metric to wandb.
26
+
27
+ Args:
28
+ name: the metric name
29
+ """
30
+ pass
31
+
32
+
33
+ @dataclass
34
+ class ConfusionMatrixOutput(NonScalarMetricOutput):
35
+ """Confusion matrix metric output.
36
+
37
+ Args:
38
+ confusion_matrix: confusion matrix of shape (num_classes, num_classes)
39
+ where cm[i, j] is the count of samples with true label i and predicted
40
+ label j.
41
+ class_names: optional list of class names for axis labels
42
+ """
43
+
44
+ confusion_matrix: torch.Tensor
45
+ class_names: list[str] | None = None
46
+
47
+ def _expand_confusion_matrix(self) -> tuple[list[int], list[int]]:
48
+ """Expand confusion matrix to (preds, labels) pairs for wandb.
49
+
50
+ Returns:
51
+ Tuple of (preds, labels) as lists of integers.
52
+ """
53
+ cm = self.confusion_matrix.detach().cpu()
54
+
55
+ # Handle extra dimensions from distributed reduction
56
+ if cm.dim() > 2:
57
+ cm = cm.sum(dim=0)
58
+
59
+ total = cm.sum().item()
60
+ if total == 0:
61
+ return [], []
62
+
63
+ preds = []
64
+ labels = []
65
+ for true_label in range(cm.shape[0]):
66
+ for pred_label in range(cm.shape[1]):
67
+ count = cm[true_label, pred_label].item()
68
+ if count > 0:
69
+ preds.extend([pred_label] * int(count))
70
+ labels.extend([true_label] * int(count))
71
+
72
+ return preds, labels
73
+
74
+ def log_to_wandb(self, name: str) -> None:
75
+ """Log confusion matrix to wandb.
76
+
77
+ Args:
78
+ name: the metric name (e.g., "val_confusion_matrix")
79
+ """
80
+ preds, labels = self._expand_confusion_matrix()
81
+
82
+ if len(preds) == 0:
83
+ logger.warning(f"No samples to log for {name}")
84
+ return
85
+
86
+ num_classes = self.confusion_matrix.shape[0]
87
+ if self.class_names is None:
88
+ class_names = [str(i) for i in range(num_classes)]
89
+ else:
90
+ class_names = self.class_names
91
+
92
+ wandb.log(
93
+ {
94
+ name: wandb.plot.confusion_matrix(
95
+ preds=preds,
96
+ y_true=labels,
97
+ class_names=class_names,
98
+ title=name,
99
+ ),
100
+ },
101
+ )
102
+
103
+
104
+ class ConfusionMatrixMetric(Metric):
105
+ """Confusion matrix metric that works on flattened inputs.
106
+
107
+ Expects preds of shape (N, C) and labels of shape (N,).
108
+ Should be wrapped by ClassificationMetric or SegmentationMetric
109
+ which handle the task-specific preprocessing.
110
+
111
+ Args:
112
+ num_classes: number of classes
113
+ class_names: optional list of class names for labeling
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ num_classes: int,
119
+ class_names: list[str] | None = None,
120
+ ):
121
+ """Initialize a new ConfusionMatrixMetric.
122
+
123
+ Args:
124
+ num_classes: number of classes
125
+ class_names: optional list of class names for labeling
126
+ """
127
+ super().__init__()
128
+ self.num_classes = num_classes
129
+ self.class_names = class_names
130
+ self.add_state(
131
+ "confusion_matrix",
132
+ default=torch.zeros(num_classes, num_classes, dtype=torch.long),
133
+ dist_reduce_fx="sum",
134
+ )
135
+
136
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
137
+ """Update metric.
138
+
139
+ Args:
140
+ preds: predictions of shape (N, C) - probabilities
141
+ labels: ground truth of shape (N,) - class indices
142
+ """
143
+ if len(preds) == 0:
144
+ return
145
+
146
+ pred_classes = preds.argmax(dim=1) # (N,)
147
+
148
+ for true_label in range(self.num_classes):
149
+ for pred_label in range(self.num_classes):
150
+ count = ((labels == true_label) & (pred_classes == pred_label)).sum()
151
+ self.confusion_matrix[true_label, pred_label] += count
152
+
153
+ def compute(self) -> ConfusionMatrixOutput:
154
+ """Returns the confusion matrix wrapped in ConfusionMatrixOutput."""
155
+ return ConfusionMatrixOutput(
156
+ confusion_matrix=self.confusion_matrix,
157
+ class_names=self.class_names,
158
+ )
159
+
160
+ def reset(self) -> None:
161
+ """Reset metric."""
162
+ super().reset()