orca-sdk 0.0.90__py3-none-any.whl → 0.0.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +12 -0
  2. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +12 -12
  3. orca_sdk/_generated_api_client/api/classification_model/update_model_classification_model_name_or_id_patch.py +183 -0
  4. orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +168 -0
  5. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +183 -0
  6. orca_sdk/_generated_api_client/models/__init__.py +8 -2
  7. orca_sdk/_generated_api_client/models/{label_prediction_result.py → base_label_prediction_result.py} +24 -9
  8. orca_sdk/_generated_api_client/models/delete_memorysets_request.py +70 -0
  9. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +8 -8
  10. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +8 -0
  11. orca_sdk/_generated_api_client/models/labeled_memoryset_update.py +113 -0
  12. orca_sdk/_generated_api_client/models/prediction_request.py +9 -0
  13. orca_sdk/_generated_api_client/models/rac_model_update.py +82 -0
  14. orca_sdk/_shared/__init__.py +1 -0
  15. orca_sdk/_shared/metrics.py +195 -0
  16. orca_sdk/_shared/metrics_test.py +169 -0
  17. orca_sdk/_utils/analysis_ui.py +1 -1
  18. orca_sdk/_utils/analysis_ui_style.css +0 -3
  19. orca_sdk/classification_model.py +191 -23
  20. orca_sdk/classification_model_test.py +75 -22
  21. orca_sdk/conftest.py +13 -1
  22. orca_sdk/embedding_model.py +2 -0
  23. orca_sdk/memoryset.py +13 -0
  24. orca_sdk/memoryset_test.py +27 -6
  25. orca_sdk/telemetry.py +13 -2
  26. orca_sdk/telemetry_test.py +6 -0
  27. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/METADATA +3 -1
  28. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/RECORD +29 -20
  29. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.92.dist-info}/WHEEL +0 -0
@@ -0,0 +1,195 @@
1
+ """
2
+ This module contains metrics for usage with the Hugging Face Trainer.
3
+
4
+ IMPORTANT:
5
+ - This is a shared file between OrcaLib and the Orca SDK.
6
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
7
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
8
+
9
+ """
10
+
11
+ from typing import Literal, Tuple, TypedDict
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+ from scipy.special import softmax
16
+ from sklearn.metrics import accuracy_score, auc, f1_score, log_loss
17
+ from sklearn.metrics import precision_recall_curve as sklearn_precision_recall_curve
18
+ from sklearn.metrics import roc_auc_score
19
+ from sklearn.metrics import roc_curve as sklearn_roc_curve
20
+ from transformers.trainer_utils import EvalPrediction
21
+
22
+
23
+ class ClassificationMetrics(TypedDict):
24
+ accuracy: float
25
+ f1_score: float
26
+ roc_auc: float | None # receiver operating characteristic area under the curve (if all classes are present)
27
+ pr_auc: float | None # precision-recall area under the curve (only for binary classification)
28
+ log_loss: float # cross-entropy loss for probabilities
29
+
30
+
31
+ def compute_classifier_metrics(eval_pred: EvalPrediction) -> ClassificationMetrics:
32
+ """
33
+ Compute standard metrics for classifier with Hugging Face Trainer.
34
+
35
+ Args:
36
+ eval_pred: The predictions containing logits and expected labels as given by the Trainer.
37
+
38
+ Returns:
39
+ A dictionary containing the accuracy, f1 score, and ROC AUC score.
40
+ """
41
+ logits, references = eval_pred
42
+ if isinstance(logits, tuple):
43
+ logits = logits[0]
44
+ if not isinstance(logits, np.ndarray):
45
+ raise ValueError("Logits must be a numpy array")
46
+ if not isinstance(references, np.ndarray):
47
+ raise ValueError(
48
+ "Multiple label columns found, use the `label_names` training argument to specify which one to use"
49
+ )
50
+
51
+ if not (logits > 0).all():
52
+ # convert logits to probabilities with softmax if necessary
53
+ probabilities = softmax(logits)
54
+ elif not np.allclose(logits.sum(-1, keepdims=True), 1.0):
55
+ # convert logits to probabilities through normalization if necessary
56
+ probabilities = logits / logits.sum(-1, keepdims=True)
57
+ else:
58
+ probabilities = logits
59
+
60
+ return classification_scores(references, probabilities)
61
+
62
+
63
+ def classification_scores(
64
+ references: NDArray[np.int64],
65
+ probabilities: NDArray[np.float32],
66
+ average: Literal["micro", "macro", "weighted", "binary"] | None = None,
67
+ multi_class: Literal["ovr", "ovo"] = "ovr",
68
+ ) -> ClassificationMetrics:
69
+ if probabilities.ndim == 1:
70
+ # convert 1D probabilities (binary) to 2D logits
71
+ probabilities = np.column_stack([1 - probabilities, probabilities])
72
+ elif probabilities.ndim == 2:
73
+ if probabilities.shape[1] < 2:
74
+ raise ValueError("Use a different metric function for regression tasks")
75
+ else:
76
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
77
+
78
+ predictions = np.argmax(probabilities, axis=-1)
79
+
80
+ num_classes_references = len(set(references))
81
+ num_classes_predictions = len(set(predictions))
82
+
83
+ if average is None:
84
+ average = "binary" if num_classes_references == 2 else "weighted"
85
+
86
+ accuracy = accuracy_score(references, predictions)
87
+ f1 = f1_score(references, predictions, average=average)
88
+ loss = log_loss(references, probabilities)
89
+
90
+ if num_classes_references == num_classes_predictions:
91
+ # special case for binary classification: https://github.com/scikit-learn/scikit-learn/issues/20186
92
+ if num_classes_references == 2:
93
+ roc_auc = roc_auc_score(references, probabilities[:, 1])
94
+ precisions, recalls, _ = calculate_pr_curve(references, probabilities[:, 1])
95
+ pr_auc = auc(recalls, precisions)
96
+ else:
97
+ roc_auc = roc_auc_score(references, probabilities, multi_class=multi_class)
98
+ pr_auc = None
99
+ else:
100
+ roc_auc = None
101
+ pr_auc = None
102
+
103
+ return {
104
+ "accuracy": float(accuracy),
105
+ "f1_score": float(f1),
106
+ "roc_auc": float(roc_auc) if roc_auc is not None else None,
107
+ "pr_auc": float(pr_auc) if pr_auc is not None else None,
108
+ "log_loss": float(loss),
109
+ }
110
+
111
+
112
+ def calculate_pr_curve(
113
+ references: NDArray[np.int64],
114
+ probabilities: NDArray[np.float32],
115
+ max_length: int = 100,
116
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
117
+ if probabilities.ndim == 1:
118
+ probabilities_slice = probabilities
119
+ elif probabilities.ndim == 2:
120
+ probabilities_slice = probabilities[:, 1]
121
+ else:
122
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
123
+
124
+ if len(probabilities_slice) != len(references):
125
+ raise ValueError("Probabilities and references must have the same length")
126
+
127
+ precisions, recalls, thresholds = sklearn_precision_recall_curve(references, probabilities_slice)
128
+
129
+ # Convert all arrays to float32 immediately after getting them
130
+ precisions = precisions.astype(np.float32)
131
+ recalls = recalls.astype(np.float32)
132
+ thresholds = thresholds.astype(np.float32)
133
+
134
+ # Concatenate with 0 to include the lowest threshold
135
+ thresholds = np.concatenate(([0], thresholds))
136
+
137
+ # Sort by threshold
138
+ sorted_indices = np.argsort(thresholds)
139
+ thresholds = thresholds[sorted_indices]
140
+ precisions = precisions[sorted_indices]
141
+ recalls = recalls[sorted_indices]
142
+
143
+ if len(precisions) > max_length:
144
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
145
+ new_precisions = np.interp(new_thresholds, thresholds, precisions)
146
+ new_recalls = np.interp(new_thresholds, thresholds, recalls)
147
+ thresholds = new_thresholds
148
+ precisions = new_precisions
149
+ recalls = new_recalls
150
+
151
+ return precisions.astype(np.float32), recalls.astype(np.float32), thresholds.astype(np.float32)
152
+
153
+
154
+ def calculate_roc_curve(
155
+ references: NDArray[np.int64],
156
+ probabilities: NDArray[np.float32],
157
+ max_length: int = 100,
158
+ ) -> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
159
+ if probabilities.ndim == 1:
160
+ probabilities_slice = probabilities
161
+ elif probabilities.ndim == 2:
162
+ probabilities_slice = probabilities[:, 1]
163
+ else:
164
+ raise ValueError("Probabilities must be 1 or 2 dimensional")
165
+
166
+ if len(probabilities_slice) != len(references):
167
+ raise ValueError("Probabilities and references must have the same length")
168
+
169
+ # Convert probabilities to float32 before calling sklearn_roc_curve
170
+ probabilities_slice = probabilities_slice.astype(np.float32)
171
+ fpr, tpr, thresholds = sklearn_roc_curve(references, probabilities_slice)
172
+
173
+ # Convert all arrays to float32 immediately after getting them
174
+ fpr = fpr.astype(np.float32)
175
+ tpr = tpr.astype(np.float32)
176
+ thresholds = thresholds.astype(np.float32)
177
+
178
+ # We set the first threshold to 1.0 instead of inf for reasonable values in interpolation
179
+ thresholds[0] = 1.0
180
+
181
+ # Sort by threshold
182
+ sorted_indices = np.argsort(thresholds)
183
+ thresholds = thresholds[sorted_indices]
184
+ fpr = fpr[sorted_indices]
185
+ tpr = tpr[sorted_indices]
186
+
187
+ if len(fpr) > max_length:
188
+ new_thresholds = np.linspace(0, 1, max_length, dtype=np.float32)
189
+ new_fpr = np.interp(new_thresholds, thresholds, fpr)
190
+ new_tpr = np.interp(new_thresholds, thresholds, tpr)
191
+ thresholds = new_thresholds
192
+ fpr = new_fpr
193
+ tpr = new_tpr
194
+
195
+ return fpr.astype(np.float32), tpr.astype(np.float32), thresholds.astype(np.float32)
@@ -0,0 +1,169 @@
1
+ """
2
+ IMPORTANT:
3
+ - This is a shared file between OrcaLib and the Orca SDK.
4
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
5
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
+ """
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ import pytest
12
+
13
+ from .metrics import (
14
+ EvalPrediction,
15
+ calculate_pr_curve,
16
+ calculate_roc_curve,
17
+ classification_scores,
18
+ compute_classifier_metrics,
19
+ softmax,
20
+ )
21
+
22
+
23
+ def test_binary_metrics():
24
+ y_true = np.array([0, 1, 1, 0, 1])
25
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
+
27
+ metrics = classification_scores(y_true, y_score)
28
+
29
+ assert metrics["accuracy"] == 0.8
30
+ assert metrics["f1_score"] == 0.8
31
+ assert metrics["roc_auc"] is not None
32
+ assert metrics["roc_auc"] > 0.8
33
+ assert metrics["roc_auc"] < 1.0
34
+ assert metrics["pr_auc"] is not None
35
+ assert metrics["pr_auc"] > 0.8
36
+ assert metrics["pr_auc"] < 1.0
37
+ assert metrics["log_loss"] is not None
38
+ assert metrics["log_loss"] > 0.0
39
+
40
+
41
+ def test_multiclass_metrics_with_2_classes():
42
+ y_true = np.array([0, 1, 1, 0, 1])
43
+ y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
+
45
+ metrics = classification_scores(y_true, y_score)
46
+
47
+ assert metrics["accuracy"] == 0.8
48
+ assert metrics["f1_score"] == 0.8
49
+ assert metrics["roc_auc"] is not None
50
+ assert metrics["roc_auc"] > 0.8
51
+ assert metrics["roc_auc"] < 1.0
52
+ assert metrics["pr_auc"] is not None
53
+ assert metrics["pr_auc"] > 0.8
54
+ assert metrics["pr_auc"] < 1.0
55
+ assert metrics["log_loss"] is not None
56
+ assert metrics["log_loss"] > 0.0
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "average, multiclass",
61
+ [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
+ )
63
+ def test_multiclass_metrics_with_3_classes(
64
+ average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
+ ):
66
+ y_true = np.array([0, 1, 1, 0, 2])
67
+ y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
+
69
+ metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
70
+
71
+ assert metrics["accuracy"] == 1.0
72
+ assert metrics["f1_score"] == 1.0
73
+ assert metrics["roc_auc"] is not None
74
+ assert metrics["roc_auc"] > 0.8
75
+ assert metrics["pr_auc"] is None
76
+ assert metrics["log_loss"] is not None
77
+ assert metrics["log_loss"] > 0.0
78
+
79
+
80
+ def test_does_not_modify_logits_unless_necessary():
81
+ logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
+ references = np.array([0, 1, 0, 1])
83
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
84
+ assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
85
+
86
+
87
+ def test_normalizes_logits_if_necessary():
88
+ logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
89
+ references = np.array([0, 1, 0, 1])
90
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
91
+ assert (
92
+ metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
93
+ )
94
+
95
+
96
+ def test_softmaxes_logits_if_necessary():
97
+ logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
+ references = np.array([0, 1, 0, 1])
99
+ metrics = compute_classifier_metrics(EvalPrediction(logits, references))
100
+ assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
101
+
102
+
103
+ def test_precision_recall_curve():
104
+ y_true = np.array([0, 1, 1, 0, 1])
105
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
106
+
107
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
108
+ assert precision is not None
109
+ assert recall is not None
110
+ assert thresholds is not None
111
+
112
+ assert len(precision) == len(recall) == len(thresholds) == 6
113
+ assert precision[0] == 0.6
114
+ assert recall[0] == 1.0
115
+ assert precision[-1] == 1.0
116
+ assert recall[-1] == 0.0
117
+
118
+ # test that thresholds are sorted
119
+ assert np.all(np.diff(thresholds) >= 0)
120
+
121
+
122
+ def test_roc_curve():
123
+ y_true = np.array([0, 1, 1, 0, 1])
124
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
125
+
126
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
127
+ assert fpr is not None
128
+ assert tpr is not None
129
+ assert thresholds is not None
130
+
131
+ assert len(fpr) == len(tpr) == len(thresholds) == 6
132
+ assert fpr[0] == 1.0
133
+ assert tpr[0] == 1.0
134
+ assert fpr[-1] == 0.0
135
+ assert tpr[-1] == 0.0
136
+
137
+ # test that thresholds are sorted
138
+ assert np.all(np.diff(thresholds) >= 0)
139
+
140
+
141
+ def test_precision_recall_curve_max_length():
142
+ y_true = np.array([0, 1, 1, 0, 1])
143
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
144
+
145
+ precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
146
+ assert len(precision) == len(recall) == len(thresholds) == 5
147
+
148
+ assert precision[0] == 0.6
149
+ assert recall[0] == 1.0
150
+ assert precision[-1] == 1.0
151
+ assert recall[-1] == 0.0
152
+
153
+ # test that thresholds are sorted
154
+ assert np.all(np.diff(thresholds) >= 0)
155
+
156
+
157
+ def test_roc_curve_max_length():
158
+ y_true = np.array([0, 1, 1, 0, 1])
159
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
160
+
161
+ fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
162
+ assert len(fpr) == len(tpr) == len(thresholds) == 5
163
+ assert fpr[0] == 1.0
164
+ assert tpr[0] == 1.0
165
+ assert fpr[-1] == 0.0
166
+ assert tpr[-1] == 0.0
167
+
168
+ # test that thresholds are sorted
169
+ assert np.all(np.diff(thresholds) >= 0)
@@ -152,7 +152,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
152
152
  predicted_label_name = label_names[predicted_label]
153
153
  predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
154
154
 
155
- with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
155
+ with gr.Row(equal_height=True, variant="panel"):
156
156
  with gr.Column(scale=9):
157
157
  assert isinstance(mem.value, str)
158
158
  gr.Markdown(mem.value, label="Value", height=50)
@@ -1,6 +1,3 @@
1
- .white {
2
- background-color: white;
3
- }
4
1
  .centered input {
5
2
  margin: auto;
6
3
  }
@@ -6,6 +6,15 @@ from datetime import datetime
6
6
  from typing import Any, Generator, Iterable, Literal, cast, overload
7
7
  from uuid import UUID
8
8
 
9
+ import numpy as np
10
+ from datasets import Dataset
11
+ from sklearn.metrics import (
12
+ accuracy_score,
13
+ auc,
14
+ f1_score,
15
+ roc_auc_score,
16
+ )
17
+
9
18
  from ._generated_api_client.api import (
10
19
  create_evaluation,
11
20
  create_model,
@@ -16,11 +25,14 @@ from ._generated_api_client.api import (
16
25
  list_predictions,
17
26
  predict_gpu,
18
27
  record_prediction_feedback,
28
+ update_model,
19
29
  )
20
30
  from ._generated_api_client.models import (
31
+ ClassificationEvaluationResult,
21
32
  CreateRACModelRequest,
22
33
  EvaluationRequest,
23
34
  ListPredictionsRequest,
35
+ PrecisionRecallCurve,
24
36
  )
25
37
  from ._generated_api_client.models import (
26
38
  PredictionSortItemItemType0 as PredictionSortColumns,
@@ -31,9 +43,12 @@ from ._generated_api_client.models import (
31
43
  from ._generated_api_client.models import (
32
44
  RACHeadType,
33
45
  RACModelMetadata,
46
+ RACModelUpdate,
47
+ ROCCurve,
34
48
  )
35
49
  from ._generated_api_client.models.prediction_request import PredictionRequest
36
- from ._utils.common import CreateMode, DropMode
50
+ from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
51
+ from ._utils.common import UNSET, CreateMode, DropMode
37
52
  from ._utils.task import wait_for_task
38
53
  from .datasource import Datasource
39
54
  from .memoryset import LabeledMemoryset
@@ -270,18 +285,53 @@ class ClassificationModel:
270
285
  if if_not_exists == "error":
271
286
  raise
272
287
 
288
+ def refresh(self):
289
+ """Refresh the model data from the OrcaCloud"""
290
+ self.__dict__.update(ClassificationModel.open(self.name).__dict__)
291
+
292
+ def update_metadata(self, *, description: str | None = UNSET) -> None:
293
+ """
294
+ Update editable classification model metadata properties.
295
+
296
+ Params:
297
+ description: Value to set for the description, defaults to `[UNSET]` if not provided.
298
+
299
+ Examples:
300
+ Update the description:
301
+ >>> model.update(description="New description")
302
+
303
+ Remove description:
304
+ >>> model.update(description=None)
305
+ """
306
+ update_model(self.id, body=RACModelUpdate(description=description))
307
+ self.refresh()
308
+
273
309
  @overload
274
310
  def predict(
275
- self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
311
+ self,
312
+ value: list[str],
313
+ expected_labels: list[int] | None = None,
314
+ tags: set[str] = set(),
315
+ disable_telemetry: bool = False,
276
316
  ) -> list[LabelPrediction]:
277
317
  pass
278
318
 
279
319
  @overload
280
- def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
320
+ def predict(
321
+ self,
322
+ value: str,
323
+ expected_labels: int | None = None,
324
+ tags: set[str] = set(),
325
+ disable_telemetry: bool = False,
326
+ ) -> LabelPrediction:
281
327
  pass
282
328
 
283
329
  def predict(
284
- self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
330
+ self,
331
+ value: list[str] | str,
332
+ expected_labels: list[int] | int | None = None,
333
+ tags: set[str] = set(),
334
+ disable_telemetry: bool = False,
285
335
  ) -> list[LabelPrediction] | LabelPrediction:
286
336
  """
287
337
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -290,6 +340,7 @@ class ClassificationModel:
290
340
  value: Value(s) to get predict the labels of
291
341
  expected_labels: Expected label(s) for the given input to record for model evaluation
292
342
  tags: Tags to add to the prediction(s)
343
+ disable_telemetry: Whether to disable telemetry for the prediction(s)
293
344
 
294
345
  Returns:
295
346
  Label prediction or list of label predictions
@@ -318,8 +369,13 @@ class ClassificationModel:
318
369
  else [expected_labels] if expected_labels is not None else None
319
370
  ),
320
371
  tags=list(tags),
372
+ disable_telemetry=disable_telemetry,
321
373
  ),
322
374
  )
375
+
376
+ if not disable_telemetry and any(p.prediction_id is None for p in response):
377
+ raise RuntimeError("Failed to save prediction to database.")
378
+
323
379
  predictions = [
324
380
  LabelPrediction(
325
381
  prediction_id=prediction.prediction_id,
@@ -329,6 +385,7 @@ class ClassificationModel:
329
385
  anomaly_score=prediction.anomaly_score,
330
386
  memoryset=self.memoryset,
331
387
  model=self,
388
+ logits=prediction.logits,
332
389
  )
333
390
  for prediction in response
334
391
  ]
@@ -401,46 +458,157 @@ class ClassificationModel:
401
458
  for prediction in predictions
402
459
  ]
403
460
 
404
- def evaluate(
461
+ def _calculate_metrics(
462
+ self,
463
+ predictions: list[LabelPrediction],
464
+ expected_labels: list[int],
465
+ ) -> ClassificationEvaluationResult:
466
+
467
+ targets_array = np.array(expected_labels)
468
+ predictions_array = np.array([p.label for p in predictions])
469
+
470
+ logits_array = np.array([p.logits for p in predictions])
471
+
472
+ f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
473
+ accuracy = float(accuracy_score(targets_array, predictions_array))
474
+
475
+ # Only compute ROC AUC and PR AUC for binary classification
476
+ unique_classes = np.unique(targets_array)
477
+
478
+ pr_curve = None
479
+ roc_curve = None
480
+
481
+ if len(unique_classes) == 2:
482
+ try:
483
+ precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
484
+ pr_auc = float(auc(recalls, precisions))
485
+
486
+ pr_curve = PrecisionRecallCurve(
487
+ precisions=precisions.tolist(),
488
+ recalls=recalls.tolist(),
489
+ thresholds=pr_thresholds.tolist(),
490
+ auc=pr_auc,
491
+ )
492
+
493
+ fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
494
+ roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
495
+
496
+ roc_curve = ROCCurve(
497
+ false_positive_rates=fpr.tolist(),
498
+ true_positive_rates=tpr.tolist(),
499
+ thresholds=roc_thresholds.tolist(),
500
+ auc=roc_auc,
501
+ )
502
+ except ValueError as e:
503
+ logging.warning(f"Error calculating PR and ROC curves: {e}")
504
+
505
+ return ClassificationEvaluationResult(
506
+ f1_score=f1,
507
+ accuracy=accuracy,
508
+ loss=0.0,
509
+ precision_recall_curve=pr_curve,
510
+ roc_curve=roc_curve,
511
+ )
512
+
513
+ def _evaluate_datasource(
405
514
  self,
406
515
  datasource: Datasource,
516
+ value_column: str,
517
+ label_column: str,
518
+ record_predictions: bool,
519
+ tags: set[str] | None,
520
+ ) -> dict[str, Any]:
521
+ response = create_evaluation(
522
+ self.id,
523
+ body=EvaluationRequest(
524
+ datasource_id=datasource.id,
525
+ datasource_label_column=label_column,
526
+ datasource_value_column=value_column,
527
+ memoryset_override_id=self._memoryset_override_id,
528
+ record_telemetry=record_predictions,
529
+ telemetry_tags=list(tags) if tags else None,
530
+ ),
531
+ )
532
+ wait_for_task(response.task_id, description="Running evaluation")
533
+ response = get_evaluation(self.id, UUID(response.task_id))
534
+ assert response.result is not None
535
+ return response.result.to_dict()
536
+
537
+ def _evaluate_dataset(
538
+ self,
539
+ dataset: Dataset,
540
+ value_column: str,
541
+ label_column: str,
542
+ record_predictions: bool,
543
+ tags: set[str],
544
+ batch_size: int,
545
+ ) -> dict[str, Any]:
546
+ predictions = []
547
+ expected_labels = []
548
+
549
+ for i in range(0, len(dataset), batch_size):
550
+ batch = dataset[i : i + batch_size]
551
+ predictions.extend(
552
+ self.predict(
553
+ batch[value_column],
554
+ expected_labels=batch[label_column],
555
+ tags=tags,
556
+ disable_telemetry=(not record_predictions),
557
+ )
558
+ )
559
+ expected_labels.extend(batch[label_column])
560
+
561
+ return self._calculate_metrics(predictions, expected_labels).to_dict()
562
+
563
+ def evaluate(
564
+ self,
565
+ data: Datasource | Dataset,
407
566
  value_column: str = "value",
408
567
  label_column: str = "label",
409
568
  record_predictions: bool = False,
410
- tags: set[str] | None = None,
569
+ tags: set[str] = {"evaluation"},
570
+ batch_size: int = 100,
411
571
  ) -> dict[str, Any]:
412
572
  """
413
- Evaluate the classification model on a given datasource
573
+ Evaluate the classification model on a given dataset or datasource
414
574
 
415
575
  Params:
416
- datasource: Datasource to evaluate the model on
576
+ data: Dataset or Datasource to evaluate the model on
417
577
  value_column: Name of the column that contains the input values to the model
418
578
  label_column: Name of the column containing the expected labels
419
579
  record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
420
580
  tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
581
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
421
582
 
422
583
  Returns:
423
584
  Dictionary with evaluation metrics
424
585
 
425
586
  Examples:
587
+ Evaluate using a Datasource:
426
588
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
427
589
  { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
590
+
591
+ Evaluate using a Dataset:
592
+ >>> model.evaluate(dataset, value_column="text", label_column="sentiment")
593
+ { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
428
594
  """
429
- response = create_evaluation(
430
- self.id,
431
- body=EvaluationRequest(
432
- datasource_id=datasource.id,
433
- datasource_label_column=label_column,
434
- datasource_value_column=value_column,
435
- memoryset_override_id=self._memoryset_override_id,
436
- record_telemetry=record_predictions,
437
- telemetry_tags=list(tags) if tags else None,
438
- ),
439
- )
440
- wait_for_task(response.task_id, description="Running evaluation")
441
- response = get_evaluation(self.id, UUID(response.task_id))
442
- assert response.result is not None
443
- return response.result.to_dict()
595
+ if isinstance(data, Datasource):
596
+ return self._evaluate_datasource(
597
+ datasource=data,
598
+ value_column=value_column,
599
+ label_column=label_column,
600
+ record_predictions=record_predictions,
601
+ tags=tags,
602
+ )
603
+ else:
604
+ return self._evaluate_dataset(
605
+ dataset=data,
606
+ value_column=value_column,
607
+ label_column=label_column,
608
+ record_predictions=record_predictions,
609
+ tags=tags,
610
+ batch_size=batch_size,
611
+ )
444
612
 
445
613
  def finetune(self, datasource: Datasource):
446
614
  # do not document until implemented