orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -172
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +337 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
@@ -0,0 +1,647 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from contextlib import contextmanager
6
+ from datetime import datetime
7
+ from typing import Any, Generator, Iterable, Literal, cast, overload
8
+ from uuid import UUID
9
+
10
+ import numpy as np
11
+ from datasets import Dataset
12
+
13
+ from ._generated_api_client.api import (
14
+ create_regression_model_gpu,
15
+ delete_regression_model,
16
+ evaluate_regression_model,
17
+ get_regression_model,
18
+ get_regression_model_evaluation,
19
+ list_predictions,
20
+ list_regression_models,
21
+ predict_score_gpu,
22
+ record_prediction_feedback,
23
+ update_regression_model,
24
+ )
25
+ from ._generated_api_client.models import (
26
+ CreateRegressionModelRequest,
27
+ ListPredictionsRequest,
28
+ )
29
+ from ._generated_api_client.models import (
30
+ PredictionSortItemItemType0 as PredictionSortColumns,
31
+ )
32
+ from ._generated_api_client.models import (
33
+ PredictionSortItemItemType1 as PredictionSortDirection,
34
+ )
35
+ from ._generated_api_client.models import (
36
+ PredictiveModelUpdate,
37
+ RARHeadType,
38
+ RegressionEvaluationRequest,
39
+ RegressionModelMetadata,
40
+ RegressionPredictionRequest,
41
+ ScorePredictionWithMemoriesAndFeedback,
42
+ )
43
+ from ._generated_api_client.types import UNSET as CLIENT_UNSET
44
+ from ._generated_api_client.types import Response
45
+ from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
46
+ from ._utils.common import UNSET, CreateMode, DropMode
47
+ from .datasource import Datasource
48
+ from .job import Job
49
+ from .memoryset import ScoredMemoryset
50
+ from .telemetry import FeedbackCategory, RegressionPrediction, _parse_feedback
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ class RegressionModel:
56
+ """
57
+ A handle to a regression model in OrcaCloud
58
+
59
+ Attributes:
60
+ id: Unique identifier for the model
61
+ name: Unique name of the model
62
+ description: Optional description of the model
63
+ memoryset: Memoryset that the model uses
64
+ head_type: Regression head type of the model
65
+ memory_lookup_count: Number of memories the model uses for each prediction
66
+ locked: Whether the model is locked to prevent accidental deletion
67
+ created_at: When the model was created
68
+ updated_at: When the model was last updated
69
+ """
70
+
71
+ id: str
72
+ name: str
73
+ description: str | None
74
+ memoryset: ScoredMemoryset
75
+ head_type: RARHeadType
76
+ memory_lookup_count: int
77
+ version: int
78
+ locked: bool
79
+ created_at: datetime
80
+ updated_at: datetime
81
+ memoryset_id: str
82
+
83
+ _last_prediction: RegressionPrediction | None
84
+ _last_prediction_was_batch: bool
85
+ _memoryset_override_id: str | None
86
+
87
+ def __init__(self, metadata: RegressionModelMetadata):
88
+ # for internal use only, do not document
89
+ self.id = metadata.id
90
+ self.name = metadata.name
91
+ self.description = metadata.description
92
+ self.memoryset = ScoredMemoryset.open(metadata.memoryset_id)
93
+ self.head_type = metadata.head_type
94
+ self.memory_lookup_count = metadata.memory_lookup_count
95
+ self.version = metadata.version
96
+ self.locked = metadata.locked
97
+ self.created_at = metadata.created_at
98
+ self.updated_at = metadata.updated_at
99
+ self.memoryset_id = metadata.memoryset_id
100
+
101
+ self._memoryset_override_id = None
102
+ self._last_prediction = None
103
+ self._last_prediction_was_batch = False
104
+
105
+ def __eq__(self, other) -> bool:
106
+ return isinstance(other, RegressionModel) and self.id == other.id
107
+
108
+ def __repr__(self):
109
+ return (
110
+ "RegressionModel({\n"
111
+ f" name: '{self.name}',\n"
112
+ f" head_type: {self.head_type},\n"
113
+ f" memory_lookup_count: {self.memory_lookup_count},\n"
114
+ f" memoryset: ScoredMemoryset.open('{self.memoryset.name}'),\n"
115
+ "})"
116
+ )
117
+
118
+ @property
119
+ def last_prediction(self) -> RegressionPrediction:
120
+ """
121
+ Last prediction made by the model
122
+
123
+ Note:
124
+ If the last prediction was part of a batch prediction, the last prediction from the
125
+ batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
126
+ is raised.
127
+ """
128
+ if self._last_prediction_was_batch:
129
+ logging.warning(
130
+ "Last prediction was part of a batch prediction, returning the last prediction from the batch"
131
+ )
132
+ if self._last_prediction is None:
133
+ raise LookupError("No prediction has been made yet")
134
+ return self._last_prediction
135
+
136
+ @classmethod
137
+ def create(
138
+ cls,
139
+ name: str,
140
+ memoryset: ScoredMemoryset,
141
+ memory_lookup_count: int | None = None,
142
+ description: str | None = None,
143
+ if_exists: CreateMode = "error",
144
+ ) -> RegressionModel:
145
+ """
146
+ Create a regression model.
147
+
148
+ Params:
149
+ name: Name of the model
150
+ memoryset: The scored memoryset to use for prediction
151
+ memory_lookup_count: Number of memories to retrieve for prediction. Defaults to 10.
152
+ description: Description of the model
153
+ if_exists: How to handle existing models with the same name
154
+
155
+ Returns:
156
+ RegressionModel instance
157
+
158
+ Raises:
159
+ ValueError: If a model with the same name already exists and if_exists is "error"
160
+ ValueError: If the memoryset is empty
161
+ ValueError: If memory_lookup_count exceeds the number of memories in the memoryset
162
+ """
163
+ existing = cls.exists(name)
164
+ if existing:
165
+ if if_exists == "error":
166
+ raise ValueError(f"RegressionModel with name '{name}' already exists")
167
+ elif if_exists == "open":
168
+ existing = cls.open(name)
169
+ for attribute in {"memory_lookup_count"}:
170
+ local_attribute = locals()[attribute]
171
+ existing_attribute = getattr(existing, attribute)
172
+ if local_attribute is not None and local_attribute != existing_attribute:
173
+ raise ValueError(f"Model with name {name} already exists with different {attribute}")
174
+
175
+ # special case for memoryset
176
+ if existing.memoryset_id != memoryset.id:
177
+ raise ValueError(f"Model with name {name} already exists with different memoryset")
178
+
179
+ return existing
180
+
181
+ metadata = create_regression_model_gpu(
182
+ body=CreateRegressionModelRequest(
183
+ name=name,
184
+ memoryset_id=memoryset.id,
185
+ memory_lookup_count=memory_lookup_count,
186
+ description=description,
187
+ )
188
+ )
189
+ return cls(metadata)
190
+
191
+ @classmethod
192
+ def open(cls, name: str) -> RegressionModel:
193
+ """
194
+ Get a handle to a regression model in the OrcaCloud
195
+
196
+ Params:
197
+ name: Name or unique identifier of the regression model
198
+
199
+ Returns:
200
+ Handle to the existing regression model in the OrcaCloud
201
+
202
+ Raises:
203
+ LookupError: If the regression model does not exist
204
+ """
205
+ return cls(get_regression_model(name))
206
+
207
+ @classmethod
208
+ def exists(cls, name_or_id: str) -> bool:
209
+ """
210
+ Check if a regression model exists in the OrcaCloud
211
+
212
+ Params:
213
+ name_or_id: Name or id of the regression model
214
+
215
+ Returns:
216
+ `True` if the regression model exists, `False` otherwise
217
+ """
218
+ try:
219
+ cls.open(name_or_id)
220
+ return True
221
+ except LookupError:
222
+ return False
223
+
224
+ @classmethod
225
+ def all(cls) -> list[RegressionModel]:
226
+ """
227
+ Get a list of handles to all regression models in the OrcaCloud
228
+
229
+ Returns:
230
+ List of handles to all regression models in the OrcaCloud
231
+ """
232
+ return [cls(metadata) for metadata in list_regression_models()]
233
+
234
+ @classmethod
235
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
236
+ """
237
+ Delete a regression model from the OrcaCloud
238
+
239
+ Warning:
240
+ This will delete the model and all associated data, including predictions, evaluations, and feedback.
241
+
242
+ Params:
243
+ name_or_id: Name or id of the regression model
244
+ if_not_exists: What to do if the regression model does not exist, defaults to `"error"`.
245
+ Other option is `"ignore"` to do nothing if the regression model does not exist.
246
+
247
+ Raises:
248
+ LookupError: If the regression model does not exist and if_not_exists is `"error"`
249
+ """
250
+ try:
251
+ delete_regression_model(name_or_id)
252
+ logging.info(f"Deleted model {name_or_id}")
253
+ except LookupError:
254
+ if if_not_exists == "error":
255
+ raise
256
+
257
+ def refresh(self):
258
+ """Refresh the model data from the OrcaCloud"""
259
+ self.__dict__.update(self.open(self.name).__dict__)
260
+
261
+ def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
262
+ """
263
+ Update editable attributes of the model.
264
+
265
+ Note:
266
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
267
+
268
+ Params:
269
+ description: Value to set for the description
270
+ locked: Value to set for the locked status
271
+
272
+ Examples:
273
+ Update the description:
274
+ >>> model.set(description="New description")
275
+
276
+ Remove description:
277
+ >>> model.set(description=None)
278
+
279
+ Lock the model:
280
+ >>> model.set(locked=True)
281
+ """
282
+ update_data = PredictiveModelUpdate(
283
+ description=CLIENT_UNSET if description is UNSET else description,
284
+ locked=CLIENT_UNSET if locked is UNSET else locked,
285
+ )
286
+ update_regression_model(self.id, body=update_data)
287
+ self.refresh()
288
+
289
+ def lock(self) -> None:
290
+ """Lock the model to prevent accidental deletion"""
291
+ self.set(locked=True)
292
+
293
+ def unlock(self) -> None:
294
+ """Unlock the model to allow deletion"""
295
+ self.set(locked=False)
296
+
297
+ @overload
298
+ def predict(
299
+ self,
300
+ value: str,
301
+ expected_scores: float | None = None,
302
+ tags: set[str] | None = None,
303
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
304
+ ) -> RegressionPrediction: ...
305
+
306
+ @overload
307
+ def predict(
308
+ self,
309
+ value: list[str],
310
+ expected_scores: list[float] | None = None,
311
+ tags: set[str] | None = None,
312
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
313
+ ) -> list[RegressionPrediction]: ...
314
+
315
+ # TODO: add filter support
316
+ def predict(
317
+ self,
318
+ value: str | list[str],
319
+ expected_scores: float | list[float] | None = None,
320
+ tags: set[str] | None = None,
321
+ save_telemetry: Literal["off", "on", "sync", "async"] = "on",
322
+ ) -> RegressionPrediction | list[RegressionPrediction]:
323
+ """
324
+ Make predictions using the regression model.
325
+
326
+ Params:
327
+ value: Input text(s) to predict scores for
328
+ expected_scores: Expected score(s) for telemetry tracking
329
+ tags: Tags to associate with the prediction(s)
330
+ save_telemetry: Whether to save telemetry for the prediction(s), defaults to `True`,
331
+ which will save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
332
+ environment variable is set to `"1"`. You can also pass `"sync"` or `"async"` to
333
+ explicitly set the save mode.
334
+
335
+ Returns:
336
+ Single RegressionPrediction or list of RegressionPrediction objects
337
+
338
+ Raises:
339
+ ValueError: If expected_scores length doesn't match value length for batch predictions
340
+ """
341
+ response = predict_score_gpu(
342
+ name_or_id=self.name,
343
+ body=RegressionPredictionRequest(
344
+ input_values=value if isinstance(value, list) else [value],
345
+ expected_scores=(
346
+ expected_scores
347
+ if isinstance(expected_scores, list)
348
+ else [expected_scores] if expected_scores is not None else None
349
+ ),
350
+ memoryset_override_id=self._memoryset_override_id,
351
+ tags=list(tags or set()),
352
+ save_telemetry=save_telemetry != "off",
353
+ save_telemetry_synchronously=(
354
+ os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or save_telemetry == "sync"
355
+ ),
356
+ ),
357
+ )
358
+
359
+ if save_telemetry != "off" and any(p.prediction_id is None for p in response):
360
+ raise RuntimeError("Failed to save prediction to database.")
361
+
362
+ predictions = [
363
+ RegressionPrediction(
364
+ prediction_id=prediction.prediction_id,
365
+ label=None,
366
+ label_name=None,
367
+ score=prediction.score,
368
+ confidence=prediction.confidence,
369
+ anomaly_score=prediction.anomaly_score,
370
+ memoryset=self.memoryset,
371
+ model=self,
372
+ logits=None,
373
+ input_value=input_value,
374
+ )
375
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
376
+ ]
377
+ self._last_prediction_was_batch = isinstance(value, list)
378
+ self._last_prediction = predictions[-1]
379
+ return predictions if isinstance(value, list) else predictions[0]
380
+
381
+ def predictions(
382
+ self,
383
+ limit: int = 100,
384
+ offset: int = 0,
385
+ tag: str | None = None,
386
+ sort: list[tuple[PredictionSortColumns, PredictionSortDirection]] = [],
387
+ ) -> list[RegressionPrediction]:
388
+ """
389
+ Get a list of predictions made by this model
390
+
391
+ Params:
392
+ limit: Optional maximum number of predictions to return
393
+ offset: Optional offset of the first prediction to return
394
+ tag: Optional tag to filter predictions by
395
+ sort: Optional list of columns and directions to sort the predictions by.
396
+ Predictions can be sorted by `created_at`, `confidence`, `anomaly_score`, or `score`.
397
+
398
+ Returns:
399
+ List of score predictions
400
+
401
+ Examples:
402
+ Get the last 3 predictions:
403
+ >>> predictions = model.predictions(limit=3, sort=[("created_at", "desc")])
404
+ [
405
+ RegressionPrediction({score: 4.5, confidence: 0.95, anomaly_score: 0.1, input_value: 'Great service'}),
406
+ RegressionPrediction({score: 2.0, confidence: 0.90, anomaly_score: 0.1, input_value: 'Poor experience'}),
407
+ RegressionPrediction({score: 3.5, confidence: 0.85, anomaly_score: 0.1, input_value: 'Average'}),
408
+ ]
409
+
410
+ Get second most confident prediction:
411
+ >>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
412
+ [RegressionPrediction({score: 4.2, confidence: 0.90, anomaly_score: 0.1, input_value: 'Good service'})]
413
+ """
414
+ predictions = list_predictions(
415
+ body=ListPredictionsRequest(
416
+ model_id=self.id,
417
+ limit=limit,
418
+ offset=offset,
419
+ sort=cast(list[list[PredictionSortColumns | PredictionSortDirection]], sort),
420
+ tag=tag,
421
+ ),
422
+ )
423
+ return [
424
+ RegressionPrediction(
425
+ prediction_id=prediction.prediction_id,
426
+ label=None,
427
+ label_name=None,
428
+ score=prediction.score,
429
+ confidence=prediction.confidence,
430
+ anomaly_score=prediction.anomaly_score,
431
+ memoryset=self.memoryset,
432
+ model=self,
433
+ telemetry=prediction,
434
+ logits=None,
435
+ input_value=None,
436
+ )
437
+ for prediction in predictions
438
+ if isinstance(prediction, ScorePredictionWithMemoriesAndFeedback)
439
+ ]
440
+
441
+ def _evaluate_datasource(
442
+ self,
443
+ datasource: Datasource,
444
+ value_column: str,
445
+ score_column: str,
446
+ record_predictions: bool,
447
+ tags: set[str] | None,
448
+ background: bool = False,
449
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
450
+ response = evaluate_regression_model(
451
+ self.id,
452
+ body=RegressionEvaluationRequest(
453
+ datasource_id=datasource.id,
454
+ datasource_score_column=score_column,
455
+ datasource_value_column=value_column,
456
+ memoryset_override_id=self._memoryset_override_id,
457
+ record_telemetry=record_predictions,
458
+ telemetry_tags=list(tags) if tags else None,
459
+ ),
460
+ )
461
+
462
+ job = Job(
463
+ response.task_id,
464
+ lambda: (r := get_regression_model_evaluation(self.id, UUID(response.task_id)).result)
465
+ and RegressionMetrics(**r.to_dict()),
466
+ )
467
+ return job if background else job.result()
468
+
469
+ def _evaluate_dataset(
470
+ self,
471
+ dataset: Dataset,
472
+ value_column: str,
473
+ score_column: str,
474
+ record_predictions: bool,
475
+ tags: set[str],
476
+ batch_size: int,
477
+ ) -> RegressionMetrics:
478
+ predictions = [
479
+ prediction
480
+ for i in range(0, len(dataset), batch_size)
481
+ for prediction in self.predict(
482
+ dataset[i : i + batch_size][value_column],
483
+ expected_scores=dataset[i : i + batch_size][score_column],
484
+ tags=tags,
485
+ save_telemetry="sync" if record_predictions else "off",
486
+ )
487
+ ]
488
+
489
+ return calculate_regression_metrics(
490
+ expected_scores=dataset[score_column],
491
+ predicted_scores=[p.score for p in predictions],
492
+ anomaly_scores=[p.anomaly_score for p in predictions],
493
+ )
494
+
495
+ @overload
496
+ def evaluate(
497
+ self,
498
+ data: Datasource | Dataset,
499
+ *,
500
+ value_column: str = "value",
501
+ score_column: str = "score",
502
+ record_predictions: bool = False,
503
+ tags: set[str] = {"evaluation"},
504
+ batch_size: int = 100,
505
+ background: Literal[True],
506
+ ) -> Job[RegressionMetrics]:
507
+ pass
508
+
509
+ @overload
510
+ def evaluate(
511
+ self,
512
+ data: Datasource | Dataset,
513
+ *,
514
+ value_column: str = "value",
515
+ score_column: str = "score",
516
+ record_predictions: bool = False,
517
+ tags: set[str] = {"evaluation"},
518
+ batch_size: int = 100,
519
+ background: Literal[False] = False,
520
+ ) -> RegressionMetrics:
521
+ pass
522
+
523
+ def evaluate(
524
+ self,
525
+ data: Datasource | Dataset,
526
+ *,
527
+ value_column: str = "value",
528
+ score_column: str = "score",
529
+ record_predictions: bool = False,
530
+ tags: set[str] = {"evaluation"},
531
+ batch_size: int = 100,
532
+ background: bool = False,
533
+ ) -> RegressionMetrics | Job[RegressionMetrics]:
534
+ """
535
+ Evaluate the regression model on a given dataset or datasource
536
+
537
+ Params:
538
+ data: Dataset or Datasource to evaluate the model on
539
+ value_column: Name of the column that contains the input values to the model
540
+ score_column: Name of the column containing the expected scores
541
+ record_predictions: Whether to record [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s for analysis
542
+ tags: Optional tags to add to the recorded [`RegressionPrediction`][orca_sdk.telemetry.RegressionPrediction]s
543
+ batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
544
+ background: Whether to run the operation in the background and return a job handle
545
+
546
+ Returns:
547
+ RegressionMetrics containing metrics including MAE, MSE, RMSE, R2, and anomaly score statistics
548
+
549
+ Examples:
550
+ >>> model.evaluate(datasource, value_column="text", score_column="rating")
551
+ RegressionMetrics({
552
+ mae: 0.2500,
553
+ rmse: 0.3536,
554
+ r2: 0.8500,
555
+ anomaly_score: 0.3500 ± 0.0500,
556
+ })
557
+ """
558
+ if isinstance(data, Datasource):
559
+ return self._evaluate_datasource(
560
+ datasource=data,
561
+ value_column=value_column,
562
+ score_column=score_column,
563
+ record_predictions=record_predictions,
564
+ tags=tags,
565
+ background=background,
566
+ )
567
+ elif isinstance(data, Dataset):
568
+ return self._evaluate_dataset(
569
+ dataset=data,
570
+ value_column=value_column,
571
+ score_column=score_column,
572
+ record_predictions=record_predictions,
573
+ tags=tags,
574
+ batch_size=batch_size,
575
+ )
576
+ else:
577
+ raise ValueError(f"Invalid data type: {type(data)}")
578
+
579
+ @contextmanager
580
+ def use_memoryset(self, memoryset_override: ScoredMemoryset) -> Generator[None, None, None]:
581
+ """
582
+ Temporarily override the memoryset used by the model for predictions
583
+
584
+ Params:
585
+ memoryset_override: Memoryset to override the default memoryset with
586
+
587
+ Examples:
588
+ >>> with model.use_memoryset(ScoredMemoryset.open("my_other_memoryset")):
589
+ ... predictions = model.predict("Rate your experience")
590
+ """
591
+ self._memoryset_override_id = memoryset_override.id
592
+ yield
593
+ self._memoryset_override_id = None
594
+
595
+ @overload
596
+ def record_feedback(self, feedback: dict[str, Any]) -> None:
597
+ pass
598
+
599
+ @overload
600
+ def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
601
+ pass
602
+
603
+ def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
604
+ """
605
+ Record feedback for a list of predictions.
606
+
607
+ We support recording feedback in several categories for each prediction. A
608
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
609
+ the first time feedback with a new name is recorded. Categories are global across models.
610
+ The value type of the category is inferred from the first recorded value. Subsequent
611
+ feedback for the same category must be of the same type.
612
+
613
+ Params:
614
+ feedback: Feedback to record, this should be dictionaries with the following keys:
615
+
616
+ - `category`: Name of the category under which to record the feedback.
617
+ - `value`: Feedback value to record, should be `True` for positive feedback and
618
+ `False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
619
+ where negative values indicate negative feedback and positive values indicate
620
+ positive feedback.
621
+ - `comment`: Optional comment to record with the feedback.
622
+
623
+ Examples:
624
+ Record whether predictions were accurate:
625
+ >>> model.record_feedback({
626
+ ... "prediction": p.prediction_id,
627
+ ... "category": "accurate",
628
+ ... "value": abs(p.score - p.expected_score) < 0.5,
629
+ ... } for p in predictions)
630
+
631
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
632
+ >>> model.record_feedback({
633
+ ... "prediction": "123e4567-e89b-12d3-a456-426614174000",
634
+ ... "category": "rating",
635
+ ... "value": -0.5,
636
+ ... "comment": "2 stars"
637
+ ... })
638
+
639
+ Raises:
640
+ ValueError: If the value does not match previous value types for the category, or is a
641
+ [`float`][float] that is not between `-1.0` and `+1.0`.
642
+ """
643
+ record_prediction_feedback(
644
+ body=[
645
+ _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
646
+ ],
647
+ )