rslearn 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ import torchmetrics.classification
10
10
  from torchmetrics import Metric, MetricCollection
11
11
 
12
12
  from rslearn.models.component import FeatureMaps, Predictor
13
+ from rslearn.train.metrics import ConfusionMatrixMetric
13
14
  from rslearn.train.model_context import (
14
15
  ModelContext,
15
16
  ModelOutput,
@@ -43,6 +44,8 @@ class SegmentationTask(BasicTask):
43
44
  other_metrics: dict[str, Metric] = {},
44
45
  output_probs: bool = False,
45
46
  output_class_idx: int | None = None,
47
+ enable_confusion_matrix: bool = False,
48
+ class_names: list[str] | None = None,
46
49
  **kwargs: Any,
47
50
  ) -> None:
48
51
  """Initialize a new SegmentationTask.
@@ -80,6 +83,10 @@ class SegmentationTask(BasicTask):
80
83
  during prediction.
81
84
  output_class_idx: if set along with output_probs, only output the probability
82
85
  for this specific class index (single-channel output).
86
+ enable_confusion_matrix: whether to compute confusion matrix (default false).
87
+ If true, it requires wandb to be initialized for logging.
88
+ class_names: optional list of class names for labeling confusion matrix axes.
89
+ If not provided, classes will be labeled as "class_0", "class_1", etc.
83
90
  kwargs: additional arguments to pass to BasicTask
84
91
  """
85
92
  super().__init__(**kwargs)
@@ -106,6 +113,8 @@ class SegmentationTask(BasicTask):
106
113
  self.other_metrics = other_metrics
107
114
  self.output_probs = output_probs
108
115
  self.output_class_idx = output_class_idx
116
+ self.enable_confusion_matrix = enable_confusion_matrix
117
+ self.class_names = class_names
109
118
 
110
119
  def process_inputs(
111
120
  self,
@@ -285,6 +294,14 @@ class SegmentationTask(BasicTask):
285
294
  if self.other_metrics:
286
295
  metrics.update(self.other_metrics)
287
296
 
297
+ if self.enable_confusion_matrix:
298
+ metrics["confusion_matrix"] = SegmentationMetric(
299
+ ConfusionMatrixMetric(
300
+ num_classes=self.num_classes,
301
+ class_names=self.class_names,
302
+ ),
303
+ )
304
+
288
305
  return MetricCollection(metrics)
289
306
 
290
307
 
rslearn/utils/fsspec.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from typing import Any
8
8
 
@@ -16,6 +16,56 @@ from rslearn.log_utils import get_logger
16
16
  logger = get_logger(__name__)
17
17
 
18
18
 
19
+ def iter_nonhidden(path: UPath) -> Iterator[UPath]:
20
+ """Iterate over non-hidden entries in a directory.
21
+
22
+ Hidden entries are those whose basename starts with "." (e.g. ".DS_Store").
23
+
24
+ Args:
25
+ path: the directory to iterate.
26
+
27
+ Yields:
28
+ non-hidden UPath entries in the directory.
29
+ """
30
+ try:
31
+ it = path.iterdir()
32
+ except (FileNotFoundError, NotADirectoryError):
33
+ return
34
+
35
+ for p in it:
36
+ if p.name.startswith("."):
37
+ continue
38
+ yield p
39
+
40
+
41
+ def iter_nonhidden_subdirs(path: UPath) -> Iterator[UPath]:
42
+ """Iterate over non-hidden subdirectories in a directory.
43
+
44
+ Args:
45
+ path: the directory to iterate.
46
+
47
+ Yields:
48
+ non-hidden subdirectories in the directory.
49
+ """
50
+ for p in iter_nonhidden(path):
51
+ if p.is_dir():
52
+ yield p
53
+
54
+
55
+ def iter_nonhidden_files(path: UPath) -> Iterator[UPath]:
56
+ """Iterate over non-hidden files in a directory.
57
+
58
+ Args:
59
+ path: the directory to iterate.
60
+
61
+ Yields:
62
+ non-hidden files in the directory.
63
+ """
64
+ for p in iter_nonhidden(path):
65
+ if p.is_file():
66
+ yield p
67
+
68
+
19
69
  @contextmanager
20
70
  def get_upath_local(
21
71
  path: UPath, extra_paths: list[UPath] = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -9,6 +9,7 @@ rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,79
9
9
  rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
10
10
  rslearn/config/dataset.py,sha256=abxIUFDAYmCd4pzGnkPnW_pYyws1yhXFWJ5HnVU4WHo,23942
11
11
  rslearn/data_sources/__init__.py,sha256=FZQckYwsnnLokMeYmi0ktUyQd9bAHyLN1_-Xc3qYLag,767
12
+ rslearn/data_sources/aws_google_satellite_embedding_v1.py,sha256=ga5G8uDXdMj2pw8qGC2PD11PBg8Spuf3b4QEwsVJaBY,12805
12
13
  rslearn/data_sources/aws_landsat.py,sha256=bJmwBbUV4vjKBNp1MHt4sHhnIjhMis_jOI3FpksQc6w,16435
13
14
  rslearn/data_sources/aws_open_data.py,sha256=fum34DqqDiuiiYBfZtGFrNNOLylE9o3o7Cyb2e0Eo0g,29101
14
15
  rslearn/data_sources/aws_sentinel1.py,sha256=LfJLDhsd_6h_JinD8PbiiAyxajkIdvAc--5BJryUKlo,4674
@@ -16,7 +17,7 @@ rslearn/data_sources/aws_sentinel2_element84.py,sha256=qeCuiSlvhWChSY3AwYsKT6nZU
16
17
  rslearn/data_sources/climate_data_store.py,sha256=mqLJfYubD6m9VwxpLunoIv_MNFN6Ue1hBBVj552e8uQ,18289
17
18
  rslearn/data_sources/copernicus.py,sha256=ushAgYGxU2MzPcUNnEvEfPgO0RCC9Rbjzi189xq0jgc,35001
18
19
  rslearn/data_sources/data_source.py,sha256=xojlCoAnGTCHKbEx98JkW0oYzAKBbgGMNc0kicEjHWk,4863
19
- rslearn/data_sources/direct_materialize_data_source.py,sha256=PBmxsLBNakJeX1s92pc4FCuHANfhkYHS2vt60RGdkj0,11276
20
+ rslearn/data_sources/direct_materialize_data_source.py,sha256=UnFuCSJED9-YSFp12-MosV8bMFj6AqCb75a9ADu_Cxw,11030
20
21
  rslearn/data_sources/earthdaily.py,sha256=qUtHUG1oV5IlCWXVovUcYxQhqdNDKWaEe-BKnooWX88,14623
21
22
  rslearn/data_sources/earthdatahub.py,sha256=KRf1VnxPI9jsT0utEkeYvsCwu7LXo9t-RvMi8gXehag,15889
22
23
  rslearn/data_sources/eurocrops.py,sha256=dJ4d0xvt-rID_HuAchyucFJBuAQL-Kk1h_qm6GOH-mE,8641
@@ -32,7 +33,7 @@ rslearn/data_sources/soilgrids.py,sha256=qbnnCIOa6tlN8wxmNCzAj60-pghKEbRxa7lVIgM
32
33
  rslearn/data_sources/stac.py,sha256=Gj8TZ5pifVzWPCuzgphrle2ekQ02OET54rj-02sR2nw,10705
33
34
  rslearn/data_sources/usda_cdl.py,sha256=3GhcgTB50T7GA44nB9WwItqDJliELquw_YbiAVxh6kc,6808
34
35
  rslearn/data_sources/usgs_landsat.py,sha256=IsQOhWY8nwmgixJu1uMSR4CqsC3igcP3TArdBXkETd8,10178
35
- rslearn/data_sources/utils.py,sha256=NOC0qOyYVS6f8EUrSeP4mH0XZbSrtTLV-gKGbCC6ccg,16586
36
+ rslearn/data_sources/utils.py,sha256=EAVFCYzjFvuHWd7E2ghTub9f-bbDhq83p3x9IJDjgvk,16843
36
37
  rslearn/data_sources/vector_source.py,sha256=NCa7CxIrGKe9yRT0NyyFKFQboDGDZ1h7663PV9OfMOM,44
37
38
  rslearn/data_sources/worldcereal.py,sha256=OWZA0pvQQiKvuA5AVAc0lw8JStMEeF4DYOh0n2vdg6I,21521
38
39
  rslearn/data_sources/worldcover.py,sha256=ahyrGoXMAGWsIUDHSrqPywiK7ycwUD3E3BruNMxpo90,6057
@@ -47,7 +48,7 @@ rslearn/dataset/materialize.py,sha256=VoL5Qf5pGcQV4QMlO5vrcu7w0Sl1NdIRLUVk0kSCMO
47
48
  rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
48
49
  rslearn/dataset/window.py,sha256=X4q8YzcSOTtwKxCPf71QLMoyKUtYMSnZu0kPnmVSUx4,10644
49
50
  rslearn/dataset/storage/__init__.py,sha256=R50AVV5LH2g7ol0-jyvGcB390VsclXGbJXz4fmkn9as,52
50
- rslearn/dataset/storage/file.py,sha256=g9HZ3CD4QcgyVNsBaXhjIKQgDOAeZ4R08sJ7ntx4wo8,6815
51
+ rslearn/dataset/storage/file.py,sha256=GJJgH_eHknLbMQwoC3mOoXJKl6Ha3oNXzz62FIEMWlg,7130
51
52
  rslearn/dataset/storage/storage.py,sha256=DxZ7iwV938PiLwdQzb5EXSb4Mj8bRGmOTmA9fzq_Ge8,4840
52
53
  rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
53
54
  rslearn/models/anysat.py,sha256=nzk6hB83ltNFNXYRNA1rTvq2AQcAhwyvgBaZui1M37o,8107
@@ -61,6 +62,7 @@ rslearn/models/dinov3.py,sha256=Q9X7VTwzjllLSEvc235C9BY_jMnIoSybsiOkeA58uHo,6472
61
62
  rslearn/models/faster_rcnn.py,sha256=yOipLPmVHbadvYCR9xfCYgmkU9Mot6fgDK-kKicVTlo,8685
62
63
  rslearn/models/feature_center_crop.py,sha256=_Mu3E4iJLBug9I4ZIBIpB_VJo-xGterHmhtIFGaHR34,1808
63
64
  rslearn/models/fpn.py,sha256=qm7nKMgsZrCoAdz8ASmNKU2nvZ6USm5CedMfy_w_gwE,2079
65
+ rslearn/models/global_pool.py,sha256=Bl48AVJ7g70hPmVLJbK1y_JN9_FTyANc_7tr6YOHANY,2782
64
66
  rslearn/models/module_wrapper.py,sha256=73JspaglnNabUGZB2EiCYF_dZ3-Kicg_OpoTfUWHONk,2271
65
67
  rslearn/models/molmo.py,sha256=lXnevwTCNyc1XcnJUB5_pK1G2AJGYMvQYU21mZFf5u0,2246
66
68
  rslearn/models/multitask.py,sha256=bpFxvtFowRyT-tvRSdY7AKbEx_i1y7sToEzZgTMcF4s,16264
@@ -112,14 +114,15 @@ rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRy
112
114
  rslearn/models/presto/presto.py,sha256=fkyHB85Hfx5L-4yejSFAFv83gk9VFqAR1GTgggtq0EA,11049
113
115
  rslearn/models/presto/single_file_presto.py,sha256=-P00xjhj9dx3O6HqWpQmG9dPk_i6bT_t8vhX4uQm5tA,30242
114
116
  rslearn/tile_stores/__init__.py,sha256=-cW1J7So60SEP5ZLHCPdaFBV5CxvV3QlOhaFnUkhTJ0,1675
115
- rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
117
+ rslearn/tile_stores/default.py,sha256=AG2j0FCNi_4cnXqLjRIef5wMqMJ5_YtSkTIhk7qJQVQ,15134
116
118
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
117
119
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
118
120
  rslearn/train/all_crops_dataset.py,sha256=CWnqbSjRXJZQsudljvpA07oldiP4fZTmjwrT0sjVnq4,21399
119
- rslearn/train/data_module.py,sha256=yPShftkHJ2bhJ4carwYYb9c3PkkP7ArzXQyu37EuAxk,23718
120
- rslearn/train/dataset.py,sha256=AAkdHX3q8VD1Geq9yIatiPkM5blA2luVPPwDmXDB6Z8,43284
121
+ rslearn/train/data_module.py,sha256=G1TRhXg8SPewYy0BTZN5KpeLPK72qIaH15ePfUwrxgM,23865
122
+ rslearn/train/dataset.py,sha256=vCmm6yrW2bAc5A94aBwQe-SOIGdVcZYMM2oBYRq2_sw,45253
121
123
  rslearn/train/dataset_index.py,sha256=S5iXhQga5gnnkDqThXXlyjIwkJBPVWiUfDPx3iVs-pw,5306
122
- rslearn/train/lightning_module.py,sha256=m0aIGk5xO5y12DEiwSl6eAko6X-gQ78_Wsbvz4Hb_NE,15364
124
+ rslearn/train/lightning_module.py,sha256=n4hasJBVlAmMhvf2yaFo0gy1vGz5haQkJpZdCSKlJ8A,17482
125
+ rslearn/train/metrics.py,sha256=RknMf2n09D5XBCf0YM4Zmm0XI-pFbRtsmbY51ipVMPk,4799
123
126
  rslearn/train/model_context.py,sha256=8DMWGj5xCRmRDo_38lkhkUMHfK_yg3XZrUJQIz5a1vA,3200
124
127
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
125
128
  rslearn/train/prediction_writer.py,sha256=cRFehEtr0iBuVqzE69a0B4Lvb8ywxLeyon34KWI86H0,16961
@@ -130,13 +133,13 @@ rslearn/train/callbacks/freeze_unfreeze.py,sha256=8fIzBMhCKKjpTffIeAdhdSjsBd8NjT
130
133
  rslearn/train/callbacks/gradients.py,sha256=4YqCf0tBb6E5FnyFYbveXfQFlgNPyxIXb2FCWX4-6qs,5075
131
134
  rslearn/train/callbacks/peft.py,sha256=wEOKsS3RhsRaZTXn_Kz2wdsZdIiIaZPdCJWtdJBurT8,4156
132
135
  rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFNMs,75
133
- rslearn/train/tasks/classification.py,sha256=H-Ayqm59IxwrczC8lUV5J5vg-JILhQhTiVlyaTpBs2k,14259
136
+ rslearn/train/tasks/classification.py,sha256=_3cRa8ojd9sG2ELRW_BvByZh2YFdCBaklR8Kv9LAgOY,14864
134
137
  rslearn/train/tasks/detection.py,sha256=uDMGtsCMSk9OGXn-vpFKBAyHyVN0ji2NCfqBgg1BQyw,21725
135
138
  rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbAMVo,3891
136
139
  rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
137
- rslearn/train/tasks/per_pixel_regression.py,sha256=njShN-U9fx3SPcCxGgbDlZAp3DT_GlTt0BRZS416gnw,10387
138
- rslearn/train/tasks/regression.py,sha256=_TGlj3PA14Iye0duf25TcGZQpFXU9fGfNilzpNuPS78,12693
139
- rslearn/train/tasks/segmentation.py,sha256=dn1yo1dIArKvW9Giw8-LZyIZ87q76eslL0mk58GyApo,29663
140
+ rslearn/train/tasks/per_pixel_regression.py,sha256=3m_BTP2akadYe3IuAlCG2bd_alfNyom55-pFrI2q4PE,10928
141
+ rslearn/train/tasks/regression.py,sha256=TzHL42gm3aIdev0R7_uz_TSYbAwSvQPjCD42y1p9_7Y,13269
142
+ rslearn/train/tasks/segmentation.py,sha256=b9XS09EQvum89eoW3vWqFMKuCRtznODteKIr1hFnIz4,30531
140
143
  rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
141
144
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
142
145
  rslearn/train/transforms/concatenate.py,sha256=S8f1svzwb5UmeAgzXe4Af_hFvt5o0tQctIE6t3QYuPI,2625
@@ -153,7 +156,7 @@ rslearn/utils/__init__.py,sha256=GZc1erpEfXTc32yjEDbt5rnMrnXEBY7WVm3v4NlwwWY,620
153
156
  rslearn/utils/array.py,sha256=RC7ygtPnQwU6Lb9kwORvNxatJcaJ76JPsykQvndAfes,2444
154
157
  rslearn/utils/colors.py,sha256=ELY9_buH06TOVPLrDAyf2S0G--ZiOxnnP8Ujim6_3ig,369
155
158
  rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
156
- rslearn/utils/fsspec.py,sha256=h3fER_bkewzR9liEAULXguTIvXLUXA17pC_yZoWN5Tk,5902
159
+ rslearn/utils/fsspec.py,sha256=TcEUgXKvsmtKHv5JVOI2Vp4WNfVNeTok0x4JgZaD1iw,7052
157
160
  rslearn/utils/geometry.py,sha256=VzLoxtwdV3uC3szowT-bGuCFF6ge8eK0m01lq8q-01Q,22423
158
161
  rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF9Ps,3599
159
162
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
@@ -175,10 +178,10 @@ rslearn/vis/render_sensor_image.py,sha256=D0ynK6ABPV046970lIKwF98klpSCtrsUvZTwtZ
175
178
  rslearn/vis/render_vector_label.py,sha256=ncwgRKCYCJCK1-wTpjgksOiDDebku37LpAyq6wsg4jg,14939
176
179
  rslearn/vis/utils.py,sha256=Zop3dEmyaXUYhPiGdYzrTO8BRXWscP2dEZy2myQUnNk,2765
177
180
  rslearn/vis/vis_server.py,sha256=kIGnhTy-yfu5lBOVCoo8VVG259i974JPszudCePbzfI,20157
178
- rslearn-0.0.27.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
179
- rslearn-0.0.27.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
180
- rslearn-0.0.27.dist-info/METADATA,sha256=_6aHsGpgH_T2MQSb9qKeYshP6zdzwu3CmI2sBhD3dqk,38714
181
- rslearn-0.0.27.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
182
- rslearn-0.0.27.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
183
- rslearn-0.0.27.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
184
- rslearn-0.0.27.dist-info/RECORD,,
181
+ rslearn-0.0.29.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
182
+ rslearn-0.0.29.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
183
+ rslearn-0.0.29.dist-info/METADATA,sha256=PMjB15sAZg5VA7qHkMrhG_2hSULc0Wopxr7G3op20Hg,38714
184
+ rslearn-0.0.29.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
185
+ rslearn-0.0.29.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
186
+ rslearn-0.0.29.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
187
+ rslearn-0.0.29.dist-info/RECORD,,