orca-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. orca_sdk/__init__.py +19 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +193 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
  22. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
  23. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
  24. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
  25. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  26. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
  27. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
  28. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
  30. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
  34. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  35. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
  36. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
  37. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
  38. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
  39. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
  40. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
  41. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
  42. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
  43. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
  45. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
  46. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
  47. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
  48. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
  49. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
  50. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
  51. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
  52. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  53. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
  54. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
  56. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
  58. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
  59. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
  60. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
  62. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
  63. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
  64. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
  65. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
  66. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
  67. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
  68. orca_sdk/_generated_api_client/client.py +216 -0
  69. orca_sdk/_generated_api_client/errors.py +38 -0
  70. orca_sdk/_generated_api_client/models/__init__.py +159 -0
  71. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
  72. orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
  73. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  74. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  75. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
  76. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  77. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  78. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  79. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  80. orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
  81. orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
  82. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
  83. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  84. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  85. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  86. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  87. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  88. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  89. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  90. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  91. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  92. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  93. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  94. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
  95. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  96. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  97. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  98. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  99. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  100. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  101. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  102. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  103. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  104. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  105. orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
  106. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
  107. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  108. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
  109. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  110. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  111. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  112. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  113. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
  114. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  115. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  116. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  117. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  118. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  119. orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
  120. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
  121. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
  122. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  123. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  124. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  125. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  126. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  127. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  128. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  129. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
  130. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  131. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  132. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  133. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  134. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  135. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  136. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
  137. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  138. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  139. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  140. orca_sdk/_generated_api_client/models/task.py +198 -0
  141. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  142. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  143. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  144. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  145. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  146. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  147. orca_sdk/_generated_api_client/py.typed +1 -0
  148. orca_sdk/_generated_api_client/types.py +56 -0
  149. orca_sdk/_utils/__init__.py +0 -0
  150. orca_sdk/_utils/analysis_ui.py +194 -0
  151. orca_sdk/_utils/analysis_ui_style.css +54 -0
  152. orca_sdk/_utils/auth.py +63 -0
  153. orca_sdk/_utils/auth_test.py +31 -0
  154. orca_sdk/_utils/common.py +37 -0
  155. orca_sdk/_utils/data_parsing.py +99 -0
  156. orca_sdk/_utils/data_parsing_test.py +244 -0
  157. orca_sdk/_utils/prediction_result_ui.css +18 -0
  158. orca_sdk/_utils/prediction_result_ui.py +64 -0
  159. orca_sdk/_utils/task.py +73 -0
  160. orca_sdk/classification_model.py +499 -0
  161. orca_sdk/classification_model_test.py +266 -0
  162. orca_sdk/conftest.py +117 -0
  163. orca_sdk/datasource.py +333 -0
  164. orca_sdk/datasource_test.py +95 -0
  165. orca_sdk/embedding_model.py +336 -0
  166. orca_sdk/embedding_model_test.py +173 -0
  167. orca_sdk/labeled_memoryset.py +1154 -0
  168. orca_sdk/labeled_memoryset_test.py +271 -0
  169. orca_sdk/orca_credentials.py +75 -0
  170. orca_sdk/orca_credentials_test.py +37 -0
  171. orca_sdk/telemetry.py +386 -0
  172. orca_sdk/telemetry_test.py +100 -0
  173. orca_sdk-0.1.0.dist-info/METADATA +39 -0
  174. orca_sdk-0.1.0.dist-info/RECORD +175 -0
  175. orca_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,499 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
7
+ from uuid import UUID
8
+
9
+ from ._generated_api_client.api import (
10
+ create_evaluation,
11
+ create_model,
12
+ delete_model,
13
+ get_evaluation,
14
+ get_model,
15
+ list_models,
16
+ list_predictions,
17
+ predict_gpu,
18
+ record_prediction_feedback,
19
+ )
20
+ from ._generated_api_client.models import (
21
+ ClassificationEvaluationResult,
22
+ CreateRACModelRequest,
23
+ EvaluationRequest,
24
+ ListPredictionsRequest,
25
+ )
26
+ from ._generated_api_client.models import (
27
+ ListPredictionsRequestSortItemItemType0 as PredictionSortColumns,
28
+ )
29
+ from ._generated_api_client.models import (
30
+ ListPredictionsRequestSortItemItemType1 as PredictionSortDirection,
31
+ )
32
+ from ._generated_api_client.models import RACHeadType, RACModelMetadata
33
+ from ._generated_api_client.models.prediction_request import PredictionRequest
34
+ from ._utils.common import CreateMode, DropMode
35
+ from ._utils.task import wait_for_task
36
+ from .datasource import Datasource
37
+ from .labeled_memoryset import LabeledMemoryset
38
+ from .telemetry import LabelPrediction, _parse_feedback
39
+
40
+
41
+ class ClassificationModel:
42
+ """
43
+ A handle to a classification model in OrcaCloud
44
+
45
+ Attributes:
46
+ id: Unique identifier for the model
47
+ name: Unique name of the model
48
+ memoryset: Memoryset that the model uses
49
+ head_type: Classification head type of the model
50
+ num_classes: Number of distinct classes the model can predict
51
+ memory_lookup_count: Number of memories the model uses for each prediction
52
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
53
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
54
+ created_at: When the model was created
55
+ """
56
+
57
+ id: str
58
+ name: str
59
+ memoryset: LabeledMemoryset
60
+ head_type: RACHeadType
61
+ num_classes: int
62
+ memory_lookup_count: int
63
+ weigh_memories: bool | None
64
+ min_memory_weight: float | None
65
+ version: int
66
+ created_at: datetime
67
+
68
+ def __init__(self, metadata: RACModelMetadata):
69
+ # for internal use only, do not document
70
+ self.id = metadata.id
71
+ self.name = metadata.name
72
+ self.memoryset = LabeledMemoryset.open(metadata.memoryset_id)
73
+ self.head_type = metadata.head_type
74
+ self.num_classes = metadata.num_classes
75
+ self.memory_lookup_count = metadata.memory_lookup_count
76
+ self.weigh_memories = metadata.weigh_memories
77
+ self.min_memory_weight = metadata.min_memory_weight
78
+ self.version = metadata.version
79
+ self.created_at = metadata.created_at
80
+
81
+ self._memoryset_override_id: str | None = None
82
+ self._last_prediction: LabelPrediction | None = None
83
+ self._last_prediction_was_batch: bool = False
84
+
85
+ def __eq__(self, other) -> bool:
86
+ return isinstance(other, ClassificationModel) and self.id == other.id
87
+
88
+ def __repr__(self):
89
+ return (
90
+ "ClassificationModel({\n"
91
+ f" name: '{self.name}',\n"
92
+ f" head_type: {self.head_type},\n"
93
+ f" num_classes: {self.num_classes},\n"
94
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
95
+ f" memoryset: LabeledMemoryset.open('{self.memoryset.name}'),\n"
96
+ "})"
97
+ )
98
+
99
+ @property
100
+ def last_prediction(self) -> LabelPrediction:
101
+ """
102
+ Last prediction made by the model
103
+
104
+ Note:
105
+ If the last prediction was part of a batch prediction, the last prediction from the
106
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
107
+ is raised.
108
+ """
109
+ if self._last_prediction_was_batch:
110
+ logging.warning(
111
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
112
+ )
113
+ if self._last_prediction is None:
114
+ raise LookupError("No prediction has been made yet")
115
+ return self._last_prediction
116
+
117
+ @classmethod
118
+ def create(
119
+ cls,
120
+ name: str,
121
+ memoryset: LabeledMemoryset,
122
+ head_type: Literal["BMMOE", "FF", "KNN", "MMOE"] = "KNN",
123
+ *,
124
+ num_classes: int | None = None,
125
+ memory_lookup_count: int | None = None,
126
+ weigh_memories: bool = True,
127
+ min_memory_weight: float | None = None,
128
+ if_exists: CreateMode = "error",
129
+ ) -> ClassificationModel:
130
+ """
131
+ Create a new classification model
132
+
133
+ Params:
134
+ name: Name for the new model (must be unique)
135
+ memoryset: Memoryset to attach the model to
136
+ head_type: Type of model head to use
137
+ num_classes: Number of classes this model can predict, will be inferred from memoryset if not specified
138
+ memory_lookup_count: Number of memories to lookup for each prediction,
139
+ by default the system uses a simple heuristic to choose a number of memories that works well in most cases
140
+ weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
141
+ min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
142
+ if_exists: What to do if a model with the same name already exists, defaults to
143
+ `"error"`. Other option is `"open"` to open the existing model.
144
+
145
+ Returns:
146
+ Handle to the new model in the OrcaCloud
147
+
148
+ Raises:
149
+ ValueError: If the model already exists and if_exists is `"error"` or if it is
150
+ `"open"` and the existing model has different attributes.
151
+
152
+ Examples:
153
+ Create a new model using default options:
154
+ >>> model = ClassificationModel.create(
155
+ ... "my_model",
156
+ ... LabeledMemoryset.open("my_memoryset"),
157
+ ... )
158
+
159
+ Create a new model with non-default model head and options:
160
+ >>> model = ClassificationModel.create(
161
+ ... name="my_model",
162
+ ... memoryset=LabeledMemoryset.open("my_memoryset"),
163
+ ... head_type=RACHeadType.MMOE,
164
+ ... num_classes=5,
165
+ ... memory_lookup_count=20,
166
+ ... )
167
+ """
168
+ if cls.exists(name):
169
+ if if_exists == "error":
170
+ raise ValueError(f"Model with name {name} already exists")
171
+ elif if_exists == "open":
172
+ existing = cls.open(name)
173
+ for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
174
+ local_attribute = locals()[attribute]
175
+ existing_attribute = getattr(existing, attribute)
176
+ if local_attribute is not None and local_attribute != existing_attribute:
177
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
178
+
179
+ # special case for memoryset
180
+ if existing.memoryset.id != memoryset.id:
181
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
182
+
183
+ return existing
184
+
185
+ metadata = create_model(
186
+ body=CreateRACModelRequest(
187
+ name=name,
188
+ memoryset_id=memoryset.id,
189
+ head_type=RACHeadType(head_type),
190
+ memory_lookup_count=memory_lookup_count,
191
+ num_classes=num_classes,
192
+ weigh_memories=weigh_memories,
193
+ min_memory_weight=min_memory_weight,
194
+ ),
195
+ )
196
+ return cls(metadata)
197
+
198
+ @classmethod
199
+ def open(cls, name: str) -> ClassificationModel:
200
+ """
201
+ Get a handle to a classification model in the OrcaCloud
202
+
203
+ Params:
204
+ name: Name or unique identifier of the classification model
205
+
206
+ Returns:
207
+ Handle to the existing classification model in the OrcaCloud
208
+
209
+ Raises:
210
+ LookupError: If the classification model does not exist
211
+ """
212
+ return cls(get_model(name))
213
+
214
+ @classmethod
215
+ def exists(cls, name_or_id: str) -> bool:
216
+ """
217
+ Check if a classification model exists in the OrcaCloud
218
+
219
+ Params:
220
+ name_or_id: Name or id of the classification model
221
+
222
+ Returns:
223
+ `True` if the classification model exists, `False` otherwise
224
+ """
225
+ try:
226
+ cls.open(name_or_id)
227
+ return True
228
+ except LookupError:
229
+ return False
230
+
231
+ @classmethod
232
+ def all(cls) -> list[ClassificationModel]:
233
+ """
234
+ Get a list of handles to all classification models in the OrcaCloud
235
+
236
+ Returns:
237
+ List of handles to all classification models in the OrcaCloud
238
+ """
239
+ return [cls(metadata) for metadata in list_models()]
240
+
241
+ @classmethod
242
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
243
+ """
244
+ Delete a classification model from the OrcaCloud
245
+
246
+ Warning:
247
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
248
+
249
+ Params:
250
+ name_or_id: Name or id of the classification model
251
+ if_not_exists: What to do if the classification model does not exist, defaults to `"error"`.
252
+ Other option is `"ignore"` to do nothing if the classification model does not exist.
253
+
254
+ Raises:
255
+ LookupError: If the classification model does not exist and if_not_exists is `"error"`
256
+ """
257
+ try:
258
+ delete_model(name_or_id)
259
+ logging.info(f"Deleted model {name_or_id}")
260
+ except LookupError:
261
+ if if_not_exists == "error":
262
+ raise
263
+
264
+ @overload
265
+ def predict(
266
+ self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
267
+ ) -> list[LabelPrediction]:
268
+ pass
269
+
270
+ @overload
271
+ def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
272
+ pass
273
+
274
+ def predict(
275
+ self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
276
+ ) -> list[LabelPrediction] | LabelPrediction:
277
+ """
278
+ Predict label(s) for the given input value(s) grounded in similar memories
279
+
280
+ Params:
281
+ value: Value(s) to get predict the labels of
282
+ expected_labels: Expected label(s) for the given input to record for model evaluation
283
+ tags: Tags to add to the prediction(s)
284
+
285
+ Returns:
286
+ Label prediction or list of label predictions
287
+
288
+ Examples:
289
+ Predict the label for a single value:
290
+ >>> prediction = model.predict("I am happy", tags={"test"})
291
+ LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy' })
292
+
293
+ Predict the labels for a list of values:
294
+ >>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
295
+ [
296
+ LabelPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
297
+ LabelPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
298
+ ]
299
+ """
300
+ response = predict_gpu(
301
+ self.id,
302
+ body=PredictionRequest(
303
+ input_values=value if isinstance(value, list) else [value],
304
+ memoryset_override_id=self._memoryset_override_id,
305
+ expected_labels=(
306
+ expected_labels
307
+ if isinstance(expected_labels, list)
308
+ else [expected_labels]
309
+ if expected_labels is not None
310
+ else None
311
+ ),
312
+ tags=list(tags),
313
+ ),
314
+ )
315
+ predictions = [
316
+ LabelPrediction(
317
+ prediction_id=prediction.prediction_id,
318
+ label=prediction.label,
319
+ label_name=prediction.label_name,
320
+ confidence=prediction.confidence,
321
+ memoryset=self.memoryset,
322
+ model=self,
323
+ )
324
+ for prediction in response
325
+ ]
326
+ self._last_prediction_was_batch = isinstance(value, list)
327
+ self._last_prediction = predictions[-1]
328
+ return predictions if isinstance(value, list) else predictions[0]
329
+
330
+ def predictions(
331
+ self,
332
+ limit: int = 100,
333
+ offset: int = 0,
334
+ tag: str | None = None,
335
+ sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
336
+ ) -> list[LabelPrediction]:
337
+ """
338
+ Get a list of predictions made by this model
339
+
340
+ Params:
341
+ limit: Optional maximum number of predictions to return
342
+ offset: Optional offset of the first prediction to return
343
+ tag: Optional tag to filter predictions by
344
+ sort: Optional list of columns and directions to sort the predictions by.
345
+ Predictions can be sorted by `timestamp` or `confidence`.
346
+
347
+ Returns:
348
+ List of label predictions
349
+
350
+ Examples:
351
+ Get the last 3 predictions:
352
+ >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
353
+ [
354
+ LabeledPrediction({label: <positive: 1>, confidence: 0.95, input_value: 'I am happy'}),
355
+ LabeledPrediction({label: <negative: 0>, confidence: 0.05, input_value: 'I am sad'}),
356
+ LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am ecstatic'}),
357
+ ]
358
+
359
+
360
+ Get second most confident prediction:
361
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
362
+ [LabeledPrediction({label: <positive: 1>, confidence: 0.90, input_value: 'I am having a good day'})]
363
+ """
364
+ predictions = list_predictions(
365
+ body=ListPredictionsRequest(
366
+ model_id=self.id,
367
+ limit=limit,
368
+ offset=offset,
369
+ sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
370
+ tag=tag,
371
+ ),
372
+ )
373
+ return [
374
+ LabelPrediction(
375
+ prediction_id=prediction.prediction_id,
376
+ label=prediction.label,
377
+ label_name=prediction.label_name,
378
+ confidence=prediction.confidence,
379
+ memoryset=self.memoryset,
380
+ model=self,
381
+ telemetry=prediction,
382
+ )
383
+ for prediction in predictions
384
+ ]
385
+
386
+ def evaluate(
387
+ self,
388
+ datasource: Datasource,
389
+ value_column: str = "value",
390
+ label_column: str = "label",
391
+ record_predictions: bool = False,
392
+ tags: set[str] | None = None,
393
+ ) -> dict[str, float]:
394
+ """
395
+ Evaluate the classification model on a given datasource
396
+
397
+ Params:
398
+ datasource: Datasource to evaluate the model on
399
+ value_column: Name of the column that contains the input values to the model
400
+ label_column: Name of the column containing the expected labels
401
+ record_predictions: Whether to record [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s for analysis
402
+ tags: Optional tags to add to the recorded [`LabelPrediction`][orca_sdk.telemetry.LabelPrediction]s
403
+
404
+ Returns:
405
+ Dictionary with evaluation metrics
406
+
407
+ Examples:
408
+ >>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
409
+ { "f1_score": 0.85, "roc_auc": 0.85, "pr_auc": 0.85, "accuracy": 0.85, "loss": 0.35 }
410
+ """
411
+ response = create_evaluation(
412
+ self.id,
413
+ body=EvaluationRequest(
414
+ datasource_id=datasource.id,
415
+ datasource_label_column=label_column,
416
+ datasource_value_column=value_column,
417
+ memoryset_override_id=self._memoryset_override_id,
418
+ record_telemetry=record_predictions,
419
+ telemetry_tags=list(tags) if tags else None,
420
+ ),
421
+ )
422
+ wait_for_task(response.task_id, description="Running evaluation")
423
+ response = get_evaluation(self.id, UUID(response.task_id))
424
+ assert response.result is not None
425
+ return response.result.to_dict()
426
+
427
+ def finetune(self, datasource: Datasource):
428
+ # do not document until implemented
429
+ raise NotImplementedError("Finetuning is not supported yet")
430
+
431
+ @contextmanager
432
+ def use_memoryset(self, memoryset_override: LabeledMemoryset) -> Generator[None, None, None]:
433
+ """
434
+ Temporarily override the memoryset used by the model for predictions
435
+
436
+ Params:
437
+ memoryset_override: Memoryset to override the default memoryset with
438
+
439
+ Examples:
440
+ >>> with model.use_memoryset(LabeledMemoryset.open("my_other_memoryset")):
441
+ ... predictions = model.predict("I am happy")
442
+ """
443
+ self._memoryset_override_id = memoryset_override.id
444
+ yield
445
+ self._memoryset_override_id = None
446
+
447
+ @overload
448
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
449
+ pass
450
+
451
+ @overload
452
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
453
+ pass
454
+
455
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
456
+ """
457
+ Record feedback for a list of predictions.
458
+
459
+ We support recording feedback in several categories for each prediction. A
460
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
461
+ the first time feedback with a new name is recorded. Categories are global across models.
462
+ The value type of the category is inferred from the first recorded value. Subsequent
463
+ feedback for the same category must be of the same type.
464
+
465
+ Params:
466
+ feedback: Feedback to record, this should be dictionaries with the following keys:
467
+
468
+ - `category`: Name of the category under which to record the feedback.
469
+ - `value`: Feedback value to record, should be `True` for positive feedback and
470
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
471
+ where negative values indicate negative feedback and positive values indicate
472
+ positive feedback.
473
+ - `comment`: Optional comment to record with the feedback.
474
+
475
+ Examples:
476
+ Record whether predictions were correct or incorrect:
477
+ >>> model.record_feedback({
478
+ ... "prediction": p.prediction_id,
479
+ ... "category": "correct",
480
+ ... "value": p.label == p.expected_label,
481
+ ... } for p in predictions)
482
+
483
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
484
+ >>> model.record_feedback({
485
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
486
+ ... "category": "rating",
487
+ ... "value": -0.5,
488
+ ... "comment": "2 stars"
489
+ ... })
490
+
491
+ Raises:
492
+ ValueError: If the value does not match previous value types for the category, or is a
493
+ [`float`][float] that is not between `-1.0` and `+1.0`.
494
+ """
495
+ record_prediction_feedback(
496
+ body=[
497
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
498
+ ],
499
+ )