orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +84 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +90 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +184 -174
- orca_sdk/classification_model_test.py +178 -142
- orca_sdk/conftest.py +77 -26
- orca_sdk/datasource.py +34 -0
- orca_sdk/datasource_test.py +9 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +58 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +225 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
|
@@ -1,46 +1,52 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pytest
|
|
5
6
|
from datasets.arrow_dataset import Dataset
|
|
6
7
|
|
|
7
|
-
from .classification_model import ClassificationModel
|
|
8
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
9
|
+
from .conftest import skip_in_ci
|
|
8
10
|
from .datasource import Datasource
|
|
9
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
10
12
|
from .memoryset import LabeledMemoryset
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def test_create_model(
|
|
14
|
-
assert
|
|
15
|
-
assert
|
|
16
|
-
assert
|
|
17
|
-
assert
|
|
18
|
-
assert
|
|
15
|
+
def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
16
|
+
assert classification_model is not None
|
|
17
|
+
assert classification_model.name == "test_classification_model"
|
|
18
|
+
assert classification_model.memoryset == readonly_memoryset
|
|
19
|
+
assert classification_model.num_classes == 2
|
|
20
|
+
assert classification_model.memory_lookup_count == 3
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
def test_create_model_already_exists_error(readonly_memoryset,
|
|
23
|
+
def test_create_model_already_exists_error(readonly_memoryset, classification_model):
|
|
22
24
|
with pytest.raises(ValueError):
|
|
23
|
-
ClassificationModel.create("
|
|
25
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset)
|
|
24
26
|
with pytest.raises(ValueError):
|
|
25
|
-
ClassificationModel.create("
|
|
27
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
def test_create_model_already_exists_return(readonly_memoryset,
|
|
30
|
+
def test_create_model_already_exists_return(readonly_memoryset, classification_model):
|
|
29
31
|
with pytest.raises(ValueError):
|
|
30
|
-
ClassificationModel.create("
|
|
32
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
31
33
|
|
|
32
34
|
with pytest.raises(ValueError):
|
|
33
|
-
ClassificationModel.create(
|
|
35
|
+
ClassificationModel.create(
|
|
36
|
+
"test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
|
|
37
|
+
)
|
|
34
38
|
|
|
35
39
|
with pytest.raises(ValueError):
|
|
36
|
-
ClassificationModel.create("
|
|
40
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
37
41
|
|
|
38
42
|
with pytest.raises(ValueError):
|
|
39
|
-
ClassificationModel.create(
|
|
43
|
+
ClassificationModel.create(
|
|
44
|
+
"test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
|
|
45
|
+
)
|
|
40
46
|
|
|
41
|
-
new_model = ClassificationModel.create("
|
|
47
|
+
new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
|
|
42
48
|
assert new_model is not None
|
|
43
|
-
assert new_model.name == "
|
|
49
|
+
assert new_model.name == "test_classification_model"
|
|
44
50
|
assert new_model.memoryset == readonly_memoryset
|
|
45
51
|
assert new_model.num_classes == 2
|
|
46
52
|
assert new_model.memory_lookup_count == 3
|
|
@@ -51,14 +57,14 @@ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: Label
|
|
|
51
57
|
ClassificationModel.create("test_model", readonly_memoryset)
|
|
52
58
|
|
|
53
59
|
|
|
54
|
-
def test_get_model(
|
|
55
|
-
fetched_model = ClassificationModel.open(
|
|
60
|
+
def test_get_model(classification_model: ClassificationModel):
|
|
61
|
+
fetched_model = ClassificationModel.open(classification_model.name)
|
|
56
62
|
assert fetched_model is not None
|
|
57
|
-
assert fetched_model.id ==
|
|
58
|
-
assert fetched_model.name ==
|
|
63
|
+
assert fetched_model.id == classification_model.id
|
|
64
|
+
assert fetched_model.name == classification_model.name
|
|
59
65
|
assert fetched_model.num_classes == 2
|
|
60
66
|
assert fetched_model.memory_lookup_count == 3
|
|
61
|
-
assert fetched_model ==
|
|
67
|
+
assert fetched_model == classification_model
|
|
62
68
|
|
|
63
69
|
|
|
64
70
|
def test_get_model_unauthenticated(unauthenticated):
|
|
@@ -76,12 +82,12 @@ def test_get_model_not_found():
|
|
|
76
82
|
ClassificationModel.open(str(uuid4()))
|
|
77
83
|
|
|
78
84
|
|
|
79
|
-
def test_get_model_unauthorized(unauthorized,
|
|
85
|
+
def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
80
86
|
with pytest.raises(LookupError):
|
|
81
|
-
ClassificationModel.open(
|
|
87
|
+
ClassificationModel.open(classification_model.name)
|
|
82
88
|
|
|
83
89
|
|
|
84
|
-
def test_list_models(
|
|
90
|
+
def test_list_models(classification_model: ClassificationModel):
|
|
85
91
|
models = ClassificationModel.all()
|
|
86
92
|
assert len(models) > 0
|
|
87
93
|
assert any(model.name == model.name for model in models)
|
|
@@ -92,19 +98,28 @@ def test_list_models_unauthenticated(unauthenticated):
|
|
|
92
98
|
ClassificationModel.all()
|
|
93
99
|
|
|
94
100
|
|
|
95
|
-
def test_list_models_unauthorized(unauthorized,
|
|
101
|
+
def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
96
102
|
assert ClassificationModel.all() == []
|
|
97
103
|
|
|
98
104
|
|
|
99
|
-
def
|
|
100
|
-
|
|
101
|
-
assert
|
|
105
|
+
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
106
|
+
classification_model.description = "New description"
|
|
107
|
+
assert classification_model.description == "New description"
|
|
108
|
+
|
|
109
|
+
classification_model.set(description=None)
|
|
110
|
+
assert classification_model.description is None
|
|
102
111
|
|
|
112
|
+
classification_model.set(locked=True)
|
|
113
|
+
assert classification_model.locked is True
|
|
103
114
|
|
|
104
|
-
|
|
105
|
-
assert
|
|
106
|
-
|
|
107
|
-
|
|
115
|
+
classification_model.set(locked=False)
|
|
116
|
+
assert classification_model.locked is False
|
|
117
|
+
|
|
118
|
+
classification_model.lock()
|
|
119
|
+
assert classification_model.locked is True
|
|
120
|
+
|
|
121
|
+
classification_model.unlock()
|
|
122
|
+
assert classification_model.locked is False
|
|
108
123
|
|
|
109
124
|
|
|
110
125
|
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
@@ -115,9 +130,9 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
|
115
130
|
ClassificationModel.open("model_to_delete")
|
|
116
131
|
|
|
117
132
|
|
|
118
|
-
def test_delete_model_unauthenticated(unauthenticated,
|
|
133
|
+
def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
119
134
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
120
|
-
ClassificationModel.drop(
|
|
135
|
+
ClassificationModel.drop(classification_model.name)
|
|
121
136
|
|
|
122
137
|
|
|
123
138
|
def test_delete_model_not_found():
|
|
@@ -127,9 +142,9 @@ def test_delete_model_not_found():
|
|
|
127
142
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
128
143
|
|
|
129
144
|
|
|
130
|
-
def test_delete_model_unauthorized(unauthorized,
|
|
145
|
+
def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
131
146
|
with pytest.raises(LookupError):
|
|
132
|
-
ClassificationModel.drop(
|
|
147
|
+
ClassificationModel.drop(classification_model.name)
|
|
133
148
|
|
|
134
149
|
|
|
135
150
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
@@ -139,78 +154,57 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
|
139
154
|
LabeledMemoryset.drop(memoryset.id)
|
|
140
155
|
|
|
141
156
|
|
|
142
|
-
|
|
143
|
-
|
|
157
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
158
|
+
def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
|
|
159
|
+
result = (
|
|
160
|
+
classification_model.evaluate(eval_dataset)
|
|
161
|
+
if data_type == "dataset"
|
|
162
|
+
else classification_model.evaluate(eval_datasource)
|
|
163
|
+
)
|
|
164
|
+
|
|
144
165
|
assert result is not None
|
|
145
|
-
assert isinstance(result,
|
|
146
|
-
|
|
147
|
-
assert isinstance(result
|
|
148
|
-
assert
|
|
149
|
-
assert isinstance(result
|
|
150
|
-
assert
|
|
151
|
-
assert
|
|
152
|
-
|
|
153
|
-
assert isinstance(result
|
|
154
|
-
assert isinstance(result
|
|
155
|
-
assert isinstance(result
|
|
156
|
-
assert
|
|
157
|
-
assert
|
|
158
|
-
assert
|
|
159
|
-
|
|
160
|
-
assert
|
|
161
|
-
assert
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
assert isinstance(result["precision_recall_curve"]["thresholds"], list)
|
|
179
|
-
assert isinstance(result["precision_recall_curve"]["precisions"], list)
|
|
180
|
-
assert isinstance(result["precision_recall_curve"]["recalls"], list)
|
|
181
|
-
assert isinstance(result["roc_curve"]["thresholds"], list)
|
|
182
|
-
assert isinstance(result["roc_curve"]["false_positive_rates"], list)
|
|
183
|
-
assert isinstance(result["roc_curve"]["true_positive_rates"], list)
|
|
184
|
-
|
|
185
|
-
assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
186
|
-
assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
187
|
-
assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
188
|
-
assert np.allclose(result["roc_curve"]["auc"], 0.625)
|
|
189
|
-
|
|
190
|
-
assert np.allclose(
|
|
191
|
-
result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
|
|
192
|
-
)
|
|
193
|
-
assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
194
|
-
assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
195
|
-
assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def test_evaluate_with_telemetry(model):
|
|
199
|
-
samples = [
|
|
200
|
-
{"text": "chicken noodle soup is the best", "label": 1},
|
|
201
|
-
{"text": "cats are cute", "label": 0},
|
|
202
|
-
]
|
|
203
|
-
eval_datasource = Datasource.from_list("eval_datasource_2", samples)
|
|
204
|
-
result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
|
|
166
|
+
assert isinstance(result, ClassificationMetrics)
|
|
167
|
+
|
|
168
|
+
assert isinstance(result.accuracy, float)
|
|
169
|
+
assert np.allclose(result.accuracy, 0.5)
|
|
170
|
+
assert isinstance(result.f1_score, float)
|
|
171
|
+
assert np.allclose(result.f1_score, 0.5)
|
|
172
|
+
assert isinstance(result.loss, float)
|
|
173
|
+
|
|
174
|
+
assert isinstance(result.anomaly_score_mean, float)
|
|
175
|
+
assert isinstance(result.anomaly_score_median, float)
|
|
176
|
+
assert isinstance(result.anomaly_score_variance, float)
|
|
177
|
+
assert -1.0 <= result.anomaly_score_mean <= 1.0
|
|
178
|
+
assert -1.0 <= result.anomaly_score_median <= 1.0
|
|
179
|
+
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
180
|
+
|
|
181
|
+
assert result.pr_auc is not None
|
|
182
|
+
assert np.allclose(result.pr_auc, 0.75)
|
|
183
|
+
assert result.pr_curve is not None
|
|
184
|
+
assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
|
|
185
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
186
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
187
|
+
|
|
188
|
+
assert result.roc_auc is not None
|
|
189
|
+
assert np.allclose(result.roc_auc, 0.625)
|
|
190
|
+
assert result.roc_curve is not None
|
|
191
|
+
assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
192
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
193
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
197
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
|
|
205
198
|
assert result is not None
|
|
206
|
-
|
|
207
|
-
|
|
199
|
+
assert isinstance(result, ClassificationMetrics)
|
|
200
|
+
predictions = classification_model.predictions(tag="test")
|
|
201
|
+
assert len(predictions) == 4
|
|
208
202
|
assert all(p.tags == {"test"} for p in predictions)
|
|
209
|
-
assert all(p.expected_label ==
|
|
203
|
+
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
210
204
|
|
|
211
205
|
|
|
212
|
-
def test_predict(
|
|
213
|
-
predictions =
|
|
206
|
+
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
207
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
214
208
|
assert len(predictions) == 2
|
|
215
209
|
assert predictions[0].prediction_id is not None
|
|
216
210
|
assert predictions[1].prediction_id is not None
|
|
@@ -229,8 +223,8 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
229
223
|
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
230
224
|
|
|
231
225
|
|
|
232
|
-
def test_predict_disable_telemetry(
|
|
233
|
-
predictions =
|
|
226
|
+
def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
227
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
234
228
|
assert len(predictions) == 2
|
|
235
229
|
assert predictions[0].prediction_id is None
|
|
236
230
|
assert predictions[1].prediction_id is None
|
|
@@ -242,14 +236,14 @@ def test_predict_disable_telemetry(model: ClassificationModel, label_names: list
|
|
|
242
236
|
assert 0 <= predictions[1].confidence <= 1
|
|
243
237
|
|
|
244
238
|
|
|
245
|
-
def test_predict_unauthenticated(unauthenticated,
|
|
239
|
+
def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
246
240
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
247
|
-
|
|
241
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
248
242
|
|
|
249
243
|
|
|
250
|
-
def test_predict_unauthorized(unauthorized,
|
|
244
|
+
def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
251
245
|
with pytest.raises(LookupError):
|
|
252
|
-
|
|
246
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
253
247
|
|
|
254
248
|
|
|
255
249
|
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
@@ -263,10 +257,10 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
|
263
257
|
model.predict("test")
|
|
264
258
|
|
|
265
259
|
|
|
266
|
-
def test_record_prediction_feedback(
|
|
267
|
-
predictions =
|
|
260
|
+
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
261
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
268
262
|
expected_labels = [0, 1]
|
|
269
|
-
|
|
263
|
+
classification_model.record_feedback(
|
|
270
264
|
{
|
|
271
265
|
"prediction_id": p.prediction_id,
|
|
272
266
|
"category": "correct",
|
|
@@ -276,65 +270,107 @@ def test_record_prediction_feedback(model: ClassificationModel):
|
|
|
276
270
|
)
|
|
277
271
|
|
|
278
272
|
|
|
279
|
-
def test_record_prediction_feedback_missing_category(
|
|
280
|
-
prediction =
|
|
273
|
+
def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
|
|
274
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
281
275
|
with pytest.raises(ValueError):
|
|
282
|
-
|
|
276
|
+
classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
283
277
|
|
|
284
278
|
|
|
285
|
-
def test_record_prediction_feedback_invalid_value(
|
|
286
|
-
prediction =
|
|
279
|
+
def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
|
|
280
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
287
281
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
288
|
-
|
|
282
|
+
classification_model.record_feedback(
|
|
283
|
+
{"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
|
|
284
|
+
)
|
|
289
285
|
|
|
290
286
|
|
|
291
|
-
def test_record_prediction_feedback_invalid_prediction_id(
|
|
287
|
+
def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
|
|
292
288
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
293
|
-
|
|
289
|
+
classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
294
290
|
|
|
295
291
|
|
|
296
|
-
def test_predict_with_memoryset_override(
|
|
292
|
+
def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
297
293
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
298
294
|
"test_memoryset_inverted_labels",
|
|
299
295
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
300
296
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
301
297
|
)
|
|
302
|
-
with
|
|
303
|
-
predictions =
|
|
298
|
+
with classification_model.use_memoryset(inverted_labeled_memoryset):
|
|
299
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
304
300
|
assert predictions[0].label == 1
|
|
305
301
|
assert predictions[1].label == 0
|
|
306
302
|
|
|
307
|
-
predictions =
|
|
303
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
308
304
|
assert predictions[0].label == 0
|
|
309
305
|
assert predictions[1].label == 1
|
|
310
306
|
|
|
311
307
|
|
|
312
|
-
def test_predict_with_expected_labels(
|
|
313
|
-
prediction =
|
|
308
|
+
def test_predict_with_expected_labels(classification_model: ClassificationModel):
|
|
309
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1)
|
|
314
310
|
assert prediction.expected_label == 1
|
|
315
311
|
|
|
316
312
|
|
|
317
|
-
def test_predict_with_expected_labels_invalid_input(
|
|
313
|
+
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
318
314
|
# invalid number of expected labels for batch prediction
|
|
319
315
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
320
|
-
|
|
316
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
321
317
|
# invalid label value
|
|
322
318
|
with pytest.raises(ValueError):
|
|
323
|
-
|
|
319
|
+
classification_model.predict("Do you love soup?", expected_labels=5)
|
|
324
320
|
|
|
325
321
|
|
|
326
|
-
def
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
assert
|
|
330
|
-
assert
|
|
331
|
-
assert model._last_prediction_was_batch is True
|
|
322
|
+
def test_predict_with_filters(classification_model: ClassificationModel):
|
|
323
|
+
# there are no memories with label 0 and key g1, so we force a wrong prediction
|
|
324
|
+
filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
|
|
325
|
+
assert filtered_prediction.label == 1
|
|
326
|
+
assert filtered_prediction.label_name == "cats"
|
|
332
327
|
|
|
333
328
|
|
|
334
|
-
def
|
|
329
|
+
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
330
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
331
|
+
assert classification_model.last_prediction is not None
|
|
332
|
+
assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
333
|
+
assert classification_model.last_prediction.input_value == "Are cats cute?"
|
|
334
|
+
assert classification_model._last_prediction_was_batch is True
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
335
338
|
# Test that last_prediction is updated correctly with single prediction
|
|
336
|
-
prediction =
|
|
337
|
-
assert
|
|
338
|
-
assert
|
|
339
|
-
assert
|
|
340
|
-
assert
|
|
339
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
340
|
+
assert classification_model.last_prediction is not None
|
|
341
|
+
assert classification_model.last_prediction.prediction_id == prediction.prediction_id
|
|
342
|
+
assert classification_model.last_prediction.input_value == "Do you love soup?"
|
|
343
|
+
assert classification_model._last_prediction_was_batch is False
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
347
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
348
|
+
|
|
349
|
+
writable_memoryset.analyze(
|
|
350
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
351
|
+
lookup_count=3,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
model = ClassificationModel.create(
|
|
355
|
+
"test_model_for_explain",
|
|
356
|
+
writable_memoryset,
|
|
357
|
+
num_classes=2,
|
|
358
|
+
memory_lookup_count=3,
|
|
359
|
+
description="This is a test model for explain",
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
363
|
+
assert len(predictions) == 2
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
explanation = predictions[0].explanation
|
|
367
|
+
assert explanation is not None
|
|
368
|
+
assert len(explanation) > 10
|
|
369
|
+
assert "soup" in explanation.lower()
|
|
370
|
+
except Exception as e:
|
|
371
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
372
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
|
|
373
|
+
else:
|
|
374
|
+
raise e
|
|
375
|
+
finally:
|
|
376
|
+
ClassificationModel.drop("test_model_for_explain")
|
orca_sdk/conftest.py
CHANGED
|
@@ -11,15 +11,33 @@ from .classification_model import ClassificationModel
|
|
|
11
11
|
from .credentials import OrcaCredentials
|
|
12
12
|
from .datasource import Datasource
|
|
13
13
|
from .embedding_model import PretrainedEmbeddingModel
|
|
14
|
-
from .memoryset import LabeledMemoryset
|
|
14
|
+
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
15
|
+
from .regression_model import RegressionModel
|
|
15
16
|
|
|
16
|
-
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
17
18
|
|
|
18
19
|
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
19
20
|
|
|
20
21
|
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
def skip_in_prod(reason: str):
|
|
25
|
+
"""Custom decorator to skip tests when running against production API"""
|
|
26
|
+
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.dev.orcadb.ai"]
|
|
27
|
+
return pytest.mark.skipif(
|
|
28
|
+
os.environ["ORCA_API_URL"] in PROD_API_URLs,
|
|
29
|
+
reason=reason,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def skip_in_ci(reason: str):
|
|
34
|
+
"""Custom decorator to skip tests when running in CI"""
|
|
35
|
+
return pytest.mark.skipif(
|
|
36
|
+
os.environ.get("GITHUB_ACTIONS", "false") == "true",
|
|
37
|
+
reason=reason,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
23
41
|
def _create_org_id():
|
|
24
42
|
# UUID start to identify test data (0xtest...)
|
|
25
43
|
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
@@ -71,27 +89,27 @@ def label_names():
|
|
|
71
89
|
|
|
72
90
|
|
|
73
91
|
SAMPLE_DATA = [
|
|
74
|
-
{"value": "i love soup", "label": 0, "key": "
|
|
75
|
-
{"value": "cats are cute", "label": 1, "key": "
|
|
76
|
-
{"value": "soup is good", "label": 0, "key": "
|
|
77
|
-
{"value": "i love cats", "label": 1, "key": "
|
|
78
|
-
{"value": "everyone loves cats", "label": 1, "key": "
|
|
79
|
-
{"value": "soup is great for the winter", "label": 0, "key": "
|
|
80
|
-
{"value": "hot soup on a rainy day!", "label": 0, "key": "
|
|
81
|
-
{"value": "cats sleep all day", "label": 1, "key": "
|
|
82
|
-
{"value": "homemade soup recipes", "label": 0, "key": "
|
|
83
|
-
{"value": "cats purr when happy", "label": 1, "key": "
|
|
84
|
-
{"value": "chicken noodle soup is classic", "label": 0, "key": "
|
|
85
|
-
{"value": "kittens are baby cats", "label": 1, "key": "
|
|
86
|
-
{"value": "soup can be served cold too", "label": 0, "key": "
|
|
87
|
-
{"value": "cats have nine lives", "label": 1, "key": "
|
|
88
|
-
{"value": "tomato soup with grilled cheese", "label": 0, "key": "
|
|
89
|
-
{"value": "cats are independent animals", "label": 1, "key": "
|
|
92
|
+
{"value": "i love soup", "label": 0, "key": "g1", "score": 0.1, "source_id": "s1"},
|
|
93
|
+
{"value": "cats are cute", "label": 1, "key": "g1", "score": 0.9, "source_id": "s2"},
|
|
94
|
+
{"value": "soup is good", "label": 0, "key": "g1", "score": 0.1, "source_id": "s3"},
|
|
95
|
+
{"value": "i love cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s4"},
|
|
96
|
+
{"value": "everyone loves cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s5"},
|
|
97
|
+
{"value": "soup is great for the winter", "label": 0, "key": "g1", "score": 0.1, "source_id": "s6"},
|
|
98
|
+
{"value": "hot soup on a rainy day!", "label": 0, "key": "g1", "score": 0.1, "source_id": "s7"},
|
|
99
|
+
{"value": "cats sleep all day", "label": 1, "key": "g1", "score": 0.9, "source_id": "s8"},
|
|
100
|
+
{"value": "homemade soup recipes", "label": 0, "key": "g1", "score": 0.1, "source_id": "s9"},
|
|
101
|
+
{"value": "cats purr when happy", "label": 1, "key": "g2", "score": 0.9, "source_id": "s10"},
|
|
102
|
+
{"value": "chicken noodle soup is classic", "label": 0, "key": "g1", "score": 0.1, "source_id": "s11"},
|
|
103
|
+
{"value": "kittens are baby cats", "label": 1, "key": "g2", "score": 0.9, "source_id": "s12"},
|
|
104
|
+
{"value": "soup can be served cold too", "label": 0, "key": "g1", "score": 0.1, "source_id": "s13"},
|
|
105
|
+
{"value": "cats have nine lives", "label": 1, "key": "g2", "score": 0.9, "source_id": "s14"},
|
|
106
|
+
{"value": "tomato soup with grilled cheese", "label": 0, "key": "g1", "score": 0.1, "source_id": "s15"},
|
|
107
|
+
{"value": "cats are independent animals", "label": 1, "key": "g2", "score": 0.9, "source_id": "s16"},
|
|
90
108
|
]
|
|
91
109
|
|
|
92
110
|
|
|
93
111
|
@pytest.fixture(scope="session")
|
|
94
|
-
def hf_dataset(label_names):
|
|
112
|
+
def hf_dataset(label_names: list[str]) -> Dataset:
|
|
95
113
|
return Dataset.from_list(
|
|
96
114
|
SAMPLE_DATA,
|
|
97
115
|
features=Features(
|
|
@@ -107,16 +125,16 @@ def hf_dataset(label_names):
|
|
|
107
125
|
|
|
108
126
|
|
|
109
127
|
@pytest.fixture(scope="session")
|
|
110
|
-
def datasource(hf_dataset) -> Datasource:
|
|
128
|
+
def datasource(hf_dataset: Dataset) -> Datasource:
|
|
111
129
|
datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
112
130
|
return datasource
|
|
113
131
|
|
|
114
132
|
|
|
115
133
|
EVAL_DATASET = [
|
|
116
|
-
{"value": "chicken noodle soup is the best", "label": 1},
|
|
117
|
-
{"value": "cats are cute", "label": 0},
|
|
118
|
-
{"value": "soup is great for the winter", "label": 0},
|
|
119
|
-
{"value": "i love cats", "label": 1},
|
|
134
|
+
{"value": "chicken noodle soup is the best", "label": 1, "score": 0.9}, # mislabeled
|
|
135
|
+
{"value": "cats are cute", "label": 0, "score": 0.1}, # mislabeled
|
|
136
|
+
{"value": "soup is great for the winter", "label": 0, "score": 0.1},
|
|
137
|
+
{"value": "i love cats", "label": 1, "score": 0.9},
|
|
120
138
|
]
|
|
121
139
|
|
|
122
140
|
|
|
@@ -140,6 +158,8 @@ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
|
|
|
140
158
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
141
159
|
source_id_column="source_id",
|
|
142
160
|
max_seq_length_override=32,
|
|
161
|
+
index_type="IVF_FLAT",
|
|
162
|
+
index_params={"n_lists": 100},
|
|
143
163
|
)
|
|
144
164
|
return memoryset
|
|
145
165
|
|
|
@@ -176,14 +196,45 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
176
196
|
|
|
177
197
|
if memory_ids:
|
|
178
198
|
memoryset.delete(memory_ids)
|
|
199
|
+
memoryset.refresh()
|
|
179
200
|
assert len(memoryset) == 0
|
|
180
201
|
memoryset.insert(SAMPLE_DATA)
|
|
181
202
|
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
|
182
203
|
|
|
183
204
|
|
|
184
205
|
@pytest.fixture(scope="session")
|
|
185
|
-
def
|
|
206
|
+
def classification_model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
|
|
186
207
|
model = ClassificationModel.create(
|
|
187
|
-
"
|
|
208
|
+
"test_classification_model",
|
|
209
|
+
readonly_memoryset,
|
|
210
|
+
num_classes=2,
|
|
211
|
+
memory_lookup_count=3,
|
|
212
|
+
description="test_description",
|
|
213
|
+
)
|
|
214
|
+
return model
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Add scored memoryset and regression model fixtures
|
|
218
|
+
@pytest.fixture(scope="session")
|
|
219
|
+
def scored_memoryset(datasource: Datasource) -> ScoredMemoryset:
|
|
220
|
+
memoryset = ScoredMemoryset.create(
|
|
221
|
+
"test_scored_memoryset",
|
|
222
|
+
datasource=datasource,
|
|
223
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
224
|
+
source_id_column="source_id",
|
|
225
|
+
max_seq_length_override=32,
|
|
226
|
+
index_type="IVF_FLAT",
|
|
227
|
+
index_params={"n_lists": 100},
|
|
228
|
+
)
|
|
229
|
+
return memoryset
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@pytest.fixture(scope="session")
|
|
233
|
+
def regression_model(scored_memoryset: ScoredMemoryset) -> RegressionModel:
|
|
234
|
+
model = RegressionModel.create(
|
|
235
|
+
"test_regression_model",
|
|
236
|
+
scored_memoryset,
|
|
237
|
+
memory_lookup_count=3,
|
|
238
|
+
description="test_regression_description",
|
|
188
239
|
)
|
|
189
240
|
return model
|