orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -172
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +337 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
|
@@ -1,53 +1,52 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import os
|
|
3
2
|
from uuid import uuid4
|
|
4
3
|
|
|
5
4
|
import numpy as np
|
|
6
5
|
import pytest
|
|
7
6
|
from datasets.arrow_dataset import Dataset
|
|
8
7
|
|
|
9
|
-
from .classification_model import ClassificationModel
|
|
8
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
9
|
+
from .conftest import skip_in_ci
|
|
10
10
|
from .datasource import Datasource
|
|
11
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
12
12
|
from .memoryset import LabeledMemoryset
|
|
13
13
|
|
|
14
|
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
15
14
|
|
|
15
|
+
def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
16
|
+
assert classification_model is not None
|
|
17
|
+
assert classification_model.name == "test_classification_model"
|
|
18
|
+
assert classification_model.memoryset == readonly_memoryset
|
|
19
|
+
assert classification_model.num_classes == 2
|
|
20
|
+
assert classification_model.memory_lookup_count == 3
|
|
16
21
|
|
|
17
|
-
SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
|
|
18
22
|
|
|
19
|
-
|
|
20
|
-
def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
21
|
-
assert model is not None
|
|
22
|
-
assert model.name == "test_model"
|
|
23
|
-
assert model.memoryset == readonly_memoryset
|
|
24
|
-
assert model.num_classes == 2
|
|
25
|
-
assert model.memory_lookup_count == 3
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
|
|
23
|
+
def test_create_model_already_exists_error(readonly_memoryset, classification_model):
|
|
29
24
|
with pytest.raises(ValueError):
|
|
30
|
-
ClassificationModel.create("
|
|
25
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset)
|
|
31
26
|
with pytest.raises(ValueError):
|
|
32
|
-
ClassificationModel.create("
|
|
27
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
|
|
33
28
|
|
|
34
29
|
|
|
35
|
-
def test_create_model_already_exists_return(readonly_memoryset,
|
|
30
|
+
def test_create_model_already_exists_return(readonly_memoryset, classification_model):
|
|
36
31
|
with pytest.raises(ValueError):
|
|
37
|
-
ClassificationModel.create("
|
|
32
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
38
33
|
|
|
39
34
|
with pytest.raises(ValueError):
|
|
40
|
-
ClassificationModel.create(
|
|
35
|
+
ClassificationModel.create(
|
|
36
|
+
"test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
|
|
37
|
+
)
|
|
41
38
|
|
|
42
39
|
with pytest.raises(ValueError):
|
|
43
|
-
ClassificationModel.create("
|
|
40
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
44
41
|
|
|
45
42
|
with pytest.raises(ValueError):
|
|
46
|
-
ClassificationModel.create(
|
|
43
|
+
ClassificationModel.create(
|
|
44
|
+
"test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
|
|
45
|
+
)
|
|
47
46
|
|
|
48
|
-
new_model = ClassificationModel.create("
|
|
47
|
+
new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
|
|
49
48
|
assert new_model is not None
|
|
50
|
-
assert new_model.name == "
|
|
49
|
+
assert new_model.name == "test_classification_model"
|
|
51
50
|
assert new_model.memoryset == readonly_memoryset
|
|
52
51
|
assert new_model.num_classes == 2
|
|
53
52
|
assert new_model.memory_lookup_count == 3
|
|
@@ -58,14 +57,14 @@ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: Label
|
|
|
58
57
|
ClassificationModel.create("test_model", readonly_memoryset)
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def test_get_model(
|
|
62
|
-
fetched_model = ClassificationModel.open(
|
|
60
|
+
def test_get_model(classification_model: ClassificationModel):
|
|
61
|
+
fetched_model = ClassificationModel.open(classification_model.name)
|
|
63
62
|
assert fetched_model is not None
|
|
64
|
-
assert fetched_model.id ==
|
|
65
|
-
assert fetched_model.name ==
|
|
63
|
+
assert fetched_model.id == classification_model.id
|
|
64
|
+
assert fetched_model.name == classification_model.name
|
|
66
65
|
assert fetched_model.num_classes == 2
|
|
67
66
|
assert fetched_model.memory_lookup_count == 3
|
|
68
|
-
assert fetched_model ==
|
|
67
|
+
assert fetched_model == classification_model
|
|
69
68
|
|
|
70
69
|
|
|
71
70
|
def test_get_model_unauthenticated(unauthenticated):
|
|
@@ -83,12 +82,12 @@ def test_get_model_not_found():
|
|
|
83
82
|
ClassificationModel.open(str(uuid4()))
|
|
84
83
|
|
|
85
84
|
|
|
86
|
-
def test_get_model_unauthorized(unauthorized,
|
|
85
|
+
def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
87
86
|
with pytest.raises(LookupError):
|
|
88
|
-
ClassificationModel.open(
|
|
87
|
+
ClassificationModel.open(classification_model.name)
|
|
89
88
|
|
|
90
89
|
|
|
91
|
-
def test_list_models(
|
|
90
|
+
def test_list_models(classification_model: ClassificationModel):
|
|
92
91
|
models = ClassificationModel.all()
|
|
93
92
|
assert len(models) > 0
|
|
94
93
|
assert any(model.name == model.name for model in models)
|
|
@@ -99,19 +98,28 @@ def test_list_models_unauthenticated(unauthenticated):
|
|
|
99
98
|
ClassificationModel.all()
|
|
100
99
|
|
|
101
100
|
|
|
102
|
-
def test_list_models_unauthorized(unauthorized,
|
|
101
|
+
def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
103
102
|
assert ClassificationModel.all() == []
|
|
104
103
|
|
|
105
104
|
|
|
106
|
-
def
|
|
107
|
-
|
|
108
|
-
assert
|
|
105
|
+
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
106
|
+
classification_model.description = "New description"
|
|
107
|
+
assert classification_model.description == "New description"
|
|
108
|
+
|
|
109
|
+
classification_model.set(description=None)
|
|
110
|
+
assert classification_model.description is None
|
|
111
|
+
|
|
112
|
+
classification_model.set(locked=True)
|
|
113
|
+
assert classification_model.locked is True
|
|
109
114
|
|
|
115
|
+
classification_model.set(locked=False)
|
|
116
|
+
assert classification_model.locked is False
|
|
110
117
|
|
|
111
|
-
|
|
112
|
-
assert
|
|
113
|
-
|
|
114
|
-
|
|
118
|
+
classification_model.lock()
|
|
119
|
+
assert classification_model.locked is True
|
|
120
|
+
|
|
121
|
+
classification_model.unlock()
|
|
122
|
+
assert classification_model.locked is False
|
|
115
123
|
|
|
116
124
|
|
|
117
125
|
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
@@ -122,9 +130,9 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
|
122
130
|
ClassificationModel.open("model_to_delete")
|
|
123
131
|
|
|
124
132
|
|
|
125
|
-
def test_delete_model_unauthenticated(unauthenticated,
|
|
133
|
+
def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
126
134
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
127
|
-
ClassificationModel.drop(
|
|
135
|
+
ClassificationModel.drop(classification_model.name)
|
|
128
136
|
|
|
129
137
|
|
|
130
138
|
def test_delete_model_not_found():
|
|
@@ -134,9 +142,9 @@ def test_delete_model_not_found():
|
|
|
134
142
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
135
143
|
|
|
136
144
|
|
|
137
|
-
def test_delete_model_unauthorized(unauthorized,
|
|
145
|
+
def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
138
146
|
with pytest.raises(LookupError):
|
|
139
|
-
ClassificationModel.drop(
|
|
147
|
+
ClassificationModel.drop(classification_model.name)
|
|
140
148
|
|
|
141
149
|
|
|
142
150
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
@@ -146,78 +154,57 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
|
146
154
|
LabeledMemoryset.drop(memoryset.id)
|
|
147
155
|
|
|
148
156
|
|
|
149
|
-
|
|
150
|
-
|
|
157
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
158
|
+
def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
|
|
159
|
+
result = (
|
|
160
|
+
classification_model.evaluate(eval_dataset)
|
|
161
|
+
if data_type == "dataset"
|
|
162
|
+
else classification_model.evaluate(eval_datasource)
|
|
163
|
+
)
|
|
164
|
+
|
|
151
165
|
assert result is not None
|
|
152
|
-
assert isinstance(result,
|
|
153
|
-
|
|
154
|
-
assert isinstance(result
|
|
155
|
-
assert
|
|
156
|
-
assert isinstance(result
|
|
157
|
-
assert
|
|
158
|
-
assert
|
|
159
|
-
|
|
160
|
-
assert isinstance(result
|
|
161
|
-
assert isinstance(result
|
|
162
|
-
assert isinstance(result
|
|
163
|
-
assert
|
|
164
|
-
assert
|
|
165
|
-
assert
|
|
166
|
-
|
|
167
|
-
assert
|
|
168
|
-
assert
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
assert isinstance(result["precision_recall_curve"]["thresholds"], list)
|
|
186
|
-
assert isinstance(result["precision_recall_curve"]["precisions"], list)
|
|
187
|
-
assert isinstance(result["precision_recall_curve"]["recalls"], list)
|
|
188
|
-
assert isinstance(result["roc_curve"]["thresholds"], list)
|
|
189
|
-
assert isinstance(result["roc_curve"]["false_positive_rates"], list)
|
|
190
|
-
assert isinstance(result["roc_curve"]["true_positive_rates"], list)
|
|
191
|
-
|
|
192
|
-
assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
193
|
-
assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
194
|
-
assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
195
|
-
assert np.allclose(result["roc_curve"]["auc"], 0.625)
|
|
196
|
-
|
|
197
|
-
assert np.allclose(
|
|
198
|
-
result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
|
|
199
|
-
)
|
|
200
|
-
assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
201
|
-
assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
202
|
-
assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
def test_evaluate_with_telemetry(model):
|
|
206
|
-
samples = [
|
|
207
|
-
{"text": "chicken noodle soup is the best", "label": 1},
|
|
208
|
-
{"text": "cats are cute", "label": 0},
|
|
209
|
-
]
|
|
210
|
-
eval_datasource = Datasource.from_list("eval_datasource_2", samples)
|
|
211
|
-
result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
|
|
166
|
+
assert isinstance(result, ClassificationMetrics)
|
|
167
|
+
|
|
168
|
+
assert isinstance(result.accuracy, float)
|
|
169
|
+
assert np.allclose(result.accuracy, 0.5)
|
|
170
|
+
assert isinstance(result.f1_score, float)
|
|
171
|
+
assert np.allclose(result.f1_score, 0.5)
|
|
172
|
+
assert isinstance(result.loss, float)
|
|
173
|
+
|
|
174
|
+
assert isinstance(result.anomaly_score_mean, float)
|
|
175
|
+
assert isinstance(result.anomaly_score_median, float)
|
|
176
|
+
assert isinstance(result.anomaly_score_variance, float)
|
|
177
|
+
assert -1.0 <= result.anomaly_score_mean <= 1.0
|
|
178
|
+
assert -1.0 <= result.anomaly_score_median <= 1.0
|
|
179
|
+
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
180
|
+
|
|
181
|
+
assert result.pr_auc is not None
|
|
182
|
+
assert np.allclose(result.pr_auc, 0.75)
|
|
183
|
+
assert result.pr_curve is not None
|
|
184
|
+
assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
|
|
185
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
|
|
186
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
|
|
187
|
+
|
|
188
|
+
assert result.roc_auc is not None
|
|
189
|
+
assert np.allclose(result.roc_auc, 0.625)
|
|
190
|
+
assert result.roc_curve is not None
|
|
191
|
+
assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
|
|
192
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
|
|
193
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
197
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
|
|
212
198
|
assert result is not None
|
|
213
|
-
|
|
214
|
-
|
|
199
|
+
assert isinstance(result, ClassificationMetrics)
|
|
200
|
+
predictions = classification_model.predictions(tag="test")
|
|
201
|
+
assert len(predictions) == 4
|
|
215
202
|
assert all(p.tags == {"test"} for p in predictions)
|
|
216
|
-
assert all(p.expected_label ==
|
|
203
|
+
assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
|
|
217
204
|
|
|
218
205
|
|
|
219
|
-
def test_predict(
|
|
220
|
-
predictions =
|
|
206
|
+
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
207
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
221
208
|
assert len(predictions) == 2
|
|
222
209
|
assert predictions[0].prediction_id is not None
|
|
223
210
|
assert predictions[1].prediction_id is not None
|
|
@@ -236,8 +223,8 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
|
|
|
236
223
|
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
237
224
|
|
|
238
225
|
|
|
239
|
-
def test_predict_disable_telemetry(
|
|
240
|
-
predictions =
|
|
226
|
+
def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
227
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
241
228
|
assert len(predictions) == 2
|
|
242
229
|
assert predictions[0].prediction_id is None
|
|
243
230
|
assert predictions[1].prediction_id is None
|
|
@@ -249,14 +236,14 @@ def test_predict_disable_telemetry(model: ClassificationModel, label_names: list
|
|
|
249
236
|
assert 0 <= predictions[1].confidence <= 1
|
|
250
237
|
|
|
251
238
|
|
|
252
|
-
def test_predict_unauthenticated(unauthenticated,
|
|
239
|
+
def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
|
|
253
240
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
254
|
-
|
|
241
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
255
242
|
|
|
256
243
|
|
|
257
|
-
def test_predict_unauthorized(unauthorized,
|
|
244
|
+
def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
|
|
258
245
|
with pytest.raises(LookupError):
|
|
259
|
-
|
|
246
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
260
247
|
|
|
261
248
|
|
|
262
249
|
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
@@ -270,10 +257,10 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
|
270
257
|
model.predict("test")
|
|
271
258
|
|
|
272
259
|
|
|
273
|
-
def test_record_prediction_feedback(
|
|
274
|
-
predictions =
|
|
260
|
+
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
261
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
275
262
|
expected_labels = [0, 1]
|
|
276
|
-
|
|
263
|
+
classification_model.record_feedback(
|
|
277
264
|
{
|
|
278
265
|
"prediction_id": p.prediction_id,
|
|
279
266
|
"category": "correct",
|
|
@@ -283,73 +270,80 @@ def test_record_prediction_feedback(model: ClassificationModel):
|
|
|
283
270
|
)
|
|
284
271
|
|
|
285
272
|
|
|
286
|
-
def test_record_prediction_feedback_missing_category(
|
|
287
|
-
prediction =
|
|
273
|
+
def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
|
|
274
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
288
275
|
with pytest.raises(ValueError):
|
|
289
|
-
|
|
276
|
+
classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
290
277
|
|
|
291
278
|
|
|
292
|
-
def test_record_prediction_feedback_invalid_value(
|
|
293
|
-
prediction =
|
|
279
|
+
def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
|
|
280
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
294
281
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
295
|
-
|
|
282
|
+
classification_model.record_feedback(
|
|
283
|
+
{"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
|
|
284
|
+
)
|
|
296
285
|
|
|
297
286
|
|
|
298
|
-
def test_record_prediction_feedback_invalid_prediction_id(
|
|
287
|
+
def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
|
|
299
288
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
300
|
-
|
|
289
|
+
classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
301
290
|
|
|
302
291
|
|
|
303
|
-
def test_predict_with_memoryset_override(
|
|
292
|
+
def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
304
293
|
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
305
294
|
"test_memoryset_inverted_labels",
|
|
306
295
|
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
307
296
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
308
297
|
)
|
|
309
|
-
with
|
|
310
|
-
predictions =
|
|
298
|
+
with classification_model.use_memoryset(inverted_labeled_memoryset):
|
|
299
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
311
300
|
assert predictions[0].label == 1
|
|
312
301
|
assert predictions[1].label == 0
|
|
313
302
|
|
|
314
|
-
predictions =
|
|
303
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
315
304
|
assert predictions[0].label == 0
|
|
316
305
|
assert predictions[1].label == 1
|
|
317
306
|
|
|
318
307
|
|
|
319
|
-
def test_predict_with_expected_labels(
|
|
320
|
-
prediction =
|
|
308
|
+
def test_predict_with_expected_labels(classification_model: ClassificationModel):
|
|
309
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1)
|
|
321
310
|
assert prediction.expected_label == 1
|
|
322
311
|
|
|
323
312
|
|
|
324
|
-
def test_predict_with_expected_labels_invalid_input(
|
|
313
|
+
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
325
314
|
# invalid number of expected labels for batch prediction
|
|
326
315
|
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
327
|
-
|
|
316
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
328
317
|
# invalid label value
|
|
329
318
|
with pytest.raises(ValueError):
|
|
330
|
-
|
|
319
|
+
classification_model.predict("Do you love soup?", expected_labels=5)
|
|
331
320
|
|
|
332
321
|
|
|
333
|
-
def
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
assert
|
|
337
|
-
assert
|
|
338
|
-
|
|
322
|
+
def test_predict_with_filters(classification_model: ClassificationModel):
|
|
323
|
+
# there are no memories with label 0 and key g1, so we force a wrong prediction
|
|
324
|
+
filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
|
|
325
|
+
assert filtered_prediction.label == 1
|
|
326
|
+
assert filtered_prediction.label_name == "cats"
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
330
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
331
|
+
assert classification_model.last_prediction is not None
|
|
332
|
+
assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
333
|
+
assert classification_model.last_prediction.input_value == "Are cats cute?"
|
|
334
|
+
assert classification_model._last_prediction_was_batch is True
|
|
339
335
|
|
|
340
336
|
|
|
341
|
-
def test_last_prediction_with_single(
|
|
337
|
+
def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
342
338
|
# Test that last_prediction is updated correctly with single prediction
|
|
343
|
-
prediction =
|
|
344
|
-
assert
|
|
345
|
-
assert
|
|
346
|
-
assert
|
|
347
|
-
assert
|
|
339
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
340
|
+
assert classification_model.last_prediction is not None
|
|
341
|
+
assert classification_model.last_prediction.prediction_id == prediction.prediction_id
|
|
342
|
+
assert classification_model.last_prediction.input_value == "Do you love soup?"
|
|
343
|
+
assert classification_model._last_prediction_was_batch is False
|
|
348
344
|
|
|
349
345
|
|
|
350
|
-
@
|
|
351
|
-
SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
|
|
352
|
-
)
|
|
346
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
353
347
|
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
354
348
|
|
|
355
349
|
writable_memoryset.analyze(
|
|
@@ -370,17 +364,13 @@ def test_explain(writable_memoryset: LabeledMemoryset):
|
|
|
370
364
|
|
|
371
365
|
try:
|
|
372
366
|
explanation = predictions[0].explanation
|
|
373
|
-
print(explanation)
|
|
374
367
|
assert explanation is not None
|
|
375
368
|
assert len(explanation) > 10
|
|
376
369
|
assert "soup" in explanation.lower()
|
|
377
370
|
except Exception as e:
|
|
378
371
|
if "ANTHROPIC_API_KEY" in str(e):
|
|
379
|
-
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set
|
|
372
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
|
|
380
373
|
else:
|
|
381
374
|
raise e
|
|
382
375
|
finally:
|
|
383
|
-
|
|
384
|
-
ClassificationModel.drop("test_model_for_explain")
|
|
385
|
-
except Exception as e:
|
|
386
|
-
logging.info(f"Failed to drop test model for explain: {e}")
|
|
376
|
+
ClassificationModel.drop("test_model_for_explain")
|
orca_sdk/conftest.py
CHANGED
|
@@ -11,15 +11,33 @@ from .classification_model import ClassificationModel
|
|
|
11
11
|
from .credentials import OrcaCredentials
|
|
12
12
|
from .datasource import Datasource
|
|
13
13
|
from .embedding_model import PretrainedEmbeddingModel
|
|
14
|
-
from .memoryset import LabeledMemoryset
|
|
14
|
+
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
15
|
+
from .regression_model import RegressionModel
|
|
15
16
|
|
|
16
|
-
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
17
18
|
|
|
18
19
|
os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
|
|
19
20
|
|
|
20
21
|
os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
def skip_in_prod(reason: str):
|
|
25
|
+
"""Custom decorator to skip tests when running against production API"""
|
|
26
|
+
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.dev.orcadb.ai"]
|
|
27
|
+
return pytest.mark.skipif(
|
|
28
|
+
os.environ["ORCA_API_URL"] in PROD_API_URLs,
|
|
29
|
+
reason=reason,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def skip_in_ci(reason: str):
|
|
34
|
+
"""Custom decorator to skip tests when running in CI"""
|
|
35
|
+
return pytest.mark.skipif(
|
|
36
|
+
os.environ.get("GITHUB_ACTIONS", "false") == "true",
|
|
37
|
+
reason=reason,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
23
41
|
def _create_org_id():
|
|
24
42
|
# UUID start to identify test data (0xtest...)
|
|
25
43
|
return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
|
|
@@ -71,27 +89,27 @@ def label_names():
|
|
|
71
89
|
|
|
72
90
|
|
|
73
91
|
SAMPLE_DATA = [
|
|
74
|
-
{"value": "i love soup", "label": 0, "key": "
|
|
75
|
-
{"value": "cats are cute", "label": 1, "key": "
|
|
76
|
-
{"value": "soup is good", "label": 0, "key": "
|
|
77
|
-
{"value": "i love cats", "label": 1, "key": "
|
|
78
|
-
{"value": "everyone loves cats", "label": 1, "key": "
|
|
79
|
-
{"value": "soup is great for the winter", "label": 0, "key": "
|
|
80
|
-
{"value": "hot soup on a rainy day!", "label": 0, "key": "
|
|
81
|
-
{"value": "cats sleep all day", "label": 1, "key": "
|
|
82
|
-
{"value": "homemade soup recipes", "label": 0, "key": "
|
|
83
|
-
{"value": "cats purr when happy", "label": 1, "key": "
|
|
84
|
-
{"value": "chicken noodle soup is classic", "label": 0, "key": "
|
|
85
|
-
{"value": "kittens are baby cats", "label": 1, "key": "
|
|
86
|
-
{"value": "soup can be served cold too", "label": 0, "key": "
|
|
87
|
-
{"value": "cats have nine lives", "label": 1, "key": "
|
|
88
|
-
{"value": "tomato soup with grilled cheese", "label": 0, "key": "
|
|
89
|
-
{"value": "cats are independent animals", "label": 1, "key": "
|
|
92
|
+
{"value": "i love soup", "label": 0, "key": "g1", "score": 0.1, "source_id": "s1"},
|
|
93
|
+
{"value": "cats are cute", "label": 1, "key": "g1", "score": 0.9, "source_id": "s2"},
|
|
94
|
+
{"value": "soup is good", "label": 0, "key": "g1", "score": 0.1, "source_id": "s3"},
|
|
95
|
+
{"value": "i love cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s4"},
|
|
96
|
+
{"value": "everyone loves cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s5"},
|
|
97
|
+
{"value": "soup is great for the winter", "label": 0, "key": "g1", "score": 0.1, "source_id": "s6"},
|
|
98
|
+
{"value": "hot soup on a rainy day!", "label": 0, "key": "g1", "score": 0.1, "source_id": "s7"},
|
|
99
|
+
{"value": "cats sleep all day", "label": 1, "key": "g1", "score": 0.9, "source_id": "s8"},
|
|
100
|
+
{"value": "homemade soup recipes", "label": 0, "key": "g1", "score": 0.1, "source_id": "s9"},
|
|
101
|
+
{"value": "cats purr when happy", "label": 1, "key": "g2", "score": 0.9, "source_id": "s10"},
|
|
102
|
+
{"value": "chicken noodle soup is classic", "label": 0, "key": "g1", "score": 0.1, "source_id": "s11"},
|
|
103
|
+
{"value": "kittens are baby cats", "label": 1, "key": "g2", "score": 0.9, "source_id": "s12"},
|
|
104
|
+
{"value": "soup can be served cold too", "label": 0, "key": "g1", "score": 0.1, "source_id": "s13"},
|
|
105
|
+
{"value": "cats have nine lives", "label": 1, "key": "g2", "score": 0.9, "source_id": "s14"},
|
|
106
|
+
{"value": "tomato soup with grilled cheese", "label": 0, "key": "g1", "score": 0.1, "source_id": "s15"},
|
|
107
|
+
{"value": "cats are independent animals", "label": 1, "key": "g2", "score": 0.9, "source_id": "s16"},
|
|
90
108
|
]
|
|
91
109
|
|
|
92
110
|
|
|
93
111
|
@pytest.fixture(scope="session")
|
|
94
|
-
def hf_dataset(label_names):
|
|
112
|
+
def hf_dataset(label_names: list[str]) -> Dataset:
|
|
95
113
|
return Dataset.from_list(
|
|
96
114
|
SAMPLE_DATA,
|
|
97
115
|
features=Features(
|
|
@@ -107,16 +125,16 @@ def hf_dataset(label_names):
|
|
|
107
125
|
|
|
108
126
|
|
|
109
127
|
@pytest.fixture(scope="session")
|
|
110
|
-
def datasource(hf_dataset) -> Datasource:
|
|
128
|
+
def datasource(hf_dataset: Dataset) -> Datasource:
|
|
111
129
|
datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
112
130
|
return datasource
|
|
113
131
|
|
|
114
132
|
|
|
115
133
|
EVAL_DATASET = [
|
|
116
|
-
{"value": "chicken noodle soup is the best", "label": 1},
|
|
117
|
-
{"value": "cats are cute", "label": 0},
|
|
118
|
-
{"value": "soup is great for the winter", "label": 0},
|
|
119
|
-
{"value": "i love cats", "label": 1},
|
|
134
|
+
{"value": "chicken noodle soup is the best", "label": 1, "score": 0.9}, # mislabeled
|
|
135
|
+
{"value": "cats are cute", "label": 0, "score": 0.1}, # mislabeled
|
|
136
|
+
{"value": "soup is great for the winter", "label": 0, "score": 0.1},
|
|
137
|
+
{"value": "i love cats", "label": 1, "score": 0.9},
|
|
120
138
|
]
|
|
121
139
|
|
|
122
140
|
|
|
@@ -140,6 +158,8 @@ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
|
|
|
140
158
|
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
141
159
|
source_id_column="source_id",
|
|
142
160
|
max_seq_length_override=32,
|
|
161
|
+
index_type="IVF_FLAT",
|
|
162
|
+
index_params={"n_lists": 100},
|
|
143
163
|
)
|
|
144
164
|
return memoryset
|
|
145
165
|
|
|
@@ -183,8 +203,38 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
183
203
|
|
|
184
204
|
|
|
185
205
|
@pytest.fixture(scope="session")
|
|
186
|
-
def
|
|
206
|
+
def classification_model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
|
|
187
207
|
model = ClassificationModel.create(
|
|
188
|
-
"
|
|
208
|
+
"test_classification_model",
|
|
209
|
+
readonly_memoryset,
|
|
210
|
+
num_classes=2,
|
|
211
|
+
memory_lookup_count=3,
|
|
212
|
+
description="test_description",
|
|
213
|
+
)
|
|
214
|
+
return model
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Add scored memoryset and regression model fixtures
|
|
218
|
+
@pytest.fixture(scope="session")
|
|
219
|
+
def scored_memoryset(datasource: Datasource) -> ScoredMemoryset:
|
|
220
|
+
memoryset = ScoredMemoryset.create(
|
|
221
|
+
"test_scored_memoryset",
|
|
222
|
+
datasource=datasource,
|
|
223
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
224
|
+
source_id_column="source_id",
|
|
225
|
+
max_seq_length_override=32,
|
|
226
|
+
index_type="IVF_FLAT",
|
|
227
|
+
index_params={"n_lists": 100},
|
|
228
|
+
)
|
|
229
|
+
return memoryset
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@pytest.fixture(scope="session")
|
|
233
|
+
def regression_model(scored_memoryset: ScoredMemoryset) -> RegressionModel:
|
|
234
|
+
model = RegressionModel.create(
|
|
235
|
+
"test_regression_model",
|
|
236
|
+
scored_memoryset,
|
|
237
|
+
memory_lookup_count=3,
|
|
238
|
+
description="test_regression_description",
|
|
189
239
|
)
|
|
190
240
|
return model
|