orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,684 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
7
+
8
+ from datasets import Dataset
9
+
10
+ from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .client import (
13
+ PredictiveModelUpdate,
14
+ RARHeadType,
15
+ RegressionModelMetadata,
16
+ orca_api,
17
+ )
18
+ from .datasource import Datasource
19
+ from .job import Job
20
+ from .memoryset import ScoredMemoryset
21
+ from .telemetry import (
22
+ RegressionPrediction,
23
+ TelemetryMode,
24
+ _get_telemetry_config,
25
+ _parse_feedback,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class RegressionModel:
32
+ """
33
+ A handle to a regression model in OrcaCloud
34
+
35
+ Attributes:
36
+ id: Unique identifier for the model
37
+ name: Unique name of the model
38
+ description: Optional description of the model
39
+ memoryset: Memoryset that the model uses
40
+ head_type: Regression head type of the model
41
+ memory_lookup_count: Number of memories the model uses for each prediction
42
+ locked: Whether the model is locked to prevent accidental deletion
43
+ created_at: When the model was created
44
+ updated_at: When the model was last updated
45
+ """
46
+
47
+ id: str
48
+ name: str
49
+ description: str | None
50
+ memoryset: ScoredMemoryset
51
+ head_type: RARHeadType
52
+ memory_lookup_count: int
53
+ version: int
54
+ locked: bool
55
+ created_at: datetime
56
+ updated_at: datetime
57
+ memoryset_id: str
58
+
59
+ _last_prediction: RegressionPrediction | None
60
+ _last_prediction_was_batch: bool
61
+ _memoryset_override_id: str | None
62
+
63
+ def __init__(self, metadata: RegressionModelMetadata):
64
+ # for internal use only, do not document
65
+ self.id = metadata["id"]
66
+ self.name = metadata["name"]
67
+ self.description = metadata["description"]
68
+ self.memoryset = ScoredMemoryset.open(metadata["memoryset_id"])
69
+ self.head_type = metadata["head_type"]
70
+ self.memory_lookup_count = metadata["memory_lookup_count"]
71
+ self.version = metadata["version"]
72
+ self.locked = metadata["locked"]
73
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
74
+ self.updated_at = datetime.fromisoformat(metadata["updated_at"])
75
+ self.memoryset_id = metadata["memoryset_id"]
76
+
77
+ self._memoryset_override_id = None
78
+ self._last_prediction = None
79
+ self._last_prediction_was_batch = False
80
+
81
+ def __eq__(self, other) -> bool:
82
+ return isinstance(other, RegressionModel) and self.id == other.id
83
+
84
+ def __repr__(self):
85
+ return (
86
+ "RegressionModel({\n"
87
+ f" name: '{self.name}',\n"
88
+ f" head_type: {self.head_type},\n"
89
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
90
+ f" memoryset: ScoredMemoryset.open('{self.memoryset.name}'),\n"
91
+ "})"
92
+ )
93
+
94
+ @property
95
+ def last_prediction(self) -> RegressionPrediction:
96
+ """
97
+ Last prediction made by the model
98
+
99
+ Note:
100
+ If the last prediction was part of a batch prediction, the last prediction from the
101
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
102
+ is raised.
103
+ """
104
+ if self._last_prediction_was_batch:
105
+ logging.warning(
106
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
107
+ )
108
+ if self._last_prediction is None:
109
+ raise LookupError("No prediction has been made yet")
110
+ return self._last_prediction
111
+
112
+ @classmethod
113
+ def create(
114
+ cls,
115
+ name: str,
116
+ memoryset: ScoredMemoryset,
117
+ memory_lookup_count: int | None = None,
118
+ description: str | None = None,
119
+ if_exists: CreateMode = "error",
120
+ ) -> RegressionModel:
121
+ """
122
+ Create a regression model.
123
+
124
+ Params:
125
+ name: Name of the model
126
+ memoryset: The scored memoryset to use for prediction
127
+ memory_lookup_count: Number of memories to retrieve for prediction. Defaults to 10.
128
+ description: Description of the model
129
+ if_exists: How to handle existing models with the same name
130
+
131
+ Returns:
132
+ RegressionModel instance
133
+
134
+ Raises:
135
+ ValueError: If a model with the same name already exists and if_exists is "error"
136
+ ValueError: If the memoryset is empty
137
+ ValueError: If memory_lookup_count exceeds the number of memories in the memoryset
138
+ """
139
+ existing = cls.exists(name)
140
+ if existing:
141
+ if if_exists == "error":
142
+ raise ValueError(f"RegressionModel with name '{name}' already exists")
143
+ elif if_exists == "open":
144
+ existing = cls.open(name)
145
+ for attribute in {"memory_lookup_count"}:
146
+ local_attribute = locals()[attribute]
147
+ existing_attribute = getattr(existing, attribute)
148
+ if local_attribute is not None and local_attribute != existing_attribute:
149
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
150
+
151
+ # special case for memoryset
152
+ if existing.memoryset_id != memoryset.id:
153
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
154
+
155
+ return existing
156
+
157
+ metadata = orca_api.POST(
158
+ "/regression_model",
159
+ json={
160
+ "name": name,
161
+ "memoryset_name_or_id": memoryset.id,
162
+ "memory_lookup_count": memory_lookup_count,
163
+ "description": description,
164
+ },
165
+ )
166
+ return cls(metadata)
167
+
168
+ @classmethod
169
+ def open(cls, name: str) -> RegressionModel:
170
+ """
171
+ Get a handle to a regression model in the OrcaCloud
172
+
173
+ Params:
174
+ name: Name or unique identifier of the regression model
175
+
176
+ Returns:
177
+ Handle to the existing regression model in the OrcaCloud
178
+
179
+ Raises:
180
+ LookupError: If the regression model does not exist
181
+ """
182
+ return cls(orca_api.GET("/regression_model/{name_or_id}", params={"name_or_id": name}))
183
+
184
+ @classmethod
185
+ def exists(cls, name_or_id: str) -> bool:
186
+ """
187
+ Check if a regression model exists in the OrcaCloud
188
+
189
+ Params:
190
+ name_or_id: Name or id of the regression model
191
+
192
+ Returns:
193
+ `True` if the regression model exists, `False` otherwise
194
+ """
195
+ try:
196
+ cls.open(name_or_id)
197
+ return True
198
+ except LookupError:
199
+ return False
200
+
201
+ @classmethod
202
+ def all(cls) -> list[RegressionModel]:
203
+ """
204
+ Get a list of handles to all regression models in the OrcaCloud
205
+
206
+ Returns:
207
+ List of handles to all regression models in the OrcaCloud
208
+ """
209
+ return [cls(metadata) for metadata in orca_api.GET("/regression_model")]
210
+
211
+ @classmethod
212
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
213
+ """
214
+ Delete a regression model from the OrcaCloud
215
+
216
+ Warning:
217
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
218
+
219
+ Params:
220
+ name_or_id: Name or id of the regression model
221
+ if_not_exists: What to do if the regression model does not exist, defaults to `"error"`.
222
+ Other option is `"ignore"` to do nothing if the regression model does not exist.
223
+
224
+ Raises:
225
+ LookupError: If the regression model does not exist and if_not_exists is `"error"`
226
+ """
227
+ try:
228
+ orca_api.DELETE("/regression_model/{name_or_id}", params={"name_or_id": name_or_id})
229
+ logging.info(f"Deleted model {name_or_id}")
230
+ except LookupError:
231
+ if if_not_exists == "error":
232
+ raise
233
+
234
+ def refresh(self):
235
+ """Refresh the model data from the OrcaCloud"""
236
+ self.__dict__.update(self.open(self.name).__dict__)
237
+
238
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
239
+ """
240
+ Update editable attributes of the model.
241
+
242
+ Note:
243
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
244
+
245
+ Params:
246
+ description: Value to set for the description
247
+ locked: Value to set for the locked status
248
+
249
+ Examples:
250
+ Update the description:
251
+ >>> model.set(description="New description")
252
+
253
+ Remove description:
254
+ >>> model.set(description=None)
255
+
256
+ Lock the model:
257
+ >>> model.set(locked=True)
258
+ """
259
+ update: PredictiveModelUpdate = {}
260
+ if description is not UNSET:
261
+ update["description"] = description
262
+ if locked is not UNSET:
263
+ update["locked"] = locked
264
+ orca_api.PATCH("/regression_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
265
+ self.refresh()
266
+
267
+ def lock(self) -> None:
268
+ """Lock the model to prevent accidental deletion"""
269
+ self.set(locked=True)
270
+
271
+ def unlock(self) -> None:
272
+ """Unlock the model to allow deletion"""
273
+ self.set(locked=False)
274
+
275
+ @overload
276
+ def predict(
277
+ self,
278
+ value: str,
279
+ expected_scores: float | None = None,
280
+ tags: set[str] | None = None,
281
+ save_telemetry: TelemetryMode = "on",
282
+ prompt: str | None = None,
283
+ use_lookup_cache: bool = True,
284
+ timeout_seconds: int = 10,
285
+ ) -> RegressionPrediction: ...
286
+
287
+ @overload
288
+ def predict(
289
+ self,
290
+ value: list[str],
291
+ expected_scores: list[float] | None = None,
292
+ tags: set[str] | None = None,
293
+ save_telemetry: TelemetryMode = "on",
294
+ prompt: str | None = None,
295
+ use_lookup_cache: bool = True,
296
+ timeout_seconds: int = 10,
297
+ ) -> list[RegressionPrediction]: ...
298
+
299
+ # TODO: add filter support
300
+ def predict(
301
+ self,
302
+ value: str | list[str],
303
+ expected_scores: float | list[float] | None = None,
304
+ tags: set[str] | None = None,
305
+ save_telemetry: TelemetryMode = "on",
306
+ prompt: str | None = None,
307
+ use_lookup_cache: bool = True,
308
+ timeout_seconds: int = 10,
309
+ ) -> RegressionPrediction | list[RegressionPrediction]:
310
+ """
311
+ Make predictions using the regression model.
312
+
313
+ Params:
314
+ value: Input text(s) to predict scores for
315
+ expected_scores: Expected score(s) for telemetry tracking
316
+ tags: Tags to associate with the prediction(s)
317
+ save_telemetry: Whether to save telemetry for the prediction(s), defaults to `True`,
318
+ which will save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
319
+ environment variable is set to `"1"`. You can also pass `"sync"` or `"async"` to
320
+ explicitly set the save mode.
321
+ prompt: Optional prompt for instruction-tuned embedding models
322
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
323
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
324
+
325
+ Returns:
326
+ Single RegressionPrediction or list of RegressionPrediction objects
327
+
328
+ Raises:
329
+ ValueError: If expected_scores length doesn't match value length for batch predictions
330
+ ValueError: If timeout_seconds is not a positive integer
331
+ TimeoutError: If the request times out after the specified duration
332
+ """
333
+ if timeout_seconds <= 0:
334
+ raise ValueError("timeout_seconds must be a positive integer")
335
+
336
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
337
+ response = orca_api.POST(
338
+ "/gpu/regression_model/{name_or_id}/prediction",
339
+ params={"name_or_id": self.id},
340
+ json={
341
+ "input_values": value if isinstance(value, list) else [value],
342
+ "memoryset_override_name_or_id": self._memoryset_override_id,
343
+ "expected_scores": (
344
+ expected_scores
345
+ if isinstance(expected_scores, list)
346
+ else [expected_scores] if expected_scores is not None else None
347
+ ),
348
+ "tags": list(tags or set()),
349
+ "save_telemetry": telemetry_on,
350
+ "save_telemetry_synchronously": telemetry_sync,
351
+ "prompt": prompt,
352
+ "use_lookup_cache": use_lookup_cache,
353
+ },
354
+ timeout=timeout_seconds,
355
+ )
356
+
357
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
358
+ raise RuntimeError("Failed to save prediction to database.")
359
+
360
+ predictions = [
361
+ RegressionPrediction(
362
+ prediction_id=prediction["prediction_id"],
363
+ label=None,
364
+ label_name=None,
365
+ score=prediction["score"],
366
+ confidence=prediction["confidence"],
367
+ anomaly_score=prediction["anomaly_score"],
368
+ memoryset=self.memoryset,
369
+ model=self,
370
+ logits=None,
371
+ input_value=input_value,
372
+ )
373
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
374
+ ]
375
+ self._last_prediction_was_batch = isinstance(value, list)
376
+ self._last_prediction = predictions[-1]
377
+ return predictions if isinstance(value, list) else predictions[0]
378
+
379
+ def predictions(
380
+ self,
381
+ limit: int = 100,
382
+ offset: int = 0,
383
+ tag: str | None = None,
384
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
385
+ ) -> list[RegressionPrediction]:
386
+ """
387
+ Get a list of predictions made by this model
388
+
389
+ Params:
390
+ limit: Optional maximum number of predictions to return
391
+ offset: Optional offset of the first prediction to return
392
+ tag: Optional tag to filter predictions by
393
+ sort: Optional list of columns and directions to sort the predictions by.
394
+ Predictions can be sorted by `created_at`, `confidence`, `anomaly_score`, or `score`.
395
+
396
+ Returns:
397
+ List of score predictions
398
+
399
+ Examples:
400
+ Get the last 3 predictions:
401
+ >>> predictions = model.predictions(limit=3, sort=[("created_at", "desc")])
402
+ [
403
+ RegressionPrediction({score: 4.5, confidence: 0.95, anomaly_score: 0.1, input_value: 'Great service'}),
404
+ RegressionPrediction({score: 2.0, confidence: 0.90, anomaly_score: 0.1, input_value: 'Poor experience'}),
405
+ RegressionPrediction({score: 3.5, confidence: 0.85, anomaly_score: 0.1, input_value: 'Average'}),
406
+ ]
407
+
408
+ Get second most confident prediction:
409
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
410
+ [RegressionPrediction({score: 4.2, confidence: 0.90, anomaly_score: 0.1, input_value: 'Good service'})]
411
+ """
412
+ predictions = orca_api.POST(
413
+ "/telemetry/prediction",
414
+ json={
415
+ "model_id": self.id,
416
+ "limit": limit,
417
+ "offset": offset,
418
+ "sort": [list(sort_item) for sort_item in sort],
419
+ "tag": tag,
420
+ },
421
+ )
422
+ return [
423
+ RegressionPrediction(
424
+ prediction_id=prediction["prediction_id"],
425
+ label=None,
426
+ label_name=None,
427
+ score=prediction["score"],
428
+ confidence=prediction["confidence"],
429
+ anomaly_score=prediction["anomaly_score"],
430
+ memoryset=self.memoryset,
431
+ model=self,
432
+ telemetry=prediction,
433
+ logits=None,
434
+ input_value=None,
435
+ )
436
+ for prediction in predictions
437
+ if "score" in prediction
438
+ ]
439
+
440
+ def _evaluate_datasource(
441
+ self,
442
+ datasource: Datasource,
443
+ value_column: str,
444
+ score_column: str,
445
+ record_predictions: bool,
446
+ tags: set[str] | None,
447
+ background: bool = False,
448
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
449
+ response = orca_api.POST(
450
+ "/regression_model/{model_name_or_id}/evaluation",
451
+ params={"model_name_or_id": self.id},
452
+ json={
453
+ "datasource_name_or_id": datasource.id,
454
+ "datasource_score_column": score_column,
455
+ "datasource_value_column": value_column,
456
+ "memoryset_override_name_or_id": self._memoryset_override_id,
457
+ "record_telemetry": record_predictions,
458
+ "telemetry_tags": list(tags) if tags else None,
459
+ },
460
+ )
461
+
462
+ def get_value():
463
+ res = orca_api.GET(
464
+ "/regression_model/{model_name_or_id}/evaluation/{task_id}",
465
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
466
+ )
467
+ assert res["result"] is not None
468
+ return RegressionMetrics(
469
+ coverage=res["result"].get("coverage"),
470
+ mse=res["result"].get("mse"),
471
+ rmse=res["result"].get("rmse"),
472
+ mae=res["result"].get("mae"),
473
+ r2=res["result"].get("r2"),
474
+ explained_variance=res["result"].get("explained_variance"),
475
+ loss=res["result"].get("loss"),
476
+ anomaly_score_mean=res["result"].get("anomaly_score_mean"),
477
+ anomaly_score_median=res["result"].get("anomaly_score_median"),
478
+ anomaly_score_variance=res["result"].get("anomaly_score_variance"),
479
+ )
480
+
481
+ job = Job(response["task_id"], get_value)
482
+ return job if background else job.result()
483
+
484
+ def _evaluate_dataset(
485
+ self,
486
+ dataset: Dataset,
487
+ value_column: str,
488
+ score_column: str,
489
+ record_predictions: bool,
490
+ tags: set[str],
491
+ batch_size: int,
492
+ prompt: str | None = None,
493
+ ) -> RegressionMetrics:
494
+ if len(dataset) == 0:
495
+ raise ValueError("Evaluation dataset cannot be empty")
496
+
497
+ if any(x is None for x in dataset[score_column]):
498
+ raise ValueError("Evaluation dataset cannot contain None values in the score column")
499
+
500
+ predictions = [
501
+ prediction
502
+ for i in range(0, len(dataset), batch_size)
503
+ for prediction in self.predict(
504
+ dataset[i : i + batch_size][value_column],
505
+ expected_scores=dataset[i : i + batch_size][score_column],
506
+ tags=tags,
507
+ save_telemetry="sync" if record_predictions else "off",
508
+ prompt=prompt,
509
+ )
510
+ ]
511
+
512
+ return calculate_regression_metrics(
513
+ expected_scores=dataset[score_column],
514
+ predicted_scores=[p.score for p in predictions],
515
+ anomaly_scores=[p.anomaly_score for p in predictions],
516
+ )
517
+
518
+ @overload
519
+ def evaluate(
520
+ self,
521
+ data: Datasource | Dataset,
522
+ *,
523
+ value_column: str = "value",
524
+ score_column: str = "score",
525
+ record_predictions: bool = False,
526
+ tags: set[str] = {"evaluation"},
527
+ batch_size: int = 100,
528
+ prompt: str | None = None,
529
+ background: Literal[True],
530
+ ) -> Job[RegressionMetrics]:
531
+ pass
532
+
533
+ @overload
534
+ def evaluate(
535
+ self,
536
+ data: Datasource | Dataset,
537
+ *,
538
+ value_column: str = "value",
539
+ score_column: str = "score",
540
+ record_predictions: bool = False,
541
+ tags: set[str] = {"evaluation"},
542
+ batch_size: int = 100,
543
+ prompt: str | None = None,
544
+ background: Literal[False] = False,
545
+ ) -> RegressionMetrics:
546
+ pass
547
+
548
+ def evaluate(
549
+ self,
550
+ data: Datasource | Dataset,
551
+ *,
552
+ value_column: str = "value",
553
+ score_column: str = "score",
554
+ record_predictions: bool = False,
555
+ tags: set[str] = {"evaluation"},
556
+ batch_size: int = 100,
557
+ prompt: str | None = None,
558
+ background: bool = False,
559
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
560
+ """
561
+ Evaluate the regression model on a given dataset or datasource
562
+
563
+ Params:
564
+ data: Dataset or Datasource to evaluate the model on
565
+ value_column: Name of the column that contains the input values to the model
566
+ score_column: Name of the column containing the expected scores
567
+ record_predictions: Whether to record [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s for analysis
568
+ tags: Optional tags to add to the recorded [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s
569
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
570
+ prompt: Optional prompt for instruction-tuned embedding models
571
+ background: Whether to run the operation in the background and return a job handle
572
+
573
+ Returns:
574
+ RegressionMetrics containing metrics including MAE, MSE, RMSE, R2, and anomaly score statistics
575
+
576
+ Examples:
577
+ >>> model.evaluate(datasource, value_column="text", score_column="rating")
578
+ RegressionMetrics({
579
+ mae: 0.2500,
580
+ rmse: 0.3536,
581
+ r2: 0.8500,
582
+ anomaly_score: 0.3500 ± 0.0500,
583
+ })
584
+
585
+ >>> # Using with an instruction-tuned embedding model
586
+ >>> model.evaluate(dataset,prompt="Represent this review for rating prediction:")
587
+ RegressionMetrics({
588
+ mae: 0.2000,
589
+ rmse: 0.3000,
590
+ r2: 0.9000,
591
+ anomaly_score: 0.3000 ± 0.0400})
592
+ """
593
+ if isinstance(data, Datasource):
594
+ return self._evaluate_datasource(
595
+ datasource=data,
596
+ value_column=value_column,
597
+ score_column=score_column,
598
+ record_predictions=record_predictions,
599
+ tags=tags,
600
+ background=background,
601
+ )
602
+ elif isinstance(data, Dataset):
603
+ return self._evaluate_dataset(
604
+ dataset=data,
605
+ value_column=value_column,
606
+ score_column=score_column,
607
+ record_predictions=record_predictions,
608
+ tags=tags,
609
+ batch_size=batch_size,
610
+ prompt=prompt,
611
+ )
612
+ else:
613
+ raise ValueError(f"Invalid data type: {type(data)}")
614
+
615
+ @contextmanager
616
+ def use_memoryset(self, memoryset_override: ScoredMemoryset) -> Generator[None, None, None]:
617
+ """
618
+ Temporarily override the memoryset used by the model for predictions
619
+
620
+ Params:
621
+ memoryset_override: Memoryset to override the default memoryset with
622
+
623
+ Examples:
624
+ >>> with model.use_memoryset(ScoredMemoryset.open("my_other_memoryset")):
625
+ ... predictions = model.predict("Rate your experience")
626
+ """
627
+ self._memoryset_override_id = memoryset_override.id
628
+ yield
629
+ self._memoryset_override_id = None
630
+
631
+ @overload
632
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
633
+ pass
634
+
635
+ @overload
636
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
637
+ pass
638
+
639
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
640
+ """
641
+ Record feedback for a list of predictions.
642
+
643
+ We support recording feedback in several categories for each prediction. A
644
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
645
+ the first time feedback with a new name is recorded. Categories are global across models.
646
+ The value type of the category is inferred from the first recorded value. Subsequent
647
+ feedback for the same category must be of the same type.
648
+
649
+ Params:
650
+ feedback: Feedback to record, this should be dictionaries with the following keys:
651
+
652
+ - `category`: Name of the category under which to record the feedback.
653
+ - `value`: Feedback value to record, should be `True` for positive feedback and
654
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
655
+ where negative values indicate negative feedback and positive values indicate
656
+ positive feedback.
657
+ - `comment`: Optional comment to record with the feedback.
658
+
659
+ Examples:
660
+ Record whether predictions were accurate:
661
+ >>> model.record_feedback({
662
+ ... "prediction": p.prediction_id,
663
+ ... "category": "accurate",
664
+ ... "value": abs(p.score - p.expected_score) < 0.5,
665
+ ... } for p in predictions)
666
+
667
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
668
+ >>> model.record_feedback({
669
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
670
+ ... "category": "rating",
671
+ ... "value": -0.5,
672
+ ... "comment": "2 stars"
673
+ ... })
674
+
675
+ Raises:
676
+ ValueError: If the value does not match previous value types for the category, or is a
677
+ [`float`][float] that is not between `-1.0` and `+1.0`.
678
+ """
679
+ orca_api.PUT(
680
+ "/telemetry/prediction/feedback",
681
+ json=[
682
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
683
+ ],
684
+ )