orca-sdk 0.0.91__py3-none-any.whl → 0.0.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/models/base_label_prediction_result.py +9 -1
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +195 -0
- orca_sdk/_shared/metrics_test.py +169 -0
- orca_sdk/classification_model.py +144 -19
- orca_sdk/classification_model_test.py +49 -22
- orca_sdk/telemetry.py +3 -0
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.92.dist-info}/METADATA +3 -1
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.92.dist-info}/RECORD +12 -9
- {orca_sdk-0.0.91.dist-info → orca_sdk-0.0.92.dist-info}/WHEEL +0 -0
|
@@ -10,7 +10,7 @@ The main change is:
|
|
|
10
10
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
|
-
from typing import Any, Type, TypeVar, Union, cast
|
|
13
|
+
from typing import Any, List, Type, TypeVar, Union, cast
|
|
14
14
|
|
|
15
15
|
from attrs import define as _attrs_define
|
|
16
16
|
from attrs import field as _attrs_field
|
|
@@ -28,6 +28,7 @@ class BaseLabelPredictionResult:
|
|
|
28
28
|
anomaly_score (Union[None, float]):
|
|
29
29
|
label (int):
|
|
30
30
|
label_name (Union[None, str]):
|
|
31
|
+
logits (List[float]):
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
34
|
prediction_id: Union[None, str]
|
|
@@ -35,6 +36,7 @@ class BaseLabelPredictionResult:
|
|
|
35
36
|
anomaly_score: Union[None, float]
|
|
36
37
|
label: int
|
|
37
38
|
label_name: Union[None, str]
|
|
39
|
+
logits: List[float]
|
|
38
40
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
39
41
|
|
|
40
42
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -51,6 +53,8 @@ class BaseLabelPredictionResult:
|
|
|
51
53
|
label_name: Union[None, str]
|
|
52
54
|
label_name = self.label_name
|
|
53
55
|
|
|
56
|
+
logits = self.logits
|
|
57
|
+
|
|
54
58
|
field_dict: dict[str, Any] = {}
|
|
55
59
|
field_dict.update(self.additional_properties)
|
|
56
60
|
field_dict.update(
|
|
@@ -60,6 +64,7 @@ class BaseLabelPredictionResult:
|
|
|
60
64
|
"anomaly_score": anomaly_score,
|
|
61
65
|
"label": label,
|
|
62
66
|
"label_name": label_name,
|
|
67
|
+
"logits": logits,
|
|
63
68
|
}
|
|
64
69
|
)
|
|
65
70
|
|
|
@@ -94,12 +99,15 @@ class BaseLabelPredictionResult:
|
|
|
94
99
|
|
|
95
100
|
label_name = _parse_label_name(d.pop("label_name"))
|
|
96
101
|
|
|
102
|
+
logits = cast(List[float], d.pop("logits"))
|
|
103
|
+
|
|
97
104
|
base_label_prediction_result = cls(
|
|
98
105
|
prediction_id=prediction_id,
|
|
99
106
|
confidence=confidence,
|
|
100
107
|
anomaly_score=anomaly_score,
|
|
101
108
|
label=label,
|
|
102
109
|
label_name=label_name,
|
|
110
|
+
logits=logits,
|
|
103
111
|
)
|
|
104
112
|
|
|
105
113
|
base_label_prediction_result.additional_properties = d
|
|
@@ -34,10 +34,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
34
34
|
anomaly_score (Union[None, float]):
|
|
35
35
|
label (int):
|
|
36
36
|
label_name (Union[None, str]):
|
|
37
|
+
logits (List[float]):
|
|
37
38
|
timestamp (datetime.datetime):
|
|
38
39
|
input_value (str):
|
|
39
40
|
input_embedding (List[float]):
|
|
40
|
-
logits (List[float]):
|
|
41
41
|
expected_label (Union[None, int]):
|
|
42
42
|
expected_label_name (Union[None, str]):
|
|
43
43
|
memories (List['LabelPredictionMemoryLookup']):
|
|
@@ -56,10 +56,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
56
56
|
anomaly_score: Union[None, float]
|
|
57
57
|
label: int
|
|
58
58
|
label_name: Union[None, str]
|
|
59
|
+
logits: List[float]
|
|
59
60
|
timestamp: datetime.datetime
|
|
60
61
|
input_value: str
|
|
61
62
|
input_embedding: List[float]
|
|
62
|
-
logits: List[float]
|
|
63
63
|
expected_label: Union[None, int]
|
|
64
64
|
expected_label_name: Union[None, str]
|
|
65
65
|
memories: List["LabelPredictionMemoryLookup"]
|
|
@@ -86,6 +86,8 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
86
86
|
label_name: Union[None, str]
|
|
87
87
|
label_name = self.label_name
|
|
88
88
|
|
|
89
|
+
logits = self.logits
|
|
90
|
+
|
|
89
91
|
timestamp = self.timestamp.isoformat()
|
|
90
92
|
|
|
91
93
|
input_value: str
|
|
@@ -93,8 +95,6 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
93
95
|
|
|
94
96
|
input_embedding = self.input_embedding
|
|
95
97
|
|
|
96
|
-
logits = self.logits
|
|
97
|
-
|
|
98
98
|
expected_label: Union[None, int]
|
|
99
99
|
expected_label = self.expected_label
|
|
100
100
|
|
|
@@ -136,10 +136,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
136
136
|
"anomaly_score": anomaly_score,
|
|
137
137
|
"label": label,
|
|
138
138
|
"label_name": label_name,
|
|
139
|
+
"logits": logits,
|
|
139
140
|
"timestamp": timestamp,
|
|
140
141
|
"input_value": input_value,
|
|
141
142
|
"input_embedding": input_embedding,
|
|
142
|
-
"logits": logits,
|
|
143
143
|
"expected_label": expected_label,
|
|
144
144
|
"expected_label_name": expected_label_name,
|
|
145
145
|
"memories": memories,
|
|
@@ -182,6 +182,8 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
182
182
|
|
|
183
183
|
label_name = _parse_label_name(d.pop("label_name"))
|
|
184
184
|
|
|
185
|
+
logits = cast(List[float], d.pop("logits"))
|
|
186
|
+
|
|
185
187
|
timestamp = isoparse(d.pop("timestamp"))
|
|
186
188
|
|
|
187
189
|
def _parse_input_value(data: object) -> str:
|
|
@@ -191,8 +193,6 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
191
193
|
|
|
192
194
|
input_embedding = cast(List[float], d.pop("input_embedding"))
|
|
193
195
|
|
|
194
|
-
logits = cast(List[float], d.pop("logits"))
|
|
195
|
-
|
|
196
196
|
def _parse_expected_label(data: object) -> Union[None, int]:
|
|
197
197
|
if data is None:
|
|
198
198
|
return data
|
|
@@ -251,10 +251,10 @@ class LabelPredictionWithMemoriesAndFeedback:
|
|
|
251
251
|
anomaly_score=anomaly_score,
|
|
252
252
|
label=label,
|
|
253
253
|
label_name=label_name,
|
|
254
|
+
logits=logits,
|
|
254
255
|
timestamp=timestamp,
|
|
255
256
|
input_value=input_value,
|
|
256
257
|
input_embedding=input_embedding,
|
|
257
|
-
logits=logits,
|
|
258
258
|
expected_label=expected_label,
|
|
259
259
|
expected_label_name=expected_label_name,
|
|
260
260
|
memories=memories,
|
|
@@ -43,6 +43,7 @@ class LabeledMemorysetMetadata:
|
|
|
43
43
|
label_names (List[str]):
|
|
44
44
|
created_at (datetime.datetime):
|
|
45
45
|
updated_at (datetime.datetime):
|
|
46
|
+
memories_updated_at (datetime.datetime):
|
|
46
47
|
insertion_task_id (str):
|
|
47
48
|
insertion_status (TaskStatus): Status of task in the task queue
|
|
48
49
|
metrics (MemorysetMetrics):
|
|
@@ -59,6 +60,7 @@ class LabeledMemorysetMetadata:
|
|
|
59
60
|
label_names: List[str]
|
|
60
61
|
created_at: datetime.datetime
|
|
61
62
|
updated_at: datetime.datetime
|
|
63
|
+
memories_updated_at: datetime.datetime
|
|
62
64
|
insertion_task_id: str
|
|
63
65
|
insertion_status: TaskStatus
|
|
64
66
|
metrics: "MemorysetMetrics"
|
|
@@ -97,6 +99,8 @@ class LabeledMemorysetMetadata:
|
|
|
97
99
|
|
|
98
100
|
updated_at = self.updated_at.isoformat()
|
|
99
101
|
|
|
102
|
+
memories_updated_at = self.memories_updated_at.isoformat()
|
|
103
|
+
|
|
100
104
|
insertion_task_id = self.insertion_task_id
|
|
101
105
|
|
|
102
106
|
insertion_status = (
|
|
@@ -120,6 +124,7 @@ class LabeledMemorysetMetadata:
|
|
|
120
124
|
"label_names": label_names,
|
|
121
125
|
"created_at": created_at,
|
|
122
126
|
"updated_at": updated_at,
|
|
127
|
+
"memories_updated_at": memories_updated_at,
|
|
123
128
|
"insertion_task_id": insertion_task_id,
|
|
124
129
|
"insertion_status": insertion_status,
|
|
125
130
|
"metrics": metrics,
|
|
@@ -180,6 +185,8 @@ class LabeledMemorysetMetadata:
|
|
|
180
185
|
|
|
181
186
|
updated_at = isoparse(d.pop("updated_at"))
|
|
182
187
|
|
|
188
|
+
memories_updated_at = isoparse(d.pop("memories_updated_at"))
|
|
189
|
+
|
|
183
190
|
insertion_task_id = d.pop("insertion_task_id")
|
|
184
191
|
|
|
185
192
|
insertion_status = TaskStatus(d.pop("insertion_status"))
|
|
@@ -198,6 +205,7 @@ class LabeledMemorysetMetadata:
|
|
|
198
205
|
label_names=label_names,
|
|
199
206
|
created_at=created_at,
|
|
200
207
|
updated_at=updated_at,
|
|
208
|
+
memories_updated_at=memories_updated_at,
|
|
201
209
|
insertion_task_id=insertion_task_id,
|
|
202
210
|
insertion_status=insertion_status,
|
|
203
211
|
metrics=metrics,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .metrics import calculate_pr_curve, calculate_roc_curve, compute_classifier_metrics
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Literal, Tuple, TypedDict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
from scipy.special import softmax
|
|
16
|
+
from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
|
|
17
|
+
from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
|
|
18
|
+
from sklearn.metrics import roc_auc_score
|
|
19
|
+
from sklearn.metrics import roc_curve as sklearn_roc_curve
|
|
20
|
+
from transformers.trainer_utils import EvalPrediction
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClassificationMetrics(TypedDict):
|
|
24
|
+
accuracy: float
|
|
25
|
+
f1_score: float
|
|
26
|
+
roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
|
|
27
|
+
pr_auc: float | None # precision-recall area under the curve (only for binary classification)
|
|
28
|
+
log_loss: float # cross-entropy loss for probabilities
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
|
|
32
|
+
"""
|
|
33
|
+
Compute standard metrics for classifier with Hugging Face Trainer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
eval_pred: The predictions containing logits and expected labels as given by the Trainer.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dictionary containing the accuracy, f1 score, and ROC AUC score.
|
|
40
|
+
"""
|
|
41
|
+
logits, references = eval_pred
|
|
42
|
+
if isinstance(logits, tuple):
|
|
43
|
+
logits = logits[0]
|
|
44
|
+
if not isinstance(logits, np.ndarray):
|
|
45
|
+
raise ValueError("Logits must be a numpy array")
|
|
46
|
+
if not isinstance(references, np.ndarray):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if not (logits > 0).all():
|
|
52
|
+
# convert logits to probabilities with softmax if necessary
|
|
53
|
+
probabilities = softmax(logits)
|
|
54
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
55
|
+
# convert logits to probabilities through normalization if necessary
|
|
56
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
57
|
+
else:
|
|
58
|
+
probabilities = logits
|
|
59
|
+
|
|
60
|
+
return classification_scores(references, probabilities)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def classification_scores(
|
|
64
|
+
references: NDArray[np.int64],
|
|
65
|
+
probabilities: NDArray[np.float32],
|
|
66
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
67
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
68
|
+
) -> ClassificationMetrics:
|
|
69
|
+
if probabilities.ndim == 1:
|
|
70
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
71
|
+
probabilities = np.column_stack([1 - probabilities, probabilities])
|
|
72
|
+
elif probabilities.ndim == 2:
|
|
73
|
+
if probabilities.shape[1] < 2:
|
|
74
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
77
|
+
|
|
78
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
79
|
+
|
|
80
|
+
num_classes_references = len(set(references))
|
|
81
|
+
num_classes_predictions = len(set(predictions))
|
|
82
|
+
|
|
83
|
+
if average is None:
|
|
84
|
+
average = "binary" if num_classes_references == 2 else "weighted"
|
|
85
|
+
|
|
86
|
+
accuracy = accuracy_score(references, predictions)
|
|
87
|
+
f1 = f1_score(references, predictions, average=average)
|
|
88
|
+
loss = log_loss(references, probabilities)
|
|
89
|
+
|
|
90
|
+
if num_classes_references == num_classes_predictions:
|
|
91
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
92
|
+
if num_classes_references == 2:
|
|
93
|
+
roc_auc = roc_auc_score(references, probabilities[:, 1])
|
|
94
|
+
precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
|
|
95
|
+
pr_auc = auc(recalls, precisions)
|
|
96
|
+
else:
|
|
97
|
+
roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
98
|
+
pr_auc = None
|
|
99
|
+
else:
|
|
100
|
+
roc_auc = None
|
|
101
|
+
pr_auc = None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"accuracy": float(accuracy),
|
|
105
|
+
"f1_score": float(f1),
|
|
106
|
+
"roc_auc": float(roc_auc) if roc_auc is not None else None,
|
|
107
|
+
"pr_auc": float(pr_auc) if pr_auc is not None else None,
|
|
108
|
+
"log_loss": float(loss),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def calculate_pr_curve(
|
|
113
|
+
references: NDArray[np.int64],
|
|
114
|
+
probabilities: NDArray[np.float32],
|
|
115
|
+
max_length: int = 100,
|
|
116
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
117
|
+
if probabilities.ndim == 1:
|
|
118
|
+
probabilities_slice = probabilities
|
|
119
|
+
elif probabilities.ndim == 2:
|
|
120
|
+
probabilities_slice = probabilities[:, 1]
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
123
|
+
|
|
124
|
+
if len(probabilities_slice) != len(references):
|
|
125
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
126
|
+
|
|
127
|
+
precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
|
|
128
|
+
|
|
129
|
+
# Convert all arrays to float32 immediately after getting them
|
|
130
|
+
precisions = precisions.astype(np.float32)
|
|
131
|
+
recalls = recalls.astype(np.float32)
|
|
132
|
+
thresholds = thresholds.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
# Concatenate with 0 to include the lowest threshold
|
|
135
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
136
|
+
|
|
137
|
+
# Sort by threshold
|
|
138
|
+
sorted_indices = np.argsort(thresholds)
|
|
139
|
+
thresholds = thresholds[sorted_indices]
|
|
140
|
+
precisions = precisions[sorted_indices]
|
|
141
|
+
recalls = recalls[sorted_indices]
|
|
142
|
+
|
|
143
|
+
if len(precisions) > max_length:
|
|
144
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
145
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
146
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
147
|
+
thresholds = new_thresholds
|
|
148
|
+
precisions = new_precisions
|
|
149
|
+
recalls = new_recalls
|
|
150
|
+
|
|
151
|
+
return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def calculate_roc_curve(
|
|
155
|
+
references: NDArray[np.int64],
|
|
156
|
+
probabilities: NDArray[np.float32],
|
|
157
|
+
max_length: int = 100,
|
|
158
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
159
|
+
if probabilities.ndim == 1:
|
|
160
|
+
probabilities_slice = probabilities
|
|
161
|
+
elif probabilities.ndim == 2:
|
|
162
|
+
probabilities_slice = probabilities[:, 1]
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
165
|
+
|
|
166
|
+
if len(probabilities_slice) != len(references):
|
|
167
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
168
|
+
|
|
169
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
170
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
171
|
+
fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
|
|
172
|
+
|
|
173
|
+
# Convert all arrays to float32 immediately after getting them
|
|
174
|
+
fpr = fpr.astype(np.float32)
|
|
175
|
+
tpr = tpr.astype(np.float32)
|
|
176
|
+
thresholds = thresholds.astype(np.float32)
|
|
177
|
+
|
|
178
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
179
|
+
thresholds[0] = 1.0
|
|
180
|
+
|
|
181
|
+
# Sort by threshold
|
|
182
|
+
sorted_indices = np.argsort(thresholds)
|
|
183
|
+
thresholds = thresholds[sorted_indices]
|
|
184
|
+
fpr = fpr[sorted_indices]
|
|
185
|
+
tpr = tpr[sorted_indices]
|
|
186
|
+
|
|
187
|
+
if len(fpr) > max_length:
|
|
188
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
189
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
190
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
191
|
+
thresholds = new_thresholds
|
|
192
|
+
fpr = new_fpr
|
|
193
|
+
tpr = new_tpr
|
|
194
|
+
|
|
195
|
+
return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IMPORTANT:
|
|
3
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
4
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
5
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
|
|
13
|
+
from .metrics import (
|
|
14
|
+
EvalPrediction,
|
|
15
|
+
calculate_pr_curve,
|
|
16
|
+
calculate_roc_curve,
|
|
17
|
+
classification_scores,
|
|
18
|
+
compute_classifier_metrics,
|
|
19
|
+
softmax,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_binary_metrics():
|
|
24
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
|
+
|
|
27
|
+
metrics = classification_scores(y_true, y_score)
|
|
28
|
+
|
|
29
|
+
assert metrics["accuracy"] == 0.8
|
|
30
|
+
assert metrics["f1_score"] == 0.8
|
|
31
|
+
assert metrics["roc_auc"] is not None
|
|
32
|
+
assert metrics["roc_auc"] > 0.8
|
|
33
|
+
assert metrics["roc_auc"] < 1.0
|
|
34
|
+
assert metrics["pr_auc"] is not None
|
|
35
|
+
assert metrics["pr_auc"] > 0.8
|
|
36
|
+
assert metrics["pr_auc"] < 1.0
|
|
37
|
+
assert metrics["log_loss"] is not None
|
|
38
|
+
assert metrics["log_loss"] > 0.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_multiclass_metrics_with_2_classes():
|
|
42
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
|
+
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
|
+
|
|
45
|
+
metrics = classification_scores(y_true, y_score)
|
|
46
|
+
|
|
47
|
+
assert metrics["accuracy"] == 0.8
|
|
48
|
+
assert metrics["f1_score"] == 0.8
|
|
49
|
+
assert metrics["roc_auc"] is not None
|
|
50
|
+
assert metrics["roc_auc"] > 0.8
|
|
51
|
+
assert metrics["roc_auc"] < 1.0
|
|
52
|
+
assert metrics["pr_auc"] is not None
|
|
53
|
+
assert metrics["pr_auc"] > 0.8
|
|
54
|
+
assert metrics["pr_auc"] < 1.0
|
|
55
|
+
assert metrics["log_loss"] is not None
|
|
56
|
+
assert metrics["log_loss"] > 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize(
|
|
60
|
+
"average, multiclass",
|
|
61
|
+
[("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
|
|
62
|
+
)
|
|
63
|
+
def test_multiclass_metrics_with_3_classes(
|
|
64
|
+
average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
|
|
65
|
+
):
|
|
66
|
+
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
|
+
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
|
+
|
|
69
|
+
metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
|
+
|
|
71
|
+
assert metrics["accuracy"] == 1.0
|
|
72
|
+
assert metrics["f1_score"] == 1.0
|
|
73
|
+
assert metrics["roc_auc"] is not None
|
|
74
|
+
assert metrics["roc_auc"] > 0.8
|
|
75
|
+
assert metrics["pr_auc"] is None
|
|
76
|
+
assert metrics["log_loss"] is not None
|
|
77
|
+
assert metrics["log_loss"] > 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_does_not_modify_logits_unless_necessary():
|
|
81
|
+
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
+
references = np.array([0, 1, 0, 1])
|
|
83
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
84
|
+
assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_normalizes_logits_if_necessary():
|
|
88
|
+
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
89
|
+
references = np.array([0, 1, 0, 1])
|
|
90
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
91
|
+
assert (
|
|
92
|
+
metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmaxes_logits_if_necessary():
|
|
97
|
+
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
+
references = np.array([0, 1, 0, 1])
|
|
99
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
100
|
+
assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_precision_recall_curve():
|
|
104
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
105
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
106
|
+
|
|
107
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
|
|
108
|
+
assert precision is not None
|
|
109
|
+
assert recall is not None
|
|
110
|
+
assert thresholds is not None
|
|
111
|
+
|
|
112
|
+
assert len(precision) == len(recall) == len(thresholds) == 6
|
|
113
|
+
assert precision[0] == 0.6
|
|
114
|
+
assert recall[0] == 1.0
|
|
115
|
+
assert precision[-1] == 1.0
|
|
116
|
+
assert recall[-1] == 0.0
|
|
117
|
+
|
|
118
|
+
# test that thresholds are sorted
|
|
119
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_roc_curve():
|
|
123
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
124
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
125
|
+
|
|
126
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
|
|
127
|
+
assert fpr is not None
|
|
128
|
+
assert tpr is not None
|
|
129
|
+
assert thresholds is not None
|
|
130
|
+
|
|
131
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 6
|
|
132
|
+
assert fpr[0] == 1.0
|
|
133
|
+
assert tpr[0] == 1.0
|
|
134
|
+
assert fpr[-1] == 0.0
|
|
135
|
+
assert tpr[-1] == 0.0
|
|
136
|
+
|
|
137
|
+
# test that thresholds are sorted
|
|
138
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_precision_recall_curve_max_length():
|
|
142
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
143
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
144
|
+
|
|
145
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
146
|
+
assert len(precision) == len(recall) == len(thresholds) == 5
|
|
147
|
+
|
|
148
|
+
assert precision[0] == 0.6
|
|
149
|
+
assert recall[0] == 1.0
|
|
150
|
+
assert precision[-1] == 1.0
|
|
151
|
+
assert recall[-1] == 0.0
|
|
152
|
+
|
|
153
|
+
# test that thresholds are sorted
|
|
154
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_roc_curve_max_length():
|
|
158
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
159
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
160
|
+
|
|
161
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
162
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 5
|
|
163
|
+
assert fpr[0] == 1.0
|
|
164
|
+
assert tpr[0] == 1.0
|
|
165
|
+
assert fpr[-1] == 0.0
|
|
166
|
+
assert tpr[-1] == 0.0
|
|
167
|
+
|
|
168
|
+
# test that thresholds are sorted
|
|
169
|
+
assert np.all(np.diff(thresholds) >= 0)
|
orca_sdk/classification_model.py
CHANGED
|
@@ -6,6 +6,15 @@ from datetime import datetime
|
|
|
6
6
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
|
+
import numpy as np
|
|
10
|
+
from datasets import Dataset
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
accuracy_score,
|
|
13
|
+
auc,
|
|
14
|
+
f1_score,
|
|
15
|
+
roc_auc_score,
|
|
16
|
+
)
|
|
17
|
+
|
|
9
18
|
from ._generated_api_client.api import (
|
|
10
19
|
create_evaluation,
|
|
11
20
|
create_model,
|
|
@@ -19,9 +28,11 @@ from ._generated_api_client.api import (
|
|
|
19
28
|
update_model,
|
|
20
29
|
)
|
|
21
30
|
from ._generated_api_client.models import (
|
|
31
|
+
ClassificationEvaluationResult,
|
|
22
32
|
CreateRACModelRequest,
|
|
23
33
|
EvaluationRequest,
|
|
24
34
|
ListPredictionsRequest,
|
|
35
|
+
PrecisionRecallCurve,
|
|
25
36
|
)
|
|
26
37
|
from ._generated_api_client.models import (
|
|
27
38
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -33,8 +44,10 @@ from ._generated_api_client.models import (
|
|
|
33
44
|
RACHeadType,
|
|
34
45
|
RACModelMetadata,
|
|
35
46
|
RACModelUpdate,
|
|
47
|
+
ROCCurve,
|
|
36
48
|
)
|
|
37
49
|
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
50
|
+
from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
|
|
38
51
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
39
52
|
from ._utils.task import wait_for_task
|
|
40
53
|
from .datasource import Datasource
|
|
@@ -372,6 +385,7 @@ class ClassificationModel:
|
|
|
372
385
|
anomaly_score=prediction.anomaly_score,
|
|
373
386
|
memoryset=self.memoryset,
|
|
374
387
|
model=self,
|
|
388
|
+
logits=prediction.logits,
|
|
375
389
|
)
|
|
376
390
|
for prediction in response
|
|
377
391
|
]
|
|
@@ -444,46 +458,157 @@ class ClassificationModel:
|
|
|
444
458
|
for prediction in predictions
|
|
445
459
|
]
|
|
446
460
|
|
|
447
|
-
def
|
|
461
|
+
def _calculate_metrics(
|
|
462
|
+
self,
|
|
463
|
+
predictions: list[LabelPrediction],
|
|
464
|
+
expected_labels: list[int],
|
|
465
|
+
) -> ClassificationEvaluationResult:
|
|
466
|
+
|
|
467
|
+
targets_array = np.array(expected_labels)
|
|
468
|
+
predictions_array = np.array([p.label for p in predictions])
|
|
469
|
+
|
|
470
|
+
logits_array = np.array([p.logits for p in predictions])
|
|
471
|
+
|
|
472
|
+
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
473
|
+
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
474
|
+
|
|
475
|
+
# Only compute ROC AUC and PR AUC for binary classification
|
|
476
|
+
unique_classes = np.unique(targets_array)
|
|
477
|
+
|
|
478
|
+
pr_curve = None
|
|
479
|
+
roc_curve = None
|
|
480
|
+
|
|
481
|
+
if len(unique_classes) == 2:
|
|
482
|
+
try:
|
|
483
|
+
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
484
|
+
pr_auc = float(auc(recalls, precisions))
|
|
485
|
+
|
|
486
|
+
pr_curve = PrecisionRecallCurve(
|
|
487
|
+
precisions=precisions.tolist(),
|
|
488
|
+
recalls=recalls.tolist(),
|
|
489
|
+
thresholds=pr_thresholds.tolist(),
|
|
490
|
+
auc=pr_auc,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
494
|
+
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
495
|
+
|
|
496
|
+
roc_curve = ROCCurve(
|
|
497
|
+
false_positive_rates=fpr.tolist(),
|
|
498
|
+
true_positive_rates=tpr.tolist(),
|
|
499
|
+
thresholds=roc_thresholds.tolist(),
|
|
500
|
+
auc=roc_auc,
|
|
501
|
+
)
|
|
502
|
+
except ValueError as e:
|
|
503
|
+
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
504
|
+
|
|
505
|
+
return ClassificationEvaluationResult(
|
|
506
|
+
f1_score=f1,
|
|
507
|
+
accuracy=accuracy,
|
|
508
|
+
loss=0.0,
|
|
509
|
+
precision_recall_curve=pr_curve,
|
|
510
|
+
roc_curve=roc_curve,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def _evaluate_datasource(
|
|
448
514
|
self,
|
|
449
515
|
datasource: Datasource,
|
|
516
|
+
value_column: str,
|
|
517
|
+
label_column: str,
|
|
518
|
+
record_predictions: bool,
|
|
519
|
+
tags: set[str] | None,
|
|
520
|
+
) -> dict[str, Any]:
|
|
521
|
+
response = create_evaluation(
|
|
522
|
+
self.id,
|
|
523
|
+
body=EvaluationRequest(
|
|
524
|
+
datasource_id=datasource.id,
|
|
525
|
+
datasource_label_column=label_column,
|
|
526
|
+
datasource_value_column=value_column,
|
|
527
|
+
memoryset_override_id=self._memoryset_override_id,
|
|
528
|
+
record_telemetry=record_predictions,
|
|
529
|
+
telemetry_tags=list(tags) if tags else None,
|
|
530
|
+
),
|
|
531
|
+
)
|
|
532
|
+
wait_for_task(response.task_id, description="Running evaluation")
|
|
533
|
+
response = get_evaluation(self.id, UUID(response.task_id))
|
|
534
|
+
assert response.result is not None
|
|
535
|
+
return response.result.to_dict()
|
|
536
|
+
|
|
537
|
+
def _evaluate_dataset(
|
|
538
|
+
self,
|
|
539
|
+
dataset: Dataset,
|
|
540
|
+
value_column: str,
|
|
541
|
+
label_column: str,
|
|
542
|
+
record_predictions: bool,
|
|
543
|
+
tags: set[str],
|
|
544
|
+
batch_size: int,
|
|
545
|
+
) -> dict[str, Any]:
|
|
546
|
+
predictions = []
|
|
547
|
+
expected_labels = []
|
|
548
|
+
|
|
549
|
+
for i in range(0, len(dataset), batch_size):
|
|
550
|
+
batch = dataset[i : i + batch_size]
|
|
551
|
+
predictions.extend(
|
|
552
|
+
self.predict(
|
|
553
|
+
batch[value_column],
|
|
554
|
+
expected_labels=batch[label_column],
|
|
555
|
+
tags=tags,
|
|
556
|
+
disable_telemetry=(not record_predictions),
|
|
557
|
+
)
|
|
558
|
+
)
|
|
559
|
+
expected_labels.extend(batch[label_column])
|
|
560
|
+
|
|
561
|
+
return self._calculate_metrics(predictions, expected_labels).to_dict()
|
|
562
|
+
|
|
563
|
+
def evaluate(
|
|
564
|
+
self,
|
|
565
|
+
data: Datasource | Dataset,
|
|
450
566
|
value_column: str = "value",
|
|
451
567
|
label_column: str = "label",
|
|
452
568
|
record_predictions: bool = False,
|
|
453
|
-
tags: set[str]
|
|
569
|
+
tags: set[str] = {"evaluation"},
|
|
570
|
+
batch_size: int = 100,
|
|
454
571
|
) -> dict[str, Any]:
|
|
455
572
|
"""
|
|
456
|
-
Evaluate the classification model on a given datasource
|
|
573
|
+
Evaluate the classification model on a given dataset or datasource
|
|
457
574
|
|
|
458
575
|
Params:
|
|
459
|
-
|
|
576
|
+
data: Dataset or Datasource to evaluate the model on
|
|
460
577
|
value_column: Name of the column that contains the input values to the model
|
|
461
578
|
label_column: Name of the column containing the expected labels
|
|
462
579
|
record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
|
|
463
580
|
tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
|
|
581
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
464
582
|
|
|
465
583
|
Returns:
|
|
466
584
|
Dictionary with evaluation metrics
|
|
467
585
|
|
|
468
586
|
Examples:
|
|
587
|
+
Evaluate using a Datasource:
|
|
469
588
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
470
589
|
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
590
|
+
|
|
591
|
+
Evaluate using a Dataset:
|
|
592
|
+
>>> model.evaluate(dataset, value_column="text", label_column="sentiment")
|
|
593
|
+
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
471
594
|
"""
|
|
472
|
-
|
|
473
|
-
self.
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
595
|
+
if isinstance(data, Datasource):
|
|
596
|
+
return self._evaluate_datasource(
|
|
597
|
+
datasource=data,
|
|
598
|
+
value_column=value_column,
|
|
599
|
+
label_column=label_column,
|
|
600
|
+
record_predictions=record_predictions,
|
|
601
|
+
tags=tags,
|
|
602
|
+
)
|
|
603
|
+
else:
|
|
604
|
+
return self._evaluate_dataset(
|
|
605
|
+
dataset=data,
|
|
606
|
+
value_column=value_column,
|
|
607
|
+
label_column=label_column,
|
|
608
|
+
record_predictions=record_predictions,
|
|
609
|
+
tags=tags,
|
|
610
|
+
batch_size=batch_size,
|
|
611
|
+
)
|
|
487
612
|
|
|
488
613
|
def finetune(self, datasource: Datasource):
|
|
489
614
|
# do not document until implemented
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from uuid import uuid4
|
|
2
2
|
|
|
3
|
+
import numpy as np
|
|
3
4
|
import pytest
|
|
4
5
|
from datasets.arrow_dataset import Dataset
|
|
5
6
|
|
|
@@ -138,28 +139,47 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
|
138
139
|
LabeledMemoryset.drop(memoryset.id)
|
|
139
140
|
|
|
140
141
|
|
|
141
|
-
def
|
|
142
|
-
|
|
143
|
-
"
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
142
|
+
def test_evaluate_combined(model):
|
|
143
|
+
data = [
|
|
144
|
+
{"text": "chicken noodle soup is the best", "label": 1},
|
|
145
|
+
{"text": "cats are cute", "label": 0},
|
|
146
|
+
{"text": "soup is great for the winter", "label": 0},
|
|
147
|
+
{"text": "i love cats", "label": 1},
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
eval_datasource = Datasource.from_list("eval_datasource", data)
|
|
151
|
+
result_datasource = model.evaluate(eval_datasource, value_column="text")
|
|
152
|
+
|
|
153
|
+
eval_dataset = Dataset.from_list(data)
|
|
154
|
+
result_dataset = model.evaluate(eval_dataset, value_column="text")
|
|
155
|
+
|
|
156
|
+
for result in [result_datasource, result_dataset]:
|
|
157
|
+
assert result is not None
|
|
158
|
+
assert isinstance(result, dict)
|
|
159
|
+
assert isinstance(result["accuracy"], float)
|
|
160
|
+
assert isinstance(result["f1_score"], float)
|
|
161
|
+
assert isinstance(result["loss"], float)
|
|
162
|
+
assert np.allclose(result["accuracy"], 0.5)
|
|
163
|
+
assert np.allclose(result["f1_score"], 0.5)
|
|
164
|
+
|
|
165
|
+
assert isinstance(result["precision_recall_curve"]["thresholds"], list)
|
|
166
|
+
assert isinstance(result["precision_recall_curve"]["precisions"], list)
|
|
167
|
+
assert isinstance(result["precision_recall_curve"]["recalls"], list)
|
|
168
|
+
assert isinstance(result["roc_curve"]["thresholds"], list)
|
|
169
|
+
assert isinstance(result["roc_curve"]["false_positive_rates"], list)
|
|
170
|
+
assert isinstance(result["roc_curve"]["true_positive_rates"], list)
|
|
171
|
+
|
|
172
|
+
assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
173
|
+
assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
174
|
+
assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
175
|
+
assert np.allclose(result["roc_curve"]["auc"], 0.625)
|
|
176
|
+
|
|
177
|
+
assert np.allclose(
|
|
178
|
+
result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
|
|
179
|
+
)
|
|
180
|
+
assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
181
|
+
assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
182
|
+
assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
|
|
163
183
|
|
|
164
184
|
|
|
165
185
|
def test_evaluate_with_telemetry(model):
|
|
@@ -188,6 +208,13 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
188
208
|
assert predictions[1].label_name == label_names[1]
|
|
189
209
|
assert 0 <= predictions[1].confidence <= 1
|
|
190
210
|
|
|
211
|
+
assert predictions[0].logits is not None
|
|
212
|
+
assert predictions[1].logits is not None
|
|
213
|
+
assert len(predictions[0].logits) == 2
|
|
214
|
+
assert len(predictions[1].logits) == 2
|
|
215
|
+
assert predictions[0].logits[0] > predictions[0].logits[1]
|
|
216
|
+
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
217
|
+
|
|
191
218
|
|
|
192
219
|
def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
|
|
193
220
|
predictions = model.predict(["Do you love soup?", "Are cats cute?"], disable_telemetry=True)
|
orca_sdk/telemetry.py
CHANGED
|
@@ -135,6 +135,7 @@ class LabelPrediction:
|
|
|
135
135
|
anomaly_score: float | None
|
|
136
136
|
memoryset: LabeledMemoryset
|
|
137
137
|
model: ClassificationModel
|
|
138
|
+
logits: list[float] | None
|
|
138
139
|
|
|
139
140
|
def __init__(
|
|
140
141
|
self,
|
|
@@ -147,6 +148,7 @@ class LabelPrediction:
|
|
|
147
148
|
memoryset: LabeledMemoryset | str,
|
|
148
149
|
model: ClassificationModel | str,
|
|
149
150
|
telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
|
|
151
|
+
logits: list[float] | None = None,
|
|
150
152
|
):
|
|
151
153
|
# for internal use only, do not document
|
|
152
154
|
from .classification_model import ClassificationModel
|
|
@@ -159,6 +161,7 @@ class LabelPrediction:
|
|
|
159
161
|
self.memoryset = LabeledMemoryset.open(memoryset) if isinstance(memoryset, str) else memoryset
|
|
160
162
|
self.model = ClassificationModel.open(model) if isinstance(model, str) else model
|
|
161
163
|
self.__telemetry = telemetry if telemetry else None
|
|
164
|
+
self.logits = logits
|
|
162
165
|
|
|
163
166
|
def __repr__(self):
|
|
164
167
|
return (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: orca_sdk
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.92
|
|
4
4
|
Summary: SDK for interacting with Orca Services
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Orca DB Inc.
|
|
@@ -20,7 +20,9 @@ Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
|
20
20
|
Requires-Dist: pyarrow (>=18.0.0,<19.0.0)
|
|
21
21
|
Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
|
|
22
22
|
Requires-Dist: python-dotenv (>=1.1.0,<2.0.0)
|
|
23
|
+
Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
|
|
23
24
|
Requires-Dist: torch (>=2.5.1,<3.0.0)
|
|
25
|
+
Requires-Dist: transformers (>=4.51.3,<5.0.0)
|
|
24
26
|
Description-Content-Type: text/markdown
|
|
25
27
|
|
|
26
28
|
<!--
|
|
@@ -81,7 +81,7 @@ orca_sdk/_generated_api_client/models/__init__.py,sha256=3fjbYdRtS5POw4Ce2FfBdnU
|
|
|
81
81
|
orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py,sha256=n5xhKNRx_QaAmXgveWSwLRlAjTHkuEGiH0-Vr1H6RsY,4256
|
|
82
82
|
orca_sdk/_generated_api_client/models/api_key_metadata.py,sha256=jQrSe_X5hCgFYh8PwX-X0M6VINVGVhLBlKmv4qN5otA,3789
|
|
83
83
|
orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py,sha256=umhWzrAt0ZEI9B7pLpnOEsc5Kc-dxeJdzHX7iHyjt4I,182
|
|
84
|
-
orca_sdk/_generated_api_client/models/base_label_prediction_result.py,sha256=
|
|
84
|
+
orca_sdk/_generated_api_client/models/base_label_prediction_result.py,sha256=wJBkJcUdI588tOXimOZ6lBIFGPAaStBrOC84m4-8CIw,3828
|
|
85
85
|
orca_sdk/_generated_api_client/models/base_model.py,sha256=0UY9I_q-b6kOG0LYcw_C192PKRfmejYX9rZa7POCrTc,1563
|
|
86
86
|
orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py,sha256=w5Ni3zFPMTf8XYsH-EZmlokx7vV0vpQiSnbXlktoQBY,6713
|
|
87
87
|
orca_sdk/_generated_api_client/models/classification_evaluation_result.py,sha256=mdSZjv7qy6OreEjwNTV_VpfoeuZHdrnlCG8sr0elhoo,4715
|
|
@@ -121,7 +121,7 @@ orca_sdk/_generated_api_client/models/internal_server_error_response.py,sha256=R
|
|
|
121
121
|
orca_sdk/_generated_api_client/models/label_class_metrics.py,sha256=Q3vWLw8F_IdwAwhunLp0f_l7PvP1gZN1XGCZQRJtbAY,3144
|
|
122
122
|
orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py,sha256=DEwdX5532kHRpsKJe3wOgUWUTZOdeaJV30XvsI8dyOI,6005
|
|
123
123
|
orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py,sha256=bsXtXEf25ch5qAdpnXWSi2qzCkQPZ4xhKcHWMxlgOhQ,2338
|
|
124
|
-
orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py,sha256=
|
|
124
|
+
orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py,sha256=ZxmUTIY02_eZz5EeO6xwehFyS4yYzt_Nw3v0pEbTclQ,9004
|
|
125
125
|
orca_sdk/_generated_api_client/models/labeled_memory.py,sha256=BYG1PqvL3FXKQCuBTg3pLwIgA0Uv8KU5YoxvdR2zZxg,5205
|
|
126
126
|
orca_sdk/_generated_api_client/models/labeled_memory_insert.py,sha256=O3rgrloH3eu9YPzP7X1AKRfq6wxx9Eznl_prpRiMVVM,3768
|
|
127
127
|
orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py,sha256=b9T8i61YUIkNpbJzLwUztfUGNqwRzYJ51RfFpukNS5I,2295
|
|
@@ -134,7 +134,7 @@ orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py,s
|
|
|
134
134
|
orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py,sha256=pr2cM9z2F4iAMW6N38xGCYD_fr8R5co70-p0TVRN94w,6307
|
|
135
135
|
orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_feedback_metrics.py,sha256=qSUl04WibCHV-1yoytEW2TI5in2cf1HCerpOJ8wej3w,2272
|
|
136
136
|
orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics_metadata.py,sha256=wYnU5KuMTlUwIxpbrCe4obx40h_-FJExxoCOMd0-Qik,2366
|
|
137
|
-
orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py,sha256=
|
|
137
|
+
orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py,sha256=1GiEJSXBA6VnUoLYSNk2f5Zxyj8bi_kWvUGOUzHNKyQ,7669
|
|
138
138
|
orca_sdk/_generated_api_client/models/labeled_memoryset_update.py,sha256=xd5obMpcK1zZiU-q4xQUbrWnkaIi176gcjZTBoAMlpQ,3586
|
|
139
139
|
orca_sdk/_generated_api_client/models/list_memories_request.py,sha256=ZPp2FR8-tNMc9eAmErAHEpLf2xrvI_6NtGldSQfAfe4,3091
|
|
140
140
|
orca_sdk/_generated_api_client/models/list_predictions_request.py,sha256=I20mJhJhx-sIeFeK1WNbmaTI07U2lhS840pURBZdYGo,9976
|
|
@@ -190,6 +190,9 @@ orca_sdk/_generated_api_client/models/unauthorized_error_response.py,sha256=Sr-p
|
|
|
190
190
|
orca_sdk/_generated_api_client/models/update_prediction_request.py,sha256=HMPq_K0MlQY7beWn73LEhgjUNcBEjqGC8oFlB9t9em0,3573
|
|
191
191
|
orca_sdk/_generated_api_client/py.typed,sha256=8ZJUsxZiuOy1oJeVhsTWQhTG_6pTVHVXk5hJL79ebTk,25
|
|
192
192
|
orca_sdk/_generated_api_client/types.py,sha256=j7-uA7wWwN1cq0d7ULccN4vDm-1IzgnrxSyVktxvABM,1399
|
|
193
|
+
orca_sdk/_shared/__init__.py,sha256=aXGbM6K8IN5V_7bPeTQZE2CZedV1i1IkynS7swq8D7k,89
|
|
194
|
+
orca_sdk/_shared/metrics.py,sha256=FNGOSfZke3AVCf-j7FdYcq7nmH68RJ0SqD0r_LsLjeY,7565
|
|
195
|
+
orca_sdk/_shared/metrics_test.py,sha256=Udv_JsHbYFYtP2W7iFHgnafOciD03te25qvrX9PUaQ8,5522
|
|
193
196
|
orca_sdk/_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
194
197
|
orca_sdk/_utils/analysis_ui.py,sha256=2ADUOxbLTcY0sYOpHeReTr13sQ7Yu4kZQ66RDeUuWZw,9216
|
|
195
198
|
orca_sdk/_utils/analysis_ui_style.css,sha256=q_ba_-_KtgztepHg829zLzypaxKayl7ySC1-oYDzV3k,836
|
|
@@ -203,8 +206,8 @@ orca_sdk/_utils/prediction_result_ui.py,sha256=dudc21ka2Bqdtr_8wQaMVFxLGvrsZxWUZ
|
|
|
203
206
|
orca_sdk/_utils/task.py,sha256=WOfFuRCoh6QHVDDYrGeq3Hi6NVihQQZJii0cBCONOWk,2400
|
|
204
207
|
orca_sdk/_utils/value_parser.py,sha256=c3qMABCCDQcIjn9N1orYYnlRwDW9JWdGwW_2TDZPLdI,1286
|
|
205
208
|
orca_sdk/_utils/value_parser_test.py,sha256=OybsiC-Obi32RRi9NIuwrVBRAnlyPMV1xVAaevSrb7M,1079
|
|
206
|
-
orca_sdk/classification_model.py,sha256=
|
|
207
|
-
orca_sdk/classification_model_test.py,sha256=
|
|
209
|
+
orca_sdk/classification_model.py,sha256=j3b277NGeF2kDehwPN7s95KkEyjafvH52ip9t-dRFPk,26439
|
|
210
|
+
orca_sdk/classification_model_test.py,sha256=gbqyjjnwVZB_Z7IHLIVJ7U3jY-xdKag9hIXULDelVqQ,13272
|
|
208
211
|
orca_sdk/conftest.py,sha256=_7O6yVccU-_zteUTCX3j7j7ZfyKNBD7nYL-G8ln6qXY,4661
|
|
209
212
|
orca_sdk/credentials.py,sha256=gq_4w_o-igCCLNR6TY1x4RzMYysKUCsXJvdi6nem-A0,3558
|
|
210
213
|
orca_sdk/credentials_test.py,sha256=ETTyDZ9MEpb_X6yiRcgYGWNKCB2QZ5CLYB_unRGg1b8,1028
|
|
@@ -214,8 +217,8 @@ orca_sdk/embedding_model.py,sha256=Hw8NlwzWVK5ts8SF0lHIs7hL38hCTreEiIyoqHY-OFA,1
|
|
|
214
217
|
orca_sdk/embedding_model_test.py,sha256=j6uGu9ZJSafDV7uFiJiG8SZVGvPQBgxxDcg7i1xbWho,6914
|
|
215
218
|
orca_sdk/memoryset.py,sha256=xvaNn3YwG3fzk3MZhk3LeX_K5yRKP-yRf79bIAUBR-Y,56058
|
|
216
219
|
orca_sdk/memoryset_test.py,sha256=w8-2RXFePg1pqC67uMpHSevjnW4P0GbNpqRjJXAmIa0,15122
|
|
217
|
-
orca_sdk/telemetry.py,sha256=
|
|
220
|
+
orca_sdk/telemetry.py,sha256=U53NI7_D1IpWqdV8NYuUrwvhpX0CF_PJvRRvOiFekno,16393
|
|
218
221
|
orca_sdk/telemetry_test.py,sha256=7JfS0k7r9STMCkasCjXWL3KmbrdmVjVnFeYPCdT8jqQ,5059
|
|
219
|
-
orca_sdk-0.0.
|
|
220
|
-
orca_sdk-0.0.
|
|
221
|
-
orca_sdk-0.0.
|
|
222
|
+
orca_sdk-0.0.92.dist-info/METADATA,sha256=W1Ee5mRfJQH03lS5XQAztQC1cE4gupmRJaibugGAcrs,3229
|
|
223
|
+
orca_sdk-0.0.92.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
224
|
+
orca_sdk-0.0.92.dist-info/RECORD,,
|
|
File without changes
|