orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
7
|
+
|
|
8
|
+
from datasets import Dataset
|
|
9
|
+
|
|
10
|
+
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
11
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
|
+
from .async_client import OrcaAsyncClient
|
|
13
|
+
from .client import (
|
|
14
|
+
BootstrapClassificationModelMeta,
|
|
15
|
+
BootstrapLabeledMemoryDataResult,
|
|
16
|
+
ClassificationModelMetadata,
|
|
17
|
+
ClassificationPredictionRequest,
|
|
18
|
+
ListPredictionsRequest,
|
|
19
|
+
OrcaClient,
|
|
20
|
+
PredictiveModelUpdate,
|
|
21
|
+
RACHeadType,
|
|
22
|
+
)
|
|
23
|
+
from .datasource import Datasource
|
|
24
|
+
from .job import Job
|
|
25
|
+
from .memoryset import (
|
|
26
|
+
FilterItem,
|
|
27
|
+
FilterItemTuple,
|
|
28
|
+
LabeledMemoryset,
|
|
29
|
+
_is_metric_column,
|
|
30
|
+
_parse_filter_item_from_tuple,
|
|
31
|
+
)
|
|
32
|
+
from .telemetry import (
|
|
33
|
+
ClassificationPrediction,
|
|
34
|
+
TelemetryMode,
|
|
35
|
+
_get_telemetry_config,
|
|
36
|
+
_parse_feedback,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BootstrappedClassificationModel:
|
|
41
|
+
|
|
42
|
+
datasource: Datasource | None
|
|
43
|
+
memoryset: LabeledMemoryset | None
|
|
44
|
+
classification_model: ClassificationModel | None
|
|
45
|
+
agent_output: BootstrapLabeledMemoryDataResult | None
|
|
46
|
+
|
|
47
|
+
def __init__(self, metadata: BootstrapClassificationModelMeta):
|
|
48
|
+
self.datasource = Datasource.open(metadata["datasource_meta"]["id"])
|
|
49
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_meta"]["id"])
|
|
50
|
+
self.classification_model = ClassificationModel.open(metadata["model_meta"]["id"])
|
|
51
|
+
self.agent_output = metadata["agent_output"]
|
|
52
|
+
|
|
53
|
+
def __repr__(self):
|
|
54
|
+
return (
|
|
55
|
+
"BootstrappedClassificationModel({\n"
|
|
56
|
+
f" datasource: {self.datasource},\n"
|
|
57
|
+
f" memoryset: {self.memoryset},\n"
|
|
58
|
+
f" classification_model: {self.classification_model},\n"
|
|
59
|
+
f" agent_output: {self.agent_output},\n"
|
|
60
|
+
"})"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ClassificationModel:
|
|
65
|
+
"""
|
|
66
|
+
A handle to a classification model in OrcaCloud
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
id: Unique identifier for the model
|
|
70
|
+
name: Unique name of the model
|
|
71
|
+
description: Optional description of the model
|
|
72
|
+
memoryset: Memoryset that the model uses
|
|
73
|
+
head_type: Classification head type of the model
|
|
74
|
+
num_classes: Number of distinct classes the model can predict
|
|
75
|
+
memory_lookup_count: Number of memories the model uses for each prediction
|
|
76
|
+
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
77
|
+
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
78
|
+
locked: Whether the model is locked to prevent accidental deletion
|
|
79
|
+
created_at: When the model was created
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
id: str
|
|
83
|
+
name: str
|
|
84
|
+
description: str | None
|
|
85
|
+
memoryset: LabeledMemoryset
|
|
86
|
+
head_type: RACHeadType
|
|
87
|
+
num_classes: int
|
|
88
|
+
memory_lookup_count: int
|
|
89
|
+
weigh_memories: bool | None
|
|
90
|
+
min_memory_weight: float | None
|
|
91
|
+
version: int
|
|
92
|
+
locked: bool
|
|
93
|
+
created_at: datetime
|
|
94
|
+
|
|
95
|
+
def __init__(self, metadata: ClassificationModelMetadata):
|
|
96
|
+
# for internal use only, do not document
|
|
97
|
+
self.id = metadata["id"]
|
|
98
|
+
self.name = metadata["name"]
|
|
99
|
+
self.description = metadata["description"]
|
|
100
|
+
self.memoryset = LabeledMemoryset.open(metadata["memoryset_id"])
|
|
101
|
+
self.head_type = metadata["head_type"]
|
|
102
|
+
self.num_classes = metadata["num_classes"]
|
|
103
|
+
self.memory_lookup_count = metadata["memory_lookup_count"]
|
|
104
|
+
self.weigh_memories = metadata["weigh_memories"]
|
|
105
|
+
self.min_memory_weight = metadata["min_memory_weight"]
|
|
106
|
+
self.version = metadata["version"]
|
|
107
|
+
self.locked = metadata["locked"]
|
|
108
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
109
|
+
|
|
110
|
+
self._memoryset_override_id: str | None = None
|
|
111
|
+
self._last_prediction: ClassificationPrediction | None = None
|
|
112
|
+
self._last_prediction_was_batch: bool = False
|
|
113
|
+
|
|
114
|
+
def __eq__(self, other) -> bool:
|
|
115
|
+
return isinstance(other, ClassificationModel) and self.id == other.id
|
|
116
|
+
|
|
117
|
+
def __repr__(self):
|
|
118
|
+
memoryset_repr = self.memoryset.__repr__().replace("\n", "\n ")
|
|
119
|
+
return (
|
|
120
|
+
"ClassificationModel({\n"
|
|
121
|
+
f" name: '{self.name}',\n"
|
|
122
|
+
f" head_type: {self.head_type},\n"
|
|
123
|
+
f" num_classes: {self.num_classes},\n"
|
|
124
|
+
f" memory_lookup_count: {self.memory_lookup_count},\n"
|
|
125
|
+
f" memoryset: {memoryset_repr},\n"
|
|
126
|
+
"})"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def last_prediction(self) -> ClassificationPrediction:
|
|
131
|
+
"""
|
|
132
|
+
Last prediction made by the model
|
|
133
|
+
|
|
134
|
+
Note:
|
|
135
|
+
If the last prediction was part of a batch prediction, the last prediction from the
|
|
136
|
+
batch is returned. If no prediction has been made yet, a [`LookupError`][LookupError]
|
|
137
|
+
is raised.
|
|
138
|
+
"""
|
|
139
|
+
if self._last_prediction_was_batch:
|
|
140
|
+
logging.warning(
|
|
141
|
+
"Last prediction was part of a batch prediction, returning the last prediction from the batch"
|
|
142
|
+
)
|
|
143
|
+
if self._last_prediction is None:
|
|
144
|
+
raise LookupError("No prediction has been made yet")
|
|
145
|
+
return self._last_prediction
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def create(
|
|
149
|
+
cls,
|
|
150
|
+
name: str,
|
|
151
|
+
memoryset: LabeledMemoryset,
|
|
152
|
+
head_type: RACHeadType = "KNN",
|
|
153
|
+
*,
|
|
154
|
+
description: str | None = None,
|
|
155
|
+
num_classes: int | None = None,
|
|
156
|
+
memory_lookup_count: int | None = None,
|
|
157
|
+
weigh_memories: bool = True,
|
|
158
|
+
min_memory_weight: float | None = None,
|
|
159
|
+
if_exists: CreateMode = "error",
|
|
160
|
+
) -> ClassificationModel:
|
|
161
|
+
"""
|
|
162
|
+
Create a new classification model
|
|
163
|
+
|
|
164
|
+
Params:
|
|
165
|
+
name: Name for the new model (must be unique)
|
|
166
|
+
memoryset: Memoryset to attach the model to
|
|
167
|
+
head_type: Type of model head to use
|
|
168
|
+
num_classes: Number of classes this model can predict, will be inferred from memoryset if not specified
|
|
169
|
+
memory_lookup_count: Number of memories to lookup for each prediction,
|
|
170
|
+
by default the system uses a simple heuristic to choose a number of memories that works well in most cases
|
|
171
|
+
weigh_memories: If using a KNN head, whether the model weighs memories by their lookup score
|
|
172
|
+
min_memory_weight: If using a KNN head, minimum lookup score memories have to be over to not be ignored
|
|
173
|
+
if_exists: What to do if a model with the same name already exists, defaults to
|
|
174
|
+
`"error"`. Other option is `"open"` to open the existing model.
|
|
175
|
+
description: Optional description for the model, this will be used in agentic flows,
|
|
176
|
+
so make sure it is concise and describes the purpose of your model.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Handle to the new model in the OrcaCloud
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If the model already exists and if_exists is `"error"` or if it is
|
|
183
|
+
`"open"` and the existing model has different attributes.
|
|
184
|
+
|
|
185
|
+
Examples:
|
|
186
|
+
Create a new model using default options:
|
|
187
|
+
>>> model = ClassificationModel.create(
|
|
188
|
+
... "my_model",
|
|
189
|
+
... LabeledMemoryset.open("my_memoryset"),
|
|
190
|
+
... )
|
|
191
|
+
|
|
192
|
+
Create a new model with non-default model head and options:
|
|
193
|
+
>>> model = ClassificationModel.create(
|
|
194
|
+
... name="my_model",
|
|
195
|
+
... memoryset=LabeledMemoryset.open("my_memoryset"),
|
|
196
|
+
... head_type=RACHeadType.MMOE,
|
|
197
|
+
... num_classes=5,
|
|
198
|
+
... memory_lookup_count=20,
|
|
199
|
+
... )
|
|
200
|
+
"""
|
|
201
|
+
if cls.exists(name):
|
|
202
|
+
if if_exists == "error":
|
|
203
|
+
raise ValueError(f"Model with name {name} already exists")
|
|
204
|
+
elif if_exists == "open":
|
|
205
|
+
existing = cls.open(name)
|
|
206
|
+
for attribute in {
|
|
207
|
+
"head_type",
|
|
208
|
+
"memory_lookup_count",
|
|
209
|
+
"num_classes",
|
|
210
|
+
"min_memory_weight",
|
|
211
|
+
}:
|
|
212
|
+
local_attribute = locals()[attribute]
|
|
213
|
+
existing_attribute = getattr(existing, attribute)
|
|
214
|
+
if local_attribute is not None and local_attribute != existing_attribute:
|
|
215
|
+
raise ValueError(f"Model with name {name} already exists with different {attribute}")
|
|
216
|
+
|
|
217
|
+
# special case for memoryset
|
|
218
|
+
if existing.memoryset.id != memoryset.id:
|
|
219
|
+
raise ValueError(f"Model with name {name} already exists with different memoryset")
|
|
220
|
+
|
|
221
|
+
return existing
|
|
222
|
+
|
|
223
|
+
client = OrcaClient._resolve_client()
|
|
224
|
+
metadata = client.POST(
|
|
225
|
+
"/classification_model",
|
|
226
|
+
json={
|
|
227
|
+
"name": name,
|
|
228
|
+
"memoryset_name_or_id": memoryset.id,
|
|
229
|
+
"head_type": head_type,
|
|
230
|
+
"memory_lookup_count": memory_lookup_count,
|
|
231
|
+
"num_classes": num_classes,
|
|
232
|
+
"weigh_memories": weigh_memories,
|
|
233
|
+
"min_memory_weight": min_memory_weight,
|
|
234
|
+
"description": description,
|
|
235
|
+
},
|
|
236
|
+
)
|
|
237
|
+
return cls(metadata)
|
|
238
|
+
|
|
239
|
+
@classmethod
|
|
240
|
+
def open(cls, name: str) -> ClassificationModel:
|
|
241
|
+
"""
|
|
242
|
+
Get a handle to a classification model in the OrcaCloud
|
|
243
|
+
|
|
244
|
+
Params:
|
|
245
|
+
name: Name or unique identifier of the classification model
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Handle to the existing classification model in the OrcaCloud
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
LookupError: If the classification model does not exist
|
|
252
|
+
"""
|
|
253
|
+
client = OrcaClient._resolve_client()
|
|
254
|
+
return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def exists(cls, name_or_id: str) -> bool:
|
|
258
|
+
"""
|
|
259
|
+
Check if a classification model exists in the OrcaCloud
|
|
260
|
+
|
|
261
|
+
Params:
|
|
262
|
+
name_or_id: Name or id of the classification model
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
`True` if the classification model exists, `False` otherwise
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
cls.open(name_or_id)
|
|
269
|
+
return True
|
|
270
|
+
except LookupError:
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def all(cls) -> list[ClassificationModel]:
|
|
275
|
+
"""
|
|
276
|
+
Get a list of handles to all classification models in the OrcaCloud
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
List of handles to all classification models in the OrcaCloud
|
|
280
|
+
"""
|
|
281
|
+
client = OrcaClient._resolve_client()
|
|
282
|
+
return [cls(metadata) for metadata in client.GET("/classification_model")]
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
286
|
+
"""
|
|
287
|
+
Delete a classification model from the OrcaCloud
|
|
288
|
+
|
|
289
|
+
Warning:
|
|
290
|
+
This will delete the model and all associated data, including predictions, evaluations, and feedback.
|
|
291
|
+
|
|
292
|
+
Params:
|
|
293
|
+
name_or_id: Name or id of the classification model
|
|
294
|
+
if_not_exists: What to do if the classification model does not exist, defaults to `"error"`.
|
|
295
|
+
Other option is `"ignore"` to do nothing if the classification model does not exist.
|
|
296
|
+
|
|
297
|
+
Raises:
|
|
298
|
+
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
client = OrcaClient._resolve_client()
|
|
302
|
+
client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
303
|
+
logging.info(f"Deleted model {name_or_id}")
|
|
304
|
+
except LookupError:
|
|
305
|
+
if if_not_exists == "error":
|
|
306
|
+
raise
|
|
307
|
+
|
|
308
|
+
def refresh(self):
|
|
309
|
+
"""Refresh the model data from the OrcaCloud"""
|
|
310
|
+
self.__dict__.update(self.open(self.name).__dict__)
|
|
311
|
+
|
|
312
|
+
def set(self, *, description: str | None = UNSET, locked: bool = UNSET) -> None:
|
|
313
|
+
"""
|
|
314
|
+
Update editable attributes of the model.
|
|
315
|
+
|
|
316
|
+
Note:
|
|
317
|
+
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
318
|
+
|
|
319
|
+
Params:
|
|
320
|
+
description: Value to set for the description
|
|
321
|
+
locked: Value to set for the locked status
|
|
322
|
+
|
|
323
|
+
Examples:
|
|
324
|
+
Update the description:
|
|
325
|
+
>>> model.set(description="New description")
|
|
326
|
+
|
|
327
|
+
Remove description:
|
|
328
|
+
>>> model.set(description=None)
|
|
329
|
+
|
|
330
|
+
Lock the model:
|
|
331
|
+
>>> model.set(locked=True)
|
|
332
|
+
"""
|
|
333
|
+
update: PredictiveModelUpdate = {}
|
|
334
|
+
if description is not UNSET:
|
|
335
|
+
update["description"] = description
|
|
336
|
+
if locked is not UNSET:
|
|
337
|
+
update["locked"] = locked
|
|
338
|
+
client = OrcaClient._resolve_client()
|
|
339
|
+
client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
340
|
+
self.refresh()
|
|
341
|
+
|
|
342
|
+
def lock(self) -> None:
|
|
343
|
+
"""Lock the model to prevent accidental deletion"""
|
|
344
|
+
self.set(locked=True)
|
|
345
|
+
|
|
346
|
+
def unlock(self) -> None:
|
|
347
|
+
"""Unlock the model to allow deletion"""
|
|
348
|
+
self.set(locked=False)
|
|
349
|
+
|
|
350
|
+
@overload
|
|
351
|
+
def predict(
|
|
352
|
+
self,
|
|
353
|
+
value: list[str],
|
|
354
|
+
expected_labels: list[int] | None = None,
|
|
355
|
+
filters: list[FilterItemTuple] = [],
|
|
356
|
+
tags: set[str] | None = None,
|
|
357
|
+
save_telemetry: TelemetryMode = "on",
|
|
358
|
+
prompt: str | None = None,
|
|
359
|
+
use_lookup_cache: bool = True,
|
|
360
|
+
timeout_seconds: int = 10,
|
|
361
|
+
ignore_unlabeled: bool = False,
|
|
362
|
+
partition_id: str | list[str | None] | None = None,
|
|
363
|
+
partition_filter_mode: Literal[
|
|
364
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
365
|
+
] = "include_global",
|
|
366
|
+
use_gpu: bool = True,
|
|
367
|
+
batch_size: int = 100,
|
|
368
|
+
) -> list[ClassificationPrediction]:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
@overload
|
|
372
|
+
def predict(
|
|
373
|
+
self,
|
|
374
|
+
value: str,
|
|
375
|
+
expected_labels: int | None = None,
|
|
376
|
+
filters: list[FilterItemTuple] = [],
|
|
377
|
+
tags: set[str] | None = None,
|
|
378
|
+
save_telemetry: TelemetryMode = "on",
|
|
379
|
+
prompt: str | None = None,
|
|
380
|
+
use_lookup_cache: bool = True,
|
|
381
|
+
timeout_seconds: int = 10,
|
|
382
|
+
ignore_unlabeled: bool = False,
|
|
383
|
+
partition_id: str | None = None,
|
|
384
|
+
partition_filter_mode: Literal[
|
|
385
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
386
|
+
] = "include_global",
|
|
387
|
+
use_gpu: bool = True,
|
|
388
|
+
batch_size: int = 100,
|
|
389
|
+
) -> ClassificationPrediction:
|
|
390
|
+
pass
|
|
391
|
+
|
|
392
|
+
def predict(
|
|
393
|
+
self,
|
|
394
|
+
value: list[str] | str,
|
|
395
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
396
|
+
filters: list[FilterItemTuple] = [],
|
|
397
|
+
tags: set[str] | None = None,
|
|
398
|
+
save_telemetry: TelemetryMode = "on",
|
|
399
|
+
prompt: str | None = None,
|
|
400
|
+
use_lookup_cache: bool = True,
|
|
401
|
+
timeout_seconds: int = 10,
|
|
402
|
+
ignore_unlabeled: bool = False,
|
|
403
|
+
partition_id: str | None | list[str | None] = None,
|
|
404
|
+
partition_filter_mode: Literal[
|
|
405
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
406
|
+
] = "include_global",
|
|
407
|
+
use_gpu: bool = True,
|
|
408
|
+
batch_size: int = 100,
|
|
409
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
410
|
+
"""
|
|
411
|
+
Predict label(s) for the given input value(s) grounded in similar memories
|
|
412
|
+
|
|
413
|
+
Params:
|
|
414
|
+
value: Value(s) to get predict the labels of
|
|
415
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
416
|
+
filters: Optional filters to apply during memory lookup
|
|
417
|
+
tags: Tags to add to the prediction(s)
|
|
418
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
419
|
+
* `"off"`: Do not save telemetry
|
|
420
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
421
|
+
environment variable is set.
|
|
422
|
+
* `"sync"`: Save telemetry synchronously
|
|
423
|
+
* `"async"`: Save telemetry asynchronously
|
|
424
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
425
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
426
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
427
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
428
|
+
If False (default), allow unlabeled memories when necessary.
|
|
429
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
430
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
431
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
432
|
+
* `"include_global"`: Include global memories
|
|
433
|
+
* `"exclude_global"`: Exclude global memories
|
|
434
|
+
* `"only_global"`: Only include global memories
|
|
435
|
+
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
436
|
+
batch_size: Number of values to process in a single API call
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
Label prediction or list of label predictions
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
443
|
+
TimeoutError: If the request times out after the specified duration
|
|
444
|
+
|
|
445
|
+
Examples:
|
|
446
|
+
Predict the label for a single value:
|
|
447
|
+
>>> prediction = model.predict("I am happy", tags={"test"})
|
|
448
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
449
|
+
|
|
450
|
+
Predict the labels for a list of values:
|
|
451
|
+
>>> predictions = model.predict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
452
|
+
[
|
|
453
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
454
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
455
|
+
]
|
|
456
|
+
|
|
457
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
458
|
+
>>> prediction = model.predict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
459
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
if timeout_seconds <= 0:
|
|
463
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
464
|
+
if batch_size <= 0 or batch_size > 500:
|
|
465
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
466
|
+
|
|
467
|
+
parsed_filters = [
|
|
468
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
469
|
+
]
|
|
470
|
+
|
|
471
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
472
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
473
|
+
|
|
474
|
+
# Convert to list for batching
|
|
475
|
+
values = value if isinstance(value, list) else [value]
|
|
476
|
+
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
477
|
+
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
478
|
+
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
479
|
+
raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
|
|
480
|
+
|
|
481
|
+
if isinstance(expected_labels, int):
|
|
482
|
+
expected_labels = [expected_labels] * len(values)
|
|
483
|
+
elif isinstance(expected_labels, str):
|
|
484
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
485
|
+
elif isinstance(expected_labels, list):
|
|
486
|
+
expected_labels = [
|
|
487
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
488
|
+
for label in expected_labels
|
|
489
|
+
]
|
|
490
|
+
|
|
491
|
+
if use_gpu:
|
|
492
|
+
endpoint = "/gpu/classification_model/{name_or_id}/prediction"
|
|
493
|
+
else:
|
|
494
|
+
endpoint = "/classification_model/{name_or_id}/prediction"
|
|
495
|
+
|
|
496
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
497
|
+
client = OrcaClient._resolve_client()
|
|
498
|
+
|
|
499
|
+
predictions: list[ClassificationPrediction] = []
|
|
500
|
+
for i in range(0, len(values), batch_size):
|
|
501
|
+
batch_values = values[i : i + batch_size]
|
|
502
|
+
batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
|
|
503
|
+
|
|
504
|
+
request_json: ClassificationPredictionRequest = {
|
|
505
|
+
"input_values": batch_values,
|
|
506
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
507
|
+
"expected_labels": batch_expected_labels,
|
|
508
|
+
"tags": list(tags or set()),
|
|
509
|
+
"save_telemetry": telemetry_on,
|
|
510
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
511
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
512
|
+
"prompt": prompt,
|
|
513
|
+
"use_lookup_cache": use_lookup_cache,
|
|
514
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
515
|
+
"partition_filter_mode": partition_filter_mode,
|
|
516
|
+
}
|
|
517
|
+
if partition_filter_mode != "ignore_partitions":
|
|
518
|
+
request_json["partition_ids"] = (
|
|
519
|
+
partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
response = client.POST(
|
|
523
|
+
endpoint,
|
|
524
|
+
params={"name_or_id": self.id},
|
|
525
|
+
json=request_json,
|
|
526
|
+
timeout=timeout_seconds,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
530
|
+
raise RuntimeError("Failed to save some prediction to database.")
|
|
531
|
+
|
|
532
|
+
predictions.extend(
|
|
533
|
+
ClassificationPrediction(
|
|
534
|
+
prediction_id=prediction["prediction_id"],
|
|
535
|
+
label=prediction["label"],
|
|
536
|
+
label_name=prediction["label_name"],
|
|
537
|
+
score=None,
|
|
538
|
+
confidence=prediction["confidence"],
|
|
539
|
+
anomaly_score=prediction["anomaly_score"],
|
|
540
|
+
memoryset=self.memoryset,
|
|
541
|
+
model=self,
|
|
542
|
+
logits=prediction["logits"],
|
|
543
|
+
input_value=input_value,
|
|
544
|
+
)
|
|
545
|
+
for prediction, input_value in zip(response, batch_values)
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
549
|
+
self._last_prediction = predictions[-1]
|
|
550
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
551
|
+
|
|
552
|
+
@overload
|
|
553
|
+
async def apredict(
|
|
554
|
+
self,
|
|
555
|
+
value: list[str],
|
|
556
|
+
expected_labels: list[int] | None = None,
|
|
557
|
+
filters: list[FilterItemTuple] = [],
|
|
558
|
+
tags: set[str] | None = None,
|
|
559
|
+
save_telemetry: TelemetryMode = "on",
|
|
560
|
+
prompt: str | None = None,
|
|
561
|
+
use_lookup_cache: bool = True,
|
|
562
|
+
timeout_seconds: int = 10,
|
|
563
|
+
ignore_unlabeled: bool = False,
|
|
564
|
+
partition_id: str | list[str | None] | None = None,
|
|
565
|
+
partition_filter_mode: Literal[
|
|
566
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
567
|
+
] = "include_global",
|
|
568
|
+
batch_size: int = 100,
|
|
569
|
+
) -> list[ClassificationPrediction]:
|
|
570
|
+
pass
|
|
571
|
+
|
|
572
|
+
@overload
|
|
573
|
+
async def apredict(
|
|
574
|
+
self,
|
|
575
|
+
value: str,
|
|
576
|
+
expected_labels: int | None = None,
|
|
577
|
+
filters: list[FilterItemTuple] = [],
|
|
578
|
+
tags: set[str] | None = None,
|
|
579
|
+
save_telemetry: TelemetryMode = "on",
|
|
580
|
+
prompt: str | None = None,
|
|
581
|
+
use_lookup_cache: bool = True,
|
|
582
|
+
timeout_seconds: int = 10,
|
|
583
|
+
ignore_unlabeled: bool = False,
|
|
584
|
+
partition_id: str | None = None,
|
|
585
|
+
partition_filter_mode: Literal[
|
|
586
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
587
|
+
] = "include_global",
|
|
588
|
+
batch_size: int = 100,
|
|
589
|
+
) -> ClassificationPrediction:
|
|
590
|
+
pass
|
|
591
|
+
|
|
592
|
+
async def apredict(
|
|
593
|
+
self,
|
|
594
|
+
value: list[str] | str,
|
|
595
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
596
|
+
filters: list[FilterItemTuple] = [],
|
|
597
|
+
tags: set[str] | None = None,
|
|
598
|
+
save_telemetry: TelemetryMode = "on",
|
|
599
|
+
prompt: str | None = None,
|
|
600
|
+
use_lookup_cache: bool = True,
|
|
601
|
+
timeout_seconds: int = 10,
|
|
602
|
+
ignore_unlabeled: bool = False,
|
|
603
|
+
partition_id: str | None | list[str | None] = None,
|
|
604
|
+
partition_filter_mode: Literal[
|
|
605
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
606
|
+
] = "include_global",
|
|
607
|
+
batch_size: int = 100,
|
|
608
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
609
|
+
"""
|
|
610
|
+
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
611
|
+
|
|
612
|
+
Params:
|
|
613
|
+
value: Value(s) to get predict the labels of
|
|
614
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
615
|
+
filters: Optional filters to apply during memory lookup
|
|
616
|
+
tags: Tags to add to the prediction(s)
|
|
617
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
618
|
+
* `"off"`: Do not save telemetry
|
|
619
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
620
|
+
environment variable is set.
|
|
621
|
+
* `"sync"`: Save telemetry synchronously
|
|
622
|
+
* `"async"`: Save telemetry asynchronously
|
|
623
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
624
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
625
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
626
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
627
|
+
If False (default), allow unlabeled memories when necessary.
|
|
628
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
629
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
630
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
631
|
+
* `"include_global"`: Include global memories
|
|
632
|
+
* `"exclude_global"`: Exclude global memories
|
|
633
|
+
* `"only_global"`: Only include global memories
|
|
634
|
+
batch_size: Number of values to process in a single API call
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
Label prediction or list of label predictions.
|
|
638
|
+
|
|
639
|
+
Raises:
|
|
640
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
641
|
+
TimeoutError: If the request times out after the specified duration
|
|
642
|
+
|
|
643
|
+
Examples:
|
|
644
|
+
Predict the label for a single value:
|
|
645
|
+
>>> prediction = await model.apredict("I am happy", tags={"test"})
|
|
646
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
647
|
+
|
|
648
|
+
Predict the labels for a list of values:
|
|
649
|
+
>>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
650
|
+
[
|
|
651
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
652
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
653
|
+
]
|
|
654
|
+
|
|
655
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
656
|
+
>>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
657
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
if timeout_seconds <= 0:
|
|
661
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
662
|
+
if batch_size <= 0 or batch_size > 500:
|
|
663
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
664
|
+
|
|
665
|
+
parsed_filters = [
|
|
666
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
667
|
+
]
|
|
668
|
+
|
|
669
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
670
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
671
|
+
|
|
672
|
+
# Convert to list for batching
|
|
673
|
+
values = value if isinstance(value, list) else [value]
|
|
674
|
+
if isinstance(expected_labels, list) and len(expected_labels) != len(values):
|
|
675
|
+
raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
|
|
676
|
+
if isinstance(partition_id, list) and len(partition_id) != len(values):
|
|
677
|
+
raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
|
|
678
|
+
|
|
679
|
+
if isinstance(expected_labels, int):
|
|
680
|
+
expected_labels = [expected_labels] * len(values)
|
|
681
|
+
elif isinstance(expected_labels, str):
|
|
682
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
|
|
683
|
+
elif isinstance(expected_labels, list):
|
|
684
|
+
expected_labels = [
|
|
685
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
686
|
+
for label in expected_labels
|
|
687
|
+
]
|
|
688
|
+
|
|
689
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
690
|
+
client = OrcaAsyncClient._resolve_client()
|
|
691
|
+
|
|
692
|
+
predictions: list[ClassificationPrediction] = []
|
|
693
|
+
for i in range(0, len(values), batch_size):
|
|
694
|
+
batch_values = values[i : i + batch_size]
|
|
695
|
+
batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
|
|
696
|
+
|
|
697
|
+
request_json: ClassificationPredictionRequest = {
|
|
698
|
+
"input_values": batch_values,
|
|
699
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
700
|
+
"expected_labels": batch_expected_labels,
|
|
701
|
+
"tags": list(tags or set()),
|
|
702
|
+
"save_telemetry": telemetry_on,
|
|
703
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
704
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
705
|
+
"prompt": prompt,
|
|
706
|
+
"use_lookup_cache": use_lookup_cache,
|
|
707
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
708
|
+
"partition_filter_mode": partition_filter_mode,
|
|
709
|
+
}
|
|
710
|
+
if partition_filter_mode != "ignore_partitions":
|
|
711
|
+
request_json["partition_ids"] = (
|
|
712
|
+
partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
|
|
713
|
+
)
|
|
714
|
+
response = await client.POST(
|
|
715
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
716
|
+
params={"name_or_id": self.id},
|
|
717
|
+
json=request_json,
|
|
718
|
+
timeout=timeout_seconds,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
722
|
+
raise RuntimeError("Failed to save some prediction to database.")
|
|
723
|
+
|
|
724
|
+
predictions.extend(
|
|
725
|
+
ClassificationPrediction(
|
|
726
|
+
prediction_id=prediction["prediction_id"],
|
|
727
|
+
label=prediction["label"],
|
|
728
|
+
label_name=prediction["label_name"],
|
|
729
|
+
score=None,
|
|
730
|
+
confidence=prediction["confidence"],
|
|
731
|
+
anomaly_score=prediction["anomaly_score"],
|
|
732
|
+
memoryset=self.memoryset,
|
|
733
|
+
model=self,
|
|
734
|
+
logits=prediction["logits"],
|
|
735
|
+
input_value=input_value,
|
|
736
|
+
)
|
|
737
|
+
for prediction, input_value in zip(response, batch_values)
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
741
|
+
self._last_prediction = predictions[-1]
|
|
742
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
743
|
+
|
|
744
|
+
def predictions(
|
|
745
|
+
self,
|
|
746
|
+
limit: int | None = None,
|
|
747
|
+
offset: int = 0,
|
|
748
|
+
tag: str | None = None,
|
|
749
|
+
sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
|
|
750
|
+
expected_label_match: bool | None = None,
|
|
751
|
+
batch_size: int = 100,
|
|
752
|
+
) -> list[ClassificationPrediction]:
|
|
753
|
+
"""
|
|
754
|
+
Get a list of predictions made by this model
|
|
755
|
+
|
|
756
|
+
Params:
|
|
757
|
+
limit: Maximum number of predictions to return. If `None`, returns all predictions
|
|
758
|
+
by automatically paginating through results.
|
|
759
|
+
offset: Optional offset of the first prediction to return
|
|
760
|
+
tag: Optional tag to filter predictions by
|
|
761
|
+
sort: Optional list of columns and directions to sort the predictions by.
|
|
762
|
+
Predictions can be sorted by `timestamp` or `confidence`.
|
|
763
|
+
expected_label_match: Optional filter to only include predictions where the expected
|
|
764
|
+
label does (`True`) or doesn't (`False`) match the predicted label
|
|
765
|
+
batch_size: Number of predictions to fetch in a single API call
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
List of label predictions
|
|
769
|
+
|
|
770
|
+
Examples:
|
|
771
|
+
Get all predictions with a specific tag:
|
|
772
|
+
>>> predictions = model.predictions(tag="evaluation")
|
|
773
|
+
|
|
774
|
+
Get the last 3 predictions:
|
|
775
|
+
>>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
|
|
776
|
+
[
|
|
777
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
778
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
779
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am ecstatic'}),
|
|
780
|
+
]
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
Get second most confident prediction:
|
|
784
|
+
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
785
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.90, anomaly_score: 0.1, input_value: 'I am having a good day'})]
|
|
786
|
+
|
|
787
|
+
Get predictions where the expected label doesn't match the predicted label:
|
|
788
|
+
>>> predictions = model.predictions(expected_label_match=False)
|
|
789
|
+
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
790
|
+
"""
|
|
791
|
+
if batch_size <= 0 or batch_size > 500:
|
|
792
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
793
|
+
if limit == 0:
|
|
794
|
+
return []
|
|
795
|
+
|
|
796
|
+
client = OrcaClient._resolve_client()
|
|
797
|
+
all_predictions: list[ClassificationPrediction] = []
|
|
798
|
+
|
|
799
|
+
if limit is not None and limit < batch_size:
|
|
800
|
+
pages = [(offset, limit)]
|
|
801
|
+
else:
|
|
802
|
+
# automatically paginate the requests if necessary
|
|
803
|
+
total = client.POST(
|
|
804
|
+
"/telemetry/prediction/count",
|
|
805
|
+
json={
|
|
806
|
+
"model_id": self.id,
|
|
807
|
+
"tag": tag,
|
|
808
|
+
"expected_label_match": expected_label_match,
|
|
809
|
+
},
|
|
810
|
+
)
|
|
811
|
+
max_limit = max(total - offset, 0)
|
|
812
|
+
limit = min(limit, max_limit) if limit is not None else max_limit
|
|
813
|
+
pages = [(o, min(batch_size, limit - (o - offset))) for o in range(offset, offset + limit, batch_size)]
|
|
814
|
+
|
|
815
|
+
for current_offset, current_limit in pages:
|
|
816
|
+
request_json: ListPredictionsRequest = {
|
|
817
|
+
"model_id": self.id,
|
|
818
|
+
"limit": current_limit,
|
|
819
|
+
"offset": current_offset,
|
|
820
|
+
"tag": tag,
|
|
821
|
+
"expected_label_match": expected_label_match,
|
|
822
|
+
}
|
|
823
|
+
if sort:
|
|
824
|
+
request_json["sort"] = sort
|
|
825
|
+
response = client.POST(
|
|
826
|
+
"/telemetry/prediction",
|
|
827
|
+
json=request_json,
|
|
828
|
+
)
|
|
829
|
+
all_predictions.extend(
|
|
830
|
+
ClassificationPrediction(
|
|
831
|
+
prediction_id=prediction["prediction_id"],
|
|
832
|
+
label=prediction["label"],
|
|
833
|
+
label_name=prediction["label_name"],
|
|
834
|
+
score=None,
|
|
835
|
+
confidence=prediction["confidence"],
|
|
836
|
+
anomaly_score=prediction["anomaly_score"],
|
|
837
|
+
memoryset=self.memoryset,
|
|
838
|
+
model=self,
|
|
839
|
+
telemetry=prediction,
|
|
840
|
+
)
|
|
841
|
+
for prediction in response
|
|
842
|
+
if "label" in prediction
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
return all_predictions
|
|
846
|
+
|
|
847
|
+
def _evaluate_datasource(
|
|
848
|
+
self,
|
|
849
|
+
datasource: Datasource,
|
|
850
|
+
value_column: str,
|
|
851
|
+
label_column: str,
|
|
852
|
+
record_predictions: bool,
|
|
853
|
+
tags: set[str] | None,
|
|
854
|
+
subsample: int | float | None,
|
|
855
|
+
background: bool = False,
|
|
856
|
+
ignore_unlabeled: bool = False,
|
|
857
|
+
partition_column: str | None = None,
|
|
858
|
+
partition_filter_mode: Literal[
|
|
859
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
860
|
+
] = "include_global",
|
|
861
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
862
|
+
client = OrcaClient._resolve_client()
|
|
863
|
+
response = client.POST(
|
|
864
|
+
"/classification_model/{model_name_or_id}/evaluation",
|
|
865
|
+
params={"model_name_or_id": self.id},
|
|
866
|
+
json={
|
|
867
|
+
"datasource_name_or_id": datasource.id,
|
|
868
|
+
"datasource_label_column": label_column,
|
|
869
|
+
"datasource_value_column": value_column,
|
|
870
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
871
|
+
"record_telemetry": record_predictions,
|
|
872
|
+
"telemetry_tags": list(tags) if tags else None,
|
|
873
|
+
"subsample": subsample,
|
|
874
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
875
|
+
"datasource_partition_column": partition_column,
|
|
876
|
+
"partition_filter_mode": partition_filter_mode,
|
|
877
|
+
},
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
def get_value():
|
|
881
|
+
client = OrcaClient._resolve_client()
|
|
882
|
+
res = client.GET(
|
|
883
|
+
"/classification_model/{model_name_or_id}/evaluation/{job_id}",
|
|
884
|
+
params={"model_name_or_id": self.id, "job_id": response["job_id"]},
|
|
885
|
+
)
|
|
886
|
+
assert res["result"] is not None
|
|
887
|
+
return ClassificationMetrics(
|
|
888
|
+
coverage=res["result"].get("coverage"),
|
|
889
|
+
f1_score=res["result"].get("f1_score"),
|
|
890
|
+
accuracy=res["result"].get("accuracy"),
|
|
891
|
+
loss=res["result"].get("loss"),
|
|
892
|
+
anomaly_score_mean=res["result"].get("anomaly_score_mean"),
|
|
893
|
+
anomaly_score_median=res["result"].get("anomaly_score_median"),
|
|
894
|
+
anomaly_score_variance=res["result"].get("anomaly_score_variance"),
|
|
895
|
+
roc_auc=res["result"].get("roc_auc"),
|
|
896
|
+
pr_auc=res["result"].get("pr_auc"),
|
|
897
|
+
pr_curve=res["result"].get("pr_curve"),
|
|
898
|
+
roc_curve=res["result"].get("roc_curve"),
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
job = Job(response["job_id"], get_value)
|
|
902
|
+
return job if background else job.result()
|
|
903
|
+
|
|
904
|
+
def _evaluate_dataset(
|
|
905
|
+
self,
|
|
906
|
+
dataset: Dataset,
|
|
907
|
+
value_column: str,
|
|
908
|
+
label_column: str,
|
|
909
|
+
record_predictions: bool,
|
|
910
|
+
tags: set[str],
|
|
911
|
+
batch_size: int,
|
|
912
|
+
ignore_unlabeled: bool,
|
|
913
|
+
partition_column: str | None = None,
|
|
914
|
+
partition_filter_mode: Literal[
|
|
915
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
916
|
+
] = "include_global",
|
|
917
|
+
) -> ClassificationMetrics:
|
|
918
|
+
if len(dataset) == 0:
|
|
919
|
+
raise ValueError("Evaluation dataset cannot be empty")
|
|
920
|
+
|
|
921
|
+
if any(x is None for x in dataset[label_column]):
|
|
922
|
+
raise ValueError("Evaluation dataset cannot contain None values in the label column")
|
|
923
|
+
|
|
924
|
+
predictions = [
|
|
925
|
+
prediction
|
|
926
|
+
for i in range(0, len(dataset), batch_size)
|
|
927
|
+
for prediction in self.predict(
|
|
928
|
+
dataset[i : i + batch_size][value_column],
|
|
929
|
+
expected_labels=dataset[i : i + batch_size][label_column],
|
|
930
|
+
tags=tags,
|
|
931
|
+
save_telemetry="sync" if record_predictions else "off",
|
|
932
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
933
|
+
partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
|
|
934
|
+
partition_filter_mode=partition_filter_mode,
|
|
935
|
+
)
|
|
936
|
+
]
|
|
937
|
+
|
|
938
|
+
return calculate_classification_metrics(
|
|
939
|
+
expected_labels=dataset[label_column],
|
|
940
|
+
logits=[p.logits for p in predictions],
|
|
941
|
+
anomaly_scores=[p.anomaly_score for p in predictions],
|
|
942
|
+
include_curves=True,
|
|
943
|
+
include_confusion_matrix=True,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
@overload
|
|
947
|
+
def evaluate(
|
|
948
|
+
self,
|
|
949
|
+
data: Datasource | Dataset,
|
|
950
|
+
*,
|
|
951
|
+
value_column: str = "value",
|
|
952
|
+
label_column: str = "label",
|
|
953
|
+
partition_column: str | None = None,
|
|
954
|
+
record_predictions: bool = False,
|
|
955
|
+
tags: set[str] = {"evaluation"},
|
|
956
|
+
batch_size: int = 100,
|
|
957
|
+
subsample: int | float | None = None,
|
|
958
|
+
background: Literal[True],
|
|
959
|
+
ignore_unlabeled: bool = False,
|
|
960
|
+
partition_filter_mode: Literal[
|
|
961
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
962
|
+
] = "include_global",
|
|
963
|
+
) -> Job[ClassificationMetrics]:
|
|
964
|
+
pass
|
|
965
|
+
|
|
966
|
+
@overload
|
|
967
|
+
def evaluate(
|
|
968
|
+
self,
|
|
969
|
+
data: Datasource | Dataset,
|
|
970
|
+
*,
|
|
971
|
+
value_column: str = "value",
|
|
972
|
+
label_column: str = "label",
|
|
973
|
+
partition_column: str | None = None,
|
|
974
|
+
record_predictions: bool = False,
|
|
975
|
+
tags: set[str] = {"evaluation"},
|
|
976
|
+
batch_size: int = 100,
|
|
977
|
+
subsample: int | float | None = None,
|
|
978
|
+
background: Literal[False] = False,
|
|
979
|
+
ignore_unlabeled: bool = False,
|
|
980
|
+
partition_filter_mode: Literal[
|
|
981
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
982
|
+
] = "include_global",
|
|
983
|
+
) -> ClassificationMetrics:
|
|
984
|
+
pass
|
|
985
|
+
|
|
986
|
+
def evaluate(
|
|
987
|
+
self,
|
|
988
|
+
data: Datasource | Dataset,
|
|
989
|
+
*,
|
|
990
|
+
value_column: str = "value",
|
|
991
|
+
label_column: str = "label",
|
|
992
|
+
partition_column: str | None = None,
|
|
993
|
+
record_predictions: bool = False,
|
|
994
|
+
tags: set[str] = {"evaluation"},
|
|
995
|
+
batch_size: int = 100,
|
|
996
|
+
subsample: int | float | None = None,
|
|
997
|
+
background: bool = False,
|
|
998
|
+
ignore_unlabeled: bool = False,
|
|
999
|
+
partition_filter_mode: Literal[
|
|
1000
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
1001
|
+
] = "include_global",
|
|
1002
|
+
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
1003
|
+
"""
|
|
1004
|
+
Evaluate the classification model on a given dataset or datasource
|
|
1005
|
+
|
|
1006
|
+
Params:
|
|
1007
|
+
data: Dataset or Datasource to evaluate the model on
|
|
1008
|
+
value_column: Name of the column that contains the input values to the model
|
|
1009
|
+
label_column: Name of the column containing the expected labels
|
|
1010
|
+
partition_column: Optional name of the column that contains the partition IDs
|
|
1011
|
+
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
1012
|
+
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
1013
|
+
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
1014
|
+
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
1015
|
+
background: Whether to run the operation in the background and return a job handle
|
|
1016
|
+
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
1017
|
+
partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
|
|
1018
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
1019
|
+
* `"include_global"`: Include global memories
|
|
1020
|
+
* `"exclude_global"`: Exclude global memories
|
|
1021
|
+
* `"only_global"`: Only include global memories
|
|
1022
|
+
Returns:
|
|
1023
|
+
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
1024
|
+
|
|
1025
|
+
Examples:
|
|
1026
|
+
>>> model.evaluate(datasource, value_column="text", label_column="airline_sentiment")
|
|
1027
|
+
ClassificationMetrics({
|
|
1028
|
+
accuracy: 0.8500,
|
|
1029
|
+
f1_score: 0.8500,
|
|
1030
|
+
roc_auc: 0.8500,
|
|
1031
|
+
pr_auc: 0.8500,
|
|
1032
|
+
anomaly_score: 0.3500 ± 0.0500,
|
|
1033
|
+
})
|
|
1034
|
+
"""
|
|
1035
|
+
if isinstance(data, Datasource):
|
|
1036
|
+
return self._evaluate_datasource(
|
|
1037
|
+
datasource=data,
|
|
1038
|
+
value_column=value_column,
|
|
1039
|
+
label_column=label_column,
|
|
1040
|
+
record_predictions=record_predictions,
|
|
1041
|
+
tags=tags,
|
|
1042
|
+
subsample=subsample,
|
|
1043
|
+
background=background,
|
|
1044
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
1045
|
+
partition_column=partition_column,
|
|
1046
|
+
partition_filter_mode=partition_filter_mode,
|
|
1047
|
+
)
|
|
1048
|
+
elif isinstance(data, Dataset):
|
|
1049
|
+
return self._evaluate_dataset(
|
|
1050
|
+
dataset=data,
|
|
1051
|
+
value_column=value_column,
|
|
1052
|
+
label_column=label_column,
|
|
1053
|
+
record_predictions=record_predictions,
|
|
1054
|
+
tags=tags,
|
|
1055
|
+
batch_size=batch_size,
|
|
1056
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
1057
|
+
partition_column=partition_column,
|
|
1058
|
+
partition_filter_mode=partition_filter_mode,
|
|
1059
|
+
)
|
|
1060
|
+
else:
|
|
1061
|
+
raise ValueError(f"Invalid data type: {type(data)}")
|
|
1062
|
+
|
|
1063
|
+
def finetune(self, datasource: Datasource):
|
|
1064
|
+
# do not document until implemented
|
|
1065
|
+
raise NotImplementedError("Finetuning is not supported yet")
|
|
1066
|
+
|
|
1067
|
+
@contextmanager
|
|
1068
|
+
def use_memoryset(self, memoryset_override: LabeledMemoryset) -> Generator[None, None, None]:
|
|
1069
|
+
"""
|
|
1070
|
+
Temporarily override the memoryset used by the model for predictions
|
|
1071
|
+
|
|
1072
|
+
Params:
|
|
1073
|
+
memoryset_override: Memoryset to override the default memoryset with
|
|
1074
|
+
|
|
1075
|
+
Examples:
|
|
1076
|
+
>>> with model.use_memoryset(LabeledMemoryset.open("my_other_memoryset")):
|
|
1077
|
+
... predictions = model.predict("I am happy")
|
|
1078
|
+
"""
|
|
1079
|
+
self._memoryset_override_id = memoryset_override.id
|
|
1080
|
+
yield
|
|
1081
|
+
self._memoryset_override_id = None
|
|
1082
|
+
|
|
1083
|
+
@overload
|
|
1084
|
+
def record_feedback(self, feedback: dict[str, Any]) -> None:
|
|
1085
|
+
pass
|
|
1086
|
+
|
|
1087
|
+
@overload
|
|
1088
|
+
def record_feedback(self, feedback: Iterable[dict[str, Any]]) -> None:
|
|
1089
|
+
pass
|
|
1090
|
+
|
|
1091
|
+
def record_feedback(self, feedback: Iterable[dict[str, Any]] | dict[str, Any]):
|
|
1092
|
+
"""
|
|
1093
|
+
Record feedback for a list of predictions.
|
|
1094
|
+
|
|
1095
|
+
We support recording feedback in several categories for each prediction. A
|
|
1096
|
+
[`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
|
|
1097
|
+
the first time feedback with a new name is recorded. Categories are global across models.
|
|
1098
|
+
The value type of the category is inferred from the first recorded value. Subsequent
|
|
1099
|
+
feedback for the same category must be of the same type.
|
|
1100
|
+
|
|
1101
|
+
Params:
|
|
1102
|
+
feedback: Feedback to record, this should be dictionaries with the following keys:
|
|
1103
|
+
|
|
1104
|
+
- `category`: Name of the category under which to record the feedback.
|
|
1105
|
+
- `value`: Feedback value to record, should be `True` for positive feedback and
|
|
1106
|
+
`False` for negative feedback or a [`float`][float] between `-1.0` and `+1.0`
|
|
1107
|
+
where negative values indicate negative feedback and positive values indicate
|
|
1108
|
+
positive feedback.
|
|
1109
|
+
- `comment`: Optional comment to record with the feedback.
|
|
1110
|
+
|
|
1111
|
+
Examples:
|
|
1112
|
+
Record whether predictions were correct or incorrect:
|
|
1113
|
+
>>> model.record_feedback({
|
|
1114
|
+
... "prediction": p.prediction_id,
|
|
1115
|
+
... "category": "correct",
|
|
1116
|
+
... "value": p.label == p.expected_label,
|
|
1117
|
+
... } for p in predictions)
|
|
1118
|
+
|
|
1119
|
+
Record star rating as normalized continuous score between `-1.0` and `+1.0`:
|
|
1120
|
+
>>> model.record_feedback({
|
|
1121
|
+
... "prediction": "123e4567-e89b-12d3-a456-426614174000",
|
|
1122
|
+
... "category": "rating",
|
|
1123
|
+
... "value": -0.5,
|
|
1124
|
+
... "comment": "2 stars"
|
|
1125
|
+
... })
|
|
1126
|
+
|
|
1127
|
+
Raises:
|
|
1128
|
+
ValueError: If the value does not match previous value types for the category, or is a
|
|
1129
|
+
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
1130
|
+
"""
|
|
1131
|
+
client = OrcaClient._resolve_client()
|
|
1132
|
+
client.PUT(
|
|
1133
|
+
"/telemetry/prediction/feedback",
|
|
1134
|
+
json=[
|
|
1135
|
+
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
1136
|
+
],
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
@staticmethod
|
|
1140
|
+
def bootstrap_model(
|
|
1141
|
+
model_description: str,
|
|
1142
|
+
label_names: list[str],
|
|
1143
|
+
initial_examples: list[tuple[str, str]],
|
|
1144
|
+
num_examples_per_label: int,
|
|
1145
|
+
background: bool = False,
|
|
1146
|
+
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
1147
|
+
client = OrcaClient._resolve_client()
|
|
1148
|
+
response = client.POST(
|
|
1149
|
+
"/agents/bootstrap_classification_model",
|
|
1150
|
+
json={
|
|
1151
|
+
"model_description": model_description,
|
|
1152
|
+
"label_names": label_names,
|
|
1153
|
+
"initial_examples": [{"text": text, "label_name": label_name} for text, label_name in initial_examples],
|
|
1154
|
+
"num_examples_per_label": num_examples_per_label,
|
|
1155
|
+
},
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
def get_result() -> BootstrappedClassificationModel:
|
|
1159
|
+
client = OrcaClient._resolve_client()
|
|
1160
|
+
res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
|
|
1161
|
+
assert res["result"] is not None
|
|
1162
|
+
return BootstrappedClassificationModel(res["result"])
|
|
1163
|
+
|
|
1164
|
+
job = Job(response["job_id"], get_result)
|
|
1165
|
+
return job if background else job.result()
|