orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +176 -14
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +515 -475
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +19 -10
- orca_sdk/datasource_test.py +24 -18
- orca_sdk/embedding_model.py +22 -13
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +14 -8
- orca_sdk/memoryset.py +513 -183
- orca_sdk/memoryset_test.py +130 -32
- orca_sdk/regression_model.py +21 -11
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +24 -13
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +1 -1
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -3,19 +3,27 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Generator,
|
|
9
|
+
Iterable,
|
|
10
|
+
Literal,
|
|
11
|
+
cast,
|
|
12
|
+
overload,
|
|
13
|
+
)
|
|
7
14
|
|
|
8
15
|
from datasets import Dataset
|
|
9
16
|
|
|
10
17
|
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
11
18
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
19
|
+
from .async_client import OrcaAsyncClient
|
|
12
20
|
from .client import (
|
|
13
21
|
BootstrapClassificationModelMeta,
|
|
14
22
|
BootstrapClassificationModelResult,
|
|
15
23
|
ClassificationModelMetadata,
|
|
24
|
+
OrcaClient,
|
|
16
25
|
PredictiveModelUpdate,
|
|
17
26
|
RACHeadType,
|
|
18
|
-
orca_api,
|
|
19
27
|
)
|
|
20
28
|
from .datasource import Datasource
|
|
21
29
|
from .job import Job
|
|
@@ -211,7 +219,8 @@ class ClassificationModel:
|
|
|
211
219
|
|
|
212
220
|
return existing
|
|
213
221
|
|
|
214
|
-
|
|
222
|
+
client = OrcaClient._resolve_client()
|
|
223
|
+
metadata = client.POST(
|
|
215
224
|
"/classification_model",
|
|
216
225
|
json={
|
|
217
226
|
"name": name,
|
|
@@ -240,7 +249,8 @@ class ClassificationModel:
|
|
|
240
249
|
Raises:
|
|
241
250
|
LookupError: If the classification model does not exist
|
|
242
251
|
"""
|
|
243
|
-
|
|
252
|
+
client = OrcaClient._resolve_client()
|
|
253
|
+
return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
244
254
|
|
|
245
255
|
@classmethod
|
|
246
256
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -267,7 +277,8 @@ class ClassificationModel:
|
|
|
267
277
|
Returns:
|
|
268
278
|
List of handles to all classification models in the OrcaCloud
|
|
269
279
|
"""
|
|
270
|
-
|
|
280
|
+
client = OrcaClient._resolve_client()
|
|
281
|
+
return [cls(metadata) for metadata in client.GET("/classification_model")]
|
|
271
282
|
|
|
272
283
|
@classmethod
|
|
273
284
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -286,7 +297,8 @@ class ClassificationModel:
|
|
|
286
297
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
287
298
|
"""
|
|
288
299
|
try:
|
|
289
|
-
|
|
300
|
+
client = OrcaClient._resolve_client()
|
|
301
|
+
client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
290
302
|
logging.info(f"Deleted model {name_or_id}")
|
|
291
303
|
except LookupError:
|
|
292
304
|
if if_not_exists == "error":
|
|
@@ -322,7 +334,8 @@ class ClassificationModel:
|
|
|
322
334
|
update["description"] = description
|
|
323
335
|
if locked is not UNSET:
|
|
324
336
|
update["locked"] = locked
|
|
325
|
-
|
|
337
|
+
client = OrcaClient._resolve_client()
|
|
338
|
+
client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
326
339
|
self.refresh()
|
|
327
340
|
|
|
328
341
|
def lock(self) -> None:
|
|
@@ -435,7 +448,150 @@ class ClassificationModel:
|
|
|
435
448
|
]
|
|
436
449
|
|
|
437
450
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
438
|
-
|
|
451
|
+
client = OrcaClient._resolve_client()
|
|
452
|
+
response = client.POST(
|
|
453
|
+
"/gpu/classification_model/{name_or_id}/prediction",
|
|
454
|
+
params={"name_or_id": self.id},
|
|
455
|
+
json={
|
|
456
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
457
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
458
|
+
"expected_labels": expected_labels,
|
|
459
|
+
"tags": list(tags or set()),
|
|
460
|
+
"save_telemetry": telemetry_on,
|
|
461
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
462
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
463
|
+
"prompt": prompt,
|
|
464
|
+
"use_lookup_cache": use_lookup_cache,
|
|
465
|
+
},
|
|
466
|
+
timeout=timeout_seconds,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
470
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
471
|
+
|
|
472
|
+
predictions = [
|
|
473
|
+
ClassificationPrediction(
|
|
474
|
+
prediction_id=prediction["prediction_id"],
|
|
475
|
+
label=prediction["label"],
|
|
476
|
+
label_name=prediction["label_name"],
|
|
477
|
+
score=None,
|
|
478
|
+
confidence=prediction["confidence"],
|
|
479
|
+
anomaly_score=prediction["anomaly_score"],
|
|
480
|
+
memoryset=self.memoryset,
|
|
481
|
+
model=self,
|
|
482
|
+
logits=prediction["logits"],
|
|
483
|
+
input_value=input_value,
|
|
484
|
+
)
|
|
485
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
486
|
+
]
|
|
487
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
488
|
+
self._last_prediction = predictions[-1]
|
|
489
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
490
|
+
|
|
491
|
+
@overload
|
|
492
|
+
async def apredict(
|
|
493
|
+
self,
|
|
494
|
+
value: list[str],
|
|
495
|
+
expected_labels: list[int] | None = None,
|
|
496
|
+
filters: list[FilterItemTuple] = [],
|
|
497
|
+
tags: set[str] | None = None,
|
|
498
|
+
save_telemetry: TelemetryMode = "on",
|
|
499
|
+
prompt: str | None = None,
|
|
500
|
+
use_lookup_cache: bool = True,
|
|
501
|
+
timeout_seconds: int = 10,
|
|
502
|
+
) -> list[ClassificationPrediction]:
|
|
503
|
+
pass
|
|
504
|
+
|
|
505
|
+
@overload
|
|
506
|
+
async def apredict(
|
|
507
|
+
self,
|
|
508
|
+
value: str,
|
|
509
|
+
expected_labels: int | None = None,
|
|
510
|
+
filters: list[FilterItemTuple] = [],
|
|
511
|
+
tags: set[str] | None = None,
|
|
512
|
+
save_telemetry: TelemetryMode = "on",
|
|
513
|
+
prompt: str | None = None,
|
|
514
|
+
use_lookup_cache: bool = True,
|
|
515
|
+
timeout_seconds: int = 10,
|
|
516
|
+
) -> ClassificationPrediction:
|
|
517
|
+
pass
|
|
518
|
+
|
|
519
|
+
async def apredict(
|
|
520
|
+
self,
|
|
521
|
+
value: list[str] | str,
|
|
522
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
523
|
+
filters: list[FilterItemTuple] = [],
|
|
524
|
+
tags: set[str] | None = None,
|
|
525
|
+
save_telemetry: TelemetryMode = "on",
|
|
526
|
+
prompt: str | None = None,
|
|
527
|
+
use_lookup_cache: bool = True,
|
|
528
|
+
timeout_seconds: int = 10,
|
|
529
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
530
|
+
"""
|
|
531
|
+
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
532
|
+
|
|
533
|
+
Params:
|
|
534
|
+
value: Value(s) to get predict the labels of
|
|
535
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
536
|
+
filters: Optional filters to apply during memory lookup
|
|
537
|
+
tags: Tags to add to the prediction(s)
|
|
538
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
539
|
+
* `"off"`: Do not save telemetry
|
|
540
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
541
|
+
environment variable is set.
|
|
542
|
+
* `"sync"`: Save telemetry synchronously
|
|
543
|
+
* `"async"`: Save telemetry asynchronously
|
|
544
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
545
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
546
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
Label prediction or list of label predictions.
|
|
550
|
+
|
|
551
|
+
Raises:
|
|
552
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
553
|
+
TimeoutError: If the request times out after the specified duration
|
|
554
|
+
|
|
555
|
+
Examples:
|
|
556
|
+
Predict the label for a single value:
|
|
557
|
+
>>> prediction = await model.apredict("I am happy", tags={"test"})
|
|
558
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
559
|
+
|
|
560
|
+
Predict the labels for a list of values:
|
|
561
|
+
>>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
562
|
+
[
|
|
563
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
564
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
565
|
+
]
|
|
566
|
+
|
|
567
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
568
|
+
>>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
569
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
if timeout_seconds <= 0:
|
|
573
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
574
|
+
|
|
575
|
+
parsed_filters = [
|
|
576
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
577
|
+
]
|
|
578
|
+
|
|
579
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
580
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
581
|
+
|
|
582
|
+
if isinstance(expected_labels, int):
|
|
583
|
+
expected_labels = [expected_labels]
|
|
584
|
+
elif isinstance(expected_labels, str):
|
|
585
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
586
|
+
elif isinstance(expected_labels, list):
|
|
587
|
+
expected_labels = [
|
|
588
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
589
|
+
for label in expected_labels
|
|
590
|
+
]
|
|
591
|
+
|
|
592
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
593
|
+
client = OrcaAsyncClient._resolve_client()
|
|
594
|
+
response = await client.POST(
|
|
439
595
|
"/gpu/classification_model/{name_or_id}/prediction",
|
|
440
596
|
params={"name_or_id": self.id},
|
|
441
597
|
json={
|
|
@@ -515,7 +671,8 @@ class ClassificationModel:
|
|
|
515
671
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
516
672
|
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
517
673
|
"""
|
|
518
|
-
|
|
674
|
+
client = OrcaClient._resolve_client()
|
|
675
|
+
predictions = client.POST(
|
|
519
676
|
"/telemetry/prediction",
|
|
520
677
|
json={
|
|
521
678
|
"model_id": self.id,
|
|
@@ -551,7 +708,8 @@ class ClassificationModel:
|
|
|
551
708
|
tags: set[str] | None,
|
|
552
709
|
background: bool = False,
|
|
553
710
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
554
|
-
|
|
711
|
+
client = OrcaClient._resolve_client()
|
|
712
|
+
response = client.POST(
|
|
555
713
|
"/classification_model/{model_name_or_id}/evaluation",
|
|
556
714
|
params={"model_name_or_id": self.id},
|
|
557
715
|
json={
|
|
@@ -565,7 +723,8 @@ class ClassificationModel:
|
|
|
565
723
|
)
|
|
566
724
|
|
|
567
725
|
def get_value():
|
|
568
|
-
|
|
726
|
+
client = OrcaClient._resolve_client()
|
|
727
|
+
res = client.GET(
|
|
569
728
|
"/classification_model/{model_name_or_id}/evaluation/{task_id}",
|
|
570
729
|
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
571
730
|
)
|
|
@@ -773,7 +932,8 @@ class ClassificationModel:
|
|
|
773
932
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
774
933
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
775
934
|
"""
|
|
776
|
-
|
|
935
|
+
client = OrcaClient._resolve_client()
|
|
936
|
+
client.PUT(
|
|
777
937
|
"/telemetry/prediction/feedback",
|
|
778
938
|
json=[
|
|
779
939
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
@@ -788,7 +948,8 @@ class ClassificationModel:
|
|
|
788
948
|
num_examples_per_label: int,
|
|
789
949
|
background: bool = False,
|
|
790
950
|
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
791
|
-
|
|
951
|
+
client = OrcaClient._resolve_client()
|
|
952
|
+
response = client.POST(
|
|
792
953
|
"/agents/bootstrap_classification_model",
|
|
793
954
|
json={
|
|
794
955
|
"model_description": model_description,
|
|
@@ -799,7 +960,8 @@ class ClassificationModel:
|
|
|
799
960
|
)
|
|
800
961
|
|
|
801
962
|
def get_result() -> BootstrappedClassificationModel:
|
|
802
|
-
|
|
963
|
+
client = OrcaClient._resolve_client()
|
|
964
|
+
res = client.GET(
|
|
803
965
|
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
804
966
|
)
|
|
805
967
|
assert res["result"] is not None
|
|
@@ -53,9 +53,10 @@ def test_create_model_already_exists_return(readonly_memoryset, classification_m
|
|
|
53
53
|
assert new_model.memory_lookup_count == 3
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def test_create_model_unauthenticated(
|
|
57
|
-
with
|
|
58
|
-
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
def test_get_model(classification_model: ClassificationModel):
|
|
@@ -68,9 +69,10 @@ def test_get_model(classification_model: ClassificationModel):
|
|
|
68
69
|
assert fetched_model == classification_model
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
def test_get_model_unauthenticated(
|
|
72
|
-
with
|
|
73
|
-
|
|
72
|
+
def test_get_model_unauthenticated(unauthenticated_client):
|
|
73
|
+
with unauthenticated_client.use():
|
|
74
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
75
|
+
ClassificationModel.open("test_model")
|
|
74
76
|
|
|
75
77
|
|
|
76
78
|
def test_get_model_invalid_input():
|
|
@@ -83,9 +85,10 @@ def test_get_model_not_found():
|
|
|
83
85
|
ClassificationModel.open(str(uuid4()))
|
|
84
86
|
|
|
85
87
|
|
|
86
|
-
def test_get_model_unauthorized(
|
|
87
|
-
with
|
|
88
|
-
|
|
88
|
+
def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
89
|
+
with unauthorized_client.use():
|
|
90
|
+
with pytest.raises(LookupError):
|
|
91
|
+
ClassificationModel.open(classification_model.name)
|
|
89
92
|
|
|
90
93
|
|
|
91
94
|
def test_list_models(classification_model: ClassificationModel):
|
|
@@ -94,13 +97,15 @@ def test_list_models(classification_model: ClassificationModel):
|
|
|
94
97
|
assert any(model.name == model.name for model in models)
|
|
95
98
|
|
|
96
99
|
|
|
97
|
-
def test_list_models_unauthenticated(
|
|
98
|
-
with
|
|
99
|
-
|
|
100
|
+
def test_list_models_unauthenticated(unauthenticated_client):
|
|
101
|
+
with unauthenticated_client.use():
|
|
102
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
103
|
+
ClassificationModel.all()
|
|
100
104
|
|
|
101
105
|
|
|
102
|
-
def test_list_models_unauthorized(
|
|
103
|
-
|
|
106
|
+
def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
107
|
+
with unauthorized_client.use():
|
|
108
|
+
assert ClassificationModel.all() == []
|
|
104
109
|
|
|
105
110
|
|
|
106
111
|
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
@@ -131,9 +136,10 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
|
131
136
|
ClassificationModel.open("model_to_delete")
|
|
132
137
|
|
|
133
138
|
|
|
134
|
-
def test_delete_model_unauthenticated(
|
|
135
|
-
with
|
|
136
|
-
|
|
139
|
+
def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
140
|
+
with unauthenticated_client.use():
|
|
141
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
142
|
+
ClassificationModel.drop(classification_model.name)
|
|
137
143
|
|
|
138
144
|
|
|
139
145
|
def test_delete_model_not_found():
|
|
@@ -143,9 +149,10 @@ def test_delete_model_not_found():
|
|
|
143
149
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
144
150
|
|
|
145
151
|
|
|
146
|
-
def test_delete_model_unauthorized(
|
|
147
|
-
with
|
|
148
|
-
|
|
152
|
+
def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
153
|
+
with unauthorized_client.use():
|
|
154
|
+
with pytest.raises(LookupError):
|
|
155
|
+
ClassificationModel.drop(classification_model.name)
|
|
149
156
|
|
|
150
157
|
|
|
151
158
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
@@ -254,14 +261,16 @@ def test_predict_disable_telemetry(classification_model: ClassificationModel, la
|
|
|
254
261
|
assert 0 <= predictions[1].confidence <= 1
|
|
255
262
|
|
|
256
263
|
|
|
257
|
-
def test_predict_unauthenticated(
|
|
258
|
-
with
|
|
259
|
-
|
|
264
|
+
def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
265
|
+
with unauthenticated_client.use():
|
|
266
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
267
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
260
268
|
|
|
261
269
|
|
|
262
|
-
def test_predict_unauthorized(
|
|
263
|
-
with
|
|
264
|
-
|
|
270
|
+
def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
271
|
+
with unauthorized_client.use():
|
|
272
|
+
with pytest.raises(LookupError):
|
|
273
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
265
274
|
|
|
266
275
|
|
|
267
276
|
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
@@ -396,7 +405,7 @@ def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
|
396
405
|
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
397
406
|
|
|
398
407
|
writable_memoryset.analyze(
|
|
399
|
-
{"name": "
|
|
408
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
400
409
|
lookup_count=3,
|
|
401
410
|
)
|
|
402
411
|
|
|
@@ -430,7 +439,7 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
|
430
439
|
"""Test getting action recommendations for predictions"""
|
|
431
440
|
|
|
432
441
|
writable_memoryset.analyze(
|
|
433
|
-
{"name": "
|
|
442
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
434
443
|
lookup_count=3,
|
|
435
444
|
)
|
|
436
445
|
|
|
@@ -494,3 +503,62 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
|
494
503
|
# Both should work and return valid predictions
|
|
495
504
|
assert prediction_with_prompt.label is not None
|
|
496
505
|
assert prediction_without_prompt.label is not None
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
|
|
510
|
+
"""Test async prediction with a single value"""
|
|
511
|
+
prediction = await classification_model.apredict("Do you love soup?")
|
|
512
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
513
|
+
assert prediction.prediction_id is not None
|
|
514
|
+
assert prediction.label == 0
|
|
515
|
+
assert prediction.label_name == label_names[0]
|
|
516
|
+
assert 0 <= prediction.confidence <= 1
|
|
517
|
+
assert prediction.logits is not None
|
|
518
|
+
assert len(prediction.logits) == 2
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@pytest.mark.asyncio
|
|
522
|
+
async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
|
|
523
|
+
"""Test async prediction with a batch of values"""
|
|
524
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
|
|
525
|
+
assert len(predictions) == 2
|
|
526
|
+
assert predictions[0].prediction_id is not None
|
|
527
|
+
assert predictions[1].prediction_id is not None
|
|
528
|
+
assert predictions[0].label == 0
|
|
529
|
+
assert predictions[0].label_name == label_names[0]
|
|
530
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
531
|
+
assert predictions[1].label == 1
|
|
532
|
+
assert predictions[1].label_name == label_names[1]
|
|
533
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@pytest.mark.asyncio
|
|
537
|
+
async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
|
|
538
|
+
"""Test async prediction with expected labels"""
|
|
539
|
+
prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
|
|
540
|
+
assert prediction.expected_label == 1
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@pytest.mark.asyncio
|
|
544
|
+
async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
545
|
+
"""Test async prediction with telemetry disabled"""
|
|
546
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
547
|
+
assert len(predictions) == 2
|
|
548
|
+
assert predictions[0].prediction_id is None
|
|
549
|
+
assert predictions[1].prediction_id is None
|
|
550
|
+
assert predictions[0].label == 0
|
|
551
|
+
assert predictions[0].label_name == label_names[0]
|
|
552
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
553
|
+
assert predictions[1].label == 1
|
|
554
|
+
assert predictions[1].label_name == label_names[1]
|
|
555
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
@pytest.mark.asyncio
|
|
559
|
+
async def test_predict_async_with_filters(classification_model: ClassificationModel):
|
|
560
|
+
"""Test async prediction with filters"""
|
|
561
|
+
# there are no memories with label 0 and key g2, so we force a wrong prediction
|
|
562
|
+
filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
|
|
563
|
+
assert filtered_prediction.label == 1
|
|
564
|
+
assert filtered_prediction.label_name == "cats"
|