orca-sdk 0.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. orca_sdk/__init__.py +24 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +205 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
  22. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
  23. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
  24. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
  25. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
  26. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
  27. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
  28. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
  30. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
  34. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
  35. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
  36. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
  37. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  38. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
  39. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
  40. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
  41. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
  42. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
  43. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
  45. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
  46. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
  47. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
  48. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
  49. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
  50. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
  51. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
  52. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
  53. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
  54. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
  57. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
  58. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
  59. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  60. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
  61. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
  62. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
  63. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
  65. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
  66. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
  67. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
  68. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
  69. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
  70. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
  71. orca_sdk/_generated_api_client/client.py +216 -0
  72. orca_sdk/_generated_api_client/errors.py +38 -0
  73. orca_sdk/_generated_api_client/models/__init__.py +179 -0
  74. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
  75. orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
  76. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
  77. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  78. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  79. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
  80. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  81. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  82. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  83. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  84. orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
  85. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
  86. orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
  87. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
  88. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
  89. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  90. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  91. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  92. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  93. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
  94. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
  95. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
  96. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  97. orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
  98. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  99. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  100. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  101. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  102. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  103. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  104. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
  105. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  106. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  107. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  108. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  109. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  110. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  111. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  112. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  113. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  114. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  115. orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
  116. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
  117. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  118. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
  119. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  120. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  121. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  122. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  123. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
  124. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  125. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  126. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  127. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  128. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  129. orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
  130. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  131. orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
  132. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  133. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  134. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  135. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  136. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  137. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  138. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
  139. orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
  140. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  141. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  142. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  143. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  144. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  145. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
  146. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
  147. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  148. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
  149. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  150. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  151. orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
  152. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  153. orca_sdk/_generated_api_client/models/task.py +198 -0
  154. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  155. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  156. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  157. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  158. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  159. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  160. orca_sdk/_generated_api_client/py.typed +1 -0
  161. orca_sdk/_generated_api_client/types.py +56 -0
  162. orca_sdk/_utils/__init__.py +0 -0
  163. orca_sdk/_utils/analysis_ui.py +192 -0
  164. orca_sdk/_utils/analysis_ui_style.css +54 -0
  165. orca_sdk/_utils/auth.py +68 -0
  166. orca_sdk/_utils/auth_test.py +31 -0
  167. orca_sdk/_utils/common.py +37 -0
  168. orca_sdk/_utils/data_parsing.py +99 -0
  169. orca_sdk/_utils/data_parsing_test.py +244 -0
  170. orca_sdk/_utils/prediction_result_ui.css +18 -0
  171. orca_sdk/_utils/prediction_result_ui.py +64 -0
  172. orca_sdk/_utils/task.py +73 -0
  173. orca_sdk/classification_model.py +508 -0
  174. orca_sdk/classification_model_test.py +272 -0
  175. orca_sdk/conftest.py +116 -0
  176. orca_sdk/credentials.py +126 -0
  177. orca_sdk/credentials_test.py +37 -0
  178. orca_sdk/datasource.py +333 -0
  179. orca_sdk/datasource_test.py +96 -0
  180. orca_sdk/embedding_model.py +347 -0
  181. orca_sdk/embedding_model_test.py +176 -0
  182. orca_sdk/memoryset.py +1209 -0
  183. orca_sdk/memoryset_test.py +287 -0
  184. orca_sdk/telemetry.py +398 -0
  185. orca_sdk/telemetry_test.py +109 -0
  186. orca_sdk-0.0.78.dist-info/METADATA +79 -0
  187. orca_sdk-0.0.78.dist-info/RECORD +188 -0
  188. orca_sdk-0.0.78.dist-info/WHEEL +4 -0
@@ -0,0 +1,287 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+ from datasets.arrow_dataset import Dataset
5
+
6
+ from .datasource import Datasource
7
+ from .embedding_model import PretrainedEmbeddingModel
8
+ from .memoryset import LabeledMemoryset, TaskStatus
9
+
10
+
11
+ def test_create_memoryset(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
12
+ assert memoryset is not None
13
+ assert memoryset.name == "test_memoryset"
14
+ assert memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
15
+ assert memoryset.label_names == label_names
16
+ assert memoryset.insertion_status == TaskStatus.COMPLETED
17
+ assert isinstance(memoryset.length, int)
18
+ assert memoryset.length == len(hf_dataset)
19
+
20
+
21
+ def test_create_memoryset_unauthenticated(unauthenticated, datasource):
22
+ with pytest.raises(ValueError, match="Invalid API key"):
23
+ LabeledMemoryset.create("test_memoryset", datasource)
24
+
25
+
26
+ def test_create_memoryset_invalid_input(datasource):
27
+ # invalid name
28
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
29
+ LabeledMemoryset.create("test memoryset", datasource)
30
+ # invalid datasource
31
+ datasource.id = str(uuid4())
32
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
33
+ LabeledMemoryset.create("test_memoryset_invalid_datasource", datasource)
34
+
35
+
36
+ def test_create_memoryset_already_exists_error(hf_dataset, label_names, memoryset):
37
+ with pytest.raises(ValueError):
38
+ LabeledMemoryset.from_hf_dataset("test_memoryset", hf_dataset, label_names=label_names, value_column="text")
39
+ with pytest.raises(ValueError):
40
+ LabeledMemoryset.from_hf_dataset(
41
+ "test_memoryset", hf_dataset, label_names=label_names, value_column="text", if_exists="error"
42
+ )
43
+
44
+
45
+ def test_create_memoryset_already_exists_open(hf_dataset, label_names, memoryset):
46
+ # invalid label names
47
+ with pytest.raises(ValueError):
48
+ LabeledMemoryset.from_hf_dataset(
49
+ memoryset.name,
50
+ hf_dataset,
51
+ label_names=["turtles", "frogs"],
52
+ value_column="text",
53
+ if_exists="open",
54
+ )
55
+ # different embedding model
56
+ with pytest.raises(ValueError):
57
+ LabeledMemoryset.from_hf_dataset(
58
+ memoryset.name,
59
+ hf_dataset,
60
+ label_names=label_names,
61
+ embedding_model=PretrainedEmbeddingModel.DISTILBERT,
62
+ if_exists="open",
63
+ )
64
+ opened_memoryset = LabeledMemoryset.from_hf_dataset(
65
+ memoryset.name,
66
+ hf_dataset,
67
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
68
+ if_exists="open",
69
+ )
70
+ assert opened_memoryset is not None
71
+ assert opened_memoryset.name == memoryset.name
72
+ assert opened_memoryset.length == len(hf_dataset)
73
+
74
+
75
+ def test_open_memoryset(memoryset, hf_dataset):
76
+ fetched_memoryset = LabeledMemoryset.open(memoryset.name)
77
+ assert fetched_memoryset is not None
78
+ assert fetched_memoryset.name == memoryset.name
79
+ assert fetched_memoryset.length == len(hf_dataset)
80
+
81
+
82
+ def test_open_memoryset_unauthenticated(unauthenticated, memoryset):
83
+ with pytest.raises(ValueError, match="Invalid API key"):
84
+ LabeledMemoryset.open(memoryset.name)
85
+
86
+
87
+ def test_open_memoryset_not_found():
88
+ with pytest.raises(LookupError):
89
+ LabeledMemoryset.open(str(uuid4()))
90
+
91
+
92
+ def test_open_memoryset_invalid_input():
93
+ with pytest.raises(ValueError, match=r"Invalid input:.*"):
94
+ LabeledMemoryset.open("not valid id")
95
+
96
+
97
+ def test_open_memoryset_unauthorized(unauthorized, memoryset):
98
+ with pytest.raises(LookupError):
99
+ LabeledMemoryset.open(memoryset.name)
100
+
101
+
102
+ def test_all_memorysets(memoryset):
103
+ memorysets = LabeledMemoryset.all()
104
+ assert len(memorysets) > 0
105
+ assert any(memoryset.name == memoryset.name for memoryset in memorysets)
106
+
107
+
108
+ def test_all_memorysets_unauthenticated(unauthenticated):
109
+ with pytest.raises(ValueError, match="Invalid API key"):
110
+ LabeledMemoryset.all()
111
+
112
+
113
+ def test_all_memorysets_unauthorized(unauthorized, memoryset):
114
+ assert memoryset not in LabeledMemoryset.all()
115
+
116
+
117
+ def test_drop_memoryset(hf_dataset):
118
+ memoryset = LabeledMemoryset.from_hf_dataset(
119
+ "test_memoryset_delete",
120
+ hf_dataset.select(range(1)),
121
+ value_column="text",
122
+ )
123
+ assert LabeledMemoryset.exists(memoryset.name)
124
+ LabeledMemoryset.drop(memoryset.name)
125
+ assert not LabeledMemoryset.exists(memoryset.name)
126
+
127
+
128
+ def test_drop_memoryset_unauthenticated(unauthenticated, memoryset):
129
+ with pytest.raises(ValueError, match="Invalid API key"):
130
+ LabeledMemoryset.drop(memoryset.name)
131
+
132
+
133
+ def test_drop_memoryset_not_found(memoryset):
134
+ with pytest.raises(LookupError):
135
+ LabeledMemoryset.drop(str(uuid4()))
136
+ # ignores error if specified
137
+ LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
138
+
139
+
140
+ def test_drop_memoryset_unauthorized(unauthorized, memoryset):
141
+ with pytest.raises(LookupError):
142
+ LabeledMemoryset.drop(memoryset.name)
143
+
144
+
145
+ def test_search(memoryset: LabeledMemoryset):
146
+ memory_lookups = memoryset.search(["i love soup", "cats are cute"])
147
+ assert len(memory_lookups) == 2
148
+ assert len(memory_lookups[0]) == 1
149
+ assert len(memory_lookups[1]) == 1
150
+ assert memory_lookups[0][0].label == 0
151
+ assert memory_lookups[1][0].label == 1
152
+
153
+
154
+ def test_search_count(memoryset: LabeledMemoryset):
155
+ memory_lookups = memoryset.search("i love soup", count=3)
156
+ assert len(memory_lookups) == 3
157
+ assert memory_lookups[0].label == 0
158
+ assert memory_lookups[1].label == 0
159
+ assert memory_lookups[2].label == 0
160
+
161
+
162
+ def test_get_memory_at_index(memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
163
+ memory = memoryset[0]
164
+ assert memory.value == hf_dataset[0]["text"]
165
+ assert memory.label == hf_dataset[0]["label"]
166
+ assert memory.label_name == label_names[hf_dataset[0]["label"]]
167
+ assert memory.source_id == hf_dataset[0]["source_id"]
168
+ assert memory.score == hf_dataset[0]["score"]
169
+ assert memory.key == hf_dataset[0]["key"]
170
+ last_memory = memoryset[-1]
171
+ assert last_memory.value == hf_dataset[-1]["text"]
172
+ assert last_memory.label == hf_dataset[-1]["label"]
173
+
174
+
175
+ def test_get_range_of_memories(memoryset: LabeledMemoryset, hf_dataset: Dataset):
176
+ memories = memoryset[1:3]
177
+ assert len(memories) == 2
178
+ assert memories[0].value == hf_dataset["text"][1]
179
+ assert memories[1].value == hf_dataset["text"][2]
180
+
181
+
182
+ def test_get_memory_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
183
+ memory = memoryset.get(memoryset[0].memory_id)
184
+ assert memory.value == hf_dataset[0]["text"]
185
+ assert memory == memoryset[memory.memory_id]
186
+
187
+
188
+ def test_get_memories_by_id(memoryset: LabeledMemoryset, hf_dataset: Dataset):
189
+ memories = memoryset.get([memoryset[0].memory_id, memoryset[1].memory_id])
190
+ assert len(memories) == 2
191
+ assert memories[0].value == hf_dataset[0]["text"]
192
+ assert memories[1].value == hf_dataset[1]["text"]
193
+
194
+
195
+ def test_query_memoryset(memoryset: LabeledMemoryset):
196
+ memories = memoryset.query(filters=[("label", "==", 1)])
197
+ assert len(memories) == 3
198
+ assert all(memory.label == 1 for memory in memories)
199
+ assert len(memoryset.query(limit=2)) == 2
200
+ assert len(memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
201
+
202
+
203
+ def test_insert_memories(memoryset: LabeledMemoryset):
204
+ prev_length = memoryset.length
205
+ memoryset.insert(
206
+ [
207
+ dict(value="tomato soup is my favorite", label=0),
208
+ dict(value="cats are fun to play with", label=1),
209
+ ]
210
+ )
211
+ assert memoryset.length == prev_length + 2
212
+ memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
213
+ assert memoryset.length == prev_length + 3
214
+ last_memory = memoryset[-1]
215
+ assert last_memory.value == "tomato soup is my favorite"
216
+ assert last_memory.label == 0
217
+ assert last_memory.metadata
218
+ assert last_memory.metadata["key"] == "test"
219
+ assert last_memory.source_id == "test"
220
+
221
+
222
+ def test_update_memory(memoryset: LabeledMemoryset, hf_dataset: Dataset):
223
+ memory_id = memoryset[0].memory_id
224
+ updated_memory = memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
225
+ assert updated_memory.value == "i love soup so much"
226
+ assert updated_memory.label == hf_dataset[0]["label"]
227
+ assert memoryset.get(memory_id).value == "i love soup so much"
228
+
229
+
230
+ def test_update_memory_instance(memoryset: LabeledMemoryset, hf_dataset: Dataset):
231
+ memory = memoryset[0]
232
+ updated_memory = memory.update(value="i love soup even more")
233
+ assert updated_memory is memory
234
+ assert memory.value == "i love soup even more"
235
+ assert memory.label == hf_dataset[0]["label"]
236
+
237
+
238
+ def test_update_memories(memoryset: LabeledMemoryset):
239
+ memory_ids = [memory.memory_id for memory in memoryset[:2]]
240
+ updated_memories = memoryset.update(
241
+ [
242
+ dict(memory_id=memory_ids[0], value="i love soup so much"),
243
+ dict(memory_id=memory_ids[1], value="cats are so cute"),
244
+ ]
245
+ )
246
+ assert updated_memories[0].value == "i love soup so much"
247
+ assert updated_memories[1].value == "cats are so cute"
248
+
249
+
250
+ def test_delete_memory(memoryset: LabeledMemoryset):
251
+ prev_length = memoryset.length
252
+ memory_id = memoryset[0].memory_id
253
+ memoryset.delete(memory_id)
254
+ with pytest.raises(LookupError):
255
+ memoryset.get(memory_id)
256
+ assert memoryset.length == prev_length - 1
257
+
258
+
259
+ def test_delete_memories(memoryset: LabeledMemoryset):
260
+ prev_length = memoryset.length
261
+ memoryset.delete([memoryset[0].memory_id, memoryset[1].memory_id])
262
+ assert memoryset.length == prev_length - 2
263
+
264
+
265
+ def test_clone_memoryset(memoryset: LabeledMemoryset):
266
+ cloned_memoryset = memoryset.clone("test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT)
267
+ assert cloned_memoryset is not None
268
+ assert cloned_memoryset.name == "test_cloned_memoryset"
269
+ assert cloned_memoryset.length == memoryset.length
270
+ assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
271
+ assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
272
+
273
+
274
+ def test_embedding_evaluation(hf_dataset):
275
+ datasource = Datasource.from_hf_dataset("eval_datasource", hf_dataset, if_exists="open")
276
+ response = LabeledMemoryset.run_embedding_evaluation(
277
+ datasource, embedding_models=["CDE_SMALL"], neighbor_count=2, value_column="text"
278
+ )
279
+ assert response is not None
280
+ assert isinstance(response, dict)
281
+ assert response is not None
282
+ assert isinstance(response["evaluation_results"], list)
283
+ assert len(response["evaluation_results"]) == 1
284
+ assert response["evaluation_results"][0] is not None
285
+ assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
286
+ assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
287
+ Datasource.drop("eval_datasource")
orca_sdk/telemetry.py ADDED
@@ -0,0 +1,398 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING, Any, Iterable, overload
6
+ from uuid import UUID
7
+
8
+ from orca_sdk._utils.common import UNSET
9
+
10
+ from ._generated_api_client.api import (
11
+ drop_feedback_category_with_data,
12
+ get_prediction,
13
+ list_feedback_categories,
14
+ list_predictions,
15
+ record_prediction_feedback,
16
+ update_prediction,
17
+ )
18
+ from ._generated_api_client.models import (
19
+ FeedbackType,
20
+ LabelPredictionWithMemoriesAndFeedback,
21
+ ListPredictionsRequest,
22
+ PredictionFeedbackCategory,
23
+ PredictionFeedbackRequest,
24
+ UpdatePredictionRequest,
25
+ )
26
+ from ._generated_api_client.types import UNSET as CLIENT_UNSET
27
+ from ._utils.prediction_result_ui import inspect_prediction_result
28
+ from .memoryset import LabeledMemoryLookup, LabeledMemoryset
29
+
30
+ if TYPE_CHECKING:
31
+ from .classification_model import ClassificationModel
32
+
33
+
34
+ def _parse_feedback(feedback: dict[str, Any]) -> PredictionFeedbackRequest:
35
+ category = feedback.get("category", None)
36
+ if category is None:
37
+ raise ValueError("`category` must be specified")
38
+ prediction_id = feedback.get("prediction_id", None)
39
+ if prediction_id is None:
40
+ raise ValueError("`prediction_id` must be specified")
41
+ return PredictionFeedbackRequest(
42
+ prediction_id=prediction_id,
43
+ category_name=category,
44
+ value=feedback.get("value", CLIENT_UNSET),
45
+ comment=feedback.get("comment", CLIENT_UNSET),
46
+ )
47
+
48
+
49
+ class FeedbackCategory:
50
+ """
51
+ A category of feedback for predictions.
52
+
53
+ Categories are created automatically, the first time feedback with a new name is recorded.
54
+ The value type of the category is inferred from the first recorded value. Subsequent feedback
55
+ for the same category must be of the same type. Categories are not model specific.
56
+
57
+ Attributes:
58
+ id: Unique identifier for the category.
59
+ name: Name of the category.
60
+ value_type: Type that values for this category must have.
61
+ created_at: When the category was created.
62
+ """
63
+
64
+ id: str
65
+ name: str
66
+ value_type: type[bool] | type[float]
67
+ created_at: datetime
68
+
69
+ def __init__(self, category: PredictionFeedbackCategory):
70
+ # for internal use only, do not document
71
+ self.id = category.id
72
+ self.name = category.name
73
+ self.value_type = bool if category.type == FeedbackType.BINARY else float
74
+ self.created_at = category.created_at
75
+
76
+ @classmethod
77
+ def all(cls) -> list[FeedbackCategory]:
78
+ """
79
+ Get a list of all existing feedback categories.
80
+
81
+ Returns:
82
+ List with information about all existing feedback categories.
83
+ """
84
+ return [FeedbackCategory(category) for category in list_feedback_categories()]
85
+
86
+ @classmethod
87
+ def drop(cls, name: str) -> None:
88
+ """
89
+ Drop all feedback for this category and drop the category itself, allowing it to be
90
+ recreated with a different value type.
91
+
92
+ Warning:
93
+ This will delete all feedback in this category across all models.
94
+
95
+ Params:
96
+ name: Name of the category to drop.
97
+
98
+ Raises:
99
+ LookupError: If the category is not found.
100
+ """
101
+ drop_feedback_category_with_data(name)
102
+ logging.info(f"Deleted feedback category {name} with all associated feedback")
103
+
104
+ def __repr__(self):
105
+ return "FeedbackCategory({" + f"name: {self.name}, " + f"value_type: {self.value_type}" + "})"
106
+
107
+
108
+ class LabelPrediction:
109
+ """
110
+ A prediction made by a model
111
+
112
+ Attributes:
113
+ prediction_id: Unique identifier for the prediction
114
+ label: Predicted label for the input value
115
+ label_name: Name of the predicted label
116
+ confidence: Confidence of the prediction
117
+ anomaly_score: The score for how anomalous the input is relative to the memories
118
+ memory_lookups: List of memories used to ground the prediction
119
+ input_value: Input value that this prediction was for
120
+ model: Model that was used to make the prediction
121
+ memoryset: Memoryset that was used to lookup memories to ground the prediction
122
+ expected_label: Optional expected label that was set for the prediction
123
+ tags: tags that were set for the prediction
124
+ feedback: Feedback recorded, mapping from category name to value
125
+ """
126
+
127
+ prediction_id: str
128
+ label: int
129
+ label_name: str | None
130
+ confidence: float
131
+ anomaly_score: float | None
132
+ memoryset: LabeledMemoryset
133
+ model: ClassificationModel
134
+
135
+ def __init__(
136
+ self,
137
+ prediction_id: str,
138
+ *,
139
+ label: int,
140
+ label_name: str | None,
141
+ confidence: float,
142
+ anomaly_score: float | None,
143
+ memoryset: LabeledMemoryset | str,
144
+ model: ClassificationModel | str,
145
+ telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
146
+ ):
147
+ # for internal use only, do not document
148
+ from .classification_model import ClassificationModel
149
+
150
+ self.prediction_id = prediction_id
151
+ self.label = label
152
+ self.label_name = label_name
153
+ self.confidence = confidence
154
+ self.anomaly_score = anomaly_score
155
+ self.memoryset = LabeledMemoryset.open(memoryset) if isinstance(memoryset, str) else memoryset
156
+ self.model = ClassificationModel.open(model) if isinstance(model, str) else model
157
+ self.__telemetry = telemetry if telemetry else None
158
+
159
+ def __repr__(self):
160
+ return (
161
+ "LabelPrediction({"
162
+ + f"label: <{self.label_name}: {self.label}>, "
163
+ + f"confidence: {self.confidence:.2f}, "
164
+ + f"anomaly_score: {self.anomaly_score:.2f}, "
165
+ if self.anomaly_score is not None
166
+ else ""
167
+ + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
168
+ + "})"
169
+ )
170
+
171
+ @property
172
+ def _telemetry(self) -> LabelPredictionWithMemoriesAndFeedback:
173
+ # for internal use only, do not document
174
+ if self.__telemetry is None:
175
+ self.__telemetry = get_prediction(prediction_id=UUID(self.prediction_id))
176
+ return self.__telemetry
177
+
178
+ @property
179
+ def memory_lookups(self) -> list[LabeledMemoryLookup]:
180
+ return [LabeledMemoryLookup(self.memoryset.id, lookup) for lookup in self._telemetry.memories]
181
+
182
+ @property
183
+ def input_value(self) -> str | None:
184
+ return self._telemetry.input_value
185
+
186
+ @property
187
+ def feedback(self) -> dict[str, bool | float]:
188
+ return {
189
+ f.category_name: (
190
+ f.value if f.category_type == FeedbackType.CONTINUOUS else True if f.value == 1 else False
191
+ )
192
+ for f in self._telemetry.feedbacks
193
+ }
194
+
195
+ @property
196
+ def expected_label(self) -> int | None:
197
+ return self._telemetry.expected_label
198
+
199
+ @property
200
+ def tags(self) -> set[str]:
201
+ return set(self._telemetry.tags)
202
+
203
+ @overload
204
+ @classmethod
205
+ def get(cls, prediction_id: str) -> LabelPrediction: # type: ignore -- this takes precedence
206
+ pass
207
+
208
+ @overload
209
+ @classmethod
210
+ def get(cls, prediction_id: Iterable[str]) -> list[LabelPrediction]:
211
+ pass
212
+
213
+ @classmethod
214
+ def get(cls, prediction_id: str | Iterable[str]) -> LabelPrediction | list[LabelPrediction]:
215
+ """
216
+ Fetch a prediction or predictions
217
+
218
+ Params:
219
+ prediction_id: Unique identifier of the prediction or predictions to fetch
220
+
221
+ Returns:
222
+ Prediction or list of predictions
223
+
224
+ Raises:
225
+ LookupError: If no prediction with the given id is found
226
+
227
+ Examples:
228
+ Fetch a single prediction:
229
+ >>> LabelPrediction.get("0195019a-5bc7-7afb-b902-5945ee1fb766")
230
+ LabelPrediction({
231
+ label: <positive: 1>,
232
+ confidence: 0.95,
233
+ anomaly_score: 0.1,
234
+ input_value: "I am happy",
235
+ memoryset: "my_memoryset",
236
+ model: "my_model"
237
+ })
238
+
239
+ Fetch multiple predictions:
240
+ >>> LabelPrediction.get([
241
+ ... "0195019a-5bc7-7afb-b902-5945ee1fb766",
242
+ ... "019501a1-ea08-76b2-9f62-95e4800b4841",
243
+ ... ])
244
+ [
245
+ LabelPrediction({
246
+ label: <positive: 1>,
247
+ confidence: 0.95,
248
+ anomaly_score: 0.1,
249
+ input_value: "I am happy",
250
+ memoryset: "my_memoryset",
251
+ model: "my_model"
252
+ }),
253
+ LabelPrediction({
254
+ label: <negative: 0>,
255
+ confidence: 0.05,
256
+ anomaly_score: 0.2,
257
+ input_value: "I am sad",
258
+ memoryset: "my_memoryset", model: "my_model"
259
+ }),
260
+ ]
261
+ """
262
+ if isinstance(prediction_id, str):
263
+ prediction = get_prediction(prediction_id=UUID(prediction_id))
264
+ return cls(
265
+ prediction_id=prediction.prediction_id,
266
+ label=prediction.label,
267
+ label_name=prediction.label_name,
268
+ confidence=prediction.confidence,
269
+ anomaly_score=prediction.anomaly_score,
270
+ memoryset=prediction.memoryset_id,
271
+ model=prediction.model_id,
272
+ telemetry=prediction,
273
+ )
274
+ else:
275
+ return [
276
+ cls(
277
+ prediction_id=prediction.prediction_id,
278
+ label=prediction.label,
279
+ label_name=prediction.label_name,
280
+ confidence=prediction.confidence,
281
+ anomaly_score=prediction.anomaly_score,
282
+ memoryset=prediction.memoryset_id,
283
+ model=prediction.model_id,
284
+ telemetry=prediction,
285
+ )
286
+ for prediction in list_predictions(body=ListPredictionsRequest(prediction_ids=list(prediction_id)))
287
+ ]
288
+
289
+ def refresh(self):
290
+ """Refresh the prediction data from the OrcaCloud"""
291
+ self.__dict__.update(LabelPrediction.get(self.prediction_id).__dict__)
292
+
293
+ def inspect(self):
294
+ """Open a UI to inspect the memories used by this prediction"""
295
+ inspect_prediction_result(self)
296
+
297
+ def update(self, *, expected_label: int | None = UNSET, tags: set[str] | None = UNSET) -> None:
298
+ """
299
+ Update editable prediction properties.
300
+
301
+ Params:
302
+ expected_label: Value to set for the expected label, defaults to `[UNSET]` if not provided.
303
+ tags: Value to replace existing tags with, defaults to `[UNSET]` if not provided.
304
+
305
+ Examples:
306
+ Update the expected label:
307
+ >>> prediction.update(expected_label=1)
308
+
309
+ Add a new tag:
310
+ >>> prediction.update(tags=prediction.tags | {"new_tag"})
311
+
312
+ Remove expected label and tags:
313
+ >>> prediction.update(expected_label=None, tags=None)
314
+ """
315
+ update_prediction(
316
+ prediction_id=self.prediction_id,
317
+ body=UpdatePredictionRequest(
318
+ expected_label=expected_label if expected_label is not UNSET else CLIENT_UNSET,
319
+ tags=[] if tags is None else list(tags) if tags is not UNSET else CLIENT_UNSET,
320
+ ),
321
+ )
322
+ self.refresh()
323
+
324
+ def add_tag(self, tag: str) -> None:
325
+ """
326
+ Add a tag to the prediction
327
+
328
+ Params:
329
+ tag: Tag to add to the prediction
330
+ """
331
+ self.update(tags=self.tags | {tag})
332
+
333
+ def remove_tag(self, tag: str) -> None:
334
+ """
335
+ Remove a tag from the prediction
336
+
337
+ Params:
338
+ tag: Tag to remove from the prediction
339
+ """
340
+ self.update(tags=self.tags - {tag})
341
+
342
+ def record_feedback(
343
+ self,
344
+ category: str,
345
+ value: bool | float,
346
+ *,
347
+ comment: str | None = None,
348
+ ):
349
+ """
350
+ Record feedback for the prediction.
351
+
352
+ We support recording feedback in several categories for each prediction. A
353
+ [`FeedbackCategory`][orca_sdk.telemetry.FeedbackCategory] is created automatically,
354
+ the first time feedback with a new name is recorded. Categories are global across models.
355
+ The value type of the category is inferred from the first recorded value. Subsequent
356
+ feedback for the same category must be of the same type.
357
+
358
+ Params:
359
+ category: Name of the category under which to record the feedback.
360
+ value: Feedback value to record, should be `True` for positive feedback and `False` for
361
+ negative feedback or a [`float`][float] between `-1.0` and `+1.0` where negative
362
+ values indicate negative feedback and positive values indicate positive feedback.
363
+ comment: Optional comment to record with the feedback.
364
+
365
+ Examples:
366
+ Record whether a suggestion was accepted or rejected:
367
+ >>> prediction.record_feedback("accepted", True)
368
+
369
+ Record star rating as normalized continuous score between `-1.0` and `+1.0`:
370
+ >>> prediction.record_feedback("rating", -0.5, comment="2 stars")
371
+
372
+ Raises:
373
+ ValueError: If the value does not match previous value types for the category, or is a
374
+ [`float`][float] that is not between `-1.0` and `+1.0`.
375
+ """
376
+ record_prediction_feedback(
377
+ body=[
378
+ _parse_feedback(
379
+ {"prediction_id": self.prediction_id, "category": category, "value": value, "comment": comment}
380
+ )
381
+ ]
382
+ )
383
+ self.refresh()
384
+
385
+ def delete_feedback(self, category: str) -> None:
386
+ """
387
+ Delete prediction feedback for a specific category.
388
+
389
+ Params:
390
+ category: Name of the category of the feedback to delete.
391
+
392
+ Raises:
393
+ ValueError: If the category is not found.
394
+ """
395
+ record_prediction_feedback(
396
+ body=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)]
397
+ )
398
+ self.refresh()