orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -172
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +337 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
orca_sdk/telemetry.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from abc import ABC
4
5
  from datetime import datetime
5
- from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
6
+ from typing import TYPE_CHECKING, Any, Generator, Iterable, Self, overload
6
7
  from uuid import UUID
7
8
 
8
9
  from orca_sdk._utils.common import UNSET
@@ -23,14 +24,21 @@ from ._generated_api_client.models import (
23
24
  ListPredictionsRequest,
24
25
  PredictionFeedbackCategory,
25
26
  PredictionFeedbackRequest,
27
+ ScorePredictionWithMemoriesAndFeedback,
26
28
  UpdatePredictionRequest,
27
29
  )
28
30
  from ._generated_api_client.types import UNSET as CLIENT_UNSET
29
31
  from ._utils.prediction_result_ui import inspect_prediction_result
30
- from .memoryset import LabeledMemoryLookup, LabeledMemoryset
32
+ from .memoryset import (
33
+ LabeledMemoryLookup,
34
+ LabeledMemoryset,
35
+ ScoredMemoryLookup,
36
+ ScoredMemoryset,
37
+ )
31
38
 
32
39
  if TYPE_CHECKING:
33
40
  from .classification_model import ClassificationModel
41
+ from .regression_model import RegressionModel
34
42
 
35
43
 
36
44
  def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
@@ -107,91 +115,62 @@ class FeedbackCategory:
107
115
  return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"
108
116
 
109
117
 
110
- class LabelPrediction:
111
- """
112
- A prediction made by a model
113
-
114
- Attributes:
115
- prediction_id: Unique identifier for the prediction
116
- label: Predicted label for the input value
117
- label_name: Name of the predicted label
118
- confidence: Confidence of the prediction
119
- anomaly_score: The score for how anomalous the input is relative to the memories
120
- memory_lookups: List of memories used to ground the prediction
121
- input_value: Input value that this prediction was for
122
- model: Model that was used to make the prediction
123
- memoryset: Memoryset that was used to lookup memories to ground the prediction
124
- expected_label: Optional expected label that was set for the prediction
125
- expected_label_name: Name of the expected label
126
- tags: tags that were set for the prediction
127
- feedback: Feedback recorded, mapping from category name to value
128
- explanation: Explanation why the model made the prediction generated by a reasoning agent
129
- """
130
-
118
+ class _Prediction(ABC):
131
119
  prediction_id: str | None
132
- label: int
133
- label_name: str | None
134
120
  confidence: float
135
121
  anomaly_score: float | None
136
- memoryset: LabeledMemoryset
137
- model: ClassificationModel
138
- logits: list[float] | None
139
122
 
140
123
  def __init__(
141
124
  self,
142
125
  prediction_id: str | None,
143
126
  *,
144
- label: int,
127
+ label: int | None,
145
128
  label_name: str | None,
129
+ score: float | None,
146
130
  confidence: float,
147
131
  anomaly_score: float | None,
148
- memoryset: LabeledMemoryset | str,
149
- model: ClassificationModel | str,
150
- telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
132
+ memoryset: LabeledMemoryset | ScoredMemoryset,
133
+ model: ClassificationModel | RegressionModel,
134
+ telemetry: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback | None = None,
151
135
  logits: list[float] | None = None,
152
- input_value: str | list[list[float]] | None = None,
136
+ input_value: str | None = None,
153
137
  ):
154
- # for internal use only, do not document
155
- from .classification_model import ClassificationModel
156
-
157
138
  self.prediction_id = prediction_id
158
139
  self.label = label
159
140
  self.label_name = label_name
141
+ self.score = score
160
142
  self.confidence = confidence
161
143
  self.anomaly_score = anomaly_score
162
- self.memoryset = LabeledMemoryset.open(memoryset) if isinstance(memoryset, str) else memoryset
163
- self.model = ClassificationModel.open(model) if isinstance(model, str) else model
144
+ self.memoryset = memoryset
145
+ self.model = model
164
146
  self.__telemetry = telemetry if telemetry else None
165
147
  self.logits = logits
166
148
  self._input_value = input_value
167
149
 
168
- def __repr__(self):
169
- return (
170
- "LabelPrediction({"
171
- + f"label: <{self.label_name}: {self.label}>, "
172
- + f"confidence: {self.confidence:.2f}, "
173
- + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
174
- + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
175
- + "})"
176
- )
177
-
178
150
  @property
179
- def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback:
151
+ def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback:
180
152
  # for internal use only, do not document
181
153
  if self.__telemetry is None:
182
154
  self.__telemetry = get_prediction(prediction_id=UUID(self.prediction_id))
183
155
  return self.__telemetry
184
156
 
185
157
  @property
186
- def memory_lookups(self) -> list[LabeledMemoryLookup]:
187
- return [LabeledMemoryLookup(self.memoryset.id, lookup) for lookup in self._telemetry.memories]
188
-
189
- @property
190
- def input_value(self) -> str | list[list[float]] | None:
158
+ def input_value(self) -> str:
191
159
  if self._input_value is not None:
192
160
  return self._input_value
161
+ assert isinstance(self._telemetry.input_value, str)
193
162
  return self._telemetry.input_value
194
163
 
164
+ @property
165
+ def memory_lookups(self) -> list[LabeledMemoryLookup] | list[ScoredMemoryLookup]:
166
+ match self._telemetry:
167
+ case LabelPredictionWithMemoriesAndFeedback():
168
+ return [
169
+ LabeledMemoryLookup(self._telemetry.memoryset_id, lookup) for lookup in self._telemetry.memories
170
+ ]
171
+ case ScorePredictionWithMemoriesAndFeedback():
172
+ return [ScoredMemoryLookup(self._telemetry.memoryset_id, lookup) for lookup in self._telemetry.memories]
173
+
195
174
  @property
196
175
  def feedback(self) -> dict[str, bool | float]:
197
176
  return {
@@ -201,14 +180,6 @@ class LabelPrediction:
201
180
  for f in self._telemetry.feedbacks
202
181
  }
203
182
 
204
- @property
205
- def expected_label(self) -> int | None:
206
- return self._telemetry.expected_label
207
-
208
- @property
209
- def expected_label_name(self) -> str | None:
210
- return self._telemetry.expected_label_name
211
-
212
183
  @property
213
184
  def tags(self) -> set[str]:
214
185
  return set(self._telemetry.tags)
@@ -246,16 +217,16 @@ class LabelPrediction:
246
217
 
247
218
  @overload
248
219
  @classmethod
249
- def get(cls, prediction_id: str) -> LabelPrediction: # type: ignore -- this takes precedence
220
+ def get(cls, prediction_id: str) -> Self: # type: ignore -- this takes precedence
250
221
  pass
251
222
 
252
223
  @overload
253
224
  @classmethod
254
- def get(cls, prediction_id: Iterable[str]) -> list[LabelPrediction]:
225
+ def get(cls, prediction_id: Iterable[str]) -> list[Self]:
255
226
  pass
256
227
 
257
228
  @classmethod
258
- def get(cls, prediction_id: str | Iterable[str]) -> LabelPrediction | list[LabelPrediction]:
229
+ def get(cls, prediction_id: str | Iterable[str]) -> Self | list[Self]:
259
230
  """
260
231
  Fetch a prediction or predictions
261
232
 
@@ -303,30 +274,40 @@ class LabelPrediction:
303
274
  }),
304
275
  ]
305
276
  """
306
- if isinstance(prediction_id, str):
307
- prediction = get_prediction(prediction_id=UUID(prediction_id))
277
+ from .classification_model import ClassificationModel
278
+ from .regression_model import RegressionModel
279
+
280
+ def create_prediction(
281
+ prediction: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback,
282
+ ) -> Self:
283
+
284
+ if isinstance(prediction, LabelPredictionWithMemoriesAndFeedback):
285
+ memoryset = LabeledMemoryset.open(prediction.memoryset_id)
286
+ model = ClassificationModel.open(prediction.model_id)
287
+ elif isinstance(prediction, ScorePredictionWithMemoriesAndFeedback):
288
+ memoryset = ScoredMemoryset.open(prediction.memoryset_id)
289
+ model = RegressionModel.open(prediction.model_id)
290
+ else:
291
+ raise ValueError(f"Invalid prediction type: {type(prediction)}")
292
+
308
293
  return cls(
309
294
  prediction_id=prediction.prediction_id,
310
- label=prediction.label,
311
- label_name=prediction.label_name,
295
+ label=getattr(prediction, "label", None),
296
+ label_name=getattr(prediction, "label_name", None),
297
+ score=getattr(prediction, "score", None),
312
298
  confidence=prediction.confidence,
313
299
  anomaly_score=prediction.anomaly_score,
314
- memoryset=prediction.memoryset_id,
315
- model=prediction.model_id,
300
+ memoryset=memoryset,
301
+ model=model,
316
302
  telemetry=prediction,
317
303
  )
304
+
305
+ if isinstance(prediction_id, str):
306
+ prediction = get_prediction(prediction_id=UUID(prediction_id))
307
+ return create_prediction(prediction)
318
308
  else:
319
309
  return [
320
- cls(
321
- prediction_id=prediction.prediction_id,
322
- label=prediction.label,
323
- label_name=prediction.label_name,
324
- confidence=prediction.confidence,
325
- anomaly_score=prediction.anomaly_score,
326
- memoryset=prediction.memoryset_id,
327
- model=prediction.model_id,
328
- telemetry=prediction,
329
- )
310
+ create_prediction(prediction)
330
311
  for prediction in list_predictions(body=ListPredictionsRequest(prediction_ids=list(prediction_id)))
331
312
  ]
332
313
 
@@ -334,30 +315,15 @@ class LabelPrediction:
334
315
  """Refresh the prediction data from the OrcaCloud"""
335
316
  if self.prediction_id is None:
336
317
  raise ValueError("Cannot refresh prediction with no prediction ID")
337
- self.__dict__.update(LabelPrediction.get(self.prediction_id).__dict__)
338
-
339
- def inspect(self):
340
- """Open a UI to inspect the memories used by this prediction"""
341
- inspect_prediction_result(self)
342
-
343
- def update(self, *, expected_label: int | None = UNSET, tags: set[str] | None = UNSET) -> None:
344
- """
345
- Update editable prediction properties.
346
-
347
- Params:
348
- expected_label: Value to set for the expected label, defaults to `[UNSET]` if not provided.
349
- tags: Value to replace existing tags with, defaults to `[UNSET]` if not provided.
318
+ self.__dict__.update(self.get(self.prediction_id).__dict__)
350
319
 
351
- Examples:
352
- Update the expected label:
353
- >>> prediction.update(expected_label=1)
354
-
355
- Add a new tag:
356
- >>> prediction.update(tags=prediction.tags | {"new_tag"})
357
-
358
- Remove expected label and tags:
359
- >>> prediction.update(expected_label=None, tags=None)
360
- """
320
+ def _update(
321
+ self,
322
+ *,
323
+ tags: set[str] | None = UNSET,
324
+ expected_label: int | None = UNSET,
325
+ expected_score: float | None = UNSET,
326
+ ) -> None:
361
327
  if self.prediction_id is None:
362
328
  raise ValueError("Cannot update prediction with no prediction ID")
363
329
 
@@ -365,6 +331,7 @@ class LabelPrediction:
365
331
  prediction_id=self.prediction_id,
366
332
  body=UpdatePredictionRequest(
367
333
  expected_label=expected_label if expected_label is not UNSET else CLIENT_UNSET,
334
+ expected_score=expected_score if expected_score is not UNSET else CLIENT_UNSET,
368
335
  tags=[] if tags is None else list(tags) if tags is not UNSET else CLIENT_UNSET,
369
336
  ),
370
337
  )
@@ -377,7 +344,7 @@ class LabelPrediction:
377
344
  Params:
378
345
  tag: Tag to add to the prediction
379
346
  """
380
- self.update(tags=self.tags | {tag})
347
+ self._update(tags=self.tags | {tag})
381
348
 
382
349
  def remove_tag(self, tag: str) -> None:
383
350
  """
@@ -386,7 +353,7 @@ class LabelPrediction:
386
353
  Params:
387
354
  tag: Tag to remove from the prediction
388
355
  """
389
- self.update(tags=self.tags - {tag})
356
+ self._update(tags=self.tags - {tag})
390
357
 
391
358
  def record_feedback(
392
359
  self,
@@ -448,3 +415,153 @@ class LabelPrediction:
448
415
  body=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)]
449
416
  )
450
417
  self.refresh()
418
+
419
+ def inspect(self) -> None:
420
+ """
421
+ Display an interactive UI with the details about this prediction
422
+
423
+ Params:
424
+ **kwargs: Additional keyword arguments to pass to the display function
425
+
426
+ Note:
427
+ This method is only available in Jupyter notebooks.
428
+ """
429
+ inspect_prediction_result(self) # type: ignore
430
+
431
+
432
+ class ClassificationPrediction(_Prediction):
433
+ """
434
+ Labeled prediction result from a [`ClassificationModel`][orca_sdk.ClassificationModel]
435
+
436
+ Attributes:
437
+ prediction_id: Unique identifier of this prediction used for feedback
438
+ label: Label predicted by the model
439
+ label_name: Human-readable name of the label
440
+ confidence: Confidence of the prediction
441
+ anomaly_score: Anomaly score of the input
442
+ input_value: The input value used for the prediction
443
+ expected_label: Expected label for the prediction, useful when evaluating the model
444
+ expected_label_name: Human-readable name of the expected label
445
+ memory_lookups: Memories used by the model to make the prediction
446
+ explanation: Natural language explanation of the prediction, only available if the model
447
+ has the Explain API enabled
448
+ tags: Tags for the prediction, useful for filtering and grouping predictions
449
+ model: Model used to make the prediction
450
+ memoryset: Memoryset that was used to lookup memories to ground the prediction
451
+ """
452
+
453
+ label: int
454
+ label_name: str
455
+ logits: list[float] | None
456
+ model: ClassificationModel
457
+ memoryset: LabeledMemoryset
458
+
459
+ def __repr__(self):
460
+ return (
461
+ "ClassificationPrediction({"
462
+ + f"label: <{self.label_name}: {self.label}>, "
463
+ + f"confidence: {self.confidence:.2f}, "
464
+ + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
465
+ + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
466
+ + "})"
467
+ )
468
+
469
+ @property
470
+ def memory_lookups(self) -> list[LabeledMemoryLookup]:
471
+ assert isinstance(self._telemetry, LabelPredictionWithMemoriesAndFeedback)
472
+ return [LabeledMemoryLookup(self._telemetry.memoryset_id, lookup) for lookup in self._telemetry.memories]
473
+
474
+ @property
475
+ def expected_label(self) -> int | None:
476
+ assert isinstance(self._telemetry, LabelPredictionWithMemoriesAndFeedback)
477
+ return self._telemetry.expected_label
478
+
479
+ @property
480
+ def expected_label_name(self) -> str | None:
481
+ assert isinstance(self._telemetry, LabelPredictionWithMemoriesAndFeedback)
482
+ return self._telemetry.expected_label_name
483
+
484
+ def update(
485
+ self,
486
+ *,
487
+ tags: set[str] | None = UNSET,
488
+ expected_label: int | None = UNSET,
489
+ ) -> None:
490
+ """
491
+ Update the prediction.
492
+
493
+ Note:
494
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
495
+
496
+ Params:
497
+ tags: New tags to set for the prediction. Set to `None` to remove all tags.
498
+ expected_label: New expected label to set for the prediction. Set to `None` to remove.
499
+ """
500
+ self._update(tags=tags, expected_label=expected_label)
501
+
502
+
503
+ class RegressionPrediction(_Prediction):
504
+ """
505
+ Score-based prediction result from a [`RegressionModel`][orca_sdk.RegressionModel]
506
+
507
+ Attributes:
508
+ prediction_id: Unique identifier of this prediction used for feedback
509
+ score: Score predicted by the model
510
+ confidence: Confidence of the prediction
511
+ anomaly_score: Anomaly score of the input
512
+ input_value: The input value used for the prediction
513
+ expected_score: Expected score for the prediction, useful when evaluating the model
514
+ memory_lookups: Memories used by the model to make the prediction
515
+ explanation: Natural language explanation of the prediction, only available if the model
516
+ has the Explain API enabled
517
+ tags: Tags for the prediction, useful for filtering and grouping predictions
518
+ model: Model used to make the prediction
519
+ memoryset: Memoryset that was used to lookup memories to ground the prediction
520
+ """
521
+
522
+ score: float
523
+ model: RegressionModel
524
+ memoryset: ScoredMemoryset
525
+
526
+ def __repr__(self):
527
+ return (
528
+ "RegressionPrediction({"
529
+ + f"score: {self.score:.2f}, "
530
+ + f"confidence: {self.confidence:.2f}, "
531
+ + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
532
+ + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
533
+ + "})"
534
+ )
535
+
536
+ @property
537
+ def memory_lookups(self) -> list[ScoredMemoryLookup]:
538
+ assert isinstance(self._telemetry, ScorePredictionWithMemoriesAndFeedback)
539
+ return [ScoredMemoryLookup(self._telemetry.memoryset_id, lookup) for lookup in self._telemetry.memories]
540
+
541
+ @property
542
+ def explanation(self) -> str:
543
+ """The explanation for this prediction. Requires `lighthouse_client_api_key` to be set."""
544
+ raise NotImplementedError("Explanation is not supported for regression predictions")
545
+
546
+ @property
547
+ def expected_score(self) -> float | None:
548
+ assert isinstance(self._telemetry, ScorePredictionWithMemoriesAndFeedback)
549
+ return self._telemetry.expected_score
550
+
551
+ def update(
552
+ self,
553
+ *,
554
+ tags: set[str] | None = UNSET,
555
+ expected_score: float | None = UNSET,
556
+ ) -> None:
557
+ """
558
+ Update the prediction.
559
+
560
+ Note:
561
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
562
+
563
+ Params:
564
+ tags: New tags to set for the prediction. Set to `None` to remove all tags.
565
+ expected_score: New expected score to set for the prediction. Set to `None` to remove.
566
+ """
567
+ self._update(tags=tags, expected_score=expected_score)
@@ -2,75 +2,79 @@ import pytest
2
2
 
3
3
  from .classification_model import ClassificationModel
4
4
  from .memoryset import LabeledMemoryLookup
5
- from .telemetry import FeedbackCategory, LabelPrediction
5
+ from .telemetry import ClassificationPrediction, FeedbackCategory
6
6
 
7
7
 
8
- def test_get_prediction(model: ClassificationModel):
9
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
8
+ def test_get_prediction(classification_model: ClassificationModel):
9
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
10
10
  assert len(predictions) == 2
11
11
  assert predictions[0].prediction_id is not None
12
12
  assert predictions[1].prediction_id is not None
13
- prediction_with_telemetry = LabelPrediction.get(predictions[0].prediction_id)
13
+ prediction_with_telemetry = ClassificationPrediction.get(predictions[0].prediction_id)
14
14
  assert prediction_with_telemetry is not None
15
15
  assert prediction_with_telemetry.label == 0
16
16
  assert prediction_with_telemetry.input_value == "Do you love soup?"
17
17
 
18
18
 
19
- def test_get_predictions(model: ClassificationModel):
20
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
19
+ def test_get_predictions(classification_model: ClassificationModel):
20
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
21
21
  assert len(predictions) == 2
22
22
  assert predictions[0].prediction_id is not None
23
23
  assert predictions[1].prediction_id is not None
24
- prediction_with_telemetry = LabelPrediction.get([predictions[0].prediction_id, predictions[1].prediction_id])
24
+ prediction_with_telemetry = ClassificationPrediction.get(
25
+ [predictions[0].prediction_id, predictions[1].prediction_id]
26
+ )
25
27
  assert len(prediction_with_telemetry) == 2
26
28
  assert prediction_with_telemetry[0].label == 0
27
29
  assert prediction_with_telemetry[0].input_value == "Do you love soup?"
28
30
  assert prediction_with_telemetry[1].label == 1
29
31
 
30
32
 
31
- def test_get_predictions_with_expected_label_match(model: ClassificationModel):
32
- model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0, 0], tags={"expected_label_match"})
33
- model.predict("no expectations", tags={"expected_label_match"})
34
- assert len(model.predictions(tag="expected_label_match")) == 3
35
- assert len(model.predictions(expected_label_match=True, tag="expected_label_match")) == 1
36
- assert len(model.predictions(expected_label_match=False, tag="expected_label_match")) == 1
33
+ def test_get_predictions_with_expected_label_match(classification_model: ClassificationModel):
34
+ classification_model.predict(
35
+ ["Do you love soup?", "Are cats cute?"], expected_labels=[0, 0], tags={"expected_label_match"}
36
+ )
37
+ classification_model.predict("no expectations", tags={"expected_label_match"})
38
+ assert len(classification_model.predictions(tag="expected_label_match")) == 3
39
+ assert len(classification_model.predictions(expected_label_match=True, tag="expected_label_match")) == 1
40
+ assert len(classification_model.predictions(expected_label_match=False, tag="expected_label_match")) == 1
37
41
 
38
42
 
39
- def test_get_prediction_memory_lookups(model: ClassificationModel):
40
- prediction = model.predict("Do you love soup?")
43
+ def test_get_prediction_memory_lookups(classification_model: ClassificationModel):
44
+ prediction = classification_model.predict("Do you love soup?")
41
45
  assert isinstance(prediction.memory_lookups, list)
42
46
  assert len(prediction.memory_lookups) > 0
43
47
  assert all(isinstance(lookup, LabeledMemoryLookup) for lookup in prediction.memory_lookups)
44
48
 
45
49
 
46
- def test_record_feedback(model: ClassificationModel):
47
- prediction = model.predict("Do you love soup?")
50
+ def test_record_feedback(classification_model: ClassificationModel):
51
+ prediction = classification_model.predict("Do you love soup?")
48
52
  assert "correct" not in prediction.feedback
49
53
  prediction.record_feedback(category="correct", value=prediction.label == 0)
50
54
  assert prediction.feedback["correct"] is True
51
55
 
52
56
 
53
- def test_record_feedback_with_invalid_value(model: ClassificationModel):
57
+ def test_record_feedback_with_invalid_value(classification_model: ClassificationModel):
54
58
  with pytest.raises(ValueError, match=r"Invalid input.*"):
55
- model.predict("Do you love soup?").record_feedback(category="correct", value="not a bool") # type: ignore
59
+ classification_model.predict("Do you love soup?").record_feedback(category="correct", value="not a bool") # type: ignore
56
60
 
57
61
 
58
- def test_record_feedback_with_inconsistent_value_for_category(model: ClassificationModel):
59
- model.predict("Do you love soup?").record_feedback(category="correct", value=True)
62
+ def test_record_feedback_with_inconsistent_value_for_category(classification_model: ClassificationModel):
63
+ classification_model.predict("Do you love soup?").record_feedback(category="correct", value=True)
60
64
  with pytest.raises(ValueError, match=r"Invalid input.*"):
61
- model.predict("Do you love soup?").record_feedback(category="correct", value=-1.0)
65
+ classification_model.predict("Do you love soup?").record_feedback(category="correct", value=-1.0)
62
66
 
63
67
 
64
- def test_delete_feedback(model: ClassificationModel):
65
- prediction = model.predict("Do you love soup?")
68
+ def test_delete_feedback(classification_model: ClassificationModel):
69
+ prediction = classification_model.predict("Do you love soup?")
66
70
  prediction.record_feedback(category="test_delete", value=True)
67
71
  assert "test_delete" in prediction.feedback
68
72
  prediction.delete_feedback("test_delete")
69
73
  assert "test_delete" not in prediction.feedback
70
74
 
71
75
 
72
- def test_list_feedback_categories(model: ClassificationModel):
73
- prediction = model.predict("Do you love soup?")
76
+ def test_list_feedback_categories(classification_model: ClassificationModel):
77
+ prediction = classification_model.predict("Do you love soup?")
74
78
  prediction.record_feedback(category="correct", value=True)
75
79
  prediction.record_feedback(category="confidence", value=0.8)
76
80
  categories = FeedbackCategory.all()
@@ -79,8 +83,8 @@ def test_list_feedback_categories(model: ClassificationModel):
79
83
  assert any(c.name == "confidence" and c.value_type == float for c in categories)
80
84
 
81
85
 
82
- def test_drop_feedback_category(model: ClassificationModel):
83
- prediction = model.predict("Do you love soup?")
86
+ def test_drop_feedback_category(classification_model: ClassificationModel):
87
+ prediction = classification_model.predict("Do you love soup?")
84
88
  prediction.record_feedback(category="test_category", value=True)
85
89
  assert any(c.name == "test_category" for c in FeedbackCategory.all())
86
90
  FeedbackCategory.drop("test_category")
@@ -89,8 +93,8 @@ def test_drop_feedback_category(model: ClassificationModel):
89
93
  assert "test_category" not in prediction.feedback
90
94
 
91
95
 
92
- def test_update_prediction(model: ClassificationModel):
93
- prediction = model.predict("Do you love soup?")
96
+ def test_update_prediction(classification_model: ClassificationModel):
97
+ prediction = classification_model.predict("Do you love soup?")
94
98
  assert prediction.expected_label is None
95
99
  assert prediction.tags == set()
96
100
  # update expected label
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: orca_sdk
3
- Version: 0.0.94
3
+ Version: 0.0.96
4
4
  Summary: SDK for interacting with Orca Services
5
5
  License: Apache-2.0
6
6
  Author: Orca DB Inc.
@@ -15,14 +15,12 @@ Requires-Dist: attrs (>=22.2.0)
15
15
  Requires-Dist: datasets (>=3.1.0,<4.0.0)
16
16
  Requires-Dist: gradio (==5.13.0)
17
17
  Requires-Dist: httpx (>=0.20.0,<0.29.0)
18
- Requires-Dist: networkx (>=3.4.2,<4.0.0)
19
18
  Requires-Dist: pandas (>=2.2.3,<3.0.0)
20
19
  Requires-Dist: pyarrow (>=18.0.0,<19.0.0)
21
20
  Requires-Dist: python-dateutil (>=2.8.0,<3.0.0)
22
21
  Requires-Dist: python-dotenv (>=1.1.0,<2.0.0)
23
22
  Requires-Dist: scikit-learn (>=1.6.1,<2.0.0)
24
23
  Requires-Dist: torch (>=2.5.1,<3.0.0)
25
- Requires-Dist: transformers (>=4.51.3,<5.0.0)
26
24
  Description-Content-Type: text/markdown
27
25
 
28
26
  <!--
@@ -75,7 +73,7 @@ model = ClassificationModel("my_model", memoryset)
75
73
  prediction = model.predict("my input")
76
74
  ```
77
75
 
78
- For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart/).
76
+ For a more detailed walkthrough, check out our [Quick Start Guide](https://docs.orcadb.ai/quickstart-sdk/).
79
77
 
80
78
  ## Support
81
79