orca-sdk 0.0.90__py3-none-any.whl → 0.0.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +12 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +12 -12
- orca_sdk/_generated_api_client/api/classification_model/update_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/models/__init__.py +8 -2
- orca_sdk/_generated_api_client/models/{label_prediction_result.py → base_label_prediction_result.py} +24 -9
- orca_sdk/_generated_api_client/models/delete_memorysets_request.py +70 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_update.py +113 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +9 -0
- orca_sdk/_generated_api_client/models/rac_model_update.py +82 -0
- orca_sdk/_shared/__init__.py +1 -0
- orca_sdk/_shared/metrics.py +195 -0
- orca_sdk/_shared/metrics_test.py +169 -0
- orca_sdk/_utils/analysis_ui.py +1 -1
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/classification_model.py +191 -23
- orca_sdk/classification_model_test.py +75 -22
- orca_sdk/conftest.py +13 -1
- orca_sdk/embedding_model.py +2 -0
- orca_sdk/memoryset.py +13 -0
- orca_sdk/memoryset_test.py +27 -6
- orca_sdk/telemetry.py +13 -2
- orca_sdk/telemetry_test.py +6 -0
- {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/METADATA +3 -1
- {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/RECORD +29 -20
- {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains metrics for usage with the Hugging Face Trainer.
|
|
3
|
+
|
|
4
|
+
IMPORTANT:
|
|
5
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
6
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
7
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Literal, Tuple, TypedDict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
from scipy.special import softmax
|
|
16
|
+
from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
|
|
17
|
+
from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
|
|
18
|
+
from sklearn.metrics import roc_auc_score
|
|
19
|
+
from sklearn.metrics import roc_curve as sklearn_roc_curve
|
|
20
|
+
from transformers.trainer_utils import EvalPrediction
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClassificationMetrics(TypedDict):
|
|
24
|
+
accuracy: float
|
|
25
|
+
f1_score: float
|
|
26
|
+
roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
|
|
27
|
+
pr_auc: float | None # precision-recall area under the curve (only for binary classification)
|
|
28
|
+
log_loss: float # cross-entropy loss for probabilities
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
|
|
32
|
+
"""
|
|
33
|
+
Compute standard metrics for classifier with Hugging Face Trainer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
eval_pred: The predictions containing logits and expected labels as given by the Trainer.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A dictionary containing the accuracy, f1 score, and ROC AUC score.
|
|
40
|
+
"""
|
|
41
|
+
logits, references = eval_pred
|
|
42
|
+
if isinstance(logits, tuple):
|
|
43
|
+
logits = logits[0]
|
|
44
|
+
if not isinstance(logits, np.ndarray):
|
|
45
|
+
raise ValueError("Logits must be a numpy array")
|
|
46
|
+
if not isinstance(references, np.ndarray):
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Multiple label columns found, use the `label_names` training argument to specify which one to use"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if not (logits > 0).all():
|
|
52
|
+
# convert logits to probabilities with softmax if necessary
|
|
53
|
+
probabilities = softmax(logits)
|
|
54
|
+
elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
|
|
55
|
+
# convert logits to probabilities through normalization if necessary
|
|
56
|
+
probabilities = logits / logits.sum(-1, keepdims=True)
|
|
57
|
+
else:
|
|
58
|
+
probabilities = logits
|
|
59
|
+
|
|
60
|
+
return classification_scores(references, probabilities)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def classification_scores(
|
|
64
|
+
references: NDArray[np.int64],
|
|
65
|
+
probabilities: NDArray[np.float32],
|
|
66
|
+
average: Literal["micro", "macro", "weighted", "binary"] | None = None,
|
|
67
|
+
multi_class: Literal["ovr", "ovo"] = "ovr",
|
|
68
|
+
) -> ClassificationMetrics:
|
|
69
|
+
if probabilities.ndim == 1:
|
|
70
|
+
# convert 1D probabilities (binary) to 2D logits
|
|
71
|
+
probabilities = np.column_stack([1 - probabilities, probabilities])
|
|
72
|
+
elif probabilities.ndim == 2:
|
|
73
|
+
if probabilities.shape[1] < 2:
|
|
74
|
+
raise ValueError("Use a different metric function for regression tasks")
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
77
|
+
|
|
78
|
+
predictions = np.argmax(probabilities, axis=-1)
|
|
79
|
+
|
|
80
|
+
num_classes_references = len(set(references))
|
|
81
|
+
num_classes_predictions = len(set(predictions))
|
|
82
|
+
|
|
83
|
+
if average is None:
|
|
84
|
+
average = "binary" if num_classes_references == 2 else "weighted"
|
|
85
|
+
|
|
86
|
+
accuracy = accuracy_score(references, predictions)
|
|
87
|
+
f1 = f1_score(references, predictions, average=average)
|
|
88
|
+
loss = log_loss(references, probabilities)
|
|
89
|
+
|
|
90
|
+
if num_classes_references == num_classes_predictions:
|
|
91
|
+
# special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
|
|
92
|
+
if num_classes_references == 2:
|
|
93
|
+
roc_auc = roc_auc_score(references, probabilities[:, 1])
|
|
94
|
+
precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
|
|
95
|
+
pr_auc = auc(recalls, precisions)
|
|
96
|
+
else:
|
|
97
|
+
roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
|
|
98
|
+
pr_auc = None
|
|
99
|
+
else:
|
|
100
|
+
roc_auc = None
|
|
101
|
+
pr_auc = None
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"accuracy": float(accuracy),
|
|
105
|
+
"f1_score": float(f1),
|
|
106
|
+
"roc_auc": float(roc_auc) if roc_auc is not None else None,
|
|
107
|
+
"pr_auc": float(pr_auc) if pr_auc is not None else None,
|
|
108
|
+
"log_loss": float(loss),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def calculate_pr_curve(
|
|
113
|
+
references: NDArray[np.int64],
|
|
114
|
+
probabilities: NDArray[np.float32],
|
|
115
|
+
max_length: int = 100,
|
|
116
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
117
|
+
if probabilities.ndim == 1:
|
|
118
|
+
probabilities_slice = probabilities
|
|
119
|
+
elif probabilities.ndim == 2:
|
|
120
|
+
probabilities_slice = probabilities[:, 1]
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
123
|
+
|
|
124
|
+
if len(probabilities_slice) != len(references):
|
|
125
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
126
|
+
|
|
127
|
+
precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
|
|
128
|
+
|
|
129
|
+
# Convert all arrays to float32 immediately after getting them
|
|
130
|
+
precisions = precisions.astype(np.float32)
|
|
131
|
+
recalls = recalls.astype(np.float32)
|
|
132
|
+
thresholds = thresholds.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
# Concatenate with 0 to include the lowest threshold
|
|
135
|
+
thresholds = np.concatenate(([0], thresholds))
|
|
136
|
+
|
|
137
|
+
# Sort by threshold
|
|
138
|
+
sorted_indices = np.argsort(thresholds)
|
|
139
|
+
thresholds = thresholds[sorted_indices]
|
|
140
|
+
precisions = precisions[sorted_indices]
|
|
141
|
+
recalls = recalls[sorted_indices]
|
|
142
|
+
|
|
143
|
+
if len(precisions) > max_length:
|
|
144
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
145
|
+
new_precisions = np.interp(new_thresholds, thresholds, precisions)
|
|
146
|
+
new_recalls = np.interp(new_thresholds, thresholds, recalls)
|
|
147
|
+
thresholds = new_thresholds
|
|
148
|
+
precisions = new_precisions
|
|
149
|
+
recalls = new_recalls
|
|
150
|
+
|
|
151
|
+
return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def calculate_roc_curve(
|
|
155
|
+
references: NDArray[np.int64],
|
|
156
|
+
probabilities: NDArray[np.float32],
|
|
157
|
+
max_length: int = 100,
|
|
158
|
+
) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
|
|
159
|
+
if probabilities.ndim == 1:
|
|
160
|
+
probabilities_slice = probabilities
|
|
161
|
+
elif probabilities.ndim == 2:
|
|
162
|
+
probabilities_slice = probabilities[:, 1]
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("Probabilities must be 1 or 2 dimensional")
|
|
165
|
+
|
|
166
|
+
if len(probabilities_slice) != len(references):
|
|
167
|
+
raise ValueError("Probabilities and references must have the same length")
|
|
168
|
+
|
|
169
|
+
# Convert probabilities to float32 before calling sklearn_roc_curve
|
|
170
|
+
probabilities_slice = probabilities_slice.astype(np.float32)
|
|
171
|
+
fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
|
|
172
|
+
|
|
173
|
+
# Convert all arrays to float32 immediately after getting them
|
|
174
|
+
fpr = fpr.astype(np.float32)
|
|
175
|
+
tpr = tpr.astype(np.float32)
|
|
176
|
+
thresholds = thresholds.astype(np.float32)
|
|
177
|
+
|
|
178
|
+
# We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
|
|
179
|
+
thresholds[0] = 1.0
|
|
180
|
+
|
|
181
|
+
# Sort by threshold
|
|
182
|
+
sorted_indices = np.argsort(thresholds)
|
|
183
|
+
thresholds = thresholds[sorted_indices]
|
|
184
|
+
fpr = fpr[sorted_indices]
|
|
185
|
+
tpr = tpr[sorted_indices]
|
|
186
|
+
|
|
187
|
+
if len(fpr) > max_length:
|
|
188
|
+
new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
|
|
189
|
+
new_fpr = np.interp(new_thresholds, thresholds, fpr)
|
|
190
|
+
new_tpr = np.interp(new_thresholds, thresholds, tpr)
|
|
191
|
+
thresholds = new_thresholds
|
|
192
|
+
fpr = new_fpr
|
|
193
|
+
tpr = new_tpr
|
|
194
|
+
|
|
195
|
+
return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IMPORTANT:
|
|
3
|
+
- This is a shared file between OrcaLib and the Orca SDK.
|
|
4
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
5
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
|
|
13
|
+
from .metrics import (
|
|
14
|
+
EvalPrediction,
|
|
15
|
+
calculate_pr_curve,
|
|
16
|
+
calculate_roc_curve,
|
|
17
|
+
classification_scores,
|
|
18
|
+
compute_classifier_metrics,
|
|
19
|
+
softmax,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_binary_metrics():
|
|
24
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
|
+
|
|
27
|
+
metrics = classification_scores(y_true, y_score)
|
|
28
|
+
|
|
29
|
+
assert metrics["accuracy"] == 0.8
|
|
30
|
+
assert metrics["f1_score"] == 0.8
|
|
31
|
+
assert metrics["roc_auc"] is not None
|
|
32
|
+
assert metrics["roc_auc"] > 0.8
|
|
33
|
+
assert metrics["roc_auc"] < 1.0
|
|
34
|
+
assert metrics["pr_auc"] is not None
|
|
35
|
+
assert metrics["pr_auc"] > 0.8
|
|
36
|
+
assert metrics["pr_auc"] < 1.0
|
|
37
|
+
assert metrics["log_loss"] is not None
|
|
38
|
+
assert metrics["log_loss"] > 0.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_multiclass_metrics_with_2_classes():
|
|
42
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
|
+
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
|
+
|
|
45
|
+
metrics = classification_scores(y_true, y_score)
|
|
46
|
+
|
|
47
|
+
assert metrics["accuracy"] == 0.8
|
|
48
|
+
assert metrics["f1_score"] == 0.8
|
|
49
|
+
assert metrics["roc_auc"] is not None
|
|
50
|
+
assert metrics["roc_auc"] > 0.8
|
|
51
|
+
assert metrics["roc_auc"] < 1.0
|
|
52
|
+
assert metrics["pr_auc"] is not None
|
|
53
|
+
assert metrics["pr_auc"] > 0.8
|
|
54
|
+
assert metrics["pr_auc"] < 1.0
|
|
55
|
+
assert metrics["log_loss"] is not None
|
|
56
|
+
assert metrics["log_loss"] > 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize(
|
|
60
|
+
"average, multiclass",
|
|
61
|
+
[("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
|
|
62
|
+
)
|
|
63
|
+
def test_multiclass_metrics_with_3_classes(
|
|
64
|
+
average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
|
|
65
|
+
):
|
|
66
|
+
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
|
+
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
|
+
|
|
69
|
+
metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
|
+
|
|
71
|
+
assert metrics["accuracy"] == 1.0
|
|
72
|
+
assert metrics["f1_score"] == 1.0
|
|
73
|
+
assert metrics["roc_auc"] is not None
|
|
74
|
+
assert metrics["roc_auc"] > 0.8
|
|
75
|
+
assert metrics["pr_auc"] is None
|
|
76
|
+
assert metrics["log_loss"] is not None
|
|
77
|
+
assert metrics["log_loss"] > 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_does_not_modify_logits_unless_necessary():
|
|
81
|
+
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
+
references = np.array([0, 1, 0, 1])
|
|
83
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
84
|
+
assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_normalizes_logits_if_necessary():
|
|
88
|
+
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
89
|
+
references = np.array([0, 1, 0, 1])
|
|
90
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
91
|
+
assert (
|
|
92
|
+
metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmaxes_logits_if_necessary():
|
|
97
|
+
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
+
references = np.array([0, 1, 0, 1])
|
|
99
|
+
metrics = compute_classifier_metrics(EvalPrediction(logits, references))
|
|
100
|
+
assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_precision_recall_curve():
|
|
104
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
105
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
106
|
+
|
|
107
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
|
|
108
|
+
assert precision is not None
|
|
109
|
+
assert recall is not None
|
|
110
|
+
assert thresholds is not None
|
|
111
|
+
|
|
112
|
+
assert len(precision) == len(recall) == len(thresholds) == 6
|
|
113
|
+
assert precision[0] == 0.6
|
|
114
|
+
assert recall[0] == 1.0
|
|
115
|
+
assert precision[-1] == 1.0
|
|
116
|
+
assert recall[-1] == 0.0
|
|
117
|
+
|
|
118
|
+
# test that thresholds are sorted
|
|
119
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_roc_curve():
|
|
123
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
124
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
125
|
+
|
|
126
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
|
|
127
|
+
assert fpr is not None
|
|
128
|
+
assert tpr is not None
|
|
129
|
+
assert thresholds is not None
|
|
130
|
+
|
|
131
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 6
|
|
132
|
+
assert fpr[0] == 1.0
|
|
133
|
+
assert tpr[0] == 1.0
|
|
134
|
+
assert fpr[-1] == 0.0
|
|
135
|
+
assert tpr[-1] == 0.0
|
|
136
|
+
|
|
137
|
+
# test that thresholds are sorted
|
|
138
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_precision_recall_curve_max_length():
|
|
142
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
143
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
144
|
+
|
|
145
|
+
precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
146
|
+
assert len(precision) == len(recall) == len(thresholds) == 5
|
|
147
|
+
|
|
148
|
+
assert precision[0] == 0.6
|
|
149
|
+
assert recall[0] == 1.0
|
|
150
|
+
assert precision[-1] == 1.0
|
|
151
|
+
assert recall[-1] == 0.0
|
|
152
|
+
|
|
153
|
+
# test that thresholds are sorted
|
|
154
|
+
assert np.all(np.diff(thresholds) >= 0)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_roc_curve_max_length():
|
|
158
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
159
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
160
|
+
|
|
161
|
+
fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
162
|
+
assert len(fpr) == len(tpr) == len(thresholds) == 5
|
|
163
|
+
assert fpr[0] == 1.0
|
|
164
|
+
assert tpr[0] == 1.0
|
|
165
|
+
assert fpr[-1] == 0.0
|
|
166
|
+
assert tpr[-1] == 0.0
|
|
167
|
+
|
|
168
|
+
# test that thresholds are sorted
|
|
169
|
+
assert np.all(np.diff(thresholds) >= 0)
|
orca_sdk/_utils/analysis_ui.py
CHANGED
|
@@ -152,7 +152,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
152
152
|
predicted_label_name = label_names[predicted_label]
|
|
153
153
|
predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
|
|
154
154
|
|
|
155
|
-
with gr.Row(equal_height=True, variant="panel"
|
|
155
|
+
with gr.Row(equal_height=True, variant="panel"):
|
|
156
156
|
with gr.Column(scale=9):
|
|
157
157
|
assert isinstance(mem.value, str)
|
|
158
158
|
gr.Markdown(mem.value, label="Value", height=50)
|
orca_sdk/classification_model.py
CHANGED
|
@@ -6,6 +6,15 @@ from datetime import datetime
|
|
|
6
6
|
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
|
+
import numpy as np
|
|
10
|
+
from datasets import Dataset
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
accuracy_score,
|
|
13
|
+
auc,
|
|
14
|
+
f1_score,
|
|
15
|
+
roc_auc_score,
|
|
16
|
+
)
|
|
17
|
+
|
|
9
18
|
from ._generated_api_client.api import (
|
|
10
19
|
create_evaluation,
|
|
11
20
|
create_model,
|
|
@@ -16,11 +25,14 @@ from ._generated_api_client.api import (
|
|
|
16
25
|
list_predictions,
|
|
17
26
|
predict_gpu,
|
|
18
27
|
record_prediction_feedback,
|
|
28
|
+
update_model,
|
|
19
29
|
)
|
|
20
30
|
from ._generated_api_client.models import (
|
|
31
|
+
ClassificationEvaluationResult,
|
|
21
32
|
CreateRACModelRequest,
|
|
22
33
|
EvaluationRequest,
|
|
23
34
|
ListPredictionsRequest,
|
|
35
|
+
PrecisionRecallCurve,
|
|
24
36
|
)
|
|
25
37
|
from ._generated_api_client.models import (
|
|
26
38
|
PredictionSortItemItemType0 as PredictionSortColumns,
|
|
@@ -31,9 +43,12 @@ from ._generated_api_client.models import (
|
|
|
31
43
|
from ._generated_api_client.models import (
|
|
32
44
|
RACHeadType,
|
|
33
45
|
RACModelMetadata,
|
|
46
|
+
RACModelUpdate,
|
|
47
|
+
ROCCurve,
|
|
34
48
|
)
|
|
35
49
|
from ._generated_api_client.models.prediction_request import PredictionRequest
|
|
36
|
-
from .
|
|
50
|
+
from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
|
|
51
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
37
52
|
from ._utils.task import wait_for_task
|
|
38
53
|
from .datasource import Datasource
|
|
39
54
|
from .memoryset import LabeledMemoryset
|
|
@@ -270,18 +285,53 @@ class ClassificationModel:
|
|
|
270
285
|
if if_not_exists == "error":
|
|
271
286
|
raise
|
|
272
287
|
|
|
288
|
+
def refresh(self):
|
|
289
|
+
"""Refresh the model data from the OrcaCloud"""
|
|
290
|
+
self.__dict__.update(ClassificationModel.open(self.name).__dict__)
|
|
291
|
+
|
|
292
|
+
def update_metadata(self, *, description: str | None = UNSET) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Update editable classification model metadata properties.
|
|
295
|
+
|
|
296
|
+
Params:
|
|
297
|
+
description: Value to set for the description, defaults to `[UNSET]` if not provided.
|
|
298
|
+
|
|
299
|
+
Examples:
|
|
300
|
+
Update the description:
|
|
301
|
+
>>> model.update(description="New description")
|
|
302
|
+
|
|
303
|
+
Remove description:
|
|
304
|
+
>>> model.update(description=None)
|
|
305
|
+
"""
|
|
306
|
+
update_model(self.id, body=RACModelUpdate(description=description))
|
|
307
|
+
self.refresh()
|
|
308
|
+
|
|
273
309
|
@overload
|
|
274
310
|
def predict(
|
|
275
|
-
self,
|
|
311
|
+
self,
|
|
312
|
+
value: list[str],
|
|
313
|
+
expected_labels: list[int] | None = None,
|
|
314
|
+
tags: set[str] = set(),
|
|
315
|
+
disable_telemetry: bool = False,
|
|
276
316
|
) -> list[LabelPrediction]:
|
|
277
317
|
pass
|
|
278
318
|
|
|
279
319
|
@overload
|
|
280
|
-
def predict(
|
|
320
|
+
def predict(
|
|
321
|
+
self,
|
|
322
|
+
value: str,
|
|
323
|
+
expected_labels: int | None = None,
|
|
324
|
+
tags: set[str] = set(),
|
|
325
|
+
disable_telemetry: bool = False,
|
|
326
|
+
) -> LabelPrediction:
|
|
281
327
|
pass
|
|
282
328
|
|
|
283
329
|
def predict(
|
|
284
|
-
self,
|
|
330
|
+
self,
|
|
331
|
+
value: list[str] | str,
|
|
332
|
+
expected_labels: list[int] | int | None = None,
|
|
333
|
+
tags: set[str] = set(),
|
|
334
|
+
disable_telemetry: bool = False,
|
|
285
335
|
) -> list[LabelPrediction] | LabelPrediction:
|
|
286
336
|
"""
|
|
287
337
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -290,6 +340,7 @@ class ClassificationModel:
|
|
|
290
340
|
value: Value(s) to get predict the labels of
|
|
291
341
|
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
292
342
|
tags: Tags to add to the prediction(s)
|
|
343
|
+
disable_telemetry: Whether to disable telemetry for the prediction(s)
|
|
293
344
|
|
|
294
345
|
Returns:
|
|
295
346
|
Label prediction or list of label predictions
|
|
@@ -318,8 +369,13 @@ class ClassificationModel:
|
|
|
318
369
|
else [expected_labels] if expected_labels is not None else None
|
|
319
370
|
),
|
|
320
371
|
tags=list(tags),
|
|
372
|
+
disable_telemetry=disable_telemetry,
|
|
321
373
|
),
|
|
322
374
|
)
|
|
375
|
+
|
|
376
|
+
if not disable_telemetry and any(p.prediction_id is None for p in response):
|
|
377
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
378
|
+
|
|
323
379
|
predictions = [
|
|
324
380
|
LabelPrediction(
|
|
325
381
|
prediction_id=prediction.prediction_id,
|
|
@@ -329,6 +385,7 @@ class ClassificationModel:
|
|
|
329
385
|
anomaly_score=prediction.anomaly_score,
|
|
330
386
|
memoryset=self.memoryset,
|
|
331
387
|
model=self,
|
|
388
|
+
logits=prediction.logits,
|
|
332
389
|
)
|
|
333
390
|
for prediction in response
|
|
334
391
|
]
|
|
@@ -401,46 +458,157 @@ class ClassificationModel:
|
|
|
401
458
|
for prediction in predictions
|
|
402
459
|
]
|
|
403
460
|
|
|
404
|
-
def
|
|
461
|
+
def _calculate_metrics(
|
|
462
|
+
self,
|
|
463
|
+
predictions: list[LabelPrediction],
|
|
464
|
+
expected_labels: list[int],
|
|
465
|
+
) -> ClassificationEvaluationResult:
|
|
466
|
+
|
|
467
|
+
targets_array = np.array(expected_labels)
|
|
468
|
+
predictions_array = np.array([p.label for p in predictions])
|
|
469
|
+
|
|
470
|
+
logits_array = np.array([p.logits for p in predictions])
|
|
471
|
+
|
|
472
|
+
f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
|
|
473
|
+
accuracy = float(accuracy_score(targets_array, predictions_array))
|
|
474
|
+
|
|
475
|
+
# Only compute ROC AUC and PR AUC for binary classification
|
|
476
|
+
unique_classes = np.unique(targets_array)
|
|
477
|
+
|
|
478
|
+
pr_curve = None
|
|
479
|
+
roc_curve = None
|
|
480
|
+
|
|
481
|
+
if len(unique_classes) == 2:
|
|
482
|
+
try:
|
|
483
|
+
precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
|
|
484
|
+
pr_auc = float(auc(recalls, precisions))
|
|
485
|
+
|
|
486
|
+
pr_curve = PrecisionRecallCurve(
|
|
487
|
+
precisions=precisions.tolist(),
|
|
488
|
+
recalls=recalls.tolist(),
|
|
489
|
+
thresholds=pr_thresholds.tolist(),
|
|
490
|
+
auc=pr_auc,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
|
|
494
|
+
roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
|
|
495
|
+
|
|
496
|
+
roc_curve = ROCCurve(
|
|
497
|
+
false_positive_rates=fpr.tolist(),
|
|
498
|
+
true_positive_rates=tpr.tolist(),
|
|
499
|
+
thresholds=roc_thresholds.tolist(),
|
|
500
|
+
auc=roc_auc,
|
|
501
|
+
)
|
|
502
|
+
except ValueError as e:
|
|
503
|
+
logging.warning(f"Error calculating PR and ROC curves: {e}")
|
|
504
|
+
|
|
505
|
+
return ClassificationEvaluationResult(
|
|
506
|
+
f1_score=f1,
|
|
507
|
+
accuracy=accuracy,
|
|
508
|
+
loss=0.0,
|
|
509
|
+
precision_recall_curve=pr_curve,
|
|
510
|
+
roc_curve=roc_curve,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
def _evaluate_datasource(
|
|
405
514
|
self,
|
|
406
515
|
datasource: Datasource,
|
|
516
|
+
value_column: str,
|
|
517
|
+
label_column: str,
|
|
518
|
+
record_predictions: bool,
|
|
519
|
+
tags: set[str] | None,
|
|
520
|
+
) -> dict[str, Any]:
|
|
521
|
+
response = create_evaluation(
|
|
522
|
+
self.id,
|
|
523
|
+
body=EvaluationRequest(
|
|
524
|
+
datasource_id=datasource.id,
|
|
525
|
+
datasource_label_column=label_column,
|
|
526
|
+
datasource_value_column=value_column,
|
|
527
|
+
memoryset_override_id=self._memoryset_override_id,
|
|
528
|
+
record_telemetry=record_predictions,
|
|
529
|
+
telemetry_tags=list(tags) if tags else None,
|
|
530
|
+
),
|
|
531
|
+
)
|
|
532
|
+
wait_for_task(response.task_id, description="Running evaluation")
|
|
533
|
+
response = get_evaluation(self.id, UUID(response.task_id))
|
|
534
|
+
assert response.result is not None
|
|
535
|
+
return response.result.to_dict()
|
|
536
|
+
|
|
537
|
+
def _evaluate_dataset(
|
|
538
|
+
self,
|
|
539
|
+
dataset: Dataset,
|
|
540
|
+
value_column: str,
|
|
541
|
+
label_column: str,
|
|
542
|
+
record_predictions: bool,
|
|
543
|
+
tags: set[str],
|
|
544
|
+
batch_size: int,
|
|
545
|
+
) -> dict[str, Any]:
|
|
546
|
+
predictions = []
|
|
547
|
+
expected_labels = []
|
|
548
|
+
|
|
549
|
+
for i in range(0, len(dataset), batch_size):
|
|
550
|
+
batch = dataset[i : i + batch_size]
|
|
551
|
+
predictions.extend(
|
|
552
|
+
self.predict(
|
|
553
|
+
batch[value_column],
|
|
554
|
+
expected_labels=batch[label_column],
|
|
555
|
+
tags=tags,
|
|
556
|
+
disable_telemetry=(not record_predictions),
|
|
557
|
+
)
|
|
558
|
+
)
|
|
559
|
+
expected_labels.extend(batch[label_column])
|
|
560
|
+
|
|
561
|
+
return self._calculate_metrics(predictions, expected_labels).to_dict()
|
|
562
|
+
|
|
563
|
+
def evaluate(
|
|
564
|
+
self,
|
|
565
|
+
data: Datasource | Dataset,
|
|
407
566
|
value_column: str = "value",
|
|
408
567
|
label_column: str = "label",
|
|
409
568
|
record_predictions: bool = False,
|
|
410
|
-
tags: set[str]
|
|
569
|
+
tags: set[str] = {"evaluation"},
|
|
570
|
+
batch_size: int = 100,
|
|
411
571
|
) -> dict[str, Any]:
|
|
412
572
|
"""
|
|
413
|
-
Evaluate the classification model on a given datasource
|
|
573
|
+
Evaluate the classification model on a given dataset or datasource
|
|
414
574
|
|
|
415
575
|
Params:
|
|
416
|
-
|
|
576
|
+
data: Dataset or Datasource to evaluate the model on
|
|
417
577
|
value_column: Name of the column that contains the input values to the model
|
|
418
578
|
label_column: Name of the column containing the expected labels
|
|
419
579
|
record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
|
|
420
580
|
tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
|
|
581
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
421
582
|
|
|
422
583
|
Returns:
|
|
423
584
|
Dictionary with evaluation metrics
|
|
424
585
|
|
|
425
586
|
Examples:
|
|
587
|
+
Evaluate using a Datasource:
|
|
426
588
|
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
427
589
|
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
590
|
+
|
|
591
|
+
Evaluate using a Dataset:
|
|
592
|
+
>>> model.evaluate(dataset, value_column="text", label_column="sentiment")
|
|
593
|
+
{ "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
|
|
428
594
|
"""
|
|
429
|
-
|
|
430
|
-
self.
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
595
|
+
if isinstance(data, Datasource):
|
|
596
|
+
return self._evaluate_datasource(
|
|
597
|
+
datasource=data,
|
|
598
|
+
value_column=value_column,
|
|
599
|
+
label_column=label_column,
|
|
600
|
+
record_predictions=record_predictions,
|
|
601
|
+
tags=tags,
|
|
602
|
+
)
|
|
603
|
+
else:
|
|
604
|
+
return self._evaluate_dataset(
|
|
605
|
+
dataset=data,
|
|
606
|
+
value_column=value_column,
|
|
607
|
+
label_column=label_column,
|
|
608
|
+
record_predictions=record_predictions,
|
|
609
|
+
tags=tags,
|
|
610
|
+
batch_size=batch_size,
|
|
611
|
+
)
|
|
444
612
|
|
|
445
613
|
def finetune(self, datasource: Datasource):
|
|
446
614
|
# do not document until implemented
|