orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
orca_sdk/telemetry.py ADDED
@@ -0,0 +1,742 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from abc import ABC
6
+ from datetime import datetime
7
+ from typing import TYPE_CHECKING, Any, Iterable, Literal, Self, cast, overload
8
+
9
+ from httpx import Timeout
10
+
11
+ from ._utils.common import UNSET
12
+ from .client import (
13
+ LabelPredictionWithMemoriesAndFeedback,
14
+ OrcaClient,
15
+ PredictionFeedbackCategory,
16
+ PredictionFeedbackRequest,
17
+ ScorePredictionWithMemoriesAndFeedback,
18
+ UpdatePredictionRequest,
19
+ )
20
+
21
+ if TYPE_CHECKING:
22
+ from .classification_model import ClassificationModel
23
+ from .memoryset import (
24
+ LabeledMemoryLookup,
25
+ LabeledMemoryset,
26
+ ScoredMemoryLookup,
27
+ ScoredMemoryset,
28
+ )
29
+ from .regression_model import RegressionModel
30
+
31
+ TelemetryMode = Literal["off", "on", "sync", "async"]
32
+ """
33
+ Mode for saving telemetry. One of:
34
+
35
+ - `"off"`: Do not save telemetry
36
+ - `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY` environment variable is set.
37
+ - `"sync"`: Save telemetry synchronously
38
+ - `"async"`: Save telemetry asynchronously
39
+ """
40
+
41
+
42
+ def _get_telemetry_config(override: TelemetryMode | None = None) -> tuple[bool, bool]:
43
+ return (
44
+ override != "off",
45
+ os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or override == "sync",
46
+ )
47
+
48
+
49
+ def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
50
+ category = feedback.get("category", None)
51
+ if category is None:
52
+ raise ValueError("`category` must be specified")
53
+ prediction_id = feedback.get("prediction_id", None)
54
+ if prediction_id is None:
55
+ raise ValueError("`prediction_id` must be specified")
56
+ output: PredictionFeedbackRequest = {
57
+ "prediction_id": prediction_id,
58
+ "category_name": category,
59
+ }
60
+ if "value" in feedback:
61
+ output["value"] = feedback["value"]
62
+ if "comment" in feedback:
63
+ output["comment"] = feedback["comment"]
64
+ return output
65
+
66
+
67
+ class FeedbackCategory:
68
+ """
69
+ A category of feedback for predictions.
70
+
71
+ Categories are created automatically, the first time feedback with a new name is recorded.
72
+ The value type of the category is inferred from the first recorded value. Subsequent feedback
73
+ for the same category must be of the same type. Categories are not model specific.
74
+
75
+ Attributes:
76
+ id: Unique identifier for the category.
77
+ name: Name of the category.
78
+ value_type: Type that values for this category must have.
79
+ created_at: When the category was created.
80
+ """
81
+
82
+ id: str
83
+ name: str
84
+ value_type: type[bool] | type[float]
85
+ created_at: datetime
86
+
87
+ def __init__(self, category: PredictionFeedbackCategory):
88
+ self.id = category["id"]
89
+ self.name = category["name"]
90
+ self.value_type = bool if category["type"] == "BINARY" else float
91
+ self.created_at = datetime.fromisoformat(category["created_at"])
92
+
93
+ @classmethod
94
+ def all(cls) -> list[FeedbackCategory]:
95
+ """
96
+ Get a list of all existing feedback categories.
97
+
98
+ Returns:
99
+ List with information about all existing feedback categories.
100
+ """
101
+ client = OrcaClient._resolve_client()
102
+ return [FeedbackCategory(category) for category in client.GET("/telemetry/feedback_category")]
103
+
104
+ @classmethod
105
+ def drop(cls, name: str) -> None:
106
+ """
107
+ Drop all feedback for this category and drop the category itself, allowing it to be
108
+ recreated with a different value type.
109
+
110
+ Warning:
111
+ This will delete all feedback in this category across all models.
112
+
113
+ Params:
114
+ name: Name of the category to drop.
115
+
116
+ Raises:
117
+ LookupError: If the category is not found.
118
+ """
119
+ client = OrcaClient._resolve_client()
120
+ client.DELETE("/telemetry/feedback_category/{name_or_id}", params={"name_or_id": name})
121
+ logging.info(f"Deleted feedback category {name} with all associated feedback")
122
+
123
+ def __repr__(self):
124
+ return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"
125
+
126
+
127
+ class AddMemorySuggestions:
128
+ suggestions: list[tuple[str, str]]
129
+ memoryset_id: str
130
+ model_id: str
131
+ prediction_id: str
132
+
133
+ def __init__(self, suggestions: list[tuple[str, str]], memoryset_id: str, model_id: str, prediction_id: str):
134
+ self.suggestions = suggestions
135
+ self.memoryset_id = memoryset_id
136
+ self.model_id = model_id
137
+ self.prediction_id = prediction_id
138
+
139
+ def __repr__(self):
140
+ return (
141
+ "AddMemorySuggestions({"
142
+ + f"suggestions: {self.suggestions}, "
143
+ + f"memoryset_id: {self.memoryset_id}, "
144
+ + f"model_id: {self.model_id}, "
145
+ + f"prediction_id: {self.prediction_id}"
146
+ + "})"
147
+ )
148
+
149
+ def apply(self) -> None:
150
+ from .memoryset import LabeledMemoryset
151
+
152
+ memoryset = LabeledMemoryset.open(self.memoryset_id)
153
+ label_name_to_label = {label_name: label for label, label_name in enumerate(memoryset.label_names)}
154
+ memoryset.insert(
155
+ [{"value": suggestion[0], "label": label_name_to_label[suggestion[1]]} for suggestion in self.suggestions]
156
+ )
157
+
158
+
159
+ class PredictionBase(ABC):
160
+ prediction_id: str | None
161
+ confidence: float
162
+ anomaly_score: float | None
163
+
164
+ def __init__(
165
+ self,
166
+ prediction_id: str | None,
167
+ *,
168
+ label: int | None,
169
+ label_name: str | None,
170
+ score: float | None,
171
+ confidence: float,
172
+ anomaly_score: float | None,
173
+ memoryset: LabeledMemoryset | ScoredMemoryset,
174
+ model: ClassificationModel | RegressionModel,
175
+ telemetry: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback | None = None,
176
+ logits: list[float] | None = None,
177
+ input_value: str | None = None,
178
+ ):
179
+ self.prediction_id = prediction_id
180
+ self.label = label
181
+ self.label_name = label_name
182
+ self.score = score
183
+ self.confidence = confidence
184
+ self.anomaly_score = anomaly_score
185
+ self.memoryset = memoryset
186
+ self.model = model
187
+ self.__telemetry = telemetry if telemetry else None
188
+ self.logits = logits
189
+ self._input_value = input_value
190
+
191
+ @property
192
+ def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback:
193
+ # for internal use only, do not document
194
+ if self.__telemetry is None:
195
+ if self.prediction_id is None:
196
+ raise ValueError("Cannot fetch telemetry with no prediction ID")
197
+ client = OrcaClient._resolve_client()
198
+ self.__telemetry = client.GET(
199
+ "/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}
200
+ )
201
+ return self.__telemetry
202
+
203
+ @property
204
+ def input_value(self) -> str:
205
+ if self._input_value is not None:
206
+ return self._input_value
207
+ assert isinstance(self._telemetry["input_value"], str)
208
+ return self._telemetry["input_value"]
209
+
210
+ @property
211
+ def memory_lookups(self) -> list[LabeledMemoryLookup] | list[ScoredMemoryLookup]:
212
+ from .memoryset import LabeledMemoryLookup, ScoredMemoryLookup
213
+
214
+ if "label" in self._telemetry:
215
+ return [
216
+ LabeledMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]
217
+ ]
218
+ else:
219
+ return [
220
+ ScoredMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]
221
+ ]
222
+
223
+ @property
224
+ def feedback(self) -> dict[str, bool | float]:
225
+ feedbacks = self._telemetry.get("feedbacks", [])
226
+ if not feedbacks:
227
+ return {}
228
+
229
+ feedback_by_category: dict[str, bool | float] = {}
230
+ seen_categories: set[str] = set()
231
+ total_categories = len(set(f["category_name"] for f in feedbacks))
232
+
233
+ for f in feedbacks:
234
+ category_name = f["category_name"]
235
+ if category_name not in seen_categories:
236
+ # Convert BINARY (1/0) to boolean, CONTINUOUS to float
237
+ value = f["value"]
238
+ if f["category_type"] == "BINARY":
239
+ value = bool(value)
240
+ else:
241
+ value = float(value)
242
+ feedback_by_category[category_name] = value
243
+ seen_categories.add(category_name)
244
+
245
+ # Early exit once we've found the most recent value for all categories
246
+ if len(seen_categories) == total_categories:
247
+ break
248
+
249
+ return feedback_by_category
250
+
251
+ @property
252
+ def is_correct(self) -> bool:
253
+ if "label" in self._telemetry:
254
+ expected_label = self._telemetry.get("expected_label")
255
+ label = self._telemetry.get("label")
256
+ return expected_label is not None and label is not None and label == expected_label
257
+ else:
258
+ expected_score = self._telemetry.get("expected_score")
259
+ score = self._telemetry.get("score")
260
+ return expected_score is not None and score is not None and abs(score - expected_score) < 0.001
261
+
262
+ @property
263
+ def tags(self) -> set[str]:
264
+ return set(self._telemetry["tags"])
265
+
266
+ @property
267
+ def explanation(self) -> str:
268
+ if self._telemetry["explanation"] is None:
269
+ client = OrcaClient._resolve_client()
270
+ self._telemetry["explanation"] = client.GET(
271
+ "/telemetry/prediction/{prediction_id}/explanation",
272
+ params={"prediction_id": self._telemetry["prediction_id"]},
273
+ parse_as="text",
274
+ timeout=30,
275
+ )
276
+ return self._telemetry["explanation"]
277
+
278
+ def explain(self, refresh: bool = False) -> None:
279
+ """
280
+ Print an explanation of the prediction as a stream of text.
281
+
282
+ Params:
283
+ refresh: Force the explanation agent to re-run even if an explanation already exists.
284
+ """
285
+ if not refresh and self._telemetry["explanation"] is not None:
286
+ print(self._telemetry["explanation"])
287
+ else:
288
+ client = OrcaClient._resolve_client()
289
+ with client.stream(
290
+ "GET",
291
+ f"/telemetry/prediction/{self.prediction_id}/explanation?refresh={refresh}",
292
+ timeout=Timeout(connect=3, read=None),
293
+ ) as res:
294
+ for chunk in res.iter_text():
295
+ print(chunk, end="")
296
+ print() # final newline
297
+
298
+ @overload
299
+ @classmethod
300
+ def get(cls, prediction_id: str) -> Self: # type: ignore -- this takes precedence
301
+ pass
302
+
303
+ @overload
304
+ @classmethod
305
+ def get(cls, prediction_id: Iterable[str]) -> list[Self]:
306
+ pass
307
+
308
+ @classmethod
309
+ def get(cls, prediction_id: str | Iterable[str]) -> Self | list[Self]:
310
+ """
311
+ Fetch a prediction or predictions
312
+
313
+ Params:
314
+ prediction_id: Unique identifier of the prediction or predictions to fetch
315
+
316
+ Returns:
317
+ Prediction or list of predictions
318
+
319
+ Raises:
320
+ LookupError: If no prediction with the given id is found
321
+
322
+ Examples:
323
+ Fetch a single prediction:
324
+ >>> LabelPrediction.get("0195019a-5bc7-7afb-b902-5945ee1fb766")
325
+ LabelPrediction({
326
+ label: <positive: 1>,
327
+ confidence: 0.95,
328
+ anomaly_score: 0.1,
329
+ input_value: "I am happy",
330
+ memoryset: "my_memoryset",
331
+ model: "my_model"
332
+ })
333
+
334
+ Fetch multiple predictions:
335
+ >>> LabelPrediction.get([
336
+ ... "0195019a-5bc7-7afb-b902-5945ee1fb766",
337
+ ... "019501a1-ea08-76b2-9f62-95e4800b4841",
338
+ ... ])
339
+ [
340
+ LabelPrediction({
341
+ label: <positive: 1>,
342
+ confidence: 0.95,
343
+ anomaly_score: 0.1,
344
+ input_value: "I am happy",
345
+ memoryset: "my_memoryset",
346
+ model: "my_model"
347
+ }),
348
+ LabelPrediction({
349
+ label: <negative: 0>,
350
+ confidence: 0.05,
351
+ anomaly_score: 0.2,
352
+ input_value: "I am sad",
353
+ memoryset: "my_memoryset", model: "my_model"
354
+ }),
355
+ ]
356
+ """
357
+ from .classification_model import ClassificationModel
358
+ from .regression_model import RegressionModel
359
+
360
+ def create_prediction(
361
+ prediction: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback,
362
+ ) -> Self:
363
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
364
+
365
+ if "label" in prediction:
366
+ memoryset = LabeledMemoryset.open(prediction["memoryset_id"])
367
+ model = ClassificationModel.open(prediction["model_id"])
368
+ else:
369
+ memoryset = ScoredMemoryset.open(prediction["memoryset_id"])
370
+ model = RegressionModel.open(prediction["model_id"])
371
+
372
+ return cls(
373
+ prediction_id=prediction["prediction_id"],
374
+ label=prediction.get("label", None),
375
+ label_name=prediction.get("label_name", None),
376
+ score=prediction.get("score", None),
377
+ confidence=prediction["confidence"],
378
+ anomaly_score=prediction["anomaly_score"],
379
+ memoryset=memoryset,
380
+ model=model,
381
+ telemetry=prediction,
382
+ )
383
+
384
+ client = OrcaClient._resolve_client()
385
+ if isinstance(prediction_id, str):
386
+ return create_prediction(
387
+ client.GET("/telemetry/prediction/{prediction_id}", params={"prediction_id": prediction_id})
388
+ )
389
+ else:
390
+ return [
391
+ create_prediction(prediction)
392
+ for prediction in client.POST("/telemetry/prediction", json={"prediction_ids": list(prediction_id)})
393
+ ]
394
+
395
+ def refresh(self):
396
+ """Refresh the prediction data from the OrcaCloud"""
397
+ if self.prediction_id is None:
398
+ raise ValueError("Cannot refresh prediction with no prediction ID")
399
+ self.__dict__.update(self.get(self.prediction_id).__dict__)
400
+
401
+ def _update(
402
+ self,
403
+ *,
404
+ tags: set[str] | None = UNSET,
405
+ expected_label: int | None = UNSET,
406
+ expected_score: float | None = UNSET,
407
+ ) -> None:
408
+ if self.prediction_id is None:
409
+ raise ValueError("Cannot update prediction with no prediction ID")
410
+
411
+ payload: UpdatePredictionRequest = {}
412
+ if tags is not UNSET:
413
+ payload["tags"] = [] if tags is None else list(tags)
414
+ if expected_label is not UNSET:
415
+ payload["expected_label"] = expected_label
416
+ if expected_score is not UNSET:
417
+ payload["expected_score"] = expected_score
418
+ client = OrcaClient._resolve_client()
419
+ client.PATCH(
420
+ "/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}, json=payload
421
+ )
422
+ self.refresh()
423
+
424
+ def add_tag(self, tag: str) -> None:
425
+ """
426
+ Add a tag to the prediction
427
+
428
+ Params:
429
+ tag: Tag to add to the prediction
430
+ """
431
+ self._update(tags=self.tags | {tag})
432
+
433
+ def remove_tag(self, tag: str) -> None:
434
+ """
435
+ Remove a tag from the prediction
436
+
437
+ Params:
438
+ tag: Tag to remove from the prediction
439
+ """
440
+ self._update(tags=self.tags - {tag})
441
+
442
+ def record_feedback(
443
+ self,
444
+ category: str,
445
+ value: bool | float,
446
+ *,
447
+ comment: str | None = None,
448
+ ):
449
+ """
450
+ Record feedback for the prediction.
451
+
452
+ We support recording feedback in several categories for each prediction. A
453
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
454
+ the first time feedback with a new name is recorded. Categories are global across models.
455
+ The value type of the category is inferred from the first recorded value. Subsequent
456
+ feedback for the same category must be of the same type.
457
+
458
+ Params:
459
+ category: Name of the category under which to record the feedback.
460
+ value: Feedback value to record, should be `True` for positive feedback and `False` for
461
+ negative feedback or a [`float`][float] between `-1.0` and `+1.0` where negative
462
+ values indicate negative feedback and positive values indicate positive feedback.
463
+ comment: Optional comment to record with the feedback.
464
+
465
+ Examples:
466
+ Record whether a suggestion was accepted or rejected:
467
+ >>> prediction.record_feedback("accepted", True)
468
+
469
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
470
+ >>> prediction.record_feedback("rating", -0.5, comment="2 stars")
471
+
472
+ Raises:
473
+ ValueError: If the value does not match previous value types for the category, or is a
474
+ [`float`][float] that is not between `-1.0` and `+1.0`.
475
+ """
476
+ client = OrcaClient._resolve_client()
477
+ client.PUT(
478
+ "/telemetry/prediction/feedback",
479
+ json=[
480
+ _parse_feedback(
481
+ {"prediction_id": self.prediction_id, "category": category, "value": value, "comment": comment}
482
+ )
483
+ ],
484
+ )
485
+ self.refresh()
486
+
487
+ def delete_feedback(self, category: str) -> None:
488
+ """
489
+ Delete prediction feedback for a specific category.
490
+
491
+ Params:
492
+ category: Name of the category of the feedback to delete.
493
+
494
+ Raises:
495
+ ValueError: If the category is not found.
496
+ """
497
+ if self.prediction_id is None:
498
+ raise ValueError("Cannot delete feedback with no prediction ID")
499
+
500
+ client = OrcaClient._resolve_client()
501
+ client.PUT(
502
+ "/telemetry/prediction/feedback",
503
+ json=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)],
504
+ )
505
+ self.refresh()
506
+
507
+ def inspect(self) -> None:
508
+ """
509
+ Display an interactive UI with the details about this prediction
510
+
511
+ Note:
512
+ This method is only available in Jupyter notebooks.
513
+ """
514
+ from ._utils.prediction_result_ui import inspect_prediction_result
515
+
516
+ inspect_prediction_result(self)
517
+
518
+
519
+ class ClassificationPrediction(PredictionBase):
520
+ """
521
+ Labeled prediction result from a [`ClassificationModel`][orca_sdk.ClassificationModel]
522
+
523
+ Attributes:
524
+ prediction_id: Unique identifier of this prediction used for feedback
525
+ label: Label predicted by the model
526
+ label_name: Human-readable name of the label
527
+ confidence: Confidence of the prediction
528
+ anomaly_score: Anomaly score of the input
529
+ input_value: The input value used for the prediction
530
+ expected_label: Expected label for the prediction, useful when evaluating the model
531
+ expected_label_name: Human-readable name of the expected label
532
+ memory_lookups: Memories used by the model to make the prediction
533
+ explanation: Natural language explanation of the prediction, only available if the model
534
+ has the Explain API enabled
535
+ tags: Tags for the prediction, useful for filtering and grouping predictions
536
+ model: Model used to make the prediction
537
+ memoryset: Memoryset that was used to lookup memories to ground the prediction
538
+ """
539
+
540
+ label: int
541
+ label_name: str
542
+ logits: list[float] | None
543
+ model: ClassificationModel
544
+ memoryset: LabeledMemoryset
545
+
546
+ def __repr__(self):
547
+ return (
548
+ "ClassificationPrediction({"
549
+ + f"label: <{self.label_name}: {self.label}>, "
550
+ + f"confidence: {self.confidence:.2f}, "
551
+ + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
552
+ + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
553
+ + "})"
554
+ )
555
+
556
+ @property
557
+ def memory_lookups(self) -> list[LabeledMemoryLookup]:
558
+ from .memoryset import LabeledMemoryLookup
559
+
560
+ assert "label" in self._telemetry
561
+ return [LabeledMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]]
562
+
563
+ @property
564
+ def expected_label(self) -> int | None:
565
+ assert "label" in self._telemetry
566
+ return self._telemetry["expected_label"]
567
+
568
+ @property
569
+ def expected_label_name(self) -> str | None:
570
+ assert "label" in self._telemetry
571
+ return self._telemetry["expected_label_name"]
572
+
573
+ def update(
574
+ self,
575
+ *,
576
+ tags: set[str] | None = UNSET,
577
+ expected_label: int | None = UNSET,
578
+ ) -> None:
579
+ """
580
+ Update the prediction.
581
+
582
+ Note:
583
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
584
+
585
+ Params:
586
+ tags: New tags to set for the prediction. Set to `None` to remove all tags.
587
+ expected_label: New expected label to set for the prediction. Set to `None` to remove.
588
+ """
589
+ self._update(tags=tags, expected_label=expected_label)
590
+
591
+ def recommend_action(self, *, refresh: bool = False) -> tuple[str, str]:
592
+ """
593
+ Get an action recommendation for improving this prediction.
594
+
595
+ Analyzes the prediction and suggests the most effective action to improve model
596
+ performance, such as adding memories, detecting mislabels, removing duplicates,
597
+ or finetuning.
598
+
599
+ Params:
600
+ refresh: Force the action recommendation agent to re-run even if a recommendation already exists
601
+
602
+ Returns:
603
+ Tuple of (action, rationale) where:
604
+ - action: The recommended action ("add_memories", "detect_mislabels", "remove_duplicates", or "finetuning") that would resolve the mislabeling
605
+ - rationale: Explanation for why this action was recommended
606
+
607
+ Raises:
608
+ ValueError: If the prediction has no prediction ID
609
+ RuntimeError: If the lighthouse API key is not configured
610
+
611
+ Examples:
612
+ Get action recommendation for an incorrect prediction:
613
+ >>> action, rationale = prediction.recommend_action()
614
+ >>> print(f"Recommended action: {action}")
615
+ >>> print(f"Rationale: {rationale}")
616
+ """
617
+ if self.prediction_id is None:
618
+ raise ValueError("Cannot get action recommendation with no prediction ID")
619
+
620
+ client = OrcaClient._resolve_client()
621
+ response = client.GET(
622
+ "/telemetry/prediction/{prediction_id}/action",
623
+ params={"prediction_id": self.prediction_id},
624
+ timeout=30,
625
+ )
626
+ return (response["action"], response["rationale"])
627
+
628
+ def generate_memory_suggestions(self, *, num_memories: int = 3) -> AddMemorySuggestions:
629
+ """
630
+ Generate synthetic memory suggestions to improve this prediction.
631
+
632
+ Creates new example memories that are similar to the input but have clearer
633
+ signals for the expected label. These can be added to the memoryset to improve
634
+ model performance on similar inputs.
635
+
636
+ Params:
637
+ num_memories: Number of memory suggestions to generate (default: 3)
638
+
639
+ Returns:
640
+ List of dictionaries that can be directly passed to memoryset.insert().
641
+ Each dictionary contains:
642
+ - "value": The suggested memory text
643
+ - "label": The suggested label as an integer
644
+
645
+ Raises:
646
+ ValueError: If the prediction has no prediction ID
647
+ RuntimeError: If the lighthouse API key is not configured
648
+
649
+ Examples:
650
+ Generate memory suggestions for an incorrect prediction:
651
+ >>> suggestions = prediction.generate_memory_suggestions(num_memories=3)
652
+ >>> for suggestion in suggestions:
653
+ ... print(f"Value: {suggestion['value']}, Label: {suggestion['label']}")
654
+ >>>
655
+ >>> # Add suggestions directly to memoryset
656
+ >>> model.memoryset.insert(suggestions)
657
+ """
658
+ if self.prediction_id is None:
659
+ raise ValueError("Cannot generate memory suggestions with no prediction ID")
660
+
661
+ client = OrcaClient._resolve_client()
662
+ response = client.GET(
663
+ "/telemetry/prediction/{prediction_id}/memory_suggestions",
664
+ params={"prediction_id": self.prediction_id, "num_memories": num_memories},
665
+ timeout=30,
666
+ )
667
+
668
+ return AddMemorySuggestions(
669
+ suggestions=[(m["value"], m["label_name"]) for m in response["memories"]],
670
+ memoryset_id=self.memoryset.id,
671
+ model_id=self.model.id,
672
+ prediction_id=self.prediction_id,
673
+ )
674
+
675
+
676
+ class RegressionPrediction(PredictionBase):
677
+ """
678
+ Score-based prediction result from a [`RegressionModel`][orca_sdk.RegressionModel]
679
+
680
+ Attributes:
681
+ prediction_id: Unique identifier of this prediction used for feedback
682
+ score: Score predicted by the model
683
+ confidence: Confidence of the prediction
684
+ anomaly_score: Anomaly score of the input
685
+ input_value: The input value used for the prediction
686
+ expected_score: Expected score for the prediction, useful when evaluating the model
687
+ memory_lookups: Memories used by the model to make the prediction
688
+ explanation: Natural language explanation of the prediction, only available if the model
689
+ has the Explain API enabled
690
+ tags: Tags for the prediction, useful for filtering and grouping predictions
691
+ model: Model used to make the prediction
692
+ memoryset: Memoryset that was used to lookup memories to ground the prediction
693
+ """
694
+
695
+ score: float
696
+ model: RegressionModel
697
+ memoryset: ScoredMemoryset
698
+
699
+ def __repr__(self):
700
+ return (
701
+ "RegressionPrediction({"
702
+ + f"score: {self.score:.2f}, "
703
+ + f"confidence: {self.confidence:.2f}, "
704
+ + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
705
+ + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
706
+ + "})"
707
+ )
708
+
709
+ @property
710
+ def memory_lookups(self) -> list[ScoredMemoryLookup]:
711
+ from .memoryset import ScoredMemoryLookup
712
+
713
+ assert "score" in self._telemetry
714
+ return [ScoredMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]]
715
+
716
+ @property
717
+ def explanation(self) -> str:
718
+ """The explanation for this prediction. Requires `lighthouse_client_api_key` to be set."""
719
+ raise NotImplementedError("Explanation is not supported for regression predictions")
720
+
721
+ @property
722
+ def expected_score(self) -> float | None:
723
+ assert "score" in self._telemetry
724
+ return self._telemetry["expected_score"]
725
+
726
+ def update(
727
+ self,
728
+ *,
729
+ tags: set[str] | None = UNSET,
730
+ expected_score: float | None = UNSET,
731
+ ) -> None:
732
+ """
733
+ Update the prediction.
734
+
735
+ Note:
736
+ If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
737
+
738
+ Params:
739
+ tags: New tags to set for the prediction. Set to `None` to remove all tags.
740
+ expected_score: New expected score to set for the prediction. Set to `None` to remove.
741
+ """
742
+ self._update(tags=tags, expected_score=expected_score)