orca-sdk 0.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. orca_sdk/__init__.py +24 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +205 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
  22. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
  23. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
  24. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
  25. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
  26. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
  27. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
  28. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
  30. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
  34. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
  35. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
  36. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
  37. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  38. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
  39. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
  40. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
  41. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
  42. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
  43. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
  45. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
  46. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
  47. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
  48. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
  49. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
  50. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
  51. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
  52. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
  53. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
  54. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
  57. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
  58. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
  59. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  60. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
  61. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
  62. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
  63. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
  65. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
  66. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
  67. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
  68. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
  69. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
  70. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
  71. orca_sdk/_generated_api_client/client.py +216 -0
  72. orca_sdk/_generated_api_client/errors.py +38 -0
  73. orca_sdk/_generated_api_client/models/__init__.py +179 -0
  74. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
  75. orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
  76. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
  77. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  78. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  79. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
  80. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  81. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  82. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  83. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  84. orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
  85. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
  86. orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
  87. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
  88. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
  89. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  90. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  91. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  92. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  93. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
  94. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
  95. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
  96. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  97. orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
  98. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  99. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  100. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  101. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  102. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  103. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  104. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
  105. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  106. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  107. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  108. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  109. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  110. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  111. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  112. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  113. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  114. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  115. orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
  116. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
  117. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  118. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
  119. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  120. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  121. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  122. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  123. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
  124. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  125. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  126. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  127. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  128. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  129. orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
  130. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  131. orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
  132. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  133. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  134. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  135. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  136. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  137. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  138. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
  139. orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
  140. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  141. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  142. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  143. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  144. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  145. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
  146. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
  147. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  148. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
  149. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  150. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  151. orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
  152. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  153. orca_sdk/_generated_api_client/models/task.py +198 -0
  154. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  155. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  156. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  157. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  158. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  159. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  160. orca_sdk/_generated_api_client/py.typed +1 -0
  161. orca_sdk/_generated_api_client/types.py +56 -0
  162. orca_sdk/_utils/__init__.py +0 -0
  163. orca_sdk/_utils/analysis_ui.py +192 -0
  164. orca_sdk/_utils/analysis_ui_style.css +54 -0
  165. orca_sdk/_utils/auth.py +68 -0
  166. orca_sdk/_utils/auth_test.py +31 -0
  167. orca_sdk/_utils/common.py +37 -0
  168. orca_sdk/_utils/data_parsing.py +99 -0
  169. orca_sdk/_utils/data_parsing_test.py +244 -0
  170. orca_sdk/_utils/prediction_result_ui.css +18 -0
  171. orca_sdk/_utils/prediction_result_ui.py +64 -0
  172. orca_sdk/_utils/task.py +73 -0
  173. orca_sdk/classification_model.py +508 -0
  174. orca_sdk/classification_model_test.py +272 -0
  175. orca_sdk/conftest.py +116 -0
  176. orca_sdk/credentials.py +126 -0
  177. orca_sdk/credentials_test.py +37 -0
  178. orca_sdk/datasource.py +333 -0
  179. orca_sdk/datasource_test.py +96 -0
  180. orca_sdk/embedding_model.py +347 -0
  181. orca_sdk/embedding_model_test.py +176 -0
  182. orca_sdk/memoryset.py +1209 -0
  183. orca_sdk/memoryset_test.py +287 -0
  184. orca_sdk/telemetry.py +398 -0
  185. orca_sdk/telemetry_test.py +109 -0
  186. orca_sdk-0.0.78.dist-info/METADATA +79 -0
  187. orca_sdk-0.0.78.dist-info/RECORD +188 -0
  188. orca_sdk-0.0.78.dist-info/WHEEL +4 -0
@@ -0,0 +1,272 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+ from datasets.arrow_dataset import Dataset
5
+
6
+ from .classification_model import ClassificationModel
7
+ from .datasource import Datasource
8
+ from .embedding_model import PretrainedEmbeddingModel
9
+ from .memoryset import LabeledMemoryset
10
+
11
+
12
+ def test_create_model(model: ClassificationModel, memoryset: LabeledMemoryset):
13
+ assert model is not None
14
+ assert model.name == "test_model"
15
+ assert model.memoryset == memoryset
16
+ assert model.num_classes == 2
17
+ assert model.memory_lookup_count == 3
18
+
19
+
20
+ def test_create_model_already_exists_error(memoryset, model: ClassificationModel):
21
+ with pytest.raises(ValueError):
22
+ ClassificationModel.create("test_model", memoryset)
23
+ with pytest.raises(ValueError):
24
+ ClassificationModel.create("test_model", memoryset, if_exists="error")
25
+
26
+
27
+ def test_create_model_already_exists_return(memoryset, model: ClassificationModel):
28
+ with pytest.raises(ValueError):
29
+ ClassificationModel.create("test_model", memoryset, if_exists="open", head_type="MMOE")
30
+
31
+ with pytest.raises(ValueError):
32
+ ClassificationModel.create("test_model", memoryset, if_exists="open", memory_lookup_count=37)
33
+
34
+ with pytest.raises(ValueError):
35
+ ClassificationModel.create("test_model", memoryset, if_exists="open", num_classes=19)
36
+
37
+ with pytest.raises(ValueError):
38
+ ClassificationModel.create("test_model", memoryset, if_exists="open", min_memory_weight=0.77)
39
+
40
+ new_model = ClassificationModel.create("test_model", memoryset, if_exists="open")
41
+ assert new_model is not None
42
+ assert new_model.name == "test_model"
43
+ assert new_model.memoryset == memoryset
44
+ assert new_model.num_classes == 2
45
+ assert new_model.memory_lookup_count == 3
46
+
47
+
48
+ def test_create_model_unauthenticated(unauthenticated, memoryset: LabeledMemoryset):
49
+ with pytest.raises(ValueError, match="Invalid API key"):
50
+ ClassificationModel.create("test_model", memoryset)
51
+
52
+
53
+ def test_get_model(model: ClassificationModel):
54
+ fetched_model = ClassificationModel.open(model.name)
55
+ assert fetched_model is not None
56
+ assert fetched_model.id == model.id
57
+ assert fetched_model.name == model.name
58
+ assert fetched_model.num_classes == 2
59
+ assert fetched_model.memory_lookup_count == 3
60
+ assert fetched_model == model
61
+
62
+
63
+ def test_get_model_unauthenticated(unauthenticated):
64
+ with pytest.raises(ValueError, match="Invalid API key"):
65
+ ClassificationModel.open("test_model")
66
+
67
+
68
+ def test_get_model_invalid_input():
69
+ with pytest.raises(ValueError, match="Invalid input"):
70
+ ClassificationModel.open("not valid id")
71
+
72
+
73
+ def test_get_model_not_found():
74
+ with pytest.raises(LookupError):
75
+ ClassificationModel.open(str(uuid4()))
76
+
77
+
78
+ def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
79
+ with pytest.raises(LookupError):
80
+ ClassificationModel.open(model.name)
81
+
82
+
83
+ def test_list_models(model: ClassificationModel):
84
+ models = ClassificationModel.all()
85
+ assert len(models) > 0
86
+ assert any(model.name == model.name for model in models)
87
+
88
+
89
+ def test_list_models_unauthenticated(unauthenticated):
90
+ with pytest.raises(ValueError, match="Invalid API key"):
91
+ ClassificationModel.all()
92
+
93
+
94
+ def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
95
+ assert ClassificationModel.all() == []
96
+
97
+
98
+ def test_delete_model(memoryset: LabeledMemoryset):
99
+ ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
100
+ assert ClassificationModel.open("model_to_delete")
101
+ ClassificationModel.drop("model_to_delete")
102
+ with pytest.raises(LookupError):
103
+ ClassificationModel.open("model_to_delete")
104
+
105
+
106
+ def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
107
+ with pytest.raises(ValueError, match="Invalid API key"):
108
+ ClassificationModel.drop(model.name)
109
+
110
+
111
+ def test_delete_model_not_found():
112
+ with pytest.raises(LookupError):
113
+ ClassificationModel.drop(str(uuid4()))
114
+ # ignores error if specified
115
+ ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
116
+
117
+
118
+ def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
119
+ with pytest.raises(LookupError):
120
+ ClassificationModel.drop(model.name)
121
+
122
+
123
+ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
124
+ memoryset = LabeledMemoryset.from_hf_dataset("test_memoryset_delete_before_model", hf_dataset, value_column="text")
125
+ ClassificationModel.create("test_model_delete_before_memoryset", memoryset)
126
+ with pytest.raises(RuntimeError):
127
+ LabeledMemoryset.drop(memoryset.id)
128
+
129
+
130
+ def test_evaluate(model):
131
+ eval_datasource = Datasource.from_list(
132
+ "eval_datasource",
133
+ [
134
+ {"text": "chicken noodle soup is the best", "label": 1},
135
+ {"text": "cats are cute", "label": 0},
136
+ {"text": "soup is great for the winter", "label": 0},
137
+ {"text": "i love cats", "label": 1},
138
+ ],
139
+ )
140
+ result = model.evaluate(eval_datasource, value_column="text")
141
+ assert result is not None
142
+ assert isinstance(result, dict)
143
+ assert isinstance(result["accuracy"], float)
144
+ assert isinstance(result["f1_score"], float)
145
+ assert isinstance(result["loss"], float)
146
+ assert len(result["precision_recall_curve"]["thresholds"]) == 4
147
+ assert len(result["precision_recall_curve"]["precisions"]) == 4
148
+ assert len(result["precision_recall_curve"]["recalls"]) == 4
149
+ assert len(result["roc_curve"]["thresholds"]) == 4
150
+ assert len(result["roc_curve"]["false_positive_rates"]) == 4
151
+ assert len(result["roc_curve"]["true_positive_rates"]) == 4
152
+
153
+
154
+ def test_evaluate_with_telemetry(model):
155
+ samples = [
156
+ {"text": "chicken noodle soup is the best", "label": 1},
157
+ {"text": "cats are cute", "label": 0},
158
+ ]
159
+ eval_datasource = Datasource.from_list("eval_datasource_2", samples)
160
+ result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
161
+ assert result is not None
162
+ predictions = model.predictions(tag="test")
163
+ assert len(predictions) == 2
164
+ assert all(p.tags == {"test"} for p in predictions)
165
+ assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
166
+
167
+
168
+ def test_predict(model: ClassificationModel, label_names: list[str]):
169
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
170
+ assert len(predictions) == 2
171
+ assert predictions[0].label == 0
172
+ assert predictions[0].label_name == label_names[0]
173
+ assert 0 <= predictions[0].confidence <= 1
174
+ assert predictions[1].label == 1
175
+ assert predictions[1].label_name == label_names[1]
176
+ assert 0 <= predictions[1].confidence <= 1
177
+
178
+
179
+ def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
180
+ with pytest.raises(ValueError, match="Invalid API key"):
181
+ model.predict(["Do you love soup?", "Are cats cute?"])
182
+
183
+
184
+ def test_predict_unauthorized(unauthorized, model: ClassificationModel):
185
+ with pytest.raises(LookupError):
186
+ model.predict(["Do you love soup?", "Are cats cute?"])
187
+
188
+
189
+ def test_predict_constraint_violation(memoryset: LabeledMemoryset):
190
+ model = ClassificationModel.create(
191
+ "test_model_lookup_count_too_high", memoryset, num_classes=2, memory_lookup_count=memoryset.length + 2
192
+ )
193
+ with pytest.raises(RuntimeError):
194
+ model.predict("test")
195
+
196
+
197
+ def test_record_prediction_feedback(model: ClassificationModel):
198
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
199
+ expected_labels = [0, 1]
200
+ model.record_feedback(
201
+ {
202
+ "prediction_id": p.prediction_id,
203
+ "category": "correct",
204
+ "value": p.label == expected_label,
205
+ }
206
+ for expected_label, p in zip(expected_labels, predictions)
207
+ )
208
+
209
+
210
+ def test_record_prediction_feedback_missing_category(model: ClassificationModel):
211
+ prediction = model.predict("Do you love soup?")
212
+ with pytest.raises(ValueError):
213
+ model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
214
+
215
+
216
+ def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
217
+ prediction = model.predict("Do you love soup?")
218
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
219
+ model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
220
+
221
+
222
+ def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
223
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
224
+ model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
225
+
226
+
227
+ def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
228
+ inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
229
+ "test_memoryset_inverted_labels",
230
+ hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
231
+ value_column="text",
232
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
233
+ )
234
+ with model.use_memoryset(inverted_labeled_memoryset):
235
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
236
+ assert predictions[0].label == 1
237
+ assert predictions[1].label == 0
238
+
239
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
240
+ assert predictions[0].label == 0
241
+ assert predictions[1].label == 1
242
+
243
+
244
+ def test_predict_with_expected_labels(model: ClassificationModel):
245
+ prediction = model.predict("Do you love soup?", expected_labels=1)
246
+ assert prediction.expected_label == 1
247
+
248
+
249
+ def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
250
+ # invalid number of expected labels for batch prediction
251
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
252
+ model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
253
+ # invalid label value
254
+ with pytest.raises(ValueError):
255
+ model.predict("Do you love soup?", expected_labels=5)
256
+
257
+
258
+ def test_last_prediction_with_batch(model: ClassificationModel):
259
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
260
+ assert model.last_prediction is not None
261
+ assert model.last_prediction.prediction_id == predictions[-1].prediction_id
262
+ assert model.last_prediction.input_value == "Are cats cute?"
263
+ assert model._last_prediction_was_batch is True
264
+
265
+
266
+ def test_last_prediction_with_single(model: ClassificationModel):
267
+ # Test that last_prediction is updated correctly with single prediction
268
+ prediction = model.predict("Do you love soup?")
269
+ assert model.last_prediction is not None
270
+ assert model.last_prediction.prediction_id == prediction.prediction_id
271
+ assert model.last_prediction.input_value == "Do you love soup?"
272
+ assert model._last_prediction_was_batch is False
orca_sdk/conftest.py ADDED
@@ -0,0 +1,116 @@
1
+ import logging
2
+ import os
3
+ from typing import Generator
4
+ from uuid import uuid4
5
+
6
+ import pytest
7
+ from datasets import ClassLabel, Dataset, Features, Value
8
+
9
+ from ._utils.auth import _create_api_key, _delete_org
10
+ from .classification_model import ClassificationModel
11
+ from .credentials import OrcaCredentials
12
+ from .datasource import Datasource
13
+ from .embedding_model import PretrainedEmbeddingModel
14
+ from .memoryset import LabeledMemoryset
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
19
+
20
+
21
+ def _create_org_id():
22
+ # UUID start to identify test data (0xtest...)
23
+ return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
24
+
25
+
26
+ @pytest.fixture(scope="session")
27
+ def org_id():
28
+ return _create_org_id()
29
+
30
+
31
+ @pytest.fixture(autouse=True, scope="session")
32
+ def api_key(org_id) -> Generator[str, None, None]:
33
+ api_key = _create_api_key(org_id=org_id, name="orca_sdk_test")
34
+ OrcaCredentials.set_api_key(api_key, check_validity=True)
35
+ yield api_key
36
+ _delete_org(org_id)
37
+
38
+
39
+ @pytest.fixture(autouse=True)
40
+ def authenticated(api_key):
41
+ OrcaCredentials.set_api_key(api_key, check_validity=False)
42
+
43
+
44
+ @pytest.fixture()
45
+ def unauthenticated(api_key):
46
+ OrcaCredentials.set_api_key(str(uuid4()), check_validity=False)
47
+ yield
48
+ # Need to reset the api key to the original api key so following tests don't fail
49
+ OrcaCredentials.set_api_key(api_key, check_validity=False)
50
+
51
+
52
+ @pytest.fixture()
53
+ def other_org_id():
54
+ return _create_org_id()
55
+
56
+
57
+ @pytest.fixture()
58
+ def unauthorized(api_key, other_org_id):
59
+ different_api_key = _create_api_key(org_id=other_org_id, name="orca_sdk_test_other_org")
60
+ OrcaCredentials.set_api_key(different_api_key, check_validity=False)
61
+ yield
62
+ OrcaCredentials.set_api_key(api_key, check_validity=False)
63
+ _delete_org(other_org_id)
64
+
65
+
66
+ @pytest.fixture(scope="session")
67
+ def label_names():
68
+ return ["soup", "cats"]
69
+
70
+
71
+ SAMPLE_DATA = [
72
+ {"text": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
73
+ {"text": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
74
+ {"text": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
75
+ {"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
76
+ {"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
77
+ {"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
78
+ ]
79
+
80
+
81
+ @pytest.fixture(scope="session")
82
+ def hf_dataset(label_names):
83
+ return Dataset.from_list(
84
+ SAMPLE_DATA,
85
+ features=Features(
86
+ {
87
+ "text": Value("string"),
88
+ "label": ClassLabel(names=label_names),
89
+ "key": Value("string"),
90
+ "score": Value("float"),
91
+ "source_id": Value("string"),
92
+ }
93
+ ),
94
+ )
95
+
96
+
97
+ @pytest.fixture(scope="session")
98
+ def datasource(hf_dataset) -> Datasource:
99
+ return Datasource.from_hf_dataset("test_datasource", hf_dataset)
100
+
101
+
102
+ @pytest.fixture(scope="session")
103
+ def memoryset(datasource) -> LabeledMemoryset:
104
+ return LabeledMemoryset.create(
105
+ "test_memoryset",
106
+ datasource=datasource,
107
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
108
+ value_column="text",
109
+ source_id_column="source_id",
110
+ max_seq_length_override=32,
111
+ )
112
+
113
+
114
+ @pytest.fixture(scope="session")
115
+ def model(memoryset) -> ClassificationModel:
116
+ return ClassificationModel.create("test_model", memoryset, num_classes=2, memory_lookup_count=3)
@@ -0,0 +1,126 @@
1
+ from datetime import datetime
2
+ from typing import Literal, NamedTuple
3
+
4
+ from ._generated_api_client.api import (
5
+ check_authentication,
6
+ create_api_key,
7
+ delete_api_key,
8
+ list_api_keys,
9
+ )
10
+ from ._generated_api_client.client import get_base_url, get_headers, set_headers
11
+ from ._generated_api_client.models import (
12
+ CreateApiKeyRequest,
13
+ CreateApiKeyRequestScopeItem,
14
+ )
15
+
16
+ Scope = Literal["ADMINISTER", "PREDICT"]
17
+ """
18
+ The scopes of an API key.
19
+
20
+ - `ADMINISTER`: Can do anything, including creating and deleting organizations, models, and API keys.
21
+ - `PREDICT`: Can only call model.predict and perform CRUD operations on predictions.
22
+ """
23
+
24
+
25
+ class ApiKeyInfo(NamedTuple):
26
+ """
27
+ Named tuple containing information about an API key
28
+
29
+ Attributes:
30
+ name: Unique name of the API key
31
+ created_at: When the API key was created
32
+ """
33
+
34
+ name: str
35
+ created_at: datetime
36
+ scopes: set[Scope]
37
+
38
+
39
+ class OrcaCredentials:
40
+ """
41
+ Class for managing Orca API credentials
42
+ """
43
+
44
+ @staticmethod
45
+ def get_api_url() -> str:
46
+ """
47
+ Get the Orca API base URL that is currently being used
48
+ """
49
+ return get_base_url()
50
+
51
+ @staticmethod
52
+ def list_api_keys() -> list[ApiKeyInfo]:
53
+ """
54
+ List all API keys that have been created for your org
55
+
56
+ Returns:
57
+ A list of named tuples, with the name and creation date time of the API key
58
+ """
59
+ return [
60
+ ApiKeyInfo(name=api_key.name, created_at=api_key.created_at, scopes=set(s.value for s in api_key.scope))
61
+ for api_key in list_api_keys()
62
+ ]
63
+
64
+ @staticmethod
65
+ def is_authenticated() -> bool:
66
+ """
67
+ Check if you are authenticated to interact with the Orca API
68
+
69
+ Returns:
70
+ True if you are authenticated, False otherwise
71
+ """
72
+ try:
73
+ return check_authentication()
74
+ except ValueError as e:
75
+ if "Invalid API key" in str(e):
76
+ return False
77
+ raise e
78
+
79
+ @staticmethod
80
+ def create_api_key(name: str, scopes: set[Scope] = {"ADMINISTER"}) -> str:
81
+ """
82
+ Create a new API key with the given name and scopes
83
+
84
+ Params:
85
+ name: The name of the API key
86
+ scopes: The scopes of the API key
87
+
88
+ Returns:
89
+ The secret value of the API key. Make sure to save this value as it will not be shown again.
90
+ """
91
+ res = create_api_key(
92
+ body=CreateApiKeyRequest(name=name, scope=[CreateApiKeyRequestScopeItem(scope) for scope in scopes])
93
+ )
94
+ return res.api_key
95
+
96
+ @staticmethod
97
+ def revoke_api_key(name: str) -> None:
98
+ """
99
+ Delete an API key
100
+
101
+ Params:
102
+ name: The name of the API key to delete
103
+
104
+ Raises:
105
+ ValueError: if the API key is not found
106
+ """
107
+ delete_api_key(name_or_id=name)
108
+
109
+ @staticmethod
110
+ def set_api_key(api_key: str, check_validity: bool = True):
111
+ """
112
+ Set the API key to use for authenticating with the Orca API
113
+
114
+ Note:
115
+ The API key can also be provided by setting the `ORCA_API_KEY` environment variable
116
+
117
+ Params:
118
+ api_key: The API key to set
119
+ check_validity: Whether to check if the API key is valid and raise an error otherwise
120
+
121
+ Raises:
122
+ ValueError: if the API key is invalid and `check_validity` is True
123
+ """
124
+ set_headers(get_headers() | {"Api-Key": api_key})
125
+ if check_validity:
126
+ check_authentication()
@@ -0,0 +1,37 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+
5
+ from .credentials import OrcaCredentials
6
+
7
+
8
+ def test_list_api_keys():
9
+ api_keys = OrcaCredentials.list_api_keys()
10
+ assert len(api_keys) >= 1
11
+ assert "orca_sdk_test" in [api_key.name for api_key in api_keys]
12
+
13
+
14
+ def test_list_api_keys_unauthenticated(unauthenticated):
15
+ with pytest.raises(ValueError, match="Invalid API key"):
16
+ OrcaCredentials.list_api_keys()
17
+
18
+
19
+ def test_is_authenticated():
20
+ assert OrcaCredentials.is_authenticated()
21
+
22
+
23
+ def test_is_authenticated_false(unauthenticated):
24
+ assert not OrcaCredentials.is_authenticated()
25
+
26
+
27
+ def test_set_api_key(api_key, unauthenticated):
28
+ assert not OrcaCredentials.is_authenticated()
29
+ OrcaCredentials.set_api_key(api_key)
30
+ assert OrcaCredentials.is_authenticated()
31
+
32
+
33
+ def test_set_invalid_api_key(api_key):
34
+ assert OrcaCredentials.is_authenticated()
35
+ with pytest.raises(ValueError, match="Invalid API key"):
36
+ OrcaCredentials.set_api_key(str(uuid4()))
37
+ assert not OrcaCredentials.is_authenticated()