orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -4,38 +4,58 @@ import logging
4
4
  from contextlib import contextmanager
5
5
  from datetime import datetime
6
6
  from typing import Any, Generator, Iterable, Literal, cast, overload
7
- from uuid import UUID
8
-
9
- from ._generated_api_client.api import (
10
- create_evaluation,
11
- create_model,
12
- delete_model,
13
- get_evaluation,
14
- get_model,
15
- list_models,
16
- list_predictions,
17
- predict_gpu,
18
- record_prediction_feedback,
19
- )
20
- from ._generated_api_client.models import (
21
- ClassificationEvaluationResult,
22
- CreateRACModelRequest,
23
- EvaluationRequest,
24
- ListPredictionsRequest,
7
+
8
+ from datasets import Dataset
9
+
10
+ from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .client import (
13
+ BootstrapClassificationModelMeta,
14
+ BootstrapClassificationModelResult,
15
+ ClassificationModelMetadata,
16
+ PredictiveModelUpdate,
17
+ RACHeadType,
18
+ orca_api,
25
19
  )
26
- from ._generated_api_client.models import (
27
- ListPredictionsRequestSortItemItemType0 as PredictionSortColumns,
20
+ from .datasource import Datasource
21
+ from .job import Job
22
+ from .memoryset import (
23
+ FilterItem,
24
+ FilterItemTuple,
25
+ LabeledMemoryset,
26
+ _is_metric_column,
27
+ _parse_filter_item_from_tuple,
28
28
  )
29
- from ._generated_api_client.models import (
30
- ListPredictionsRequestSortItemItemType1 as PredictionSortDirection,
29
+ from .telemetry import (
30
+ ClassificationPrediction,
31
+ TelemetryMode,
32
+ _get_telemetry_config,
33
+ _parse_feedback,
31
34
  )
32
- from ._generated_api_client.models import RACHeadType, RACModelMetadata
33
- from ._generated_api_client.models.prediction_request import PredictionRequest
34
- from ._utils.common import CreateMode, DropMode
35
- from ._utils.task import wait_for_task
36
- from .datasource import Datasource
37
- from .memoryset import LabeledMemoryset
38
- from .telemetry import LabelPrediction, _parse_feedback
35
+
36
+
37
+ class BootstrappedClassificationModel:
38
+
39
+ datasource: Datasource | None
40
+ memoryset: LabeledMemoryset | None
41
+ classification_model: ClassificationModel | None
42
+ agent_output: BootstrapClassificationModelResult | None
43
+
44
+ def __init__(self, metadata: BootstrapClassificationModelMeta):
45
+ self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
46
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
47
+ self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
48
+ self.agent_output = metadata["agent_output"]
49
+
50
+ def __repr__(self):
51
+ return (
52
+ "BootstrappedClassificationModel({\n"
53
+ f" datasource: {self.datasource},\n"
54
+ f" memoryset: {self.memoryset},\n"
55
+ f" classification_model: {self.classification_model},\n"
56
+ f" agent_output: {self.agent_output},\n"
57
+ "})"
58
+ )
39
59
 
40
60
 
41
61
  class ClassificationModel:
@@ -45,17 +65,20 @@ class ClassificationModel:
45
65
  Attributes:
46
66
  id: Unique identifier for the model
47
67
  name: Unique name of the model
68
+ description: Optional description of the model
48
69
  memoryset: Memoryset that the model uses
49
70
  head_type: Classification head type of the model
50
71
  num_classes: Number of distinct classes the model can predict
51
72
  memory_lookup_count: Number of memories the model uses for each prediction
52
73
  weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
53
74
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
75
+ locked: Whether the model is locked to prevent accidental deletion
54
76
  created_at: When the model was created
55
77
  """
56
78
 
57
79
  id: str
58
80
  name: str
81
+ description: str | None
59
82
  memoryset: LabeledMemoryset
60
83
  head_type: RACHeadType
61
84
  num_classes: int
@@ -63,23 +86,26 @@ class ClassificationModel:
63
86
  weigh_memories: bool | None
64
87
  min_memory_weight: float | None
65
88
  version: int
89
+ locked: bool
66
90
  created_at: datetime
67
91
 
68
- def __init__(self, metadata: RACModelMetadata):
92
+ def __init__(self, metadata: ClassificationModelMetadata):
69
93
  # for internal use only, do not document
70
- self.id = metadata.id
71
- self.name = metadata.name
72
- self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
73
- self.head_type = metadata.head_type
74
- self.num_classes = metadata.num_classes
75
- self.memory_lookup_count = metadata.memory_lookup_count
76
- self.weigh_memories = metadata.weigh_memories
77
- self.min_memory_weight = metadata.min_memory_weight
78
- self.version = metadata.version
79
- self.created_at = metadata.created_at
94
+ self.id = metadata["id"]
95
+ self.name = metadata["name"]
96
+ self.description = metadata["description"]
97
+ self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
98
+ self.head_type = metadata["head_type"]
99
+ self.num_classes = metadata["num_classes"]
100
+ self.memory_lookup_count = metadata["memory_lookup_count"]
101
+ self.weigh_memories = metadata["weigh_memories"]
102
+ self.min_memory_weight = metadata["min_memory_weight"]
103
+ self.version = metadata["version"]
104
+ self.locked = metadata["locked"]
105
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
80
106
 
81
107
  self._memoryset_override_id: str | None = None
82
- self._last_prediction: LabelPrediction | None = None
108
+ self._last_prediction: ClassificationPrediction | None = None
83
109
  self._last_prediction_was_batch: bool = False
84
110
 
85
111
  def __eq__(self, other) -> bool:
@@ -97,7 +123,7 @@ class ClassificationModel:
97
123
  )
98
124
 
99
125
  @property
100
- def last_prediction(self) -> LabelPrediction:
126
+ def last_prediction(self) -> ClassificationPrediction:
101
127
  """
102
128
  Last prediction made by the model
103
129
 
@@ -119,8 +145,9 @@ class ClassificationModel:
119
145
  cls,
120
146
  name: str,
121
147
  memoryset: LabeledMemoryset,
122
- head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
148
+ head_type: RACHeadType = "KNN",
123
149
  *,
150
+ description: str | None = None,
124
151
  num_classes: int | None = None,
125
152
  memory_lookup_count: int | None = None,
126
153
  weigh_memories: bool = True,
@@ -141,6 +168,8 @@ class ClassificationModel:
141
168
  min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
142
169
  if_exists: What to do if a model with the same name already exists, defaults to
143
170
  `"error"`. Other option is `"open"` to open the existing model.
171
+ description: Optional description for the model, this will be used in agentic flows,
172
+ so make sure it is concise and describes the purpose of your model.
144
173
 
145
174
  Returns:
146
175
  Handle to the new model in the OrcaCloud
@@ -182,16 +211,18 @@ class ClassificationModel:
182
211
 
183
212
  return existing
184
213
 
185
- metadata = create_model(
186
- body=CreateRACModelRequest(
187
- name=name,
188
- memoryset_id=memoryset.id,
189
- head_type=RACHeadType(head_type),
190
- memory_lookup_count=memory_lookup_count,
191
- num_classes=num_classes,
192
- weigh_memories=weigh_memories,
193
- min_memory_weight=min_memory_weight,
194
- ),
214
+ metadata = orca_api.POST(
215
+ "/classification_model",
216
+ json={
217
+ "name": name,
218
+ "memoryset_name_or_id": memoryset.id,
219
+ "head_type": head_type,
220
+ "memory_lookup_count": memory_lookup_count,
221
+ "num_classes": num_classes,
222
+ "weigh_memories": weigh_memories,
223
+ "min_memory_weight": min_memory_weight,
224
+ "description": description,
225
+ },
195
226
  )
196
227
  return cls(metadata)
197
228
 
@@ -209,7 +240,7 @@ class ClassificationModel:
209
240
  Raises:
210
241
  LookupError: If the classification model does not exist
211
242
  """
212
- return cls(get_model(name))
243
+ return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
213
244
 
214
245
  @classmethod
215
246
  def exists(cls, name_or_id: str) -> bool:
@@ -236,7 +267,7 @@ class ClassificationModel:
236
267
  Returns:
237
268
  List of handles to all classification models in the OrcaCloud
238
269
  """
239
- return [cls(metadata) for metadata in list_models()]
270
+ return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
240
271
 
241
272
  @classmethod
242
273
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -255,73 +286,189 @@ class ClassificationModel:
255
286
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
256
287
  """
257
288
  try:
258
- delete_model(name_or_id)
289
+ orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
259
290
  logging.info(f"Deleted model {name_or_id}")
260
291
  except LookupError:
261
292
  if if_not_exists == "error":
262
293
  raise
263
294
 
295
+ def refresh(self):
296
+ """Refresh the model data from the OrcaCloud"""
297
+ self.__dict__.update(self.open(self.name).__dict__)
298
+
299
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
300
+ """
301
+ Update editable attributes of the model.
302
+
303
+ Note:
304
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
305
+
306
+ Params:
307
+ description: Value to set for the description
308
+ locked: Value to set for the locked status
309
+
310
+ Examples:
311
+ Update the description:
312
+ >>> model.set(description="New description")
313
+
314
+ Remove description:
315
+ >>> model.set(description=None)
316
+
317
+ Lock the model:
318
+ >>> model.set(locked=True)
319
+ """
320
+ update: PredictiveModelUpdate = {}
321
+ if description is not UNSET:
322
+ update["description"] = description
323
+ if locked is not UNSET:
324
+ update["locked"] = locked
325
+ orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
326
+ self.refresh()
327
+
328
+ def lock(self) -> None:
329
+ """Lock the model to prevent accidental deletion"""
330
+ self.set(locked=True)
331
+
332
+ def unlock(self) -> None:
333
+ """Unlock the model to allow deletion"""
334
+ self.set(locked=False)
335
+
264
336
  @overload
265
337
  def predict(
266
- self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
267
- ) -> list[LabelPrediction]:
338
+ self,
339
+ value: list[str],
340
+ expected_labels: list[int] | None = None,
341
+ filters: list[FilterItemTuple] = [],
342
+ tags: set[str] | None = None,
343
+ save_telemetry: TelemetryMode = "on",
344
+ prompt: str | None = None,
345
+ use_lookup_cache: bool = True,
346
+ timeout_seconds: int = 10,
347
+ ) -> list[ClassificationPrediction]:
268
348
  pass
269
349
 
270
350
  @overload
271
- def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
351
+ def predict(
352
+ self,
353
+ value: str,
354
+ expected_labels: int | None = None,
355
+ filters: list[FilterItemTuple] = [],
356
+ tags: set[str] | None = None,
357
+ save_telemetry: TelemetryMode = "on",
358
+ prompt: str | None = None,
359
+ use_lookup_cache: bool = True,
360
+ timeout_seconds: int = 10,
361
+ ) -> ClassificationPrediction:
272
362
  pass
273
363
 
274
364
  def predict(
275
- self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
276
- ) -> list[LabelPrediction] | LabelPrediction:
365
+ self,
366
+ value: list[str] | str,
367
+ expected_labels: list[int] | list[str] | int | str | None = None,
368
+ filters: list[FilterItemTuple] = [],
369
+ tags: set[str] | None = None,
370
+ save_telemetry: TelemetryMode = "on",
371
+ prompt: str | None = None,
372
+ use_lookup_cache: bool = True,
373
+ timeout_seconds: int = 10,
374
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
277
375
  """
278
376
  Predict label(s) for the given input value(s) grounded in similar memories
279
377
 
280
378
  Params:
281
379
  value: Value(s) to get predict the labels of
282
380
  expected_labels: Expected label(s) for the given input to record for model evaluation
381
+ filters: Optional filters to apply during memory lookup
283
382
  tags: Tags to add to the prediction(s)
383
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
384
+ * `"off"`: Do not save telemetry
385
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
386
+ environment variable is set.
387
+ * `"sync"`: Save telemetry synchronously
388
+ * `"async"`: Save telemetry asynchronously
389
+ prompt: Optional prompt to use for instruction-tuned embedding models
390
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
391
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
284
392
 
285
393
  Returns:
286
394
  Label prediction or list of label predictions
287
395
 
396
+ Raises:
397
+ ValueError: If timeout_seconds is not a positive integer
398
+ TimeoutError: If the request times out after the specified duration
399
+
288
400
  Examples:
289
401
  Predict the label for a single value:
290
402
  >>> prediction = model.predict("I am happy", tags={"test"})
291
- LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy' })
403
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
292
404
 
293
405
  Predict the labels for a list of values:
294
406
  >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
295
407
  [
296
- LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
297
- LabelPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
408
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
409
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
298
410
  ]
411
+
412
+ Using a prompt with an instruction-tuned embedding model:
413
+ >>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
414
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
299
415
  """
300
- response = predict_gpu(
301
- self.id,
302
- body=PredictionRequest(
303
- input_values=value if isinstance(value, list) else [value],
304
- memoryset_override_id=self._memoryset_override_id,
305
- expected_labels=(
306
- expected_labels
307
- if isinstance(expected_labels, list)
308
- else [expected_labels]
309
- if expected_labels is not None
310
- else None
311
- ),
312
- tags=list(tags),
313
- ),
416
+
417
+ if timeout_seconds <= 0:
418
+ raise ValueError("timeout_seconds must be a positive integer")
419
+
420
+ parsed_filters = [
421
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
422
+ ]
423
+
424
+ if any(_is_metric_column(filter[0]) for filter in filters):
425
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
426
+
427
+ if isinstance(expected_labels, int):
428
+ expected_labels = [expected_labels]
429
+ elif isinstance(expected_labels, str):
430
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
431
+ elif isinstance(expected_labels, list):
432
+ expected_labels = [
433
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
434
+ for label in expected_labels
435
+ ]
436
+
437
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
438
+ response = orca_api.POST(
439
+ "/gpu/classification_model/{name_or_id}/prediction",
440
+ params={"name_or_id": self.id},
441
+ json={
442
+ "input_values": value if isinstance(value, list) else [value],
443
+ "memoryset_override_name_or_id": self._memoryset_override_id,
444
+ "expected_labels": expected_labels,
445
+ "tags": list(tags or set()),
446
+ "save_telemetry": telemetry_on,
447
+ "save_telemetry_synchronously": telemetry_sync,
448
+ "filters": cast(list[FilterItem], parsed_filters),
449
+ "prompt": prompt,
450
+ "use_lookup_cache": use_lookup_cache,
451
+ },
452
+ timeout=timeout_seconds,
314
453
  )
454
+
455
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
456
+ raise RuntimeError("Failed to save prediction to database.")
457
+
315
458
  predictions = [
316
- LabelPrediction(
317
- prediction_id=prediction.prediction_id,
318
- label=prediction.label,
319
- label_name=prediction.label_name,
320
- confidence=prediction.confidence,
459
+ ClassificationPrediction(
460
+ prediction_id=prediction["prediction_id"],
461
+ label=prediction["label"],
462
+ label_name=prediction["label_name"],
463
+ score=None,
464
+ confidence=prediction["confidence"],
465
+ anomaly_score=prediction["anomaly_score"],
321
466
  memoryset=self.memoryset,
322
467
  model=self,
468
+ logits=prediction["logits"],
469
+ input_value=input_value,
323
470
  )
324
- for prediction in response
471
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
325
472
  ]
326
473
  self._last_prediction_was_batch = isinstance(value, list)
327
474
  self._last_prediction = predictions[-1]
@@ -332,8 +479,9 @@ class ClassificationModel:
332
479
  limit: int = 100,
333
480
  offset: int = 0,
334
481
  tag: str | None = None,
335
- sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
336
- ) -> list[LabelPrediction]:
482
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
483
+ expected_label_match: bool | None = None,
484
+ ) -> list[ClassificationPrediction]:
337
485
  """
338
486
  Get a list of predictions made by this model
339
487
 
@@ -343,6 +491,8 @@ class ClassificationModel:
343
491
  tag: Optional tag to filter predictions by
344
492
  sort: Optional list of columns and directions to sort the predictions by.
345
493
  Predictions can be sorted by `timestamp` or `confidence`.
494
+ expected_label_match: Optional filter to only include predictions where the expected
495
+ label does (`True`) or doesn't (`False`) match the predicted label
346
496
 
347
497
  Returns:
348
498
  List of label predictions
@@ -351,78 +501,209 @@ class ClassificationModel:
351
501
  Get the last 3 predictions:
352
502
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
353
503
  [
354
- LabeledPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
355
- LabeledPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
356
- LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am ecstatic'}),
504
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
505
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
506
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
357
507
  ]
358
508
 
359
509
 
360
510
  Get second most confident prediction:
361
511
  >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
362
- [LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am having a good day'})]
512
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
513
+
514
+ Get predictions where the expected label doesn't match the predicted label:
515
+ >>> predictions = model.predictions(expected_label_match=False)
516
+ [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
363
517
  """
364
- predictions = list_predictions(
365
- body=ListPredictionsRequest(
366
- model_id=self.id,
367
- limit=limit,
368
- offset=offset,
369
- sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
370
- tag=tag,
371
- ),
518
+ predictions = orca_api.POST(
519
+ "/telemetry/prediction",
520
+ json={
521
+ "model_id": self.id,
522
+ "limit": limit,
523
+ "offset": offset,
524
+ "sort": [list(sort_item) for sort_item in sort],
525
+ "tag": tag,
526
+ "expected_label_match": expected_label_match,
527
+ },
372
528
  )
373
529
  return [
374
- LabelPrediction(
375
- prediction_id=prediction.prediction_id,
376
- label=prediction.label,
377
- label_name=prediction.label_name,
378
- confidence=prediction.confidence,
530
+ ClassificationPrediction(
531
+ prediction_id=prediction["prediction_id"],
532
+ label=prediction["label"],
533
+ label_name=prediction["label_name"],
534
+ score=None,
535
+ confidence=prediction["confidence"],
536
+ anomaly_score=prediction["anomaly_score"],
379
537
  memoryset=self.memoryset,
380
538
  model=self,
381
539
  telemetry=prediction,
382
540
  )
383
541
  for prediction in predictions
542
+ if "label" in prediction
384
543
  ]
385
544
 
386
- def evaluate(
545
+ def _evaluate_datasource(
387
546
  self,
388
547
  datasource: Datasource,
548
+ value_column: str,
549
+ label_column: str,
550
+ record_predictions: bool,
551
+ tags: set[str] | None,
552
+ background: bool = False,
553
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
554
+ response = orca_api.POST(
555
+ "/classification_model/{model_name_or_id}/evaluation",
556
+ params={"model_name_or_id": self.id},
557
+ json={
558
+ "datasource_name_or_id": datasource.id,
559
+ "datasource_label_column": label_column,
560
+ "datasource_value_column": value_column,
561
+ "memoryset_override_name_or_id": self._memoryset_override_id,
562
+ "record_telemetry": record_predictions,
563
+ "telemetry_tags": list(tags) if tags else None,
564
+ },
565
+ )
566
+
567
+ def get_value():
568
+ res = orca_api.GET(
569
+ "/classification_model/{model_name_or_id}/evaluation/{task_id}",
570
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
571
+ )
572
+ assert res["result"] is not None
573
+ return ClassificationMetrics(
574
+ coverage=res["result"].get("coverage"),
575
+ f1_score=res["result"].get("f1_score"),
576
+ accuracy=res["result"].get("accuracy"),
577
+ loss=res["result"].get("loss"),
578
+ anomaly_score_mean=res["result"].get("anomaly_score_mean"),
579
+ anomaly_score_median=res["result"].get("anomaly_score_median"),
580
+ anomaly_score_variance=res["result"].get("anomaly_score_variance"),
581
+ roc_auc=res["result"].get("roc_auc"),
582
+ pr_auc=res["result"].get("pr_auc"),
583
+ pr_curve=res["result"].get("pr_curve"),
584
+ roc_curve=res["result"].get("roc_curve"),
585
+ )
586
+
587
+ job = Job(response["task_id"], get_value)
588
+ return job if background else job.result()
589
+
590
+ def _evaluate_dataset(
591
+ self,
592
+ dataset: Dataset,
593
+ value_column: str,
594
+ label_column: str,
595
+ record_predictions: bool,
596
+ tags: set[str],
597
+ batch_size: int,
598
+ ) -> ClassificationMetrics:
599
+ if len(dataset) == 0:
600
+ raise ValueError("Evaluation dataset cannot be empty")
601
+
602
+ if any(x is None for x in dataset[label_column]):
603
+ raise ValueError("Evaluation dataset cannot contain None values in the label column")
604
+
605
+ predictions = [
606
+ prediction
607
+ for i in range(0, len(dataset), batch_size)
608
+ for prediction in self.predict(
609
+ dataset[i : i + batch_size][value_column],
610
+ expected_labels=dataset[i : i + batch_size][label_column],
611
+ tags=tags,
612
+ save_telemetry="sync" if record_predictions else "off",
613
+ )
614
+ ]
615
+
616
+ return calculate_classification_metrics(
617
+ expected_labels=dataset[label_column],
618
+ logits=[p.logits for p in predictions],
619
+ anomaly_scores=[p.anomaly_score for p in predictions],
620
+ include_curves=True,
621
+ )
622
+
623
+ @overload
624
+ def evaluate(
625
+ self,
626
+ data: Datasource | Dataset,
627
+ *,
389
628
  value_column: str = "value",
390
629
  label_column: str = "label",
391
630
  record_predictions: bool = False,
392
- tags: set[str] | None = None,
393
- ) -> dict[str, float]:
631
+ tags: set[str] = {"evaluation"},
632
+ batch_size: int = 100,
633
+ background: Literal[True],
634
+ ) -> Job[ClassificationMetrics]:
635
+ pass
636
+
637
+ @overload
638
+ def evaluate(
639
+ self,
640
+ data: Datasource | Dataset,
641
+ *,
642
+ value_column: str = "value",
643
+ label_column: str = "label",
644
+ record_predictions: bool = False,
645
+ tags: set[str] = {"evaluation"},
646
+ batch_size: int = 100,
647
+ background: Literal[False] = False,
648
+ ) -> ClassificationMetrics:
649
+ pass
650
+
651
+ def evaluate(
652
+ self,
653
+ data: Datasource | Dataset,
654
+ *,
655
+ value_column: str = "value",
656
+ label_column: str = "label",
657
+ record_predictions: bool = False,
658
+ tags: set[str] = {"evaluation"},
659
+ batch_size: int = 100,
660
+ background: bool = False,
661
+ ) -> ClassificationMetrics | Job[ClassificationMetrics]:
394
662
  """
395
- Evaluate the classification model on a given datasource
663
+ Evaluate the classification model on a given dataset or datasource
396
664
 
397
665
  Params:
398
- datasource: Datasource to evaluate the model on
666
+ data: Dataset or Datasource to evaluate the model on
399
667
  value_column: Name of the column that contains the input values to the model
400
668
  label_column: Name of the column containing the expected labels
401
- record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
402
- tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
669
+ record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
670
+ tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
671
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
672
+ background: Whether to run the operation in the background and return a job handle
403
673
 
404
674
  Returns:
405
- Dictionary with evaluation metrics
675
+ EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
406
676
 
407
677
  Examples:
408
678
  >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
409
- { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35 }
679
+ ClassificationMetrics({
680
+ accuracy: 0.8500,
681
+ f1_score: 0.8500,
682
+ roc_auc: 0.8500,
683
+ pr_auc: 0.8500,
684
+ anomaly_score: 0.3500 ± 0.0500,
685
+ })
410
686
  """
411
- response = create_evaluation(
412
- self.id,
413
- body=EvaluationRequest(
414
- datasource_id=datasource.id,
415
- datasource_label_column=label_column,
416
- datasource_value_column=value_column,
417
- memoryset_override_id=self._memoryset_override_id,
418
- record_telemetry=record_predictions,
419
- telemetry_tags=list(tags) if tags else None,
420
- ),
421
- )
422
- wait_for_task(response.task_id, description="Running evaluation")
423
- response = get_evaluation(self.id, UUID(response.task_id))
424
- assert response.result is not None
425
- return response.result.to_dict()
687
+ if isinstance(data, Datasource):
688
+ return self._evaluate_datasource(
689
+ datasource=data,
690
+ value_column=value_column,
691
+ label_column=label_column,
692
+ record_predictions=record_predictions,
693
+ tags=tags,
694
+ background=background,
695
+ )
696
+ elif isinstance(data, Dataset):
697
+ return self._evaluate_dataset(
698
+ dataset=data,
699
+ value_column=value_column,
700
+ label_column=label_column,
701
+ record_predictions=record_predictions,
702
+ tags=tags,
703
+ batch_size=batch_size,
704
+ )
705
+ else:
706
+ raise ValueError(f"Invalid data type: {type(data)}")
426
707
 
427
708
  def finetune(self, datasource: Datasource):
428
709
  # do not document until implemented
@@ -492,8 +773,37 @@ class ClassificationModel:
492
773
  ValueError: If the value does not match previous value types for the category, or is a
493
774
  [`float`][float] that is not between `-1.0` and `+1.0`.
494
775
  """
495
- record_prediction_feedback(
496
- body=[
776
+ orca_api.PUT(
777
+ "/telemetry/prediction/feedback",
778
+ json=[
497
779
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
498
780
  ],
499
781
  )
782
+
783
+ @staticmethod
784
+ def bootstrap_model(
785
+ model_description: str,
786
+ label_names: list[str],
787
+ initial_examples: list[tuple[str, str]],
788
+ num_examples_per_label: int,
789
+ background: bool = False,
790
+ ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
791
+ response = orca_api.POST(
792
+ "/agents/bootstrap_classification_model",
793
+ json={
794
+ "model_description": model_description,
795
+ "label_names": label_names,
796
+ "initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
797
+ "num_examples_per_label": num_examples_per_label,
798
+ },
799
+ )
800
+
801
+ def get_result() -> BootstrappedClassificationModel:
802
+ res = orca_api.GET(
803
+ "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
804
+ )
805
+ assert res["result"] is not None
806
+ return BootstrappedClassificationModel(res["result"])
807
+
808
+ job = Job(response["task_id"], get_result)
809
+ return job if background else job.result()