rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,24 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Regression task."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any, Literal
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
|
+
import shapely
|
|
7
8
|
import torch
|
|
8
9
|
import torchmetrics
|
|
9
10
|
from PIL import Image, ImageDraw
|
|
10
11
|
from torchmetrics import Metric, MetricCollection
|
|
11
12
|
|
|
12
|
-
from rslearn.
|
|
13
|
+
from rslearn.models.component import FeatureVector, Predictor
|
|
14
|
+
from rslearn.train.model_context import (
|
|
15
|
+
ModelContext,
|
|
16
|
+
ModelOutput,
|
|
17
|
+
RasterImage,
|
|
18
|
+
SampleMetadata,
|
|
19
|
+
)
|
|
20
|
+
from rslearn.utils.feature import Feature
|
|
21
|
+
from rslearn.utils.geometry import STGeometry
|
|
13
22
|
|
|
14
23
|
from .task import BasicTask
|
|
15
24
|
|
|
@@ -20,23 +29,29 @@ class RegressionTask(BasicTask):
|
|
|
20
29
|
def __init__(
|
|
21
30
|
self,
|
|
22
31
|
property_name: str,
|
|
23
|
-
filters: list[tuple[str, str]] | None,
|
|
32
|
+
filters: list[tuple[str, str]] | None = None,
|
|
24
33
|
allow_invalid: bool = False,
|
|
25
34
|
scale_factor: float = 1,
|
|
26
|
-
metric_mode:
|
|
27
|
-
|
|
28
|
-
|
|
35
|
+
metric_mode: Literal["mse", "l1"] = "mse",
|
|
36
|
+
use_accuracy_metric: bool = False,
|
|
37
|
+
within_factor: float = 0.1,
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> None:
|
|
29
40
|
"""Initialize a new RegressionTask.
|
|
30
41
|
|
|
31
42
|
Args:
|
|
32
|
-
property_name: the property from which to extract the
|
|
33
|
-
value is read from the first matching feature.
|
|
43
|
+
property_name: the property from which to extract the ground truth
|
|
44
|
+
regression value. The value is read from the first matching feature.
|
|
34
45
|
filters: optional list of (property_name, property_value) to only consider
|
|
35
46
|
features with matching properties.
|
|
36
47
|
allow_invalid: instead of throwing error when no regression label is found
|
|
37
48
|
at a window, simply mark the example invalid for this task
|
|
38
|
-
scale_factor: multiply the label value by this factor
|
|
39
|
-
metric_mode: what metric to use, either mse or l1
|
|
49
|
+
scale_factor: multiply the label value by this factor for training
|
|
50
|
+
metric_mode: what metric to use, either "mse" (default) or "l1"
|
|
51
|
+
use_accuracy_metric: include metric that reports percentage of
|
|
52
|
+
examples where output is within a factor of the ground truth.
|
|
53
|
+
within_factor: the factor for accuracy metric. If it's 0.2, and ground
|
|
54
|
+
truth is 5.0, then values from 5.0*0.8 to 5.0*1.2 are accepted.
|
|
40
55
|
kwargs: other arguments to pass to BasicTask
|
|
41
56
|
"""
|
|
42
57
|
super().__init__(**kwargs)
|
|
@@ -45,14 +60,16 @@ class RegressionTask(BasicTask):
|
|
|
45
60
|
self.allow_invalid = allow_invalid
|
|
46
61
|
self.scale_factor = scale_factor
|
|
47
62
|
self.metric_mode = metric_mode
|
|
63
|
+
self.use_accuracy_metric = use_accuracy_metric
|
|
64
|
+
self.within_factor = within_factor
|
|
48
65
|
|
|
49
66
|
if not self.filters:
|
|
50
67
|
self.filters = []
|
|
51
68
|
|
|
52
69
|
def process_inputs(
|
|
53
70
|
self,
|
|
54
|
-
raw_inputs: dict[str,
|
|
55
|
-
metadata:
|
|
71
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
72
|
+
metadata: SampleMetadata,
|
|
56
73
|
load_targets: bool = True,
|
|
57
74
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
58
75
|
"""Processes the data into targets.
|
|
@@ -70,7 +87,10 @@ class RegressionTask(BasicTask):
|
|
|
70
87
|
return {}, {}
|
|
71
88
|
|
|
72
89
|
data = raw_inputs["targets"]
|
|
90
|
+
assert isinstance(data, list)
|
|
73
91
|
for feat in data:
|
|
92
|
+
if feat.properties is None or self.filters is None:
|
|
93
|
+
continue
|
|
74
94
|
for property_name, property_value in self.filters:
|
|
75
95
|
if feat.properties.get(property_name) != property_value:
|
|
76
96
|
continue
|
|
@@ -90,6 +110,35 @@ class RegressionTask(BasicTask):
|
|
|
90
110
|
"valid": torch.tensor(0, dtype=torch.float32),
|
|
91
111
|
}
|
|
92
112
|
|
|
113
|
+
def process_output(
|
|
114
|
+
self, raw_output: Any, metadata: SampleMetadata
|
|
115
|
+
) -> list[Feature]:
|
|
116
|
+
"""Processes an output into raster or vector data.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
raw_output: the output from prediction head, which must be a scalar tensor.
|
|
120
|
+
metadata: metadata about the patch being read
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
a list with a single Feature corresponding to the patch extent and with a
|
|
124
|
+
property containing the predicted value.
|
|
125
|
+
"""
|
|
126
|
+
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
|
|
127
|
+
raise ValueError("output for RegressionTask must be a scalar Tensor")
|
|
128
|
+
|
|
129
|
+
output = raw_output.item() / self.scale_factor
|
|
130
|
+
feature = Feature(
|
|
131
|
+
STGeometry(
|
|
132
|
+
metadata.projection,
|
|
133
|
+
shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
|
|
134
|
+
None,
|
|
135
|
+
),
|
|
136
|
+
{
|
|
137
|
+
self.property_name: output,
|
|
138
|
+
},
|
|
139
|
+
)
|
|
140
|
+
return [feature]
|
|
141
|
+
|
|
93
142
|
def visualize(
|
|
94
143
|
self,
|
|
95
144
|
input_dict: dict[str, Any],
|
|
@@ -109,6 +158,8 @@ class RegressionTask(BasicTask):
|
|
|
109
158
|
image = super().visualize(input_dict, target_dict, output)["image"]
|
|
110
159
|
image = Image.fromarray(image)
|
|
111
160
|
draw = ImageDraw.Draw(image)
|
|
161
|
+
if target_dict is None:
|
|
162
|
+
raise ValueError("target_dict is required for visualization")
|
|
112
163
|
target = target_dict["value"] / self.scale_factor
|
|
113
164
|
output = output / self.scale_factor
|
|
114
165
|
text = f"Label: {target:.2f}\nOutput: {output:.2f}"
|
|
@@ -121,27 +172,36 @@ class RegressionTask(BasicTask):
|
|
|
121
172
|
|
|
122
173
|
def get_metrics(self) -> MetricCollection:
|
|
123
174
|
"""Get the metrics for this task."""
|
|
175
|
+
metric_dict: dict[str, Metric] = {}
|
|
176
|
+
|
|
124
177
|
if self.metric_mode == "mse":
|
|
125
|
-
|
|
178
|
+
metric_dict["mse"] = RegressionMetricWrapper(
|
|
179
|
+
metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
|
|
180
|
+
)
|
|
126
181
|
elif self.metric_mode == "l1":
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
self.metric_mode: RegressionMetricWrapper(
|
|
131
|
-
metric=metric, scale_factor=self.scale_factor
|
|
132
|
-
)
|
|
133
|
-
}
|
|
134
|
-
)
|
|
182
|
+
metric_dict["l1"] = RegressionMetricWrapper(
|
|
183
|
+
metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
|
|
184
|
+
)
|
|
135
185
|
|
|
186
|
+
if self.use_accuracy_metric:
|
|
187
|
+
metric_dict["accuracy"] = RegressionMetricWrapper(
|
|
188
|
+
metric=RegressionAccuracy(self.within_factor),
|
|
189
|
+
scale_factor=self.scale_factor,
|
|
190
|
+
)
|
|
136
191
|
|
|
137
|
-
|
|
192
|
+
return MetricCollection(metric_dict)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class RegressionHead(Predictor):
|
|
138
196
|
"""Head for regression task."""
|
|
139
197
|
|
|
140
|
-
def __init__(
|
|
198
|
+
def __init__(
|
|
199
|
+
self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
|
|
200
|
+
):
|
|
141
201
|
"""Initialize a new RegressionHead.
|
|
142
202
|
|
|
143
203
|
Args:
|
|
144
|
-
loss_mode: the loss function to use, either "mse" or "l1".
|
|
204
|
+
loss_mode: the loss function to use, either "mse" (default) or "l1".
|
|
145
205
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
146
206
|
requires targets to be between 0-1.
|
|
147
207
|
"""
|
|
@@ -151,48 +211,59 @@ class RegressionHead(torch.nn.Module):
|
|
|
151
211
|
|
|
152
212
|
def forward(
|
|
153
213
|
self,
|
|
154
|
-
|
|
155
|
-
|
|
214
|
+
intermediates: Any,
|
|
215
|
+
context: ModelContext,
|
|
156
216
|
targets: list[dict[str, Any]] | None = None,
|
|
157
|
-
):
|
|
217
|
+
) -> ModelOutput:
|
|
158
218
|
"""Compute the regression outputs and loss from logits and targets.
|
|
159
219
|
|
|
160
220
|
Args:
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
221
|
+
intermediates: output from previous model component, which must be a
|
|
222
|
+
FeatureVector with channel dimension size 1 (Bx1).
|
|
223
|
+
context: the model context.
|
|
224
|
+
targets: target dicts, which each must contain a "value" key containing the
|
|
225
|
+
regression label, along with a "valid" key containing a flag indicating
|
|
226
|
+
whether each example is valid for this task.
|
|
164
227
|
|
|
165
228
|
Returns:
|
|
166
|
-
|
|
229
|
+
the model outputs. The output is a B tensor so that it is split up into a
|
|
230
|
+
scalar for each example.
|
|
167
231
|
"""
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
232
|
+
if not isinstance(intermediates, FeatureVector):
|
|
233
|
+
raise ValueError("the input to RegressionHead must be a FeatureVector")
|
|
234
|
+
if intermediates.feature_vector.shape[1] != 1:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
logits = intermediates.feature_vector[:, 0]
|
|
172
240
|
|
|
173
241
|
if self.use_sigmoid:
|
|
174
242
|
outputs = torch.nn.functional.sigmoid(logits)
|
|
175
243
|
else:
|
|
176
244
|
outputs = logits
|
|
177
245
|
|
|
178
|
-
|
|
246
|
+
losses = {}
|
|
179
247
|
if targets:
|
|
180
248
|
labels = torch.stack([target["value"] for target in targets])
|
|
181
249
|
mask = torch.stack([target["valid"] for target in targets])
|
|
182
250
|
if self.loss_mode == "mse":
|
|
183
|
-
|
|
251
|
+
losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
|
|
184
252
|
elif self.loss_mode == "l1":
|
|
185
|
-
|
|
253
|
+
losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
|
|
186
254
|
else:
|
|
187
|
-
|
|
255
|
+
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
188
256
|
|
|
189
|
-
return
|
|
257
|
+
return ModelOutput(
|
|
258
|
+
outputs=outputs,
|
|
259
|
+
loss_dict=losses,
|
|
260
|
+
)
|
|
190
261
|
|
|
191
262
|
|
|
192
263
|
class RegressionMetricWrapper(Metric):
|
|
193
264
|
"""Metric for regression task."""
|
|
194
265
|
|
|
195
|
-
def __init__(self, metric: Metric, scale_factor: float, **kwargs):
|
|
266
|
+
def __init__(self, metric: Metric, scale_factor: float, **kwargs: Any) -> None:
|
|
196
267
|
"""Initialize a new RegressionMetricWrapper.
|
|
197
268
|
|
|
198
269
|
Args:
|
|
@@ -206,14 +277,17 @@ class RegressionMetricWrapper(Metric):
|
|
|
206
277
|
self.metric = metric
|
|
207
278
|
self.scale_factor = scale_factor
|
|
208
279
|
|
|
209
|
-
def update(
|
|
280
|
+
def update(
|
|
281
|
+
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
282
|
+
) -> None:
|
|
210
283
|
"""Update metric.
|
|
211
284
|
|
|
212
285
|
Args:
|
|
213
286
|
preds: the predictions
|
|
214
287
|
targets: the targets
|
|
215
288
|
"""
|
|
216
|
-
preds
|
|
289
|
+
if not isinstance(preds, torch.Tensor):
|
|
290
|
+
preds = torch.stack(preds)
|
|
217
291
|
labels = torch.stack([target["value"] for target in targets])
|
|
218
292
|
|
|
219
293
|
# Sub-select the valid labels.
|
|
@@ -237,3 +311,46 @@ class RegressionMetricWrapper(Metric):
|
|
|
237
311
|
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
238
312
|
"""Returns a plot of the metric."""
|
|
239
313
|
return self.metric.plot(*args, **kwargs)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class RegressionAccuracy(Metric):
|
|
317
|
+
"""Percentage of examples with estimate within some factor of ground truth."""
|
|
318
|
+
|
|
319
|
+
def __init__(self, factor: float) -> None:
|
|
320
|
+
"""Initialize a new RegressionAccuracy.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
factor: the factor so if estimate is within this much of ground truth then
|
|
324
|
+
it is marked correct.
|
|
325
|
+
"""
|
|
326
|
+
super().__init__()
|
|
327
|
+
self.factor = factor
|
|
328
|
+
self.correct = 0
|
|
329
|
+
self.total = 0
|
|
330
|
+
|
|
331
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
332
|
+
"""Update metric.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
preds: the predictions
|
|
336
|
+
labels: the ground truth data
|
|
337
|
+
"""
|
|
338
|
+
decisions = (preds >= labels * (1 - self.factor)) & (
|
|
339
|
+
preds <= labels * (1 + self.factor)
|
|
340
|
+
)
|
|
341
|
+
self.correct += torch.count_nonzero(decisions)
|
|
342
|
+
self.total += len(decisions)
|
|
343
|
+
|
|
344
|
+
def compute(self) -> Any:
|
|
345
|
+
"""Returns the computed metric."""
|
|
346
|
+
return torch.tensor(self.correct / self.total)
|
|
347
|
+
|
|
348
|
+
def reset(self) -> None:
|
|
349
|
+
"""Reset metric."""
|
|
350
|
+
super().reset()
|
|
351
|
+
self.correct = 0
|
|
352
|
+
self.total = 0
|
|
353
|
+
|
|
354
|
+
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
355
|
+
"""Returns a plot of the metric."""
|
|
356
|
+
return None
|