rslearn 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -567,3 +567,53 @@ class Naip(PlanetaryComputer):
567
567
  context=context,
568
568
  **kwargs,
569
569
  )
570
+
571
+
572
+ class CopDemGlo30(PlanetaryComputer):
573
+ """A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
574
+
575
+ See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
576
+ """
577
+
578
+ COLLECTION_NAME = "cop-dem-glo-30"
579
+ DATA_ASSET = "data"
580
+
581
+ def __init__(
582
+ self,
583
+ band_name: str = "DEM",
584
+ context: DataSourceContext = DataSourceContext(),
585
+ **kwargs: Any,
586
+ ):
587
+ """Initialize a new CopDemGlo30 instance.
588
+
589
+ Args:
590
+ band_name: band name to use if the layer config is missing from the
591
+ context.
592
+ context: the data source context.
593
+ kwargs: additional arguments to pass to PlanetaryComputer.
594
+ """
595
+ if context.layer_config is not None:
596
+ if len(context.layer_config.band_sets) != 1:
597
+ raise ValueError("expected a single band set")
598
+ if len(context.layer_config.band_sets[0].bands) != 1:
599
+ raise ValueError("expected band set to have a single band")
600
+ band_name = context.layer_config.band_sets[0].bands[0]
601
+
602
+ super().__init__(
603
+ collection_name=self.COLLECTION_NAME,
604
+ asset_bands={self.DATA_ASSET: [band_name]},
605
+ # Skip since all items should have the same asset(s).
606
+ skip_items_missing_assets=True,
607
+ context=context,
608
+ **kwargs,
609
+ )
610
+
611
+ def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
612
+ # Copernicus DEM is static; ignore item timestamps so it matches any window.
613
+ item = super()._stac_item_to_item(stac_item)
614
+ item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
615
+ return item
616
+
617
+ def _get_search_time_range(self, geometry: STGeometry) -> None:
618
+ # Copernicus DEM is static; do not filter STAC searches by time.
619
+ return None
@@ -1,6 +1,7 @@
1
1
  """A partial data source implementation providing get_items using a STAC API."""
2
2
 
3
3
  import json
4
+ from datetime import datetime
4
5
  from typing import Any
5
6
 
6
7
  import shapely
@@ -132,6 +133,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
132
133
 
133
134
  return SourceItem(stac_item.id, geom, asset_urls, properties)
134
135
 
136
+ def _get_search_time_range(
137
+ self, geometry: STGeometry
138
+ ) -> datetime | tuple[datetime, datetime] | None:
139
+ """Get time range to include in STAC API search.
140
+
141
+ By default, we filter STAC searches to the window's time range. Subclasses can
142
+ override this to disable time filtering for "static" datasets.
143
+
144
+ Args:
145
+ geometry: the geometry we are searching for.
146
+
147
+ Returns:
148
+ the time range (or timestamp) to pass to the STAC search, or None to avoid
149
+ temporal filtering in the search request.
150
+ """
151
+ # Note: StacClient.search accepts either a datetime or a (start, end) tuple.
152
+ return geometry.time_range
153
+
135
154
  def get_item_by_name(self, name: str) -> SourceItem:
136
155
  """Gets an item by name.
137
156
 
@@ -191,10 +210,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
191
210
  # for each requested geometry.
192
211
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
193
212
  logger.debug("performing STAC search for geometry %s", wgs84_geometry)
213
+ search_time_range = self._get_search_time_range(wgs84_geometry)
194
214
  stac_items = self.client.search(
195
215
  collections=[self.collection_name],
196
216
  intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
197
- date_time=wgs84_geometry.time_range,
217
+ date_time=search_time_range,
198
218
  query=self.query,
199
219
  limit=self.limit,
200
220
  )
rslearn/main.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import multiprocessing
5
+ import os
5
6
  import random
6
7
  import sys
7
8
  import time
@@ -45,6 +46,7 @@ handler_registry = {}
45
46
  ItemType = TypeVar("ItemType", bound="Item")
46
47
 
47
48
  MULTIPROCESSING_CONTEXT = "forkserver"
49
+ MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
48
50
 
49
51
 
50
52
  def register_handler(category: Any, command: str) -> Callable:
@@ -837,7 +839,8 @@ def model_predict() -> None:
837
839
  def main() -> None:
838
840
  """CLI entrypoint."""
839
841
  try:
840
- multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
842
+ mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
843
+ multiprocessing.set_start_method(mp_context)
841
844
  except RuntimeError as e:
842
845
  logger.error(
843
846
  f"Multiprocessing context already set to {multiprocessing.get_context()}: "
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
210
210
  # Fail silently for single-dataset case, which is okay
211
211
  pass
212
212
 
213
+ def on_validation_epoch_end(self) -> None:
214
+ """Compute and log validation metrics at epoch end.
215
+
216
+ We manually compute and log metrics here (instead of passing the MetricCollection
217
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
+ metrics, while log_dict expects each metric to return a scalar tensor.
219
+ """
220
+ metrics = self.val_metrics.compute()
221
+ self.log_dict(metrics)
222
+ self.val_metrics.reset()
223
+
213
224
  def on_test_epoch_end(self) -> None:
214
- """Optionally save the test metrics to a file."""
225
+ """Compute and log test metrics at epoch end, optionally save to file.
226
+
227
+ We manually compute and log metrics here (instead of passing the MetricCollection
228
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
+ metrics, while log_dict expects each metric to return a scalar tensor.
230
+ """
231
+ metrics = self.test_metrics.compute()
232
+ self.log_dict(metrics)
233
+ self.test_metrics.reset()
234
+
215
235
  if self.metrics_file:
216
236
  with open(self.metrics_file, "w") as f:
217
- metrics = self.test_metrics.compute()
218
237
  metrics_dict = {k: v.item() for k, v in metrics.items()}
219
238
  json.dump(metrics_dict, f, indent=4)
220
239
  logger.info(f"Saved metrics to {self.metrics_file}")
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
300
319
  sync_dist=True,
301
320
  )
302
321
  self.val_metrics.update(outputs, targets)
303
- self.log_dict(
304
- self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
305
- )
306
322
 
307
323
  def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
308
324
  """Compute the test loss and additional metrics.
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
340
356
  sync_dist=True,
341
357
  )
342
358
  self.test_metrics.update(outputs, targets)
343
- self.log_dict(
344
- self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
345
- )
346
359
 
347
360
  if self.visualize_dir:
348
361
  for inp, target, output, metadata in zip(
@@ -118,13 +118,16 @@ class MultiTask(Task):
118
118
 
119
119
  def get_metrics(self) -> MetricCollection:
120
120
  """Get metrics for this task."""
121
- metrics = []
121
+ # Flatten metrics into a single dict with task_name/ prefix to avoid nested
122
+ # MetricCollections. Nested collections cause issues because MetricCollection
123
+ # has postfix=None which breaks MetricCollection.compute().
124
+ all_metrics = {}
122
125
  for task_name, task in self.tasks.items():
123
- cur_metrics = {}
124
126
  for metric_name, metric in task.get_metrics().items():
125
- cur_metrics[metric_name] = MetricWrapper(task_name, metric)
126
- metrics.append(MetricCollection(cur_metrics, prefix=f"{task_name}/"))
127
- return MetricCollection(metrics)
127
+ all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
128
+ task_name, metric
129
+ )
130
+ return MetricCollection(all_metrics)
128
131
 
129
132
 
130
133
  class MetricWrapper(Metric):
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
100
100
  raise ValueError(
101
101
  f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
102
102
  )
103
- return (raw_output / self.scale_factor).cpu().numpy()
103
+ return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
104
104
 
105
105
  def visualize(
106
106
  self,
@@ -53,6 +53,7 @@ class SegmentationTask(BasicTask):
53
53
  enable_accuracy_metric: bool = True,
54
54
  enable_miou_metric: bool = False,
55
55
  enable_f1_metric: bool = False,
56
+ report_metric_per_class: bool = False,
56
57
  f1_metric_thresholds: list[list[float]] = [[0.5]],
57
58
  metric_kwargs: dict[str, Any] = {},
58
59
  miou_metric_kwargs: dict[str, Any] = {},
@@ -74,6 +75,8 @@ class SegmentationTask(BasicTask):
74
75
  enable_accuracy_metric: whether to enable the accuracy metric (default
75
76
  true).
76
77
  enable_f1_metric: whether to enable the F1 metric (default false).
78
+ report_metric_per_class: whether to report chosen metrics for each class, in
79
+ addition to the average score across classes.
77
80
  enable_miou_metric: whether to enable the mean IoU metric (default false).
78
81
  f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
79
82
  Each inner list is used to initialize a separate F1 metric where the
@@ -107,6 +110,7 @@ class SegmentationTask(BasicTask):
107
110
  self.enable_accuracy_metric = enable_accuracy_metric
108
111
  self.enable_f1_metric = enable_f1_metric
109
112
  self.enable_miou_metric = enable_miou_metric
113
+ self.report_metric_per_class = report_metric_per_class
110
114
  self.f1_metric_thresholds = f1_metric_thresholds
111
115
  self.metric_kwargs = metric_kwargs
112
116
  self.miou_metric_kwargs = miou_metric_kwargs
@@ -237,29 +241,41 @@ class SegmentationTask(BasicTask):
237
241
  # Metric name can't contain "." so change to ",".
238
242
  suffix = "_" + str(thresholds[0]).replace(".", ",")
239
243
 
244
+ # Create one metric per type - it returns a dict with "avg" and optionally per-class keys
240
245
  metrics["F1" + suffix] = SegmentationMetric(
241
- F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
246
+ F1Metric(
247
+ num_classes=self.num_classes,
248
+ score_thresholds=thresholds,
249
+ report_per_class=self.report_metric_per_class,
250
+ ),
242
251
  )
243
252
  metrics["precision" + suffix] = SegmentationMetric(
244
253
  F1Metric(
245
254
  num_classes=self.num_classes,
246
255
  score_thresholds=thresholds,
247
256
  metric_mode="precision",
248
- )
257
+ report_per_class=self.report_metric_per_class,
258
+ ),
249
259
  )
250
260
  metrics["recall" + suffix] = SegmentationMetric(
251
261
  F1Metric(
252
262
  num_classes=self.num_classes,
253
263
  score_thresholds=thresholds,
254
264
  metric_mode="recall",
255
- )
265
+ report_per_class=self.report_metric_per_class,
266
+ ),
256
267
  )
257
268
 
258
269
  if self.enable_miou_metric:
259
- miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
270
+ miou_metric_kwargs: dict[str, Any] = dict(
271
+ num_classes=self.num_classes,
272
+ report_per_class=self.report_metric_per_class,
273
+ )
260
274
  if self.nodata_value is not None:
261
275
  miou_metric_kwargs["nodata_value"] = self.nodata_value
262
276
  miou_metric_kwargs.update(self.miou_metric_kwargs)
277
+
278
+ # Create one metric - it returns a dict with "avg" and optionally per-class keys
263
279
  metrics["mean_iou"] = SegmentationMetric(
264
280
  MeanIoUMetric(**miou_metric_kwargs),
265
281
  pass_probabilities=False,
@@ -274,6 +290,20 @@ class SegmentationTask(BasicTask):
274
290
  class SegmentationHead(Predictor):
275
291
  """Head for segmentation task."""
276
292
 
293
+ def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
294
+ """Initialize a new SegmentationTask.
295
+
296
+ Args:
297
+ weights: weights for cross entropy loss (Tensor of size C)
298
+ dice_loss: weather to add dice loss to cross entropy
299
+ """
300
+ super().__init__()
301
+ if weights is not None:
302
+ self.register_buffer("weights", torch.Tensor(weights))
303
+ else:
304
+ self.weights = None
305
+ self.dice_loss = dice_loss
306
+
277
307
  def forward(
278
308
  self,
279
309
  intermediates: Any,
@@ -308,7 +338,7 @@ class SegmentationHead(Predictor):
308
338
  labels = torch.stack([target["classes"] for target in targets], dim=0)
309
339
  mask = torch.stack([target["valid"] for target in targets], dim=0)
310
340
  per_pixel_loss = torch.nn.functional.cross_entropy(
311
- logits, labels, reduction="none"
341
+ logits, labels, weight=self.weights, reduction="none"
312
342
  )
313
343
  mask_sum = torch.sum(mask)
314
344
  if mask_sum > 0:
@@ -318,6 +348,9 @@ class SegmentationHead(Predictor):
318
348
  # If there are no valid pixels, we avoid dividing by zero and just let
319
349
  # the summed mask loss be zero.
320
350
  losses["cls"] = torch.sum(per_pixel_loss * mask)
351
+ if self.dice_loss:
352
+ dice_loss = DiceLoss()(outputs, labels, mask)
353
+ losses["dice"] = dice_loss
321
354
 
322
355
  return ModelOutput(
323
356
  outputs=outputs,
@@ -333,6 +366,7 @@ class SegmentationMetric(Metric):
333
366
  metric: Metric,
334
367
  pass_probabilities: bool = True,
335
368
  class_idx: int | None = None,
369
+ output_key: str | None = None,
336
370
  ):
337
371
  """Initialize a new SegmentationMetric.
338
372
 
@@ -341,12 +375,19 @@ class SegmentationMetric(Metric):
341
375
  classes from the targets and masking out invalid pixels.
342
376
  pass_probabilities: whether to pass predicted probabilities to the metric.
343
377
  If False, argmax is applied to pass the predicted classes instead.
344
- class_idx: if metric returns value for multiple classes, select this class.
378
+ class_idx: if set, return only this class index's value. For backward
379
+ compatibility with configs using standard torchmetrics. Internally
380
+ converted to output_key="cls_{class_idx}".
381
+ output_key: if the wrapped metric returns a dict (or a tensor that gets
382
+ converted to a dict), return only this key's value. For standard
383
+ torchmetrics with average=None, tensors are converted to dicts with
384
+ keys "cls_0", "cls_1", etc. If None, the full dict is returned.
345
385
  """
346
386
  super().__init__()
347
387
  self.metric = metric
348
388
  self.pass_probablities = pass_probabilities
349
389
  self.class_idx = class_idx
390
+ self.output_key = output_key
350
391
 
351
392
  def update(
352
393
  self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
@@ -376,10 +417,32 @@ class SegmentationMetric(Metric):
376
417
  self.metric.update(preds, labels)
377
418
 
378
419
  def compute(self) -> Any:
379
- """Returns the computed metric."""
420
+ """Returns the computed metric.
421
+
422
+ If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
423
+ with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
424
+ This allows uniform handling via output_key for both standard torchmetrics and
425
+ custom dict-returning metrics.
426
+ """
380
427
  result = self.metric.compute()
428
+
429
+ # Convert multi-element tensors to dict for uniform handling.
430
+ # This supports standard torchmetrics with average=None which return per-class tensors.
431
+ if isinstance(result, torch.Tensor) and result.ndim >= 1:
432
+ result = {f"cls_{i}": result[i] for i in range(len(result))}
433
+
434
+ if self.output_key is not None:
435
+ if not isinstance(result, dict):
436
+ raise TypeError(
437
+ f"output_key is set to '{self.output_key}' but metric returned "
438
+ f"{type(result).__name__} instead of dict"
439
+ )
440
+ return result[self.output_key]
381
441
  if self.class_idx is not None:
382
- result = result[self.class_idx]
442
+ # For backward compatibility: class_idx can index into the converted dict
443
+ if isinstance(result, dict):
444
+ return result[f"cls_{self.class_idx}"]
445
+ return result[self.class_idx]
383
446
  return result
384
447
 
385
448
  def reset(self) -> None:
@@ -404,6 +467,7 @@ class F1Metric(Metric):
404
467
  num_classes: int,
405
468
  score_thresholds: list[float],
406
469
  metric_mode: str = "f1",
470
+ report_per_class: bool = False,
407
471
  ):
408
472
  """Create a new F1Metric.
409
473
 
@@ -413,11 +477,14 @@ class F1Metric(Metric):
413
477
  metric is the best F1 across score thresholds.
414
478
  metric_mode: set to "precision" or "recall" to return that instead of F1
415
479
  (default "f1")
480
+ report_per_class: whether to include per-class scores in the output dict.
481
+ If False, only returns the "avg" key.
416
482
  """
417
483
  super().__init__()
418
484
  self.num_classes = num_classes
419
485
  self.score_thresholds = score_thresholds
420
486
  self.metric_mode = metric_mode
487
+ self.report_per_class = report_per_class
421
488
 
422
489
  assert self.metric_mode in ["f1", "precision", "recall"]
423
490
 
@@ -462,9 +529,10 @@ class F1Metric(Metric):
462
529
  """Compute metric.
463
530
 
464
531
  Returns:
465
- the best F1 score across score thresholds and classes.
532
+ dict with "avg" key containing mean score across classes.
533
+ If report_per_class is True, also includes "cls_N" keys for each class N.
466
534
  """
467
- best_scores = []
535
+ cls_best_scores = {}
468
536
 
469
537
  for cls_idx in range(self.num_classes):
470
538
  best_score = None
@@ -501,9 +569,12 @@ class F1Metric(Metric):
501
569
  if best_score is None or score > best_score:
502
570
  best_score = score
503
571
 
504
- best_scores.append(best_score)
572
+ cls_best_scores[f"cls_{cls_idx}"] = best_score
505
573
 
506
- return torch.mean(torch.stack(best_scores))
574
+ report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
575
+ if self.report_per_class:
576
+ report_scores.update(cls_best_scores)
577
+ return report_scores
507
578
 
508
579
 
509
580
  class MeanIoUMetric(Metric):
@@ -523,7 +594,7 @@ class MeanIoUMetric(Metric):
523
594
  num_classes: int,
524
595
  nodata_value: int | None = None,
525
596
  ignore_missing_classes: bool = False,
526
- class_idx: int | None = None,
597
+ report_per_class: bool = False,
527
598
  ):
528
599
  """Create a new MeanIoUMetric.
529
600
 
@@ -535,15 +606,14 @@ class MeanIoUMetric(Metric):
535
606
  ignore_missing_classes: whether to ignore classes that don't appear in
536
607
  either the predictions or the ground truth. If false, the IoU for a
537
608
  missing class will be 0.
538
- class_idx: only compute and return the IoU for this class. This option is
539
- provided so the user can get per-class IoU results, since Lightning
540
- only supports scalar return values from metrics.
609
+ report_per_class: whether to include per-class IoU scores in the output dict.
610
+ If False, only returns the "avg" key.
541
611
  """
542
612
  super().__init__()
543
613
  self.num_classes = num_classes
544
614
  self.nodata_value = nodata_value
545
615
  self.ignore_missing_classes = ignore_missing_classes
546
- self.class_idx = class_idx
616
+ self.report_per_class = report_per_class
547
617
 
548
618
  self.add_state(
549
619
  "intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
@@ -584,9 +654,11 @@ class MeanIoUMetric(Metric):
584
654
  """Compute metric.
585
655
 
586
656
  Returns:
587
- the mean IoU across classes.
657
+ dict with "avg" containing the mean IoU across classes.
658
+ If report_per_class is True, also includes "cls_N" keys for each valid class N.
588
659
  """
589
- per_class_scores = []
660
+ cls_scores = {}
661
+ valid_scores = []
590
662
 
591
663
  for cls_idx in range(self.num_classes):
592
664
  # Check if nodata_value is set and is one of the classes
@@ -599,6 +671,56 @@ class MeanIoUMetric(Metric):
599
671
  if union == 0 and self.ignore_missing_classes:
600
672
  continue
601
673
 
602
- per_class_scores.append(intersection / union)
674
+ score = intersection / union
675
+ cls_scores[f"cls_{cls_idx}"] = score
676
+ valid_scores.append(score)
677
+
678
+ report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
679
+ if self.report_per_class:
680
+ report_scores.update(cls_scores)
681
+ return report_scores
682
+
683
+
684
+ class DiceLoss(torch.nn.Module):
685
+ """Mean Dice Loss for segmentation.
686
+
687
+ This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
688
+ The per-class intersection is the number of pixels across all examples where
689
+ the predicted label and ground truth label are both that class, and the per-class
690
+ union is defined similarly.
691
+ """
692
+
693
+ def __init__(self, smooth: float = 1e-7):
694
+ """Initialize a new DiceLoss."""
695
+ super().__init__()
696
+ self.smooth = smooth
697
+
698
+ def forward(
699
+ self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
700
+ ) -> torch.Tensor:
701
+ """Compute Dice Loss.
702
+
703
+ Returns:
704
+ the mean Dicen Loss across classes
705
+ """
706
+ num_classes = inputs.shape[1]
707
+ targets_one_hot = (
708
+ torch.nn.functional.one_hot(targets, num_classes)
709
+ .permute(0, 3, 1, 2)
710
+ .float()
711
+ )
712
+
713
+ # Expand mask to [B, C, H, W]
714
+ mask = mask.unsqueeze(1).expand_as(inputs)
715
+
716
+ dice_per_class = []
717
+ for c in range(num_classes):
718
+ pred_c = inputs[:, c] * mask[:, c]
719
+ target_c = targets_one_hot[:, c] * mask[:, c]
720
+
721
+ intersection = (pred_c * target_c).sum()
722
+ union = pred_c.sum() + target_c.sum()
723
+ dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
724
+ dice_per_class.append(dice_c)
603
725
 
604
- return torch.mean(torch.stack(per_class_scores))
726
+ return 1 - torch.stack(dice_per_class).mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.22
3
+ Version: 0.0.23
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -3,7 +3,7 @@ rslearn/arg_parser.py,sha256=Go1MyEflcau_cziirmNd7Yhxa0WtXTAljIVE4f5H1GE,1194
3
3
  rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
4
4
  rslearn/lightning_cli.py,sha256=1eTeffUlFqBe2KnyuYyXJdNKYQClCA-PV1xr0vJyJao,17972
5
5
  rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
6
- rslearn/main.py,sha256=jRMYeU3-QvYSkTAJB69S1mHEft7-5_-RomzX1B-b8GM,28581
6
+ rslearn/main.py,sha256=rrDEoa0xCkDflH-HN2SaHt0hb-rLfXWP-kJKISZAe9s,28714
7
7
  rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
9
9
  rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
@@ -25,9 +25,9 @@ rslearn/data_sources/local_files.py,sha256=mo5W_BxBl89EPTIHNDEpXM6qBjrP225KK0Pcm
25
25
  rslearn/data_sources/openstreetmap.py,sha256=TzZfouc2Z4_xjx2v_uv7aPn4tVW3flRVQN4qBfl507E,18161
26
26
  rslearn/data_sources/planet.py,sha256=6FWQ0bl1k3jwvwp4EVGi2qs3OD1QhnKOKP36mN4HELI,9446
27
27
  rslearn/data_sources/planet_basemap.py,sha256=e9R6FlagJjg8Z6Rc1dC6zK3xMkCohz8eohXqXmd29xg,9670
28
- rslearn/data_sources/planetary_computer.py,sha256=nTJ6Jh6CNBdCEIsn7G_xLQ0Nige5evdPdqLYmWTdDl4,20722
28
+ rslearn/data_sources/planetary_computer.py,sha256=8kVatSXnwPUZljVOjj9vnVbOsmWhRdROi5YTiCmYmII,22594
29
29
  rslearn/data_sources/soilgrids.py,sha256=rwO4goFPQ7lx420FvYBHYFXdihnZqn_-IjdqtxQ9j2g,12455
30
- rslearn/data_sources/stac.py,sha256=1qUbTD1fNSvBCX9QIXtyb9mGQ4K8ubRNIeEJs_I3QFU,9889
30
+ rslearn/data_sources/stac.py,sha256=l7V1QzvpNtoH_funiTSl1J8Lj1P3nMj24_fRpgCAslQ,10692
31
31
  rslearn/data_sources/usda_cdl.py,sha256=_WvxZkm0fbXfniRs6NT8iVCbTTmVPflDhsFT2ci6_Dk,6879
32
32
  rslearn/data_sources/usgs_landsat.py,sha256=kPOb3hsZe5-guUcFZZkwzcRpYZ3Zo7Bk4E829q_xiyU,18516
33
33
  rslearn/data_sources/utils.py,sha256=v_90ALOuts7RHNcx-j8o-aQ_aFjh8ZhXrmsaa9uEGDA,11651
@@ -116,7 +116,7 @@ rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
116
116
  rslearn/train/all_patches_dataset.py,sha256=EVoYCmS3g4OfWPt5CZzwHVx9isbnWh5HIGA0RBqPFeA,21145
117
117
  rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
118
118
  rslearn/train/dataset.py,sha256=Jy1jU3GigfHaFeX9rbveX9bqy2Pd5Wh_vquD6_aFnS8,36522
119
- rslearn/train/lightning_module.py,sha256=7WBAgdJhcMHueLKE2DthSFmNYvlNUh1dB4sibkqCsRA,14761
119
+ rslearn/train/lightning_module.py,sha256=V4YoEg9PrwrgG4q9Dmv_9OBrSIK-SRPzjWtZRIfmPFg,15366
120
120
  rslearn/train/model_context.py,sha256=6o66BY6okBK-D5e0JUwPd7fxD_XehVaqxdQkJJKmQ3E,2580
121
121
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
122
122
  rslearn/train/prediction_writer.py,sha256=rW0BUaYT_F1QqmpnQlbrLiLya1iBfC5Pb78G_NlF-vA,15956
@@ -130,10 +130,10 @@ rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFN
130
130
  rslearn/train/tasks/classification.py,sha256=72ZBcbunMsdPYQN53S-4GfiLIDrr1X3Hni07dBJ0pu0,14261
131
131
  rslearn/train/tasks/detection.py,sha256=B0tfB7UGIbRtjnye3PhzLmfeQ4X7ImO3A-_LeNhBA54,21988
132
132
  rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbAMVo,3891
133
- rslearn/train/tasks/multi_task.py,sha256=1ML9mZ-kM3JfElisLOWBUn4k12gsKTFjoYYgamnyxt8,6124
134
- rslearn/train/tasks/per_pixel_regression.py,sha256=znCLFaZbGx8lvIkntDXjcX7yy7giyyBdWN-TwTGaPV4,10197
133
+ rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
134
+ rslearn/train/tasks/per_pixel_regression.py,sha256=Clrod6LQGjgNC0IAR4HLY7eCGWMHj2mk4d4moZCl4Qc,10209
135
135
  rslearn/train/tasks/regression.py,sha256=bVS_ApZSpbL0NaaM8Mu5Bsu4SBUyLpVtrPslulvvZHs,12695
136
- rslearn/train/tasks/segmentation.py,sha256=ie9ZV-sklLjQs35caiEglC1xff6dxeug_N-f_A8VosA,23034
136
+ rslearn/train/tasks/segmentation.py,sha256=Y3Sm2oOzR3yJCpagwBmp1yCwa024MQN2v1PcpiaWBf8,28425
137
137
  rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
138
138
  rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
139
139
  rslearn/train/transforms/concatenate.py,sha256=hVVBaxIdk1Cx8JHPirj54TGpbWAJx5y_xD7k1rmGmT0,3166
@@ -162,10 +162,10 @@ rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4
162
162
  rslearn/utils/stac.py,sha256=z93N5ZeEe1oUikX5ILMA5sQEZX276sAeMjsg0TShnSk,5776
163
163
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
164
164
  rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
165
- rslearn-0.0.22.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
- rslearn-0.0.22.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
- rslearn-0.0.22.dist-info/METADATA,sha256=UArAfc_JYTffP8-cOwQf5mxh6XUtsRv5cwzFiLWNzLU,37936
168
- rslearn-0.0.22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
169
- rslearn-0.0.22.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
- rslearn-0.0.22.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
- rslearn-0.0.22.dist-info/RECORD,,
165
+ rslearn-0.0.23.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
166
+ rslearn-0.0.23.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
167
+ rslearn-0.0.23.dist-info/METADATA,sha256=YFo7HcByJFrlgbSqcCUat2Z7nn1RU0aQzR0InaDSKEg,37936
168
+ rslearn-0.0.23.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
169
+ rslearn-0.0.23.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
170
+ rslearn-0.0.23.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
171
+ rslearn-0.0.23.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5