orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
|
@@ -1,63 +1,71 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
4
|
+
import numpy as np
|
|
3
5
|
import pytest
|
|
4
|
-
from datasets
|
|
6
|
+
from datasets import Dataset
|
|
5
7
|
|
|
6
|
-
from .classification_model import ClassificationModel
|
|
8
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
9
|
+
from .conftest import skip_in_ci
|
|
7
10
|
from .datasource import Datasource
|
|
8
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
9
12
|
from .memoryset import LabeledMemoryset
|
|
13
|
+
from .telemetry import ClassificationPrediction
|
|
10
14
|
|
|
11
15
|
|
|
12
|
-
def test_create_model(
|
|
13
|
-
assert
|
|
14
|
-
assert
|
|
15
|
-
assert
|
|
16
|
-
assert
|
|
17
|
-
assert
|
|
16
|
+
def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
17
|
+
assert classification_model is not None
|
|
18
|
+
assert classification_model.name == "test_classification_model"
|
|
19
|
+
assert classification_model.memoryset == readonly_memoryset
|
|
20
|
+
assert classification_model.num_classes == 2
|
|
21
|
+
assert classification_model.memory_lookup_count == 3
|
|
18
22
|
|
|
19
23
|
|
|
20
|
-
def test_create_model_already_exists_error(
|
|
24
|
+
def test_create_model_already_exists_error(readonly_memoryset, classification_model):
|
|
21
25
|
with pytest.raises(ValueError):
|
|
22
|
-
ClassificationModel.create("
|
|
26
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset)
|
|
23
27
|
with pytest.raises(ValueError):
|
|
24
|
-
ClassificationModel.create("
|
|
28
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
def test_create_model_already_exists_return(
|
|
31
|
+
def test_create_model_already_exists_return(readonly_memoryset, classification_model):
|
|
28
32
|
with pytest.raises(ValueError):
|
|
29
|
-
ClassificationModel.create("
|
|
33
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
30
34
|
|
|
31
35
|
with pytest.raises(ValueError):
|
|
32
|
-
ClassificationModel.create(
|
|
36
|
+
ClassificationModel.create(
|
|
37
|
+
"test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
|
|
38
|
+
)
|
|
33
39
|
|
|
34
40
|
with pytest.raises(ValueError):
|
|
35
|
-
ClassificationModel.create("
|
|
41
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
36
42
|
|
|
37
43
|
with pytest.raises(ValueError):
|
|
38
|
-
ClassificationModel.create(
|
|
44
|
+
ClassificationModel.create(
|
|
45
|
+
"test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
|
|
46
|
+
)
|
|
39
47
|
|
|
40
|
-
new_model = ClassificationModel.create("
|
|
48
|
+
new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
|
|
41
49
|
assert new_model is not None
|
|
42
|
-
assert new_model.name == "
|
|
43
|
-
assert new_model.memoryset ==
|
|
50
|
+
assert new_model.name == "test_classification_model"
|
|
51
|
+
assert new_model.memoryset == readonly_memoryset
|
|
44
52
|
assert new_model.num_classes == 2
|
|
45
53
|
assert new_model.memory_lookup_count == 3
|
|
46
54
|
|
|
47
55
|
|
|
48
|
-
def test_create_model_unauthenticated(unauthenticated,
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
|
|
49
57
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
-
ClassificationModel.create("test_model",
|
|
58
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
51
59
|
|
|
52
60
|
|
|
53
|
-
def test_get_model(
|
|
54
|
-
fetched_model = ClassificationModel.open(
|
|
61
|
+
def test_get_model(classification_model: ClassificationModel):
|
|
62
|
+
fetched_model = ClassificationModel.open(classification_model.name)
|
|
55
63
|
assert fetched_model is not None
|
|
56
|
-
assert fetched_model.id ==
|
|
57
|
-
assert fetched_model.name ==
|
|
64
|
+
assert fetched_model.id == classification_model.id
|
|
65
|
+
assert fetched_model.name == classification_model.name
|
|
58
66
|
assert fetched_model.num_classes == 2
|
|
59
67
|
assert fetched_model.memory_lookup_count == 3
|
|
60
|
-
assert fetched_model ==
|
|
68
|
+
assert fetched_model == classification_model
|
|
61
69
|
|
|
62
70
|
|
|
63
71
|
def test_get_model_unauthenticated(unauthenticated):
|
|
@@ -75,12 +83,12 @@ def test_get_model_not_found():
|
|
|
75
83
|
ClassificationModel.open(str(uuid4()))
|
|
76
84
|
|
|
77
85
|
|
|
78
|
-
def test_get_model_unauthorized(unauthorized,
|
|
86
|
+
def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
79
87
|
with pytest.raises(LookupError):
|
|
80
|
-
ClassificationModel.open(
|
|
88
|
+
ClassificationModel.open(classification_model.name)
|
|
81
89
|
|
|
82
90
|
|
|
83
|
-
def test_list_models(
|
|
91
|
+
def test_list_models(classification_model: ClassificationModel):
|
|
84
92
|
models = ClassificationModel.all()
|
|
85
93
|
assert len(models) > 0
|
|
86
94
|
assert any(model.name == model.name for model in models)
|
|
@@ -91,21 +99,41 @@ def test_list_models_unauthenticated(unauthenticated):
|
|
|
91
99
|
ClassificationModel.all()
|
|
92
100
|
|
|
93
101
|
|
|
94
|
-
def test_list_models_unauthorized(unauthorized,
|
|
102
|
+
def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
95
103
|
assert ClassificationModel.all() == []
|
|
96
104
|
|
|
97
105
|
|
|
98
|
-
def
|
|
99
|
-
|
|
106
|
+
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
107
|
+
classification_model.description = "New description"
|
|
108
|
+
assert classification_model.description == "New description"
|
|
109
|
+
|
|
110
|
+
classification_model.set(description=None)
|
|
111
|
+
assert classification_model.description is None
|
|
112
|
+
|
|
113
|
+
classification_model.set(locked=True)
|
|
114
|
+
assert classification_model.locked is True
|
|
115
|
+
|
|
116
|
+
classification_model.set(locked=False)
|
|
117
|
+
assert classification_model.locked is False
|
|
118
|
+
|
|
119
|
+
classification_model.lock()
|
|
120
|
+
assert classification_model.locked is True
|
|
121
|
+
|
|
122
|
+
classification_model.unlock()
|
|
123
|
+
assert classification_model.locked is False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
127
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
|
|
100
128
|
assert ClassificationModel.open("model_to_delete")
|
|
101
129
|
ClassificationModel.drop("model_to_delete")
|
|
102
130
|
with pytest.raises(LookupError):
|
|
103
131
|
ClassificationModel.open("model_to_delete")
|
|
104
132
|
|
|
105
133
|
|
|
106
|
-
def test_delete_model_unauthenticated(unauthenticated,
|
|
134
|
+
def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
107
135
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
108
|
-
ClassificationModel.drop(
|
|
136
|
+
ClassificationModel.drop(classification_model.name)
|
|
109
137
|
|
|
110
138
|
|
|
111
139
|
def test_delete_model_not_found():
|
|
@@ -115,53 +143,82 @@ def test_delete_model_not_found():
|
|
|
115
143
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
116
144
|
|
|
117
145
|
|
|
118
|
-
def test_delete_model_unauthorized(unauthorized,
|
|
146
|
+
def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
119
147
|
with pytest.raises(LookupError):
|
|
120
|
-
ClassificationModel.drop(
|
|
148
|
+
ClassificationModel.drop(classification_model.name)
|
|
121
149
|
|
|
122
150
|
|
|
123
|
-
@pytest.mark.flaky
|
|
124
151
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
125
|
-
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset
|
|
152
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
|
|
126
153
|
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
127
154
|
with pytest.raises(RuntimeError):
|
|
128
155
|
LabeledMemoryset.drop(memoryset.id)
|
|
129
156
|
|
|
130
157
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
{"text": "soup is great for the winter", "label": 0},
|
|
138
|
-
{"text": "i love cats", "label": 1},
|
|
139
|
-
],
|
|
158
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
159
|
+
def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
|
|
160
|
+
result = (
|
|
161
|
+
classification_model.evaluate(eval_dataset)
|
|
162
|
+
if data_type == "dataset"
|
|
163
|
+
else classification_model.evaluate(eval_datasource)
|
|
140
164
|
)
|
|
141
|
-
|
|
165
|
+
|
|
142
166
|
assert result is not None
|
|
143
|
-
assert isinstance(result
|
|
144
|
-
|
|
145
|
-
assert isinstance(result
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
167
|
+
assert isinstance(result, ClassificationMetrics)
|
|
168
|
+
|
|
169
|
+
assert isinstance(result.accuracy, float)
|
|
170
|
+
assert np.allclose(result.accuracy, 0.5)
|
|
171
|
+
assert isinstance(result.f1_score, float)
|
|
172
|
+
assert np.allclose(result.f1_score, 0.5)
|
|
173
|
+
assert isinstance(result.loss, float)
|
|
174
|
+
|
|
175
|
+
assert isinstance(result.anomaly_score_mean, float)
|
|
176
|
+
assert isinstance(result.anomaly_score_median, float)
|
|
177
|
+
assert isinstance(result.anomaly_score_variance, float)
|
|
178
|
+
assert -1.0 <= result.anomaly_score_mean <= 1.0
|
|
179
|
+
assert -1.0 <= result.anomaly_score_median <= 1.0
|
|
180
|
+
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
181
|
+
|
|
182
|
+
assert result.pr_auc is not None
|
|
183
|
+
assert np.allclose(result.pr_auc, 0.75)
|
|
184
|
+
assert result.pr_curve is not None
|
|
185
|
+
assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
|
|
186
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
187
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
188
|
+
|
|
189
|
+
assert result.roc_auc is not None
|
|
190
|
+
assert np.allclose(result.roc_auc, 0.625)
|
|
191
|
+
assert result.roc_curve is not None
|
|
192
|
+
assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
193
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
194
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
|
|
198
|
+
with pytest.raises(ValueError):
|
|
199
|
+
classification_model.evaluate(datasource, record_predictions=True, tags={"test"})
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_evaluate_dataset_with_nones_raises_error(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
203
|
+
with pytest.raises(ValueError):
|
|
204
|
+
classification_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
208
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
|
|
155
209
|
assert result is not None
|
|
156
|
-
|
|
157
|
-
|
|
210
|
+
assert isinstance(result, ClassificationMetrics)
|
|
211
|
+
predictions = classification_model.predictions(tag="test")
|
|
212
|
+
assert len(predictions) == 4
|
|
158
213
|
assert all(p.tags == {"test"} for p in predictions)
|
|
159
|
-
assert all(p.expected_label ==
|
|
214
|
+
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
160
215
|
|
|
161
216
|
|
|
162
|
-
def test_predict(
|
|
163
|
-
predictions =
|
|
217
|
+
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
218
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
164
219
|
assert len(predictions) == 2
|
|
220
|
+
assert predictions[0].prediction_id is not None
|
|
221
|
+
assert predictions[1].prediction_id is not None
|
|
165
222
|
assert predictions[0].label == 0
|
|
166
223
|
assert predictions[0].label_name == label_names[0]
|
|
167
224
|
assert 0 <= predictions[0].confidence <= 1
|
|
@@ -169,29 +226,59 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
169
226
|
assert predictions[1].label_name == label_names[1]
|
|
170
227
|
assert 0 <= predictions[1].confidence <= 1
|
|
171
228
|
|
|
229
|
+
assert predictions[0].logits is not None
|
|
230
|
+
assert predictions[1].logits is not None
|
|
231
|
+
assert len(predictions[0].logits) == 2
|
|
232
|
+
assert len(predictions[1].logits) == 2
|
|
233
|
+
assert predictions[0].logits[0] > predictions[0].logits[1]
|
|
234
|
+
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_classification_prediction_has_no_label(classification_model: ClassificationModel):
|
|
238
|
+
"""Ensure optional score is None for classification predictions."""
|
|
239
|
+
prediction = classification_model.predict("Do you want to go to the beach?")
|
|
240
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
241
|
+
assert prediction.label is None
|
|
242
|
+
|
|
172
243
|
|
|
173
|
-
def
|
|
244
|
+
def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
245
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
246
|
+
assert len(predictions) == 2
|
|
247
|
+
assert predictions[0].prediction_id is None
|
|
248
|
+
assert predictions[1].prediction_id is None
|
|
249
|
+
assert predictions[0].label == 0
|
|
250
|
+
assert predictions[0].label_name == label_names[0]
|
|
251
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
252
|
+
assert predictions[1].label == 1
|
|
253
|
+
assert predictions[1].label_name == label_names[1]
|
|
254
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
174
258
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
175
|
-
|
|
259
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
176
260
|
|
|
177
261
|
|
|
178
|
-
def test_predict_unauthorized(unauthorized,
|
|
262
|
+
def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
179
263
|
with pytest.raises(LookupError):
|
|
180
|
-
|
|
264
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
181
265
|
|
|
182
266
|
|
|
183
|
-
def test_predict_constraint_violation(
|
|
267
|
+
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
184
268
|
model = ClassificationModel.create(
|
|
185
|
-
"test_model_lookup_count_too_high",
|
|
269
|
+
"test_model_lookup_count_too_high",
|
|
270
|
+
readonly_memoryset,
|
|
271
|
+
num_classes=2,
|
|
272
|
+
memory_lookup_count=readonly_memoryset.length + 2,
|
|
186
273
|
)
|
|
187
274
|
with pytest.raises(RuntimeError):
|
|
188
275
|
model.predict("test")
|
|
189
276
|
|
|
190
277
|
|
|
191
|
-
def test_record_prediction_feedback(
|
|
192
|
-
predictions =
|
|
278
|
+
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
279
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
193
280
|
expected_labels = [0, 1]
|
|
194
|
-
|
|
281
|
+
classification_model.record_feedback(
|
|
195
282
|
{
|
|
196
283
|
"prediction_id": p.prediction_id,
|
|
197
284
|
"category": "correct",
|
|
@@ -201,66 +288,209 @@ def test_record_prediction_feedback(model: ClassificationModel):
|
|
|
201
288
|
)
|
|
202
289
|
|
|
203
290
|
|
|
204
|
-
def test_record_prediction_feedback_missing_category(
|
|
205
|
-
prediction =
|
|
291
|
+
def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
|
|
292
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
206
293
|
with pytest.raises(ValueError):
|
|
207
|
-
|
|
294
|
+
classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
208
295
|
|
|
209
296
|
|
|
210
|
-
def test_record_prediction_feedback_invalid_value(
|
|
211
|
-
prediction =
|
|
297
|
+
def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
|
|
298
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
212
299
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
213
|
-
|
|
300
|
+
classification_model.record_feedback(
|
|
301
|
+
{"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
|
|
302
|
+
)
|
|
214
303
|
|
|
215
304
|
|
|
216
|
-
def test_record_prediction_feedback_invalid_prediction_id(
|
|
305
|
+
def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
|
|
217
306
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
218
|
-
|
|
307
|
+
classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
219
308
|
|
|
220
309
|
|
|
221
|
-
def test_predict_with_memoryset_override(
|
|
310
|
+
def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
222
311
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
223
312
|
"test_memoryset_inverted_labels",
|
|
224
313
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
225
|
-
value_column="text",
|
|
226
314
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
227
315
|
)
|
|
228
|
-
with
|
|
229
|
-
predictions =
|
|
316
|
+
with classification_model.use_memoryset(inverted_labeled_memoryset):
|
|
317
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
230
318
|
assert predictions[0].label == 1
|
|
231
319
|
assert predictions[1].label == 0
|
|
232
320
|
|
|
233
|
-
predictions =
|
|
321
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
234
322
|
assert predictions[0].label == 0
|
|
235
323
|
assert predictions[1].label == 1
|
|
236
324
|
|
|
237
325
|
|
|
238
|
-
def test_predict_with_expected_labels(
|
|
239
|
-
prediction =
|
|
326
|
+
def test_predict_with_expected_labels(classification_model: ClassificationModel):
|
|
327
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1)
|
|
240
328
|
assert prediction.expected_label == 1
|
|
241
329
|
|
|
242
330
|
|
|
243
|
-
def test_predict_with_expected_labels_invalid_input(
|
|
331
|
+
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
244
332
|
# invalid number of expected labels for batch prediction
|
|
245
333
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
246
|
-
|
|
334
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
247
335
|
# invalid label value
|
|
248
336
|
with pytest.raises(ValueError):
|
|
249
|
-
|
|
337
|
+
classification_model.predict("Do you love soup?", expected_labels=5)
|
|
250
338
|
|
|
251
339
|
|
|
252
|
-
def
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
assert
|
|
256
|
-
assert
|
|
257
|
-
assert model._last_prediction_was_batch is True
|
|
340
|
+
def test_predict_with_filters(classification_model: ClassificationModel):
|
|
341
|
+
# there are no memories with label 0 and key g1, so we force a wrong prediction
|
|
342
|
+
filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
|
|
343
|
+
assert filtered_prediction.label == 1
|
|
344
|
+
assert filtered_prediction.label_name == "cats"
|
|
258
345
|
|
|
259
346
|
|
|
260
|
-
def
|
|
261
|
-
|
|
347
|
+
def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
|
|
348
|
+
model = ClassificationModel.create(
|
|
349
|
+
"test_predict_with_memoryset_update",
|
|
350
|
+
writable_memoryset,
|
|
351
|
+
num_classes=2,
|
|
352
|
+
memory_lookup_count=3,
|
|
353
|
+
)
|
|
354
|
+
|
|
262
355
|
prediction = model.predict("Do you love soup?")
|
|
263
|
-
assert
|
|
264
|
-
assert
|
|
265
|
-
|
|
266
|
-
|
|
356
|
+
assert prediction.label == 0
|
|
357
|
+
assert prediction.label_name == "soup"
|
|
358
|
+
|
|
359
|
+
# insert new memories
|
|
360
|
+
writable_memoryset.insert(
|
|
361
|
+
[
|
|
362
|
+
{"value": "Do you love soup?", "label": 1, "key": "g1"},
|
|
363
|
+
{"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
|
|
364
|
+
{"value": "Do you love crackers?", "label": 1, "key": "g2"},
|
|
365
|
+
{"value": "Do you love broth?", "label": 1, "key": "g2"},
|
|
366
|
+
{"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
|
|
367
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
368
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
369
|
+
],
|
|
370
|
+
)
|
|
371
|
+
prediction = model.predict("Do you love soup?")
|
|
372
|
+
assert prediction.label == 1
|
|
373
|
+
assert prediction.label_name == "cats"
|
|
374
|
+
|
|
375
|
+
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
379
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
380
|
+
assert classification_model.last_prediction is not None
|
|
381
|
+
assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
382
|
+
assert classification_model.last_prediction.input_value == "Are cats cute?"
|
|
383
|
+
assert classification_model._last_prediction_was_batch is True
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
387
|
+
# Test that last_prediction is updated correctly with single prediction
|
|
388
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
389
|
+
assert classification_model.last_prediction is not None
|
|
390
|
+
assert classification_model.last_prediction.prediction_id == prediction.prediction_id
|
|
391
|
+
assert classification_model.last_prediction.input_value == "Do you love soup?"
|
|
392
|
+
assert classification_model._last_prediction_was_batch is False
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
396
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
397
|
+
|
|
398
|
+
writable_memoryset.analyze(
|
|
399
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
400
|
+
lookup_count=3,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
model = ClassificationModel.create(
|
|
404
|
+
"test_model_for_explain",
|
|
405
|
+
writable_memoryset,
|
|
406
|
+
num_classes=2,
|
|
407
|
+
memory_lookup_count=3,
|
|
408
|
+
description="This is a test model for explain",
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
412
|
+
assert len(predictions) == 2
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
explanation = predictions[0].explanation
|
|
416
|
+
assert explanation is not None
|
|
417
|
+
assert len(explanation) > 10
|
|
418
|
+
assert "soup" in explanation.lower()
|
|
419
|
+
except Exception as e:
|
|
420
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
421
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
|
|
422
|
+
else:
|
|
423
|
+
raise e
|
|
424
|
+
finally:
|
|
425
|
+
ClassificationModel.drop("test_model_for_explain")
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
429
|
+
def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
430
|
+
"""Test getting action recommendations for predictions"""
|
|
431
|
+
|
|
432
|
+
writable_memoryset.analyze(
|
|
433
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
434
|
+
lookup_count=3,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
model = ClassificationModel.create(
|
|
438
|
+
"test_model_for_action",
|
|
439
|
+
writable_memoryset,
|
|
440
|
+
num_classes=2,
|
|
441
|
+
memory_lookup_count=3,
|
|
442
|
+
description="This is a test model for action recommendations",
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Make a prediction with expected label to simulate incorrect prediction
|
|
446
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
447
|
+
|
|
448
|
+
memoryset_length = model.memoryset.length
|
|
449
|
+
|
|
450
|
+
try:
|
|
451
|
+
# Get action recommendation
|
|
452
|
+
action, rationale = prediction.recommend_action()
|
|
453
|
+
|
|
454
|
+
assert action is not None
|
|
455
|
+
assert rationale is not None
|
|
456
|
+
assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
|
|
457
|
+
assert len(rationale) > 10
|
|
458
|
+
|
|
459
|
+
# Test memory suggestions
|
|
460
|
+
suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
|
|
461
|
+
memory_suggestions = suggestions_response.suggestions
|
|
462
|
+
|
|
463
|
+
assert memory_suggestions is not None
|
|
464
|
+
assert len(memory_suggestions) == 2
|
|
465
|
+
|
|
466
|
+
for suggestion in memory_suggestions:
|
|
467
|
+
assert isinstance(suggestion[0], str)
|
|
468
|
+
assert len(suggestion[0]) > 0
|
|
469
|
+
assert isinstance(suggestion[1], str)
|
|
470
|
+
assert suggestion[1] in model.memoryset.label_names
|
|
471
|
+
|
|
472
|
+
suggestions_response.apply()
|
|
473
|
+
|
|
474
|
+
model.memoryset.refresh()
|
|
475
|
+
assert model.memoryset.length == memoryset_length + 2
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
479
|
+
logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
|
|
480
|
+
else:
|
|
481
|
+
raise e
|
|
482
|
+
finally:
|
|
483
|
+
ClassificationModel.drop("test_model_for_action")
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
487
|
+
"""Test that prompt parameter is properly passed through to predictions"""
|
|
488
|
+
# Test with an instruction-supporting embedding model if available
|
|
489
|
+
prediction_with_prompt = classification_model.predict(
|
|
490
|
+
"I love this product!", prompt="Represent this text for sentiment classification:"
|
|
491
|
+
)
|
|
492
|
+
prediction_without_prompt = classification_model.predict("I love this product!")
|
|
493
|
+
|
|
494
|
+
# Both should work and return valid predictions
|
|
495
|
+
assert prediction_with_prompt.label is not None
|
|
496
|
+
assert prediction_without_prompt.label is not None
|