orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,694 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from datetime import datetime
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
7
+
8
+ from datasets import Dataset
9
+
10
+ from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
11
+ from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .client import (
13
+ OrcaClient,
14
+ PredictiveModelUpdate,
15
+ RARHeadType,
16
+ RegressionModelMetadata,
17
+ )
18
+ from .datasource import Datasource
19
+ from .job import Job
20
+ from .memoryset import ScoredMemoryset
21
+ from .telemetry import (
22
+ RegressionPrediction,
23
+ TelemetryMode,
24
+ _get_telemetry_config,
25
+ _parse_feedback,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class RegressionModel:
32
+ """
33
+ A handle to a regression model in OrcaCloud
34
+
35
+ Attributes:
36
+ id: Unique identifier for the model
37
+ name: Unique name of the model
38
+ description: Optional description of the model
39
+ memoryset: Memoryset that the model uses
40
+ head_type: Regression head type of the model
41
+ memory_lookup_count: Number of memories the model uses for each prediction
42
+ locked: Whether the model is locked to prevent accidental deletion
43
+ created_at: When the model was created
44
+ updated_at: When the model was last updated
45
+ """
46
+
47
+ id: str
48
+ name: str
49
+ description: str | None
50
+ memoryset: ScoredMemoryset
51
+ head_type: RARHeadType
52
+ memory_lookup_count: int
53
+ version: int
54
+ locked: bool
55
+ created_at: datetime
56
+ updated_at: datetime
57
+ memoryset_id: str
58
+
59
+ _last_prediction: RegressionPrediction | None
60
+ _last_prediction_was_batch: bool
61
+ _memoryset_override_id: str | None
62
+
63
+ def __init__(self, metadata: RegressionModelMetadata):
64
+ # for internal use only, do not document
65
+ self.id = metadata["id"]
66
+ self.name = metadata["name"]
67
+ self.description = metadata["description"]
68
+ self.memoryset = ScoredMemoryset.open(metadata["memoryset_id"])
69
+ self.head_type = metadata["head_type"]
70
+ self.memory_lookup_count = metadata["memory_lookup_count"]
71
+ self.version = metadata["version"]
72
+ self.locked = metadata["locked"]
73
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
74
+ self.updated_at = datetime.fromisoformat(metadata["updated_at"])
75
+ self.memoryset_id = metadata["memoryset_id"]
76
+
77
+ self._memoryset_override_id = None
78
+ self._last_prediction = None
79
+ self._last_prediction_was_batch = False
80
+
81
+ def __eq__(self, other) -> bool:
82
+ return isinstance(other, RegressionModel) and self.id == other.id
83
+
84
+ def __repr__(self):
85
+ return (
86
+ "RegressionModel({\n"
87
+ f" name: '{self.name}',\n"
88
+ f" head_type: {self.head_type},\n"
89
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
90
+ f" memoryset: ScoredMemoryset.open('{self.memoryset.name}'),\n"
91
+ "})"
92
+ )
93
+
94
+ @property
95
+ def last_prediction(self) -> RegressionPrediction:
96
+ """
97
+ Last prediction made by the model
98
+
99
+ Note:
100
+ If the last prediction was part of a batch prediction, the last prediction from the
101
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
102
+ is raised.
103
+ """
104
+ if self._last_prediction_was_batch:
105
+ logging.warning(
106
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
107
+ )
108
+ if self._last_prediction is None:
109
+ raise LookupError("No prediction has been made yet")
110
+ return self._last_prediction
111
+
112
+ @classmethod
113
+ def create(
114
+ cls,
115
+ name: str,
116
+ memoryset: ScoredMemoryset,
117
+ memory_lookup_count: int | None = None,
118
+ description: str | None = None,
119
+ if_exists: CreateMode = "error",
120
+ ) -> RegressionModel:
121
+ """
122
+ Create a regression model.
123
+
124
+ Params:
125
+ name: Name of the model
126
+ memoryset: The scored memoryset to use for prediction
127
+ memory_lookup_count: Number of memories to retrieve for prediction. Defaults to 10.
128
+ description: Description of the model
129
+ if_exists: How to handle existing models with the same name
130
+
131
+ Returns:
132
+ RegressionModel instance
133
+
134
+ Raises:
135
+ ValueError: If a model with the same name already exists and if_exists is "error"
136
+ ValueError: If the memoryset is empty
137
+ ValueError: If memory_lookup_count exceeds the number of memories in the memoryset
138
+ """
139
+ existing = cls.exists(name)
140
+ if existing:
141
+ if if_exists == "error":
142
+ raise ValueError(f"RegressionModel with name '{name}' already exists")
143
+ elif if_exists == "open":
144
+ existing = cls.open(name)
145
+ for attribute in {"memory_lookup_count"}:
146
+ local_attribute = locals()[attribute]
147
+ existing_attribute = getattr(existing, attribute)
148
+ if local_attribute is not None and local_attribute != existing_attribute:
149
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
150
+
151
+ # special case for memoryset
152
+ if existing.memoryset_id != memoryset.id:
153
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
154
+
155
+ return existing
156
+
157
+ client = OrcaClient._resolve_client()
158
+ metadata = client.POST(
159
+ "/regression_model",
160
+ json={
161
+ "name": name,
162
+ "memoryset_name_or_id": memoryset.id,
163
+ "memory_lookup_count": memory_lookup_count,
164
+ "description": description,
165
+ },
166
+ )
167
+ return cls(metadata)
168
+
169
+ @classmethod
170
+ def open(cls, name: str) -> RegressionModel:
171
+ """
172
+ Get a handle to a regression model in the OrcaCloud
173
+
174
+ Params:
175
+ name: Name or unique identifier of the regression model
176
+
177
+ Returns:
178
+ Handle to the existing regression model in the OrcaCloud
179
+
180
+ Raises:
181
+ LookupError: If the regression model does not exist
182
+ """
183
+ client = OrcaClient._resolve_client()
184
+ return cls(client.GET("/regression_model/{name_or_id}", params={"name_or_id": name}))
185
+
186
+ @classmethod
187
+ def exists(cls, name_or_id: str) -> bool:
188
+ """
189
+ Check if a regression model exists in the OrcaCloud
190
+
191
+ Params:
192
+ name_or_id: Name or id of the regression model
193
+
194
+ Returns:
195
+ `True` if the regression model exists, `False` otherwise
196
+ """
197
+ try:
198
+ cls.open(name_or_id)
199
+ return True
200
+ except LookupError:
201
+ return False
202
+
203
+ @classmethod
204
+ def all(cls) -> list[RegressionModel]:
205
+ """
206
+ Get a list of handles to all regression models in the OrcaCloud
207
+
208
+ Returns:
209
+ List of handles to all regression models in the OrcaCloud
210
+ """
211
+ client = OrcaClient._resolve_client()
212
+ return [cls(metadata) for metadata in client.GET("/regression_model")]
213
+
214
+ @classmethod
215
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
216
+ """
217
+ Delete a regression model from the OrcaCloud
218
+
219
+ Warning:
220
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
221
+
222
+ Params:
223
+ name_or_id: Name or id of the regression model
224
+ if_not_exists: What to do if the regression model does not exist, defaults to `"error"`.
225
+ Other option is `"ignore"` to do nothing if the regression model does not exist.
226
+
227
+ Raises:
228
+ LookupError: If the regression model does not exist and if_not_exists is `"error"`
229
+ """
230
+ try:
231
+ client = OrcaClient._resolve_client()
232
+ client.DELETE("/regression_model/{name_or_id}", params={"name_or_id": name_or_id})
233
+ logging.info(f"Deleted model {name_or_id}")
234
+ except LookupError:
235
+ if if_not_exists == "error":
236
+ raise
237
+
238
+ def refresh(self):
239
+ """Refresh the model data from the OrcaCloud"""
240
+ self.__dict__.update(self.open(self.name).__dict__)
241
+
242
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
243
+ """
244
+ Update editable attributes of the model.
245
+
246
+ Note:
247
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
248
+
249
+ Params:
250
+ description: Value to set for the description
251
+ locked: Value to set for the locked status
252
+
253
+ Examples:
254
+ Update the description:
255
+ >>> model.set(description="New description")
256
+
257
+ Remove description:
258
+ >>> model.set(description=None)
259
+
260
+ Lock the model:
261
+ >>> model.set(locked=True)
262
+ """
263
+ update: PredictiveModelUpdate = {}
264
+ if description is not UNSET:
265
+ update["description"] = description
266
+ if locked is not UNSET:
267
+ update["locked"] = locked
268
+ client = OrcaClient._resolve_client()
269
+ client.PATCH("/regression_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
270
+ self.refresh()
271
+
272
+ def lock(self) -> None:
273
+ """Lock the model to prevent accidental deletion"""
274
+ self.set(locked=True)
275
+
276
+ def unlock(self) -> None:
277
+ """Unlock the model to allow deletion"""
278
+ self.set(locked=False)
279
+
280
+ @overload
281
+ def predict(
282
+ self,
283
+ value: str,
284
+ expected_scores: float | None = None,
285
+ tags: set[str] | None = None,
286
+ save_telemetry: TelemetryMode = "on",
287
+ prompt: str | None = None,
288
+ use_lookup_cache: bool = True,
289
+ timeout_seconds: int = 10,
290
+ ) -> RegressionPrediction: ...
291
+
292
+ @overload
293
+ def predict(
294
+ self,
295
+ value: list[str],
296
+ expected_scores: list[float] | None = None,
297
+ tags: set[str] | None = None,
298
+ save_telemetry: TelemetryMode = "on",
299
+ prompt: str | None = None,
300
+ use_lookup_cache: bool = True,
301
+ timeout_seconds: int = 10,
302
+ ) -> list[RegressionPrediction]: ...
303
+
304
+ # TODO: add filter support
305
+ def predict(
306
+ self,
307
+ value: str | list[str],
308
+ expected_scores: float | list[float] | None = None,
309
+ tags: set[str] | None = None,
310
+ save_telemetry: TelemetryMode = "on",
311
+ prompt: str | None = None,
312
+ use_lookup_cache: bool = True,
313
+ timeout_seconds: int = 10,
314
+ ) -> RegressionPrediction | list[RegressionPrediction]:
315
+ """
316
+ Make predictions using the regression model.
317
+
318
+ Params:
319
+ value: Input text(s) to predict scores for
320
+ expected_scores: Expected score(s) for telemetry tracking
321
+ tags: Tags to associate with the prediction(s)
322
+ save_telemetry: Whether to save telemetry for the prediction(s), defaults to `True`,
323
+ which will save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
324
+ environment variable is set to `"1"`. You can also pass `"sync"` or `"async"` to
325
+ explicitly set the save mode.
326
+ prompt: Optional prompt for instruction-tuned embedding models
327
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
328
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
329
+
330
+ Returns:
331
+ Single RegressionPrediction or list of RegressionPrediction objects
332
+
333
+ Raises:
334
+ ValueError: If expected_scores length doesn't match value length for batch predictions
335
+ ValueError: If timeout_seconds is not a positive integer
336
+ TimeoutError: If the request times out after the specified duration
337
+ """
338
+ if timeout_seconds <= 0:
339
+ raise ValueError("timeout_seconds must be a positive integer")
340
+
341
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
342
+ client = OrcaClient._resolve_client()
343
+ response = client.POST(
344
+ "/gpu/regression_model/{name_or_id}/prediction",
345
+ params={"name_or_id": self.id},
346
+ json={
347
+ "input_values": value if isinstance(value, list) else [value],
348
+ "memoryset_override_name_or_id": self._memoryset_override_id,
349
+ "expected_scores": (
350
+ expected_scores
351
+ if isinstance(expected_scores, list)
352
+ else [expected_scores] if expected_scores is not None else None
353
+ ),
354
+ "tags": list(tags or set()),
355
+ "save_telemetry": telemetry_on,
356
+ "save_telemetry_synchronously": telemetry_sync,
357
+ "prompt": prompt,
358
+ "use_lookup_cache": use_lookup_cache,
359
+ },
360
+ timeout=timeout_seconds,
361
+ )
362
+
363
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
364
+ raise RuntimeError("Failed to save prediction to database.")
365
+
366
+ predictions = [
367
+ RegressionPrediction(
368
+ prediction_id=prediction["prediction_id"],
369
+ label=None,
370
+ label_name=None,
371
+ score=prediction["score"],
372
+ confidence=prediction["confidence"],
373
+ anomaly_score=prediction["anomaly_score"],
374
+ memoryset=self.memoryset,
375
+ model=self,
376
+ logits=None,
377
+ input_value=input_value,
378
+ )
379
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
380
+ ]
381
+ self._last_prediction_was_batch = isinstance(value, list)
382
+ self._last_prediction = predictions[-1]
383
+ return predictions if isinstance(value, list) else predictions[0]
384
+
385
+ def predictions(
386
+ self,
387
+ limit: int = 100,
388
+ offset: int = 0,
389
+ tag: str | None = None,
390
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
391
+ ) -> list[RegressionPrediction]:
392
+ """
393
+ Get a list of predictions made by this model
394
+
395
+ Params:
396
+ limit: Optional maximum number of predictions to return
397
+ offset: Optional offset of the first prediction to return
398
+ tag: Optional tag to filter predictions by
399
+ sort: Optional list of columns and directions to sort the predictions by.
400
+ Predictions can be sorted by `created_at`, `confidence`, `anomaly_score`, or `score`.
401
+
402
+ Returns:
403
+ List of score predictions
404
+
405
+ Examples:
406
+ Get the last 3 predictions:
407
+ >>> predictions = model.predictions(limit=3, sort=[("created_at", "desc")])
408
+ [
409
+ RegressionPrediction({score: 4.5, confidence: 0.95, anomaly_score: 0.1, input_value: 'Great service'}),
410
+ RegressionPrediction({score: 2.0, confidence: 0.90, anomaly_score: 0.1, input_value: 'Poor experience'}),
411
+ RegressionPrediction({score: 3.5, confidence: 0.85, anomaly_score: 0.1, input_value: 'Average'}),
412
+ ]
413
+
414
+ Get second most confident prediction:
415
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
416
+ [RegressionPrediction({score: 4.2, confidence: 0.90, anomaly_score: 0.1, input_value: 'Good service'})]
417
+ """
418
+ client = OrcaClient._resolve_client()
419
+ predictions = client.POST(
420
+ "/telemetry/prediction",
421
+ json={
422
+ "model_id": self.id,
423
+ "limit": limit,
424
+ "offset": offset,
425
+ "sort": [list(sort_item) for sort_item in sort],
426
+ "tag": tag,
427
+ },
428
+ )
429
+ return [
430
+ RegressionPrediction(
431
+ prediction_id=prediction["prediction_id"],
432
+ label=None,
433
+ label_name=None,
434
+ score=prediction["score"],
435
+ confidence=prediction["confidence"],
436
+ anomaly_score=prediction["anomaly_score"],
437
+ memoryset=self.memoryset,
438
+ model=self,
439
+ telemetry=prediction,
440
+ logits=None,
441
+ input_value=None,
442
+ )
443
+ for prediction in predictions
444
+ if "score" in prediction
445
+ ]
446
+
447
+ def _evaluate_datasource(
448
+ self,
449
+ datasource: Datasource,
450
+ value_column: str,
451
+ score_column: str,
452
+ record_predictions: bool,
453
+ tags: set[str] | None,
454
+ background: bool = False,
455
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
456
+ client = OrcaClient._resolve_client()
457
+ response = client.POST(
458
+ "/regression_model/{model_name_or_id}/evaluation",
459
+ params={"model_name_or_id": self.id},
460
+ json={
461
+ "datasource_name_or_id": datasource.id,
462
+ "datasource_score_column": score_column,
463
+ "datasource_value_column": value_column,
464
+ "memoryset_override_name_or_id": self._memoryset_override_id,
465
+ "record_telemetry": record_predictions,
466
+ "telemetry_tags": list(tags) if tags else None,
467
+ },
468
+ )
469
+
470
+ def get_value():
471
+ client = OrcaClient._resolve_client()
472
+ res = client.GET(
473
+ "/regression_model/{model_name_or_id}/evaluation/{task_id}",
474
+ params={"model_name_or_id": self.id, "task_id": response["task_id"]},
475
+ )
476
+ assert res["result"] is not None
477
+ return RegressionMetrics(
478
+ coverage=res["result"].get("coverage"),
479
+ mse=res["result"].get("mse"),
480
+ rmse=res["result"].get("rmse"),
481
+ mae=res["result"].get("mae"),
482
+ r2=res["result"].get("r2"),
483
+ explained_variance=res["result"].get("explained_variance"),
484
+ loss=res["result"].get("loss"),
485
+ anomaly_score_mean=res["result"].get("anomaly_score_mean"),
486
+ anomaly_score_median=res["result"].get("anomaly_score_median"),
487
+ anomaly_score_variance=res["result"].get("anomaly_score_variance"),
488
+ )
489
+
490
+ job = Job(response["task_id"], get_value)
491
+ return job if background else job.result()
492
+
493
+ def _evaluate_dataset(
494
+ self,
495
+ dataset: Dataset,
496
+ value_column: str,
497
+ score_column: str,
498
+ record_predictions: bool,
499
+ tags: set[str],
500
+ batch_size: int,
501
+ prompt: str | None = None,
502
+ ) -> RegressionMetrics:
503
+ if len(dataset) == 0:
504
+ raise ValueError("Evaluation dataset cannot be empty")
505
+
506
+ if any(x is None for x in dataset[score_column]):
507
+ raise ValueError("Evaluation dataset cannot contain None values in the score column")
508
+
509
+ predictions = [
510
+ prediction
511
+ for i in range(0, len(dataset), batch_size)
512
+ for prediction in self.predict(
513
+ dataset[i : i + batch_size][value_column],
514
+ expected_scores=dataset[i : i + batch_size][score_column],
515
+ tags=tags,
516
+ save_telemetry="sync" if record_predictions else "off",
517
+ prompt=prompt,
518
+ )
519
+ ]
520
+
521
+ return calculate_regression_metrics(
522
+ expected_scores=dataset[score_column],
523
+ predicted_scores=[p.score for p in predictions],
524
+ anomaly_scores=[p.anomaly_score for p in predictions],
525
+ )
526
+
527
+ @overload
528
+ def evaluate(
529
+ self,
530
+ data: Datasource | Dataset,
531
+ *,
532
+ value_column: str = "value",
533
+ score_column: str = "score",
534
+ record_predictions: bool = False,
535
+ tags: set[str] = {"evaluation"},
536
+ batch_size: int = 100,
537
+ prompt: str | None = None,
538
+ background: Literal[True],
539
+ ) -> Job[RegressionMetrics]:
540
+ pass
541
+
542
+ @overload
543
+ def evaluate(
544
+ self,
545
+ data: Datasource | Dataset,
546
+ *,
547
+ value_column: str = "value",
548
+ score_column: str = "score",
549
+ record_predictions: bool = False,
550
+ tags: set[str] = {"evaluation"},
551
+ batch_size: int = 100,
552
+ prompt: str | None = None,
553
+ background: Literal[False] = False,
554
+ ) -> RegressionMetrics:
555
+ pass
556
+
557
+ def evaluate(
558
+ self,
559
+ data: Datasource | Dataset,
560
+ *,
561
+ value_column: str = "value",
562
+ score_column: str = "score",
563
+ record_predictions: bool = False,
564
+ tags: set[str] = {"evaluation"},
565
+ batch_size: int = 100,
566
+ prompt: str | None = None,
567
+ background: bool = False,
568
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
569
+ """
570
+ Evaluate the regression model on a given dataset or datasource
571
+
572
+ Params:
573
+ data: Dataset or Datasource to evaluate the model on
574
+ value_column: Name of the column that contains the input values to the model
575
+ score_column: Name of the column containing the expected scores
576
+ record_predictions: Whether to record [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s for analysis
577
+ tags: Optional tags to add to the recorded [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s
578
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
579
+ prompt: Optional prompt for instruction-tuned embedding models
580
+ background: Whether to run the operation in the background and return a job handle
581
+
582
+ Returns:
583
+ RegressionMetrics containing metrics including MAE, MSE, RMSE, R2, and anomaly score statistics
584
+
585
+ Examples:
586
+ >>> model.evaluate(datasource, value_column="text", score_column="rating")
587
+ RegressionMetrics({
588
+ mae: 0.2500,
589
+ rmse: 0.3536,
590
+ r2: 0.8500,
591
+ anomaly_score: 0.3500 ± 0.0500,
592
+ })
593
+
594
+ >>> # Using with an instruction-tuned embedding model
595
+ >>> model.evaluate(dataset,prompt="Represent this review for rating prediction:")
596
+ RegressionMetrics({
597
+ mae: 0.2000,
598
+ rmse: 0.3000,
599
+ r2: 0.9000,
600
+ anomaly_score: 0.3000 ± 0.0400})
601
+ """
602
+ if isinstance(data, Datasource):
603
+ return self._evaluate_datasource(
604
+ datasource=data,
605
+ value_column=value_column,
606
+ score_column=score_column,
607
+ record_predictions=record_predictions,
608
+ tags=tags,
609
+ background=background,
610
+ )
611
+ elif isinstance(data, Dataset):
612
+ return self._evaluate_dataset(
613
+ dataset=data,
614
+ value_column=value_column,
615
+ score_column=score_column,
616
+ record_predictions=record_predictions,
617
+ tags=tags,
618
+ batch_size=batch_size,
619
+ prompt=prompt,
620
+ )
621
+ else:
622
+ raise ValueError(f"Invalid data type: {type(data)}")
623
+
624
+ @contextmanager
625
+ def use_memoryset(self, memoryset_override: ScoredMemoryset) -> Generator[None, None, None]:
626
+ """
627
+ Temporarily override the memoryset used by the model for predictions
628
+
629
+ Params:
630
+ memoryset_override: Memoryset to override the default memoryset with
631
+
632
+ Examples:
633
+ >>> with model.use_memoryset(ScoredMemoryset.open("my_other_memoryset")):
634
+ ... predictions = model.predict("Rate your experience")
635
+ """
636
+ self._memoryset_override_id = memoryset_override.id
637
+ yield
638
+ self._memoryset_override_id = None
639
+
640
+ @overload
641
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
642
+ pass
643
+
644
+ @overload
645
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
646
+ pass
647
+
648
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
649
+ """
650
+ Record feedback for a list of predictions.
651
+
652
+ We support recording feedback in several categories for each prediction. A
653
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
654
+ the first time feedback with a new name is recorded. Categories are global across models.
655
+ The value type of the category is inferred from the first recorded value. Subsequent
656
+ feedback for the same category must be of the same type.
657
+
658
+ Params:
659
+ feedback: Feedback to record, this should be dictionaries with the following keys:
660
+
661
+ - `category`: Name of the category under which to record the feedback.
662
+ - `value`: Feedback value to record, should be `True` for positive feedback and
663
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
664
+ where negative values indicate negative feedback and positive values indicate
665
+ positive feedback.
666
+ - `comment`: Optional comment to record with the feedback.
667
+
668
+ Examples:
669
+ Record whether predictions were accurate:
670
+ >>> model.record_feedback({
671
+ ... "prediction": p.prediction_id,
672
+ ... "category": "accurate",
673
+ ... "value": abs(p.score - p.expected_score) < 0.5,
674
+ ... } for p in predictions)
675
+
676
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
677
+ >>> model.record_feedback({
678
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
679
+ ... "category": "rating",
680
+ ... "value": -0.5,
681
+ ... "comment": "2 stars"
682
+ ... })
683
+
684
+ Raises:
685
+ ValueError: If the value does not match previous value types for the category, or is a
686
+ [`float`][float] that is not between `-1.0` and `+1.0`.
687
+ """
688
+ client = OrcaClient._resolve_client()
689
+ client.PUT(
690
+ "/telemetry/prediction/feedback",
691
+ json=[
692
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
693
+ ],
694
+ )