orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -172
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +337 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
@@ -5,37 +5,29 @@ import os
5
5
  from contextlib import contextmanager
6
6
  from datetime import datetime
7
7
  from typing import Any, Generator, Iterable, Literal, cast, overload
8
- from uuid import UUID, uuid4
8
+ from uuid import UUID
9
9
 
10
- import numpy as np
11
-
12
- import numpy as np
13
10
  from datasets import Dataset
14
- from sklearn.metrics import (
15
- accuracy_score,
16
- auc,
17
- f1_score,
18
- roc_auc_score,
19
- )
20
11
 
21
12
  from ._generated_api_client.api import (
22
- create_evaluation,
23
- create_model,
24
- delete_model,
25
- get_evaluation,
26
- get_model,
27
- list_models,
13
+ create_classification_model_gpu,
14
+ delete_classification_model,
15
+ evaluate_classification_model,
16
+ get_classification_model,
17
+ get_classification_model_evaluation,
18
+ list_classification_models,
28
19
  list_predictions,
29
- predict_gpu,
20
+ predict_label_gpu,
30
21
  record_prediction_feedback,
31
- update_model,
22
+ update_classification_model,
32
23
  )
33
24
  from ._generated_api_client.models import (
34
- ClassificationEvaluationResult,
35
- CreateRACModelRequest,
36
- EvaluationRequest,
25
+ ClassificationEvaluationRequest,
26
+ ClassificationModelMetadata,
27
+ ClassificationPredictionRequest,
28
+ CreateClassificationModelRequest,
29
+ LabelPredictionWithMemoriesAndFeedback,
37
30
  ListPredictionsRequest,
38
- PrecisionRecallCurve,
39
31
  )
40
32
  from ._generated_api_client.models import (
41
33
  PredictionSortItemItemType0 as PredictionSortColumns,
@@ -44,18 +36,21 @@ from ._generated_api_client.models import (
44
36
  PredictionSortItemItemType1 as PredictionSortDirection,
45
37
  )
46
38
  from ._generated_api_client.models import (
39
+ PredictiveModelUpdate,
47
40
  RACHeadType,
48
- RACModelMetadata,
49
- RACModelUpdate,
50
- ROCCurve,
51
41
  )
52
- from ._generated_api_client.models.prediction_request import PredictionRequest
53
- from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
42
+ from ._generated_api_client.types import UNSET as CLIENT_UNSET
43
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
54
44
  from ._utils.common import UNSET, CreateMode, DropMode
55
- from ._utils.task import wait_for_task
56
45
  from .datasource import Datasource
57
- from .memoryset import LabeledMemoryset
58
- from .telemetry import LabelPrediction, _parse_feedback
46
+ from .job import Job
47
+ from .memoryset import (
48
+ FilterItem,
49
+ FilterItemTuple,
50
+ LabeledMemoryset,
51
+ _parse_filter_item_from_tuple,
52
+ )
53
+ from .telemetry import ClassificationPrediction, _parse_feedback
59
54
 
60
55
 
61
56
  class ClassificationModel:
@@ -72,6 +67,7 @@ class ClassificationModel:
72
67
  memory_lookup_count: Number of memories the model uses for each prediction
73
68
  weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
74
69
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
70
+ locked: Whether the model is locked to prevent accidental deletion
75
71
  created_at: When the model was created
76
72
  """
77
73
 
@@ -85,9 +81,10 @@ class ClassificationModel:
85
81
  weigh_memories: bool | None
86
82
  min_memory_weight: float | None
87
83
  version: int
84
+ locked: bool
88
85
  created_at: datetime
89
86
 
90
- def __init__(self, metadata: RACModelMetadata):
87
+ def __init__(self, metadata: ClassificationModelMetadata):
91
88
  # for internal use only, do not document
92
89
  self.id = metadata.id
93
90
  self.name = metadata.name
@@ -99,10 +96,11 @@ class ClassificationModel:
99
96
  self.weigh_memories = metadata.weigh_memories
100
97
  self.min_memory_weight = metadata.min_memory_weight
101
98
  self.version = metadata.version
99
+ self.locked = metadata.locked
102
100
  self.created_at = metadata.created_at
103
101
 
104
102
  self._memoryset_override_id: str | None = None
105
- self._last_prediction: LabelPrediction | None = None
103
+ self._last_prediction: ClassificationPrediction | None = None
106
104
  self._last_prediction_was_batch: bool = False
107
105
 
108
106
  def __eq__(self, other) -> bool:
@@ -120,7 +118,7 @@ class ClassificationModel:
120
118
  )
121
119
 
122
120
  @property
123
- def last_prediction(self) -> LabelPrediction:
121
+ def last_prediction(self) -> ClassificationPrediction:
124
122
  """
125
123
  Last prediction made by the model
126
124
 
@@ -208,8 +206,8 @@ class ClassificationModel:
208
206
 
209
207
  return existing
210
208
 
211
- metadata = create_model(
212
- body=CreateRACModelRequest(
209
+ metadata = create_classification_model_gpu(
210
+ body=CreateClassificationModelRequest(
213
211
  name=name,
214
212
  memoryset_id=memoryset.id,
215
213
  head_type=RACHeadType(head_type),
@@ -236,7 +234,7 @@ class ClassificationModel:
236
234
  Raises:
237
235
  LookupError: If the classification model does not exist
238
236
  """
239
- return cls(get_model(name))
237
+ return cls(get_classification_model(name))
240
238
 
241
239
  @classmethod
242
240
  def exists(cls, name_or_id: str) -> bool:
@@ -263,7 +261,7 @@ class ClassificationModel:
263
261
  Returns:
264
262
  List of handles to all classification models in the OrcaCloud
265
263
  """
266
- return [cls(metadata) for metadata in list_models()]
264
+ return [cls(metadata) for metadata in list_classification_models()]
267
265
 
268
266
  @classmethod
269
267
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -282,7 +280,7 @@ class ClassificationModel:
282
280
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
283
281
  """
284
282
  try:
285
- delete_model(name_or_id)
283
+ delete_classification_model(name_or_id)
286
284
  logging.info(f"Deleted model {name_or_id}")
287
285
  except LookupError:
288
286
  if if_not_exists == "error":
@@ -290,34 +288,53 @@ class ClassificationModel:
290
288
 
291
289
  def refresh(self):
292
290
  """Refresh the model data from the OrcaCloud"""
293
- self.__dict__.update(ClassificationModel.open(self.name).__dict__)
291
+ self.__dict__.update(self.open(self.name).__dict__)
294
292
 
295
- def update_metadata(self, *, description: str | None = UNSET) -> None:
293
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
296
294
  """
297
- Update editable classification model metadata properties.
295
+ Update editable attributes of the model.
296
+
297
+ Note:
298
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
298
299
 
299
300
  Params:
300
- description: Value to set for the description, defaults to `[UNSET]` if not provided.
301
+ description: Value to set for the description
302
+ locked: Value to set for the locked status
301
303
 
302
304
  Examples:
303
305
  Update the description:
304
- >>> model.update(description="New description")
306
+ >>> model.set(description="New description")
305
307
 
306
308
  Remove description:
307
- >>> model.update(description=None)
309
+ >>> model.set(description=None)
310
+
311
+ Lock the model:
312
+ >>> model.set(locked=True)
308
313
  """
309
- update_model(self.id, body=RACModelUpdate(description=description))
314
+ update_data = PredictiveModelUpdate(
315
+ description=CLIENT_UNSET if description is UNSET else description,
316
+ locked=CLIENT_UNSET if locked is UNSET else locked,
317
+ )
318
+ update_classification_model(self.id, body=update_data)
310
319
  self.refresh()
311
320
 
321
+ def lock(self) -> None:
322
+ """Lock the model to prevent accidental deletion"""
323
+ self.set(locked=True)
324
+
325
+ def unlock(self) -> None:
326
+ """Unlock the model to allow deletion"""
327
+ self.set(locked=False)
328
+
312
329
  @overload
313
330
  def predict(
314
331
  self,
315
332
  value: list[str],
316
333
  expected_labels: list[int] | None = None,
317
- tags: set[str] = set(),
318
- save_telemetry: bool = True,
319
- save_telemetry_synchronously: bool = False,
320
- ) -> list[LabelPrediction]:
334
+ filters: list[FilterItemTuple] = [],
335
+ tags: set[str] | None = None,
336
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
337
+ ) -> list[ClassificationPrediction]:
321
338
  pass
322
339
 
323
340
  @overload
@@ -325,20 +342,20 @@ class ClassificationModel:
325
342
  self,
326
343
  value: str,
327
344
  expected_labels: int | None = None,
328
- tags: set[str] = set(),
329
- save_telemetry: bool = True,
330
- save_telemetry_synchronously: bool = False,
331
- ) -> LabelPrediction:
345
+ filters: list[FilterItemTuple] = [],
346
+ tags: set[str] | None = None,
347
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
348
+ ) -> ClassificationPrediction:
332
349
  pass
333
350
 
334
351
  def predict(
335
352
  self,
336
353
  value: list[str] | str,
337
354
  expected_labels: list[int] | int | None = None,
338
- tags: set[str] = set(),
339
- save_telemetry: bool = True,
340
- save_telemetry_synchronously: bool = False,
341
- ) -> list[LabelPrediction] | LabelPrediction:
355
+ filters: list[FilterItemTuple] = [],
356
+ tags: set[str] | None = None,
357
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
358
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
342
359
  """
343
360
  Predict label(s) for the given input value(s) grounded in similar memories
344
361
 
@@ -346,10 +363,12 @@ class ClassificationModel:
346
363
  value: Value(s) to get predict the labels of
347
364
  expected_labels: Expected label(s) for the given input to record for model evaluation
348
365
  tags: Tags to add to the prediction(s)
349
- save_telemetry: Whether to enable telemetry for the prediction(s)
350
- save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
351
- asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
352
- may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
366
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
367
+ * `"off"`: Do not save telemetry
368
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
369
+ environment variable is set.
370
+ * `"sync"`: Save telemetry synchronously
371
+ * `"async"`: Save telemetry asynchronously
353
372
 
354
373
  Returns:
355
374
  Label prediction or list of label predictions
@@ -357,49 +376,51 @@ class ClassificationModel:
357
376
  Examples:
358
377
  Predict the label for a single value:
359
378
  >>> prediction = model.predict("I am happy", tags={"test"})
360
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
379
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
361
380
 
362
381
  Predict the labels for a list of values:
363
382
  >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
364
383
  [
365
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
366
- LabelPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
384
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
385
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
367
386
  ]
368
387
  """
369
388
 
370
- if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
371
- env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
372
- logging.info(
373
- f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
374
- )
375
- save_telemetry_synchronously = env_var.lower() == "true"
389
+ parsed_filters = [
390
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
391
+ ]
392
+
393
+ if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
394
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
376
395
 
377
- response = predict_gpu(
396
+ response = predict_label_gpu(
378
397
  self.id,
379
- body=PredictionRequest(
398
+ body=ClassificationPredictionRequest(
380
399
  input_values=value if isinstance(value, list) else [value],
381
400
  memoryset_override_id=self._memoryset_override_id,
382
401
  expected_labels=(
383
402
  expected_labels
384
403
  if isinstance(expected_labels, list)
385
- else [expected_labels]
386
- if expected_labels is not None
387
- else None
404
+ else [expected_labels] if expected_labels is not None else None
405
+ ),
406
+ tags=list(tags or set()),
407
+ save_telemetry=save_telemetry != "off",
408
+ save_telemetry_synchronously=(
409
+ os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
388
410
  ),
389
- tags=list(tags),
390
- save_telemetry=save_telemetry,
391
- save_telemetry_synchronously=save_telemetry_synchronously,
411
+ filters=cast(list[FilterItem], parsed_filters),
392
412
  ),
393
413
  )
394
414
 
395
- if save_telemetry and any(p.prediction_id is None for p in response):
415
+ if save_telemetry != "off" and any(p.prediction_id is None for p in response):
396
416
  raise RuntimeError("Failed to save prediction to database.")
397
417
 
398
418
  predictions = [
399
- LabelPrediction(
419
+ ClassificationPrediction(
400
420
  prediction_id=prediction.prediction_id,
401
421
  label=prediction.label,
402
422
  label_name=prediction.label_name,
423
+ score=None,
403
424
  confidence=prediction.confidence,
404
425
  anomaly_score=prediction.anomaly_score,
405
426
  memoryset=self.memoryset,
@@ -420,7 +441,7 @@ class ClassificationModel:
420
441
  tag: str | None = None,
421
442
  sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
422
443
  expected_label_match: bool | None = None,
423
- ) -> list[LabelPrediction]:
444
+ ) -> list[ClassificationPrediction]:
424
445
  """
425
446
  Get a list of predictions made by this model
426
447
 
@@ -440,19 +461,19 @@ class ClassificationModel:
440
461
  Get the last 3 predictions:
441
462
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
442
463
  [
443
- LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
444
- LabeledPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
445
- LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
464
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
465
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
466
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
446
467
  ]
447
468
 
448
469
 
449
470
  Get second most confident prediction:
450
471
  >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
451
- [LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
472
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
452
473
 
453
474
  Get predictions where the expected label doesn't match the predicted label:
454
475
  >>> predictions = model.predictions(expected_label_match=False)
455
- [LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
476
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
456
477
  """
457
478
  predictions = list_predictions(
458
479
  body=ListPredictionsRequest(
@@ -465,10 +486,11 @@ class ClassificationModel:
465
486
  ),
466
487
  )
467
488
  return [
468
- LabelPrediction(
489
+ ClassificationPrediction(
469
490
  prediction_id=prediction.prediction_id,
470
491
  label=prediction.label,
471
492
  label_name=prediction.label_name,
493
+ score=None,
472
494
  confidence=prediction.confidence,
473
495
  anomaly_score=prediction.anomaly_score,
474
496
  memoryset=self.memoryset,
@@ -476,59 +498,9 @@ class ClassificationModel:
476
498
  telemetry=prediction,
477
499
  )
478
500
  for prediction in predictions
501
+ if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
479
502
  ]
480
503
 
481
- def _calculate_metrics(
482
- self,
483
- predictions: list[LabelPrediction],
484
- expected_labels: list[int],
485
- ) -> ClassificationEvaluationResult:
486
- targets_array = np.array(expected_labels)
487
- predictions_array = np.array([p.label for p in predictions])
488
-
489
- logits_array = np.array([p.logits for p in predictions])
490
-
491
- f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
492
- accuracy = float(accuracy_score(targets_array, predictions_array))
493
-
494
- # Only compute ROC AUC and PR AUC for binary classification
495
- unique_classes = np.unique(targets_array)
496
-
497
- pr_curve = None
498
- roc_curve = None
499
-
500
- if len(unique_classes) == 2:
501
- try:
502
- precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
503
- pr_auc = float(auc(recalls, precisions))
504
-
505
- pr_curve = PrecisionRecallCurve(
506
- precisions=precisions.tolist(),
507
- recalls=recalls.tolist(),
508
- thresholds=pr_thresholds.tolist(),
509
- auc=pr_auc,
510
- )
511
-
512
- fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
513
- roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
514
-
515
- roc_curve = ROCCurve(
516
- false_positive_rates=fpr.tolist(),
517
- true_positive_rates=tpr.tolist(),
518
- thresholds=roc_thresholds.tolist(),
519
- auc=roc_auc,
520
- )
521
- except ValueError as e:
522
- logging.warning(f"Error calculating PR and ROC curves: {e}")
523
-
524
- return ClassificationEvaluationResult(
525
- f1_score=f1,
526
- accuracy=accuracy,
527
- loss=0.0,
528
- precision_recall_curve=pr_curve,
529
- roc_curve=roc_curve,
530
- )
531
-
532
504
  def _evaluate_datasource(
533
505
  self,
534
506
  datasource: Datasource,
@@ -536,10 +508,11 @@ class ClassificationModel:
536
508
  label_column: str,
537
509
  record_predictions: bool,
538
510
  tags: set[str] | None,
539
- ) -> dict[str, Any]:
540
- response = create_evaluation(
511
+ background: bool = False,
512
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
513
+ response = evaluate_classification_model(
541
514
  self.id,
542
- body=EvaluationRequest(
515
+ body=ClassificationEvaluationRequest(
543
516
  datasource_id=datasource.id,
544
517
  datasource_label_column=label_column,
545
518
  datasource_value_column=value_column,
@@ -548,10 +521,13 @@ class ClassificationModel:
548
521
  telemetry_tags=list(tags) if tags else None,
549
522
  ),
550
523
  )
551
- wait_for_task(response.task_id, description="Running evaluation")
552
- response = get_evaluation(self.id, UUID(response.task_id))
553
- assert response.result is not None
554
- return response.result.to_dict()
524
+
525
+ job = Job(
526
+ response.task_id,
527
+ lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
528
+ and ClassificationMetrics(**r.to_dict()),
529
+ )
530
+ return job if background else job.result()
555
531
 
556
532
  def _evaluate_dataset(
557
533
  self,
@@ -561,34 +537,64 @@ class ClassificationModel:
561
537
  record_predictions: bool,
562
538
  tags: set[str],
563
539
  batch_size: int,
564
- ) -> dict[str, Any]:
565
- predictions = []
566
- expected_labels = []
567
-
568
- for i in range(0, len(dataset), batch_size):
569
- batch = dataset[i : i + batch_size]
570
- predictions.extend(
571
- self.predict(
572
- batch[value_column],
573
- expected_labels=batch[label_column],
574
- tags=tags,
575
- save_telemetry=record_predictions,
576
- save_telemetry_synchronously=(not record_predictions),
577
- )
540
+ ) -> ClassificationMetrics:
541
+ predictions = [
542
+ prediction
543
+ for i in range(0, len(dataset), batch_size)
544
+ for prediction in self.predict(
545
+ dataset[i : i + batch_size][value_column],
546
+ expected_labels=dataset[i : i + batch_size][label_column],
547
+ tags=tags,
548
+ save_telemetry="sync" if record_predictions else "off",
578
549
  )
579
- expected_labels.extend(batch[label_column])
550
+ ]
551
+
552
+ return calculate_classification_metrics(
553
+ expected_labels=dataset[label_column],
554
+ logits=[p.logits for p in predictions],
555
+ anomaly_scores=[p.anomaly_score for p in predictions],
556
+ include_curves=True,
557
+ )
558
+
559
+ @overload
560
+ def evaluate(
561
+ self,
562
+ data: Datasource | Dataset,
563
+ *,
564
+ value_column: str = "value",
565
+ label_column: str = "label",
566
+ record_predictions: bool = False,
567
+ tags: set[str] = {"evaluation"},
568
+ batch_size: int = 100,
569
+ background: Literal[True],
570
+ ) -> Job[ClassificationMetrics]:
571
+ pass
580
572
 
581
- return self._calculate_metrics(predictions, expected_labels).to_dict()
573
+ @overload
574
+ def evaluate(
575
+ self,
576
+ data: Datasource | Dataset,
577
+ *,
578
+ value_column: str = "value",
579
+ label_column: str = "label",
580
+ record_predictions: bool = False,
581
+ tags: set[str] = {"evaluation"},
582
+ batch_size: int = 100,
583
+ background: Literal[False] = False,
584
+ ) -> ClassificationMetrics:
585
+ pass
582
586
 
583
587
  def evaluate(
584
588
  self,
585
589
  data: Datasource | Dataset,
590
+ *,
586
591
  value_column: str = "value",
587
592
  label_column: str = "label",
588
593
  record_predictions: bool = False,
589
594
  tags: set[str] = {"evaluation"},
590
595
  batch_size: int = 100,
591
- ) -> dict[str, Any]:
596
+ background: bool = False,
597
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
592
598
  """
593
599
  Evaluate the classification model on a given dataset or datasource
594
600
 
@@ -596,21 +602,23 @@ class ClassificationModel:
596
602
  data: Dataset or Datasource to evaluate the model on
597
603
  value_column: Name of the column that contains the input values to the model
598
604
  label_column: Name of the column containing the expected labels
599
- record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
600
- tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
605
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
606
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
601
607
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
608
+ background: Whether to run the operation in the background and return a job handle
602
609
 
603
610
  Returns:
604
- Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
611
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
605
612
 
606
613
  Examples:
607
- Evaluate using a Datasource:
608
614
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
609
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
610
-
611
- Evaluate using a Dataset:
612
- >>> model.evaluate(dataset, value_column="text", label_column="sentiment")
613
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
615
+ ClassificationMetrics({
616
+ accuracy: 0.8500,
617
+ f1_score: 0.8500,
618
+ roc_auc: 0.8500,
619
+ pr_auc: 0.8500,
620
+ anomaly_score: 0.3500 ± 0.0500,
621
+ })
614
622
  """
615
623
  if isinstance(data, Datasource):
616
624
  return self._evaluate_datasource(
@@ -619,8 +627,9 @@ class ClassificationModel:
619
627
  label_column=label_column,
620
628
  record_predictions=record_predictions,
621
629
  tags=tags,
630
+ background=background,
622
631
  )
623
- else:
632
+ elif isinstance(data, Dataset):
624
633
  return self._evaluate_dataset(
625
634
  dataset=data,
626
635
  value_column=value_column,
@@ -629,6 +638,8 @@ class ClassificationModel:
629
638
  tags=tags,
630
639
  batch_size=batch_size,
631
640
  )
641
+ else:
642
+ raise ValueError(f"Invalid data type: {type(data)}")
632
643
 
633
644
  def finetune(self, datasource: Datasource):
634
645
  # do not document until implemented