orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -175
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +338 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -5,37 +5,29 @@ import os
5
5
  from contextlib import contextmanager
6
6
  from datetime import datetime
7
7
  from typing import Any, Generator, Iterable, Literal, cast, overload
8
- from uuid import UUID, uuid4
8
+ from uuid import UUID
9
9
 
10
- import numpy as np
11
-
12
- import numpy as np
13
10
  from datasets import Dataset
14
- from sklearn.metrics import (
15
- accuracy_score,
16
- auc,
17
- f1_score,
18
- roc_auc_score,
19
- )
20
11
 
21
12
  from ._generated_api_client.api import (
22
- create_evaluation,
23
- create_model,
24
- delete_model,
25
- get_evaluation,
26
- get_model,
27
- list_models,
13
+ create_classification_model,
14
+ delete_classification_model,
15
+ evaluate_classification_model,
16
+ get_classification_model,
17
+ get_classification_model_evaluation,
18
+ list_classification_models,
28
19
  list_predictions,
29
- predict_gpu,
20
+ predict_label_gpu,
30
21
  record_prediction_feedback,
31
- update_model,
22
+ update_classification_model,
32
23
  )
33
24
  from ._generated_api_client.models import (
34
- ClassificationEvaluationResult,
35
- CreateRACModelRequest,
36
- EvaluationRequest,
25
+ ClassificationEvaluationRequest,
26
+ ClassificationModelMetadata,
27
+ ClassificationPredictionRequest,
28
+ CreateClassificationModelRequest,
29
+ LabelPredictionWithMemoriesAndFeedback,
37
30
  ListPredictionsRequest,
38
- PrecisionRecallCurve,
39
31
  )
40
32
  from ._generated_api_client.models import (
41
33
  PredictionSortItemItemType0 as PredictionSortColumns,
@@ -43,19 +35,19 @@ from ._generated_api_client.models import (
43
35
  from ._generated_api_client.models import (
44
36
  PredictionSortItemItemType1 as PredictionSortDirection,
45
37
  )
46
- from ._generated_api_client.models import (
47
- RACHeadType,
48
- RACModelMetadata,
49
- RACModelUpdate,
50
- ROCCurve,
51
- )
52
- from ._generated_api_client.models.prediction_request import PredictionRequest
53
- from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
38
+ from ._generated_api_client.models import PredictiveModelUpdate, RACHeadType
39
+ from ._generated_api_client.types import UNSET as CLIENT_UNSET
40
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
54
41
  from ._utils.common import UNSET, CreateMode, DropMode
55
- from ._utils.task import wait_for_task
56
42
  from .datasource import Datasource
57
- from .memoryset import LabeledMemoryset
58
- from .telemetry import LabelPrediction, _parse_feedback
43
+ from .job import Job
44
+ from .memoryset import (
45
+ FilterItem,
46
+ FilterItemTuple,
47
+ LabeledMemoryset,
48
+ _parse_filter_item_from_tuple,
49
+ )
50
+ from .telemetry import ClassificationPrediction, _parse_feedback
59
51
 
60
52
 
61
53
  class ClassificationModel:
@@ -72,6 +64,7 @@ class ClassificationModel:
72
64
  memory_lookup_count: Number of memories the model uses for each prediction
73
65
  weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
74
66
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
67
+ locked: Whether the model is locked to prevent accidental deletion
75
68
  created_at: When the model was created
76
69
  """
77
70
 
@@ -85,9 +78,10 @@ class ClassificationModel:
85
78
  weigh_memories: bool | None
86
79
  min_memory_weight: float | None
87
80
  version: int
81
+ locked: bool
88
82
  created_at: datetime
89
83
 
90
- def __init__(self, metadata: RACModelMetadata):
84
+ def __init__(self, metadata: ClassificationModelMetadata):
91
85
  # for internal use only, do not document
92
86
  self.id = metadata.id
93
87
  self.name = metadata.name
@@ -99,10 +93,11 @@ class ClassificationModel:
99
93
  self.weigh_memories = metadata.weigh_memories
100
94
  self.min_memory_weight = metadata.min_memory_weight
101
95
  self.version = metadata.version
96
+ self.locked = metadata.locked
102
97
  self.created_at = metadata.created_at
103
98
 
104
99
  self._memoryset_override_id: str | None = None
105
- self._last_prediction: LabelPrediction | None = None
100
+ self._last_prediction: ClassificationPrediction | None = None
106
101
  self._last_prediction_was_batch: bool = False
107
102
 
108
103
  def __eq__(self, other) -> bool:
@@ -120,7 +115,7 @@ class ClassificationModel:
120
115
  )
121
116
 
122
117
  @property
123
- def last_prediction(self) -> LabelPrediction:
118
+ def last_prediction(self) -> ClassificationPrediction:
124
119
  """
125
120
  Last prediction made by the model
126
121
 
@@ -208,8 +203,8 @@ class ClassificationModel:
208
203
 
209
204
  return existing
210
205
 
211
- metadata = create_model(
212
- body=CreateRACModelRequest(
206
+ metadata = create_classification_model(
207
+ body=CreateClassificationModelRequest(
213
208
  name=name,
214
209
  memoryset_id=memoryset.id,
215
210
  head_type=RACHeadType(head_type),
@@ -236,7 +231,7 @@ class ClassificationModel:
236
231
  Raises:
237
232
  LookupError: If the classification model does not exist
238
233
  """
239
- return cls(get_model(name))
234
+ return cls(get_classification_model(name))
240
235
 
241
236
  @classmethod
242
237
  def exists(cls, name_or_id: str) -> bool:
@@ -263,7 +258,7 @@ class ClassificationModel:
263
258
  Returns:
264
259
  List of handles to all classification models in the OrcaCloud
265
260
  """
266
- return [cls(metadata) for metadata in list_models()]
261
+ return [cls(metadata) for metadata in list_classification_models()]
267
262
 
268
263
  @classmethod
269
264
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -282,7 +277,7 @@ class ClassificationModel:
282
277
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
283
278
  """
284
279
  try:
285
- delete_model(name_or_id)
280
+ delete_classification_model(name_or_id)
286
281
  logging.info(f"Deleted model {name_or_id}")
287
282
  except LookupError:
288
283
  if if_not_exists == "error":
@@ -290,34 +285,53 @@ class ClassificationModel:
290
285
 
291
286
  def refresh(self):
292
287
  """Refresh the model data from the OrcaCloud"""
293
- self.__dict__.update(ClassificationModel.open(self.name).__dict__)
288
+ self.__dict__.update(self.open(self.name).__dict__)
294
289
 
295
- def update_metadata(self, *, description: str | None = UNSET) -> None:
290
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
296
291
  """
297
- Update editable classification model metadata properties.
292
+ Update editable attributes of the model.
293
+
294
+ Note:
295
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
298
296
 
299
297
  Params:
300
- description: Value to set for the description, defaults to `[UNSET]` if not provided.
298
+ description: Value to set for the description
299
+ locked: Value to set for the locked status
301
300
 
302
301
  Examples:
303
302
  Update the description:
304
- >>> model.update(description="New description")
303
+ >>> model.set(description="New description")
305
304
 
306
305
  Remove description:
307
- >>> model.update(description=None)
306
+ >>> model.set(description=None)
307
+
308
+ Lock the model:
309
+ >>> model.set(locked=True)
308
310
  """
309
- update_model(self.id, body=RACModelUpdate(description=description))
311
+ update_data = PredictiveModelUpdate(
312
+ description=CLIENT_UNSET if description is UNSET else description,
313
+ locked=CLIENT_UNSET if locked is UNSET else locked,
314
+ )
315
+ update_classification_model(self.id, body=update_data)
310
316
  self.refresh()
311
317
 
318
+ def lock(self) -> None:
319
+ """Lock the model to prevent accidental deletion"""
320
+ self.set(locked=True)
321
+
322
+ def unlock(self) -> None:
323
+ """Unlock the model to allow deletion"""
324
+ self.set(locked=False)
325
+
312
326
  @overload
313
327
  def predict(
314
328
  self,
315
329
  value: list[str],
316
330
  expected_labels: list[int] | None = None,
317
- tags: set[str] = set(),
318
- save_telemetry: bool = True,
319
- save_telemetry_synchronously: bool = False,
320
- ) -> list[LabelPrediction]:
331
+ filters: list[FilterItemTuple] = [],
332
+ tags: set[str] | None = None,
333
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
334
+ ) -> list[ClassificationPrediction]:
321
335
  pass
322
336
 
323
337
  @overload
@@ -325,20 +339,20 @@ class ClassificationModel:
325
339
  self,
326
340
  value: str,
327
341
  expected_labels: int | None = None,
328
- tags: set[str] = set(),
329
- save_telemetry: bool = True,
330
- save_telemetry_synchronously: bool = False,
331
- ) -> LabelPrediction:
342
+ filters: list[FilterItemTuple] = [],
343
+ tags: set[str] | None = None,
344
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
345
+ ) -> ClassificationPrediction:
332
346
  pass
333
347
 
334
348
  def predict(
335
349
  self,
336
350
  value: list[str] | str,
337
351
  expected_labels: list[int] | int | None = None,
338
- tags: set[str] = set(),
339
- save_telemetry: bool = True,
340
- save_telemetry_synchronously: bool = False,
341
- ) -> list[LabelPrediction] | LabelPrediction:
352
+ filters: list[FilterItemTuple] = [],
353
+ tags: set[str] | None = None,
354
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
355
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
342
356
  """
343
357
  Predict label(s) for the given input value(s) grounded in similar memories
344
358
 
@@ -346,10 +360,12 @@ class ClassificationModel:
346
360
  value: Value(s) to get predict the labels of
347
361
  expected_labels: Expected label(s) for the given input to record for model evaluation
348
362
  tags: Tags to add to the prediction(s)
349
- save_telemetry: Whether to enable telemetry for the prediction(s)
350
- save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
351
- asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
352
- may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
363
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
364
+ * `"off"`: Do not save telemetry
365
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
366
+ environment variable is set.
367
+ * `"sync"`: Save telemetry synchronously
368
+ * `"async"`: Save telemetry asynchronously
353
369
 
354
370
  Returns:
355
371
  Label prediction or list of label predictions
@@ -357,49 +373,51 @@ class ClassificationModel:
357
373
  Examples:
358
374
  Predict the label for a single value:
359
375
  >>> prediction = model.predict("I am happy", tags={"test"})
360
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
376
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
361
377
 
362
378
  Predict the labels for a list of values:
363
379
  >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
364
380
  [
365
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
366
- LabelPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
381
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
382
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
367
383
  ]
368
384
  """
369
385
 
370
- if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
371
- env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
372
- logging.info(
373
- f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
374
- )
375
- save_telemetry_synchronously = env_var.lower() == "true"
386
+ parsed_filters = [
387
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
388
+ ]
376
389
 
377
- response = predict_gpu(
390
+ if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
391
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
392
+
393
+ response = predict_label_gpu(
378
394
  self.id,
379
- body=PredictionRequest(
395
+ body=ClassificationPredictionRequest(
380
396
  input_values=value if isinstance(value, list) else [value],
381
397
  memoryset_override_id=self._memoryset_override_id,
382
398
  expected_labels=(
383
399
  expected_labels
384
400
  if isinstance(expected_labels, list)
385
- else [expected_labels]
386
- if expected_labels is not None
387
- else None
401
+ else [expected_labels] if expected_labels is not None else None
402
+ ),
403
+ tags=list(tags or set()),
404
+ save_telemetry=save_telemetry != "off",
405
+ save_telemetry_synchronously=(
406
+ os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
388
407
  ),
389
- tags=list(tags),
390
- save_telemetry=save_telemetry,
391
- save_telemetry_synchronously=save_telemetry_synchronously,
408
+ filters=cast(list[FilterItem], parsed_filters),
392
409
  ),
393
410
  )
394
411
 
395
- if save_telemetry and any(p.prediction_id is None for p in response):
412
+ if save_telemetry != "off" and any(p.prediction_id is None for p in response):
396
413
  raise RuntimeError("Failed to save prediction to database.")
397
414
 
398
415
  predictions = [
399
- LabelPrediction(
416
+ ClassificationPrediction(
400
417
  prediction_id=prediction.prediction_id,
401
418
  label=prediction.label,
402
419
  label_name=prediction.label_name,
420
+ score=None,
403
421
  confidence=prediction.confidence,
404
422
  anomaly_score=prediction.anomaly_score,
405
423
  memoryset=self.memoryset,
@@ -420,7 +438,7 @@ class ClassificationModel:
420
438
  tag: str | None = None,
421
439
  sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
422
440
  expected_label_match: bool | None = None,
423
- ) -> list[LabelPrediction]:
441
+ ) -> list[ClassificationPrediction]:
424
442
  """
425
443
  Get a list of predictions made by this model
426
444
 
@@ -440,19 +458,19 @@ class ClassificationModel:
440
458
  Get the last 3 predictions:
441
459
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
442
460
  [
443
- LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
444
- LabeledPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
445
- LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
461
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
462
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
463
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
446
464
  ]
447
465
 
448
466
 
449
467
  Get second most confident prediction:
450
468
  >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
451
- [LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
469
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
452
470
 
453
471
  Get predictions where the expected label doesn't match the predicted label:
454
472
  >>> predictions = model.predictions(expected_label_match=False)
455
- [LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
473
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
456
474
  """
457
475
  predictions = list_predictions(
458
476
  body=ListPredictionsRequest(
@@ -465,10 +483,11 @@ class ClassificationModel:
465
483
  ),
466
484
  )
467
485
  return [
468
- LabelPrediction(
486
+ ClassificationPrediction(
469
487
  prediction_id=prediction.prediction_id,
470
488
  label=prediction.label,
471
489
  label_name=prediction.label_name,
490
+ score=None,
472
491
  confidence=prediction.confidence,
473
492
  anomaly_score=prediction.anomaly_score,
474
493
  memoryset=self.memoryset,
@@ -476,59 +495,9 @@ class ClassificationModel:
476
495
  telemetry=prediction,
477
496
  )
478
497
  for prediction in predictions
498
+ if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
479
499
  ]
480
500
 
481
- def _calculate_metrics(
482
- self,
483
- predictions: list[LabelPrediction],
484
- expected_labels: list[int],
485
- ) -> ClassificationEvaluationResult:
486
- targets_array = np.array(expected_labels)
487
- predictions_array = np.array([p.label for p in predictions])
488
-
489
- logits_array = np.array([p.logits for p in predictions])
490
-
491
- f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
492
- accuracy = float(accuracy_score(targets_array, predictions_array))
493
-
494
- # Only compute ROC AUC and PR AUC for binary classification
495
- unique_classes = np.unique(targets_array)
496
-
497
- pr_curve = None
498
- roc_curve = None
499
-
500
- if len(unique_classes) == 2:
501
- try:
502
- precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
503
- pr_auc = float(auc(recalls, precisions))
504
-
505
- pr_curve = PrecisionRecallCurve(
506
- precisions=precisions.tolist(),
507
- recalls=recalls.tolist(),
508
- thresholds=pr_thresholds.tolist(),
509
- auc=pr_auc,
510
- )
511
-
512
- fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
513
- roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
514
-
515
- roc_curve = ROCCurve(
516
- false_positive_rates=fpr.tolist(),
517
- true_positive_rates=tpr.tolist(),
518
- thresholds=roc_thresholds.tolist(),
519
- auc=roc_auc,
520
- )
521
- except ValueError as e:
522
- logging.warning(f"Error calculating PR and ROC curves: {e}")
523
-
524
- return ClassificationEvaluationResult(
525
- f1_score=f1,
526
- accuracy=accuracy,
527
- loss=0.0,
528
- precision_recall_curve=pr_curve,
529
- roc_curve=roc_curve,
530
- )
531
-
532
501
  def _evaluate_datasource(
533
502
  self,
534
503
  datasource: Datasource,
@@ -536,10 +505,11 @@ class ClassificationModel:
536
505
  label_column: str,
537
506
  record_predictions: bool,
538
507
  tags: set[str] | None,
539
- ) -> dict[str, Any]:
540
- response = create_evaluation(
508
+ background: bool = False,
509
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
510
+ response = evaluate_classification_model(
541
511
  self.id,
542
- body=EvaluationRequest(
512
+ body=ClassificationEvaluationRequest(
543
513
  datasource_id=datasource.id,
544
514
  datasource_label_column=label_column,
545
515
  datasource_value_column=value_column,
@@ -548,10 +518,13 @@ class ClassificationModel:
548
518
  telemetry_tags=list(tags) if tags else None,
549
519
  ),
550
520
  )
551
- wait_for_task(response.task_id, description="Running evaluation")
552
- response = get_evaluation(self.id, UUID(response.task_id))
553
- assert response.result is not None
554
- return response.result.to_dict()
521
+
522
+ job = Job(
523
+ response.task_id,
524
+ lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
525
+ and ClassificationMetrics(**r.to_dict()),
526
+ )
527
+ return job if background else job.result()
555
528
 
556
529
  def _evaluate_dataset(
557
530
  self,
@@ -561,34 +534,64 @@ class ClassificationModel:
561
534
  record_predictions: bool,
562
535
  tags: set[str],
563
536
  batch_size: int,
564
- ) -> dict[str, Any]:
565
- predictions = []
566
- expected_labels = []
567
-
568
- for i in range(0, len(dataset), batch_size):
569
- batch = dataset[i : i + batch_size]
570
- predictions.extend(
571
- self.predict(
572
- batch[value_column],
573
- expected_labels=batch[label_column],
574
- tags=tags,
575
- save_telemetry=record_predictions,
576
- save_telemetry_synchronously=(not record_predictions),
577
- )
537
+ ) -> ClassificationMetrics:
538
+ predictions = [
539
+ prediction
540
+ for i in range(0, len(dataset), batch_size)
541
+ for prediction in self.predict(
542
+ dataset[i : i + batch_size][value_column],
543
+ expected_labels=dataset[i : i + batch_size][label_column],
544
+ tags=tags,
545
+ save_telemetry="sync" if record_predictions else "off",
578
546
  )
579
- expected_labels.extend(batch[label_column])
547
+ ]
548
+
549
+ return calculate_classification_metrics(
550
+ expected_labels=dataset[label_column],
551
+ logits=[p.logits for p in predictions],
552
+ anomaly_scores=[p.anomaly_score for p in predictions],
553
+ include_curves=True,
554
+ )
555
+
556
+ @overload
557
+ def evaluate(
558
+ self,
559
+ data: Datasource | Dataset,
560
+ *,
561
+ value_column: str = "value",
562
+ label_column: str = "label",
563
+ record_predictions: bool = False,
564
+ tags: set[str] = {"evaluation"},
565
+ batch_size: int = 100,
566
+ background: Literal[True],
567
+ ) -> Job[ClassificationMetrics]:
568
+ pass
580
569
 
581
- return self._calculate_metrics(predictions, expected_labels).to_dict()
570
+ @overload
571
+ def evaluate(
572
+ self,
573
+ data: Datasource | Dataset,
574
+ *,
575
+ value_column: str = "value",
576
+ label_column: str = "label",
577
+ record_predictions: bool = False,
578
+ tags: set[str] = {"evaluation"},
579
+ batch_size: int = 100,
580
+ background: Literal[False] = False,
581
+ ) -> ClassificationMetrics:
582
+ pass
582
583
 
583
584
  def evaluate(
584
585
  self,
585
586
  data: Datasource | Dataset,
587
+ *,
586
588
  value_column: str = "value",
587
589
  label_column: str = "label",
588
590
  record_predictions: bool = False,
589
591
  tags: set[str] = {"evaluation"},
590
592
  batch_size: int = 100,
591
- ) -> dict[str, Any]:
593
+ background: bool = False,
594
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
592
595
  """
593
596
  Evaluate the classification model on a given dataset or datasource
594
597
 
@@ -596,21 +599,23 @@ class ClassificationModel:
596
599
  data: Dataset or Datasource to evaluate the model on
597
600
  value_column: Name of the column that contains the input values to the model
598
601
  label_column: Name of the column containing the expected labels
599
- record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
600
- tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
602
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
603
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
601
604
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
605
+ background: Whether to run the operation in the background and return a job handle
602
606
 
603
607
  Returns:
604
- Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
608
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
605
609
 
606
610
  Examples:
607
- Evaluate using a Datasource:
608
611
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
609
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
610
-
611
- Evaluate using a Dataset:
612
- >>> model.evaluate(dataset, value_column="text", label_column="sentiment")
613
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
612
+ ClassificationMetrics({
613
+ accuracy: 0.8500,
614
+ f1_score: 0.8500,
615
+ roc_auc: 0.8500,
616
+ pr_auc: 0.8500,
617
+ anomaly_score: 0.3500 ± 0.0500,
618
+ })
614
619
  """
615
620
  if isinstance(data, Datasource):
616
621
  return self._evaluate_datasource(
@@ -619,8 +624,9 @@ class ClassificationModel:
619
624
  label_column=label_column,
620
625
  record_predictions=record_predictions,
621
626
  tags=tags,
627
+ background=background,
622
628
  )
623
- else:
629
+ elif isinstance(data, Dataset):
624
630
  return self._evaluate_dataset(
625
631
  dataset=data,
626
632
  value_column=value_column,
@@ -629,6 +635,8 @@ class ClassificationModel:
629
635
  tags=tags,
630
636
  batch_size=batch_size,
631
637
  )
638
+ else:
639
+ raise ValueError(f"Invalid data type: {type(data)}")
632
640
 
633
641
  def finetune(self, datasource: Datasource):
634
642
  # do not document until implemented