orca-sdk 0.0.91__py3-none-any.whl → 0.0.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +4 -0
- orca_sdk/_generated_api_client/api/memoryset/suggest_cascading_edits_memoryset_name_or_id_memory_memory_id_cascading_edits_post.py +233 -0
- orca_sdk/_generated_api_client/models/__init__.py +4 -0
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
- orca_sdk/_generated_api_client/models/cascade_edit_suggestions_request.py +154 -0
- orca_sdk/_generated_api_client/models/cascading_edit_suggestion.py +92 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +62 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +1 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- orca_sdk/_generated_api_client/models/labeled_memory.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +16 -7
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +195 -0
- orca_sdk/_shared/metrics_test.py +169 -0
- orca_sdk/_utils/data_parsing.py +31 -2
- orca_sdk/_utils/data_parsing_test.py +18 -15
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/classification_model.py +170 -27
- orca_sdk/classification_model_test.py +74 -32
- orca_sdk/conftest.py +86 -25
- orca_sdk/datasource.py +22 -12
- orca_sdk/embedding_model_test.py +6 -5
- orca_sdk/memoryset.py +78 -0
- orca_sdk/memoryset_test.py +197 -123
- orca_sdk/telemetry.py +3 -0
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/METADATA +3 -1
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/RECORD +32 -25
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.93.dist-info}/WHEEL +0 -0
|
@@ -38,6 +38,7 @@ class LabelPredictionMemoryLookup:
|
|
|
38
38
|
memory_version (int):
|
|
39
39
|
created_at (datetime.datetime):
|
|
40
40
|
updated_at (datetime.datetime):
|
|
41
|
+
edited_at (datetime.datetime):
|
|
41
42
|
metrics (MemoryMetrics):
|
|
42
43
|
label (int):
|
|
43
44
|
label_name (Union[None, str]):
|
|
@@ -54,6 +55,7 @@ class LabelPredictionMemoryLookup:
|
|
|
54
55
|
memory_version: int
|
|
55
56
|
created_at: datetime.datetime
|
|
56
57
|
updated_at: datetime.datetime
|
|
58
|
+
edited_at: datetime.datetime
|
|
57
59
|
metrics: "MemoryMetrics"
|
|
58
60
|
label: int
|
|
59
61
|
label_name: Union[None, str]
|
|
@@ -81,6 +83,8 @@ class LabelPredictionMemoryLookup:
|
|
|
81
83
|
|
|
82
84
|
updated_at = self.updated_at.isoformat()
|
|
83
85
|
|
|
86
|
+
edited_at = self.edited_at.isoformat()
|
|
87
|
+
|
|
84
88
|
metrics = self.metrics.to_dict()
|
|
85
89
|
|
|
86
90
|
label = self.label
|
|
@@ -106,6 +110,7 @@ class LabelPredictionMemoryLookup:
|
|
|
106
110
|
"memory_version": memory_version,
|
|
107
111
|
"created_at": created_at,
|
|
108
112
|
"updated_at": updated_at,
|
|
113
|
+
"edited_at": edited_at,
|
|
109
114
|
"metrics": metrics,
|
|
110
115
|
"label": label,
|
|
111
116
|
"label_name": label_name,
|
|
@@ -148,6 +153,8 @@ class LabelPredictionMemoryLookup:
|
|
|
148
153
|
|
|
149
154
|
updated_at = isoparse(d.pop("updated_at"))
|
|
150
155
|
|
|
156
|
+
edited_at = isoparse(d.pop("edited_at"))
|
|
157
|
+
|
|
151
158
|
metrics = MemoryMetrics.from_dict(d.pop("metrics"))
|
|
152
159
|
|
|
153
160
|
label = d.pop("label")
|
|
@@ -174,6 +181,7 @@ class LabelPredictionMemoryLookup:
|
|
|
174
181
|
memory_version=memory_version,
|
|
175
182
|
created_at=created_at,
|
|
176
183
|
updated_at=updated_at,
|
|
184
|
+
edited_at=edited_at,
|
|
177
185
|
metrics=metrics,
|
|
178
186
|
label=label,
|
|
179
187
|
label_name=label_name,
|
|
@@ -34,10 +34,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
34
34
|
anomaly_score (Union[None, float]):
|
|
35
35
|
label (int):
|
|
36
36
|
label_name (Union[None, str]):
|
|
37
|
+
logits (List[float]):
|
|
37
38
|
timestamp (datetime.datetime):
|
|
38
39
|
input_value (str):
|
|
39
40
|
input_embedding (List[float]):
|
|
40
|
-
logits (List[float]):
|
|
41
41
|
expected_label (Union[None, int]):
|
|
42
42
|
expected_label_name (Union[None, str]):
|
|
43
43
|
memories (List['LabelPredictionMemoryLookup']):
|
|
@@ -56,10 +56,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
56
56
|
anomaly_score: Union[None, float]
|
|
57
57
|
label: int
|
|
58
58
|
label_name: Union[None, str]
|
|
59
|
+
logits: List[float]
|
|
59
60
|
timestamp: datetime.datetime
|
|
60
61
|
input_value: str
|
|
61
62
|
input_embedding: List[float]
|
|
62
|
-
logits: List[float]
|
|
63
63
|
expected_label: Union[None, int]
|
|
64
64
|
expected_label_name: Union[None, str]
|
|
65
65
|
memories: List["LabelPredictionMemoryLookup"]
|
|
@@ -86,6 +86,8 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
86
86
|
label_name: Union[None, str]
|
|
87
87
|
label_name = self.label_name
|
|
88
88
|
|
|
89
|
+
logits = self.logits
|
|
90
|
+
|
|
89
91
|
timestamp = self.timestamp.isoformat()
|
|
90
92
|
|
|
91
93
|
input_value: str
|
|
@@ -93,8 +95,6 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
93
95
|
|
|
94
96
|
input_embedding = self.input_embedding
|
|
95
97
|
|
|
96
|
-
logits = self.logits
|
|
97
|
-
|
|
98
98
|
expected_label: Union[None, int]
|
|
99
99
|
expected_label = self.expected_label
|
|
100
100
|
|
|
@@ -136,10 +136,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
136
136
|
"anomaly_score": anomaly_score,
|
|
137
137
|
"label": label,
|
|
138
138
|
"label_name": label_name,
|
|
139
|
+
"logits": logits,
|
|
139
140
|
"timestamp": timestamp,
|
|
140
141
|
"input_value": input_value,
|
|
141
142
|
"input_embedding": input_embedding,
|
|
142
|
-
"logits": logits,
|
|
143
143
|
"expected_label": expected_label,
|
|
144
144
|
"expected_label_name": expected_label_name,
|
|
145
145
|
"memories": memories,
|
|
@@ -182,6 +182,8 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
182
182
|
|
|
183
183
|
label_name = _parse_label_name(d.pop("label_name"))
|
|
184
184
|
|
|
185
|
+
logits = cast(List[float], d.pop("logits"))
|
|
186
|
+
|
|
185
187
|
timestamp = isoparse(d.pop("timestamp"))
|
|
186
188
|
|
|
187
189
|
def _parse_input_value(data: object) -> str:
|
|
@@ -191,8 +193,6 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
191
193
|
|
|
192
194
|
input_embedding = cast(List[float], d.pop("input_embedding"))
|
|
193
195
|
|
|
194
|
-
logits = cast(List[float], d.pop("logits"))
|
|
195
|
-
|
|
196
196
|
def _parse_expected_label(data: object) -> Union[None, int]:
|
|
197
197
|
if data is None:
|
|
198
198
|
return data
|
|
@@ -251,10 +251,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
251
251
|
anomaly_score=anomaly_score,
|
|
252
252
|
label=label,
|
|
253
253
|
label_name=label_name,
|
|
254
|
+
logits=logits,
|
|
254
255
|
timestamp=timestamp,
|
|
255
256
|
input_value=input_value,
|
|
256
257
|
input_embedding=input_embedding,
|
|
257
|
-
logits=logits,
|
|
258
258
|
expected_label=expected_label,
|
|
259
259
|
expected_label_name=expected_label_name,
|
|
260
260
|
memories=memories,
|
|
@@ -38,6 +38,7 @@ class LabeledMemory:
|
|
|
38
38
|
memory_version (int):
|
|
39
39
|
created_at (datetime.datetime):
|
|
40
40
|
updated_at (datetime.datetime):
|
|
41
|
+
edited_at (datetime.datetime):
|
|
41
42
|
metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
|
|
42
43
|
label (int):
|
|
43
44
|
label_name (Union[None, str]):
|
|
@@ -51,6 +52,7 @@ class LabeledMemory:
|
|
|
51
52
|
memory_version: int
|
|
52
53
|
created_at: datetime.datetime
|
|
53
54
|
updated_at: datetime.datetime
|
|
55
|
+
edited_at: datetime.datetime
|
|
54
56
|
metrics: "LabeledMemoryMetrics"
|
|
55
57
|
label: int
|
|
56
58
|
label_name: Union[None, str]
|
|
@@ -75,6 +77,8 @@ class LabeledMemory:
|
|
|
75
77
|
|
|
76
78
|
updated_at = self.updated_at.isoformat()
|
|
77
79
|
|
|
80
|
+
edited_at = self.edited_at.isoformat()
|
|
81
|
+
|
|
78
82
|
metrics = self.metrics.to_dict()
|
|
79
83
|
|
|
80
84
|
label = self.label
|
|
@@ -94,6 +98,7 @@ class LabeledMemory:
|
|
|
94
98
|
"memory_version": memory_version,
|
|
95
99
|
"created_at": created_at,
|
|
96
100
|
"updated_at": updated_at,
|
|
101
|
+
"edited_at": edited_at,
|
|
97
102
|
"metrics": metrics,
|
|
98
103
|
"label": label,
|
|
99
104
|
"label_name": label_name,
|
|
@@ -133,6 +138,8 @@ class LabeledMemory:
|
|
|
133
138
|
|
|
134
139
|
updated_at = isoparse(d.pop("updated_at"))
|
|
135
140
|
|
|
141
|
+
edited_at = isoparse(d.pop("edited_at"))
|
|
142
|
+
|
|
136
143
|
metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
|
|
137
144
|
|
|
138
145
|
label = d.pop("label")
|
|
@@ -153,6 +160,7 @@ class LabeledMemory:
|
|
|
153
160
|
memory_version=memory_version,
|
|
154
161
|
created_at=created_at,
|
|
155
162
|
updated_at=updated_at,
|
|
163
|
+
edited_at=edited_at,
|
|
156
164
|
metrics=metrics,
|
|
157
165
|
label=label,
|
|
158
166
|
label_name=label_name,
|
|
@@ -38,6 +38,7 @@ class LabeledMemoryLookup:
|
|
|
38
38
|
memory_version (int):
|
|
39
39
|
created_at (datetime.datetime):
|
|
40
40
|
updated_at (datetime.datetime):
|
|
41
|
+
edited_at (datetime.datetime):
|
|
41
42
|
metrics (MemoryMetrics):
|
|
42
43
|
label (int):
|
|
43
44
|
label_name (Union[None, str]):
|
|
@@ -52,6 +53,7 @@ class LabeledMemoryLookup:
|
|
|
52
53
|
memory_version: int
|
|
53
54
|
created_at: datetime.datetime
|
|
54
55
|
updated_at: datetime.datetime
|
|
56
|
+
edited_at: datetime.datetime
|
|
55
57
|
metrics: "MemoryMetrics"
|
|
56
58
|
label: int
|
|
57
59
|
label_name: Union[None, str]
|
|
@@ -77,6 +79,8 @@ class LabeledMemoryLookup:
|
|
|
77
79
|
|
|
78
80
|
updated_at = self.updated_at.isoformat()
|
|
79
81
|
|
|
82
|
+
edited_at = self.edited_at.isoformat()
|
|
83
|
+
|
|
80
84
|
metrics = self.metrics.to_dict()
|
|
81
85
|
|
|
82
86
|
label = self.label
|
|
@@ -98,6 +102,7 @@ class LabeledMemoryLookup:
|
|
|
98
102
|
"memory_version": memory_version,
|
|
99
103
|
"created_at": created_at,
|
|
100
104
|
"updated_at": updated_at,
|
|
105
|
+
"edited_at": edited_at,
|
|
101
106
|
"metrics": metrics,
|
|
102
107
|
"label": label,
|
|
103
108
|
"label_name": label_name,
|
|
@@ -138,6 +143,8 @@ class LabeledMemoryLookup:
|
|
|
138
143
|
|
|
139
144
|
updated_at = isoparse(d.pop("updated_at"))
|
|
140
145
|
|
|
146
|
+
edited_at = isoparse(d.pop("edited_at"))
|
|
147
|
+
|
|
141
148
|
metrics = MemoryMetrics.from_dict(d.pop("metrics"))
|
|
142
149
|
|
|
143
150
|
label = d.pop("label")
|
|
@@ -160,6 +167,7 @@ class LabeledMemoryLookup:
|
|
|
160
167
|
memory_version=memory_version,
|
|
161
168
|
created_at=created_at,
|
|
162
169
|
updated_at=updated_at,
|
|
170
|
+
edited_at=edited_at,
|
|
163
171
|
metrics=metrics,
|
|
164
172
|
label=label,
|
|
165
173
|
label_name=label_name,
|
|
@@ -40,6 +40,7 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
40
40
|
memory_version (int):
|
|
41
41
|
created_at (datetime.datetime):
|
|
42
42
|
updated_at (datetime.datetime):
|
|
43
|
+
edited_at (datetime.datetime):
|
|
43
44
|
metrics (LabeledMemoryMetrics): Metrics computed for a labeled memory.
|
|
44
45
|
label (int):
|
|
45
46
|
label_name (Union[None, str]):
|
|
@@ -55,6 +56,7 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
55
56
|
memory_version: int
|
|
56
57
|
created_at: datetime.datetime
|
|
57
58
|
updated_at: datetime.datetime
|
|
59
|
+
edited_at: datetime.datetime
|
|
58
60
|
metrics: "LabeledMemoryMetrics"
|
|
59
61
|
label: int
|
|
60
62
|
label_name: Union[None, str]
|
|
@@ -81,6 +83,8 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
81
83
|
|
|
82
84
|
updated_at = self.updated_at.isoformat()
|
|
83
85
|
|
|
86
|
+
edited_at = self.edited_at.isoformat()
|
|
87
|
+
|
|
84
88
|
metrics = self.metrics.to_dict()
|
|
85
89
|
|
|
86
90
|
label = self.label
|
|
@@ -104,6 +108,7 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
104
108
|
"memory_version": memory_version,
|
|
105
109
|
"created_at": created_at,
|
|
106
110
|
"updated_at": updated_at,
|
|
111
|
+
"edited_at": edited_at,
|
|
107
112
|
"metrics": metrics,
|
|
108
113
|
"label": label,
|
|
109
114
|
"label_name": label_name,
|
|
@@ -148,6 +153,8 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
148
153
|
|
|
149
154
|
updated_at = isoparse(d.pop("updated_at"))
|
|
150
155
|
|
|
156
|
+
edited_at = isoparse(d.pop("edited_at"))
|
|
157
|
+
|
|
151
158
|
metrics = LabeledMemoryMetrics.from_dict(d.pop("metrics"))
|
|
152
159
|
|
|
153
160
|
label = d.pop("label")
|
|
@@ -172,6 +179,7 @@ class LabeledMemoryWithFeedbackMetrics:
|
|
|
172
179
|
memory_version=memory_version,
|
|
173
180
|
created_at=created_at,
|
|
174
181
|
updated_at=updated_at,
|
|
182
|
+
edited_at=edited_at,
|
|
175
183
|
metrics=metrics,
|
|
176
184
|
label=label,
|
|
177
185
|
label_name=label_name,
|
|
@@ -43,6 +43,7 @@ class LabeledMemorysetMetadata:
|
|
|
43
43
|
label_names (List[str]):
|
|
44
44
|
created_at (datetime.datetime):
|
|
45
45
|
updated_at (datetime.datetime):
|
|
46
|
+
memories_updated_at (datetime.datetime):
|
|
46
47
|
insertion_task_id (str):
|
|
47
48
|
insertion_status (TaskStatus): Status of task in the task queue
|
|
48
49
|
metrics (MemorysetMetrics):
|
|
@@ -59,6 +60,7 @@ class LabeledMemorysetMetadata:
|
|
|
59
60
|
label_names: List[str]
|
|
60
61
|
created_at: datetime.datetime
|
|
61
62
|
updated_at: datetime.datetime
|
|
63
|
+
memories_updated_at: datetime.datetime
|
|
62
64
|
insertion_task_id: str
|
|
63
65
|
insertion_status: TaskStatus
|
|
64
66
|
metrics: "MemorysetMetrics"
|
|
@@ -97,6 +99,8 @@ class LabeledMemorysetMetadata:
|
|
|
97
99
|
|
|
98
100
|
updated_at = self.updated_at.isoformat()
|
|
99
101
|
|
|
102
|
+
memories_updated_at = self.memories_updated_at.isoformat()
|
|
103
|
+
|
|
100
104
|
insertion_task_id = self.insertion_task_id
|
|
101
105
|
|
|
102
106
|
insertion_status = (
|
|
@@ -120,6 +124,7 @@ class LabeledMemorysetMetadata:
|
|
|
120
124
|
"label_names": label_names,
|
|
121
125
|
"created_at": created_at,
|
|
122
126
|
"updated_at": updated_at,
|
|
127
|
+
"memories_updated_at": memories_updated_at,
|
|
123
128
|
"insertion_task_id": insertion_task_id,
|
|
124
129
|
"insertion_status": insertion_status,
|
|
125
130
|
"metrics": metrics,
|
|
@@ -180,6 +185,8 @@ class LabeledMemorysetMetadata:
|
|
|
180
185
|
|
|
181
186
|
updated_at = isoparse(d.pop("updated_at"))
|
|
182
187
|
|
|
188
|
+
memories_updated_at = isoparse(d.pop("memories_updated_at"))
|
|
189
|
+
|
|
183
190
|
insertion_task_id = d.pop("insertion_task_id")
|
|
184
191
|
|
|
185
192
|
insertion_status = TaskStatus(d.pop("insertion_status"))
|
|
@@ -198,6 +205,7 @@ class LabeledMemorysetMetadata:
|
|
|
198
205
|
label_names=label_names,
|
|
199
206
|
created_at=created_at,
|
|
200
207
|
updated_at=updated_at,
|
|
208
|
+
memories_updated_at=memories_updated_at,
|
|
201
209
|
insertion_task_id=insertion_task_id,
|
|
202
210
|
insertion_status=insertion_status,
|
|
203
211
|
metrics=metrics,
|
|
@@ -28,14 +28,16 @@ class PredictionRequest:
|
|
|
28
28
|
expected_labels (Union[List[int], None, Unset]):
|
|
29
29
|
tags (Union[Unset, List[str]]):
|
|
30
30
|
memoryset_override_id (Union[None, Unset, str]):
|
|
31
|
-
|
|
31
|
+
save_telemetry (Union[Unset, bool]): Default: True.
|
|
32
|
+
save_telemetry_synchronously (Union[Unset, bool]): Default: False.
|
|
32
33
|
"""
|
|
33
34
|
|
|
34
35
|
input_values: List[str]
|
|
35
36
|
expected_labels: Union[List[int], None, Unset] = UNSET
|
|
36
37
|
tags: Union[Unset, List[str]] = UNSET
|
|
37
38
|
memoryset_override_id: Union[None, Unset, str] = UNSET
|
|
38
|
-
|
|
39
|
+
save_telemetry: Union[Unset, bool] = True
|
|
40
|
+
save_telemetry_synchronously: Union[Unset, bool] = False
|
|
39
41
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
40
42
|
|
|
41
43
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -62,7 +64,9 @@ class PredictionRequest:
|
|
|
62
64
|
else:
|
|
63
65
|
memoryset_override_id = self.memoryset_override_id
|
|
64
66
|
|
|
65
|
-
|
|
67
|
+
save_telemetry = self.save_telemetry
|
|
68
|
+
|
|
69
|
+
save_telemetry_synchronously = self.save_telemetry_synchronously
|
|
66
70
|
|
|
67
71
|
field_dict: dict[str, Any] = {}
|
|
68
72
|
field_dict.update(self.additional_properties)
|
|
@@ -77,8 +81,10 @@ class PredictionRequest:
|
|
|
77
81
|
field_dict["tags"] = tags
|
|
78
82
|
if memoryset_override_id is not UNSET:
|
|
79
83
|
field_dict["memoryset_override_id"] = memoryset_override_id
|
|
80
|
-
if
|
|
81
|
-
field_dict["
|
|
84
|
+
if save_telemetry is not UNSET:
|
|
85
|
+
field_dict["save_telemetry"] = save_telemetry
|
|
86
|
+
if save_telemetry_synchronously is not UNSET:
|
|
87
|
+
field_dict["save_telemetry_synchronously"] = save_telemetry_synchronously
|
|
82
88
|
|
|
83
89
|
return field_dict
|
|
84
90
|
|
|
@@ -156,14 +162,17 @@ class PredictionRequest:
|
|
|
156
162
|
|
|
157
163
|
memoryset_override_id = _parse_memoryset_override_id(d.pop("memoryset_override_id", UNSET))
|
|
158
164
|
|
|
159
|
-
|
|
165
|
+
save_telemetry = d.pop("save_telemetry", UNSET)
|
|
166
|
+
|
|
167
|
+
save_telemetry_synchronously = d.pop("save_telemetry_synchronously", UNSET)
|
|
160
168
|
|
|
161
169
|
prediction_request = cls(
|
|
162
170
|
input_values=input_values,
|
|
163
171
|
expected_labels=expected_labels,
|
|
164
172
|
tags=tags,
|
|
165
173
|
memoryset_override_id=memoryset_override_id,
|
|
166
|
-
|
|
174
|
+
save_telemetry=save_telemetry,
|
|
175
|
+
save_telemetry_synchronously=save_telemetry_synchronously,
|
|
167
176
|
)
|
|
168
177
|
|
|
169
178
|
prediction_request.additional_properties = d
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .metrics import calculate_pr_curve, calculate_roc_curve, compute_classifier_metrics
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Literal, Tuple, TypedDict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
from scipy.special import softmax
|
|
16
|
+
from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
|
|
17
|
+
from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
|
|
18
|
+
from sklearn.metrics import roc_auc_score
|
|
19
|
+
from sklearn.metrics import roc_curve as sklearn_roc_curve
|
|
20
|
+
from transformers.trainer_utils import EvalPrediction
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClassificationMetrics(TypedDict):
|
|
24
|
+
accuracy: float
|
|
25
|
+
f1_score: float
|
|
26
|
+
roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
|
|
27
|
+
pr_auc: float | None # precision-recall area under the curve (only for binary classification)
|
|
28
|
+
log_loss: float # cross-entropy loss for probabilities
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
|
|
32
|
+
"""
|
|
33
|
+
Compute standard metrics for classifier with Hugging Face Trainer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
eval_pred: The predictions containing logits and expected labels as given by the Trainer.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dictionary containing the accuracy, f1 score, and ROC AUC score.
|
|
40
|
+
"""
|
|
41
|
+
logits, references = eval_pred
|
|
42
|
+
if isinstance(logits, tuple):
|
|
43
|
+
logits = logits[0]
|
|
44
|
+
if not isinstance(logits, np.ndarray):
|
|
45
|
+
raise ValueError("Logits must be a numpy array")
|
|
46
|
+
if not isinstance(references, np.ndarray):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if not (logits > 0).all():
|
|
52
|
+
# convert logits to probabilities with softmax if necessary
|
|
53
|
+
probabilities = softmax(logits)
|
|
54
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
55
|
+
# convert logits to probabilities through normalization if necessary
|
|
56
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
57
|
+
else:
|
|
58
|
+
probabilities = logits
|
|
59
|
+
|
|
60
|
+
return classification_scores(references, probabilities)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def classification_scores(
|
|
64
|
+
references: NDArray[np.int64],
|
|
65
|
+
probabilities: NDArray[np.float32],
|
|
66
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
67
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
68
|
+
) -> ClassificationMetrics:
|
|
69
|
+
if probabilities.ndim == 1:
|
|
70
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
71
|
+
probabilities = np.column_stack([1 - probabilities, probabilities])
|
|
72
|
+
elif probabilities.ndim == 2:
|
|
73
|
+
if probabilities.shape[1] < 2:
|
|
74
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
77
|
+
|
|
78
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
79
|
+
|
|
80
|
+
num_classes_references = len(set(references))
|
|
81
|
+
num_classes_predictions = len(set(predictions))
|
|
82
|
+
|
|
83
|
+
if average is None:
|
|
84
|
+
average = "binary" if num_classes_references == 2 else "weighted"
|
|
85
|
+
|
|
86
|
+
accuracy = accuracy_score(references, predictions)
|
|
87
|
+
f1 = f1_score(references, predictions, average=average)
|
|
88
|
+
loss = log_loss(references, probabilities)
|
|
89
|
+
|
|
90
|
+
if num_classes_references == num_classes_predictions:
|
|
91
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
92
|
+
if num_classes_references == 2:
|
|
93
|
+
roc_auc = roc_auc_score(references, probabilities[:, 1])
|
|
94
|
+
precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
|
|
95
|
+
pr_auc = auc(recalls, precisions)
|
|
96
|
+
else:
|
|
97
|
+
roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
98
|
+
pr_auc = None
|
|
99
|
+
else:
|
|
100
|
+
roc_auc = None
|
|
101
|
+
pr_auc = None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"accuracy": float(accuracy),
|
|
105
|
+
"f1_score": float(f1),
|
|
106
|
+
"roc_auc": float(roc_auc) if roc_auc is not None else None,
|
|
107
|
+
"pr_auc": float(pr_auc) if pr_auc is not None else None,
|
|
108
|
+
"log_loss": float(loss),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def calculate_pr_curve(
|
|
113
|
+
references: NDArray[np.int64],
|
|
114
|
+
probabilities: NDArray[np.float32],
|
|
115
|
+
max_length: int = 100,
|
|
116
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
117
|
+
if probabilities.ndim == 1:
|
|
118
|
+
probabilities_slice = probabilities
|
|
119
|
+
elif probabilities.ndim == 2:
|
|
120
|
+
probabilities_slice = probabilities[:, 1]
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
123
|
+
|
|
124
|
+
if len(probabilities_slice) != len(references):
|
|
125
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
126
|
+
|
|
127
|
+
precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
|
|
128
|
+
|
|
129
|
+
# Convert all arrays to float32 immediately after getting them
|
|
130
|
+
precisions = precisions.astype(np.float32)
|
|
131
|
+
recalls = recalls.astype(np.float32)
|
|
132
|
+
thresholds = thresholds.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
# Concatenate with 0 to include the lowest threshold
|
|
135
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
136
|
+
|
|
137
|
+
# Sort by threshold
|
|
138
|
+
sorted_indices = np.argsort(thresholds)
|
|
139
|
+
thresholds = thresholds[sorted_indices]
|
|
140
|
+
precisions = precisions[sorted_indices]
|
|
141
|
+
recalls = recalls[sorted_indices]
|
|
142
|
+
|
|
143
|
+
if len(precisions) > max_length:
|
|
144
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
145
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
146
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
147
|
+
thresholds = new_thresholds
|
|
148
|
+
precisions = new_precisions
|
|
149
|
+
recalls = new_recalls
|
|
150
|
+
|
|
151
|
+
return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def calculate_roc_curve(
|
|
155
|
+
references: NDArray[np.int64],
|
|
156
|
+
probabilities: NDArray[np.float32],
|
|
157
|
+
max_length: int = 100,
|
|
158
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
159
|
+
if probabilities.ndim == 1:
|
|
160
|
+
probabilities_slice = probabilities
|
|
161
|
+
elif probabilities.ndim == 2:
|
|
162
|
+
probabilities_slice = probabilities[:, 1]
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
165
|
+
|
|
166
|
+
if len(probabilities_slice) != len(references):
|
|
167
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
168
|
+
|
|
169
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
170
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
171
|
+
fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
|
|
172
|
+
|
|
173
|
+
# Convert all arrays to float32 immediately after getting them
|
|
174
|
+
fpr = fpr.astype(np.float32)
|
|
175
|
+
tpr = tpr.astype(np.float32)
|
|
176
|
+
thresholds = thresholds.astype(np.float32)
|
|
177
|
+
|
|
178
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
179
|
+
thresholds[0] = 1.0
|
|
180
|
+
|
|
181
|
+
# Sort by threshold
|
|
182
|
+
sorted_indices = np.argsort(thresholds)
|
|
183
|
+
thresholds = thresholds[sorted_indices]
|
|
184
|
+
fpr = fpr[sorted_indices]
|
|
185
|
+
tpr = tpr[sorted_indices]
|
|
186
|
+
|
|
187
|
+
if len(fpr) > max_length:
|
|
188
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
189
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
190
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
191
|
+
thresholds = new_thresholds
|
|
192
|
+
fpr = new_fpr
|
|
193
|
+
tpr = new_tpr
|
|
194
|
+
|
|
195
|
+
return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
|