orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +31 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +601 -129
- orca_sdk/classification_model_test.py +415 -117
- orca_sdk/client.py +3787 -0
- orca_sdk/conftest.py +184 -38
- orca_sdk/credentials.py +162 -20
- orca_sdk/credentials_test.py +100 -16
- orca_sdk/datasource.py +268 -68
- orca_sdk/datasource_test.py +266 -18
- orca_sdk/embedding_model.py +434 -82
- orca_sdk/embedding_model_test.py +66 -33
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1690 -324
- orca_sdk/memoryset_test.py +456 -119
- orca_sdk/regression_model.py +694 -0
- orca_sdk/regression_model_test.py +378 -0
- orca_sdk/telemetry.py +460 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
|
@@ -1,68 +1,78 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
4
|
+
import numpy as np
|
|
3
5
|
import pytest
|
|
4
|
-
from datasets
|
|
6
|
+
from datasets import Dataset
|
|
5
7
|
|
|
6
|
-
from .classification_model import ClassificationModel
|
|
8
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
9
|
+
from .conftest import skip_in_ci
|
|
7
10
|
from .datasource import Datasource
|
|
8
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
9
12
|
from .memoryset import LabeledMemoryset
|
|
13
|
+
from .telemetry import ClassificationPrediction
|
|
10
14
|
|
|
11
15
|
|
|
12
|
-
def test_create_model(
|
|
13
|
-
assert
|
|
14
|
-
assert
|
|
15
|
-
assert
|
|
16
|
-
assert
|
|
17
|
-
assert
|
|
16
|
+
def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
17
|
+
assert classification_model is not None
|
|
18
|
+
assert classification_model.name == "test_classification_model"
|
|
19
|
+
assert classification_model.memoryset == readonly_memoryset
|
|
20
|
+
assert classification_model.num_classes == 2
|
|
21
|
+
assert classification_model.memory_lookup_count == 3
|
|
18
22
|
|
|
19
23
|
|
|
20
|
-
def test_create_model_already_exists_error(
|
|
24
|
+
def test_create_model_already_exists_error(readonly_memoryset, classification_model):
|
|
21
25
|
with pytest.raises(ValueError):
|
|
22
|
-
ClassificationModel.create("
|
|
26
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset)
|
|
23
27
|
with pytest.raises(ValueError):
|
|
24
|
-
ClassificationModel.create("
|
|
28
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
def test_create_model_already_exists_return(
|
|
31
|
+
def test_create_model_already_exists_return(readonly_memoryset, classification_model):
|
|
28
32
|
with pytest.raises(ValueError):
|
|
29
|
-
ClassificationModel.create("
|
|
33
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
30
34
|
|
|
31
35
|
with pytest.raises(ValueError):
|
|
32
|
-
ClassificationModel.create(
|
|
36
|
+
ClassificationModel.create(
|
|
37
|
+
"test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
|
|
38
|
+
)
|
|
33
39
|
|
|
34
40
|
with pytest.raises(ValueError):
|
|
35
|
-
ClassificationModel.create("
|
|
41
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
36
42
|
|
|
37
43
|
with pytest.raises(ValueError):
|
|
38
|
-
ClassificationModel.create(
|
|
44
|
+
ClassificationModel.create(
|
|
45
|
+
"test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
|
|
46
|
+
)
|
|
39
47
|
|
|
40
|
-
new_model = ClassificationModel.create("
|
|
48
|
+
new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
|
|
41
49
|
assert new_model is not None
|
|
42
|
-
assert new_model.name == "
|
|
43
|
-
assert new_model.memoryset ==
|
|
50
|
+
assert new_model.name == "test_classification_model"
|
|
51
|
+
assert new_model.memoryset == readonly_memoryset
|
|
44
52
|
assert new_model.num_classes == 2
|
|
45
53
|
assert new_model.memory_lookup_count == 3
|
|
46
54
|
|
|
47
55
|
|
|
48
|
-
def test_create_model_unauthenticated(
|
|
49
|
-
with
|
|
50
|
-
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
51
60
|
|
|
52
61
|
|
|
53
|
-
def test_get_model(
|
|
54
|
-
fetched_model = ClassificationModel.open(
|
|
62
|
+
def test_get_model(classification_model: ClassificationModel):
|
|
63
|
+
fetched_model = ClassificationModel.open(classification_model.name)
|
|
55
64
|
assert fetched_model is not None
|
|
56
|
-
assert fetched_model.id ==
|
|
57
|
-
assert fetched_model.name ==
|
|
65
|
+
assert fetched_model.id == classification_model.id
|
|
66
|
+
assert fetched_model.name == classification_model.name
|
|
58
67
|
assert fetched_model.num_classes == 2
|
|
59
68
|
assert fetched_model.memory_lookup_count == 3
|
|
60
|
-
assert fetched_model ==
|
|
69
|
+
assert fetched_model == classification_model
|
|
61
70
|
|
|
62
71
|
|
|
63
|
-
def test_get_model_unauthenticated(
|
|
64
|
-
with
|
|
65
|
-
|
|
72
|
+
def test_get_model_unauthenticated(unauthenticated_client):
|
|
73
|
+
with unauthenticated_client.use():
|
|
74
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
75
|
+
ClassificationModel.open("test_model")
|
|
66
76
|
|
|
67
77
|
|
|
68
78
|
def test_get_model_invalid_input():
|
|
@@ -75,37 +85,61 @@ def test_get_model_not_found():
|
|
|
75
85
|
ClassificationModel.open(str(uuid4()))
|
|
76
86
|
|
|
77
87
|
|
|
78
|
-
def test_get_model_unauthorized(
|
|
79
|
-
with
|
|
80
|
-
|
|
88
|
+
def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
89
|
+
with unauthorized_client.use():
|
|
90
|
+
with pytest.raises(LookupError):
|
|
91
|
+
ClassificationModel.open(classification_model.name)
|
|
81
92
|
|
|
82
93
|
|
|
83
|
-
def test_list_models(
|
|
94
|
+
def test_list_models(classification_model: ClassificationModel):
|
|
84
95
|
models = ClassificationModel.all()
|
|
85
96
|
assert len(models) > 0
|
|
86
97
|
assert any(model.name == model.name for model in models)
|
|
87
98
|
|
|
88
99
|
|
|
89
|
-
def test_list_models_unauthenticated(
|
|
90
|
-
with
|
|
91
|
-
|
|
100
|
+
def test_list_models_unauthenticated(unauthenticated_client):
|
|
101
|
+
with unauthenticated_client.use():
|
|
102
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
103
|
+
ClassificationModel.all()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
107
|
+
with unauthorized_client.use():
|
|
108
|
+
assert ClassificationModel.all() == []
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
112
|
+
classification_model.description = "New description"
|
|
113
|
+
assert classification_model.description == "New description"
|
|
114
|
+
|
|
115
|
+
classification_model.set(description=None)
|
|
116
|
+
assert classification_model.description is None
|
|
117
|
+
|
|
118
|
+
classification_model.set(locked=True)
|
|
119
|
+
assert classification_model.locked is True
|
|
92
120
|
|
|
121
|
+
classification_model.set(locked=False)
|
|
122
|
+
assert classification_model.locked is False
|
|
93
123
|
|
|
94
|
-
|
|
95
|
-
assert
|
|
124
|
+
classification_model.lock()
|
|
125
|
+
assert classification_model.locked is True
|
|
96
126
|
|
|
127
|
+
classification_model.unlock()
|
|
128
|
+
assert classification_model.locked is False
|
|
97
129
|
|
|
98
|
-
|
|
99
|
-
|
|
130
|
+
|
|
131
|
+
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
132
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
|
|
100
133
|
assert ClassificationModel.open("model_to_delete")
|
|
101
134
|
ClassificationModel.drop("model_to_delete")
|
|
102
135
|
with pytest.raises(LookupError):
|
|
103
136
|
ClassificationModel.open("model_to_delete")
|
|
104
137
|
|
|
105
138
|
|
|
106
|
-
def test_delete_model_unauthenticated(
|
|
107
|
-
with
|
|
108
|
-
|
|
139
|
+
def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
140
|
+
with unauthenticated_client.use():
|
|
141
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
142
|
+
ClassificationModel.drop(classification_model.name)
|
|
109
143
|
|
|
110
144
|
|
|
111
145
|
def test_delete_model_not_found():
|
|
@@ -115,53 +149,83 @@ def test_delete_model_not_found():
|
|
|
115
149
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
116
150
|
|
|
117
151
|
|
|
118
|
-
def test_delete_model_unauthorized(
|
|
119
|
-
with
|
|
120
|
-
|
|
152
|
+
def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
153
|
+
with unauthorized_client.use():
|
|
154
|
+
with pytest.raises(LookupError):
|
|
155
|
+
ClassificationModel.drop(classification_model.name)
|
|
121
156
|
|
|
122
157
|
|
|
123
|
-
@pytest.mark.flaky
|
|
124
158
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
125
|
-
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset
|
|
159
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
|
|
126
160
|
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
127
161
|
with pytest.raises(RuntimeError):
|
|
128
162
|
LabeledMemoryset.drop(memoryset.id)
|
|
129
163
|
|
|
130
164
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
{"text": "soup is great for the winter", "label": 0},
|
|
138
|
-
{"text": "i love cats", "label": 1},
|
|
139
|
-
],
|
|
165
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
166
|
+
def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
|
|
167
|
+
result = (
|
|
168
|
+
classification_model.evaluate(eval_dataset)
|
|
169
|
+
if data_type == "dataset"
|
|
170
|
+
else classification_model.evaluate(eval_datasource)
|
|
140
171
|
)
|
|
141
|
-
|
|
172
|
+
|
|
142
173
|
assert result is not None
|
|
143
|
-
assert isinstance(result
|
|
144
|
-
|
|
145
|
-
assert isinstance(result
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
174
|
+
assert isinstance(result, ClassificationMetrics)
|
|
175
|
+
|
|
176
|
+
assert isinstance(result.accuracy, float)
|
|
177
|
+
assert np.allclose(result.accuracy, 0.5)
|
|
178
|
+
assert isinstance(result.f1_score, float)
|
|
179
|
+
assert np.allclose(result.f1_score, 0.5)
|
|
180
|
+
assert isinstance(result.loss, float)
|
|
181
|
+
|
|
182
|
+
assert isinstance(result.anomaly_score_mean, float)
|
|
183
|
+
assert isinstance(result.anomaly_score_median, float)
|
|
184
|
+
assert isinstance(result.anomaly_score_variance, float)
|
|
185
|
+
assert -1.0 <= result.anomaly_score_mean <= 1.0
|
|
186
|
+
assert -1.0 <= result.anomaly_score_median <= 1.0
|
|
187
|
+
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
188
|
+
|
|
189
|
+
assert result.pr_auc is not None
|
|
190
|
+
assert np.allclose(result.pr_auc, 0.75)
|
|
191
|
+
assert result.pr_curve is not None
|
|
192
|
+
assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
|
|
193
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
194
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
195
|
+
|
|
196
|
+
assert result.roc_auc is not None
|
|
197
|
+
assert np.allclose(result.roc_auc, 0.625)
|
|
198
|
+
assert result.roc_curve is not None
|
|
199
|
+
assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
200
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
201
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
|
|
205
|
+
with pytest.raises(ValueError):
|
|
206
|
+
classification_model.evaluate(datasource, record_predictions=True, tags={"test"})
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_evaluate_dataset_with_nones_raises_error(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
210
|
+
with pytest.raises(ValueError):
|
|
211
|
+
classification_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
215
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
|
|
155
216
|
assert result is not None
|
|
156
|
-
|
|
157
|
-
|
|
217
|
+
assert isinstance(result, ClassificationMetrics)
|
|
218
|
+
predictions = classification_model.predictions(tag="test")
|
|
219
|
+
assert len(predictions) == 4
|
|
158
220
|
assert all(p.tags == {"test"} for p in predictions)
|
|
159
|
-
assert all(p.expected_label ==
|
|
221
|
+
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
160
222
|
|
|
161
223
|
|
|
162
|
-
def test_predict(
|
|
163
|
-
predictions =
|
|
224
|
+
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
225
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
164
226
|
assert len(predictions) == 2
|
|
227
|
+
assert predictions[0].prediction_id is not None
|
|
228
|
+
assert predictions[1].prediction_id is not None
|
|
165
229
|
assert predictions[0].label == 0
|
|
166
230
|
assert predictions[0].label_name == label_names[0]
|
|
167
231
|
assert 0 <= predictions[0].confidence <= 1
|
|
@@ -169,29 +233,61 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
169
233
|
assert predictions[1].label_name == label_names[1]
|
|
170
234
|
assert 0 <= predictions[1].confidence <= 1
|
|
171
235
|
|
|
236
|
+
assert predictions[0].logits is not None
|
|
237
|
+
assert predictions[1].logits is not None
|
|
238
|
+
assert len(predictions[0].logits) == 2
|
|
239
|
+
assert len(predictions[1].logits) == 2
|
|
240
|
+
assert predictions[0].logits[0] > predictions[0].logits[1]
|
|
241
|
+
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
172
242
|
|
|
173
|
-
def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
|
|
174
|
-
with pytest.raises(ValueError, match="Invalid API key"):
|
|
175
|
-
model.predict(["Do you love soup?", "Are cats cute?"])
|
|
176
243
|
|
|
244
|
+
def test_classification_prediction_has_no_label(classification_model: ClassificationModel):
|
|
245
|
+
"""Ensure optional score is None for classification predictions."""
|
|
246
|
+
prediction = classification_model.predict("Do you want to go to the beach?")
|
|
247
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
248
|
+
assert prediction.label is None
|
|
177
249
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
250
|
+
|
|
251
|
+
def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
252
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
253
|
+
assert len(predictions) == 2
|
|
254
|
+
assert predictions[0].prediction_id is None
|
|
255
|
+
assert predictions[1].prediction_id is None
|
|
256
|
+
assert predictions[0].label == 0
|
|
257
|
+
assert predictions[0].label_name == label_names[0]
|
|
258
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
259
|
+
assert predictions[1].label == 1
|
|
260
|
+
assert predictions[1].label_name == label_names[1]
|
|
261
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
265
|
+
with unauthenticated_client.use():
|
|
266
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
267
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
271
|
+
with unauthorized_client.use():
|
|
272
|
+
with pytest.raises(LookupError):
|
|
273
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
181
274
|
|
|
182
275
|
|
|
183
|
-
def test_predict_constraint_violation(
|
|
276
|
+
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
184
277
|
model = ClassificationModel.create(
|
|
185
|
-
"test_model_lookup_count_too_high",
|
|
278
|
+
"test_model_lookup_count_too_high",
|
|
279
|
+
readonly_memoryset,
|
|
280
|
+
num_classes=2,
|
|
281
|
+
memory_lookup_count=readonly_memoryset.length + 2,
|
|
186
282
|
)
|
|
187
283
|
with pytest.raises(RuntimeError):
|
|
188
284
|
model.predict("test")
|
|
189
285
|
|
|
190
286
|
|
|
191
|
-
def test_record_prediction_feedback(
|
|
192
|
-
predictions =
|
|
287
|
+
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
288
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
193
289
|
expected_labels = [0, 1]
|
|
194
|
-
|
|
290
|
+
classification_model.record_feedback(
|
|
195
291
|
{
|
|
196
292
|
"prediction_id": p.prediction_id,
|
|
197
293
|
"category": "correct",
|
|
@@ -201,66 +297,268 @@ def test_record_prediction_feedback(model: ClassificationModel):
|
|
|
201
297
|
)
|
|
202
298
|
|
|
203
299
|
|
|
204
|
-
def test_record_prediction_feedback_missing_category(
|
|
205
|
-
prediction =
|
|
300
|
+
def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
|
|
301
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
206
302
|
with pytest.raises(ValueError):
|
|
207
|
-
|
|
303
|
+
classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
208
304
|
|
|
209
305
|
|
|
210
|
-
def test_record_prediction_feedback_invalid_value(
|
|
211
|
-
prediction =
|
|
306
|
+
def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
|
|
307
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
212
308
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
213
|
-
|
|
309
|
+
classification_model.record_feedback(
|
|
310
|
+
{"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
|
|
311
|
+
)
|
|
214
312
|
|
|
215
313
|
|
|
216
|
-
def test_record_prediction_feedback_invalid_prediction_id(
|
|
314
|
+
def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
|
|
217
315
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
218
|
-
|
|
316
|
+
classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
219
317
|
|
|
220
318
|
|
|
221
|
-
def test_predict_with_memoryset_override(
|
|
319
|
+
def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
222
320
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
223
321
|
"test_memoryset_inverted_labels",
|
|
224
322
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
225
|
-
value_column="text",
|
|
226
323
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
227
324
|
)
|
|
228
|
-
with
|
|
229
|
-
predictions =
|
|
325
|
+
with classification_model.use_memoryset(inverted_labeled_memoryset):
|
|
326
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
230
327
|
assert predictions[0].label == 1
|
|
231
328
|
assert predictions[1].label == 0
|
|
232
329
|
|
|
233
|
-
predictions =
|
|
330
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
234
331
|
assert predictions[0].label == 0
|
|
235
332
|
assert predictions[1].label == 1
|
|
236
333
|
|
|
237
334
|
|
|
238
|
-
def test_predict_with_expected_labels(
|
|
239
|
-
prediction =
|
|
335
|
+
def test_predict_with_expected_labels(classification_model: ClassificationModel):
|
|
336
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1)
|
|
240
337
|
assert prediction.expected_label == 1
|
|
241
338
|
|
|
242
339
|
|
|
243
|
-
def test_predict_with_expected_labels_invalid_input(
|
|
340
|
+
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
244
341
|
# invalid number of expected labels for batch prediction
|
|
245
342
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
246
|
-
|
|
343
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
247
344
|
# invalid label value
|
|
248
345
|
with pytest.raises(ValueError):
|
|
249
|
-
|
|
346
|
+
classification_model.predict("Do you love soup?", expected_labels=5)
|
|
250
347
|
|
|
251
348
|
|
|
252
|
-
def
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
assert
|
|
256
|
-
assert
|
|
257
|
-
assert model._last_prediction_was_batch is True
|
|
349
|
+
def test_predict_with_filters(classification_model: ClassificationModel):
|
|
350
|
+
# there are no memories with label 0 and key g1, so we force a wrong prediction
|
|
351
|
+
filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
|
|
352
|
+
assert filtered_prediction.label == 1
|
|
353
|
+
assert filtered_prediction.label_name == "cats"
|
|
258
354
|
|
|
259
355
|
|
|
260
|
-
def
|
|
261
|
-
|
|
356
|
+
def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
|
|
357
|
+
model = ClassificationModel.create(
|
|
358
|
+
"test_predict_with_memoryset_update",
|
|
359
|
+
writable_memoryset,
|
|
360
|
+
num_classes=2,
|
|
361
|
+
memory_lookup_count=3,
|
|
362
|
+
)
|
|
363
|
+
|
|
262
364
|
prediction = model.predict("Do you love soup?")
|
|
263
|
-
assert
|
|
264
|
-
assert
|
|
265
|
-
|
|
266
|
-
|
|
365
|
+
assert prediction.label == 0
|
|
366
|
+
assert prediction.label_name == "soup"
|
|
367
|
+
|
|
368
|
+
# insert new memories
|
|
369
|
+
writable_memoryset.insert(
|
|
370
|
+
[
|
|
371
|
+
{"value": "Do you love soup?", "label": 1, "key": "g1"},
|
|
372
|
+
{"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
|
|
373
|
+
{"value": "Do you love crackers?", "label": 1, "key": "g2"},
|
|
374
|
+
{"value": "Do you love broth?", "label": 1, "key": "g2"},
|
|
375
|
+
{"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
|
|
376
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
377
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
378
|
+
],
|
|
379
|
+
)
|
|
380
|
+
prediction = model.predict("Do you love soup?")
|
|
381
|
+
assert prediction.label == 1
|
|
382
|
+
assert prediction.label_name == "cats"
|
|
383
|
+
|
|
384
|
+
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
388
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
389
|
+
assert classification_model.last_prediction is not None
|
|
390
|
+
assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
391
|
+
assert classification_model.last_prediction.input_value == "Are cats cute?"
|
|
392
|
+
assert classification_model._last_prediction_was_batch is True
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
396
|
+
# Test that last_prediction is updated correctly with single prediction
|
|
397
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
398
|
+
assert classification_model.last_prediction is not None
|
|
399
|
+
assert classification_model.last_prediction.prediction_id == prediction.prediction_id
|
|
400
|
+
assert classification_model.last_prediction.input_value == "Do you love soup?"
|
|
401
|
+
assert classification_model._last_prediction_was_batch is False
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
405
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
406
|
+
|
|
407
|
+
writable_memoryset.analyze(
|
|
408
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
409
|
+
lookup_count=3,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
model = ClassificationModel.create(
|
|
413
|
+
"test_model_for_explain",
|
|
414
|
+
writable_memoryset,
|
|
415
|
+
num_classes=2,
|
|
416
|
+
memory_lookup_count=3,
|
|
417
|
+
description="This is a test model for explain",
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
421
|
+
assert len(predictions) == 2
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
explanation = predictions[0].explanation
|
|
425
|
+
assert explanation is not None
|
|
426
|
+
assert len(explanation) > 10
|
|
427
|
+
assert "soup" in explanation.lower()
|
|
428
|
+
except Exception as e:
|
|
429
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
430
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
|
|
431
|
+
else:
|
|
432
|
+
raise e
|
|
433
|
+
finally:
|
|
434
|
+
ClassificationModel.drop("test_model_for_explain")
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
438
|
+
def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
439
|
+
"""Test getting action recommendations for predictions"""
|
|
440
|
+
|
|
441
|
+
writable_memoryset.analyze(
|
|
442
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
443
|
+
lookup_count=3,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
model = ClassificationModel.create(
|
|
447
|
+
"test_model_for_action",
|
|
448
|
+
writable_memoryset,
|
|
449
|
+
num_classes=2,
|
|
450
|
+
memory_lookup_count=3,
|
|
451
|
+
description="This is a test model for action recommendations",
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Make a prediction with expected label to simulate incorrect prediction
|
|
455
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
456
|
+
|
|
457
|
+
memoryset_length = model.memoryset.length
|
|
458
|
+
|
|
459
|
+
try:
|
|
460
|
+
# Get action recommendation
|
|
461
|
+
action, rationale = prediction.recommend_action()
|
|
462
|
+
|
|
463
|
+
assert action is not None
|
|
464
|
+
assert rationale is not None
|
|
465
|
+
assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
|
|
466
|
+
assert len(rationale) > 10
|
|
467
|
+
|
|
468
|
+
# Test memory suggestions
|
|
469
|
+
suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
|
|
470
|
+
memory_suggestions = suggestions_response.suggestions
|
|
471
|
+
|
|
472
|
+
assert memory_suggestions is not None
|
|
473
|
+
assert len(memory_suggestions) == 2
|
|
474
|
+
|
|
475
|
+
for suggestion in memory_suggestions:
|
|
476
|
+
assert isinstance(suggestion[0], str)
|
|
477
|
+
assert len(suggestion[0]) > 0
|
|
478
|
+
assert isinstance(suggestion[1], str)
|
|
479
|
+
assert suggestion[1] in model.memoryset.label_names
|
|
480
|
+
|
|
481
|
+
suggestions_response.apply()
|
|
482
|
+
|
|
483
|
+
model.memoryset.refresh()
|
|
484
|
+
assert model.memoryset.length == memoryset_length + 2
|
|
485
|
+
|
|
486
|
+
except Exception as e:
|
|
487
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
488
|
+
logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
|
|
489
|
+
else:
|
|
490
|
+
raise e
|
|
491
|
+
finally:
|
|
492
|
+
ClassificationModel.drop("test_model_for_action")
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
496
|
+
"""Test that prompt parameter is properly passed through to predictions"""
|
|
497
|
+
# Test with an instruction-supporting embedding model if available
|
|
498
|
+
prediction_with_prompt = classification_model.predict(
|
|
499
|
+
"I love this product!", prompt="Represent this text for sentiment classification:"
|
|
500
|
+
)
|
|
501
|
+
prediction_without_prompt = classification_model.predict("I love this product!")
|
|
502
|
+
|
|
503
|
+
# Both should work and return valid predictions
|
|
504
|
+
assert prediction_with_prompt.label is not None
|
|
505
|
+
assert prediction_without_prompt.label is not None
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
|
|
510
|
+
"""Test async prediction with a single value"""
|
|
511
|
+
prediction = await classification_model.apredict("Do you love soup?")
|
|
512
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
513
|
+
assert prediction.prediction_id is not None
|
|
514
|
+
assert prediction.label == 0
|
|
515
|
+
assert prediction.label_name == label_names[0]
|
|
516
|
+
assert 0 <= prediction.confidence <= 1
|
|
517
|
+
assert prediction.logits is not None
|
|
518
|
+
assert len(prediction.logits) == 2
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@pytest.mark.asyncio
|
|
522
|
+
async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
|
|
523
|
+
"""Test async prediction with a batch of values"""
|
|
524
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
|
|
525
|
+
assert len(predictions) == 2
|
|
526
|
+
assert predictions[0].prediction_id is not None
|
|
527
|
+
assert predictions[1].prediction_id is not None
|
|
528
|
+
assert predictions[0].label == 0
|
|
529
|
+
assert predictions[0].label_name == label_names[0]
|
|
530
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
531
|
+
assert predictions[1].label == 1
|
|
532
|
+
assert predictions[1].label_name == label_names[1]
|
|
533
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@pytest.mark.asyncio
|
|
537
|
+
async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
|
|
538
|
+
"""Test async prediction with expected labels"""
|
|
539
|
+
prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
|
|
540
|
+
assert prediction.expected_label == 1
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@pytest.mark.asyncio
|
|
544
|
+
async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
545
|
+
"""Test async prediction with telemetry disabled"""
|
|
546
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
547
|
+
assert len(predictions) == 2
|
|
548
|
+
assert predictions[0].prediction_id is None
|
|
549
|
+
assert predictions[1].prediction_id is None
|
|
550
|
+
assert predictions[0].label == 0
|
|
551
|
+
assert predictions[0].label_name == label_names[0]
|
|
552
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
553
|
+
assert predictions[1].label == 1
|
|
554
|
+
assert predictions[1].label_name == label_names[1]
|
|
555
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
@pytest.mark.asyncio
|
|
559
|
+
async def test_predict_async_with_filters(classification_model: ClassificationModel):
|
|
560
|
+
"""Test async prediction with filters"""
|
|
561
|
+
# there are no memories with label 0 and key g2, so we force a wrong prediction
|
|
562
|
+
filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
|
|
563
|
+
assert filtered_prediction.label == 1
|
|
564
|
+
assert filtered_prediction.label_name == "cats"
|