rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,9 @@
3
3
  from typing import Any
4
4
 
5
5
  import numpy.typing as npt
6
- import torch
7
6
  from torchmetrics import Metric, MetricCollection
8
7
 
8
+ from rslearn.train.model_context import RasterImage, SampleMetadata
9
9
  from rslearn.utils import Feature
10
10
 
11
11
  from .task import Task
@@ -29,8 +29,8 @@ class MultiTask(Task):
29
29
 
30
30
  def process_inputs(
31
31
  self,
32
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
33
- metadata: dict[str, Any],
32
+ raw_inputs: dict[str, RasterImage | list[Feature]],
33
+ metadata: SampleMetadata,
34
34
  load_targets: bool = True,
35
35
  ) -> tuple[dict[str, Any], dict[str, Any]]:
36
36
  """Processes the data into targets.
@@ -46,7 +46,14 @@ class MultiTask(Task):
46
46
  """
47
47
  input_dict = {}
48
48
  target_dict = {}
49
- for task_name, task in self.tasks.items():
49
+ if metadata.dataset_source is None:
50
+ # No multi-dataset, so always compute across all tasks
51
+ task_iter = list(self.tasks.items())
52
+ else:
53
+ # Multi-dataset, so only compute for the task in this dataset
54
+ task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
55
+
56
+ for task_name, task in task_iter:
50
57
  cur_raw_inputs = {}
51
58
  for k, v in self.input_mapping[task_name].items():
52
59
  if k not in raw_inputs:
@@ -62,12 +69,13 @@ class MultiTask(Task):
62
69
  return input_dict, target_dict
63
70
 
64
71
  def process_output(
65
- self, raw_output: Any, metadata: dict[str, Any]
66
- ) -> npt.NDArray[Any] | list[Feature]:
72
+ self, raw_output: Any, metadata: SampleMetadata
73
+ ) -> dict[str, Any]:
67
74
  """Processes an output into raster or vector data.
68
75
 
69
76
  Args:
70
- raw_output: the output from prediction head.
77
+ raw_output: the output from prediction head. It must be a dict mapping from
78
+ task name to per-task output for this sample.
71
79
  metadata: metadata about the patch being read
72
80
 
73
81
  Returns:
@@ -75,9 +83,11 @@ class MultiTask(Task):
75
83
  """
76
84
  processed_output = {}
77
85
  for task_name, task in self.tasks.items():
78
- processed_output[task_name] = task.process_output(
79
- raw_output[task_name], metadata
80
- )
86
+ if task_name in raw_output:
87
+ # In multi-dataset training, we may not have all datasets in the batch
88
+ processed_output[task_name] = task.process_output(
89
+ raw_output[task_name], metadata
90
+ )
81
91
  return processed_output
82
92
 
83
93
  def visualize(
@@ -146,10 +156,14 @@ class MetricWrapper(Metric):
146
156
  preds: the predictions
147
157
  targets: the targets
148
158
  """
149
- self.metric.update(
150
- [pred[self.task_name] for pred in preds],
151
- [target[self.task_name] for target in targets],
152
- )
159
+ try:
160
+ self.metric.update(
161
+ [pred[self.task_name] for pred in preds],
162
+ [target[self.task_name] for target in targets],
163
+ )
164
+ except KeyError:
165
+ # In multi-dataset training, we may not have all datasets in the batch
166
+ pass
153
167
 
154
168
  def compute(self) -> Any:
155
169
  """Returns the computed metric."""
@@ -0,0 +1,291 @@
1
+ """Per-pixel regression task."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch
8
+ import torchmetrics
9
+ from torchmetrics import Metric, MetricCollection
10
+
11
+ from rslearn.models.component import FeatureMaps, Predictor
12
+ from rslearn.train.model_context import (
13
+ ModelContext,
14
+ ModelOutput,
15
+ RasterImage,
16
+ SampleMetadata,
17
+ )
18
+ from rslearn.utils.feature import Feature
19
+
20
+ from .task import BasicTask
21
+
22
+
23
+ class PerPixelRegressionTask(BasicTask):
24
+ """A per-pixel regression task."""
25
+
26
+ def __init__(
27
+ self,
28
+ scale_factor: float = 1,
29
+ metric_mode: Literal["mse", "l1"] = "mse",
30
+ nodata_value: float | None = None,
31
+ **kwargs: Any,
32
+ ) -> None:
33
+ """Initialize a new PerPixelRegressionTask.
34
+
35
+ Args:
36
+ scale_factor: multiply ground truth values by this factor before using it for
37
+ training.
38
+ metric_mode: what metric to use, either "mse" (default) or "l1"
39
+ nodata_value: optional value to treat as invalid. The loss will be masked
40
+ at pixels where the ground truth value is equal to nodata_value.
41
+ kwargs: other arguments to pass to BasicTask
42
+ """
43
+ super().__init__(**kwargs)
44
+ self.scale_factor = scale_factor
45
+ self.metric_mode = metric_mode
46
+ self.nodata_value = nodata_value
47
+
48
+ def process_inputs(
49
+ self,
50
+ raw_inputs: dict[str, RasterImage | list[Feature]],
51
+ metadata: SampleMetadata,
52
+ load_targets: bool = True,
53
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
54
+ """Processes the data into targets.
55
+
56
+ Args:
57
+ raw_inputs: raster or vector data to process
58
+ metadata: metadata about the patch being read
59
+ load_targets: whether to load the targets or only inputs
60
+
61
+ Returns:
62
+ tuple (input_dict, target_dict) containing the processed inputs and targets
63
+ that are compatible with both metrics and loss functions
64
+ """
65
+ if not load_targets:
66
+ return {}, {}
67
+
68
+ assert isinstance(raw_inputs["targets"], RasterImage)
69
+ assert raw_inputs["targets"].image.shape[0] == 1
70
+ assert raw_inputs["targets"].image.shape[1] == 1
71
+ labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
72
+
73
+ if self.nodata_value is not None:
74
+ valid = (
75
+ raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
+ ).float()
77
+ else:
78
+ valid = torch.ones(labels.shape, dtype=torch.float32)
79
+
80
+ return {}, {
81
+ "values": labels,
82
+ "valid": valid,
83
+ }
84
+
85
+ def process_output(
86
+ self, raw_output: Any, metadata: SampleMetadata
87
+ ) -> npt.NDArray[Any] | list[Feature]:
88
+ """Processes an output into raster or vector data.
89
+
90
+ Args:
91
+ raw_output: the output from prediction head, which must be an HW tensor.
92
+ metadata: metadata about the patch being read
93
+
94
+ Returns:
95
+ either raster or vector data.
96
+ """
97
+ if not isinstance(raw_output, torch.Tensor):
98
+ raise ValueError("output for PerPixelRegressionTask must be a tensor")
99
+ if len(raw_output.shape) != 2:
100
+ raise ValueError(
101
+ f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
102
+ )
103
+ return (raw_output / self.scale_factor).cpu().numpy()
104
+
105
+ def visualize(
106
+ self,
107
+ input_dict: dict[str, Any],
108
+ target_dict: dict[str, Any] | None,
109
+ output: Any,
110
+ ) -> dict[str, npt.NDArray[Any]]:
111
+ """Visualize the outputs and targets.
112
+
113
+ Args:
114
+ input_dict: the input dict from process_inputs
115
+ target_dict: the target dict from process_inputs
116
+ output: the prediction
117
+
118
+ Returns:
119
+ a dictionary mapping image name to visualization image
120
+ """
121
+ image = super().visualize(input_dict, target_dict, output)["image"]
122
+ if target_dict is None:
123
+ raise ValueError("target_dict is required for visualization")
124
+ gt_values = target_dict["classes"].cpu().numpy()
125
+ pred_values = output.cpu().numpy()[0, :, :]
126
+ gt_vis = np.clip(gt_values * 255, 0, 255).astype(np.uint8)
127
+ pred_vis = np.clip(pred_values * 255, 0, 255).astype(np.uint8)
128
+ return {
129
+ "image": np.array(image),
130
+ "gt": gt_vis,
131
+ "pred": pred_vis,
132
+ }
133
+
134
+ def get_metrics(self) -> MetricCollection:
135
+ """Get the metrics for this task."""
136
+ metric_dict: dict[str, Metric] = {}
137
+
138
+ if self.metric_mode == "mse":
139
+ metric_dict["mse"] = PerPixelRegressionMetricWrapper(
140
+ metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
141
+ )
142
+ elif self.metric_mode == "l1":
143
+ metric_dict["l1"] = PerPixelRegressionMetricWrapper(
144
+ metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
145
+ )
146
+
147
+ return MetricCollection(metric_dict)
148
+
149
+
150
+ class PerPixelRegressionHead(Predictor):
151
+ """Head for per-pixel regression task."""
152
+
153
+ def __init__(
154
+ self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
155
+ ):
156
+ """Initialize a new RegressionHead.
157
+
158
+ Args:
159
+ loss_mode: the loss function to use, either "mse" (default) or "l1".
160
+ use_sigmoid: whether to apply a sigmoid activation on the output. This
161
+ requires targets to be between 0-1.
162
+ """
163
+ super().__init__()
164
+
165
+ if loss_mode not in ["mse", "l1"]:
166
+ raise ValueError("invalid loss mode")
167
+
168
+ self.loss_mode = loss_mode
169
+ self.use_sigmoid = use_sigmoid
170
+
171
+ def forward(
172
+ self,
173
+ intermediates: Any,
174
+ context: ModelContext,
175
+ targets: list[dict[str, Any]] | None = None,
176
+ ) -> ModelOutput:
177
+ """Compute the regression outputs and loss from logits and targets.
178
+
179
+ Args:
180
+ intermediates: output from previous component, which must be a FeatureMaps
181
+ with one feature map corresponding to the logits. The channel dimension
182
+ size must be 1.
183
+ context: the model context.
184
+ targets: must contain values key that stores the regression labels, and
185
+ valid key containing mask image indicating where the labels are valid.
186
+
187
+ Returns:
188
+ tuple of outputs and loss dict. The output is a BHW tensor so that the
189
+ per-sample output is an HW tensor.
190
+ """
191
+ if not isinstance(intermediates, FeatureMaps):
192
+ raise ValueError(
193
+ "the input to PerPixelRegressionHead must be a FeatureMaps"
194
+ )
195
+ if len(intermediates.feature_maps) != 1:
196
+ raise ValueError(
197
+ "the input to PerPixelRegressionHead must have one feature map"
198
+ )
199
+ if intermediates.feature_maps[0].shape[1] != 1:
200
+ raise ValueError(
201
+ f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
202
+ )
203
+
204
+ logits = intermediates.feature_maps[0][:, 0, :, :]
205
+
206
+ if self.use_sigmoid:
207
+ outputs = torch.nn.functional.sigmoid(logits)
208
+ else:
209
+ outputs = logits
210
+
211
+ losses = {}
212
+ if targets:
213
+ labels = torch.stack([target["values"] for target in targets])
214
+ mask = torch.stack([target["valid"] for target in targets])
215
+
216
+ if self.loss_mode == "mse":
217
+ scores = torch.square(outputs - labels)
218
+ elif self.loss_mode == "l1":
219
+ scores = torch.abs(outputs - labels)
220
+ else:
221
+ assert False
222
+
223
+ # Compute average but only over valid pixels.
224
+ mask_total = mask.sum()
225
+ if mask_total == 0:
226
+ # Just average over all pixels but it will be zero.
227
+ losses["regress"] = (scores * mask).mean()
228
+ else:
229
+ losses["regress"] = (scores * mask).sum() / mask_total
230
+
231
+ return ModelOutput(
232
+ outputs=outputs,
233
+ loss_dict=losses,
234
+ )
235
+
236
+
237
+ class PerPixelRegressionMetricWrapper(Metric):
238
+ """Metric for per-pixel regression task."""
239
+
240
+ def __init__(self, metric: Metric, scale_factor: float, **kwargs: Any) -> None:
241
+ """Initialize a new PerPixelRegressionMetricWrapper.
242
+
243
+ Args:
244
+ metric: the underlying torchmetric to apply, which should accept a flat
245
+ tensor of predicted values followed by a flat tensor of target values
246
+ scale_factor: scale factor to undo so that metric is based on original
247
+ values
248
+ kwargs: other arguments to pass to super constructor
249
+ """
250
+ super().__init__(**kwargs)
251
+ self.metric = metric
252
+ self.scale_factor = scale_factor
253
+
254
+ def update(
255
+ self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
256
+ ) -> None:
257
+ """Update metric.
258
+
259
+ Args:
260
+ preds: the predictions
261
+ targets: the targets
262
+ """
263
+ if not isinstance(preds, torch.Tensor):
264
+ preds = torch.stack(preds)
265
+ labels = torch.stack([target["values"] for target in targets])
266
+
267
+ # Sub-select the valid labels.
268
+ # We flatten the prediction and label images at valid pixels.
269
+ if len(preds.shape) == 4:
270
+ assert preds.shape[1] == 1
271
+ preds = preds[:, 0, :, :]
272
+ mask = torch.stack([target["valid"] > 0 for target in targets])
273
+ preds = preds[mask]
274
+ labels = labels[mask]
275
+ if len(preds) == 0:
276
+ return
277
+
278
+ self.metric.update(preds, labels)
279
+
280
+ def compute(self) -> Any:
281
+ """Returns the computed metric."""
282
+ return self.metric.compute()
283
+
284
+ def reset(self) -> None:
285
+ """Reset metric."""
286
+ super().reset()
287
+ self.metric.reset()
288
+
289
+ def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
290
+ """Returns a plot of the metric."""
291
+ return self.metric.plot(*args, **kwargs)