orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -3,39 +3,67 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from datetime import datetime
6
- from typing import Any, Generator, Iterable, Literal, cast, overload
7
- from uuid import UUID
8
-
9
- from ._generated_api_client.api import (
10
- create_evaluation,
11
- create_model,
12
- delete_model,
13
- get_evaluation,
14
- get_model,
15
- list_models,
16
- list_predictions,
17
- predict_gpu,
18
- record_prediction_feedback,
6
+ from typing import (
7
+ Any,
8
+ Generator,
9
+ Iterable,
10
+ Literal,
11
+ cast,
12
+ overload,
19
13
  )
20
- from ._generated_api_client.models import (
21
- ClassificationEvaluationResult,
22
- CreateRACModelRequest,
23
- EvaluationRequest,
24
- ListPredictionsRequest,
14
+
15
+ from datasets import Dataset
16
+
17
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
18
+ from ._utils.common import UNSET, CreateMode, DropMode
19
+ from .async_client import OrcaAsyncClient
20
+ from .client import (
21
+ BootstrapClassificationModelMeta,
22
+ BootstrapClassificationModelResult,
23
+ ClassificationModelMetadata,
24
+ OrcaClient,
25
+ PredictiveModelUpdate,
26
+ RACHeadType,
25
27
  )
26
- from ._generated_api_client.models import (
27
- ListPredictionsRequestSortItemItemType0 as PredictionSortColumns,
28
+ from .datasource import Datasource
29
+ from .job import Job
30
+ from .memoryset import (
31
+ FilterItem,
32
+ FilterItemTuple,
33
+ LabeledMemoryset,
34
+ _is_metric_column,
35
+ _parse_filter_item_from_tuple,
28
36
  )
29
- from ._generated_api_client.models import (
30
- ListPredictionsRequestSortItemItemType1 as PredictionSortDirection,
37
+ from .telemetry import (
38
+ ClassificationPrediction,
39
+ TelemetryMode,
40
+ _get_telemetry_config,
41
+ _parse_feedback,
31
42
  )
32
- from ._generated_api_client.models import RACHeadType, RACModelMetadata
33
- from ._generated_api_client.models.prediction_request import PredictionRequest
34
- from ._utils.common import CreateMode, DropMode
35
- from ._utils.task import wait_for_task
36
- from .datasource import Datasource
37
- from .memoryset import LabeledMemoryset
38
- from .telemetry import LabelPrediction, _parse_feedback
43
+
44
+
45
+ class BootstrappedClassificationModel:
46
+
47
+ datasource: Datasource | None
48
+ memoryset: LabeledMemoryset | None
49
+ classification_model: ClassificationModel | None
50
+ agent_output: BootstrapClassificationModelResult | None
51
+
52
+ def __init__(self, metadata: BootstrapClassificationModelMeta):
53
+ self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
54
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
55
+ self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
56
+ self.agent_output = metadata["agent_output"]
57
+
58
+ def __repr__(self):
59
+ return (
60
+ "BootstrappedClassificationModel({\n"
61
+ f" datasource: {self.datasource},\n"
62
+ f" memoryset: {self.memoryset},\n"
63
+ f" classification_model: {self.classification_model},\n"
64
+ f" agent_output: {self.agent_output},\n"
65
+ "})"
66
+ )
39
67
 
40
68
 
41
69
  class ClassificationModel:
@@ -45,17 +73,20 @@ class ClassificationModel:
45
73
  Attributes:
46
74
  id: Unique identifier for the model
47
75
  name: Unique name of the model
76
+ description: Optional description of the model
48
77
  memoryset: Memoryset that the model uses
49
78
  head_type: Classification head type of the model
50
79
  num_classes: Number of distinct classes the model can predict
51
80
  memory_lookup_count: Number of memories the model uses for each prediction
52
81
  weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
53
82
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
83
+ locked: Whether the model is locked to prevent accidental deletion
54
84
  created_at: When the model was created
55
85
  """
56
86
 
57
87
  id: str
58
88
  name: str
89
+ description: str | None
59
90
  memoryset: LabeledMemoryset
60
91
  head_type: RACHeadType
61
92
  num_classes: int
@@ -63,23 +94,26 @@ class ClassificationModel:
63
94
  weigh_memories: bool | None
64
95
  min_memory_weight: float | None
65
96
  version: int
97
+ locked: bool
66
98
  created_at: datetime
67
99
 
68
- def __init__(self, metadata: RACModelMetadata):
100
+ def __init__(self, metadata: ClassificationModelMetadata):
69
101
  # for internal use only, do not document
70
- self.id = metadata.id
71
- self.name = metadata.name
72
- self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
73
- self.head_type = metadata.head_type
74
- self.num_classes = metadata.num_classes
75
- self.memory_lookup_count = metadata.memory_lookup_count
76
- self.weigh_memories = metadata.weigh_memories
77
- self.min_memory_weight = metadata.min_memory_weight
78
- self.version = metadata.version
79
- self.created_at = metadata.created_at
102
+ self.id = metadata["id"]
103
+ self.name = metadata["name"]
104
+ self.description = metadata["description"]
105
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
106
+ self.head_type = metadata["head_type"]
107
+ self.num_classes = metadata["num_classes"]
108
+ self.memory_lookup_count = metadata["memory_lookup_count"]
109
+ self.weigh_memories = metadata["weigh_memories"]
110
+ self.min_memory_weight = metadata["min_memory_weight"]
111
+ self.version = metadata["version"]
112
+ self.locked = metadata["locked"]
113
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
80
114
 
81
115
  self._memoryset_override_id: str | None = None
82
- self._last_prediction: LabelPrediction | None = None
116
+ self._last_prediction: ClassificationPrediction | None = None
83
117
  self._last_prediction_was_batch: bool = False
84
118
 
85
119
  def __eq__(self, other) -> bool:
@@ -97,7 +131,7 @@ class ClassificationModel:
97
131
  )
98
132
 
99
133
  @property
100
- def last_prediction(self) -> LabelPrediction:
134
+ def last_prediction(self) -> ClassificationPrediction:
101
135
  """
102
136
  Last prediction made by the model
103
137
 
@@ -119,8 +153,9 @@ class ClassificationModel:
119
153
  cls,
120
154
  name: str,
121
155
  memoryset: LabeledMemoryset,
122
- head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
156
+ head_type: RACHeadType = "KNN",
123
157
  *,
158
+ description: str | None = None,
124
159
  num_classes: int | None = None,
125
160
  memory_lookup_count: int | None = None,
126
161
  weigh_memories: bool = True,
@@ -141,6 +176,8 @@ class ClassificationModel:
141
176
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
142
177
  if_exists: What to do if a model with the same name already exists, defaults to
143
178
  `"error"`. Other option is `"open"` to open the existing model.
179
+ description: Optional description for the model, this will be used in agentic flows,
180
+ so make sure it is concise and describes the purpose of your model.
144
181
 
145
182
  Returns:
146
183
  Handle to the new model in the OrcaCloud
@@ -182,16 +219,19 @@ class ClassificationModel:
182
219
 
183
220
  return existing
184
221
 
185
- metadata = create_model(
186
- body=CreateRACModelRequest(
187
- name=name,
188
- memoryset_id=memoryset.id,
189
- head_type=RACHeadType(head_type),
190
- memory_lookup_count=memory_lookup_count,
191
- num_classes=num_classes,
192
- weigh_memories=weigh_memories,
193
- min_memory_weight=min_memory_weight,
194
- ),
222
+ client = OrcaClient._resolve_client()
223
+ metadata = client.POST(
224
+ "/classification_model",
225
+ json={
226
+ "name": name,
227
+ "memoryset_name_or_id": memoryset.id,
228
+ "head_type": head_type,
229
+ "memory_lookup_count": memory_lookup_count,
230
+ "num_classes": num_classes,
231
+ "weigh_memories": weigh_memories,
232
+ "min_memory_weight": min_memory_weight,
233
+ "description": description,
234
+ },
195
235
  )
196
236
  return cls(metadata)
197
237
 
@@ -209,7 +249,8 @@ class ClassificationModel:
209
249
  Raises:
210
250
  LookupError: If the classification model does not exist
211
251
  """
212
- return cls(get_model(name))
252
+ client = OrcaClient._resolve_client()
253
+ return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
213
254
 
214
255
  @classmethod
215
256
  def exists(cls, name_or_id: str) -> bool:
@@ -236,7 +277,8 @@ class ClassificationModel:
236
277
  Returns:
237
278
  List of handles to all classification models in the OrcaCloud
238
279
  """
239
- return [cls(metadata) for metadata in list_models()]
280
+ client = OrcaClient._resolve_client()
281
+ return [cls(metadata) for metadata in client.GET("/classification_model")]
240
282
 
241
283
  @classmethod
242
284
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -255,73 +297,334 @@ class ClassificationModel:
255
297
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
256
298
  """
257
299
  try:
258
- delete_model(name_or_id)
300
+ client = OrcaClient._resolve_client()
301
+ client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
259
302
  logging.info(f"Deleted model {name_or_id}")
260
303
  except LookupError:
261
304
  if if_not_exists == "error":
262
305
  raise
263
306
 
307
+ def refresh(self):
308
+ """Refresh the model data from the OrcaCloud"""
309
+ self.__dict__.update(self.open(self.name).__dict__)
310
+
311
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
312
+ """
313
+ Update editable attributes of the model.
314
+
315
+ Note:
316
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
317
+
318
+ Params:
319
+ description: Value to set for the description
320
+ locked: Value to set for the locked status
321
+
322
+ Examples:
323
+ Update the description:
324
+ >>> model.set(description="New description")
325
+
326
+ Remove description:
327
+ >>> model.set(description=None)
328
+
329
+ Lock the model:
330
+ >>> model.set(locked=True)
331
+ """
332
+ update: PredictiveModelUpdate = {}
333
+ if description is not UNSET:
334
+ update["description"] = description
335
+ if locked is not UNSET:
336
+ update["locked"] = locked
337
+ client = OrcaClient._resolve_client()
338
+ client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
339
+ self.refresh()
340
+
341
+ def lock(self) -> None:
342
+ """Lock the model to prevent accidental deletion"""
343
+ self.set(locked=True)
344
+
345
+ def unlock(self) -> None:
346
+ """Unlock the model to allow deletion"""
347
+ self.set(locked=False)
348
+
264
349
  @overload
265
350
  def predict(
266
- self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
267
- ) -> list[LabelPrediction]:
351
+ self,
352
+ value: list[str],
353
+ expected_labels: list[int] | None = None,
354
+ filters: list[FilterItemTuple] = [],
355
+ tags: set[str] | None = None,
356
+ save_telemetry: TelemetryMode = "on",
357
+ prompt: str | None = None,
358
+ use_lookup_cache: bool = True,
359
+ timeout_seconds: int = 10,
360
+ ) -> list[ClassificationPrediction]:
268
361
  pass
269
362
 
270
363
  @overload
271
- def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
364
+ def predict(
365
+ self,
366
+ value: str,
367
+ expected_labels: int | None = None,
368
+ filters: list[FilterItemTuple] = [],
369
+ tags: set[str] | None = None,
370
+ save_telemetry: TelemetryMode = "on",
371
+ prompt: str | None = None,
372
+ use_lookup_cache: bool = True,
373
+ timeout_seconds: int = 10,
374
+ ) -> ClassificationPrediction:
272
375
  pass
273
376
 
274
377
  def predict(
275
- self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
276
- ) -> list[LabelPrediction] | LabelPrediction:
378
+ self,
379
+ value: list[str] | str,
380
+ expected_labels: list[int] | list[str] | int | str | None = None,
381
+ filters: list[FilterItemTuple] = [],
382
+ tags: set[str] | None = None,
383
+ save_telemetry: TelemetryMode = "on",
384
+ prompt: str | None = None,
385
+ use_lookup_cache: bool = True,
386
+ timeout_seconds: int = 10,
387
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
277
388
  """
278
389
  Predict label(s) for the given input value(s) grounded in similar memories
279
390
 
280
391
  Params:
281
392
  value: Value(s) to get predict the labels of
282
393
  expected_labels: Expected label(s) for the given input to record for model evaluation
394
+ filters: Optional filters to apply during memory lookup
283
395
  tags: Tags to add to the prediction(s)
396
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
397
+ * `"off"`: Do not save telemetry
398
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
399
+ environment variable is set.
400
+ * `"sync"`: Save telemetry synchronously
401
+ * `"async"`: Save telemetry asynchronously
402
+ prompt: Optional prompt to use for instruction-tuned embedding models
403
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
404
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
284
405
 
285
406
  Returns:
286
407
  Label prediction or list of label predictions
287
408
 
409
+ Raises:
410
+ ValueError: If timeout_seconds is not a positive integer
411
+ TimeoutError: If the request times out after the specified duration
412
+
288
413
  Examples:
289
414
  Predict the label for a single value:
290
415
  >>> prediction = model.predict("I am happy", tags={"test"})
291
- LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy' })
416
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
292
417
 
293
418
  Predict the labels for a list of values:
294
419
  >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
295
420
  [
296
- LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
297
- LabelPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
421
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
422
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
298
423
  ]
424
+
425
+ Using a prompt with an instruction-tuned embedding model:
426
+ >>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
427
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
299
428
  """
300
- response = predict_gpu(
301
- self.id,
302
- body=PredictionRequest(
303
- input_values=value if isinstance(value, list) else [value],
304
- memoryset_override_id=self._memoryset_override_id,
305
- expected_labels=(
306
- expected_labels
307
- if isinstance(expected_labels, list)
308
- else [expected_labels]
309
- if expected_labels is not None
310
- else None
311
- ),
312
- tags=list(tags),
313
- ),
429
+
430
+ if timeout_seconds <= 0:
431
+ raise ValueError("timeout_seconds must be a positive integer")
432
+
433
+ parsed_filters = [
434
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
435
+ ]
436
+
437
+ if any(_is_metric_column(filter[0]) for filter in filters):
438
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
439
+
440
+ if isinstance(expected_labels, int):
441
+ expected_labels = [expected_labels]
442
+ elif isinstance(expected_labels, str):
443
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
444
+ elif isinstance(expected_labels, list):
445
+ expected_labels = [
446
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
447
+ for label in expected_labels
448
+ ]
449
+
450
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
451
+ client = OrcaClient._resolve_client()
452
+ response = client.POST(
453
+ "/gpu/classification_model/{name_or_id}/prediction",
454
+ params={"name_or_id": self.id},
455
+ json={
456
+ "input_values": value if isinstance(value, list) else [value],
457
+ "memoryset_override_name_or_id": self._memoryset_override_id,
458
+ "expected_labels": expected_labels,
459
+ "tags": list(tags or set()),
460
+ "save_telemetry": telemetry_on,
461
+ "save_telemetry_synchronously": telemetry_sync,
462
+ "filters": cast(list[FilterItem], parsed_filters),
463
+ "prompt": prompt,
464
+ "use_lookup_cache": use_lookup_cache,
465
+ },
466
+ timeout=timeout_seconds,
314
467
  )
468
+
469
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
470
+ raise RuntimeError("Failed to save prediction to database.")
471
+
315
472
  predictions = [
316
- LabelPrediction(
317
- prediction_id=prediction.prediction_id,
318
- label=prediction.label,
319
- label_name=prediction.label_name,
320
- confidence=prediction.confidence,
473
+ ClassificationPrediction(
474
+ prediction_id=prediction["prediction_id"],
475
+ label=prediction["label"],
476
+ label_name=prediction["label_name"],
477
+ score=None,
478
+ confidence=prediction["confidence"],
479
+ anomaly_score=prediction["anomaly_score"],
321
480
  memoryset=self.memoryset,
322
481
  model=self,
482
+ logits=prediction["logits"],
483
+ input_value=input_value,
323
484
  )
324
- for prediction in response
485
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
486
+ ]
487
+ self._last_prediction_was_batch = isinstance(value, list)
488
+ self._last_prediction = predictions[-1]
489
+ return predictions if isinstance(value, list) else predictions[0]
490
+
491
+ @overload
492
+ async def apredict(
493
+ self,
494
+ value: list[str],
495
+ expected_labels: list[int] | None = None,
496
+ filters: list[FilterItemTuple] = [],
497
+ tags: set[str] | None = None,
498
+ save_telemetry: TelemetryMode = "on",
499
+ prompt: str | None = None,
500
+ use_lookup_cache: bool = True,
501
+ timeout_seconds: int = 10,
502
+ ) -> list[ClassificationPrediction]:
503
+ pass
504
+
505
+ @overload
506
+ async def apredict(
507
+ self,
508
+ value: str,
509
+ expected_labels: int | None = None,
510
+ filters: list[FilterItemTuple] = [],
511
+ tags: set[str] | None = None,
512
+ save_telemetry: TelemetryMode = "on",
513
+ prompt: str | None = None,
514
+ use_lookup_cache: bool = True,
515
+ timeout_seconds: int = 10,
516
+ ) -> ClassificationPrediction:
517
+ pass
518
+
519
+ async def apredict(
520
+ self,
521
+ value: list[str] | str,
522
+ expected_labels: list[int] | list[str] | int | str | None = None,
523
+ filters: list[FilterItemTuple] = [],
524
+ tags: set[str] | None = None,
525
+ save_telemetry: TelemetryMode = "on",
526
+ prompt: str | None = None,
527
+ use_lookup_cache: bool = True,
528
+ timeout_seconds: int = 10,
529
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
530
+ """
531
+ Asynchronously predict label(s) for the given input value(s) grounded in similar memories
532
+
533
+ Params:
534
+ value: Value(s) to get predict the labels of
535
+ expected_labels: Expected label(s) for the given input to record for model evaluation
536
+ filters: Optional filters to apply during memory lookup
537
+ tags: Tags to add to the prediction(s)
538
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
539
+ * `"off"`: Do not save telemetry
540
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
541
+ environment variable is set.
542
+ * `"sync"`: Save telemetry synchronously
543
+ * `"async"`: Save telemetry asynchronously
544
+ prompt: Optional prompt to use for instruction-tuned embedding models
545
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
546
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
547
+
548
+ Returns:
549
+ Label prediction or list of label predictions.
550
+
551
+ Raises:
552
+ ValueError: If timeout_seconds is not a positive integer
553
+ TimeoutError: If the request times out after the specified duration
554
+
555
+ Examples:
556
+ Predict the label for a single value:
557
+ >>> prediction = await model.apredict("I am happy", tags={"test"})
558
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
559
+
560
+ Predict the labels for a list of values:
561
+ >>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
562
+ [
563
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
564
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
565
+ ]
566
+
567
+ Using a prompt with an instruction-tuned embedding model:
568
+ >>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
569
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
570
+ """
571
+
572
+ if timeout_seconds <= 0:
573
+ raise ValueError("timeout_seconds must be a positive integer")
574
+
575
+ parsed_filters = [
576
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
577
+ ]
578
+
579
+ if any(_is_metric_column(filter[0]) for filter in filters):
580
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
581
+
582
+ if isinstance(expected_labels, int):
583
+ expected_labels = [expected_labels]
584
+ elif isinstance(expected_labels, str):
585
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
586
+ elif isinstance(expected_labels, list):
587
+ expected_labels = [
588
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
589
+ for label in expected_labels
590
+ ]
591
+
592
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
593
+ client = OrcaAsyncClient._resolve_client()
594
+ response = await client.POST(
595
+ "/gpu/classification_model/{name_or_id}/prediction",
596
+ params={"name_or_id": self.id},
597
+ json={
598
+ "input_values": value if isinstance(value, list) else [value],
599
+ "memoryset_override_name_or_id": self._memoryset_override_id,
600
+ "expected_labels": expected_labels,
601
+ "tags": list(tags or set()),
602
+ "save_telemetry": telemetry_on,
603
+ "save_telemetry_synchronously": telemetry_sync,
604
+ "filters": cast(list[FilterItem], parsed_filters),
605
+ "prompt": prompt,
606
+ "use_lookup_cache": use_lookup_cache,
607
+ },
608
+ timeout=timeout_seconds,
609
+ )
610
+
611
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
612
+ raise RuntimeError("Failed to save prediction to database.")
613
+
614
+ predictions = [
615
+ ClassificationPrediction(
616
+ prediction_id=prediction["prediction_id"],
617
+ label=prediction["label"],
618
+ label_name=prediction["label_name"],
619
+ score=None,
620
+ confidence=prediction["confidence"],
621
+ anomaly_score=prediction["anomaly_score"],
622
+ memoryset=self.memoryset,
623
+ model=self,
624
+ logits=prediction["logits"],
625
+ input_value=input_value,
626
+ )
627
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
325
628
  ]
326
629
  self._last_prediction_was_batch = isinstance(value, list)
327
630
  self._last_prediction = predictions[-1]
@@ -332,8 +635,9 @@ class ClassificationModel:
332
635
  limit: int = 100,
333
636
  offset: int = 0,
334
637
  tag: str | None = None,
335
- sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
336
- ) -> list[LabelPrediction]:
638
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
639
+ expected_label_match: bool | None = None,
640
+ ) -> list[ClassificationPrediction]:
337
641
  """
338
642
  Get a list of predictions made by this model
339
643
 
@@ -343,6 +647,8 @@ class ClassificationModel:
343
647
  tag: Optional tag to filter predictions by
344
648
  sort: Optional list of columns and directions to sort the predictions by.
345
649
  Predictions can be sorted by `timestamp` or `confidence`.
650
+ expected_label_match: Optional filter to only include predictions where the expected
651
+ label does (`True`) or doesn't (`False`) match the predicted label
346
652
 
347
653
  Returns:
348
654
  List of label predictions
@@ -351,78 +657,212 @@ class ClassificationModel:
351
657
  Get the last 3 predictions:
352
658
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
353
659
  [
354
- LabeledPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
355
- LabeledPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
356
- LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am ecstatic'}),
660
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
661
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
662
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
357
663
  ]
358
664
 
359
665
 
360
666
  Get second most confident prediction:
361
667
  >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
362
- [LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am having a good day'})]
668
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
669
+
670
+ Get predictions where the expected label doesn't match the predicted label:
671
+ >>> predictions = model.predictions(expected_label_match=False)
672
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
363
673
  """
364
- predictions = list_predictions(
365
- body=ListPredictionsRequest(
366
- model_id=self.id,
367
- limit=limit,
368
- offset=offset,
369
- sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
370
- tag=tag,
371
- ),
674
+ client = OrcaClient._resolve_client()
675
+ predictions = client.POST(
676
+ "/telemetry/prediction",
677
+ json={
678
+ "model_id": self.id,
679
+ "limit": limit,
680
+ "offset": offset,
681
+ "sort": [list(sort_item) for sort_item in sort],
682
+ "tag": tag,
683
+ "expected_label_match": expected_label_match,
684
+ },
372
685
  )
373
686
  return [
374
- LabelPrediction(
375
- prediction_id=prediction.prediction_id,
376
- label=prediction.label,
377
- label_name=prediction.label_name,
378
- confidence=prediction.confidence,
687
+ ClassificationPrediction(
688
+ prediction_id=prediction["prediction_id"],
689
+ label=prediction["label"],
690
+ label_name=prediction["label_name"],
691
+ score=None,
692
+ confidence=prediction["confidence"],
693
+ anomaly_score=prediction["anomaly_score"],
379
694
  memoryset=self.memoryset,
380
695
  model=self,
381
696
  telemetry=prediction,
382
697
  )
383
698
  for prediction in predictions
699
+ if "label" in prediction
384
700
  ]
385
701
 
386
- def evaluate(
702
+ def _evaluate_datasource(
387
703
  self,
388
704
  datasource: Datasource,
705
+ value_column: str,
706
+ label_column: str,
707
+ record_predictions: bool,
708
+ tags: set[str] | None,
709
+ background: bool = False,
710
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
711
+ client = OrcaClient._resolve_client()
712
+ response = client.POST(
713
+ "/classification_model/{model_name_or_id}/evaluation",
714
+ params={"model_name_or_id": self.id},
715
+ json={
716
+ "datasource_name_or_id": datasource.id,
717
+ "datasource_label_column": label_column,
718
+ "datasource_value_column": value_column,
719
+ "memoryset_override_name_or_id": self._memoryset_override_id,
720
+ "record_telemetry": record_predictions,
721
+ "telemetry_tags": list(tags) if tags else None,
722
+ },
723
+ )
724
+
725
+ def get_value():
726
+ client = OrcaClient._resolve_client()
727
+ res = client.GET(
728
+ "/classification_model/{model_name_or_id}/evaluation/{task_id}",
729
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
730
+ )
731
+ assert res["result"] is not None
732
+ return ClassificationMetrics(
733
+ coverage=res["result"].get("coverage"),
734
+ f1_score=res["result"].get("f1_score"),
735
+ accuracy=res["result"].get("accuracy"),
736
+ loss=res["result"].get("loss"),
737
+ anomaly_score_mean=res["result"].get("anomaly_score_mean"),
738
+ anomaly_score_median=res["result"].get("anomaly_score_median"),
739
+ anomaly_score_variance=res["result"].get("anomaly_score_variance"),
740
+ roc_auc=res["result"].get("roc_auc"),
741
+ pr_auc=res["result"].get("pr_auc"),
742
+ pr_curve=res["result"].get("pr_curve"),
743
+ roc_curve=res["result"].get("roc_curve"),
744
+ )
745
+
746
+ job = Job(response["task_id"], get_value)
747
+ return job if background else job.result()
748
+
749
+ def _evaluate_dataset(
750
+ self,
751
+ dataset: Dataset,
752
+ value_column: str,
753
+ label_column: str,
754
+ record_predictions: bool,
755
+ tags: set[str],
756
+ batch_size: int,
757
+ ) -> ClassificationMetrics:
758
+ if len(dataset) == 0:
759
+ raise ValueError("Evaluation dataset cannot be empty")
760
+
761
+ if any(x is None for x in dataset[label_column]):
762
+ raise ValueError("Evaluation dataset cannot contain None values in the label column")
763
+
764
+ predictions = [
765
+ prediction
766
+ for i in range(0, len(dataset), batch_size)
767
+ for prediction in self.predict(
768
+ dataset[i : i + batch_size][value_column],
769
+ expected_labels=dataset[i : i + batch_size][label_column],
770
+ tags=tags,
771
+ save_telemetry="sync" if record_predictions else "off",
772
+ )
773
+ ]
774
+
775
+ return calculate_classification_metrics(
776
+ expected_labels=dataset[label_column],
777
+ logits=[p.logits for p in predictions],
778
+ anomaly_scores=[p.anomaly_score for p in predictions],
779
+ include_curves=True,
780
+ )
781
+
782
+ @overload
783
+ def evaluate(
784
+ self,
785
+ data: Datasource | Dataset,
786
+ *,
389
787
  value_column: str = "value",
390
788
  label_column: str = "label",
391
789
  record_predictions: bool = False,
392
- tags: set[str] | None = None,
393
- ) -> dict[str, float]:
790
+ tags: set[str] = {"evaluation"},
791
+ batch_size: int = 100,
792
+ background: Literal[True],
793
+ ) -> Job[ClassificationMetrics]:
794
+ pass
795
+
796
+ @overload
797
+ def evaluate(
798
+ self,
799
+ data: Datasource | Dataset,
800
+ *,
801
+ value_column: str = "value",
802
+ label_column: str = "label",
803
+ record_predictions: bool = False,
804
+ tags: set[str] = {"evaluation"},
805
+ batch_size: int = 100,
806
+ background: Literal[False] = False,
807
+ ) -> ClassificationMetrics:
808
+ pass
809
+
810
+ def evaluate(
811
+ self,
812
+ data: Datasource | Dataset,
813
+ *,
814
+ value_column: str = "value",
815
+ label_column: str = "label",
816
+ record_predictions: bool = False,
817
+ tags: set[str] = {"evaluation"},
818
+ batch_size: int = 100,
819
+ background: bool = False,
820
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
394
821
  """
395
- Evaluate the classification model on a given datasource
822
+ Evaluate the classification model on a given dataset or datasource
396
823
 
397
824
  Params:
398
- datasource: Datasource to evaluate the model on
825
+ data: Dataset or Datasource to evaluate the model on
399
826
  value_column: Name of the column that contains the input values to the model
400
827
  label_column: Name of the column containing the expected labels
401
- record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
402
- tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
828
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
829
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
830
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
831
+ background: Whether to run the operation in the background and return a job handle
403
832
 
404
833
  Returns:
405
- Dictionary with evaluation metrics
834
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
406
835
 
407
836
  Examples:
408
837
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
409
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35 }
838
+ ClassificationMetrics({
839
+ accuracy: 0.8500,
840
+ f1_score: 0.8500,
841
+ roc_auc: 0.8500,
842
+ pr_auc: 0.8500,
843
+ anomaly_score: 0.3500 ± 0.0500,
844
+ })
410
845
  """
411
- response = create_evaluation(
412
- self.id,
413
- body=EvaluationRequest(
414
- datasource_id=datasource.id,
415
- datasource_label_column=label_column,
416
- datasource_value_column=value_column,
417
- memoryset_override_id=self._memoryset_override_id,
418
- record_telemetry=record_predictions,
419
- telemetry_tags=list(tags) if tags else None,
420
- ),
421
- )
422
- wait_for_task(response.task_id, description="Running evaluation")
423
- response = get_evaluation(self.id, UUID(response.task_id))
424
- assert response.result is not None
425
- return response.result.to_dict()
846
+ if isinstance(data, Datasource):
847
+ return self._evaluate_datasource(
848
+ datasource=data,
849
+ value_column=value_column,
850
+ label_column=label_column,
851
+ record_predictions=record_predictions,
852
+ tags=tags,
853
+ background=background,
854
+ )
855
+ elif isinstance(data, Dataset):
856
+ return self._evaluate_dataset(
857
+ dataset=data,
858
+ value_column=value_column,
859
+ label_column=label_column,
860
+ record_predictions=record_predictions,
861
+ tags=tags,
862
+ batch_size=batch_size,
863
+ )
864
+ else:
865
+ raise ValueError(f"Invalid data type: {type(data)}")
426
866
 
427
867
  def finetune(self, datasource: Datasource):
428
868
  # do not document until implemented
@@ -492,8 +932,40 @@ class ClassificationModel:
492
932
  ValueError: If the value does not match previous value types for the category, or is a
493
933
  [`float`][float] that is not between `-1.0` and `+1.0`.
494
934
  """
495
- record_prediction_feedback(
496
- body=[
935
+ client = OrcaClient._resolve_client()
936
+ client.PUT(
937
+ "/telemetry/prediction/feedback",
938
+ json=[
497
939
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
498
940
  ],
499
941
  )
942
+
943
+ @staticmethod
944
+ def bootstrap_model(
945
+ model_description: str,
946
+ label_names: list[str],
947
+ initial_examples: list[tuple[str, str]],
948
+ num_examples_per_label: int,
949
+ background: bool = False,
950
+ ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
951
+ client = OrcaClient._resolve_client()
952
+ response = client.POST(
953
+ "/agents/bootstrap_classification_model",
954
+ json={
955
+ "model_description": model_description,
956
+ "label_names": label_names,
957
+ "initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
958
+ "num_examples_per_label": num_examples_per_label,
959
+ },
960
+ )
961
+
962
+ def get_result() -> BootstrappedClassificationModel:
963
+ client = OrcaClient._resolve_client()
964
+ res = client.GET(
965
+ "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
966
+ )
967
+ assert res["result"] is not None
968
+ return BootstrappedClassificationModel(res["result"])
969
+
970
+ job = Job(response["task_id"], get_result)
971
+ return job if background else job.result()