rslearn 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import tempfile
5
5
  import xml.etree.ElementTree as ET
6
- from datetime import timedelta
6
+ from datetime import datetime, timedelta
7
7
  from typing import Any
8
8
 
9
9
  import affine
@@ -12,6 +12,7 @@ import planetary_computer
12
12
  import rasterio
13
13
  import requests
14
14
  from rasterio.enums import Resampling
15
+ from typing_extensions import override
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
24
25
  from rslearn.utils.fsspec import join_upath
25
26
  from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
26
27
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
28
+ from rslearn.utils.stac import StacClient, StacItem
27
29
 
28
30
  from .copernicus import get_harmonize_callback
29
31
 
30
32
  logger = get_logger(__name__)
31
33
 
34
+ # Max limit accepted by Planetary Computer API.
35
+ PLANETARY_COMPUTER_LIMIT = 1000
36
+
37
+
38
+ class PlanetaryComputerStacClient(StacClient):
39
+ """A StacClient subclass that handles Planetary Computer's pagination limits.
40
+
41
+ Planetary Computer STAC API does not support standard pagination and has a max
42
+ limit of 1000. If the initial query returns 1000 items, this client paginates
43
+ by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
44
+ """
45
+
46
+ @override
47
+ def search(
48
+ self,
49
+ collections: list[str] | None = None,
50
+ bbox: tuple[float, float, float, float] | None = None,
51
+ intersects: dict[str, Any] | None = None,
52
+ date_time: datetime | tuple[datetime, datetime] | None = None,
53
+ ids: list[str] | None = None,
54
+ limit: int | None = None,
55
+ query: dict[str, Any] | None = None,
56
+ sortby: list[dict[str, str]] | None = None,
57
+ ) -> list[StacItem]:
58
+ # We will use sortby for pagination, so the caller must not set it.
59
+ if sortby is not None:
60
+ raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
61
+
62
+ # First, try a simple query with the PC limit to detect if pagination is needed.
63
+ # We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
64
+ # support standard pagination, and we need to detect when we hit the limit
65
+ # to switch to ID-based pagination.
66
+ # We could just start sorting by ID here and do pagination, but we treate it as
67
+ # a special case to avoid sorting since that seems to speed up the query.
68
+ stac_items = super().search(
69
+ collections=collections,
70
+ bbox=bbox,
71
+ intersects=intersects,
72
+ date_time=date_time,
73
+ ids=ids,
74
+ limit=PLANETARY_COMPUTER_LIMIT,
75
+ query=query,
76
+ )
77
+
78
+ # If we got fewer than the PC limit, we have all the results.
79
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
80
+ return stac_items
81
+
82
+ # We hit the limit, so we need to paginate by ID.
83
+ # Re-fetch with sorting by ID to ensure consistent ordering for pagination.
84
+ logger.debug(
85
+ "Initial request returned %d items (at limit), switching to ID pagination",
86
+ len(stac_items),
87
+ )
88
+
89
+ all_items: list[StacItem] = []
90
+ last_id: str | None = None
91
+
92
+ while True:
93
+ # Build query with id > last_id if we're paginating.
94
+ combined_query: dict[str, Any] = dict(query) if query else {}
95
+ if last_id is not None:
96
+ combined_query["id"] = {"gt": last_id}
97
+
98
+ stac_items = super().search(
99
+ collections=collections,
100
+ bbox=bbox,
101
+ intersects=intersects,
102
+ date_time=date_time,
103
+ ids=ids,
104
+ limit=PLANETARY_COMPUTER_LIMIT,
105
+ query=combined_query if combined_query else None,
106
+ sortby=[{"field": "id", "direction": "asc"}],
107
+ )
108
+
109
+ all_items.extend(stac_items)
110
+
111
+ # If we got fewer than the limit, we've fetched everything.
112
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
113
+ break
114
+
115
+ # Otherwise, paginate using the last item's ID.
116
+ last_id = stac_items[-1].id
117
+ logger.debug(
118
+ "Got %d items, paginating with id > %s",
119
+ len(stac_items),
120
+ last_id,
121
+ )
122
+
123
+ logger.debug("Total items fetched: %d", len(all_items))
124
+ return all_items
125
+
32
126
 
33
127
  class PlanetaryComputer(StacDataSource, TileStore):
34
128
  """Modality-agnostic data source for data on Microsoft Planetary Computer.
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
100
194
  required_assets=required_assets,
101
195
  cache_dir=cache_upath,
102
196
  )
197
+
198
+ # Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
199
+ self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
200
+
103
201
  self.asset_bands = asset_bands
104
202
  self.timeout = timeout
105
203
  self.skip_items_missing_assets = skip_items_missing_assets
@@ -567,3 +665,53 @@ class Naip(PlanetaryComputer):
567
665
  context=context,
568
666
  **kwargs,
569
667
  )
668
+
669
+
670
+ class CopDemGlo30(PlanetaryComputer):
671
+ """A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
672
+
673
+ See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
674
+ """
675
+
676
+ COLLECTION_NAME = "cop-dem-glo-30"
677
+ DATA_ASSET = "data"
678
+
679
+ def __init__(
680
+ self,
681
+ band_name: str = "DEM",
682
+ context: DataSourceContext = DataSourceContext(),
683
+ **kwargs: Any,
684
+ ):
685
+ """Initialize a new CopDemGlo30 instance.
686
+
687
+ Args:
688
+ band_name: band name to use if the layer config is missing from the
689
+ context.
690
+ context: the data source context.
691
+ kwargs: additional arguments to pass to PlanetaryComputer.
692
+ """
693
+ if context.layer_config is not None:
694
+ if len(context.layer_config.band_sets) != 1:
695
+ raise ValueError("expected a single band set")
696
+ if len(context.layer_config.band_sets[0].bands) != 1:
697
+ raise ValueError("expected band set to have a single band")
698
+ band_name = context.layer_config.band_sets[0].bands[0]
699
+
700
+ super().__init__(
701
+ collection_name=self.COLLECTION_NAME,
702
+ asset_bands={self.DATA_ASSET: [band_name]},
703
+ # Skip since all items should have the same asset(s).
704
+ skip_items_missing_assets=True,
705
+ context=context,
706
+ **kwargs,
707
+ )
708
+
709
+ def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
710
+ # Copernicus DEM is static; ignore item timestamps so it matches any window.
711
+ item = super()._stac_item_to_item(stac_item)
712
+ item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
713
+ return item
714
+
715
+ def _get_search_time_range(self, geometry: STGeometry) -> None:
716
+ # Copernicus DEM is static; do not filter STAC searches by time.
717
+ return None
@@ -1,6 +1,7 @@
1
1
  """A partial data source implementation providing get_items using a STAC API."""
2
2
 
3
3
  import json
4
+ from datetime import datetime
4
5
  from typing import Any
5
6
 
6
7
  import shapely
@@ -11,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
11
12
  from rslearn.data_sources.data_source import Item, ItemLookupDataSource
12
13
  from rslearn.data_sources.utils import match_candidate_items_to_window
13
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.utils.fsspec import open_atomic
14
16
  from rslearn.utils.geometry import STGeometry
15
17
  from rslearn.utils.stac import StacClient, StacItem
16
18
 
@@ -132,6 +134,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
132
134
 
133
135
  return SourceItem(stac_item.id, geom, asset_urls, properties)
134
136
 
137
+ def _get_search_time_range(
138
+ self, geometry: STGeometry
139
+ ) -> datetime | tuple[datetime, datetime] | None:
140
+ """Get time range to include in STAC API search.
141
+
142
+ By default, we filter STAC searches to the window's time range. Subclasses can
143
+ override this to disable time filtering for "static" datasets.
144
+
145
+ Args:
146
+ geometry: the geometry we are searching for.
147
+
148
+ Returns:
149
+ the time range (or timestamp) to pass to the STAC search, or None to avoid
150
+ temporal filtering in the search request.
151
+ """
152
+ # Note: StacClient.search accepts either a datetime or a (start, end) tuple.
153
+ return geometry.time_range
154
+
135
155
  def get_item_by_name(self, name: str) -> SourceItem:
136
156
  """Gets an item by name.
137
157
 
@@ -168,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
168
188
 
169
189
  # Finally we cache it if cache_dir is set.
170
190
  if cache_fname is not None:
171
- with cache_fname.open("w") as f:
191
+ with open_atomic(cache_fname, "w") as f:
172
192
  json.dump(item.serialize(), f)
173
193
 
174
194
  return item
@@ -191,10 +211,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
191
211
  # for each requested geometry.
192
212
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
193
213
  logger.debug("performing STAC search for geometry %s", wgs84_geometry)
214
+ search_time_range = self._get_search_time_range(wgs84_geometry)
194
215
  stac_items = self.client.search(
195
216
  collections=[self.collection_name],
196
217
  intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
197
- date_time=wgs84_geometry.time_range,
218
+ date_time=search_time_range,
198
219
  query=self.query,
199
220
  limit=self.limit,
200
221
  )
@@ -239,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
239
260
  cache_fname = self.cache_dir / f"{item.name}.json"
240
261
  if cache_fname.exists():
241
262
  continue
242
- with cache_fname.open("w") as f:
263
+ with open_atomic(cache_fname, "w") as f:
243
264
  json.dump(item.serialize(), f)
244
265
 
245
266
  cur_groups = match_candidate_items_to_window(
rslearn/main.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import multiprocessing
5
+ import os
5
6
  import random
6
7
  import sys
7
8
  import time
@@ -45,6 +46,7 @@ handler_registry = {}
45
46
  ItemType = TypeVar("ItemType", bound="Item")
46
47
 
47
48
  MULTIPROCESSING_CONTEXT = "forkserver"
49
+ MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
48
50
 
49
51
 
50
52
  def register_handler(category: Any, command: str) -> Callable:
@@ -837,7 +839,8 @@ def model_predict() -> None:
837
839
  def main() -> None:
838
840
  """CLI entrypoint."""
839
841
  try:
840
- multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
842
+ mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
843
+ multiprocessing.set_start_method(mp_context)
841
844
  except RuntimeError as e:
842
845
  logger.error(
843
846
  f"Multiprocessing context already set to {multiprocessing.get_context()}: "
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
180
180
  # want to pass 2 timesteps to the model.
181
181
  # TODO is probably to make this behaviour clearer but lets leave it like
182
182
  # this for now to not break things.
183
- num_timesteps = images.shape[1] // image_channels
183
+ num_timesteps = image_channels // images.shape[1]
184
184
  batched_timesteps = images.shape[2] // num_timesteps
185
185
  images = rearrange(
186
186
  images,
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
210
210
  # Fail silently for single-dataset case, which is okay
211
211
  pass
212
212
 
213
+ def on_validation_epoch_end(self) -> None:
214
+ """Compute and log validation metrics at epoch end.
215
+
216
+ We manually compute and log metrics here (instead of passing the MetricCollection
217
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
+ metrics, while log_dict expects each metric to return a scalar tensor.
219
+ """
220
+ metrics = self.val_metrics.compute()
221
+ self.log_dict(metrics)
222
+ self.val_metrics.reset()
223
+
213
224
  def on_test_epoch_end(self) -> None:
214
- """Optionally save the test metrics to a file."""
225
+ """Compute and log test metrics at epoch end, optionally save to file.
226
+
227
+ We manually compute and log metrics here (instead of passing the MetricCollection
228
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
+ metrics, while log_dict expects each metric to return a scalar tensor.
230
+ """
231
+ metrics = self.test_metrics.compute()
232
+ self.log_dict(metrics)
233
+ self.test_metrics.reset()
234
+
215
235
  if self.metrics_file:
216
236
  with open(self.metrics_file, "w") as f:
217
- metrics = self.test_metrics.compute()
218
237
  metrics_dict = {k: v.item() for k, v in metrics.items()}
219
238
  json.dump(metrics_dict, f, indent=4)
220
239
  logger.info(f"Saved metrics to {self.metrics_file}")
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
300
319
  sync_dist=True,
301
320
  )
302
321
  self.val_metrics.update(outputs, targets)
303
- self.log_dict(
304
- self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
305
- )
306
322
 
307
323
  def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
308
324
  """Compute the test loss and additional metrics.
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
340
356
  sync_dist=True,
341
357
  )
342
358
  self.test_metrics.update(outputs, targets)
343
- self.log_dict(
344
- self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
345
- )
346
359
 
347
360
  if self.visualize_dir:
348
361
  for inp, target, output, metadata in zip(
@@ -118,13 +118,16 @@ class MultiTask(Task):
118
118
 
119
119
  def get_metrics(self) -> MetricCollection:
120
120
  """Get metrics for this task."""
121
- metrics = []
121
+ # Flatten metrics into a single dict with task_name/ prefix to avoid nested
122
+ # MetricCollections. Nested collections cause issues because MetricCollection
123
+ # has postfix=None which breaks MetricCollection.compute().
124
+ all_metrics = {}
122
125
  for task_name, task in self.tasks.items():
123
- cur_metrics = {}
124
126
  for metric_name, metric in task.get_metrics().items():
125
- cur_metrics[metric_name] = MetricWrapper(task_name, metric)
126
- metrics.append(MetricCollection(cur_metrics, prefix=f"{task_name}/"))
127
- return MetricCollection(metrics)
127
+ all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
128
+ task_name, metric
129
+ )
130
+ return MetricCollection(all_metrics)
128
131
 
129
132
 
130
133
  class MetricWrapper(Metric):
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
100
100
  raise ValueError(
101
101
  f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
102
102
  )
103
- return (raw_output / self.scale_factor).cpu().numpy()
103
+ return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
104
104
 
105
105
  def visualize(
106
106
  self,
@@ -53,11 +53,14 @@ class SegmentationTask(BasicTask):
53
53
  enable_accuracy_metric: bool = True,
54
54
  enable_miou_metric: bool = False,
55
55
  enable_f1_metric: bool = False,
56
+ report_metric_per_class: bool = False,
56
57
  f1_metric_thresholds: list[list[float]] = [[0.5]],
57
58
  metric_kwargs: dict[str, Any] = {},
58
59
  miou_metric_kwargs: dict[str, Any] = {},
59
60
  prob_scales: list[float] | None = None,
60
61
  other_metrics: dict[str, Metric] = {},
62
+ output_probs: bool = False,
63
+ output_class_idx: int | None = None,
61
64
  **kwargs: Any,
62
65
  ) -> None:
63
66
  """Initialize a new SegmentationTask.
@@ -74,6 +77,8 @@ class SegmentationTask(BasicTask):
74
77
  enable_accuracy_metric: whether to enable the accuracy metric (default
75
78
  true).
76
79
  enable_f1_metric: whether to enable the F1 metric (default false).
80
+ report_metric_per_class: whether to report chosen metrics for each class, in
81
+ addition to the average score across classes.
77
82
  enable_miou_metric: whether to enable the mean IoU metric (default false).
78
83
  f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
79
84
  Each inner list is used to initialize a separate F1 metric where the
@@ -89,6 +94,10 @@ class SegmentationTask(BasicTask):
89
94
  this is only applied during prediction, not when computing val or test
90
95
  metrics.
91
96
  other_metrics: additional metrics to configure on this task.
97
+ output_probs: if True, output raw softmax probabilities instead of class IDs
98
+ during prediction.
99
+ output_class_idx: if set along with output_probs, only output the probability
100
+ for this specific class index (single-channel output).
92
101
  kwargs: additional arguments to pass to BasicTask
93
102
  """
94
103
  super().__init__(**kwargs)
@@ -107,11 +116,14 @@ class SegmentationTask(BasicTask):
107
116
  self.enable_accuracy_metric = enable_accuracy_metric
108
117
  self.enable_f1_metric = enable_f1_metric
109
118
  self.enable_miou_metric = enable_miou_metric
119
+ self.report_metric_per_class = report_metric_per_class
110
120
  self.f1_metric_thresholds = f1_metric_thresholds
111
121
  self.metric_kwargs = metric_kwargs
112
122
  self.miou_metric_kwargs = miou_metric_kwargs
113
123
  self.prob_scales = prob_scales
114
124
  self.other_metrics = other_metrics
125
+ self.output_probs = output_probs
126
+ self.output_class_idx = output_class_idx
115
127
 
116
128
  def process_inputs(
117
129
  self,
@@ -167,7 +179,9 @@ class SegmentationTask(BasicTask):
167
179
  metadata: metadata about the patch being read
168
180
 
169
181
  Returns:
170
- CHW numpy array with one channel, containing the predicted class IDs.
182
+ CHW numpy array. If output_probs is False, returns one channel with
183
+ predicted class IDs. If output_probs is True, returns softmax probabilities
184
+ (num_classes channels, or 1 channel if output_class_idx is set).
171
185
  """
172
186
  if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
173
187
  raise ValueError("the output for SegmentationTask must be a CHW tensor")
@@ -179,6 +193,15 @@ class SegmentationTask(BasicTask):
179
193
  self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
180
194
  )[:, None, None]
181
195
  )
196
+
197
+ if self.output_probs:
198
+ # Return raw softmax probabilities
199
+ probs = raw_output.cpu().numpy()
200
+ if self.output_class_idx is not None:
201
+ # Return only the specified class probability
202
+ return probs[self.output_class_idx : self.output_class_idx + 1, :, :]
203
+ return probs
204
+
182
205
  classes = raw_output.argmax(dim=0).cpu().numpy()
183
206
  return classes[None, :, :]
184
207
 
@@ -237,29 +260,41 @@ class SegmentationTask(BasicTask):
237
260
  # Metric name can't contain "." so change to ",".
238
261
  suffix = "_" + str(thresholds[0]).replace(".", ",")
239
262
 
263
+ # Create one metric per type - it returns a dict with "avg" and optionally per-class keys
240
264
  metrics["F1" + suffix] = SegmentationMetric(
241
- F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
265
+ F1Metric(
266
+ num_classes=self.num_classes,
267
+ score_thresholds=thresholds,
268
+ report_per_class=self.report_metric_per_class,
269
+ ),
242
270
  )
243
271
  metrics["precision" + suffix] = SegmentationMetric(
244
272
  F1Metric(
245
273
  num_classes=self.num_classes,
246
274
  score_thresholds=thresholds,
247
275
  metric_mode="precision",
248
- )
276
+ report_per_class=self.report_metric_per_class,
277
+ ),
249
278
  )
250
279
  metrics["recall" + suffix] = SegmentationMetric(
251
280
  F1Metric(
252
281
  num_classes=self.num_classes,
253
282
  score_thresholds=thresholds,
254
283
  metric_mode="recall",
255
- )
284
+ report_per_class=self.report_metric_per_class,
285
+ ),
256
286
  )
257
287
 
258
288
  if self.enable_miou_metric:
259
- miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
289
+ miou_metric_kwargs: dict[str, Any] = dict(
290
+ num_classes=self.num_classes,
291
+ report_per_class=self.report_metric_per_class,
292
+ )
260
293
  if self.nodata_value is not None:
261
294
  miou_metric_kwargs["nodata_value"] = self.nodata_value
262
295
  miou_metric_kwargs.update(self.miou_metric_kwargs)
296
+
297
+ # Create one metric - it returns a dict with "avg" and optionally per-class keys
263
298
  metrics["mean_iou"] = SegmentationMetric(
264
299
  MeanIoUMetric(**miou_metric_kwargs),
265
300
  pass_probabilities=False,
@@ -274,6 +309,20 @@ class SegmentationTask(BasicTask):
274
309
  class SegmentationHead(Predictor):
275
310
  """Head for segmentation task."""
276
311
 
312
+ def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
313
+ """Initialize a new SegmentationTask.
314
+
315
+ Args:
316
+ weights: weights for cross entropy loss (Tensor of size C)
317
+ dice_loss: weather to add dice loss to cross entropy
318
+ """
319
+ super().__init__()
320
+ if weights is not None:
321
+ self.register_buffer("weights", torch.Tensor(weights))
322
+ else:
323
+ self.weights = None
324
+ self.dice_loss = dice_loss
325
+
277
326
  def forward(
278
327
  self,
279
328
  intermediates: Any,
@@ -308,7 +357,7 @@ class SegmentationHead(Predictor):
308
357
  labels = torch.stack([target["classes"] for target in targets], dim=0)
309
358
  mask = torch.stack([target["valid"] for target in targets], dim=0)
310
359
  per_pixel_loss = torch.nn.functional.cross_entropy(
311
- logits, labels, reduction="none"
360
+ logits, labels, weight=self.weights, reduction="none"
312
361
  )
313
362
  mask_sum = torch.sum(mask)
314
363
  if mask_sum > 0:
@@ -318,6 +367,9 @@ class SegmentationHead(Predictor):
318
367
  # If there are no valid pixels, we avoid dividing by zero and just let
319
368
  # the summed mask loss be zero.
320
369
  losses["cls"] = torch.sum(per_pixel_loss * mask)
370
+ if self.dice_loss:
371
+ dice_loss = DiceLoss()(outputs, labels, mask)
372
+ losses["dice"] = dice_loss
321
373
 
322
374
  return ModelOutput(
323
375
  outputs=outputs,
@@ -333,6 +385,7 @@ class SegmentationMetric(Metric):
333
385
  metric: Metric,
334
386
  pass_probabilities: bool = True,
335
387
  class_idx: int | None = None,
388
+ output_key: str | None = None,
336
389
  ):
337
390
  """Initialize a new SegmentationMetric.
338
391
 
@@ -341,12 +394,19 @@ class SegmentationMetric(Metric):
341
394
  classes from the targets and masking out invalid pixels.
342
395
  pass_probabilities: whether to pass predicted probabilities to the metric.
343
396
  If False, argmax is applied to pass the predicted classes instead.
344
- class_idx: if metric returns value for multiple classes, select this class.
397
+ class_idx: if set, return only this class index's value. For backward
398
+ compatibility with configs using standard torchmetrics. Internally
399
+ converted to output_key="cls_{class_idx}".
400
+ output_key: if the wrapped metric returns a dict (or a tensor that gets
401
+ converted to a dict), return only this key's value. For standard
402
+ torchmetrics with average=None, tensors are converted to dicts with
403
+ keys "cls_0", "cls_1", etc. If None, the full dict is returned.
345
404
  """
346
405
  super().__init__()
347
406
  self.metric = metric
348
407
  self.pass_probablities = pass_probabilities
349
408
  self.class_idx = class_idx
409
+ self.output_key = output_key
350
410
 
351
411
  def update(
352
412
  self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
@@ -376,10 +436,32 @@ class SegmentationMetric(Metric):
376
436
  self.metric.update(preds, labels)
377
437
 
378
438
  def compute(self) -> Any:
379
- """Returns the computed metric."""
439
+ """Returns the computed metric.
440
+
441
+ If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
442
+ with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
443
+ This allows uniform handling via output_key for both standard torchmetrics and
444
+ custom dict-returning metrics.
445
+ """
380
446
  result = self.metric.compute()
447
+
448
+ # Convert multi-element tensors to dict for uniform handling.
449
+ # This supports standard torchmetrics with average=None which return per-class tensors.
450
+ if isinstance(result, torch.Tensor) and result.ndim >= 1:
451
+ result = {f"cls_{i}": result[i] for i in range(len(result))}
452
+
453
+ if self.output_key is not None:
454
+ if not isinstance(result, dict):
455
+ raise TypeError(
456
+ f"output_key is set to '{self.output_key}' but metric returned "
457
+ f"{type(result).__name__} instead of dict"
458
+ )
459
+ return result[self.output_key]
381
460
  if self.class_idx is not None:
382
- result = result[self.class_idx]
461
+ # For backward compatibility: class_idx can index into the converted dict
462
+ if isinstance(result, dict):
463
+ return result[f"cls_{self.class_idx}"]
464
+ return result[self.class_idx]
383
465
  return result
384
466
 
385
467
  def reset(self) -> None:
@@ -404,6 +486,7 @@ class F1Metric(Metric):
404
486
  num_classes: int,
405
487
  score_thresholds: list[float],
406
488
  metric_mode: str = "f1",
489
+ report_per_class: bool = False,
407
490
  ):
408
491
  """Create a new F1Metric.
409
492
 
@@ -413,11 +496,14 @@ class F1Metric(Metric):
413
496
  metric is the best F1 across score thresholds.
414
497
  metric_mode: set to "precision" or "recall" to return that instead of F1
415
498
  (default "f1")
499
+ report_per_class: whether to include per-class scores in the output dict.
500
+ If False, only returns the "avg" key.
416
501
  """
417
502
  super().__init__()
418
503
  self.num_classes = num_classes
419
504
  self.score_thresholds = score_thresholds
420
505
  self.metric_mode = metric_mode
506
+ self.report_per_class = report_per_class
421
507
 
422
508
  assert self.metric_mode in ["f1", "precision", "recall"]
423
509
 
@@ -462,9 +548,10 @@ class F1Metric(Metric):
462
548
  """Compute metric.
463
549
 
464
550
  Returns:
465
- the best F1 score across score thresholds and classes.
551
+ dict with "avg" key containing mean score across classes.
552
+ If report_per_class is True, also includes "cls_N" keys for each class N.
466
553
  """
467
- best_scores = []
554
+ cls_best_scores = {}
468
555
 
469
556
  for cls_idx in range(self.num_classes):
470
557
  best_score = None
@@ -501,9 +588,12 @@ class F1Metric(Metric):
501
588
  if best_score is None or score > best_score:
502
589
  best_score = score
503
590
 
504
- best_scores.append(best_score)
591
+ cls_best_scores[f"cls_{cls_idx}"] = best_score
505
592
 
506
- return torch.mean(torch.stack(best_scores))
593
+ report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
594
+ if self.report_per_class:
595
+ report_scores.update(cls_best_scores)
596
+ return report_scores
507
597
 
508
598
 
509
599
  class MeanIoUMetric(Metric):
@@ -523,7 +613,7 @@ class MeanIoUMetric(Metric):
523
613
  num_classes: int,
524
614
  nodata_value: int | None = None,
525
615
  ignore_missing_classes: bool = False,
526
- class_idx: int | None = None,
616
+ report_per_class: bool = False,
527
617
  ):
528
618
  """Create a new MeanIoUMetric.
529
619
 
@@ -535,15 +625,14 @@ class MeanIoUMetric(Metric):
535
625
  ignore_missing_classes: whether to ignore classes that don't appear in
536
626
  either the predictions or the ground truth. If false, the IoU for a
537
627
  missing class will be 0.
538
- class_idx: only compute and return the IoU for this class. This option is
539
- provided so the user can get per-class IoU results, since Lightning
540
- only supports scalar return values from metrics.
628
+ report_per_class: whether to include per-class IoU scores in the output dict.
629
+ If False, only returns the "avg" key.
541
630
  """
542
631
  super().__init__()
543
632
  self.num_classes = num_classes
544
633
  self.nodata_value = nodata_value
545
634
  self.ignore_missing_classes = ignore_missing_classes
546
- self.class_idx = class_idx
635
+ self.report_per_class = report_per_class
547
636
 
548
637
  self.add_state(
549
638
  "intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
@@ -584,9 +673,11 @@ class MeanIoUMetric(Metric):
584
673
  """Compute metric.
585
674
 
586
675
  Returns:
587
- the mean IoU across classes.
676
+ dict with "avg" containing the mean IoU across classes.
677
+ If report_per_class is True, also includes "cls_N" keys for each valid class N.
588
678
  """
589
- per_class_scores = []
679
+ cls_scores = {}
680
+ valid_scores = []
590
681
 
591
682
  for cls_idx in range(self.num_classes):
592
683
  # Check if nodata_value is set and is one of the classes
@@ -599,6 +690,56 @@ class MeanIoUMetric(Metric):
599
690
  if union == 0 and self.ignore_missing_classes:
600
691
  continue
601
692
 
602
- per_class_scores.append(intersection / union)
693
+ score = intersection / union
694
+ cls_scores[f"cls_{cls_idx}"] = score
695
+ valid_scores.append(score)
696
+
697
+ report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
698
+ if self.report_per_class:
699
+ report_scores.update(cls_scores)
700
+ return report_scores
701
+
702
+
703
+ class DiceLoss(torch.nn.Module):
704
+ """Mean Dice Loss for segmentation.
705
+
706
+ This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
707
+ The per-class intersection is the number of pixels across all examples where
708
+ the predicted label and ground truth label are both that class, and the per-class
709
+ union is defined similarly.
710
+ """
711
+
712
+ def __init__(self, smooth: float = 1e-7):
713
+ """Initialize a new DiceLoss."""
714
+ super().__init__()
715
+ self.smooth = smooth
716
+
717
+ def forward(
718
+ self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
719
+ ) -> torch.Tensor:
720
+ """Compute Dice Loss.
721
+
722
+ Returns:
723
+ the mean Dicen Loss across classes
724
+ """
725
+ num_classes = inputs.shape[1]
726
+ targets_one_hot = (
727
+ torch.nn.functional.one_hot(targets, num_classes)
728
+ .permute(0, 3, 1, 2)
729
+ .float()
730
+ )
731
+
732
+ # Expand mask to [B, C, H, W]
733
+ mask = mask.unsqueeze(1).expand_as(inputs)
734
+
735
+ dice_per_class = []
736
+ for c in range(num_classes):
737
+ pred_c = inputs[:, c] * mask[:, c]
738
+ target_c = targets_one_hot[:, c] * mask[:, c]
739
+
740
+ intersection = (pred_c * target_c).sum()
741
+ union = pred_c.sum() + target_c.sum()
742
+ dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
743
+ dice_per_class.append(dice_c)
603
744
 
604
- return torch.mean(torch.stack(per_class_scores))
745
+ return 1 - torch.stack(dice_per_class).mean()
@@ -476,6 +476,7 @@ class GeotiffRasterFormat(RasterFormat):
476
476
  bounds: PixelBounds,
477
477
  array: npt.NDArray[Any],
478
478
  fname: str | None = None,
479
+ nodata_val: int | float | None = None,
479
480
  ) -> None:
480
481
  """Encodes raster data.
481
482
 
@@ -485,6 +486,7 @@ class GeotiffRasterFormat(RasterFormat):
485
486
  bounds: the bounds of the raster data in the projection
486
487
  array: the raster data
487
488
  fname: override the filename to save as
489
+ nodata_val: set the nodata value when writing the raster.
488
490
  """
489
491
  if fname is None:
490
492
  fname = self.fname
@@ -520,6 +522,9 @@ class GeotiffRasterFormat(RasterFormat):
520
522
  profile["tiled"] = True
521
523
  profile["blockxsize"] = self.block_size
522
524
  profile["blockysize"] = self.block_size
525
+ # Set nodata_val if provided.
526
+ if nodata_val is not None:
527
+ profile["nodata"] = nodata_val
523
528
 
524
529
  profile.update(self.geotiff_options)
525
530
 
@@ -535,6 +540,7 @@ class GeotiffRasterFormat(RasterFormat):
535
540
  bounds: PixelBounds,
536
541
  resampling: Resampling = Resampling.bilinear,
537
542
  fname: str | None = None,
543
+ nodata_val: int | float | None = None,
538
544
  ) -> npt.NDArray[Any]:
539
545
  """Decodes raster data.
540
546
 
@@ -544,6 +550,16 @@ class GeotiffRasterFormat(RasterFormat):
544
550
  bounds: the bounds to read in the given projection.
545
551
  resampling: resampling method to use in case resampling is needed.
546
552
  fname: override the filename to read from
553
+ nodata_val: override the nodata value in the raster when reading. Pixels in
554
+ bounds that are not present in the source raster will be initialized to
555
+ this value. Note that, if the raster specifies a nodata value, and
556
+ some source pixels have that value, they will still be read under their
557
+ original value; overriding the nodata value is primarily useful if the
558
+ user wants out of bounds pixels to have a different value from the
559
+ source pixels, e.g. if the source data has background and foreground
560
+ classes (with background being nodata) but we want to read it in a
561
+ different projection and have out of bounds pixels be a third "invalid"
562
+ value.
547
563
 
548
564
  Returns:
549
565
  the raster data
@@ -561,6 +577,7 @@ class GeotiffRasterFormat(RasterFormat):
561
577
  width=bounds[2] - bounds[0],
562
578
  height=bounds[3] - bounds[1],
563
579
  resampling=resampling,
580
+ src_nodata=nodata_val,
564
581
  ) as vrt:
565
582
  return vrt.read()
566
583
 
rslearn/utils/stac.py CHANGED
@@ -101,6 +101,7 @@ class StacClient:
101
101
  ids: list[str] | None = None,
102
102
  limit: int | None = None,
103
103
  query: dict[str, Any] | None = None,
104
+ sortby: list[dict[str, str]] | None = None,
104
105
  ) -> list[StacItem]:
105
106
  """Execute a STAC item search.
106
107
 
@@ -117,6 +118,7 @@ class StacClient:
117
118
  limit: number of items per page. We will read all the pages.
118
119
  query: query dict, if STAC query extension is supported by this API. See
119
120
  https://github.com/stac-api-extensions/query.
121
+ sortby: list of sort specifications, e.g. [{"field": "id", "direction": "asc"}].
120
122
 
121
123
  Returns:
122
124
  list of matching STAC items.
@@ -142,6 +144,8 @@ class StacClient:
142
144
  request_data["limit"] = limit
143
145
  if query is not None:
144
146
  request_data["query"] = query
147
+ if sortby is not None:
148
+ request_data["sortby"] = sortby
145
149
 
146
150
  # Handle pagination.
147
151
  cur_url = self.endpoint + "/search"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -3,7 +3,7 @@ rslearn/arg_parser.py,sha256=Go1MyEflcau_cziirmNd7Yhxa0WtXTAljIVE4f5H1GE,1194
3
3
  rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
4
4
  rslearn/lightning_cli.py,sha256=1eTeffUlFqBe2KnyuYyXJdNKYQClCA-PV1xr0vJyJao,17972
5
5
  rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
6
- rslearn/main.py,sha256=jRMYeU3-QvYSkTAJB69S1mHEft7-5_-RomzX1B-b8GM,28581
6
+ rslearn/main.py,sha256=rrDEoa0xCkDflH-HN2SaHt0hb-rLfXWP-kJKISZAe9s,28714
7
7
  rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
9
9
  rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
@@ -25,9 +25,9 @@ rslearn/data_sources/local_files.py,sha256=mo5W_BxBl89EPTIHNDEpXM6qBjrP225KK0Pcm
25
25
  rslearn/data_sources/openstreetmap.py,sha256=TzZfouc2Z4_xjx2v_uv7aPn4tVW3flRVQN4qBfl507E,18161
26
26
  rslearn/data_sources/planet.py,sha256=6FWQ0bl1k3jwvwp4EVGi2qs3OD1QhnKOKP36mN4HELI,9446
27
27
  rslearn/data_sources/planet_basemap.py,sha256=e9R6FlagJjg8Z6Rc1dC6zK3xMkCohz8eohXqXmd29xg,9670
28
- rslearn/data_sources/planetary_computer.py,sha256=nTJ6Jh6CNBdCEIsn7G_xLQ0Nige5evdPdqLYmWTdDl4,20722
28
+ rslearn/data_sources/planetary_computer.py,sha256=k6CO5Yim2I-frlD8r2_uBo0CQFw89mN3_5mrv0Xk2WU,26449
29
29
  rslearn/data_sources/soilgrids.py,sha256=rwO4goFPQ7lx420FvYBHYFXdihnZqn_-IjdqtxQ9j2g,12455
30
- rslearn/data_sources/stac.py,sha256=1qUbTD1fNSvBCX9QIXtyb9mGQ4K8ubRNIeEJs_I3QFU,9889
30
+ rslearn/data_sources/stac.py,sha256=Xty1JDueAAonNVLRo8vfNBhlHrVLjhmZ-uRBYbrGvtA,10753
31
31
  rslearn/data_sources/usda_cdl.py,sha256=_WvxZkm0fbXfniRs6NT8iVCbTTmVPflDhsFT2ci6_Dk,6879
32
32
  rslearn/data_sources/usgs_landsat.py,sha256=kPOb3hsZe5-guUcFZZkwzcRpYZ3Zo7Bk4E829q_xiyU,18516
33
33
  rslearn/data_sources/utils.py,sha256=v_90ALOuts7RHNcx-j8o-aQ_aFjh8ZhXrmsaa9uEGDA,11651
@@ -69,7 +69,7 @@ rslearn/models/prithvi.py,sha256=J45eC1pd4l5AGlr19Qjrjrw5PPwvYE9bNM5qCFoznmg,403
69
69
  rslearn/models/resize_features.py,sha256=U7ZIVwwToJJnwchFG59wLWWP9eikHDB_1c4OtpubxHU,1693
70
70
  rslearn/models/sam2_enc.py,sha256=WZOtlp0FjaVztW4gEVIcsFQdKArS9iblRODP0b6Oc8M,3641
71
71
  rslearn/models/satlaspretrain.py,sha256=2R48ulbtd44Qy2FYJCkllE2Wk35eZxkc79ruSgkmgcQ,3384
72
- rslearn/models/simple_time_series.py,sha256=Nfk5E3d9W-4AyLQiy-P8p-JvxmFYE3FBrvOgttjXSMw,14678
72
+ rslearn/models/simple_time_series.py,sha256=farQwt_nJVyAbgaM2UzdyqpDuIO0SLmHr9e9EVPSWCE,14678
73
73
  rslearn/models/singletask.py,sha256=9DM9a9-Mv3vVQqRhPOIXG2HHuVqVa_zuvgafeeYh4r0,1903
74
74
  rslearn/models/ssl4eo_s12.py,sha256=DOlpIj6NfjIlWyJ27m9Xo8TMlovBDstFq0ARnmAJ6qY,3919
75
75
  rslearn/models/swin.py,sha256=Xqr3SswbHP6IhwT2atZMAPF2TUzQqfxvihksb8WSeRo,6065
@@ -116,7 +116,7 @@ rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
116
116
  rslearn/train/all_patches_dataset.py,sha256=EVoYCmS3g4OfWPt5CZzwHVx9isbnWh5HIGA0RBqPFeA,21145
117
117
  rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
118
118
  rslearn/train/dataset.py,sha256=Jy1jU3GigfHaFeX9rbveX9bqy2Pd5Wh_vquD6_aFnS8,36522
119
- rslearn/train/lightning_module.py,sha256=7WBAgdJhcMHueLKE2DthSFmNYvlNUh1dB4sibkqCsRA,14761
119
+ rslearn/train/lightning_module.py,sha256=V4YoEg9PrwrgG4q9Dmv_9OBrSIK-SRPzjWtZRIfmPFg,15366
120
120
  rslearn/train/model_context.py,sha256=6o66BY6okBK-D5e0JUwPd7fxD_XehVaqxdQkJJKmQ3E,2580
121
121
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
122
122
  rslearn/train/prediction_writer.py,sha256=rW0BUaYT_F1QqmpnQlbrLiLya1iBfC5Pb78G_NlF-vA,15956
@@ -130,10 +130,10 @@ rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFN
130
130
  rslearn/train/tasks/classification.py,sha256=72ZBcbunMsdPYQN53S-4GfiLIDrr1X3Hni07dBJ0pu0,14261
131
131
  rslearn/train/tasks/detection.py,sha256=B0tfB7UGIbRtjnye3PhzLmfeQ4X7ImO3A-_LeNhBA54,21988
132
132
  rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbAMVo,3891
133
- rslearn/train/tasks/multi_task.py,sha256=1ML9mZ-kM3JfElisLOWBUn4k12gsKTFjoYYgamnyxt8,6124
134
- rslearn/train/tasks/per_pixel_regression.py,sha256=znCLFaZbGx8lvIkntDXjcX7yy7giyyBdWN-TwTGaPV4,10197
133
+ rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
134
+ rslearn/train/tasks/per_pixel_regression.py,sha256=Clrod6LQGjgNC0IAR4HLY7eCGWMHj2mk4d4moZCl4Qc,10209
135
135
  rslearn/train/tasks/regression.py,sha256=bVS_ApZSpbL0NaaM8Mu5Bsu4SBUyLpVtrPslulvvZHs,12695
136
- rslearn/train/tasks/segmentation.py,sha256=ie9ZV-sklLjQs35caiEglC1xff6dxeug_N-f_A8VosA,23034
136
+ rslearn/train/tasks/segmentation.py,sha256=LZeuveHhMQsjNOQfMcwqSI4Ux3k9zfa58A2eZHSif8Y,29391
137
137
  rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
138
138
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
139
139
  rslearn/train/transforms/concatenate.py,sha256=hVVBaxIdk1Cx8JHPirj54TGpbWAJx5y_xD7k1rmGmT0,3166
@@ -155,17 +155,17 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
155
155
  rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
156
156
  rslearn/utils/jsonargparse.py,sha256=TRyZA151KzhjJlZczIHYguEA-YxCDYaZ2IwCRgx76nM,4791
157
157
  rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
158
- rslearn/utils/raster_format.py,sha256=qZpbODF4I7BsOxf43O6vTmH2TSNw6N8PP0wsFUVvdIw,26267
158
+ rslearn/utils/raster_format.py,sha256=fwotJBadwqYSdK8UokiKOk1pOF8JMim3kP_VwLWivPI,27382
159
159
  rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
160
160
  rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
161
161
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
162
- rslearn/utils/stac.py,sha256=z93N5ZeEe1oUikX5ILMA5sQEZX276sAeMjsg0TShnSk,5776
162
+ rslearn/utils/stac.py,sha256=c8NwOCKWvUwA-FSKlxZn-t7RZYweuye53OufT0bAK4A,5996
163
163
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
164
164
  rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
165
- rslearn-0.0.22.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
- rslearn-0.0.22.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
- rslearn-0.0.22.dist-info/METADATA,sha256=UArAfc_JYTffP8-cOwQf5mxh6XUtsRv5cwzFiLWNzLU,37936
168
- rslearn-0.0.22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
169
- rslearn-0.0.22.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
- rslearn-0.0.22.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
- rslearn-0.0.22.dist-info/RECORD,,
165
+ rslearn-0.0.24.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
+ rslearn-0.0.24.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
+ rslearn-0.0.24.dist-info/METADATA,sha256=gV5mgeYPYiKWrEu7D8acOubWvg76Nn_4ICvlD7iTpcs,37936
168
+ rslearn-0.0.24.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
169
+ rslearn-0.0.24.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
+ rslearn-0.0.24.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
+ rslearn-0.0.24.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5