orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,378 @@
1
+ from uuid import uuid4
2
+
3
+ import numpy as np
4
+ import pytest
5
+ from datasets.arrow_dataset import Dataset
6
+
7
+ from .datasource import Datasource
8
+ from .embedding_model import PretrainedEmbeddingModel
9
+ from .memoryset import ScoredMemoryset
10
+ from .regression_model import RegressionMetrics, RegressionModel
11
+ from .telemetry import RegressionPrediction
12
+
13
+
14
+ def test_create_model(regression_model: RegressionModel, scored_memoryset: ScoredMemoryset):
15
+ assert regression_model is not None
16
+ assert regression_model.name == "test_regression_model"
17
+ assert regression_model.memoryset == scored_memoryset
18
+ assert regression_model.memory_lookup_count == 3
19
+
20
+
21
+ def test_create_model_already_exists_error(scored_memoryset, regression_model: RegressionModel):
22
+ with pytest.raises(ValueError):
23
+ RegressionModel.create("test_regression_model", scored_memoryset)
24
+ with pytest.raises(ValueError):
25
+ RegressionModel.create("test_regression_model", scored_memoryset, if_exists="error")
26
+
27
+
28
+ def test_create_model_already_exists_return(scored_memoryset, regression_model: RegressionModel):
29
+ with pytest.raises(ValueError):
30
+ RegressionModel.create("test_regression_model", scored_memoryset, if_exists="open", memory_lookup_count=37)
31
+
32
+ new_model = RegressionModel.create("test_regression_model", scored_memoryset, if_exists="open")
33
+ assert new_model is not None
34
+ assert new_model.name == "test_regression_model"
35
+ assert new_model.memoryset == scored_memoryset
36
+ assert new_model.memory_lookup_count == 3
37
+
38
+
39
+ def test_create_model_unauthenticated(unauthenticated_client, scored_memoryset: ScoredMemoryset):
40
+ with unauthenticated_client.use():
41
+ with pytest.raises(ValueError, match="Invalid API key"):
42
+ RegressionModel.create("test_regression_model", scored_memoryset)
43
+
44
+
45
+ def test_get_model(regression_model: RegressionModel):
46
+ fetched_model = RegressionModel.open(regression_model.name)
47
+ assert fetched_model is not None
48
+ assert fetched_model.id == regression_model.id
49
+ assert fetched_model.name == regression_model.name
50
+ assert fetched_model.memory_lookup_count == 3
51
+ assert fetched_model == regression_model
52
+
53
+
54
+ def test_get_model_unauthenticated(unauthenticated_client):
55
+ with unauthenticated_client.use():
56
+ with pytest.raises(ValueError, match="Invalid API key"):
57
+ RegressionModel.open("test_regression_model")
58
+
59
+
60
+ def test_get_model_invalid_input():
61
+ with pytest.raises(ValueError, match="Invalid input"):
62
+ RegressionModel.open("not valid id")
63
+
64
+
65
+ def test_get_model_not_found():
66
+ with pytest.raises(LookupError):
67
+ RegressionModel.open(str(uuid4()))
68
+
69
+
70
+ def test_get_model_unauthorized(unauthorized_client, regression_model: RegressionModel):
71
+ with unauthorized_client.use():
72
+ with pytest.raises(LookupError):
73
+ RegressionModel.open(regression_model.name)
74
+
75
+
76
+ def test_list_models(regression_model: RegressionModel):
77
+ models = RegressionModel.all()
78
+ assert len(models) > 0
79
+ assert any(model.name == regression_model.name for model in models)
80
+
81
+
82
+ def test_list_models_unauthenticated(unauthenticated_client):
83
+ with unauthenticated_client.use():
84
+ with pytest.raises(ValueError, match="Invalid API key"):
85
+ RegressionModel.all()
86
+
87
+
88
+ def test_list_models_unauthorized(unauthorized_client, regression_model: RegressionModel):
89
+ with unauthorized_client.use():
90
+ assert RegressionModel.all() == []
91
+
92
+
93
+ def test_update_model_attributes(regression_model: RegressionModel):
94
+ regression_model.description = "New description"
95
+ assert regression_model.description == "New description"
96
+
97
+ regression_model.set(description=None)
98
+ assert regression_model.description is None
99
+
100
+ regression_model.set(locked=True)
101
+ assert regression_model.locked is True
102
+
103
+ regression_model.set(locked=False)
104
+ assert regression_model.locked is False
105
+
106
+ regression_model.lock()
107
+ assert regression_model.locked is True
108
+
109
+ regression_model.unlock()
110
+ assert regression_model.locked is False
111
+
112
+
113
+ def test_delete_model(scored_memoryset: ScoredMemoryset):
114
+ RegressionModel.create("regression_model_to_delete", ScoredMemoryset.open(scored_memoryset.name))
115
+ assert RegressionModel.open("regression_model_to_delete")
116
+ RegressionModel.drop("regression_model_to_delete")
117
+ with pytest.raises(LookupError):
118
+ RegressionModel.open("regression_model_to_delete")
119
+
120
+
121
+ def test_delete_model_unauthenticated(unauthenticated_client, regression_model: RegressionModel):
122
+ with unauthenticated_client.use():
123
+ with pytest.raises(ValueError, match="Invalid API key"):
124
+ RegressionModel.drop(regression_model.name)
125
+
126
+
127
+ def test_delete_model_not_found():
128
+ with pytest.raises(LookupError):
129
+ RegressionModel.drop(str(uuid4()))
130
+ # ignores error if specified
131
+ RegressionModel.drop(str(uuid4()), if_not_exists="ignore")
132
+
133
+
134
+ def test_delete_model_unauthorized(unauthorized_client, regression_model: RegressionModel):
135
+ with unauthorized_client.use():
136
+ with pytest.raises(LookupError):
137
+ RegressionModel.drop(regression_model.name)
138
+
139
+
140
+ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
141
+ memoryset = ScoredMemoryset.from_hf_dataset("test_memoryset_delete_before_regression_model", hf_dataset)
142
+ RegressionModel.create("test_regression_model_delete_before_memoryset", memoryset)
143
+ with pytest.raises(RuntimeError):
144
+ ScoredMemoryset.drop(memoryset.id)
145
+
146
+
147
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
148
+ def test_evaluate(
149
+ regression_model: RegressionModel,
150
+ eval_datasource: Datasource,
151
+ eval_dataset: Dataset,
152
+ data_type,
153
+ ):
154
+ """Test that model evaluation with a dataset works."""
155
+ result = (
156
+ regression_model.evaluate(eval_dataset)
157
+ if data_type == "dataset"
158
+ else regression_model.evaluate(eval_datasource)
159
+ )
160
+
161
+ assert isinstance(result, RegressionMetrics)
162
+ assert np.allclose(result.mae, 0.4)
163
+ assert 0.0 <= result.mse <= 1.0
164
+ assert 0.0 <= result.rmse <= 1.0
165
+ assert result.r2 is not None
166
+
167
+ assert isinstance(result.anomaly_score_mean, float)
168
+ assert isinstance(result.anomaly_score_median, float)
169
+ assert isinstance(result.anomaly_score_variance, float)
170
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
171
+ assert -1.0 <= result.anomaly_score_median <= 1.0
172
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
173
+
174
+
175
+ def test_evaluate_datasource_with_nones_raises_error(regression_model: RegressionModel, datasource: Datasource):
176
+ with pytest.raises(ValueError):
177
+ regression_model.evaluate(datasource, record_predictions=True, tags={"test"})
178
+
179
+
180
+ def test_evaluate_dataset_with_nones_raises_error(regression_model: RegressionModel, hf_dataset: Dataset):
181
+ with pytest.raises(ValueError):
182
+ regression_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
183
+
184
+
185
+ def test_evaluate_with_telemetry(regression_model, eval_dataset: Dataset):
186
+ result = regression_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
187
+ assert result is not None
188
+ assert isinstance(result, RegressionMetrics)
189
+ predictions = regression_model.predictions(tag="test")
190
+ assert len(predictions) == 4
191
+ assert all(p.tags == {"test"} for p in predictions)
192
+ assert all(p.expected_score is not None for p in predictions)
193
+ assert all(np.allclose(p.expected_score, s) for p, s in zip(predictions, eval_dataset["score"]))
194
+
195
+
196
+ def test_predict(regression_model: RegressionModel):
197
+ predictions = regression_model.predict(["Do you love soup?", "Are cats cute?"])
198
+ assert len(predictions) == 2
199
+ assert predictions[0].prediction_id is not None
200
+ assert predictions[1].prediction_id is not None
201
+ assert np.allclose(predictions[0].score, 0.1)
202
+ assert np.allclose(predictions[1].score, 0.9)
203
+ assert 0 <= predictions[0].confidence <= 1
204
+ assert 0 <= predictions[1].confidence <= 1
205
+
206
+
207
+ def test_regression_prediction_has_no_score(regression_model: RegressionModel):
208
+ """Ensure optional score is None for regression predictions."""
209
+ prediction = regression_model.predict("This beach is amazing!")
210
+ assert isinstance(prediction, RegressionPrediction)
211
+ assert prediction.score is None
212
+
213
+
214
+ def test_predict_unauthenticated(unauthenticated_client, regression_model: RegressionModel):
215
+ with unauthenticated_client.use():
216
+ with pytest.raises(ValueError, match="Invalid API key"):
217
+ regression_model.predict(["This is excellent!", "This is terrible!"])
218
+
219
+
220
+ def test_predict_unauthorized(unauthorized_client, regression_model: RegressionModel):
221
+ with unauthorized_client.use():
222
+ with pytest.raises(LookupError):
223
+ regression_model.predict(["This is excellent!", "This is terrible!"])
224
+
225
+
226
+ def test_predict_constraint_violation(scored_memoryset: ScoredMemoryset):
227
+ model = RegressionModel.create(
228
+ "test_regression_model_lookup_count_too_high",
229
+ scored_memoryset,
230
+ memory_lookup_count=scored_memoryset.length + 2,
231
+ )
232
+ with pytest.raises(RuntimeError):
233
+ model.predict("test")
234
+
235
+
236
+ def test_predict_with_prompt(regression_model: RegressionModel):
237
+ """Test that prompt parameter is properly passed through to predictions"""
238
+ # Test with an instruction-supporting embedding model if available
239
+ prediction_with_prompt = regression_model.predict(
240
+ "This product is amazing!", prompt="Represent this text for rating prediction:"
241
+ )
242
+ prediction_without_prompt = regression_model.predict("This product is amazing!")
243
+
244
+ # Both should work and return valid predictions
245
+ assert prediction_with_prompt.score is not None
246
+ assert prediction_without_prompt.score is not None
247
+ assert 0 <= prediction_with_prompt.confidence <= 1
248
+ assert 0 <= prediction_without_prompt.confidence <= 1
249
+
250
+
251
+ def test_record_prediction_feedback(regression_model: RegressionModel):
252
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
253
+ expected_scores = [0.9, 0.1]
254
+ regression_model.record_feedback(
255
+ {
256
+ "prediction_id": p.prediction_id,
257
+ "category": "accurate",
258
+ "value": abs(p.score - expected_score) < 0.2,
259
+ }
260
+ for expected_score, p in zip(expected_scores, predictions)
261
+ )
262
+
263
+
264
+ def test_record_prediction_feedback_missing_category(regression_model: RegressionModel):
265
+ prediction = regression_model.predict("This is excellent!")
266
+ with pytest.raises(ValueError):
267
+ regression_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
268
+
269
+
270
+ def test_record_prediction_feedback_invalid_value(regression_model: RegressionModel):
271
+ prediction = regression_model.predict("This is excellent!")
272
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
273
+ regression_model.record_feedback(
274
+ {"prediction_id": prediction.prediction_id, "category": "accurate", "value": "invalid"}
275
+ )
276
+
277
+
278
+ def test_record_prediction_feedback_invalid_prediction_id(regression_model: RegressionModel):
279
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
280
+ regression_model.record_feedback({"prediction_id": "invalid", "category": "accurate", "value": True})
281
+
282
+
283
+ def test_predict_with_memoryset_override(regression_model: RegressionModel, hf_dataset: Dataset):
284
+ # Create a memoryset with different scores
285
+ inverted_scored_memoryset = ScoredMemoryset.from_hf_dataset(
286
+ "test_memoryset_inverted_scores",
287
+ hf_dataset.map(lambda x: {"score": (2.0 - x["score"]) if x["score"] is not None else None}), # Invert scores
288
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
289
+ )
290
+ original_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
291
+
292
+ with regression_model.use_memoryset(inverted_scored_memoryset):
293
+ override_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
294
+ # With inverted scores, the predictions should be different
295
+ assert abs(override_predictions[0].score - original_predictions[0].score) > 0.1
296
+ assert abs(override_predictions[1].score - original_predictions[1].score) > 0.1
297
+
298
+ # After exiting context, predictions should be back to normal
299
+ new_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
300
+ assert abs(new_predictions[0].score - original_predictions[0].score) < 0.1
301
+ assert abs(new_predictions[1].score - original_predictions[1].score) < 0.1
302
+
303
+
304
+ def test_predict_with_expected_scores(regression_model: RegressionModel):
305
+ prediction = regression_model.predict("This is excellent!", expected_scores=0.9)
306
+ assert prediction.expected_score == 0.9
307
+
308
+
309
+ def test_regression_prediction_update(regression_model: RegressionModel):
310
+ prediction = regression_model.predict("Test input", expected_scores=3.5)
311
+ assert prediction.expected_score == 3.5
312
+ assert prediction.tags == set()
313
+
314
+ # Update expected score
315
+ prediction.update(expected_score=4.5)
316
+ assert prediction.expected_score == 4.5
317
+
318
+ # Add tags
319
+ prediction.update(tags={"test", "updated"})
320
+ assert prediction.tags == {"test", "updated"}
321
+
322
+ # Clear both
323
+ prediction.update(expected_score=None, tags=None)
324
+ assert prediction.expected_score is None
325
+ assert prediction.tags == set()
326
+
327
+
328
+ def test_last_prediction_with_batch(regression_model: RegressionModel):
329
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
330
+ assert regression_model.last_prediction is not None
331
+ assert regression_model.last_prediction.prediction_id == predictions[-1].prediction_id
332
+ assert regression_model.last_prediction.input_value == "This is terrible!"
333
+ assert regression_model._last_prediction_was_batch is True
334
+
335
+
336
+ def test_last_prediction_with_single(regression_model: RegressionModel):
337
+ # Test that last_prediction is updated correctly with single prediction
338
+ prediction = regression_model.predict("This is excellent!")
339
+ assert regression_model.last_prediction is not None
340
+ assert regression_model.last_prediction.prediction_id == prediction.prediction_id
341
+ assert regression_model.last_prediction.input_value == "This is excellent!"
342
+ assert regression_model._last_prediction_was_batch is False
343
+
344
+
345
+ def test_batch_predict(regression_model: RegressionModel):
346
+ """Test batch predictions"""
347
+ predictions = regression_model.predict(["test input 1", "test input 2", "test input 3"])
348
+ assert len(predictions) == 3
349
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
350
+
351
+
352
+ def test_batch_predict_with_expected_scores(regression_model: RegressionModel):
353
+ """Test batch predictions with expected scores"""
354
+ predictions = regression_model.predict(["input 1", "input 2"], expected_scores=[0.5, 0.8])
355
+ assert len(predictions) == 2
356
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
357
+
358
+
359
+ def test_use_memoryset(regression_model: RegressionModel, scored_memoryset: ScoredMemoryset):
360
+ # Test that predictions work with a memoryset
361
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
362
+ assert len(predictions) == 2
363
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
364
+ assert all(0 <= pred.confidence <= 1 for pred in predictions)
365
+
366
+ # Test that predictions work with a different memoryset
367
+ with regression_model.use_memoryset(scored_memoryset):
368
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
369
+ assert len(predictions) == 2
370
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
371
+ assert all(0 <= pred.confidence <= 1 for pred in predictions)
372
+
373
+
374
+ def test_drop(regression_model):
375
+ """Test that model drop works."""
376
+ name = regression_model.name
377
+ RegressionModel.drop(name)
378
+ assert not RegressionModel.exists(name)