orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +84 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  42. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  43. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  44. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  45. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  46. orca_sdk/_generated_api_client/models/__init__.py +90 -24
  47. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  48. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  49. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  50. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  51. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  52. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  53. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  54. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  55. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  56. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  57. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  58. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  59. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  60. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  61. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  62. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  63. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  64. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  65. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  66. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  67. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  68. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  69. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  70. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  71. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  72. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  73. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  74. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  75. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  76. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  77. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  78. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  79. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  80. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  81. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  82. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  83. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  84. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  85. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  86. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  88. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  92. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  93. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  94. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  95. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  96. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  97. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  98. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  99. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  100. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  101. orca_sdk/_shared/__init__.py +9 -1
  102. orca_sdk/_shared/metrics.py +257 -87
  103. orca_sdk/_shared/metrics_test.py +136 -77
  104. orca_sdk/_utils/data_parsing.py +0 -3
  105. orca_sdk/_utils/data_parsing_test.py +0 -3
  106. orca_sdk/_utils/prediction_result_ui.py +55 -23
  107. orca_sdk/classification_model.py +184 -174
  108. orca_sdk/classification_model_test.py +178 -142
  109. orca_sdk/conftest.py +77 -26
  110. orca_sdk/datasource.py +34 -0
  111. orca_sdk/datasource_test.py +9 -1
  112. orca_sdk/embedding_model.py +136 -14
  113. orca_sdk/embedding_model_test.py +10 -6
  114. orca_sdk/job.py +329 -0
  115. orca_sdk/job_test.py +48 -0
  116. orca_sdk/memoryset.py +882 -161
  117. orca_sdk/memoryset_test.py +58 -23
  118. orca_sdk/regression_model.py +647 -0
  119. orca_sdk/regression_model_test.py +338 -0
  120. orca_sdk/telemetry.py +225 -106
  121. orca_sdk/telemetry_test.py +34 -30
  122. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  123. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
  124. orca_sdk/_utils/task.py +0 -73
  125. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -5,37 +5,29 @@ import os
5
5
  from contextlib import contextmanager
6
6
  from datetime import datetime
7
7
  from typing import Any, Generator, Iterable, Literal, cast, overload
8
- from uuid import UUID, uuid4
8
+ from uuid import UUID
9
9
 
10
- import numpy as np
11
-
12
- import numpy as np
13
10
  from datasets import Dataset
14
- from sklearn.metrics import (
15
- accuracy_score,
16
- auc,
17
- f1_score,
18
- roc_auc_score,
19
- )
20
11
 
21
12
  from ._generated_api_client.api import (
22
- create_evaluation,
23
- create_model,
24
- delete_model,
25
- get_evaluation,
26
- get_model,
27
- list_models,
13
+ create_classification_model,
14
+ delete_classification_model,
15
+ evaluate_classification_model,
16
+ get_classification_model,
17
+ get_classification_model_evaluation,
18
+ list_classification_models,
28
19
  list_predictions,
29
- predict_gpu,
20
+ predict_label_gpu,
30
21
  record_prediction_feedback,
31
- update_model,
22
+ update_classification_model,
32
23
  )
33
24
  from ._generated_api_client.models import (
34
- ClassificationEvaluationResult,
35
- CreateRACModelRequest,
36
- EvaluationRequest,
25
+ ClassificationEvaluationRequest,
26
+ ClassificationModelMetadata,
27
+ ClassificationPredictionRequest,
28
+ CreateClassificationModelRequest,
29
+ LabelPredictionWithMemoriesAndFeedback,
37
30
  ListPredictionsRequest,
38
- PrecisionRecallCurve,
39
31
  )
40
32
  from ._generated_api_client.models import (
41
33
  PredictionSortItemItemType0 as PredictionSortColumns,
@@ -43,19 +35,19 @@ from ._generated_api_client.models import (
43
35
  from ._generated_api_client.models import (
44
36
  PredictionSortItemItemType1 as PredictionSortDirection,
45
37
  )
46
- from ._generated_api_client.models import (
47
- RACHeadType,
48
- RACModelMetadata,
49
- RACModelUpdate,
50
- ROCCurve,
51
- )
52
- from ._generated_api_client.models.prediction_request import PredictionRequest
53
- from ._shared.metrics import calculate_pr_curve, calculate_roc_curve
38
+ from ._generated_api_client.models import PredictiveModelUpdate, RACHeadType
39
+ from ._generated_api_client.types import UNSET as CLIENT_UNSET
40
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
54
41
  from ._utils.common import UNSET, CreateMode, DropMode
55
- from ._utils.task import wait_for_task
56
42
  from .datasource import Datasource
57
- from .memoryset import LabeledMemoryset
58
- from .telemetry import LabelPrediction, _parse_feedback
43
+ from .job import Job
44
+ from .memoryset import (
45
+ FilterItem,
46
+ FilterItemTuple,
47
+ LabeledMemoryset,
48
+ _parse_filter_item_from_tuple,
49
+ )
50
+ from .telemetry import ClassificationPrediction, _parse_feedback
59
51
 
60
52
 
61
53
  class ClassificationModel:
@@ -72,6 +64,7 @@ class ClassificationModel:
72
64
  memory_lookup_count: Number of memories the model uses for each prediction
73
65
  weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
74
66
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
67
+ locked: Whether the model is locked to prevent accidental deletion
75
68
  created_at: When the model was created
76
69
  """
77
70
 
@@ -85,9 +78,10 @@ class ClassificationModel:
85
78
  weigh_memories: bool | None
86
79
  min_memory_weight: float | None
87
80
  version: int
81
+ locked: bool
88
82
  created_at: datetime
89
83
 
90
- def __init__(self, metadata: RACModelMetadata):
84
+ def __init__(self, metadata: ClassificationModelMetadata):
91
85
  # for internal use only, do not document
92
86
  self.id = metadata.id
93
87
  self.name = metadata.name
@@ -99,10 +93,11 @@ class ClassificationModel:
99
93
  self.weigh_memories = metadata.weigh_memories
100
94
  self.min_memory_weight = metadata.min_memory_weight
101
95
  self.version = metadata.version
96
+ self.locked = metadata.locked
102
97
  self.created_at = metadata.created_at
103
98
 
104
99
  self._memoryset_override_id: str | None = None
105
- self._last_prediction: LabelPrediction | None = None
100
+ self._last_prediction: ClassificationPrediction | None = None
106
101
  self._last_prediction_was_batch: bool = False
107
102
 
108
103
  def __eq__(self, other) -> bool:
@@ -120,7 +115,7 @@ class ClassificationModel:
120
115
  )
121
116
 
122
117
  @property
123
- def last_prediction(self) -> LabelPrediction:
118
+ def last_prediction(self) -> ClassificationPrediction:
124
119
  """
125
120
  Last prediction made by the model
126
121
 
@@ -208,8 +203,8 @@ class ClassificationModel:
208
203
 
209
204
  return existing
210
205
 
211
- metadata = create_model(
212
- body=CreateRACModelRequest(
206
+ metadata = create_classification_model(
207
+ body=CreateClassificationModelRequest(
213
208
  name=name,
214
209
  memoryset_id=memoryset.id,
215
210
  head_type=RACHeadType(head_type),
@@ -236,7 +231,7 @@ class ClassificationModel:
236
231
  Raises:
237
232
  LookupError: If the classification model does not exist
238
233
  """
239
- return cls(get_model(name))
234
+ return cls(get_classification_model(name))
240
235
 
241
236
  @classmethod
242
237
  def exists(cls, name_or_id: str) -> bool:
@@ -263,7 +258,7 @@ class ClassificationModel:
263
258
  Returns:
264
259
  List of handles to all classification models in the OrcaCloud
265
260
  """
266
- return [cls(metadata) for metadata in list_models()]
261
+ return [cls(metadata) for metadata in list_classification_models()]
267
262
 
268
263
  @classmethod
269
264
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -282,7 +277,7 @@ class ClassificationModel:
282
277
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
283
278
  """
284
279
  try:
285
- delete_model(name_or_id)
280
+ delete_classification_model(name_or_id)
286
281
  logging.info(f"Deleted model {name_or_id}")
287
282
  except LookupError:
288
283
  if if_not_exists == "error":
@@ -290,34 +285,53 @@ class ClassificationModel:
290
285
 
291
286
  def refresh(self):
292
287
  """Refresh the model data from the OrcaCloud"""
293
- self.__dict__.update(ClassificationModel.open(self.name).__dict__)
288
+ self.__dict__.update(self.open(self.name).__dict__)
294
289
 
295
- def update_metadata(self, *, description: str | None = UNSET) -> None:
290
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
296
291
  """
297
- Update editable classification model metadata properties.
292
+ Update editable attributes of the model.
293
+
294
+ Note:
295
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
298
296
 
299
297
  Params:
300
- description: Value to set for the description, defaults to `[UNSET]` if not provided.
298
+ description: Value to set for the description
299
+ locked: Value to set for the locked status
301
300
 
302
301
  Examples:
303
302
  Update the description:
304
- >>> model.update(description="New description")
303
+ >>> model.set(description="New description")
305
304
 
306
305
  Remove description:
307
- >>> model.update(description=None)
306
+ >>> model.set(description=None)
307
+
308
+ Lock the model:
309
+ >>> model.set(locked=True)
308
310
  """
309
- update_model(self.id, body=RACModelUpdate(description=description))
311
+ update_data = PredictiveModelUpdate(
312
+ description=CLIENT_UNSET if description is UNSET else description,
313
+ locked=CLIENT_UNSET if locked is UNSET else locked,
314
+ )
315
+ update_classification_model(self.id, body=update_data)
310
316
  self.refresh()
311
317
 
318
+ def lock(self) -> None:
319
+ """Lock the model to prevent accidental deletion"""
320
+ self.set(locked=True)
321
+
322
+ def unlock(self) -> None:
323
+ """Unlock the model to allow deletion"""
324
+ self.set(locked=False)
325
+
312
326
  @overload
313
327
  def predict(
314
328
  self,
315
329
  value: list[str],
316
330
  expected_labels: list[int] | None = None,
317
- tags: set[str] = set(),
318
- save_telemetry: bool = True,
319
- save_telemetry_synchronously: bool = False,
320
- ) -> list[LabelPrediction]:
331
+ filters: list[FilterItemTuple] = [],
332
+ tags: set[str] | None = None,
333
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
334
+ ) -> list[ClassificationPrediction]:
321
335
  pass
322
336
 
323
337
  @overload
@@ -325,20 +339,20 @@ class ClassificationModel:
325
339
  self,
326
340
  value: str,
327
341
  expected_labels: int | None = None,
328
- tags: set[str] = set(),
329
- save_telemetry: bool = True,
330
- save_telemetry_synchronously: bool = False,
331
- ) -> LabelPrediction:
342
+ filters: list[FilterItemTuple] = [],
343
+ tags: set[str] | None = None,
344
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
345
+ ) -> ClassificationPrediction:
332
346
  pass
333
347
 
334
348
  def predict(
335
349
  self,
336
350
  value: list[str] | str,
337
351
  expected_labels: list[int] | int | None = None,
338
- tags: set[str] = set(),
339
- save_telemetry: bool = True,
340
- save_telemetry_synchronously: bool = False,
341
- ) -> list[LabelPrediction] | LabelPrediction:
352
+ filters: list[FilterItemTuple] = [],
353
+ tags: set[str] | None = None,
354
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
355
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
342
356
  """
343
357
  Predict label(s) for the given input value(s) grounded in similar memories
344
358
 
@@ -346,10 +360,12 @@ class ClassificationModel:
346
360
  value: Value(s) to get predict the labels of
347
361
  expected_labels: Expected label(s) for the given input to record for model evaluation
348
362
  tags: Tags to add to the prediction(s)
349
- save_telemetry: Whether to enable telemetry for the prediction(s)
350
- save_telemetry_synchronously: Whether to save telemetry synchronously. If `False`, telemetry will be saved
351
- asynchronously in the background. This may result in a delay in the telemetry being available. Please note that this
352
- may be overriden by the ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY environment variable.
363
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
364
+ * `"off"`: Do not save telemetry
365
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
366
+ environment variable is set.
367
+ * `"sync"`: Save telemetry synchronously
368
+ * `"async"`: Save telemetry asynchronously
353
369
 
354
370
  Returns:
355
371
  Label prediction or list of label predictions
@@ -357,26 +373,26 @@ class ClassificationModel:
357
373
  Examples:
358
374
  Predict the label for a single value:
359
375
  >>> prediction = model.predict("I am happy", tags={"test"})
360
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
376
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
361
377
 
362
378
  Predict the labels for a list of values:
363
379
  >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
364
380
  [
365
- LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
366
- LabelPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
381
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
382
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
367
383
  ]
368
384
  """
369
385
 
370
- if "ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY" in os.environ:
371
- env_var = os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"]
372
- logging.info(
373
- f"ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY is set to {env_var} which will override the parameter save_telemetry_synchronously = {save_telemetry_synchronously}"
374
- )
375
- save_telemetry_synchronously = env_var.lower() == "true"
386
+ parsed_filters = [
387
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
388
+ ]
376
389
 
377
- response = predict_gpu(
390
+ if not all(isinstance(filter, FilterItem) for filter in parsed_filters):
391
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
392
+
393
+ response = predict_label_gpu(
378
394
  self.id,
379
- body=PredictionRequest(
395
+ body=ClassificationPredictionRequest(
380
396
  input_values=value if isinstance(value, list) else [value],
381
397
  memoryset_override_id=self._memoryset_override_id,
382
398
  expected_labels=(
@@ -384,27 +400,32 @@ class ClassificationModel:
384
400
  if isinstance(expected_labels, list)
385
401
  else [expected_labels] if expected_labels is not None else None
386
402
  ),
387
- tags=list(tags),
388
- save_telemetry=save_telemetry,
389
- save_telemetry_synchronously=save_telemetry_synchronously,
403
+ tags=list(tags or set()),
404
+ save_telemetry=save_telemetry != "off",
405
+ save_telemetry_synchronously=(
406
+ os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
407
+ ),
408
+ filters=cast(list[FilterItem], parsed_filters),
390
409
  ),
391
410
  )
392
411
 
393
- if save_telemetry and any(p.prediction_id is None for p in response):
412
+ if save_telemetry != "off" and any(p.prediction_id is None for p in response):
394
413
  raise RuntimeError("Failed to save prediction to database.")
395
414
 
396
415
  predictions = [
397
- LabelPrediction(
416
+ ClassificationPrediction(
398
417
  prediction_id=prediction.prediction_id,
399
418
  label=prediction.label,
400
419
  label_name=prediction.label_name,
420
+ score=None,
401
421
  confidence=prediction.confidence,
402
422
  anomaly_score=prediction.anomaly_score,
403
423
  memoryset=self.memoryset,
404
424
  model=self,
405
425
  logits=prediction.logits,
426
+ input_value=input_value,
406
427
  )
407
- for prediction in response
428
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
408
429
  ]
409
430
  self._last_prediction_was_batch = isinstance(value, list)
410
431
  self._last_prediction = predictions[-1]
@@ -417,7 +438,7 @@ class ClassificationModel:
417
438
  tag: str | None = None,
418
439
  sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
419
440
  expected_label_match: bool | None = None,
420
- ) -> list[LabelPrediction]:
441
+ ) -> list[ClassificationPrediction]:
421
442
  """
422
443
  Get a list of predictions made by this model
423
444
 
@@ -437,19 +458,19 @@ class ClassificationModel:
437
458
  Get the last 3 predictions:
438
459
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
439
460
  [
440
- LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
441
- LabeledPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
442
- LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
461
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
462
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
463
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
443
464
  ]
444
465
 
445
466
 
446
467
  Get second most confident prediction:
447
468
  >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
448
- [LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
469
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
449
470
 
450
471
  Get predictions where the expected label doesn't match the predicted label:
451
472
  >>> predictions = model.predictions(expected_label_match=False)
452
- [LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
473
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
453
474
  """
454
475
  predictions = list_predictions(
455
476
  body=ListPredictionsRequest(
@@ -462,10 +483,11 @@ class ClassificationModel:
462
483
  ),
463
484
  )
464
485
  return [
465
- LabelPrediction(
486
+ ClassificationPrediction(
466
487
  prediction_id=prediction.prediction_id,
467
488
  label=prediction.label,
468
489
  label_name=prediction.label_name,
490
+ score=None,
469
491
  confidence=prediction.confidence,
470
492
  anomaly_score=prediction.anomaly_score,
471
493
  memoryset=self.memoryset,
@@ -473,60 +495,9 @@ class ClassificationModel:
473
495
  telemetry=prediction,
474
496
  )
475
497
  for prediction in predictions
498
+ if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback)
476
499
  ]
477
500
 
478
- def _calculate_metrics(
479
- self,
480
- predictions: list[LabelPrediction],
481
- expected_labels: list[int],
482
- ) -> ClassificationEvaluationResult:
483
-
484
- targets_array = np.array(expected_labels)
485
- predictions_array = np.array([p.label for p in predictions])
486
-
487
- logits_array = np.array([p.logits for p in predictions])
488
-
489
- f1 = float(f1_score(targets_array, predictions_array, average="weighted"))
490
- accuracy = float(accuracy_score(targets_array, predictions_array))
491
-
492
- # Only compute ROC AUC and PR AUC for binary classification
493
- unique_classes = np.unique(targets_array)
494
-
495
- pr_curve = None
496
- roc_curve = None
497
-
498
- if len(unique_classes) == 2:
499
- try:
500
- precisions, recalls, pr_thresholds = calculate_pr_curve(targets_array, logits_array)
501
- pr_auc = float(auc(recalls, precisions))
502
-
503
- pr_curve = PrecisionRecallCurve(
504
- precisions=precisions.tolist(),
505
- recalls=recalls.tolist(),
506
- thresholds=pr_thresholds.tolist(),
507
- auc=pr_auc,
508
- )
509
-
510
- fpr, tpr, roc_thresholds = calculate_roc_curve(targets_array, logits_array)
511
- roc_auc = float(roc_auc_score(targets_array, logits_array[:, 1]))
512
-
513
- roc_curve = ROCCurve(
514
- false_positive_rates=fpr.tolist(),
515
- true_positive_rates=tpr.tolist(),
516
- thresholds=roc_thresholds.tolist(),
517
- auc=roc_auc,
518
- )
519
- except ValueError as e:
520
- logging.warning(f"Error calculating PR and ROC curves: {e}")
521
-
522
- return ClassificationEvaluationResult(
523
- f1_score=f1,
524
- accuracy=accuracy,
525
- loss=0.0,
526
- precision_recall_curve=pr_curve,
527
- roc_curve=roc_curve,
528
- )
529
-
530
501
  def _evaluate_datasource(
531
502
  self,
532
503
  datasource: Datasource,
@@ -534,10 +505,11 @@ class ClassificationModel:
534
505
  label_column: str,
535
506
  record_predictions: bool,
536
507
  tags: set[str] | None,
537
- ) -> dict[str, Any]:
538
- response = create_evaluation(
508
+ background: bool = False,
509
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
510
+ response = evaluate_classification_model(
539
511
  self.id,
540
- body=EvaluationRequest(
512
+ body=ClassificationEvaluationRequest(
541
513
  datasource_id=datasource.id,
542
514
  datasource_label_column=label_column,
543
515
  datasource_value_column=value_column,
@@ -546,10 +518,13 @@ class ClassificationModel:
546
518
  telemetry_tags=list(tags) if tags else None,
547
519
  ),
548
520
  )
549
- wait_for_task(response.task_id, description="Running evaluation")
550
- response = get_evaluation(self.id, UUID(response.task_id))
551
- assert response.result is not None
552
- return response.result.to_dict()
521
+
522
+ job = Job(
523
+ response.task_id,
524
+ lambda: (r := get_classification_model_evaluation(self.id, UUID(response.task_id)).result)
525
+ and ClassificationMetrics(**r.to_dict()),
526
+ )
527
+ return job if background else job.result()
553
528
 
554
529
  def _evaluate_dataset(
555
530
  self,
@@ -559,34 +534,64 @@ class ClassificationModel:
559
534
  record_predictions: bool,
560
535
  tags: set[str],
561
536
  batch_size: int,
562
- ) -> dict[str, Any]:
563
- predictions = []
564
- expected_labels = []
565
-
566
- for i in range(0, len(dataset), batch_size):
567
- batch = dataset[i : i + batch_size]
568
- predictions.extend(
569
- self.predict(
570
- batch[value_column],
571
- expected_labels=batch[label_column],
572
- tags=tags,
573
- save_telemetry=record_predictions,
574
- save_telemetry_synchronously=(not record_predictions),
575
- )
537
+ ) -> ClassificationMetrics:
538
+ predictions = [
539
+ prediction
540
+ for i in range(0, len(dataset), batch_size)
541
+ for prediction in self.predict(
542
+ dataset[i : i + batch_size][value_column],
543
+ expected_labels=dataset[i : i + batch_size][label_column],
544
+ tags=tags,
545
+ save_telemetry="sync" if record_predictions else "off",
576
546
  )
577
- expected_labels.extend(batch[label_column])
547
+ ]
548
+
549
+ return calculate_classification_metrics(
550
+ expected_labels=dataset[label_column],
551
+ logits=[p.logits for p in predictions],
552
+ anomaly_scores=[p.anomaly_score for p in predictions],
553
+ include_curves=True,
554
+ )
555
+
556
+ @overload
557
+ def evaluate(
558
+ self,
559
+ data: Datasource | Dataset,
560
+ *,
561
+ value_column: str = "value",
562
+ label_column: str = "label",
563
+ record_predictions: bool = False,
564
+ tags: set[str] = {"evaluation"},
565
+ batch_size: int = 100,
566
+ background: Literal[True],
567
+ ) -> Job[ClassificationMetrics]:
568
+ pass
578
569
 
579
- return self._calculate_metrics(predictions, expected_labels).to_dict()
570
+ @overload
571
+ def evaluate(
572
+ self,
573
+ data: Datasource | Dataset,
574
+ *,
575
+ value_column: str = "value",
576
+ label_column: str = "label",
577
+ record_predictions: bool = False,
578
+ tags: set[str] = {"evaluation"},
579
+ batch_size: int = 100,
580
+ background: Literal[False] = False,
581
+ ) -> ClassificationMetrics:
582
+ pass
580
583
 
581
584
  def evaluate(
582
585
  self,
583
586
  data: Datasource | Dataset,
587
+ *,
584
588
  value_column: str = "value",
585
589
  label_column: str = "label",
586
590
  record_predictions: bool = False,
587
591
  tags: set[str] = {"evaluation"},
588
592
  batch_size: int = 100,
589
- ) -> dict[str, Any]:
593
+ background: bool = False,
594
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
590
595
  """
591
596
  Evaluate the classification model on a given dataset or datasource
592
597
 
@@ -594,21 +599,23 @@ class ClassificationModel:
594
599
  data: Dataset or Datasource to evaluate the model on
595
600
  value_column: Name of the column that contains the input values to the model
596
601
  label_column: Name of the column containing the expected labels
597
- record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
598
- tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
602
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
603
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
599
604
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
605
+ background: Whether to run the operation in the background and return a job handle
600
606
 
601
607
  Returns:
602
- Dictionary with evaluation metrics, including anomaly score statistics (mean, median, variance)
608
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
603
609
 
604
610
  Examples:
605
- Evaluate using a Datasource:
606
611
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
607
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
608
-
609
- Evaluate using a Dataset:
610
- >>> model.evaluate(dataset, value_column="text", label_column="sentiment")
611
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
612
+ ClassificationMetrics({
613
+ accuracy: 0.8500,
614
+ f1_score: 0.8500,
615
+ roc_auc: 0.8500,
616
+ pr_auc: 0.8500,
617
+ anomaly_score: 0.3500 ± 0.0500,
618
+ })
612
619
  """
613
620
  if isinstance(data, Datasource):
614
621
  return self._evaluate_datasource(
@@ -617,8 +624,9 @@ class ClassificationModel:
617
624
  label_column=label_column,
618
625
  record_predictions=record_predictions,
619
626
  tags=tags,
627
+ background=background,
620
628
  )
621
- else:
629
+ elif isinstance(data, Dataset):
622
630
  return self._evaluate_dataset(
623
631
  dataset=data,
624
632
  value_column=value_column,
@@ -627,6 +635,8 @@ class ClassificationModel:
627
635
  tags=tags,
628
636
  batch_size=batch_size,
629
637
  )
638
+ else:
639
+ raise ValueError(f"Invalid data type: {type(data)}")
630
640
 
631
641
  def finetune(self, datasource: Datasource):
632
642
  # do not document until implemented