valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ import numpy as np
2
+ import pyarrow as pa
3
+ from tqdm import tqdm
4
+
5
+ from valor_lite.cache import FileCacheWriter, MemoryCacheWriter
6
+ from valor_lite.semantic_segmentation.annotation import Segmentation
7
+ from valor_lite.semantic_segmentation.computation import compute_intermediates
8
+ from valor_lite.semantic_segmentation.evaluator import Builder
9
+
10
+
11
+ class Loader(Builder):
12
+ def __init__(
13
+ self,
14
+ writer: MemoryCacheWriter | FileCacheWriter,
15
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
16
+ ):
17
+ super().__init__(
18
+ writer=writer,
19
+ metadata_fields=metadata_fields,
20
+ )
21
+
22
+ # internal state
23
+ self._labels: dict[str, int] = {}
24
+ self._index_to_label: dict[int, str] = {}
25
+ self._datum_count = 0
26
+
27
+ def _add_label(self, value: str) -> int:
28
+ idx = self._labels.get(value, None)
29
+ if idx is None:
30
+ idx = len(self._labels)
31
+ self._labels[value] = idx
32
+ self._index_to_label[idx] = value
33
+ return idx
34
+
35
+ def add_data(
36
+ self,
37
+ segmentations: list[Segmentation],
38
+ show_progress: bool = False,
39
+ ):
40
+ """
41
+ Adds segmentations to the cache.
42
+
43
+ Parameters
44
+ ----------
45
+ segmentations : list[Segmentation]
46
+ A list of Segmentation objects.
47
+ show_progress : bool, default=False
48
+ Toggle for tqdm progress bar.
49
+ """
50
+
51
+ disable_tqdm = not show_progress
52
+ for segmentation in tqdm(segmentations, disable=disable_tqdm):
53
+
54
+ groundtruth_labels = -1 * np.ones(
55
+ len(segmentation.groundtruths), dtype=np.int64
56
+ )
57
+ for idx, groundtruth in enumerate(segmentation.groundtruths):
58
+ label_idx = self._add_label(groundtruth.label)
59
+ groundtruth_labels[idx] = label_idx
60
+
61
+ prediction_labels = -1 * np.ones(
62
+ len(segmentation.predictions), dtype=np.int64
63
+ )
64
+ for idx, prediction in enumerate(segmentation.predictions):
65
+ label_idx = self._add_label(prediction.label)
66
+ prediction_labels[idx] = label_idx
67
+
68
+ if segmentation.groundtruths:
69
+ combined_groundtruths = np.stack(
70
+ [
71
+ groundtruth.mask.flatten()
72
+ for groundtruth in segmentation.groundtruths
73
+ ],
74
+ axis=0,
75
+ )
76
+ else:
77
+ combined_groundtruths = np.zeros(
78
+ (1, segmentation.shape[0] * segmentation.shape[1]),
79
+ dtype=np.bool_,
80
+ )
81
+
82
+ if segmentation.predictions:
83
+ combined_predictions = np.stack(
84
+ [
85
+ prediction.mask.flatten()
86
+ for prediction in segmentation.predictions
87
+ ],
88
+ axis=0,
89
+ )
90
+ else:
91
+ combined_predictions = np.zeros(
92
+ (1, segmentation.shape[0] * segmentation.shape[1]),
93
+ dtype=np.bool_,
94
+ )
95
+
96
+ n_labels = len(self._labels)
97
+ counts = compute_intermediates(
98
+ groundtruths=combined_groundtruths,
99
+ predictions=combined_predictions,
100
+ groundtruth_labels=groundtruth_labels,
101
+ prediction_labels=prediction_labels,
102
+ n_labels=n_labels,
103
+ )
104
+
105
+ # prepare metadata
106
+ datum_metadata = (
107
+ segmentation.metadata if segmentation.metadata else {}
108
+ )
109
+ gt_metadata = {
110
+ self._labels[gt.label]: gt.metadata
111
+ for gt in segmentation.groundtruths
112
+ if gt.metadata
113
+ }
114
+ pd_metadata = {
115
+ self._labels[pd.label]: pd.metadata
116
+ for pd in segmentation.predictions
117
+ if pd.metadata
118
+ }
119
+
120
+ # cache formatting
121
+ rows = []
122
+ for idx in range(n_labels):
123
+ label = self._index_to_label[idx]
124
+ for pidx in range(n_labels):
125
+ # write non-zero intersections to cache
126
+ if counts[idx + 1, pidx + 1] > 0:
127
+ plabel = self._index_to_label[pidx]
128
+ rows.append(
129
+ {
130
+ # metadata
131
+ **datum_metadata,
132
+ **gt_metadata.get(idx, {}),
133
+ **pd_metadata.get(pidx, {}),
134
+ # datum
135
+ "datum_uid": segmentation.uid,
136
+ "datum_id": self._datum_count,
137
+ # groundtruth
138
+ "gt_label": label,
139
+ "gt_label_id": idx,
140
+ # prediction
141
+ "pd_label": plabel,
142
+ "pd_label_id": pidx,
143
+ # pair
144
+ "count": counts[idx + 1, pidx + 1],
145
+ }
146
+ )
147
+ # write all unmatched to preserve labels
148
+ rows.extend(
149
+ [
150
+ {
151
+ # metadata
152
+ **datum_metadata,
153
+ **gt_metadata.get(idx, {}),
154
+ # datum
155
+ "datum_uid": segmentation.uid,
156
+ "datum_id": self._datum_count,
157
+ # groundtruth
158
+ "gt_label": label,
159
+ "gt_label_id": idx,
160
+ # prediction
161
+ "pd_label": None,
162
+ "pd_label_id": -1,
163
+ # pair
164
+ "count": counts[idx + 1, 0],
165
+ },
166
+ {
167
+ # metadata
168
+ **datum_metadata,
169
+ **gt_metadata.get(idx, {}),
170
+ **pd_metadata.get(idx, {}),
171
+ # datum
172
+ "datum_uid": segmentation.uid,
173
+ "datum_id": self._datum_count,
174
+ # groundtruth
175
+ "gt_label": None,
176
+ "gt_label_id": -1,
177
+ # prediction
178
+ "pd_label": label,
179
+ "pd_label_id": idx,
180
+ # pair
181
+ "count": counts[0, idx + 1],
182
+ },
183
+ ]
184
+ )
185
+ rows.append(
186
+ {
187
+ # metadata
188
+ **datum_metadata,
189
+ # datum
190
+ "datum_uid": segmentation.uid,
191
+ "datum_id": self._datum_count,
192
+ # groundtruth
193
+ "gt_label": None,
194
+ "gt_label_id": -1,
195
+ # prediction
196
+ "pd_label": None,
197
+ "pd_label_id": -1,
198
+ # pair
199
+ "count": counts[0, 0],
200
+ }
201
+ )
202
+ self._writer.write_rows(rows)
203
+
204
+ # update datum count
205
+ self._datum_count += 1
@@ -0,0 +1,275 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ from valor_lite.schemas import BaseMetric
5
+
6
+
7
+ class MetricType(Enum):
8
+ Precision = "Precision"
9
+ Recall = "Recall"
10
+ Accuracy = "Accuracy"
11
+ F1 = "F1"
12
+ IOU = "IOU"
13
+ mIOU = "mIOU"
14
+ ConfusionMatrix = "ConfusionMatrix"
15
+
16
+
17
+ @dataclass
18
+ class Metric(BaseMetric):
19
+ """
20
+ Semantic Segmentation Metric.
21
+
22
+ Attributes
23
+ ----------
24
+ type : str
25
+ The metric type.
26
+ value : int | float | dict
27
+ The metric value.
28
+ parameters : dict[str, Any]
29
+ A dictionary containing metric parameters.
30
+ """
31
+
32
+ def __post_init__(self):
33
+ if not isinstance(self.type, str):
34
+ raise TypeError(
35
+ f"Metric type should be of type 'str': {self.type}"
36
+ )
37
+ elif not isinstance(self.value, (int, float, dict)):
38
+ raise TypeError(
39
+ f"Metric value must be of type 'int', 'float' or 'dict': {self.value}"
40
+ )
41
+ elif not isinstance(self.parameters, dict):
42
+ raise TypeError(
43
+ f"Metric parameters must be of type 'dict[str, Any]': {self.parameters}"
44
+ )
45
+ elif not all([isinstance(k, str) for k in self.parameters.keys()]):
46
+ raise TypeError(
47
+ f"Metric parameter dictionary should only have keys with type 'str': {self.parameters}"
48
+ )
49
+
50
+ @classmethod
51
+ def precision(
52
+ cls,
53
+ value: float,
54
+ label: str,
55
+ ):
56
+ """
57
+ Precision metric for a specific class label.
58
+
59
+ Precision is calulated using the number of true-positive pixels divided by
60
+ the sum of all true-positive and false-positive pixels.
61
+
62
+ Parameters
63
+ ----------
64
+ value : float
65
+ The computed precision value.
66
+ label : str
67
+ The label for which the precision is calculated.
68
+
69
+ Returns
70
+ -------
71
+ Metric
72
+ """
73
+ return cls(
74
+ type=MetricType.Precision.value,
75
+ value=value,
76
+ parameters={
77
+ "label": label,
78
+ },
79
+ )
80
+
81
+ @classmethod
82
+ def recall(
83
+ cls,
84
+ value: float,
85
+ label: str,
86
+ ):
87
+ """
88
+ Recall metric for a specific class label.
89
+
90
+ Recall is calulated using the number of true-positive pixels divided by
91
+ the sum of all true-positive and false-negative pixels.
92
+
93
+ Parameters
94
+ ----------
95
+ value : float
96
+ The computed recall value.
97
+ label : str
98
+ The label for which the recall is calculated.
99
+
100
+ Returns
101
+ -------
102
+ Metric
103
+ """
104
+ return cls(
105
+ type=MetricType.Recall.value,
106
+ value=value,
107
+ parameters={
108
+ "label": label,
109
+ },
110
+ )
111
+
112
+ @classmethod
113
+ def f1_score(
114
+ cls,
115
+ value: float,
116
+ label: str,
117
+ ):
118
+ """
119
+ F1 score for a specific class label.
120
+
121
+ Parameters
122
+ ----------
123
+ value : float
124
+ The computed F1 score.
125
+ label : str
126
+ The label for which the F1 score is calculated.
127
+
128
+ Returns
129
+ -------
130
+ Metric
131
+ """
132
+ return cls(
133
+ type=MetricType.F1.value,
134
+ value=value,
135
+ parameters={
136
+ "label": label,
137
+ },
138
+ )
139
+
140
+ @classmethod
141
+ def iou(
142
+ cls,
143
+ value: float,
144
+ label: str,
145
+ ):
146
+ """
147
+ Intersection over Union (IOU) ratio for a specific class label.
148
+
149
+ Parameters
150
+ ----------
151
+ value : float
152
+ The computed IOU ratio.
153
+ label : str
154
+ The label for which the IOU is calculated.
155
+
156
+ Returns
157
+ -------
158
+ Metric
159
+ """
160
+ return cls(
161
+ type=MetricType.IOU.value,
162
+ value=value,
163
+ parameters={
164
+ "label": label,
165
+ },
166
+ )
167
+
168
+ @classmethod
169
+ def mean_iou(cls, value: float):
170
+ """
171
+ Mean Intersection over Union (mIOU) ratio.
172
+
173
+ The mIOU value is computed by averaging IOU over all labels.
174
+
175
+ Parameters
176
+ ----------
177
+ value : float
178
+ The mIOU value.
179
+
180
+ Returns
181
+ -------
182
+ Metric
183
+ """
184
+ return cls(type=MetricType.mIOU.value, value=value, parameters={})
185
+
186
+ @classmethod
187
+ def accuracy(cls, value: float):
188
+ """
189
+ Accuracy metric computed over all labels.
190
+
191
+ Parameters
192
+ ----------
193
+ value : float
194
+ The accuracy value.
195
+
196
+ Returns
197
+ -------
198
+ Metric
199
+ """
200
+ return cls(type=MetricType.Accuracy.value, value=value, parameters={})
201
+
202
+ @classmethod
203
+ def confusion_matrix(
204
+ cls,
205
+ confusion_matrix: dict[
206
+ str, # ground truth label value
207
+ dict[
208
+ str, # prediction label value
209
+ dict[str, float], # iou
210
+ ],
211
+ ],
212
+ unmatched_predictions: dict[
213
+ str, # prediction label value
214
+ dict[str, float], # pixel ratio
215
+ ],
216
+ unmatched_ground_truths: dict[
217
+ str, # ground truth label value
218
+ dict[str, float], # pixel ratio
219
+ ],
220
+ ):
221
+ """
222
+ The confusion matrix and related metrics for semantic segmentation tasks.
223
+
224
+ This class encapsulates detailed information about the model's performance, including correct
225
+ predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
226
+ (subset of false negatives). It provides counts for each category to facilitate in-depth analysis.
227
+
228
+ Confusion Matrix Format:
229
+ {
230
+ <ground truth label>: {
231
+ <prediction label>: {
232
+ 'iou': <float>,
233
+ },
234
+ },
235
+ }
236
+
237
+ Unmatched Predictions Format:
238
+ {
239
+ <prediction label>: {
240
+ 'iou': <float>,
241
+ },
242
+ }
243
+
244
+ Unmatched Ground Truths Format:
245
+ {
246
+ <ground truth label>: {
247
+ 'iou': <float>,
248
+ },
249
+ }
250
+
251
+ Parameters
252
+ ----------
253
+ confusion_matrix : dict
254
+ Nested dictionaries representing the Intersection over Union (IOU) scores for each
255
+ ground truth label and prediction label pair.
256
+ unmatched_predictions : dict
257
+ Dictionary representing the pixel ratios for predicted labels that do not correspond
258
+ to any ground truth labels (false positives).
259
+ unmatched_ground_truths : dict
260
+ Dictionary representing the pixel ratios for ground truth labels that were not predicted
261
+ (false negatives).
262
+
263
+ Returns
264
+ -------
265
+ Metric
266
+ """
267
+ return cls(
268
+ type=MetricType.ConfusionMatrix.value,
269
+ value={
270
+ "confusion_matrix": confusion_matrix,
271
+ "unmatched_predictions": unmatched_predictions,
272
+ "unmatched_ground_truths": unmatched_ground_truths,
273
+ },
274
+ parameters={},
275
+ )
@@ -0,0 +1,149 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from valor_lite.cache import FileCacheReader, MemoryCacheReader
9
+
10
+
11
+ @dataclass
12
+ class EvaluatorInfo:
13
+ number_of_rows: int = 0
14
+ number_of_datums: int = 0
15
+ number_of_labels: int = 0
16
+ number_of_pixels: int = 0
17
+ number_of_groundtruth_pixels: int = 0
18
+ number_of_prediction_pixels: int = 0
19
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None
20
+
21
+
22
+ def generate_cache_path(path: str | Path) -> Path:
23
+ """Generate cache path from parent directory."""
24
+ return Path(path) / "counts"
25
+
26
+
27
+ def generate_metadata_path(path: str | Path) -> Path:
28
+ """Generate metadata path from parent directory."""
29
+ return Path(path) / "metadata.json"
30
+
31
+
32
+ def generate_schema(
33
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None
34
+ ) -> pa.Schema:
35
+ """Generate PyArrow schema from metadata fields."""
36
+
37
+ metadata_fields = metadata_fields if metadata_fields else []
38
+ reserved_fields = [
39
+ ("datum_uid", pa.string()),
40
+ ("datum_id", pa.int64()),
41
+ # groundtruth
42
+ ("gt_label", pa.string()),
43
+ ("gt_label_id", pa.int64()),
44
+ # prediction
45
+ ("pd_label", pa.string()),
46
+ ("pd_label_id", pa.int64()),
47
+ # pair
48
+ ("count", pa.uint64()),
49
+ ]
50
+
51
+ # validate
52
+ reserved_field_names = {f[0] for f in reserved_fields}
53
+ metadata_field_names = {f[0] for f in metadata_fields}
54
+ if conflicting := reserved_field_names & metadata_field_names:
55
+ raise ValueError(
56
+ f"metadata fields {conflicting} conflict with reserved fields"
57
+ )
58
+
59
+ return pa.schema(
60
+ [
61
+ *reserved_fields,
62
+ *metadata_fields,
63
+ ]
64
+ )
65
+
66
+
67
+ def encode_metadata_fields(
68
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None
69
+ ) -> dict[str, str]:
70
+ """Encode metadata fields into JSON format."""
71
+ metadata_fields = metadata_fields if metadata_fields else []
72
+ return {k: str(v) for k, v in metadata_fields}
73
+
74
+
75
+ def decode_metadata_fields(
76
+ encoded_metadata_fields: dict[str, str]
77
+ ) -> list[tuple[str, str | pa.DataType]]:
78
+ """Decode metadata fields from JSON format."""
79
+ return [(k, v) for k, v in encoded_metadata_fields.items()]
80
+
81
+
82
+ def extract_labels(
83
+ reader: MemoryCacheReader | FileCacheReader,
84
+ index_to_label_override: dict[int, str] | None = None,
85
+ ) -> dict[int, str]:
86
+ if index_to_label_override is not None:
87
+ return index_to_label_override
88
+
89
+ index_to_label = {}
90
+ for tbl in reader.iterate_tables(
91
+ columns=[
92
+ "gt_label_id",
93
+ "gt_label",
94
+ "pd_label_id",
95
+ "pd_label",
96
+ ]
97
+ ):
98
+
99
+ # get gt labels
100
+ gt_label_ids = tbl["gt_label_id"].to_numpy()
101
+ gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
102
+ gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
103
+ gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
104
+ gt_labels.pop(-1, None)
105
+ index_to_label.update(gt_labels)
106
+
107
+ # get pd labels
108
+ pd_label_ids = tbl["pd_label_id"].to_numpy()
109
+ pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
110
+ pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
111
+ pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
112
+ pd_labels.pop(-1, None)
113
+ index_to_label.update(pd_labels)
114
+
115
+ return index_to_label
116
+
117
+
118
+ def extract_counts(
119
+ reader: MemoryCacheReader | FileCacheReader,
120
+ datums: pc.Expression | None = None,
121
+ groundtruths: pc.Expression | None = None,
122
+ predictions: pc.Expression | None = None,
123
+ ):
124
+ n_dts, n_total, n_gts, n_pds = 0, 0, 0, 0
125
+ for tbl in reader.iterate_tables(filter=datums):
126
+
127
+ # count datums
128
+ n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
129
+
130
+ # count pixels
131
+ n_total += int(tbl["count"].to_numpy().sum())
132
+
133
+ # count groundtruth pixels
134
+ gt_tbl = tbl
135
+ gt_expr = pc.field("gt_label_id") >= 0
136
+ if groundtruths is not None:
137
+ gt_expr &= groundtruths
138
+ gt_tbl = tbl.filter(gt_expr)
139
+ n_gts += int(gt_tbl["count"].to_numpy().sum())
140
+
141
+ # count prediction pixels
142
+ pd_tbl = tbl
143
+ pd_expr = pc.field("pd_label_id") >= 0
144
+ if predictions is not None:
145
+ pd_expr &= predictions
146
+ pd_tbl = tbl.filter(pd_expr)
147
+ n_pds += int(pd_tbl["count"].to_numpy().sum())
148
+
149
+ return n_dts, n_total, n_gts, n_pds
@@ -0,0 +1,88 @@
1
+ from collections import defaultdict
2
+
3
+ from valor_lite.semantic_segmentation.metric import Metric, MetricType
4
+
5
+
6
+ def unpack_precision_recall_iou_into_metric_lists(
7
+ results: tuple,
8
+ index_to_label: dict[int, str],
9
+ ) -> dict[MetricType, list[Metric]]:
10
+
11
+ n_labels = len(index_to_label)
12
+ (
13
+ precision,
14
+ recall,
15
+ f1_score,
16
+ accuracy,
17
+ ious,
18
+ unmatched_prediction_ratios,
19
+ unmatched_ground_truth_ratios,
20
+ ) = results
21
+
22
+ metrics = defaultdict(list)
23
+
24
+ metrics[MetricType.Accuracy] = [
25
+ Metric.accuracy(
26
+ value=float(accuracy),
27
+ )
28
+ ]
29
+
30
+ metrics[MetricType.ConfusionMatrix] = [
31
+ Metric.confusion_matrix(
32
+ confusion_matrix={
33
+ index_to_label[gt_label_idx]: {
34
+ index_to_label[pd_label_idx]: {
35
+ "iou": float(ious[gt_label_idx, pd_label_idx])
36
+ }
37
+ for pd_label_idx in range(n_labels)
38
+ }
39
+ for gt_label_idx in range(n_labels)
40
+ },
41
+ unmatched_predictions={
42
+ index_to_label[pd_label_idx]: {
43
+ "ratio": float(unmatched_prediction_ratios[pd_label_idx])
44
+ }
45
+ for pd_label_idx in range(n_labels)
46
+ },
47
+ unmatched_ground_truths={
48
+ index_to_label[gt_label_idx]: {
49
+ "ratio": float(unmatched_ground_truth_ratios[gt_label_idx])
50
+ }
51
+ for gt_label_idx in range(n_labels)
52
+ },
53
+ )
54
+ ]
55
+
56
+ metrics[MetricType.mIOU] = [
57
+ Metric.mean_iou(
58
+ value=float(ious.diagonal().mean()),
59
+ )
60
+ ]
61
+
62
+ for label_idx, label in index_to_label.items():
63
+ metrics[MetricType.Precision].append(
64
+ Metric.precision(
65
+ value=float(precision[label_idx]),
66
+ label=label,
67
+ )
68
+ )
69
+ metrics[MetricType.Recall].append(
70
+ Metric.recall(
71
+ value=float(recall[label_idx]),
72
+ label=label,
73
+ )
74
+ )
75
+ metrics[MetricType.F1].append(
76
+ Metric.f1_score(
77
+ value=float(f1_score[label_idx]),
78
+ label=label,
79
+ )
80
+ )
81
+ metrics[MetricType.IOU].append(
82
+ Metric.iou(
83
+ value=float(ious[label_idx, label_idx]),
84
+ label=label,
85
+ )
86
+ )
87
+
88
+ return metrics