orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,887 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from uuid import uuid4
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytest
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
|
|
8
|
+
from .classification_model import ClassificationMetrics, ClassificationModel
|
|
9
|
+
from .conftest import skip_in_ci
|
|
10
|
+
from .datasource import Datasource
|
|
11
|
+
from .embedding_model import PretrainedEmbeddingModel
|
|
12
|
+
from .memoryset import LabeledMemoryset
|
|
13
|
+
from .telemetry import ClassificationPrediction
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
17
|
+
assert classification_model is not None
|
|
18
|
+
assert classification_model.name == "test_classification_model"
|
|
19
|
+
assert classification_model.memoryset == readonly_memoryset
|
|
20
|
+
assert classification_model.num_classes == 2
|
|
21
|
+
assert classification_model.memory_lookup_count == 3
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_create_model_already_exists_error(readonly_memoryset, classification_model):
|
|
25
|
+
with pytest.raises(ValueError):
|
|
26
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset)
|
|
27
|
+
with pytest.raises(ValueError):
|
|
28
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_create_model_already_exists_return(readonly_memoryset, classification_model):
|
|
32
|
+
with pytest.raises(ValueError):
|
|
33
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
|
|
34
|
+
|
|
35
|
+
with pytest.raises(ValueError):
|
|
36
|
+
ClassificationModel.create(
|
|
37
|
+
"test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
with pytest.raises(ValueError):
|
|
41
|
+
ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
|
|
42
|
+
|
|
43
|
+
with pytest.raises(ValueError):
|
|
44
|
+
ClassificationModel.create(
|
|
45
|
+
"test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
|
|
49
|
+
assert new_model is not None
|
|
50
|
+
assert new_model.name == "test_classification_model"
|
|
51
|
+
assert new_model.memoryset == readonly_memoryset
|
|
52
|
+
assert new_model.num_classes == 2
|
|
53
|
+
assert new_model.memory_lookup_count == 3
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_get_model(classification_model: ClassificationModel):
|
|
63
|
+
fetched_model = ClassificationModel.open(classification_model.name)
|
|
64
|
+
assert fetched_model is not None
|
|
65
|
+
assert fetched_model.id == classification_model.id
|
|
66
|
+
assert fetched_model.name == classification_model.name
|
|
67
|
+
assert fetched_model.num_classes == 2
|
|
68
|
+
assert fetched_model.memory_lookup_count == 3
|
|
69
|
+
assert fetched_model == classification_model
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_get_model_unauthenticated(unauthenticated_client):
|
|
73
|
+
with unauthenticated_client.use():
|
|
74
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
75
|
+
ClassificationModel.open("test_model")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_get_model_invalid_input():
|
|
79
|
+
with pytest.raises(ValueError, match="Invalid input"):
|
|
80
|
+
ClassificationModel.open("not valid id")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_get_model_not_found():
|
|
84
|
+
with pytest.raises(LookupError):
|
|
85
|
+
ClassificationModel.open(str(uuid4()))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
89
|
+
with unauthorized_client.use():
|
|
90
|
+
with pytest.raises(LookupError):
|
|
91
|
+
ClassificationModel.open(classification_model.name)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_list_models(classification_model: ClassificationModel):
|
|
95
|
+
models = ClassificationModel.all()
|
|
96
|
+
assert len(models) > 0
|
|
97
|
+
assert any(model.name == model.name for model in models)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_list_models_unauthenticated(unauthenticated_client):
|
|
101
|
+
with unauthenticated_client.use():
|
|
102
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
103
|
+
ClassificationModel.all()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
107
|
+
with unauthorized_client.use():
|
|
108
|
+
assert ClassificationModel.all() == []
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
112
|
+
classification_model.description = "New description"
|
|
113
|
+
assert classification_model.description == "New description"
|
|
114
|
+
|
|
115
|
+
classification_model.set(description=None)
|
|
116
|
+
assert classification_model.description is None
|
|
117
|
+
|
|
118
|
+
classification_model.set(locked=True)
|
|
119
|
+
assert classification_model.locked is True
|
|
120
|
+
|
|
121
|
+
classification_model.set(locked=False)
|
|
122
|
+
assert classification_model.locked is False
|
|
123
|
+
|
|
124
|
+
classification_model.lock()
|
|
125
|
+
assert classification_model.locked is True
|
|
126
|
+
|
|
127
|
+
classification_model.unlock()
|
|
128
|
+
assert classification_model.locked is False
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
132
|
+
ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
|
|
133
|
+
assert ClassificationModel.open("model_to_delete")
|
|
134
|
+
ClassificationModel.drop("model_to_delete")
|
|
135
|
+
with pytest.raises(LookupError):
|
|
136
|
+
ClassificationModel.open("model_to_delete")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
140
|
+
with unauthenticated_client.use():
|
|
141
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
142
|
+
ClassificationModel.drop(classification_model.name)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_delete_model_not_found():
|
|
146
|
+
with pytest.raises(LookupError):
|
|
147
|
+
ClassificationModel.drop(str(uuid4()))
|
|
148
|
+
# ignores error if specified
|
|
149
|
+
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
153
|
+
with unauthorized_client.use():
|
|
154
|
+
with pytest.raises(LookupError):
|
|
155
|
+
ClassificationModel.drop(classification_model.name)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
159
|
+
memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
|
|
160
|
+
ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
|
|
161
|
+
with pytest.raises(RuntimeError):
|
|
162
|
+
LabeledMemoryset.drop(memoryset.id)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
166
|
+
def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
|
|
167
|
+
result = (
|
|
168
|
+
classification_model.evaluate(eval_dataset)
|
|
169
|
+
if data_type == "dataset"
|
|
170
|
+
else classification_model.evaluate(eval_datasource)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
assert result is not None
|
|
174
|
+
assert isinstance(result, ClassificationMetrics)
|
|
175
|
+
|
|
176
|
+
assert isinstance(result.accuracy, float)
|
|
177
|
+
assert np.allclose(result.accuracy, 0.5)
|
|
178
|
+
assert isinstance(result.f1_score, float)
|
|
179
|
+
assert np.allclose(result.f1_score, 0.5)
|
|
180
|
+
assert isinstance(result.loss, float)
|
|
181
|
+
|
|
182
|
+
assert isinstance(result.anomaly_score_mean, float)
|
|
183
|
+
assert isinstance(result.anomaly_score_median, float)
|
|
184
|
+
assert isinstance(result.anomaly_score_variance, float)
|
|
185
|
+
assert -1.0 <= result.anomaly_score_mean <= 1.0
|
|
186
|
+
assert -1.0 <= result.anomaly_score_median <= 1.0
|
|
187
|
+
assert -1.0 <= result.anomaly_score_variance <= 1.0
|
|
188
|
+
|
|
189
|
+
assert result.pr_auc is not None
|
|
190
|
+
assert np.allclose(result.pr_auc, 0.83333)
|
|
191
|
+
assert result.pr_curve is not None
|
|
192
|
+
assert np.allclose(
|
|
193
|
+
result.pr_curve["thresholds"],
|
|
194
|
+
[0.0, 0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364],
|
|
195
|
+
)
|
|
196
|
+
assert np.allclose(result.pr_curve["precisions"], [0.5, 0.666666, 0.5, 1.0, 1.0])
|
|
197
|
+
assert np.allclose(result.pr_curve["recalls"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
198
|
+
|
|
199
|
+
assert result.roc_auc is not None
|
|
200
|
+
assert np.allclose(result.roc_auc, 0.75)
|
|
201
|
+
assert result.roc_curve is not None
|
|
202
|
+
assert np.allclose(
|
|
203
|
+
result.roc_curve["thresholds"],
|
|
204
|
+
[0.3021204173564911, 0.30852025747299194, 0.6932827234268188, 0.6972201466560364, 1.0],
|
|
205
|
+
)
|
|
206
|
+
assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.5, 0.0, 0.0])
|
|
207
|
+
assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 1.0, 0.5, 0.5, 0.0])
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
|
|
211
|
+
with pytest.raises(ValueError):
|
|
212
|
+
classification_model.evaluate(datasource, record_predictions=True, tags={"test"})
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def test_evaluate_dataset_with_nones_raises_error(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
216
|
+
with pytest.raises(ValueError):
|
|
217
|
+
classification_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
|
|
221
|
+
result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"}, batch_size=2)
|
|
222
|
+
assert result is not None
|
|
223
|
+
assert isinstance(result, ClassificationMetrics)
|
|
224
|
+
predictions = classification_model.predictions(tag="test", batch_size=100, sort=[("timestamp", "asc")])
|
|
225
|
+
assert len(predictions) == 4
|
|
226
|
+
assert all(p.tags == {"test"} for p in predictions)
|
|
227
|
+
prediction_expected_labels = [p.expected_label if p.expected_label is not None else -1 for p in predictions]
|
|
228
|
+
eval_expected_labels = list(eval_dataset["label"])
|
|
229
|
+
assert all(
|
|
230
|
+
p == l for p, l in zip(prediction_expected_labels, eval_expected_labels)
|
|
231
|
+
), f"Prediction expected labels: {prediction_expected_labels} do not match eval expected labels: {eval_expected_labels}"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
|
|
235
|
+
"""Test evaluate with partition_column on a Dataset"""
|
|
236
|
+
# Create a test dataset with partition_id column
|
|
237
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
238
|
+
[
|
|
239
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
240
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
241
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
242
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
243
|
+
]
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Evaluate with partition_column
|
|
247
|
+
result = partitioned_classification_model.evaluate(
|
|
248
|
+
eval_dataset_with_partition,
|
|
249
|
+
partition_column="partition_id",
|
|
250
|
+
partition_filter_mode="exclude_global",
|
|
251
|
+
)
|
|
252
|
+
assert result is not None
|
|
253
|
+
assert isinstance(result, ClassificationMetrics)
|
|
254
|
+
assert isinstance(result.accuracy, float)
|
|
255
|
+
assert isinstance(result.f1_score, float)
|
|
256
|
+
assert isinstance(result.loss, float)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def test_evaluate_with_partition_column_include_global(partitioned_classification_model: ClassificationModel):
|
|
260
|
+
"""Test evaluate with partition_column and include_global mode"""
|
|
261
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
262
|
+
[
|
|
263
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
264
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
265
|
+
]
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Evaluate with partition_column and include_global (default)
|
|
269
|
+
result = partitioned_classification_model.evaluate(
|
|
270
|
+
eval_dataset_with_partition,
|
|
271
|
+
partition_column="partition_id",
|
|
272
|
+
partition_filter_mode="include_global",
|
|
273
|
+
)
|
|
274
|
+
assert result is not None
|
|
275
|
+
assert isinstance(result, ClassificationMetrics)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_evaluate_with_partition_column_exclude_global(partitioned_classification_model: ClassificationModel):
|
|
279
|
+
"""Test evaluate with partition_column and exclude_global mode"""
|
|
280
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
281
|
+
[
|
|
282
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
283
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
284
|
+
]
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Evaluate with partition_column and exclude_global
|
|
288
|
+
result = partitioned_classification_model.evaluate(
|
|
289
|
+
eval_dataset_with_partition,
|
|
290
|
+
partition_column="partition_id",
|
|
291
|
+
partition_filter_mode="exclude_global",
|
|
292
|
+
)
|
|
293
|
+
assert result is not None
|
|
294
|
+
assert isinstance(result, ClassificationMetrics)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def test_evaluate_with_partition_column_only_global(partitioned_classification_model: ClassificationModel):
|
|
298
|
+
"""Test evaluate with partition_filter_mode only_global"""
|
|
299
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
300
|
+
[
|
|
301
|
+
{"value": "cats are independent animals", "label": 1, "partition_id": None},
|
|
302
|
+
{"value": "i love the beach", "label": 1, "partition_id": None},
|
|
303
|
+
]
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Evaluate with only_global mode
|
|
307
|
+
result = partitioned_classification_model.evaluate(
|
|
308
|
+
eval_dataset_with_partition,
|
|
309
|
+
partition_column="partition_id",
|
|
310
|
+
partition_filter_mode="only_global",
|
|
311
|
+
)
|
|
312
|
+
assert result is not None
|
|
313
|
+
assert isinstance(result, ClassificationMetrics)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def test_evaluate_with_partition_column_ignore_partitions(partitioned_classification_model: ClassificationModel):
|
|
317
|
+
"""Test evaluate with partition_filter_mode ignore_partitions"""
|
|
318
|
+
eval_dataset_with_partition = Dataset.from_list(
|
|
319
|
+
[
|
|
320
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
321
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p2"},
|
|
322
|
+
]
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Evaluate with ignore_partitions mode
|
|
326
|
+
result = partitioned_classification_model.evaluate(
|
|
327
|
+
eval_dataset_with_partition,
|
|
328
|
+
partition_column="partition_id",
|
|
329
|
+
partition_filter_mode="ignore_partitions",
|
|
330
|
+
)
|
|
331
|
+
assert result is not None
|
|
332
|
+
assert isinstance(result, ClassificationMetrics)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@pytest.mark.parametrize("data_type", ["dataset", "datasource"])
|
|
336
|
+
def test_evaluate_with_partition_column_datasource(partitioned_classification_model: ClassificationModel, data_type):
|
|
337
|
+
"""Test evaluate with partition_column on a Datasource"""
|
|
338
|
+
# Create a test datasource with partition_id column
|
|
339
|
+
eval_data_with_partition = [
|
|
340
|
+
{"value": "soup is good", "label": 0, "partition_id": "p1"},
|
|
341
|
+
{"value": "cats are cute", "label": 1, "partition_id": "p1"},
|
|
342
|
+
{"value": "homemade soup recipes", "label": 0, "partition_id": "p2"},
|
|
343
|
+
{"value": "cats purr when happy", "label": 1, "partition_id": "p2"},
|
|
344
|
+
]
|
|
345
|
+
|
|
346
|
+
if data_type == "dataset":
|
|
347
|
+
eval_data = Dataset.from_list(eval_data_with_partition)
|
|
348
|
+
result = partitioned_classification_model.evaluate(
|
|
349
|
+
eval_data,
|
|
350
|
+
partition_column="partition_id",
|
|
351
|
+
partition_filter_mode="exclude_global",
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
eval_datasource = Datasource.from_list("eval_datasource_with_partition", eval_data_with_partition)
|
|
355
|
+
result = partitioned_classification_model.evaluate(
|
|
356
|
+
eval_datasource,
|
|
357
|
+
partition_column="partition_id",
|
|
358
|
+
partition_filter_mode="exclude_global",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
assert result is not None
|
|
362
|
+
assert isinstance(result, ClassificationMetrics)
|
|
363
|
+
assert isinstance(result.accuracy, float)
|
|
364
|
+
assert isinstance(result.f1_score, float)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def test_predict(classification_model: ClassificationModel, label_names: list[str]):
|
|
368
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], batch_size=1)
|
|
369
|
+
assert len(predictions) == 2
|
|
370
|
+
assert predictions[0].prediction_id is not None
|
|
371
|
+
assert predictions[1].prediction_id is not None
|
|
372
|
+
assert predictions[0].label == 0
|
|
373
|
+
assert predictions[0].label_name == label_names[0]
|
|
374
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
375
|
+
assert predictions[1].label == 1
|
|
376
|
+
assert predictions[1].label_name == label_names[1]
|
|
377
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
378
|
+
|
|
379
|
+
assert predictions[0].logits is not None
|
|
380
|
+
assert predictions[1].logits is not None
|
|
381
|
+
assert len(predictions[0].logits) == 2
|
|
382
|
+
assert len(predictions[1].logits) == 2
|
|
383
|
+
assert predictions[0].logits[0] > predictions[0].logits[1]
|
|
384
|
+
assert predictions[1].logits[0] < predictions[1].logits[1]
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def test_classification_prediction_has_no_label(classification_model: ClassificationModel):
|
|
388
|
+
"""Ensure optional score is None for classification predictions."""
|
|
389
|
+
prediction = classification_model.predict("Do you want to go to the beach?")
|
|
390
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
391
|
+
assert prediction.label is None
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
395
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
396
|
+
assert len(predictions) == 2
|
|
397
|
+
assert predictions[0].prediction_id is None
|
|
398
|
+
assert predictions[1].prediction_id is None
|
|
399
|
+
assert predictions[0].label == 0
|
|
400
|
+
assert predictions[0].label_name == label_names[0]
|
|
401
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
402
|
+
assert predictions[1].label == 1
|
|
403
|
+
assert predictions[1].label_name == label_names[1]
|
|
404
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
408
|
+
with unauthenticated_client.use():
|
|
409
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
410
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
414
|
+
with unauthorized_client.use():
|
|
415
|
+
with pytest.raises(LookupError):
|
|
416
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
420
|
+
model = ClassificationModel.create(
|
|
421
|
+
"test_model_lookup_count_too_high",
|
|
422
|
+
readonly_memoryset,
|
|
423
|
+
num_classes=2,
|
|
424
|
+
memory_lookup_count=readonly_memoryset.length + 2,
|
|
425
|
+
)
|
|
426
|
+
with pytest.raises(RuntimeError):
|
|
427
|
+
model.predict("test")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def test_predict_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
431
|
+
"""Test predict with a specific partition_id"""
|
|
432
|
+
# Predict with partition_id p1 - should use memories from p1
|
|
433
|
+
prediction = partitioned_classification_model.predict(
|
|
434
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
435
|
+
)
|
|
436
|
+
assert prediction.label is not None
|
|
437
|
+
assert prediction.label_name in label_names
|
|
438
|
+
assert 0 <= prediction.confidence <= 1
|
|
439
|
+
assert prediction.logits is not None
|
|
440
|
+
assert len(prediction.logits) == 2
|
|
441
|
+
|
|
442
|
+
# Predict with partition_id p2 - should use memories from p2
|
|
443
|
+
prediction_p2 = partitioned_classification_model.predict(
|
|
444
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global"
|
|
445
|
+
)
|
|
446
|
+
assert prediction_p2.label is not None
|
|
447
|
+
assert prediction_p2.label_name in label_names
|
|
448
|
+
assert 0 <= prediction_p2.confidence <= 1
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def test_predict_with_partition_id_include_global(
|
|
452
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
453
|
+
):
|
|
454
|
+
"""Test predict with partition_id and include_global mode (default)"""
|
|
455
|
+
# Predict with partition_id p1 and include_global (default) - should include both p1 and global memories
|
|
456
|
+
prediction = partitioned_classification_model.predict(
|
|
457
|
+
"soup", partition_id="p1", partition_filter_mode="include_global"
|
|
458
|
+
)
|
|
459
|
+
assert prediction.label is not None
|
|
460
|
+
assert prediction.label_name in label_names
|
|
461
|
+
assert 0 <= prediction.confidence <= 1
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def test_predict_with_partition_id_exclude_global(
|
|
465
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
466
|
+
):
|
|
467
|
+
"""Test predict with partition_id and exclude_global mode"""
|
|
468
|
+
# Predict with partition_id p1 and exclude_global - should only use p1 memories
|
|
469
|
+
prediction = partitioned_classification_model.predict(
|
|
470
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
471
|
+
)
|
|
472
|
+
assert prediction.label is not None
|
|
473
|
+
assert prediction.label_name in label_names
|
|
474
|
+
assert 0 <= prediction.confidence <= 1
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def test_predict_with_partition_id_only_global(
|
|
478
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
479
|
+
):
|
|
480
|
+
"""Test predict with partition_filter_mode only_global"""
|
|
481
|
+
# Predict with only_global mode - should only use global memories
|
|
482
|
+
prediction = partitioned_classification_model.predict("cats", partition_filter_mode="only_global")
|
|
483
|
+
assert prediction.label is not None
|
|
484
|
+
assert prediction.label_name in label_names
|
|
485
|
+
assert 0 <= prediction.confidence <= 1
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def test_predict_with_partition_id_ignore_partitions(
|
|
489
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
490
|
+
):
|
|
491
|
+
"""Test predict with partition_filter_mode ignore_partitions"""
|
|
492
|
+
# Predict with ignore_partitions mode - should ignore partition filtering
|
|
493
|
+
prediction = partitioned_classification_model.predict("soup", partition_filter_mode="ignore_partitions")
|
|
494
|
+
assert prediction.label is not None
|
|
495
|
+
assert prediction.label_name in label_names
|
|
496
|
+
assert 0 <= prediction.confidence <= 1
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def test_predict_batch_with_partition_id(partitioned_classification_model: ClassificationModel, label_names: list[str]):
|
|
500
|
+
"""Test batch predict with partition_id"""
|
|
501
|
+
# Batch predict with partition_id p1
|
|
502
|
+
predictions = partitioned_classification_model.predict(
|
|
503
|
+
["soup is good", "cats are cute"],
|
|
504
|
+
partition_id="p1",
|
|
505
|
+
partition_filter_mode="exclude_global",
|
|
506
|
+
)
|
|
507
|
+
assert len(predictions) == 2
|
|
508
|
+
assert all(p.label is not None for p in predictions)
|
|
509
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
510
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
511
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def test_predict_with_partition_id_and_filters(
|
|
515
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
516
|
+
):
|
|
517
|
+
"""Test predict with partition_id and filters"""
|
|
518
|
+
# Predict with partition_id and filters
|
|
519
|
+
prediction = partitioned_classification_model.predict(
|
|
520
|
+
"soup",
|
|
521
|
+
partition_id="p1",
|
|
522
|
+
partition_filter_mode="exclude_global",
|
|
523
|
+
filters=[("key", "==", "g1")],
|
|
524
|
+
)
|
|
525
|
+
assert prediction.label is not None
|
|
526
|
+
assert prediction.label_name in label_names
|
|
527
|
+
assert 0 <= prediction.confidence <= 1
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def test_predict_batch_with_list_of_partition_ids(
|
|
531
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
532
|
+
):
|
|
533
|
+
"""Test batch predict with a list of partition_ids (one for each query input)"""
|
|
534
|
+
# Batch predict with a list of partition_ids - one for each input
|
|
535
|
+
# First input uses p1, second input uses p2
|
|
536
|
+
predictions = partitioned_classification_model.predict(
|
|
537
|
+
["soup is good", "cats are cute"],
|
|
538
|
+
partition_id=["p1", "p2"],
|
|
539
|
+
partition_filter_mode="exclude_global",
|
|
540
|
+
)
|
|
541
|
+
assert len(predictions) == 2
|
|
542
|
+
assert all(p.label is not None for p in predictions)
|
|
543
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
544
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
545
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
546
|
+
|
|
547
|
+
# Verify that predictions were made using the correct partitions
|
|
548
|
+
# Each prediction should use memories from its respective partition
|
|
549
|
+
assert predictions[0].input_value == "soup is good"
|
|
550
|
+
assert predictions[1].input_value == "cats are cute"
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@pytest.mark.asyncio
|
|
554
|
+
async def test_predict_async_with_partition_id(
|
|
555
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
556
|
+
):
|
|
557
|
+
"""Test async predict with partition_id"""
|
|
558
|
+
# Async predict with partition_id p1
|
|
559
|
+
prediction = await partitioned_classification_model.apredict(
|
|
560
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global"
|
|
561
|
+
)
|
|
562
|
+
assert prediction.label is not None
|
|
563
|
+
assert prediction.label_name in label_names
|
|
564
|
+
assert 0 <= prediction.confidence <= 1
|
|
565
|
+
assert prediction.logits is not None
|
|
566
|
+
assert len(prediction.logits) == 2
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@pytest.mark.asyncio
|
|
570
|
+
async def test_predict_async_batch_with_partition_id(
|
|
571
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
572
|
+
):
|
|
573
|
+
"""Test async batch predict with partition_id"""
|
|
574
|
+
# Async batch predict with partition_id p1
|
|
575
|
+
predictions = await partitioned_classification_model.apredict(
|
|
576
|
+
["soup is good", "cats are cute"],
|
|
577
|
+
partition_id="p1",
|
|
578
|
+
partition_filter_mode="exclude_global",
|
|
579
|
+
)
|
|
580
|
+
assert len(predictions) == 2
|
|
581
|
+
assert all(p.label is not None for p in predictions)
|
|
582
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
583
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
@pytest.mark.asyncio
|
|
587
|
+
async def test_predict_async_batch_with_list_of_partition_ids(
|
|
588
|
+
partitioned_classification_model: ClassificationModel, label_names: list[str]
|
|
589
|
+
):
|
|
590
|
+
"""Test async batch predict with a list of partition_ids (one for each query input)"""
|
|
591
|
+
# Async batch predict with a list of partition_ids - one for each input
|
|
592
|
+
# First input uses p1, second input uses p2
|
|
593
|
+
predictions = await partitioned_classification_model.apredict(
|
|
594
|
+
["soup is good", "cats are cute"],
|
|
595
|
+
partition_id=["p1", "p2"],
|
|
596
|
+
partition_filter_mode="exclude_global",
|
|
597
|
+
)
|
|
598
|
+
assert len(predictions) == 2
|
|
599
|
+
assert all(p.label is not None for p in predictions)
|
|
600
|
+
assert all(p.label_name in label_names for p in predictions)
|
|
601
|
+
assert all(0 <= p.confidence <= 1 for p in predictions)
|
|
602
|
+
assert all(p.logits is not None and len(p.logits) == 2 for p in predictions)
|
|
603
|
+
|
|
604
|
+
# Verify that predictions were made using the correct partitions
|
|
605
|
+
# Each prediction should use memories from its respective partition
|
|
606
|
+
assert predictions[0].input_value == "soup is good"
|
|
607
|
+
assert predictions[1].input_value == "cats are cute"
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def test_record_prediction_feedback(classification_model: ClassificationModel):
|
|
611
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
612
|
+
expected_labels = [0, 1]
|
|
613
|
+
classification_model.record_feedback(
|
|
614
|
+
{
|
|
615
|
+
"prediction_id": p.prediction_id,
|
|
616
|
+
"category": "correct",
|
|
617
|
+
"value": p.label == expected_label,
|
|
618
|
+
}
|
|
619
|
+
for expected_label, p in zip(expected_labels, predictions)
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
|
|
624
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
625
|
+
with pytest.raises(ValueError):
|
|
626
|
+
classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
|
|
630
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
631
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
632
|
+
classification_model.record_feedback(
|
|
633
|
+
{"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
|
|
638
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
639
|
+
classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
|
|
643
|
+
inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
|
|
644
|
+
"test_memoryset_inverted_labels",
|
|
645
|
+
hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
|
|
646
|
+
embedding_model=PretrainedEmbeddingModel.GTE_BASE,
|
|
647
|
+
)
|
|
648
|
+
with classification_model.use_memoryset(inverted_labeled_memoryset):
|
|
649
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
650
|
+
assert predictions[0].label == 1
|
|
651
|
+
assert predictions[1].label == 0
|
|
652
|
+
|
|
653
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
654
|
+
assert predictions[0].label == 0
|
|
655
|
+
assert predictions[1].label == 1
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def test_predict_with_expected_labels(classification_model: ClassificationModel):
|
|
659
|
+
prediction = classification_model.predict("Do you love soup?", expected_labels=1)
|
|
660
|
+
assert prediction.expected_label == 1
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
|
|
664
|
+
# invalid number of expected labels for batch prediction
|
|
665
|
+
with pytest.raises(ValueError, match=r"Invalid input.*"):
|
|
666
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
|
|
667
|
+
# invalid label value
|
|
668
|
+
with pytest.raises(ValueError):
|
|
669
|
+
classification_model.predict("Do you love soup?", expected_labels=5)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def test_predict_with_filters(classification_model: ClassificationModel):
|
|
673
|
+
# there are no memories with label 0 and key g1, so we force a wrong prediction
|
|
674
|
+
filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
|
|
675
|
+
assert filtered_prediction.label == 1
|
|
676
|
+
assert filtered_prediction.label_name == "cats"
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
|
|
680
|
+
model = ClassificationModel.create(
|
|
681
|
+
"test_predict_with_memoryset_update",
|
|
682
|
+
writable_memoryset,
|
|
683
|
+
num_classes=2,
|
|
684
|
+
memory_lookup_count=3,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
prediction = model.predict("Do you love soup?")
|
|
688
|
+
assert prediction.label == 0
|
|
689
|
+
assert prediction.label_name == "soup"
|
|
690
|
+
|
|
691
|
+
# insert new memories
|
|
692
|
+
writable_memoryset.insert(
|
|
693
|
+
[
|
|
694
|
+
{"value": "Do you love soup?", "label": 1, "key": "g1"},
|
|
695
|
+
{"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
|
|
696
|
+
{"value": "Do you love crackers?", "label": 1, "key": "g2"},
|
|
697
|
+
{"value": "Do you love broth?", "label": 1, "key": "g2"},
|
|
698
|
+
{"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
|
|
699
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
700
|
+
{"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
|
|
701
|
+
],
|
|
702
|
+
)
|
|
703
|
+
prediction = model.predict("Do you love soup?")
|
|
704
|
+
assert prediction.label == 1
|
|
705
|
+
assert prediction.label_name == "cats"
|
|
706
|
+
|
|
707
|
+
ClassificationModel.drop("test_predict_with_memoryset_update")
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def test_last_prediction_with_batch(classification_model: ClassificationModel):
|
|
711
|
+
predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
712
|
+
assert classification_model.last_prediction is not None
|
|
713
|
+
assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
|
|
714
|
+
assert classification_model.last_prediction.input_value == "Are cats cute?"
|
|
715
|
+
assert classification_model._last_prediction_was_batch is True
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
719
|
+
# Test that last_prediction is updated correctly with single prediction
|
|
720
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
721
|
+
assert classification_model.last_prediction is not None
|
|
722
|
+
assert classification_model.last_prediction.prediction_id == prediction.prediction_id
|
|
723
|
+
assert classification_model.last_prediction.input_value == "Do you love soup?"
|
|
724
|
+
assert classification_model._last_prediction_was_batch is False
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
728
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
729
|
+
|
|
730
|
+
writable_memoryset.analyze(
|
|
731
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
732
|
+
lookup_count=3,
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
model = ClassificationModel.create(
|
|
736
|
+
"test_model_for_explain",
|
|
737
|
+
writable_memoryset,
|
|
738
|
+
num_classes=2,
|
|
739
|
+
memory_lookup_count=3,
|
|
740
|
+
description="This is a test model for explain",
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
744
|
+
assert len(predictions) == 2
|
|
745
|
+
|
|
746
|
+
try:
|
|
747
|
+
explanation = predictions[0].explanation
|
|
748
|
+
assert explanation is not None
|
|
749
|
+
assert len(explanation) > 10
|
|
750
|
+
assert "soup" in explanation.lower()
|
|
751
|
+
except Exception as e:
|
|
752
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
753
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
|
|
754
|
+
else:
|
|
755
|
+
raise e
|
|
756
|
+
finally:
|
|
757
|
+
ClassificationModel.drop("test_model_for_explain")
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
761
|
+
def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
762
|
+
"""Test getting action recommendations for predictions"""
|
|
763
|
+
|
|
764
|
+
writable_memoryset.analyze(
|
|
765
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
766
|
+
lookup_count=3,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
model = ClassificationModel.create(
|
|
770
|
+
"test_model_for_action",
|
|
771
|
+
writable_memoryset,
|
|
772
|
+
num_classes=2,
|
|
773
|
+
memory_lookup_count=3,
|
|
774
|
+
description="This is a test model for action recommendations",
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
# Make a prediction with expected label to simulate incorrect prediction
|
|
778
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
779
|
+
|
|
780
|
+
memoryset_length = model.memoryset.length
|
|
781
|
+
|
|
782
|
+
try:
|
|
783
|
+
# Get action recommendation
|
|
784
|
+
action, rationale = prediction.recommend_action()
|
|
785
|
+
|
|
786
|
+
assert action is not None
|
|
787
|
+
assert rationale is not None
|
|
788
|
+
assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
|
|
789
|
+
assert len(rationale) > 10
|
|
790
|
+
|
|
791
|
+
# Test memory suggestions
|
|
792
|
+
suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
|
|
793
|
+
memory_suggestions = suggestions_response.suggestions
|
|
794
|
+
|
|
795
|
+
assert memory_suggestions is not None
|
|
796
|
+
assert len(memory_suggestions) == 2
|
|
797
|
+
|
|
798
|
+
for suggestion in memory_suggestions:
|
|
799
|
+
assert isinstance(suggestion[0], str)
|
|
800
|
+
assert len(suggestion[0]) > 0
|
|
801
|
+
assert isinstance(suggestion[1], str)
|
|
802
|
+
assert suggestion[1] in model.memoryset.label_names
|
|
803
|
+
|
|
804
|
+
suggestions_response.apply()
|
|
805
|
+
|
|
806
|
+
model.memoryset.refresh()
|
|
807
|
+
assert model.memoryset.length == memoryset_length + 2
|
|
808
|
+
|
|
809
|
+
except Exception as e:
|
|
810
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
811
|
+
logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
|
|
812
|
+
else:
|
|
813
|
+
raise e
|
|
814
|
+
finally:
|
|
815
|
+
ClassificationModel.drop("test_model_for_action")
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
819
|
+
"""Test that prompt parameter is properly passed through to predictions"""
|
|
820
|
+
# Test with an instruction-supporting embedding model if available
|
|
821
|
+
prediction_with_prompt = classification_model.predict(
|
|
822
|
+
"I love this product!", prompt="Represent this text for sentiment classification:"
|
|
823
|
+
)
|
|
824
|
+
prediction_without_prompt = classification_model.predict("I love this product!")
|
|
825
|
+
|
|
826
|
+
# Both should work and return valid predictions
|
|
827
|
+
assert prediction_with_prompt.label is not None
|
|
828
|
+
assert prediction_without_prompt.label is not None
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
@pytest.mark.asyncio
|
|
832
|
+
async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
|
|
833
|
+
"""Test async prediction with a single value"""
|
|
834
|
+
prediction = await classification_model.apredict("Do you love soup?")
|
|
835
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
836
|
+
assert prediction.prediction_id is not None
|
|
837
|
+
assert prediction.label == 0
|
|
838
|
+
assert prediction.label_name == label_names[0]
|
|
839
|
+
assert 0 <= prediction.confidence <= 1
|
|
840
|
+
assert prediction.logits is not None
|
|
841
|
+
assert len(prediction.logits) == 2
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
@pytest.mark.asyncio
|
|
845
|
+
async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
|
|
846
|
+
"""Test async prediction with a batch of values"""
|
|
847
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
|
|
848
|
+
assert len(predictions) == 2
|
|
849
|
+
assert predictions[0].prediction_id is not None
|
|
850
|
+
assert predictions[1].prediction_id is not None
|
|
851
|
+
assert predictions[0].label == 0
|
|
852
|
+
assert predictions[0].label_name == label_names[0]
|
|
853
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
854
|
+
assert predictions[1].label == 1
|
|
855
|
+
assert predictions[1].label_name == label_names[1]
|
|
856
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
@pytest.mark.asyncio
|
|
860
|
+
async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
|
|
861
|
+
"""Test async prediction with expected labels"""
|
|
862
|
+
prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
|
|
863
|
+
assert prediction.expected_label == 1
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@pytest.mark.asyncio
|
|
867
|
+
async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
868
|
+
"""Test async prediction with telemetry disabled"""
|
|
869
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
870
|
+
assert len(predictions) == 2
|
|
871
|
+
assert predictions[0].prediction_id is None
|
|
872
|
+
assert predictions[1].prediction_id is None
|
|
873
|
+
assert predictions[0].label == 0
|
|
874
|
+
assert predictions[0].label_name == label_names[0]
|
|
875
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
876
|
+
assert predictions[1].label == 1
|
|
877
|
+
assert predictions[1].label_name == label_names[1]
|
|
878
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@pytest.mark.asyncio
|
|
882
|
+
async def test_predict_async_with_filters(classification_model: ClassificationModel):
|
|
883
|
+
"""Test async prediction with filters"""
|
|
884
|
+
# there are no memories with label 0 and key g2, so we force a wrong prediction
|
|
885
|
+
filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
|
|
886
|
+
assert filtered_prediction.label == 1
|
|
887
|
+
assert filtered_prediction.label_name == "cats"
|