rslearn 0.0.22__tar.gz → 0.0.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.22/rslearn.egg-info → rslearn-0.0.23}/PKG-INFO +1 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/pyproject.toml +1 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/planetary_computer.py +50 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/stac.py +21 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/main.py +4 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/lightning_module.py +21 -8
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/multi_task.py +8 -5
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/per_pixel_regression.py +1 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/segmentation.py +143 -21
- {rslearn-0.0.22 → rslearn-0.0.23/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.22 → rslearn-0.0.23}/LICENSE +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/NOTICE +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/README.md +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/const.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/aws_sentinel2_element84.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/soilgrids.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/storage/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/storage/file.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/storage/storage.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/lightning_cli.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/attention_pooling.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/component.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/concatenate_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/py.typed +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/template_params.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/all_patches_dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/model_context.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/resize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/stac.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.23}/setup.cfg +0 -0
|
@@ -567,3 +567,53 @@ class Naip(PlanetaryComputer):
|
|
|
567
567
|
context=context,
|
|
568
568
|
**kwargs,
|
|
569
569
|
)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class CopDemGlo30(PlanetaryComputer):
|
|
573
|
+
"""A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
|
|
574
|
+
|
|
575
|
+
See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
COLLECTION_NAME = "cop-dem-glo-30"
|
|
579
|
+
DATA_ASSET = "data"
|
|
580
|
+
|
|
581
|
+
def __init__(
|
|
582
|
+
self,
|
|
583
|
+
band_name: str = "DEM",
|
|
584
|
+
context: DataSourceContext = DataSourceContext(),
|
|
585
|
+
**kwargs: Any,
|
|
586
|
+
):
|
|
587
|
+
"""Initialize a new CopDemGlo30 instance.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
band_name: band name to use if the layer config is missing from the
|
|
591
|
+
context.
|
|
592
|
+
context: the data source context.
|
|
593
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
594
|
+
"""
|
|
595
|
+
if context.layer_config is not None:
|
|
596
|
+
if len(context.layer_config.band_sets) != 1:
|
|
597
|
+
raise ValueError("expected a single band set")
|
|
598
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
599
|
+
raise ValueError("expected band set to have a single band")
|
|
600
|
+
band_name = context.layer_config.band_sets[0].bands[0]
|
|
601
|
+
|
|
602
|
+
super().__init__(
|
|
603
|
+
collection_name=self.COLLECTION_NAME,
|
|
604
|
+
asset_bands={self.DATA_ASSET: [band_name]},
|
|
605
|
+
# Skip since all items should have the same asset(s).
|
|
606
|
+
skip_items_missing_assets=True,
|
|
607
|
+
context=context,
|
|
608
|
+
**kwargs,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
|
|
612
|
+
# Copernicus DEM is static; ignore item timestamps so it matches any window.
|
|
613
|
+
item = super()._stac_item_to_item(stac_item)
|
|
614
|
+
item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
|
|
615
|
+
return item
|
|
616
|
+
|
|
617
|
+
def _get_search_time_range(self, geometry: STGeometry) -> None:
|
|
618
|
+
# Copernicus DEM is static; do not filter STAC searches by time.
|
|
619
|
+
return None
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""A partial data source implementation providing get_items using a STAC API."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
6
7
|
import shapely
|
|
@@ -132,6 +133,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
132
133
|
|
|
133
134
|
return SourceItem(stac_item.id, geom, asset_urls, properties)
|
|
134
135
|
|
|
136
|
+
def _get_search_time_range(
|
|
137
|
+
self, geometry: STGeometry
|
|
138
|
+
) -> datetime | tuple[datetime, datetime] | None:
|
|
139
|
+
"""Get time range to include in STAC API search.
|
|
140
|
+
|
|
141
|
+
By default, we filter STAC searches to the window's time range. Subclasses can
|
|
142
|
+
override this to disable time filtering for "static" datasets.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
geometry: the geometry we are searching for.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
the time range (or timestamp) to pass to the STAC search, or None to avoid
|
|
149
|
+
temporal filtering in the search request.
|
|
150
|
+
"""
|
|
151
|
+
# Note: StacClient.search accepts either a datetime or a (start, end) tuple.
|
|
152
|
+
return geometry.time_range
|
|
153
|
+
|
|
135
154
|
def get_item_by_name(self, name: str) -> SourceItem:
|
|
136
155
|
"""Gets an item by name.
|
|
137
156
|
|
|
@@ -191,10 +210,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
191
210
|
# for each requested geometry.
|
|
192
211
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
193
212
|
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
213
|
+
search_time_range = self._get_search_time_range(wgs84_geometry)
|
|
194
214
|
stac_items = self.client.search(
|
|
195
215
|
collections=[self.collection_name],
|
|
196
216
|
intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
|
|
197
|
-
date_time=
|
|
217
|
+
date_time=search_time_range,
|
|
198
218
|
query=self.query,
|
|
199
219
|
limit=self.limit,
|
|
200
220
|
)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import multiprocessing
|
|
5
|
+
import os
|
|
5
6
|
import random
|
|
6
7
|
import sys
|
|
7
8
|
import time
|
|
@@ -45,6 +46,7 @@ handler_registry = {}
|
|
|
45
46
|
ItemType = TypeVar("ItemType", bound="Item")
|
|
46
47
|
|
|
47
48
|
MULTIPROCESSING_CONTEXT = "forkserver"
|
|
49
|
+
MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def register_handler(category: Any, command: str) -> Callable:
|
|
@@ -837,7 +839,8 @@ def model_predict() -> None:
|
|
|
837
839
|
def main() -> None:
|
|
838
840
|
"""CLI entrypoint."""
|
|
839
841
|
try:
|
|
840
|
-
|
|
842
|
+
mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
|
|
843
|
+
multiprocessing.set_start_method(mp_context)
|
|
841
844
|
except RuntimeError as e:
|
|
842
845
|
logger.error(
|
|
843
846
|
f"Multiprocessing context already set to {multiprocessing.get_context()}: "
|
|
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
210
|
# Fail silently for single-dataset case, which is okay
|
|
211
211
|
pass
|
|
212
212
|
|
|
213
|
+
def on_validation_epoch_end(self) -> None:
|
|
214
|
+
"""Compute and log validation metrics at epoch end.
|
|
215
|
+
|
|
216
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
219
|
+
"""
|
|
220
|
+
metrics = self.val_metrics.compute()
|
|
221
|
+
self.log_dict(metrics)
|
|
222
|
+
self.val_metrics.reset()
|
|
223
|
+
|
|
213
224
|
def on_test_epoch_end(self) -> None:
|
|
214
|
-
"""
|
|
225
|
+
"""Compute and log test metrics at epoch end, optionally save to file.
|
|
226
|
+
|
|
227
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
230
|
+
"""
|
|
231
|
+
metrics = self.test_metrics.compute()
|
|
232
|
+
self.log_dict(metrics)
|
|
233
|
+
self.test_metrics.reset()
|
|
234
|
+
|
|
215
235
|
if self.metrics_file:
|
|
216
236
|
with open(self.metrics_file, "w") as f:
|
|
217
|
-
metrics = self.test_metrics.compute()
|
|
218
237
|
metrics_dict = {k: v.item() for k, v in metrics.items()}
|
|
219
238
|
json.dump(metrics_dict, f, indent=4)
|
|
220
239
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
300
319
|
sync_dist=True,
|
|
301
320
|
)
|
|
302
321
|
self.val_metrics.update(outputs, targets)
|
|
303
|
-
self.log_dict(
|
|
304
|
-
self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
305
|
-
)
|
|
306
322
|
|
|
307
323
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
|
308
324
|
"""Compute the test loss and additional metrics.
|
|
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
340
356
|
sync_dist=True,
|
|
341
357
|
)
|
|
342
358
|
self.test_metrics.update(outputs, targets)
|
|
343
|
-
self.log_dict(
|
|
344
|
-
self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
345
|
-
)
|
|
346
359
|
|
|
347
360
|
if self.visualize_dir:
|
|
348
361
|
for inp, target, output, metadata in zip(
|
|
@@ -118,13 +118,16 @@ class MultiTask(Task):
|
|
|
118
118
|
|
|
119
119
|
def get_metrics(self) -> MetricCollection:
|
|
120
120
|
"""Get metrics for this task."""
|
|
121
|
-
metrics
|
|
121
|
+
# Flatten metrics into a single dict with task_name/ prefix to avoid nested
|
|
122
|
+
# MetricCollections. Nested collections cause issues because MetricCollection
|
|
123
|
+
# has postfix=None which breaks MetricCollection.compute().
|
|
124
|
+
all_metrics = {}
|
|
122
125
|
for task_name, task in self.tasks.items():
|
|
123
|
-
cur_metrics = {}
|
|
124
126
|
for metric_name, metric in task.get_metrics().items():
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
|
|
128
|
+
task_name, metric
|
|
129
|
+
)
|
|
130
|
+
return MetricCollection(all_metrics)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
class MetricWrapper(Metric):
|
|
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
102
102
|
)
|
|
103
|
-
return (raw_output / self.scale_factor).cpu().numpy()
|
|
103
|
+
return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
|
|
104
104
|
|
|
105
105
|
def visualize(
|
|
106
106
|
self,
|
|
@@ -53,6 +53,7 @@ class SegmentationTask(BasicTask):
|
|
|
53
53
|
enable_accuracy_metric: bool = True,
|
|
54
54
|
enable_miou_metric: bool = False,
|
|
55
55
|
enable_f1_metric: bool = False,
|
|
56
|
+
report_metric_per_class: bool = False,
|
|
56
57
|
f1_metric_thresholds: list[list[float]] = [[0.5]],
|
|
57
58
|
metric_kwargs: dict[str, Any] = {},
|
|
58
59
|
miou_metric_kwargs: dict[str, Any] = {},
|
|
@@ -74,6 +75,8 @@ class SegmentationTask(BasicTask):
|
|
|
74
75
|
enable_accuracy_metric: whether to enable the accuracy metric (default
|
|
75
76
|
true).
|
|
76
77
|
enable_f1_metric: whether to enable the F1 metric (default false).
|
|
78
|
+
report_metric_per_class: whether to report chosen metrics for each class, in
|
|
79
|
+
addition to the average score across classes.
|
|
77
80
|
enable_miou_metric: whether to enable the mean IoU metric (default false).
|
|
78
81
|
f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
|
|
79
82
|
Each inner list is used to initialize a separate F1 metric where the
|
|
@@ -107,6 +110,7 @@ class SegmentationTask(BasicTask):
|
|
|
107
110
|
self.enable_accuracy_metric = enable_accuracy_metric
|
|
108
111
|
self.enable_f1_metric = enable_f1_metric
|
|
109
112
|
self.enable_miou_metric = enable_miou_metric
|
|
113
|
+
self.report_metric_per_class = report_metric_per_class
|
|
110
114
|
self.f1_metric_thresholds = f1_metric_thresholds
|
|
111
115
|
self.metric_kwargs = metric_kwargs
|
|
112
116
|
self.miou_metric_kwargs = miou_metric_kwargs
|
|
@@ -237,29 +241,41 @@ class SegmentationTask(BasicTask):
|
|
|
237
241
|
# Metric name can't contain "." so change to ",".
|
|
238
242
|
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
239
243
|
|
|
244
|
+
# Create one metric per type - it returns a dict with "avg" and optionally per-class keys
|
|
240
245
|
metrics["F1" + suffix] = SegmentationMetric(
|
|
241
|
-
F1Metric(
|
|
246
|
+
F1Metric(
|
|
247
|
+
num_classes=self.num_classes,
|
|
248
|
+
score_thresholds=thresholds,
|
|
249
|
+
report_per_class=self.report_metric_per_class,
|
|
250
|
+
),
|
|
242
251
|
)
|
|
243
252
|
metrics["precision" + suffix] = SegmentationMetric(
|
|
244
253
|
F1Metric(
|
|
245
254
|
num_classes=self.num_classes,
|
|
246
255
|
score_thresholds=thresholds,
|
|
247
256
|
metric_mode="precision",
|
|
248
|
-
|
|
257
|
+
report_per_class=self.report_metric_per_class,
|
|
258
|
+
),
|
|
249
259
|
)
|
|
250
260
|
metrics["recall" + suffix] = SegmentationMetric(
|
|
251
261
|
F1Metric(
|
|
252
262
|
num_classes=self.num_classes,
|
|
253
263
|
score_thresholds=thresholds,
|
|
254
264
|
metric_mode="recall",
|
|
255
|
-
|
|
265
|
+
report_per_class=self.report_metric_per_class,
|
|
266
|
+
),
|
|
256
267
|
)
|
|
257
268
|
|
|
258
269
|
if self.enable_miou_metric:
|
|
259
|
-
miou_metric_kwargs: dict[str, Any] = dict(
|
|
270
|
+
miou_metric_kwargs: dict[str, Any] = dict(
|
|
271
|
+
num_classes=self.num_classes,
|
|
272
|
+
report_per_class=self.report_metric_per_class,
|
|
273
|
+
)
|
|
260
274
|
if self.nodata_value is not None:
|
|
261
275
|
miou_metric_kwargs["nodata_value"] = self.nodata_value
|
|
262
276
|
miou_metric_kwargs.update(self.miou_metric_kwargs)
|
|
277
|
+
|
|
278
|
+
# Create one metric - it returns a dict with "avg" and optionally per-class keys
|
|
263
279
|
metrics["mean_iou"] = SegmentationMetric(
|
|
264
280
|
MeanIoUMetric(**miou_metric_kwargs),
|
|
265
281
|
pass_probabilities=False,
|
|
@@ -274,6 +290,20 @@ class SegmentationTask(BasicTask):
|
|
|
274
290
|
class SegmentationHead(Predictor):
|
|
275
291
|
"""Head for segmentation task."""
|
|
276
292
|
|
|
293
|
+
def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
|
|
294
|
+
"""Initialize a new SegmentationTask.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
weights: weights for cross entropy loss (Tensor of size C)
|
|
298
|
+
dice_loss: weather to add dice loss to cross entropy
|
|
299
|
+
"""
|
|
300
|
+
super().__init__()
|
|
301
|
+
if weights is not None:
|
|
302
|
+
self.register_buffer("weights", torch.Tensor(weights))
|
|
303
|
+
else:
|
|
304
|
+
self.weights = None
|
|
305
|
+
self.dice_loss = dice_loss
|
|
306
|
+
|
|
277
307
|
def forward(
|
|
278
308
|
self,
|
|
279
309
|
intermediates: Any,
|
|
@@ -308,7 +338,7 @@ class SegmentationHead(Predictor):
|
|
|
308
338
|
labels = torch.stack([target["classes"] for target in targets], dim=0)
|
|
309
339
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
310
340
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
311
|
-
logits, labels, reduction="none"
|
|
341
|
+
logits, labels, weight=self.weights, reduction="none"
|
|
312
342
|
)
|
|
313
343
|
mask_sum = torch.sum(mask)
|
|
314
344
|
if mask_sum > 0:
|
|
@@ -318,6 +348,9 @@ class SegmentationHead(Predictor):
|
|
|
318
348
|
# If there are no valid pixels, we avoid dividing by zero and just let
|
|
319
349
|
# the summed mask loss be zero.
|
|
320
350
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
351
|
+
if self.dice_loss:
|
|
352
|
+
dice_loss = DiceLoss()(outputs, labels, mask)
|
|
353
|
+
losses["dice"] = dice_loss
|
|
321
354
|
|
|
322
355
|
return ModelOutput(
|
|
323
356
|
outputs=outputs,
|
|
@@ -333,6 +366,7 @@ class SegmentationMetric(Metric):
|
|
|
333
366
|
metric: Metric,
|
|
334
367
|
pass_probabilities: bool = True,
|
|
335
368
|
class_idx: int | None = None,
|
|
369
|
+
output_key: str | None = None,
|
|
336
370
|
):
|
|
337
371
|
"""Initialize a new SegmentationMetric.
|
|
338
372
|
|
|
@@ -341,12 +375,19 @@ class SegmentationMetric(Metric):
|
|
|
341
375
|
classes from the targets and masking out invalid pixels.
|
|
342
376
|
pass_probabilities: whether to pass predicted probabilities to the metric.
|
|
343
377
|
If False, argmax is applied to pass the predicted classes instead.
|
|
344
|
-
class_idx: if
|
|
378
|
+
class_idx: if set, return only this class index's value. For backward
|
|
379
|
+
compatibility with configs using standard torchmetrics. Internally
|
|
380
|
+
converted to output_key="cls_{class_idx}".
|
|
381
|
+
output_key: if the wrapped metric returns a dict (or a tensor that gets
|
|
382
|
+
converted to a dict), return only this key's value. For standard
|
|
383
|
+
torchmetrics with average=None, tensors are converted to dicts with
|
|
384
|
+
keys "cls_0", "cls_1", etc. If None, the full dict is returned.
|
|
345
385
|
"""
|
|
346
386
|
super().__init__()
|
|
347
387
|
self.metric = metric
|
|
348
388
|
self.pass_probablities = pass_probabilities
|
|
349
389
|
self.class_idx = class_idx
|
|
390
|
+
self.output_key = output_key
|
|
350
391
|
|
|
351
392
|
def update(
|
|
352
393
|
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
@@ -376,10 +417,32 @@ class SegmentationMetric(Metric):
|
|
|
376
417
|
self.metric.update(preds, labels)
|
|
377
418
|
|
|
378
419
|
def compute(self) -> Any:
|
|
379
|
-
"""Returns the computed metric.
|
|
420
|
+
"""Returns the computed metric.
|
|
421
|
+
|
|
422
|
+
If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
|
|
423
|
+
with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
|
|
424
|
+
This allows uniform handling via output_key for both standard torchmetrics and
|
|
425
|
+
custom dict-returning metrics.
|
|
426
|
+
"""
|
|
380
427
|
result = self.metric.compute()
|
|
428
|
+
|
|
429
|
+
# Convert multi-element tensors to dict for uniform handling.
|
|
430
|
+
# This supports standard torchmetrics with average=None which return per-class tensors.
|
|
431
|
+
if isinstance(result, torch.Tensor) and result.ndim >= 1:
|
|
432
|
+
result = {f"cls_{i}": result[i] for i in range(len(result))}
|
|
433
|
+
|
|
434
|
+
if self.output_key is not None:
|
|
435
|
+
if not isinstance(result, dict):
|
|
436
|
+
raise TypeError(
|
|
437
|
+
f"output_key is set to '{self.output_key}' but metric returned "
|
|
438
|
+
f"{type(result).__name__} instead of dict"
|
|
439
|
+
)
|
|
440
|
+
return result[self.output_key]
|
|
381
441
|
if self.class_idx is not None:
|
|
382
|
-
|
|
442
|
+
# For backward compatibility: class_idx can index into the converted dict
|
|
443
|
+
if isinstance(result, dict):
|
|
444
|
+
return result[f"cls_{self.class_idx}"]
|
|
445
|
+
return result[self.class_idx]
|
|
383
446
|
return result
|
|
384
447
|
|
|
385
448
|
def reset(self) -> None:
|
|
@@ -404,6 +467,7 @@ class F1Metric(Metric):
|
|
|
404
467
|
num_classes: int,
|
|
405
468
|
score_thresholds: list[float],
|
|
406
469
|
metric_mode: str = "f1",
|
|
470
|
+
report_per_class: bool = False,
|
|
407
471
|
):
|
|
408
472
|
"""Create a new F1Metric.
|
|
409
473
|
|
|
@@ -413,11 +477,14 @@ class F1Metric(Metric):
|
|
|
413
477
|
metric is the best F1 across score thresholds.
|
|
414
478
|
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
415
479
|
(default "f1")
|
|
480
|
+
report_per_class: whether to include per-class scores in the output dict.
|
|
481
|
+
If False, only returns the "avg" key.
|
|
416
482
|
"""
|
|
417
483
|
super().__init__()
|
|
418
484
|
self.num_classes = num_classes
|
|
419
485
|
self.score_thresholds = score_thresholds
|
|
420
486
|
self.metric_mode = metric_mode
|
|
487
|
+
self.report_per_class = report_per_class
|
|
421
488
|
|
|
422
489
|
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
423
490
|
|
|
@@ -462,9 +529,10 @@ class F1Metric(Metric):
|
|
|
462
529
|
"""Compute metric.
|
|
463
530
|
|
|
464
531
|
Returns:
|
|
465
|
-
|
|
532
|
+
dict with "avg" key containing mean score across classes.
|
|
533
|
+
If report_per_class is True, also includes "cls_N" keys for each class N.
|
|
466
534
|
"""
|
|
467
|
-
|
|
535
|
+
cls_best_scores = {}
|
|
468
536
|
|
|
469
537
|
for cls_idx in range(self.num_classes):
|
|
470
538
|
best_score = None
|
|
@@ -501,9 +569,12 @@ class F1Metric(Metric):
|
|
|
501
569
|
if best_score is None or score > best_score:
|
|
502
570
|
best_score = score
|
|
503
571
|
|
|
504
|
-
|
|
572
|
+
cls_best_scores[f"cls_{cls_idx}"] = best_score
|
|
505
573
|
|
|
506
|
-
|
|
574
|
+
report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
|
|
575
|
+
if self.report_per_class:
|
|
576
|
+
report_scores.update(cls_best_scores)
|
|
577
|
+
return report_scores
|
|
507
578
|
|
|
508
579
|
|
|
509
580
|
class MeanIoUMetric(Metric):
|
|
@@ -523,7 +594,7 @@ class MeanIoUMetric(Metric):
|
|
|
523
594
|
num_classes: int,
|
|
524
595
|
nodata_value: int | None = None,
|
|
525
596
|
ignore_missing_classes: bool = False,
|
|
526
|
-
|
|
597
|
+
report_per_class: bool = False,
|
|
527
598
|
):
|
|
528
599
|
"""Create a new MeanIoUMetric.
|
|
529
600
|
|
|
@@ -535,15 +606,14 @@ class MeanIoUMetric(Metric):
|
|
|
535
606
|
ignore_missing_classes: whether to ignore classes that don't appear in
|
|
536
607
|
either the predictions or the ground truth. If false, the IoU for a
|
|
537
608
|
missing class will be 0.
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
only supports scalar return values from metrics.
|
|
609
|
+
report_per_class: whether to include per-class IoU scores in the output dict.
|
|
610
|
+
If False, only returns the "avg" key.
|
|
541
611
|
"""
|
|
542
612
|
super().__init__()
|
|
543
613
|
self.num_classes = num_classes
|
|
544
614
|
self.nodata_value = nodata_value
|
|
545
615
|
self.ignore_missing_classes = ignore_missing_classes
|
|
546
|
-
self.
|
|
616
|
+
self.report_per_class = report_per_class
|
|
547
617
|
|
|
548
618
|
self.add_state(
|
|
549
619
|
"intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
@@ -584,9 +654,11 @@ class MeanIoUMetric(Metric):
|
|
|
584
654
|
"""Compute metric.
|
|
585
655
|
|
|
586
656
|
Returns:
|
|
587
|
-
the mean IoU across classes.
|
|
657
|
+
dict with "avg" containing the mean IoU across classes.
|
|
658
|
+
If report_per_class is True, also includes "cls_N" keys for each valid class N.
|
|
588
659
|
"""
|
|
589
|
-
|
|
660
|
+
cls_scores = {}
|
|
661
|
+
valid_scores = []
|
|
590
662
|
|
|
591
663
|
for cls_idx in range(self.num_classes):
|
|
592
664
|
# Check if nodata_value is set and is one of the classes
|
|
@@ -599,6 +671,56 @@ class MeanIoUMetric(Metric):
|
|
|
599
671
|
if union == 0 and self.ignore_missing_classes:
|
|
600
672
|
continue
|
|
601
673
|
|
|
602
|
-
|
|
674
|
+
score = intersection / union
|
|
675
|
+
cls_scores[f"cls_{cls_idx}"] = score
|
|
676
|
+
valid_scores.append(score)
|
|
677
|
+
|
|
678
|
+
report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
|
|
679
|
+
if self.report_per_class:
|
|
680
|
+
report_scores.update(cls_scores)
|
|
681
|
+
return report_scores
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
class DiceLoss(torch.nn.Module):
|
|
685
|
+
"""Mean Dice Loss for segmentation.
|
|
686
|
+
|
|
687
|
+
This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
|
|
688
|
+
The per-class intersection is the number of pixels across all examples where
|
|
689
|
+
the predicted label and ground truth label are both that class, and the per-class
|
|
690
|
+
union is defined similarly.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
def __init__(self, smooth: float = 1e-7):
|
|
694
|
+
"""Initialize a new DiceLoss."""
|
|
695
|
+
super().__init__()
|
|
696
|
+
self.smooth = smooth
|
|
697
|
+
|
|
698
|
+
def forward(
|
|
699
|
+
self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
|
|
700
|
+
) -> torch.Tensor:
|
|
701
|
+
"""Compute Dice Loss.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
the mean Dicen Loss across classes
|
|
705
|
+
"""
|
|
706
|
+
num_classes = inputs.shape[1]
|
|
707
|
+
targets_one_hot = (
|
|
708
|
+
torch.nn.functional.one_hot(targets, num_classes)
|
|
709
|
+
.permute(0, 3, 1, 2)
|
|
710
|
+
.float()
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# Expand mask to [B, C, H, W]
|
|
714
|
+
mask = mask.unsqueeze(1).expand_as(inputs)
|
|
715
|
+
|
|
716
|
+
dice_per_class = []
|
|
717
|
+
for c in range(num_classes):
|
|
718
|
+
pred_c = inputs[:, c] * mask[:, c]
|
|
719
|
+
target_c = targets_one_hot[:, c] * mask[:, c]
|
|
720
|
+
|
|
721
|
+
intersection = (pred_c * target_c).sum()
|
|
722
|
+
union = pred_c.sum() + target_c.sum()
|
|
723
|
+
dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
724
|
+
dice_per_class.append(dice_c)
|
|
603
725
|
|
|
604
|
-
return torch.
|
|
726
|
+
return 1 - torch.stack(dice_per_class).mean()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|