orca-sdk 0.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. orca_sdk/__init__.py +24 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +205 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
  22. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
  23. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
  24. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
  25. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
  26. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
  27. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
  28. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
  30. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
  34. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
  35. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
  36. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
  37. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  38. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
  39. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
  40. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
  41. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
  42. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
  43. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
  45. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
  46. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
  47. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
  48. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
  49. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
  50. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
  51. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
  52. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
  53. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
  54. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
  57. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
  58. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
  59. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  60. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
  61. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
  62. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
  63. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
  65. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
  66. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
  67. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
  68. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
  69. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
  70. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
  71. orca_sdk/_generated_api_client/client.py +216 -0
  72. orca_sdk/_generated_api_client/errors.py +38 -0
  73. orca_sdk/_generated_api_client/models/__init__.py +179 -0
  74. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
  75. orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
  76. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
  77. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  78. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  79. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
  80. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  81. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  82. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  83. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  84. orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
  85. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
  86. orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
  87. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
  88. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
  89. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  90. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  91. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  92. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  93. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
  94. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
  95. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
  96. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  97. orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
  98. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  99. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  100. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  101. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  102. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  103. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  104. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
  105. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  106. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  107. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  108. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  109. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  110. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  111. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  112. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  113. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  114. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  115. orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
  116. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
  117. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  118. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
  119. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  120. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  121. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  122. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  123. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
  124. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  125. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  126. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  127. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  128. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  129. orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
  130. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  131. orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
  132. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  133. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  134. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  135. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  136. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  137. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  138. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
  139. orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
  140. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  141. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  142. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  143. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  144. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  145. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
  146. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
  147. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  148. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
  149. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  150. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  151. orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
  152. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  153. orca_sdk/_generated_api_client/models/task.py +198 -0
  154. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  155. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  156. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  157. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  158. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  159. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  160. orca_sdk/_generated_api_client/py.typed +1 -0
  161. orca_sdk/_generated_api_client/types.py +56 -0
  162. orca_sdk/_utils/__init__.py +0 -0
  163. orca_sdk/_utils/analysis_ui.py +192 -0
  164. orca_sdk/_utils/analysis_ui_style.css +54 -0
  165. orca_sdk/_utils/auth.py +68 -0
  166. orca_sdk/_utils/auth_test.py +31 -0
  167. orca_sdk/_utils/common.py +37 -0
  168. orca_sdk/_utils/data_parsing.py +99 -0
  169. orca_sdk/_utils/data_parsing_test.py +244 -0
  170. orca_sdk/_utils/prediction_result_ui.css +18 -0
  171. orca_sdk/_utils/prediction_result_ui.py +64 -0
  172. orca_sdk/_utils/task.py +73 -0
  173. orca_sdk/classification_model.py +508 -0
  174. orca_sdk/classification_model_test.py +272 -0
  175. orca_sdk/conftest.py +116 -0
  176. orca_sdk/credentials.py +126 -0
  177. orca_sdk/credentials_test.py +37 -0
  178. orca_sdk/datasource.py +333 -0
  179. orca_sdk/datasource_test.py +96 -0
  180. orca_sdk/embedding_model.py +347 -0
  181. orca_sdk/embedding_model_test.py +176 -0
  182. orca_sdk/memoryset.py +1209 -0
  183. orca_sdk/memoryset_test.py +287 -0
  184. orca_sdk/telemetry.py +398 -0
  185. orca_sdk/telemetry_test.py +109 -0
  186. orca_sdk-0.0.78.dist-info/METADATA +79 -0
  187. orca_sdk-0.0.78.dist-info/RECORD +188 -0
  188. orca_sdk-0.0.78.dist-info/WHEEL +4 -0
@@ -0,0 +1,508 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
7
+ from uuid import UUID
8
+
9
+ from ._generated_api_client.api import (
10
+ create_evaluation,
11
+ create_model,
12
+ delete_model,
13
+ get_evaluation,
14
+ get_model,
15
+ list_models,
16
+ list_predictions,
17
+ predict_gpu,
18
+ record_prediction_feedback,
19
+ )
20
+ from ._generated_api_client.models import (
21
+ CreateRACModelRequest,
22
+ EvaluationRequest,
23
+ ListPredictionsRequest,
24
+ )
25
+ from ._generated_api_client.models import (
26
+ PredictionSortItemItemType0 as PredictionSortColumns,
27
+ )
28
+ from ._generated_api_client.models import (
29
+ PredictionSortItemItemType1 as PredictionSortDirection,
30
+ )
31
+ from ._generated_api_client.models import RACHeadType, RACModelMetadata
32
+ from ._generated_api_client.models.prediction_request import PredictionRequest
33
+ from ._utils.common import CreateMode, DropMode
34
+ from ._utils.task import wait_for_task
35
+ from .datasource import Datasource
36
+ from .memoryset import LabeledMemoryset
37
+ from .telemetry import LabelPrediction, _parse_feedback
38
+
39
+
40
+ class ClassificationModel:
41
+ """
42
+ A handle to a classification model in OrcaCloud
43
+
44
+ Attributes:
45
+ id: Unique identifier for the model
46
+ name: Unique name of the model
47
+ memoryset: Memoryset that the model uses
48
+ head_type: Classification head type of the model
49
+ num_classes: Number of distinct classes the model can predict
50
+ memory_lookup_count: Number of memories the model uses for each prediction
51
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
52
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
53
+ created_at: When the model was created
54
+ """
55
+
56
+ id: str
57
+ name: str
58
+ memoryset: LabeledMemoryset
59
+ head_type: RACHeadType
60
+ num_classes: int
61
+ memory_lookup_count: int
62
+ weigh_memories: bool | None
63
+ min_memory_weight: float | None
64
+ version: int
65
+ created_at: datetime
66
+
67
+ def __init__(self, metadata: RACModelMetadata):
68
+ # for internal use only, do not document
69
+ self.id = metadata.id
70
+ self.name = metadata.name
71
+ self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
72
+ self.head_type = metadata.head_type
73
+ self.num_classes = metadata.num_classes
74
+ self.memory_lookup_count = metadata.memory_lookup_count
75
+ self.weigh_memories = metadata.weigh_memories
76
+ self.min_memory_weight = metadata.min_memory_weight
77
+ self.version = metadata.version
78
+ self.created_at = metadata.created_at
79
+
80
+ self._memoryset_override_id: str | None = None
81
+ self._last_prediction: LabelPrediction | None = None
82
+ self._last_prediction_was_batch: bool = False
83
+
84
+ def __eq__(self, other) -> bool:
85
+ return isinstance(other, ClassificationModel) and self.id == other.id
86
+
87
+ def __repr__(self):
88
+ return (
89
+ "ClassificationModel({\n"
90
+ f" name: '{self.name}',\n"
91
+ f" head_type: {self.head_type},\n"
92
+ f" num_classes: {self.num_classes},\n"
93
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
94
+ f" memoryset: LabeledMemoryset.open('{self.memoryset.name}'),\n"
95
+ "})"
96
+ )
97
+
98
+ @property
99
+ def last_prediction(self) -> LabelPrediction:
100
+ """
101
+ Last prediction made by the model
102
+
103
+ Note:
104
+ If the last prediction was part of a batch prediction, the last prediction from the
105
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
106
+ is raised.
107
+ """
108
+ if self._last_prediction_was_batch:
109
+ logging.warning(
110
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
111
+ )
112
+ if self._last_prediction is None:
113
+ raise LookupError("No prediction has been made yet")
114
+ return self._last_prediction
115
+
116
+ @classmethod
117
+ def create(
118
+ cls,
119
+ name: str,
120
+ memoryset: LabeledMemoryset,
121
+ head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
122
+ *,
123
+ num_classes: int | None = None,
124
+ memory_lookup_count: int | None = None,
125
+ weigh_memories: bool = True,
126
+ min_memory_weight: float | None = None,
127
+ if_exists: CreateMode = "error",
128
+ ) -> ClassificationModel:
129
+ """
130
+ Create a new classification model
131
+
132
+ Params:
133
+ name: Name for the new model (must be unique)
134
+ memoryset: Memoryset to attach the model to
135
+ head_type: Type of model head to use
136
+ num_classes: Number of classes this model can predict, will be inferred from memoryset if not specified
137
+ memory_lookup_count: Number of memories to lookup for each prediction,
138
+ by default the system uses a simple heuristic to choose a number of memories that works well in most cases
139
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
140
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
141
+ if_exists: What to do if a model with the same name already exists, defaults to
142
+ `"error"`. Other option is `"open"` to open the existing model.
143
+
144
+ Returns:
145
+ Handle to the new model in the OrcaCloud
146
+
147
+ Raises:
148
+ ValueError: If the model already exists and if_exists is `"error"` or if it is
149
+ `"open"` and the existing model has different attributes.
150
+
151
+ Examples:
152
+ Create a new model using default options:
153
+ >>> model = ClassificationModel.create(
154
+ ... "my_model",
155
+ ... LabeledMemoryset.open("my_memoryset"),
156
+ ... )
157
+
158
+ Create a new model with non-default model head and options:
159
+ >>> model = ClassificationModel.create(
160
+ ... name="my_model",
161
+ ... memoryset=LabeledMemoryset.open("my_memoryset"),
162
+ ... head_type=RACHeadType.MMOE,
163
+ ... num_classes=5,
164
+ ... memory_lookup_count=20,
165
+ ... )
166
+ """
167
+ if cls.exists(name):
168
+ if if_exists == "error":
169
+ raise ValueError(f"Model with name {name} already exists")
170
+ elif if_exists == "open":
171
+ existing = cls.open(name)
172
+ for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
173
+ local_attribute = locals()[attribute]
174
+ existing_attribute = getattr(existing, attribute)
175
+ if local_attribute is not None and local_attribute != existing_attribute:
176
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
177
+
178
+ # special case for memoryset
179
+ if existing.memoryset.id != memoryset.id:
180
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
181
+
182
+ return existing
183
+
184
+ metadata = create_model(
185
+ body=CreateRACModelRequest(
186
+ name=name,
187
+ memoryset_id=memoryset.id,
188
+ head_type=RACHeadType(head_type),
189
+ memory_lookup_count=memory_lookup_count,
190
+ num_classes=num_classes,
191
+ weigh_memories=weigh_memories,
192
+ min_memory_weight=min_memory_weight,
193
+ ),
194
+ )
195
+ return cls(metadata)
196
+
197
+ @classmethod
198
+ def open(cls, name: str) -> ClassificationModel:
199
+ """
200
+ Get a handle to a classification model in the OrcaCloud
201
+
202
+ Params:
203
+ name: Name or unique identifier of the classification model
204
+
205
+ Returns:
206
+ Handle to the existing classification model in the OrcaCloud
207
+
208
+ Raises:
209
+ LookupError: If the classification model does not exist
210
+ """
211
+ return cls(get_model(name))
212
+
213
+ @classmethod
214
+ def exists(cls, name_or_id: str) -> bool:
215
+ """
216
+ Check if a classification model exists in the OrcaCloud
217
+
218
+ Params:
219
+ name_or_id: Name or id of the classification model
220
+
221
+ Returns:
222
+ `True` if the classification model exists, `False` otherwise
223
+ """
224
+ try:
225
+ cls.open(name_or_id)
226
+ return True
227
+ except LookupError:
228
+ return False
229
+
230
+ @classmethod
231
+ def all(cls) -> list[ClassificationModel]:
232
+ """
233
+ Get a list of handles to all classification models in the OrcaCloud
234
+
235
+ Returns:
236
+ List of handles to all classification models in the OrcaCloud
237
+ """
238
+ return [cls(metadata) for metadata in list_models()]
239
+
240
+ @classmethod
241
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
242
+ """
243
+ Delete a classification model from the OrcaCloud
244
+
245
+ Warning:
246
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
247
+
248
+ Params:
249
+ name_or_id: Name or id of the classification model
250
+ if_not_exists: What to do if the classification model does not exist, defaults to `"error"`.
251
+ Other option is `"ignore"` to do nothing if the classification model does not exist.
252
+
253
+ Raises:
254
+ LookupError: If the classification model does not exist and if_not_exists is `"error"`
255
+ """
256
+ try:
257
+ delete_model(name_or_id)
258
+ logging.info(f"Deleted model {name_or_id}")
259
+ except LookupError:
260
+ if if_not_exists == "error":
261
+ raise
262
+
263
+ @overload
264
+ def predict(
265
+ self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
266
+ ) -> list[LabelPrediction]:
267
+ pass
268
+
269
+ @overload
270
+ def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
271
+ pass
272
+
273
+ def predict(
274
+ self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
275
+ ) -> list[LabelPrediction] | LabelPrediction:
276
+ """
277
+ Predict label(s) for the given input value(s) grounded in similar memories
278
+
279
+ Params:
280
+ value: Value(s) to get predict the labels of
281
+ expected_labels: Expected label(s) for the given input to record for model evaluation
282
+ tags: Tags to add to the prediction(s)
283
+
284
+ Returns:
285
+ Label prediction or list of label predictions
286
+
287
+ Examples:
288
+ Predict the label for a single value:
289
+ >>> prediction = model.predict("I am happy", tags={"test"})
290
+ LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
291
+
292
+ Predict the labels for a list of values:
293
+ >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
294
+ [
295
+ LabelPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
296
+ LabelPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
297
+ ]
298
+ """
299
+ response = predict_gpu(
300
+ self.id,
301
+ body=PredictionRequest(
302
+ input_values=value if isinstance(value, list) else [value],
303
+ memoryset_override_id=self._memoryset_override_id,
304
+ expected_labels=(
305
+ expected_labels
306
+ if isinstance(expected_labels, list)
307
+ else [expected_labels]
308
+ if expected_labels is not None
309
+ else None
310
+ ),
311
+ tags=list(tags),
312
+ ),
313
+ )
314
+ predictions = [
315
+ LabelPrediction(
316
+ prediction_id=prediction.prediction_id,
317
+ label=prediction.label,
318
+ label_name=prediction.label_name,
319
+ confidence=prediction.confidence,
320
+ anomaly_score=prediction.anomaly_score,
321
+ memoryset=self.memoryset,
322
+ model=self,
323
+ )
324
+ for prediction in response
325
+ ]
326
+ self._last_prediction_was_batch = isinstance(value, list)
327
+ self._last_prediction = predictions[-1]
328
+ return predictions if isinstance(value, list) else predictions[0]
329
+
330
+ def predictions(
331
+ self,
332
+ limit: int = 100,
333
+ offset: int = 0,
334
+ tag: str | None = None,
335
+ sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
336
+ expected_label_match: bool | None = None,
337
+ ) -> list[LabelPrediction]:
338
+ """
339
+ Get a list of predictions made by this model
340
+
341
+ Params:
342
+ limit: Optional maximum number of predictions to return
343
+ offset: Optional offset of the first prediction to return
344
+ tag: Optional tag to filter predictions by
345
+ sort: Optional list of columns and directions to sort the predictions by.
346
+ Predictions can be sorted by `timestamp` or `confidence`.
347
+ expected_label_match: Optional filter to only include predictions where the expected
348
+ label does (`True`) or doesn't (`False`) match the predicted label
349
+
350
+ Returns:
351
+ List of label predictions
352
+
353
+ Examples:
354
+ Get the last 3 predictions:
355
+ >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
356
+ [
357
+ LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
358
+ LabeledPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
359
+ LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
360
+ ]
361
+
362
+
363
+ Get second most confident prediction:
364
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
365
+ [LabeledPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
366
+
367
+ Get predictions where the expected label doesn't match the predicted label:
368
+ >>> predictions = model.predictions(expected_label_match=False)
369
+ [LabeledPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
370
+ """
371
+ predictions = list_predictions(
372
+ body=ListPredictionsRequest(
373
+ model_id=self.id,
374
+ limit=limit,
375
+ offset=offset,
376
+ sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
377
+ tag=tag,
378
+ expected_label_match=expected_label_match,
379
+ ),
380
+ )
381
+ return [
382
+ LabelPrediction(
383
+ prediction_id=prediction.prediction_id,
384
+ label=prediction.label,
385
+ label_name=prediction.label_name,
386
+ confidence=prediction.confidence,
387
+ anomaly_score=prediction.anomaly_score,
388
+ memoryset=self.memoryset,
389
+ model=self,
390
+ telemetry=prediction,
391
+ )
392
+ for prediction in predictions
393
+ ]
394
+
395
+ def evaluate(
396
+ self,
397
+ datasource: Datasource,
398
+ value_column: str = "value",
399
+ label_column: str = "label",
400
+ record_predictions: bool = False,
401
+ tags: set[str] | None = None,
402
+ ) -> dict[str, Any]:
403
+ """
404
+ Evaluate the classification model on a given datasource
405
+
406
+ Params:
407
+ datasource: Datasource to evaluate the model on
408
+ value_column: Name of the column that contains the input values to the model
409
+ label_column: Name of the column containing the expected labels
410
+ record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
411
+ tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
412
+
413
+ Returns:
414
+ Dictionary with evaluation metrics
415
+
416
+ Examples:
417
+ >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
418
+ { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35, ... }
419
+ """
420
+ response = create_evaluation(
421
+ self.id,
422
+ body=EvaluationRequest(
423
+ datasource_id=datasource.id,
424
+ datasource_label_column=label_column,
425
+ datasource_value_column=value_column,
426
+ memoryset_override_id=self._memoryset_override_id,
427
+ record_telemetry=record_predictions,
428
+ telemetry_tags=list(tags) if tags else None,
429
+ ),
430
+ )
431
+ wait_for_task(response.task_id, description="Running evaluation")
432
+ response = get_evaluation(self.id, UUID(response.task_id))
433
+ assert response.result is not None
434
+ return response.result.to_dict()
435
+
436
+ def finetune(self, datasource: Datasource):
437
+ # do not document until implemented
438
+ raise NotImplementedError("Finetuning is not supported yet")
439
+
440
+ @contextmanager
441
+ def use_memoryset(self, memoryset_override: LabeledMemoryset) -> Generator[None, None, None]:
442
+ """
443
+ Temporarily override the memoryset used by the model for predictions
444
+
445
+ Params:
446
+ memoryset_override: Memoryset to override the default memoryset with
447
+
448
+ Examples:
449
+ >>> with model.use_memoryset(LabeledMemoryset.open("my_other_memoryset")):
450
+ ... predictions = model.predict("I am happy")
451
+ """
452
+ self._memoryset_override_id = memoryset_override.id
453
+ yield
454
+ self._memoryset_override_id = None
455
+
456
+ @overload
457
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
458
+ pass
459
+
460
+ @overload
461
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
462
+ pass
463
+
464
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
465
+ """
466
+ Record feedback for a list of predictions.
467
+
468
+ We support recording feedback in several categories for each prediction. A
469
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
470
+ the first time feedback with a new name is recorded. Categories are global across models.
471
+ The value type of the category is inferred from the first recorded value. Subsequent
472
+ feedback for the same category must be of the same type.
473
+
474
+ Params:
475
+ feedback: Feedback to record, this should be dictionaries with the following keys:
476
+
477
+ - `category`: Name of the category under which to record the feedback.
478
+ - `value`: Feedback value to record, should be `True` for positive feedback and
479
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
480
+ where negative values indicate negative feedback and positive values indicate
481
+ positive feedback.
482
+ - `comment`: Optional comment to record with the feedback.
483
+
484
+ Examples:
485
+ Record whether predictions were correct or incorrect:
486
+ >>> model.record_feedback({
487
+ ... "prediction": p.prediction_id,
488
+ ... "category": "correct",
489
+ ... "value": p.label == p.expected_label,
490
+ ... } for p in predictions)
491
+
492
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
493
+ >>> model.record_feedback({
494
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
495
+ ... "category": "rating",
496
+ ... "value": -0.5,
497
+ ... "comment": "2 stars"
498
+ ... })
499
+
500
+ Raises:
501
+ ValueError: If the value does not match previous value types for the category, or is a
502
+ [`float`][float] that is not between `-1.0` and `+1.0`.
503
+ """
504
+ record_prediction_feedback(
505
+ body=[
506
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
507
+ ],
508
+ )