orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,63 +1,71 @@
1
+ import logging
1
2
  from uuid import uuid4
2
3
 
4
+ import numpy as np
3
5
  import pytest
4
- from datasets.arrow_dataset import Dataset
6
+ from datasets import Dataset
5
7
 
6
- from .classification_model import ClassificationModel
8
+ from .classification_model import ClassificationMetrics, ClassificationModel
9
+ from .conftest import skip_in_ci
7
10
  from .datasource import Datasource
8
11
  from .embedding_model import PretrainedEmbeddingModel
9
12
  from .memoryset import LabeledMemoryset
13
+ from .telemetry import ClassificationPrediction
10
14
 
11
15
 
12
- def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
13
- assert model is not None
14
- assert model.name == "test_model"
15
- assert model.memoryset == memoryset
16
- assert model.num_classes == 2
17
- assert model.memory_lookup_count == 3
16
+ def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
17
+ assert classification_model is not None
18
+ assert classification_model.name == "test_classification_model"
19
+ assert classification_model.memoryset == readonly_memoryset
20
+ assert classification_model.num_classes == 2
21
+ assert classification_model.memory_lookup_count == 3
18
22
 
19
23
 
20
- def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
24
+ def test_create_model_already_exists_error(readonly_memoryset, classification_model):
21
25
  with pytest.raises(ValueError):
22
- ClassificationModel.create("test_model", memoryset)
26
+ ClassificationModel.create("test_classification_model", readonly_memoryset)
23
27
  with pytest.raises(ValueError):
24
- ClassificationModel.create("test_model", memoryset, if_exists="error")
28
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
25
29
 
26
30
 
27
- def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
31
+ def test_create_model_already_exists_return(readonly_memoryset, classification_model):
28
32
  with pytest.raises(ValueError):
29
- ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
33
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
30
34
 
31
35
  with pytest.raises(ValueError):
32
- ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
36
+ ClassificationModel.create(
37
+ "test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
38
+ )
33
39
 
34
40
  with pytest.raises(ValueError):
35
- ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
41
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
36
42
 
37
43
  with pytest.raises(ValueError):
38
- ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
44
+ ClassificationModel.create(
45
+ "test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
46
+ )
39
47
 
40
- new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
48
+ new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
41
49
  assert new_model is not None
42
- assert new_model.name == "test_model"
43
- assert new_model.memoryset == memoryset
50
+ assert new_model.name == "test_classification_model"
51
+ assert new_model.memoryset == readonly_memoryset
44
52
  assert new_model.num_classes == 2
45
53
  assert new_model.memory_lookup_count == 3
46
54
 
47
55
 
48
- def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
56
+ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
49
57
  with pytest.raises(ValueError, match="Invalid API key"):
50
- ClassificationModel.create("test_model", memoryset)
58
+ ClassificationModel.create("test_model", readonly_memoryset)
51
59
 
52
60
 
53
- def test_get_model(model: ClassificationModel):
54
- fetched_model = ClassificationModel.open(model.name)
61
+ def test_get_model(classification_model: ClassificationModel):
62
+ fetched_model = ClassificationModel.open(classification_model.name)
55
63
  assert fetched_model is not None
56
- assert fetched_model.id == model.id
57
- assert fetched_model.name == model.name
64
+ assert fetched_model.id == classification_model.id
65
+ assert fetched_model.name == classification_model.name
58
66
  assert fetched_model.num_classes == 2
59
67
  assert fetched_model.memory_lookup_count == 3
60
- assert fetched_model == model
68
+ assert fetched_model == classification_model
61
69
 
62
70
 
63
71
  def test_get_model_unauthenticated(unauthenticated):
@@ -75,12 +83,12 @@ def test_get_model_not_found():
75
83
  ClassificationModel.open(str(uuid4()))
76
84
 
77
85
 
78
- def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
86
+ def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
79
87
  with pytest.raises(LookupError):
80
- ClassificationModel.open(model.name)
88
+ ClassificationModel.open(classification_model.name)
81
89
 
82
90
 
83
- def test_list_models(model: ClassificationModel):
91
+ def test_list_models(classification_model: ClassificationModel):
84
92
  models = ClassificationModel.all()
85
93
  assert len(models) > 0
86
94
  assert any(model.name == model.name for model in models)
@@ -91,21 +99,41 @@ def test_list_models_unauthenticated(unauthenticated):
91
99
  ClassificationModel.all()
92
100
 
93
101
 
94
- def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
102
+ def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
95
103
  assert ClassificationModel.all() == []
96
104
 
97
105
 
98
- def test_delete_model(memoryset: LabeledMemoryset):
99
- ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
106
+ def test_update_model_attributes(classification_model: ClassificationModel):
107
+ classification_model.description = "New description"
108
+ assert classification_model.description == "New description"
109
+
110
+ classification_model.set(description=None)
111
+ assert classification_model.description is None
112
+
113
+ classification_model.set(locked=True)
114
+ assert classification_model.locked is True
115
+
116
+ classification_model.set(locked=False)
117
+ assert classification_model.locked is False
118
+
119
+ classification_model.lock()
120
+ assert classification_model.locked is True
121
+
122
+ classification_model.unlock()
123
+ assert classification_model.locked is False
124
+
125
+
126
+ def test_delete_model(readonly_memoryset: LabeledMemoryset):
127
+ ClassificationModel.create("model_to_delete", LabeledMemoryset.open(readonly_memoryset.name))
100
128
  assert ClassificationModel.open("model_to_delete")
101
129
  ClassificationModel.drop("model_to_delete")
102
130
  with pytest.raises(LookupError):
103
131
  ClassificationModel.open("model_to_delete")
104
132
 
105
133
 
106
- def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
134
+ def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
107
135
  with pytest.raises(ValueError, match="Invalid API key"):
108
- ClassificationModel.drop(model.name)
136
+ ClassificationModel.drop(classification_model.name)
109
137
 
110
138
 
111
139
  def test_delete_model_not_found():
@@ -115,53 +143,82 @@ def test_delete_model_not_found():
115
143
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
116
144
 
117
145
 
118
- def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
146
+ def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
119
147
  with pytest.raises(LookupError):
120
- ClassificationModel.drop(model.name)
148
+ ClassificationModel.drop(classification_model.name)
121
149
 
122
150
 
123
- @pytest.mark.flaky
124
151
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
125
- memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
152
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset)
126
153
  ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
127
154
  with pytest.raises(RuntimeError):
128
155
  LabeledMemoryset.drop(memoryset.id)
129
156
 
130
157
 
131
- def test_evaluate(model):
132
- eval_datasource = Datasource.from_list(
133
- "eval_datasource",
134
- [
135
- {"text": "chicken noodle soup is the best", "label": 1},
136
- {"text": "cats are cute", "label": 0},
137
- {"text": "soup is great for the winter", "label": 0},
138
- {"text": "i love cats", "label": 1},
139
- ],
158
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
159
+ def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
160
+ result = (
161
+ classification_model.evaluate(eval_dataset)
162
+ if data_type == "dataset"
163
+ else classification_model.evaluate(eval_datasource)
140
164
  )
141
- result = model.evaluate(eval_datasource, value_column="text")
165
+
142
166
  assert result is not None
143
- assert isinstance(result["accuracy"], float)
144
- assert isinstance(result["f1_score"], float)
145
- assert isinstance(result["loss"], float)
146
-
147
-
148
- def test_evaluate_with_telemetry(model):
149
- samples = [
150
- {"text": "chicken noodle soup is the best", "label": 1},
151
- {"text": "cats are cute", "label": 0},
152
- ]
153
- eval_datasource = Datasource.from_list("eval_datasource_2", samples)
154
- result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
167
+ assert isinstance(result, ClassificationMetrics)
168
+
169
+ assert isinstance(result.accuracy, float)
170
+ assert np.allclose(result.accuracy, 0.5)
171
+ assert isinstance(result.f1_score, float)
172
+ assert np.allclose(result.f1_score, 0.5)
173
+ assert isinstance(result.loss, float)
174
+
175
+ assert isinstance(result.anomaly_score_mean, float)
176
+ assert isinstance(result.anomaly_score_median, float)
177
+ assert isinstance(result.anomaly_score_variance, float)
178
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
179
+ assert -1.0 <= result.anomaly_score_median <= 1.0
180
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
181
+
182
+ assert result.pr_auc is not None
183
+ assert np.allclose(result.pr_auc, 0.75)
184
+ assert result.pr_curve is not None
185
+ assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
186
+ assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
187
+ assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
188
+
189
+ assert result.roc_auc is not None
190
+ assert np.allclose(result.roc_auc, 0.625)
191
+ assert result.roc_curve is not None
192
+ assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
193
+ assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
194
+ assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
195
+
196
+
197
+ def test_evaluate_datasource_with_nones_raises_error(classification_model: ClassificationModel, datasource: Datasource):
198
+ with pytest.raises(ValueError):
199
+ classification_model.evaluate(datasource, record_predictions=True, tags={"test"})
200
+
201
+
202
+ def test_evaluate_dataset_with_nones_raises_error(classification_model: ClassificationModel, hf_dataset: Dataset):
203
+ with pytest.raises(ValueError):
204
+ classification_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
205
+
206
+
207
+ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
208
+ result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
155
209
  assert result is not None
156
- predictions = model.predictions(tag="test")
157
- assert len(predictions) == 2
210
+ assert isinstance(result, ClassificationMetrics)
211
+ predictions = classification_model.predictions(tag="test")
212
+ assert len(predictions) == 4
158
213
  assert all(p.tags == {"test"} for p in predictions)
159
- assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
214
+ assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
160
215
 
161
216
 
162
- def test_predict(model: ClassificationModel, label_names: list[str]):
163
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
217
+ def test_predict(classification_model: ClassificationModel, label_names: list[str]):
218
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
164
219
  assert len(predictions) == 2
220
+ assert predictions[0].prediction_id is not None
221
+ assert predictions[1].prediction_id is not None
165
222
  assert predictions[0].label == 0
166
223
  assert predictions[0].label_name == label_names[0]
167
224
  assert 0 <= predictions[0].confidence <= 1
@@ -169,29 +226,59 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
169
226
  assert predictions[1].label_name == label_names[1]
170
227
  assert 0 <= predictions[1].confidence <= 1
171
228
 
229
+ assert predictions[0].logits is not None
230
+ assert predictions[1].logits is not None
231
+ assert len(predictions[0].logits) == 2
232
+ assert len(predictions[1].logits) == 2
233
+ assert predictions[0].logits[0] > predictions[0].logits[1]
234
+ assert predictions[1].logits[0] < predictions[1].logits[1]
235
+
236
+
237
+ def test_classification_prediction_has_no_label(classification_model: ClassificationModel):
238
+ """Ensure optional score is None for classification predictions."""
239
+ prediction = classification_model.predict("Do you want to go to the beach?")
240
+ assert isinstance(prediction, ClassificationPrediction)
241
+ assert prediction.label is None
242
+
172
243
 
173
- def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
244
+ def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
245
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
246
+ assert len(predictions) == 2
247
+ assert predictions[0].prediction_id is None
248
+ assert predictions[1].prediction_id is None
249
+ assert predictions[0].label == 0
250
+ assert predictions[0].label_name == label_names[0]
251
+ assert 0 <= predictions[0].confidence <= 1
252
+ assert predictions[1].label == 1
253
+ assert predictions[1].label_name == label_names[1]
254
+ assert 0 <= predictions[1].confidence <= 1
255
+
256
+
257
+ def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
174
258
  with pytest.raises(ValueError, match="Invalid API key"):
175
- model.predict(["Do you love soup?", "Are cats cute?"])
259
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
176
260
 
177
261
 
178
- def test_predict_unauthorized(unauthorized, model: ClassificationModel):
262
+ def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
179
263
  with pytest.raises(LookupError):
180
- model.predict(["Do you love soup?", "Are cats cute?"])
264
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
181
265
 
182
266
 
183
- def test_predict_constraint_violation(memoryset: LabeledMemoryset):
267
+ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
184
268
  model = ClassificationModel.create(
185
- "test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
269
+ "test_model_lookup_count_too_high",
270
+ readonly_memoryset,
271
+ num_classes=2,
272
+ memory_lookup_count=readonly_memoryset.length + 2,
186
273
  )
187
274
  with pytest.raises(RuntimeError):
188
275
  model.predict("test")
189
276
 
190
277
 
191
- def test_record_prediction_feedback(model: ClassificationModel):
192
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
278
+ def test_record_prediction_feedback(classification_model: ClassificationModel):
279
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
193
280
  expected_labels = [0, 1]
194
- model.record_feedback(
281
+ classification_model.record_feedback(
195
282
  {
196
283
  "prediction_id": p.prediction_id,
197
284
  "category": "correct",
@@ -201,66 +288,209 @@ def test_record_prediction_feedback(model: ClassificationModel):
201
288
  )
202
289
 
203
290
 
204
- def test_record_prediction_feedback_missing_category(model: ClassificationModel):
205
- prediction = model.predict("Do you love soup?")
291
+ def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
292
+ prediction = classification_model.predict("Do you love soup?")
206
293
  with pytest.raises(ValueError):
207
- model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
294
+ classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
208
295
 
209
296
 
210
- def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
211
- prediction = model.predict("Do you love soup?")
297
+ def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
298
+ prediction = classification_model.predict("Do you love soup?")
212
299
  with pytest.raises(ValueError, match=r"Invalid input.*"):
213
- model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
300
+ classification_model.record_feedback(
301
+ {"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
302
+ )
214
303
 
215
304
 
216
- def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
305
+ def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
217
306
  with pytest.raises(ValueError, match=r"Invalid input.*"):
218
- model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
307
+ classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
219
308
 
220
309
 
221
- def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
310
+ def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
222
311
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
223
312
  "test_memoryset_inverted_labels",
224
313
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
225
- value_column="text",
226
314
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
227
315
  )
228
- with model.use_memoryset(inverted_labeled_memoryset):
229
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
316
+ with classification_model.use_memoryset(inverted_labeled_memoryset):
317
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
230
318
  assert predictions[0].label == 1
231
319
  assert predictions[1].label == 0
232
320
 
233
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
321
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
234
322
  assert predictions[0].label == 0
235
323
  assert predictions[1].label == 1
236
324
 
237
325
 
238
- def test_predict_with_expected_labels(model: ClassificationModel):
239
- prediction = model.predict("Do you love soup?", expected_labels=1)
326
+ def test_predict_with_expected_labels(classification_model: ClassificationModel):
327
+ prediction = classification_model.predict("Do you love soup?", expected_labels=1)
240
328
  assert prediction.expected_label == 1
241
329
 
242
330
 
243
- def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
331
+ def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
244
332
  # invalid number of expected labels for batch prediction
245
333
  with pytest.raises(ValueError, match=r"Invalid input.*"):
246
- model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
334
+ classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
247
335
  # invalid label value
248
336
  with pytest.raises(ValueError):
249
- model.predict("Do you love soup?", expected_labels=5)
337
+ classification_model.predict("Do you love soup?", expected_labels=5)
250
338
 
251
339
 
252
- def test_last_prediction_with_batch(model: ClassificationModel):
253
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
254
- assert model.last_prediction is not None
255
- assert model.last_prediction.prediction_id == predictions[-1].prediction_id
256
- assert model.last_prediction.input_value == "Are cats cute?"
257
- assert model._last_prediction_was_batch is True
340
+ def test_predict_with_filters(classification_model: ClassificationModel):
341
+ # there are no memories with label 0 and key g1, so we force a wrong prediction
342
+ filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
343
+ assert filtered_prediction.label == 1
344
+ assert filtered_prediction.label_name == "cats"
258
345
 
259
346
 
260
- def test_last_prediction_with_single(model: ClassificationModel):
261
- # Test that last_prediction is updated correctly with single prediction
347
+ def test_predict_with_memoryset_update(writable_memoryset: LabeledMemoryset):
348
+ model = ClassificationModel.create(
349
+ "test_predict_with_memoryset_update",
350
+ writable_memoryset,
351
+ num_classes=2,
352
+ memory_lookup_count=3,
353
+ )
354
+
262
355
  prediction = model.predict("Do you love soup?")
263
- assert model.last_prediction is not None
264
- assert model.last_prediction.prediction_id == prediction.prediction_id
265
- assert model.last_prediction.input_value == "Do you love soup?"
266
- assert model._last_prediction_was_batch is False
356
+ assert prediction.label == 0
357
+ assert prediction.label_name == "soup"
358
+
359
+ # insert new memories
360
+ writable_memoryset.insert(
361
+ [
362
+ {"value": "Do you love soup?", "label": 1, "key": "g1"},
363
+ {"value": "Do you love soup for dinner?", "label": 1, "key": "g2"},
364
+ {"value": "Do you love crackers?", "label": 1, "key": "g2"},
365
+ {"value": "Do you love broth?", "label": 1, "key": "g2"},
366
+ {"value": "Do you love chicken soup?", "label": 1, "key": "g2"},
367
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
368
+ {"value": "Do you love chicken soup for dinner?", "label": 1, "key": "g2"},
369
+ ],
370
+ )
371
+ prediction = model.predict("Do you love soup?")
372
+ assert prediction.label == 1
373
+ assert prediction.label_name == "cats"
374
+
375
+ ClassificationModel.drop("test_predict_with_memoryset_update")
376
+
377
+
378
+ def test_last_prediction_with_batch(classification_model: ClassificationModel):
379
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
380
+ assert classification_model.last_prediction is not None
381
+ assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
382
+ assert classification_model.last_prediction.input_value == "Are cats cute?"
383
+ assert classification_model._last_prediction_was_batch is True
384
+
385
+
386
+ def test_last_prediction_with_single(classification_model: ClassificationModel):
387
+ # Test that last_prediction is updated correctly with single prediction
388
+ prediction = classification_model.predict("Do you love soup?")
389
+ assert classification_model.last_prediction is not None
390
+ assert classification_model.last_prediction.prediction_id == prediction.prediction_id
391
+ assert classification_model.last_prediction.input_value == "Do you love soup?"
392
+ assert classification_model._last_prediction_was_batch is False
393
+
394
+
395
+ @skip_in_ci("We don't have Anthropic API key in CI")
396
+ def test_explain(writable_memoryset: LabeledMemoryset):
397
+
398
+ writable_memoryset.analyze(
399
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
400
+ lookup_count=3,
401
+ )
402
+
403
+ model = ClassificationModel.create(
404
+ "test_model_for_explain",
405
+ writable_memoryset,
406
+ num_classes=2,
407
+ memory_lookup_count=3,
408
+ description="This is a test model for explain",
409
+ )
410
+
411
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
412
+ assert len(predictions) == 2
413
+
414
+ try:
415
+ explanation = predictions[0].explanation
416
+ assert explanation is not None
417
+ assert len(explanation) > 10
418
+ assert "soup" in explanation.lower()
419
+ except Exception as e:
420
+ if "ANTHROPIC_API_KEY" in str(e):
421
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
422
+ else:
423
+ raise e
424
+ finally:
425
+ ClassificationModel.drop("test_model_for_explain")
426
+
427
+
428
+ @skip_in_ci("We don't have Anthropic API key in CI")
429
+ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
430
+ """Test getting action recommendations for predictions"""
431
+
432
+ writable_memoryset.analyze(
433
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
434
+ lookup_count=3,
435
+ )
436
+
437
+ model = ClassificationModel.create(
438
+ "test_model_for_action",
439
+ writable_memoryset,
440
+ num_classes=2,
441
+ memory_lookup_count=3,
442
+ description="This is a test model for action recommendations",
443
+ )
444
+
445
+ # Make a prediction with expected label to simulate incorrect prediction
446
+ prediction = model.predict("Do you love soup?", expected_labels=1)
447
+
448
+ memoryset_length = model.memoryset.length
449
+
450
+ try:
451
+ # Get action recommendation
452
+ action, rationale = prediction.recommend_action()
453
+
454
+ assert action is not None
455
+ assert rationale is not None
456
+ assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
457
+ assert len(rationale) > 10
458
+
459
+ # Test memory suggestions
460
+ suggestions_response = prediction.generate_memory_suggestions(num_memories=2)
461
+ memory_suggestions = suggestions_response.suggestions
462
+
463
+ assert memory_suggestions is not None
464
+ assert len(memory_suggestions) == 2
465
+
466
+ for suggestion in memory_suggestions:
467
+ assert isinstance(suggestion[0], str)
468
+ assert len(suggestion[0]) > 0
469
+ assert isinstance(suggestion[1], str)
470
+ assert suggestion[1] in model.memoryset.label_names
471
+
472
+ suggestions_response.apply()
473
+
474
+ model.memoryset.refresh()
475
+ assert model.memoryset.length == memoryset_length + 2
476
+
477
+ except Exception as e:
478
+ if "ANTHROPIC_API_KEY" in str(e):
479
+ logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
480
+ else:
481
+ raise e
482
+ finally:
483
+ ClassificationModel.drop("test_model_for_action")
484
+
485
+
486
+ def test_predict_with_prompt(classification_model: ClassificationModel):
487
+ """Test that prompt parameter is properly passed through to predictions"""
488
+ # Test with an instruction-supporting embedding model if available
489
+ prediction_with_prompt = classification_model.predict(
490
+ "I love this product!", prompt="Represent this text for sentiment classification:"
491
+ )
492
+ prediction_without_prompt = classification_model.predict("I love this product!")
493
+
494
+ # Both should work and return valid predictions
495
+ assert prediction_with_prompt.label is not None
496
+ assert prediction_without_prompt.label is not None