orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,68 +1,78 @@
1
+ import logging
1
2
  from uuid import uuid4
2
3
 
4
+ import numpy as np
3
5
  import pytest
4
- from datasets.arrow_dataset import Dataset
6
+ from datasets import Dataset
5
7
 
6
- from .classification_model import ClassificationModel
8
+ from .classification_model import ClassificationMetrics, ClassificationModel
9
+ from .conftest import skip_in_ci
7
10
  from .datasource import Datasource
8
11
  from .embedding_model import PretrainedEmbeddingModel
9
12
  from .memoryset import LabeledMemoryset
13
+ from .telemetry import ClassificationPrediction
10
14
 
11
15
 
12
- def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
13
- assert model is not None
14
- assert model.name == "test_model"
15
- assert model.memoryset == memoryset
16
- assert model.num_classes == 2
17
- assert model.memory_lookup_count == 3
16
+ def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
17
+ assert classification_model is not None
18
+ assert classification_model.name == "test_classification_model"
19
+ assert classification_model.memoryset == readonly_memoryset
20
+ assert classification_model.num_classes == 2
21
+ assert classification_model.memory_lookup_count == 3
18
22
 
19
23
 
20
- def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
24
+ def test_create_model_already_exists_error(readonly_memoryset, classification_model):
21
25
  with pytest.raises(ValueError):
22
- ClassificationModel.create("test_model", memoryset)
26
+ ClassificationModel.create("test_classification_model", readonly_memoryset)
23
27
  with pytest.raises(ValueError):
24
- ClassificationModel.create("test_model", memoryset, if_exists="error")
28
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
25
29
 
26
30
 
27
- def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
31
+ def test_create_model_already_exists_return(readonly_memoryset, classification_model):
28
32
  with pytest.raises(ValueError):
29
- ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
33
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
30
34
 
31
35
  with pytest.raises(ValueError):
32
- ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
36
+ ClassificationModel.create(
37
+ "test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
38
+ )
33
39
 
34
40
  with pytest.raises(ValueError):
35
- ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
41
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
36
42
 
37
43
  with pytest.raises(ValueError):
38
- ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
44
+ ClassificationModel.create(
45
+ "test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
46
+ )
39
47
 
40
- new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
48
+ new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
41
49
  assert new_model is not None
42
- assert new_model.name == "test_model"
43
- assert new_model.memoryset == memoryset
50
+ assert new_model.name == "test_classification_model"
51
+ assert new_model.memoryset == readonly_memoryset
44
52
  assert new_model.num_classes == 2
45
53
  assert new_model.memory_lookup_count == 3
46
54
 
47
55
 
48
- def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
49
- with pytest.raises(ValueError, match="Invalid API key"):
50
- ClassificationModel.create("test_model", memoryset)
56
+ def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ ClassificationModel.create("test_model", readonly_memoryset)
51
60
 
52
61
 
53
- def test_get_model(model: ClassificationModel):
54
- fetched_model = ClassificationModel.open(model.name)
62
+ def test_get_model(classification_model: ClassificationModel):
63
+ fetched_model = ClassificationModel.open(classification_model.name)
55
64
  assert fetched_model is not None
56
- assert fetched_model.id == model.id
57
- assert fetched_model.name == model.name
65
+ assert fetched_model.id == classification_model.id
66
+ assert fetched_model.name == classification_model.name
58
67
  assert fetched_model.num_classes == 2
59
68
  assert fetched_model.memory_lookup_count == 3
60
- assert fetched_model == model
69
+ assert fetched_model == classification_model
61
70
 
62
71
 
63
- def test_get_model_unauthenticated(unauthenticated):
64
- with pytest.raises(ValueError, match="Invalid API key"):
65
- ClassificationModel.open("test_model")
72
+ def test_get_model_unauthenticated(unauthenticated_client):
73
+ with unauthenticated_client.use():
74
+ with pytest.raises(ValueError, match="Invalid API key"):
75
+ ClassificationModel.open("test_model")
66
76
 
67
77
 
68
78
  def test_get_model_invalid_input():
@@ -75,37 +85,61 @@ def test_get_model_not_found():
75
85
  ClassificationModel.open(str(uuid4()))
76
86
 
77
87
 
78
- def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
79
- with pytest.raises(LookupError):
80
- ClassificationModel.open(model.name)
88
+ def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
89
+ with unauthorized_client.use():
90
+ with pytest.raises(LookupError):
91
+ ClassificationModel.open(classification_model.name)
81
92
 
82
93
 
83
- def test_list_models(model: ClassificationModel):
94
+ def test_list_models(classification_model: ClassificationModel):
84
95
  models = ClassificationModel.all()
85
96
  assert len(models) > 0
86
97
  assert any(model.name == model.name for model in models)
87
98
 
88
99
 
89
- def test_list_models_unauthenticated(unauthenticated):
90
- with pytest.raises(ValueError, match="Invalid API key"):
91
- ClassificationModel.all()
100
+ def test_list_models_unauthenticated(unauthenticated_client):
101
+ with unauthenticated_client.use():
102
+ with pytest.raises(ValueError, match="Invalid API key"):
103
+ ClassificationModel.all()
104
+
105
+
106
+ def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
107
+ with unauthorized_client.use():
108
+ assert ClassificationModel.all() == []
109
+
110
+
111
+ def test_update_model_attributes(classification_model: ClassificationModel):
112
+ classification_model.description = "New description"
113
+ assert classification_model.description == "New description"
114
+
115
+ classification_model.set(description=None)
116
+ assert classification_model.description is None
117
+
118
+ classification_model.set(locked=True)
119
+ assert classification_model.locked is True
92
120
 
121
+ classification_model.set(locked=False)
122
+ assert classification_model.locked is False
93
123
 
94
- def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
95
- assert ClassificationModel.all() == []
124
+ classification_model.lock()
125
+ assert classification_model.locked is True
96
126
 
127
+ classification_model.unlock()
128
+ assert classification_model.locked is False
97
129
 
98
- def test_delete_model(memoryset: LabeledMemoryset):
99
- ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
130
+
131
+ def test_delete_model(readonly_memoryset: LabeledMemoryset):
132
+ ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
100
133
  assert ClassificationModel.open("model_to_delete")
101
134
  ClassificationModel.drop("model_to_delete")
102
135
  with pytest.raises(LookupError):
103
136
  ClassificationModel.open("model_to_delete")
104
137
 
105
138
 
106
- def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
107
- with pytest.raises(ValueError, match="Invalid API key"):
108
- ClassificationModel.drop(model.name)
139
+ def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
140
+ with unauthenticated_client.use():
141
+ with pytest.raises(ValueError, match="Invalid API key"):
142
+ ClassificationModel.drop(classification_model.name)
109
143
 
110
144
 
111
145
  def test_delete_model_not_found():
@@ -115,53 +149,83 @@ def test_delete_model_not_found():
115
149
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
116
150
 
117
151
 
118
- def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
119
- with pytest.raises(LookupError):
120
- ClassificationModel.drop(model.name)
152
+ def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
153
+ with unauthorized_client.use():
154
+ with pytest.raises(LookupError):
155
+ ClassificationModel.drop(classification_model.name)
121
156
 
122
157
 
123
- @pytest.mark.flaky
124
158
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
125
- memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
159
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
126
160
  ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
127
161
  with pytest.raises(RuntimeError):
128
162
  LabeledMemoryset.drop(memoryset.id)
129
163
 
130
164
 
131
- def test_evaluate(model):
132
- eval_datasource = Datasource.from_list(
133
- "eval_datasource",
134
- [
135
- {"text": "chicken noodle soup is the best", "label": 1},
136
- {"text": "cats are cute", "label": 0},
137
- {"text": "soup is great for the winter", "label": 0},
138
- {"text": "i love cats", "label": 1},
139
- ],
165
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
166
+ def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
167
+ result = (
168
+ classification_model.evaluate(eval_dataset)
169
+ if data_type == "dataset"
170
+ else classification_model.evaluate(eval_datasource)
140
171
  )
141
- result = model.evaluate(eval_datasource, value_column="text")
172
+
142
173
  assert result is not None
143
- assert isinstance(result["accuracy"], float)
144
- assert isinstance(result["f1_score"], float)
145
- assert isinstance(result["loss"], float)
146
-
147
-
148
- def test_evaluate_with_telemetry(model):
149
- samples = [
150
- {"text": "chicken noodle soup is the best", "label": 1},
151
- {"text": "cats are cute", "label": 0},
152
- ]
153
- eval_datasource = Datasource.from_list("eval_datasource_2", samples)
154
- result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
174
+ assert isinstance(result, ClassificationMetrics)
175
+
176
+ assert isinstance(result.accuracy, float)
177
+ assert np.allclose(result.accuracy, 0.5)
178
+ assert isinstance(result.f1_score, float)
179
+ assert np.allclose(result.f1_score, 0.5)
180
+ assert isinstance(result.loss, float)
181
+
182
+ assert isinstance(result.anomaly_score_mean, float)
183
+ assert isinstance(result.anomaly_score_median, float)
184
+ assert isinstance(result.anomaly_score_variance, float)
185
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
186
+ assert -1.0 <= result.anomaly_score_median <= 1.0
187
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
188
+
189
+ assert result.pr_auc is not None
190
+ assert np.allclose(result.pr_auc, 0.75)
191
+ assert result.pr_curve is not None
192
+ assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
193
+ assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
194
+ assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
195
+
196
+ assert result.roc_auc is not None
197
+ assert np.allclose(result.roc_auc, 0.625)
198
+ assert result.roc_curve is not None
199
+ assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
200
+ assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
201
+ assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
202
+
203
+
204
+ def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
205
+ with pytest.raises(ValueError):
206
+ classification_model.evaluate(datasource, record_predictions=True, tags={"test"})
207
+
208
+
209
+ def test_evaluate_dataset_with_nones_raises_error(classification_model: ClassificationModel, hf_dataset: Dataset):
210
+ with pytest.raises(ValueError):
211
+ classification_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
212
+
213
+
214
+ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
215
+ result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
155
216
  assert result is not None
156
- predictions = model.predictions(tag="test")
157
- assert len(predictions) == 2
217
+ assert isinstance(result, ClassificationMetrics)
218
+ predictions = classification_model.predictions(tag="test")
219
+ assert len(predictions) == 4
158
220
  assert all(p.tags == {"test"} for p in predictions)
159
- assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
221
+ assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
160
222
 
161
223
 
162
- def test_predict(model: ClassificationModel, label_names: list[str]):
163
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
224
+ def test_predict(classification_model: ClassificationModel, label_names: list[str]):
225
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
164
226
  assert len(predictions) == 2
227
+ assert predictions[0].prediction_id is not None
228
+ assert predictions[1].prediction_id is not None
165
229
  assert predictions[0].label == 0
166
230
  assert predictions[0].label_name == label_names[0]
167
231
  assert 0 <= predictions[0].confidence <= 1
@@ -169,29 +233,61 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
169
233
  assert predictions[1].label_name == label_names[1]
170
234
  assert 0 <= predictions[1].confidence <= 1
171
235
 
236
+ assert predictions[0].logits is not None
237
+ assert predictions[1].logits is not None
238
+ assert len(predictions[0].logits) == 2
239
+ assert len(predictions[1].logits) == 2
240
+ assert predictions[0].logits[0] > predictions[0].logits[1]
241
+ assert predictions[1].logits[0] < predictions[1].logits[1]
172
242
 
173
- def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
174
- with pytest.raises(ValueError, match="Invalid API key"):
175
- model.predict(["Do you love soup?", "Are cats cute?"])
176
243
 
244
+ def test_classification_prediction_has_no_label(classification_model: ClassificationModel):
245
+ """Ensure optional score is None for classification predictions."""
246
+ prediction = classification_model.predict("Do you want to go to the beach?")
247
+ assert isinstance(prediction, ClassificationPrediction)
248
+ assert prediction.label is None
177
249
 
178
- def test_predict_unauthorized(unauthorized, model: ClassificationModel):
179
- with pytest.raises(LookupError):
180
- model.predict(["Do you love soup?", "Are cats cute?"])
250
+
251
+ def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
252
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
253
+ assert len(predictions) == 2
254
+ assert predictions[0].prediction_id is None
255
+ assert predictions[1].prediction_id is None
256
+ assert predictions[0].label == 0
257
+ assert predictions[0].label_name == label_names[0]
258
+ assert 0 <= predictions[0].confidence <= 1
259
+ assert predictions[1].label == 1
260
+ assert predictions[1].label_name == label_names[1]
261
+ assert 0 <= predictions[1].confidence <= 1
262
+
263
+
264
+ def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
265
+ with unauthenticated_client.use():
266
+ with pytest.raises(ValueError, match="Invalid API key"):
267
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
268
+
269
+
270
+ def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
271
+ with unauthorized_client.use():
272
+ with pytest.raises(LookupError):
273
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
181
274
 
182
275
 
183
- def test_predict_constraint_violation(memoryset: LabeledMemoryset):
276
+ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
184
277
  model = ClassificationModel.create(
185
- "test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
278
+ "test_model_lookup_count_too_high",
279
+ readonly_memoryset,
280
+ num_classes=2,
281
+ memory_lookup_count=readonly_memoryset.length + 2,
186
282
  )
187
283
  with pytest.raises(RuntimeError):
188
284
  model.predict("test")
189
285
 
190
286
 
191
- def test_record_prediction_feedback(model: ClassificationModel):
192
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
287
+ def test_record_prediction_feedback(classification_model: ClassificationModel):
288
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
193
289
  expected_labels = [0, 1]
194
- model.record_feedback(
290
+ classification_model.record_feedback(
195
291
  {
196
292
  "prediction_id": p.prediction_id,
197
293
  "category": "correct",
@@ -201,66 +297,268 @@ def test_record_prediction_feedback(model: ClassificationModel):
201
297
  )
202
298
 
203
299
 
204
- def test_record_prediction_feedback_missing_category(model: ClassificationModel):
205
- prediction = model.predict("Do you love soup?")
300
+ def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
301
+ prediction = classification_model.predict("Do you love soup?")
206
302
  with pytest.raises(ValueError):
207
- model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
303
+ classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
208
304
 
209
305
 
210
- def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
211
- prediction = model.predict("Do you love soup?")
306
+ def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
307
+ prediction = classification_model.predict("Do you love soup?")
212
308
  with pytest.raises(ValueError, match=r"Invalid input.*"):
213
- model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
309
+ classification_model.record_feedback(
310
+ {"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
311
+ )
214
312
 
215
313
 
216
- def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
314
+ def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
217
315
  with pytest.raises(ValueError, match=r"Invalid input.*"):
218
- model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
316
+ classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
219
317
 
220
318
 
221
- def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
319
+ def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
222
320
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
223
321
  "test_memoryset_inverted_labels",
224
322
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
225
- value_column="text",
226
323
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
227
324
  )
228
- with model.use_memoryset(inverted_labeled_memoryset):
229
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
325
+ with classification_model.use_memoryset(inverted_labeled_memoryset):
326
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
230
327
  assert predictions[0].label == 1
231
328
  assert predictions[1].label == 0
232
329
 
233
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
330
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
234
331
  assert predictions[0].label == 0
235
332
  assert predictions[1].label == 1
236
333
 
237
334
 
238
- def test_predict_with_expected_labels(model: ClassificationModel):
239
- prediction = model.predict("Do you love soup?", expected_labels=1)
335
+ def test_predict_with_expected_labels(classification_model: ClassificationModel):
336
+ prediction = classification_model.predict("Do you love soup?", expected_labels=1)
240
337
  assert prediction.expected_label == 1
241
338
 
242
339
 
243
- def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
340
+ def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
244
341
  # invalid number of expected labels for batch prediction
245
342
  with pytest.raises(ValueError, match=r"Invalid input.*"):
246
- model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
343
+ classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
247
344
  # invalid label value
248
345
  with pytest.raises(ValueError):
249
- model.predict("Do you love soup?", expected_labels=5)
346
+ classification_model.predict("Do you love soup?", expected_labels=5)
250
347
 
251
348
 
252
- def test_last_prediction_with_batch(model: ClassificationModel):
253
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
254
- assert model.last_prediction is not None
255
- assert model.last_prediction.prediction_id == predictions[-1].prediction_id
256
- assert model.last_prediction.input_value == "Are cats cute?"
257
- assert model._last_prediction_was_batch is True
349
+ def test_predict_with_filters(classification_model: ClassificationModel):
350
+ # there are no memories with label 0 and key g1, so we force a wrong prediction
351
+ filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
352
+ assert filtered_prediction.label == 1
353
+ assert filtered_prediction.label_name == "cats"
258
354
 
259
355
 
260
- def test_last_prediction_with_single(model: ClassificationModel):
261
- # Test that last_prediction is updated correctly with single prediction
356
+ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
357
+ model = ClassificationModel.create(
358
+ "test_predict_with_memoryset_update",
359
+ writable_memoryset,
360
+ num_classes=2,
361
+ memory_lookup_count=3,
362
+ )
363
+
262
364
  prediction = model.predict("Do you love soup?")
263
- assert model.last_prediction is not None
264
- assert model.last_prediction.prediction_id == prediction.prediction_id
265
- assert model.last_prediction.input_value == "Do you love soup?"
266
- assert model._last_prediction_was_batch is False
365
+ assert prediction.label == 0
366
+ assert prediction.label_name == "soup"
367
+
368
+ # insert new memories
369
+ writable_memoryset.insert(
370
+ [
371
+ {"value": "Do you love soup?", "label": 1, "key": "g1"},
372
+ {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
373
+ {"value": "Do you love crackers?", "label": 1, "key": "g2"},
374
+ {"value": "Do you love broth?", "label": 1, "key": "g2"},
375
+ {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
376
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
377
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
378
+ ],
379
+ )
380
+ prediction = model.predict("Do you love soup?")
381
+ assert prediction.label == 1
382
+ assert prediction.label_name == "cats"
383
+
384
+ ClassificationModel.drop("test_predict_with_memoryset_update")
385
+
386
+
387
+ def test_last_prediction_with_batch(classification_model: ClassificationModel):
388
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
389
+ assert classification_model.last_prediction is not None
390
+ assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
391
+ assert classification_model.last_prediction.input_value == "Are cats cute?"
392
+ assert classification_model._last_prediction_was_batch is True
393
+
394
+
395
+ def test_last_prediction_with_single(classification_model: ClassificationModel):
396
+ # Test that last_prediction is updated correctly with single prediction
397
+ prediction = classification_model.predict("Do you love soup?")
398
+ assert classification_model.last_prediction is not None
399
+ assert classification_model.last_prediction.prediction_id == prediction.prediction_id
400
+ assert classification_model.last_prediction.input_value == "Do you love soup?"
401
+ assert classification_model._last_prediction_was_batch is False
402
+
403
+
404
+ @skip_in_ci("We don't have Anthropic API key in CI")
405
+ def test_explain(writable_memoryset: LabeledMemoryset):
406
+
407
+ writable_memoryset.analyze(
408
+ {"name": "distribution", "neighbor_counts": [1, 3]},
409
+ lookup_count=3,
410
+ )
411
+
412
+ model = ClassificationModel.create(
413
+ "test_model_for_explain",
414
+ writable_memoryset,
415
+ num_classes=2,
416
+ memory_lookup_count=3,
417
+ description="This is a test model for explain",
418
+ )
419
+
420
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
421
+ assert len(predictions) == 2
422
+
423
+ try:
424
+ explanation = predictions[0].explanation
425
+ assert explanation is not None
426
+ assert len(explanation) > 10
427
+ assert "soup" in explanation.lower()
428
+ except Exception as e:
429
+ if "ANTHROPIC_API_KEY" in str(e):
430
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
431
+ else:
432
+ raise e
433
+ finally:
434
+ ClassificationModel.drop("test_model_for_explain")
435
+
436
+
437
+ @skip_in_ci("We don't have Anthropic API key in CI")
438
+ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
439
+ """Test getting action recommendations for predictions"""
440
+
441
+ writable_memoryset.analyze(
442
+ {"name": "distribution", "neighbor_counts": [1, 3]},
443
+ lookup_count=3,
444
+ )
445
+
446
+ model = ClassificationModel.create(
447
+ "test_model_for_action",
448
+ writable_memoryset,
449
+ num_classes=2,
450
+ memory_lookup_count=3,
451
+ description="This is a test model for action recommendations",
452
+ )
453
+
454
+ # Make a prediction with expected label to simulate incorrect prediction
455
+ prediction = model.predict("Do you love soup?", expected_labels=1)
456
+
457
+ memoryset_length = model.memoryset.length
458
+
459
+ try:
460
+ # Get action recommendation
461
+ action, rationale = prediction.recommend_action()
462
+
463
+ assert action is not None
464
+ assert rationale is not None
465
+ assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
466
+ assert len(rationale) > 10
467
+
468
+ # Test memory suggestions
469
+ suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
470
+ memory_suggestions = suggestions_response.suggestions
471
+
472
+ assert memory_suggestions is not None
473
+ assert len(memory_suggestions) == 2
474
+
475
+ for suggestion in memory_suggestions:
476
+ assert isinstance(suggestion[0], str)
477
+ assert len(suggestion[0]) > 0
478
+ assert isinstance(suggestion[1], str)
479
+ assert suggestion[1] in model.memoryset.label_names
480
+
481
+ suggestions_response.apply()
482
+
483
+ model.memoryset.refresh()
484
+ assert model.memoryset.length == memoryset_length + 2
485
+
486
+ except Exception as e:
487
+ if "ANTHROPIC_API_KEY" in str(e):
488
+ logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
489
+ else:
490
+ raise e
491
+ finally:
492
+ ClassificationModel.drop("test_model_for_action")
493
+
494
+
495
+ def test_predict_with_prompt(classification_model: ClassificationModel):
496
+ """Test that prompt parameter is properly passed through to predictions"""
497
+ # Test with an instruction-supporting embedding model if available
498
+ prediction_with_prompt = classification_model.predict(
499
+ "I love this product!", prompt="Represent this text for sentiment classification:"
500
+ )
501
+ prediction_without_prompt = classification_model.predict("I love this product!")
502
+
503
+ # Both should work and return valid predictions
504
+ assert prediction_with_prompt.label is not None
505
+ assert prediction_without_prompt.label is not None
506
+
507
+
508
+ @pytest.mark.asyncio
509
+ async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
510
+ """Test async prediction with a single value"""
511
+ prediction = await classification_model.apredict("Do you love soup?")
512
+ assert isinstance(prediction, ClassificationPrediction)
513
+ assert prediction.prediction_id is not None
514
+ assert prediction.label == 0
515
+ assert prediction.label_name == label_names[0]
516
+ assert 0 <= prediction.confidence <= 1
517
+ assert prediction.logits is not None
518
+ assert len(prediction.logits) == 2
519
+
520
+
521
+ @pytest.mark.asyncio
522
+ async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
523
+ """Test async prediction with a batch of values"""
524
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
525
+ assert len(predictions) == 2
526
+ assert predictions[0].prediction_id is not None
527
+ assert predictions[1].prediction_id is not None
528
+ assert predictions[0].label == 0
529
+ assert predictions[0].label_name == label_names[0]
530
+ assert 0 <= predictions[0].confidence <= 1
531
+ assert predictions[1].label == 1
532
+ assert predictions[1].label_name == label_names[1]
533
+ assert 0 <= predictions[1].confidence <= 1
534
+
535
+
536
+ @pytest.mark.asyncio
537
+ async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
538
+ """Test async prediction with expected labels"""
539
+ prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
540
+ assert prediction.expected_label == 1
541
+
542
+
543
+ @pytest.mark.asyncio
544
+ async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
545
+ """Test async prediction with telemetry disabled"""
546
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
547
+ assert len(predictions) == 2
548
+ assert predictions[0].prediction_id is None
549
+ assert predictions[1].prediction_id is None
550
+ assert predictions[0].label == 0
551
+ assert predictions[0].label_name == label_names[0]
552
+ assert 0 <= predictions[0].confidence <= 1
553
+ assert predictions[1].label == 1
554
+ assert predictions[1].label_name == label_names[1]
555
+ assert 0 <= predictions[1].confidence <= 1
556
+
557
+
558
+ @pytest.mark.asyncio
559
+ async def test_predict_async_with_filters(classification_model: ClassificationModel):
560
+ """Test async prediction with filters"""
561
+ # there are no memories with label 0 and key g2, so we force a wrong prediction
562
+ filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
563
+ assert filtered_prediction.label == 1
564
+ assert filtered_prediction.label_name == "cats"