orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +31 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +601 -129
- orca_sdk/classification_model_test.py +415 -117
- orca_sdk/client.py +3787 -0
- orca_sdk/conftest.py +184 -38
- orca_sdk/credentials.py +162 -20
- orca_sdk/credentials_test.py +100 -16
- orca_sdk/datasource.py +268 -68
- orca_sdk/datasource_test.py +266 -18
- orca_sdk/embedding_model.py +434 -82
- orca_sdk/embedding_model_test.py +66 -33
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1690 -324
- orca_sdk/memoryset_test.py +456 -119
- orca_sdk/regression_model.py +694 -0
- orca_sdk/regression_model_test.py +378 -0
- orca_sdk/telemetry.py +460 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/telemetry.py
CHANGED
|
@@ -1,34 +1,49 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC
|
|
4
6
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Iterable, overload
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from .
|
|
11
|
-
drop_feedback_category_with_data,
|
|
12
|
-
get_prediction,
|
|
13
|
-
list_feedback_categories,
|
|
14
|
-
list_predictions,
|
|
15
|
-
record_prediction_feedback,
|
|
16
|
-
update_prediction,
|
|
17
|
-
)
|
|
18
|
-
from ._generated_api_client.models import (
|
|
19
|
-
FeedbackType,
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Iterable, Literal, Self, overload
|
|
8
|
+
|
|
9
|
+
from httpx import Timeout
|
|
10
|
+
|
|
11
|
+
from ._utils.common import UNSET
|
|
12
|
+
from .client import (
|
|
20
13
|
LabelPredictionWithMemoriesAndFeedback,
|
|
21
|
-
|
|
14
|
+
OrcaClient,
|
|
22
15
|
PredictionFeedbackCategory,
|
|
23
16
|
PredictionFeedbackRequest,
|
|
17
|
+
ScorePredictionWithMemoriesAndFeedback,
|
|
24
18
|
UpdatePredictionRequest,
|
|
25
19
|
)
|
|
26
|
-
from .
|
|
27
|
-
|
|
28
|
-
|
|
20
|
+
from .memoryset import (
|
|
21
|
+
LabeledMemoryLookup,
|
|
22
|
+
LabeledMemoryset,
|
|
23
|
+
ScoredMemoryLookup,
|
|
24
|
+
ScoredMemoryset,
|
|
25
|
+
)
|
|
29
26
|
|
|
30
27
|
if TYPE_CHECKING:
|
|
31
28
|
from .classification_model import ClassificationModel
|
|
29
|
+
from .regression_model import RegressionModel
|
|
30
|
+
|
|
31
|
+
TelemetryMode = Literal["off", "on", "sync", "async"]
|
|
32
|
+
"""
|
|
33
|
+
Mode for saving telemetry. One of:
|
|
34
|
+
|
|
35
|
+
- `"off"`: Do not save telemetry
|
|
36
|
+
- `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY` environment variable is set.
|
|
37
|
+
- `"sync"`: Save telemetry synchronously
|
|
38
|
+
- `"async"`: Save telemetry asynchronously
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_telemetry_config(override: TelemetryMode | None = None) -> tuple[bool, bool]:
|
|
43
|
+
return (
|
|
44
|
+
override != "off",
|
|
45
|
+
os.getenv("ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY", "0") != "0" or override == "sync",
|
|
46
|
+
)
|
|
32
47
|
|
|
33
48
|
|
|
34
49
|
def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
|
|
@@ -38,12 +53,15 @@ def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
|
|
|
38
53
|
prediction_id = feedback.get("prediction_id", None)
|
|
39
54
|
if prediction_id is None:
|
|
40
55
|
raise ValueError("`prediction_id` must be specified")
|
|
41
|
-
|
|
42
|
-
prediction_id
|
|
43
|
-
category_name
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
56
|
+
output: PredictionFeedbackRequest = {
|
|
57
|
+
"prediction_id": prediction_id,
|
|
58
|
+
"category_name": category,
|
|
59
|
+
}
|
|
60
|
+
if "value" in feedback:
|
|
61
|
+
output["value"] = feedback["value"]
|
|
62
|
+
if "comment" in feedback:
|
|
63
|
+
output["comment"] = feedback["comment"]
|
|
64
|
+
return output
|
|
47
65
|
|
|
48
66
|
|
|
49
67
|
class FeedbackCategory:
|
|
@@ -67,11 +85,10 @@ class FeedbackCategory:
|
|
|
67
85
|
created_at: datetime
|
|
68
86
|
|
|
69
87
|
def __init__(self, category: PredictionFeedbackCategory):
|
|
70
|
-
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.created_at = category.created_at
|
|
88
|
+
self.id = category["id"]
|
|
89
|
+
self.name = category["name"]
|
|
90
|
+
self.value_type = bool if category["type"] == "BINARY" else float
|
|
91
|
+
self.created_at = datetime.fromisoformat(category["created_at"])
|
|
75
92
|
|
|
76
93
|
@classmethod
|
|
77
94
|
def all(cls) -> list[FeedbackCategory]:
|
|
@@ -81,7 +98,8 @@ class FeedbackCategory:
|
|
|
81
98
|
Returns:
|
|
82
99
|
List with information about all existing feedback categories.
|
|
83
100
|
"""
|
|
84
|
-
|
|
101
|
+
client = OrcaClient._resolve_client()
|
|
102
|
+
return [FeedbackCategory(category) for category in client.GET("/telemetry/feedback_category")]
|
|
85
103
|
|
|
86
104
|
@classmethod
|
|
87
105
|
def drop(cls, name: str) -> None:
|
|
@@ -98,113 +116,163 @@ class FeedbackCategory:
|
|
|
98
116
|
Raises:
|
|
99
117
|
LookupError: If the category is not found.
|
|
100
118
|
"""
|
|
101
|
-
|
|
102
|
-
|
|
119
|
+
client = OrcaClient._resolve_client()
|
|
120
|
+
client.DELETE("/telemetry/feedback_category/{name_or_id}", params={"name_or_id": name})
|
|
121
|
+
logging.info(f"Deleted feedback category {name} with all associated feedback")
|
|
103
122
|
|
|
104
123
|
def __repr__(self):
|
|
105
124
|
return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"
|
|
106
125
|
|
|
107
126
|
|
|
108
|
-
class
|
|
109
|
-
|
|
110
|
-
|
|
127
|
+
class AddMemorySuggestions:
|
|
128
|
+
suggestions: list[tuple[str, str]]
|
|
129
|
+
memoryset_id: str
|
|
130
|
+
model_id: str
|
|
131
|
+
prediction_id: str
|
|
111
132
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
memory_lookups: List of memories used to ground the prediction
|
|
118
|
-
input_value: Input value that this prediction was for
|
|
119
|
-
model: Model that was used to make the prediction
|
|
120
|
-
memoryset: Memoryset that was used to lookup memories to ground the prediction
|
|
121
|
-
expected_label: Optional expected label that was set for the prediction
|
|
122
|
-
tags: tags that were set for the prediction
|
|
123
|
-
feedback: Feedback recorded, mapping from category name to value
|
|
124
|
-
"""
|
|
133
|
+
def __init__(self, suggestions: list[tuple[str, str]], memoryset_id: str, model_id: str, prediction_id: str):
|
|
134
|
+
self.suggestions = suggestions
|
|
135
|
+
self.memoryset_id = memoryset_id
|
|
136
|
+
self.model_id = model_id
|
|
137
|
+
self.prediction_id = prediction_id
|
|
125
138
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
139
|
+
def __repr__(self):
|
|
140
|
+
return (
|
|
141
|
+
"AddMemorySuggestions({"
|
|
142
|
+
+ f"suggestions: {self.suggestions}, "
|
|
143
|
+
+ f"memoryset_id: {self.memoryset_id}, "
|
|
144
|
+
+ f"model_id: {self.model_id}, "
|
|
145
|
+
+ f"prediction_id: {self.prediction_id}"
|
|
146
|
+
+ "})"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def apply(self) -> None:
|
|
150
|
+
memoryset = LabeledMemoryset.open(self.memoryset_id)
|
|
151
|
+
label_name_to_label = {label_name: label for label, label_name in enumerate(memoryset.label_names)}
|
|
152
|
+
memoryset.insert(
|
|
153
|
+
[{"value": suggestion[0], "label": label_name_to_label[suggestion[1]]} for suggestion in self.suggestions]
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class PredictionBase(ABC):
|
|
158
|
+
prediction_id: str | None
|
|
129
159
|
confidence: float
|
|
130
|
-
|
|
131
|
-
model: ClassificationModel
|
|
160
|
+
anomaly_score: float | None
|
|
132
161
|
|
|
133
162
|
def __init__(
|
|
134
163
|
self,
|
|
135
|
-
prediction_id: str,
|
|
164
|
+
prediction_id: str | None,
|
|
136
165
|
*,
|
|
137
|
-
label: int,
|
|
166
|
+
label: int | None,
|
|
138
167
|
label_name: str | None,
|
|
168
|
+
score: float | None,
|
|
139
169
|
confidence: float,
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
170
|
+
anomaly_score: float | None,
|
|
171
|
+
memoryset: LabeledMemoryset | ScoredMemoryset,
|
|
172
|
+
model: ClassificationModel | RegressionModel,
|
|
173
|
+
telemetry: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback | None = None,
|
|
174
|
+
logits: list[float] | None = None,
|
|
175
|
+
input_value: str | None = None,
|
|
143
176
|
):
|
|
144
|
-
# for internal use only, do not document
|
|
145
|
-
from .classification_model import ClassificationModel
|
|
146
|
-
|
|
147
177
|
self.prediction_id = prediction_id
|
|
148
178
|
self.label = label
|
|
149
179
|
self.label_name = label_name
|
|
180
|
+
self.score = score
|
|
150
181
|
self.confidence = confidence
|
|
151
|
-
self.
|
|
152
|
-
self.
|
|
182
|
+
self.anomaly_score = anomaly_score
|
|
183
|
+
self.memoryset = memoryset
|
|
184
|
+
self.model = model
|
|
153
185
|
self.__telemetry = telemetry if telemetry else None
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
return (
|
|
157
|
-
"LabelPrediction({"
|
|
158
|
-
+ f"label: <{self.label_name}: {self.label}>, "
|
|
159
|
-
+ f"confidence: {self.confidence:.2f}, "
|
|
160
|
-
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
161
|
-
+ "})"
|
|
162
|
-
)
|
|
186
|
+
self.logits = logits
|
|
187
|
+
self._input_value = input_value
|
|
163
188
|
|
|
164
189
|
@property
|
|
165
|
-
def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback:
|
|
190
|
+
def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback:
|
|
166
191
|
# for internal use only, do not document
|
|
167
192
|
if self.__telemetry is None:
|
|
168
|
-
self.
|
|
193
|
+
if self.prediction_id is None:
|
|
194
|
+
raise ValueError("Cannot fetch telemetry with no prediction ID")
|
|
195
|
+
client = OrcaClient._resolve_client()
|
|
196
|
+
self.__telemetry = client.GET(
|
|
197
|
+
"/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}
|
|
198
|
+
)
|
|
169
199
|
return self.__telemetry
|
|
170
200
|
|
|
171
201
|
@property
|
|
172
|
-
def
|
|
173
|
-
|
|
202
|
+
def input_value(self) -> str:
|
|
203
|
+
if self._input_value is not None:
|
|
204
|
+
return self._input_value
|
|
205
|
+
assert isinstance(self._telemetry["input_value"], str)
|
|
206
|
+
return self._telemetry["input_value"]
|
|
174
207
|
|
|
175
208
|
@property
|
|
176
|
-
def
|
|
177
|
-
|
|
209
|
+
def memory_lookups(self) -> list[LabeledMemoryLookup] | list[ScoredMemoryLookup]:
|
|
210
|
+
if "label" in self._telemetry:
|
|
211
|
+
return [
|
|
212
|
+
LabeledMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]
|
|
213
|
+
]
|
|
214
|
+
else:
|
|
215
|
+
return [
|
|
216
|
+
ScoredMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]
|
|
217
|
+
]
|
|
178
218
|
|
|
179
219
|
@property
|
|
180
220
|
def feedback(self) -> dict[str, bool | float]:
|
|
181
221
|
return {
|
|
182
|
-
f
|
|
183
|
-
f
|
|
222
|
+
f["category_name"]: (
|
|
223
|
+
f["value"] if f["category_type"] == "CONTINUOUS" else True if f["value"] == 1 else False
|
|
184
224
|
)
|
|
185
|
-
for f in self._telemetry
|
|
225
|
+
for f in self._telemetry["feedbacks"]
|
|
186
226
|
}
|
|
187
227
|
|
|
188
228
|
@property
|
|
189
|
-
def
|
|
190
|
-
return self._telemetry
|
|
229
|
+
def tags(self) -> set[str]:
|
|
230
|
+
return set(self._telemetry["tags"])
|
|
191
231
|
|
|
192
232
|
@property
|
|
193
|
-
def
|
|
194
|
-
|
|
233
|
+
def explanation(self) -> str:
|
|
234
|
+
if self._telemetry["explanation"] is None:
|
|
235
|
+
client = OrcaClient._resolve_client()
|
|
236
|
+
self._telemetry["explanation"] = client.GET(
|
|
237
|
+
"/telemetry/prediction/{prediction_id}/explanation",
|
|
238
|
+
params={"prediction_id": self._telemetry["prediction_id"]},
|
|
239
|
+
parse_as="text",
|
|
240
|
+
timeout=30,
|
|
241
|
+
)
|
|
242
|
+
return self._telemetry["explanation"]
|
|
243
|
+
|
|
244
|
+
def explain(self, refresh: bool = False) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Print an explanation of the prediction as a stream of text.
|
|
247
|
+
|
|
248
|
+
Params:
|
|
249
|
+
refresh: Force the explanation agent to re-run even if an explanation already exists.
|
|
250
|
+
"""
|
|
251
|
+
if not refresh and self._telemetry["explanation"] is not None:
|
|
252
|
+
print(self._telemetry["explanation"])
|
|
253
|
+
else:
|
|
254
|
+
client = OrcaClient._resolve_client()
|
|
255
|
+
with client.stream(
|
|
256
|
+
"GET",
|
|
257
|
+
f"/telemetry/prediction/{self.prediction_id}/explanation?refresh={refresh}",
|
|
258
|
+
timeout=Timeout(connect=3, read=None),
|
|
259
|
+
) as res:
|
|
260
|
+
for chunk in res.iter_text():
|
|
261
|
+
print(chunk, end="")
|
|
262
|
+
print() # final newline
|
|
195
263
|
|
|
196
264
|
@overload
|
|
197
265
|
@classmethod
|
|
198
|
-
def get(cls, prediction_id: str) ->
|
|
266
|
+
def get(cls, prediction_id: str) -> Self: # type: ignore -- this takes precedence
|
|
199
267
|
pass
|
|
200
268
|
|
|
201
269
|
@overload
|
|
202
270
|
@classmethod
|
|
203
|
-
def get(cls, prediction_id: Iterable[str]) -> list[
|
|
271
|
+
def get(cls, prediction_id: Iterable[str]) -> list[Self]:
|
|
204
272
|
pass
|
|
205
273
|
|
|
206
274
|
@classmethod
|
|
207
|
-
def get(cls, prediction_id: str | Iterable[str]) ->
|
|
275
|
+
def get(cls, prediction_id: str | Iterable[str]) -> Self | list[Self]:
|
|
208
276
|
"""
|
|
209
277
|
Fetch a prediction or predictions
|
|
210
278
|
|
|
@@ -223,6 +291,7 @@ class LabelPrediction:
|
|
|
223
291
|
LabelPrediction({
|
|
224
292
|
label: <positive: 1>,
|
|
225
293
|
confidence: 0.95,
|
|
294
|
+
anomaly_score: 0.1,
|
|
226
295
|
input_value: "I am happy",
|
|
227
296
|
memoryset: "my_memoryset",
|
|
228
297
|
model: "my_model"
|
|
@@ -237,6 +306,7 @@ class LabelPrediction:
|
|
|
237
306
|
LabelPrediction({
|
|
238
307
|
label: <positive: 1>,
|
|
239
308
|
confidence: 0.95,
|
|
309
|
+
anomaly_score: 0.1,
|
|
240
310
|
input_value: "I am happy",
|
|
241
311
|
memoryset: "my_memoryset",
|
|
242
312
|
model: "my_model"
|
|
@@ -244,68 +314,75 @@ class LabelPrediction:
|
|
|
244
314
|
LabelPrediction({
|
|
245
315
|
label: <negative: 0>,
|
|
246
316
|
confidence: 0.05,
|
|
317
|
+
anomaly_score: 0.2,
|
|
247
318
|
input_value: "I am sad",
|
|
248
319
|
memoryset: "my_memoryset", model: "my_model"
|
|
249
320
|
}),
|
|
250
321
|
]
|
|
251
322
|
"""
|
|
252
|
-
|
|
253
|
-
|
|
323
|
+
from .classification_model import ClassificationModel
|
|
324
|
+
from .regression_model import RegressionModel
|
|
325
|
+
|
|
326
|
+
def create_prediction(
|
|
327
|
+
prediction: LabelPredictionWithMemoriesAndFeedback | ScorePredictionWithMemoriesAndFeedback,
|
|
328
|
+
) -> Self:
|
|
329
|
+
|
|
330
|
+
if "label" in prediction:
|
|
331
|
+
memoryset = LabeledMemoryset.open(prediction["memoryset_id"])
|
|
332
|
+
model = ClassificationModel.open(prediction["model_id"])
|
|
333
|
+
else:
|
|
334
|
+
memoryset = ScoredMemoryset.open(prediction["memoryset_id"])
|
|
335
|
+
model = RegressionModel.open(prediction["model_id"])
|
|
336
|
+
|
|
254
337
|
return cls(
|
|
255
|
-
prediction_id=prediction
|
|
256
|
-
label=prediction.label,
|
|
257
|
-
label_name=prediction.label_name,
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
338
|
+
prediction_id=prediction["prediction_id"],
|
|
339
|
+
label=prediction.get("label", None),
|
|
340
|
+
label_name=prediction.get("label_name", None),
|
|
341
|
+
score=prediction.get("score", None),
|
|
342
|
+
confidence=prediction["confidence"],
|
|
343
|
+
anomaly_score=prediction["anomaly_score"],
|
|
344
|
+
memoryset=memoryset,
|
|
345
|
+
model=model,
|
|
261
346
|
telemetry=prediction,
|
|
262
347
|
)
|
|
348
|
+
|
|
349
|
+
client = OrcaClient._resolve_client()
|
|
350
|
+
if isinstance(prediction_id, str):
|
|
351
|
+
return create_prediction(
|
|
352
|
+
client.GET("/telemetry/prediction/{prediction_id}", params={"prediction_id": prediction_id})
|
|
353
|
+
)
|
|
263
354
|
else:
|
|
264
355
|
return [
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
label=prediction.label,
|
|
268
|
-
label_name=prediction.label_name,
|
|
269
|
-
confidence=prediction.confidence,
|
|
270
|
-
memoryset=prediction.memoryset_id,
|
|
271
|
-
model=prediction.model_id,
|
|
272
|
-
telemetry=prediction,
|
|
273
|
-
)
|
|
274
|
-
for prediction in list_predictions(body=ListPredictionsRequest(prediction_ids=list(prediction_id)))
|
|
356
|
+
create_prediction(prediction)
|
|
357
|
+
for prediction in client.POST("/telemetry/prediction", json={"prediction_ids": list(prediction_id)})
|
|
275
358
|
]
|
|
276
359
|
|
|
277
360
|
def refresh(self):
|
|
278
361
|
"""Refresh the prediction data from the OrcaCloud"""
|
|
279
|
-
self.
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
"""Open a UI to inspect the memories used by this prediction"""
|
|
283
|
-
inspect_prediction_result(self)
|
|
362
|
+
if self.prediction_id is None:
|
|
363
|
+
raise ValueError("Cannot refresh prediction with no prediction ID")
|
|
364
|
+
self.__dict__.update(self.get(self.prediction_id).__dict__)
|
|
284
365
|
|
|
285
|
-
def
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
prediction_id=self.prediction_id,
|
|
305
|
-
body=UpdatePredictionRequest(
|
|
306
|
-
expected_label=expected_label if expected_label is not UNSET else CLIENT_UNSET,
|
|
307
|
-
tags=[] if tags is None else list(tags) if tags is not UNSET else CLIENT_UNSET,
|
|
308
|
-
),
|
|
366
|
+
def _update(
|
|
367
|
+
self,
|
|
368
|
+
*,
|
|
369
|
+
tags: set[str] | None = UNSET,
|
|
370
|
+
expected_label: int | None = UNSET,
|
|
371
|
+
expected_score: float | None = UNSET,
|
|
372
|
+
) -> None:
|
|
373
|
+
if self.prediction_id is None:
|
|
374
|
+
raise ValueError("Cannot update prediction with no prediction ID")
|
|
375
|
+
|
|
376
|
+
payload: UpdatePredictionRequest = {}
|
|
377
|
+
if tags is not UNSET:
|
|
378
|
+
payload["tags"] = [] if tags is None else list(tags)
|
|
379
|
+
if expected_label is not UNSET:
|
|
380
|
+
payload["expected_label"] = expected_label
|
|
381
|
+
if expected_score is not UNSET:
|
|
382
|
+
payload["expected_score"] = expected_score
|
|
383
|
+
client = OrcaClient._resolve_client()
|
|
384
|
+
client.PATCH(
|
|
385
|
+
"/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}, json=payload
|
|
309
386
|
)
|
|
310
387
|
self.refresh()
|
|
311
388
|
|
|
@@ -316,7 +393,7 @@ class LabelPrediction:
|
|
|
316
393
|
Params:
|
|
317
394
|
tag: Tag to add to the prediction
|
|
318
395
|
"""
|
|
319
|
-
self.
|
|
396
|
+
self._update(tags=self.tags | {tag})
|
|
320
397
|
|
|
321
398
|
def remove_tag(self, tag: str) -> None:
|
|
322
399
|
"""
|
|
@@ -325,7 +402,7 @@ class LabelPrediction:
|
|
|
325
402
|
Params:
|
|
326
403
|
tag: Tag to remove from the prediction
|
|
327
404
|
"""
|
|
328
|
-
self.
|
|
405
|
+
self._update(tags=self.tags - {tag})
|
|
329
406
|
|
|
330
407
|
def record_feedback(
|
|
331
408
|
self,
|
|
@@ -361,12 +438,14 @@ class LabelPrediction:
|
|
|
361
438
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
362
439
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
363
440
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
441
|
+
client = OrcaClient._resolve_client()
|
|
442
|
+
client.PUT(
|
|
443
|
+
"/telemetry/prediction/feedback",
|
|
444
|
+
json=[
|
|
366
445
|
_parse_feedback(
|
|
367
446
|
{"prediction_id": self.prediction_id, "category": category, "value": value, "comment": comment}
|
|
368
447
|
)
|
|
369
|
-
]
|
|
448
|
+
],
|
|
370
449
|
)
|
|
371
450
|
self.refresh()
|
|
372
451
|
|
|
@@ -380,7 +459,245 @@ class LabelPrediction:
|
|
|
380
459
|
Raises:
|
|
381
460
|
ValueError: If the category is not found.
|
|
382
461
|
"""
|
|
383
|
-
|
|
384
|
-
|
|
462
|
+
if self.prediction_id is None:
|
|
463
|
+
raise ValueError("Cannot delete feedback with no prediction ID")
|
|
464
|
+
|
|
465
|
+
client = OrcaClient._resolve_client()
|
|
466
|
+
client.PUT(
|
|
467
|
+
"/telemetry/prediction/feedback",
|
|
468
|
+
json=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)],
|
|
385
469
|
)
|
|
386
470
|
self.refresh()
|
|
471
|
+
|
|
472
|
+
def inspect(self) -> None:
|
|
473
|
+
"""
|
|
474
|
+
Display an interactive UI with the details about this prediction
|
|
475
|
+
|
|
476
|
+
Note:
|
|
477
|
+
This method is only available in Jupyter notebooks.
|
|
478
|
+
"""
|
|
479
|
+
from ._utils.prediction_result_ui import inspect_prediction_result
|
|
480
|
+
|
|
481
|
+
inspect_prediction_result(self)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class ClassificationPrediction(PredictionBase):
|
|
485
|
+
"""
|
|
486
|
+
Labeled prediction result from a [`ClassificationModel`][orca_sdk.ClassificationModel]
|
|
487
|
+
|
|
488
|
+
Attributes:
|
|
489
|
+
prediction_id: Unique identifier of this prediction used for feedback
|
|
490
|
+
label: Label predicted by the model
|
|
491
|
+
label_name: Human-readable name of the label
|
|
492
|
+
confidence: Confidence of the prediction
|
|
493
|
+
anomaly_score: Anomaly score of the input
|
|
494
|
+
input_value: The input value used for the prediction
|
|
495
|
+
expected_label: Expected label for the prediction, useful when evaluating the model
|
|
496
|
+
expected_label_name: Human-readable name of the expected label
|
|
497
|
+
memory_lookups: Memories used by the model to make the prediction
|
|
498
|
+
explanation: Natural language explanation of the prediction, only available if the model
|
|
499
|
+
has the Explain API enabled
|
|
500
|
+
tags: Tags for the prediction, useful for filtering and grouping predictions
|
|
501
|
+
model: Model used to make the prediction
|
|
502
|
+
memoryset: Memoryset that was used to lookup memories to ground the prediction
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
label: int
|
|
506
|
+
label_name: str
|
|
507
|
+
logits: list[float] | None
|
|
508
|
+
model: ClassificationModel
|
|
509
|
+
memoryset: LabeledMemoryset
|
|
510
|
+
|
|
511
|
+
def __repr__(self):
|
|
512
|
+
return (
|
|
513
|
+
"ClassificationPrediction({"
|
|
514
|
+
+ f"label: <{self.label_name}: {self.label}>, "
|
|
515
|
+
+ f"confidence: {self.confidence:.2f}, "
|
|
516
|
+
+ (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
|
|
517
|
+
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
518
|
+
+ "})"
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
@property
|
|
522
|
+
def memory_lookups(self) -> list[LabeledMemoryLookup]:
|
|
523
|
+
assert "label" in self._telemetry
|
|
524
|
+
return [LabeledMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]]
|
|
525
|
+
|
|
526
|
+
@property
|
|
527
|
+
def expected_label(self) -> int | None:
|
|
528
|
+
assert "label" in self._telemetry
|
|
529
|
+
return self._telemetry["expected_label"]
|
|
530
|
+
|
|
531
|
+
@property
|
|
532
|
+
def expected_label_name(self) -> str | None:
|
|
533
|
+
assert "label" in self._telemetry
|
|
534
|
+
return self._telemetry["expected_label_name"]
|
|
535
|
+
|
|
536
|
+
def update(
|
|
537
|
+
self,
|
|
538
|
+
*,
|
|
539
|
+
tags: set[str] | None = UNSET,
|
|
540
|
+
expected_label: int | None = UNSET,
|
|
541
|
+
) -> None:
|
|
542
|
+
"""
|
|
543
|
+
Update the prediction.
|
|
544
|
+
|
|
545
|
+
Note:
|
|
546
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
547
|
+
|
|
548
|
+
Params:
|
|
549
|
+
tags: New tags to set for the prediction. Set to `None` to remove all tags.
|
|
550
|
+
expected_label: New expected label to set for the prediction. Set to `None` to remove.
|
|
551
|
+
"""
|
|
552
|
+
self._update(tags=tags, expected_label=expected_label)
|
|
553
|
+
|
|
554
|
+
def recommend_action(self, *, refresh: bool = False) -> tuple[str, str]:
|
|
555
|
+
"""
|
|
556
|
+
Get an action recommendation for improving this prediction.
|
|
557
|
+
|
|
558
|
+
Analyzes the prediction and suggests the most effective action to improve model
|
|
559
|
+
performance, such as adding memories, detecting mislabels, removing duplicates,
|
|
560
|
+
or finetuning.
|
|
561
|
+
|
|
562
|
+
Params:
|
|
563
|
+
refresh: Force the action recommendation agent to re-run even if a recommendation already exists
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Tuple of (action, rationale) where:
|
|
567
|
+
- action: The recommended action ("add_memories", "detect_mislabels", "remove_duplicates", or "finetuning") that would resolve the mislabeling
|
|
568
|
+
- rationale: Explanation for why this action was recommended
|
|
569
|
+
|
|
570
|
+
Raises:
|
|
571
|
+
ValueError: If the prediction has no prediction ID
|
|
572
|
+
RuntimeError: If the lighthouse API key is not configured
|
|
573
|
+
|
|
574
|
+
Examples:
|
|
575
|
+
Get action recommendation for an incorrect prediction:
|
|
576
|
+
>>> action, rationale = prediction.recommend_action()
|
|
577
|
+
>>> print(f"Recommended action: {action}")
|
|
578
|
+
>>> print(f"Rationale: {rationale}")
|
|
579
|
+
"""
|
|
580
|
+
if self.prediction_id is None:
|
|
581
|
+
raise ValueError("Cannot get action recommendation with no prediction ID")
|
|
582
|
+
|
|
583
|
+
client = OrcaClient._resolve_client()
|
|
584
|
+
response = client.GET(
|
|
585
|
+
"/telemetry/prediction/{prediction_id}/action",
|
|
586
|
+
params={"prediction_id": self.prediction_id},
|
|
587
|
+
timeout=30,
|
|
588
|
+
)
|
|
589
|
+
return (response["action"], response["rationale"])
|
|
590
|
+
|
|
591
|
+
def generate_memory_suggestions(self, *, num_memories: int = 3) -> AddMemorySuggestions:
|
|
592
|
+
"""
|
|
593
|
+
Generate synthetic memory suggestions to improve this prediction.
|
|
594
|
+
|
|
595
|
+
Creates new example memories that are similar to the input but have clearer
|
|
596
|
+
signals for the expected label. These can be added to the memoryset to improve
|
|
597
|
+
model performance on similar inputs.
|
|
598
|
+
|
|
599
|
+
Params:
|
|
600
|
+
num_memories: Number of memory suggestions to generate (default: 3)
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
List of dictionaries that can be directly passed to memoryset.insert().
|
|
604
|
+
Each dictionary contains:
|
|
605
|
+
- "value": The suggested memory text
|
|
606
|
+
- "label": The suggested label as an integer
|
|
607
|
+
|
|
608
|
+
Raises:
|
|
609
|
+
ValueError: If the prediction has no prediction ID
|
|
610
|
+
RuntimeError: If the lighthouse API key is not configured
|
|
611
|
+
|
|
612
|
+
Examples:
|
|
613
|
+
Generate memory suggestions for an incorrect prediction:
|
|
614
|
+
>>> suggestions = prediction.generate_memory_suggestions(num_memories=3)
|
|
615
|
+
>>> for suggestion in suggestions:
|
|
616
|
+
... print(f"Value: {suggestion['value']}, Label: {suggestion['label']}")
|
|
617
|
+
>>>
|
|
618
|
+
>>> # Add suggestions directly to memoryset
|
|
619
|
+
>>> model.memoryset.insert(suggestions)
|
|
620
|
+
"""
|
|
621
|
+
if self.prediction_id is None:
|
|
622
|
+
raise ValueError("Cannot generate memory suggestions with no prediction ID")
|
|
623
|
+
|
|
624
|
+
client = OrcaClient._resolve_client()
|
|
625
|
+
response = client.GET(
|
|
626
|
+
"/telemetry/prediction/{prediction_id}/memory_suggestions",
|
|
627
|
+
params={"prediction_id": self.prediction_id, "num_memories": num_memories},
|
|
628
|
+
timeout=30,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
return AddMemorySuggestions(
|
|
632
|
+
suggestions=[(m["value"], m["label_name"]) for m in response["memories"]],
|
|
633
|
+
memoryset_id=self.memoryset.id,
|
|
634
|
+
model_id=self.model.id,
|
|
635
|
+
prediction_id=self.prediction_id,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class RegressionPrediction(PredictionBase):
|
|
640
|
+
"""
|
|
641
|
+
Score-based prediction result from a [`RegressionModel`][orca_sdk.RegressionModel]
|
|
642
|
+
|
|
643
|
+
Attributes:
|
|
644
|
+
prediction_id: Unique identifier of this prediction used for feedback
|
|
645
|
+
score: Score predicted by the model
|
|
646
|
+
confidence: Confidence of the prediction
|
|
647
|
+
anomaly_score: Anomaly score of the input
|
|
648
|
+
input_value: The input value used for the prediction
|
|
649
|
+
expected_score: Expected score for the prediction, useful when evaluating the model
|
|
650
|
+
memory_lookups: Memories used by the model to make the prediction
|
|
651
|
+
explanation: Natural language explanation of the prediction, only available if the model
|
|
652
|
+
has the Explain API enabled
|
|
653
|
+
tags: Tags for the prediction, useful for filtering and grouping predictions
|
|
654
|
+
model: Model used to make the prediction
|
|
655
|
+
memoryset: Memoryset that was used to lookup memories to ground the prediction
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
score: float
|
|
659
|
+
model: RegressionModel
|
|
660
|
+
memoryset: ScoredMemoryset
|
|
661
|
+
|
|
662
|
+
def __repr__(self):
|
|
663
|
+
return (
|
|
664
|
+
"RegressionPrediction({"
|
|
665
|
+
+ f"score: {self.score:.2f}, "
|
|
666
|
+
+ f"confidence: {self.confidence:.2f}, "
|
|
667
|
+
+ (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
|
|
668
|
+
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
669
|
+
+ "})"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
@property
|
|
673
|
+
def memory_lookups(self) -> list[ScoredMemoryLookup]:
|
|
674
|
+
assert "score" in self._telemetry
|
|
675
|
+
return [ScoredMemoryLookup(self._telemetry["memoryset_id"], lookup) for lookup in self._telemetry["memories"]]
|
|
676
|
+
|
|
677
|
+
@property
|
|
678
|
+
def explanation(self) -> str:
|
|
679
|
+
"""The explanation for this prediction. Requires `lighthouse_client_api_key` to be set."""
|
|
680
|
+
raise NotImplementedError("Explanation is not supported for regression predictions")
|
|
681
|
+
|
|
682
|
+
@property
|
|
683
|
+
def expected_score(self) -> float | None:
|
|
684
|
+
assert "score" in self._telemetry
|
|
685
|
+
return self._telemetry["expected_score"]
|
|
686
|
+
|
|
687
|
+
def update(
|
|
688
|
+
self,
|
|
689
|
+
*,
|
|
690
|
+
tags: set[str] | None = UNSET,
|
|
691
|
+
expected_score: float | None = UNSET,
|
|
692
|
+
) -> None:
|
|
693
|
+
"""
|
|
694
|
+
Update the prediction.
|
|
695
|
+
|
|
696
|
+
Note:
|
|
697
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
698
|
+
|
|
699
|
+
Params:
|
|
700
|
+
tags: New tags to set for the prediction. Set to `None` to remove all tags.
|
|
701
|
+
expected_score: New expected score to set for the prediction. Set to `None` to remove.
|
|
702
|
+
"""
|
|
703
|
+
self._update(tags=tags, expected_score=expected_score)
|