orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,369 @@
1
+ from uuid import uuid4
2
+
3
+ import numpy as np
4
+ import pytest
5
+ from datasets.arrow_dataset import Dataset
6
+
7
+ from .datasource import Datasource
8
+ from .embedding_model import PretrainedEmbeddingModel
9
+ from .memoryset import ScoredMemoryset
10
+ from .regression_model import RegressionMetrics, RegressionModel
11
+ from .telemetry import RegressionPrediction
12
+
13
+
14
+ def test_create_model(regression_model: RegressionModel, scored_memoryset: ScoredMemoryset):
15
+ assert regression_model is not None
16
+ assert regression_model.name == "test_regression_model"
17
+ assert regression_model.memoryset == scored_memoryset
18
+ assert regression_model.memory_lookup_count == 3
19
+
20
+
21
+ def test_create_model_already_exists_error(scored_memoryset, regression_model: RegressionModel):
22
+ with pytest.raises(ValueError):
23
+ RegressionModel.create("test_regression_model", scored_memoryset)
24
+ with pytest.raises(ValueError):
25
+ RegressionModel.create("test_regression_model", scored_memoryset, if_exists="error")
26
+
27
+
28
+ def test_create_model_already_exists_return(scored_memoryset, regression_model: RegressionModel):
29
+ with pytest.raises(ValueError):
30
+ RegressionModel.create("test_regression_model", scored_memoryset, if_exists="open", memory_lookup_count=37)
31
+
32
+ new_model = RegressionModel.create("test_regression_model", scored_memoryset, if_exists="open")
33
+ assert new_model is not None
34
+ assert new_model.name == "test_regression_model"
35
+ assert new_model.memoryset == scored_memoryset
36
+ assert new_model.memory_lookup_count == 3
37
+
38
+
39
+ def test_create_model_unauthenticated(unauthenticated, scored_memoryset: ScoredMemoryset):
40
+ with pytest.raises(ValueError, match="Invalid API key"):
41
+ RegressionModel.create("test_regression_model", scored_memoryset)
42
+
43
+
44
+ def test_get_model(regression_model: RegressionModel):
45
+ fetched_model = RegressionModel.open(regression_model.name)
46
+ assert fetched_model is not None
47
+ assert fetched_model.id == regression_model.id
48
+ assert fetched_model.name == regression_model.name
49
+ assert fetched_model.memory_lookup_count == 3
50
+ assert fetched_model == regression_model
51
+
52
+
53
+ def test_get_model_unauthenticated(unauthenticated):
54
+ with pytest.raises(ValueError, match="Invalid API key"):
55
+ RegressionModel.open("test_regression_model")
56
+
57
+
58
+ def test_get_model_invalid_input():
59
+ with pytest.raises(ValueError, match="Invalid input"):
60
+ RegressionModel.open("not valid id")
61
+
62
+
63
+ def test_get_model_not_found():
64
+ with pytest.raises(LookupError):
65
+ RegressionModel.open(str(uuid4()))
66
+
67
+
68
+ def test_get_model_unauthorized(unauthorized, regression_model: RegressionModel):
69
+ with pytest.raises(LookupError):
70
+ RegressionModel.open(regression_model.name)
71
+
72
+
73
+ def test_list_models(regression_model: RegressionModel):
74
+ models = RegressionModel.all()
75
+ assert len(models) > 0
76
+ assert any(model.name == regression_model.name for model in models)
77
+
78
+
79
+ def test_list_models_unauthenticated(unauthenticated):
80
+ with pytest.raises(ValueError, match="Invalid API key"):
81
+ RegressionModel.all()
82
+
83
+
84
+ def test_list_models_unauthorized(unauthorized, regression_model: RegressionModel):
85
+ assert RegressionModel.all() == []
86
+
87
+
88
+ def test_update_model_attributes(regression_model: RegressionModel):
89
+ regression_model.description = "New description"
90
+ assert regression_model.description == "New description"
91
+
92
+ regression_model.set(description=None)
93
+ assert regression_model.description is None
94
+
95
+ regression_model.set(locked=True)
96
+ assert regression_model.locked is True
97
+
98
+ regression_model.set(locked=False)
99
+ assert regression_model.locked is False
100
+
101
+ regression_model.lock()
102
+ assert regression_model.locked is True
103
+
104
+ regression_model.unlock()
105
+ assert regression_model.locked is False
106
+
107
+
108
+ def test_delete_model(scored_memoryset: ScoredMemoryset):
109
+ RegressionModel.create("regression_model_to_delete", ScoredMemoryset.open(scored_memoryset.name))
110
+ assert RegressionModel.open("regression_model_to_delete")
111
+ RegressionModel.drop("regression_model_to_delete")
112
+ with pytest.raises(LookupError):
113
+ RegressionModel.open("regression_model_to_delete")
114
+
115
+
116
+ def test_delete_model_unauthenticated(unauthenticated, regression_model: RegressionModel):
117
+ with pytest.raises(ValueError, match="Invalid API key"):
118
+ RegressionModel.drop(regression_model.name)
119
+
120
+
121
+ def test_delete_model_not_found():
122
+ with pytest.raises(LookupError):
123
+ RegressionModel.drop(str(uuid4()))
124
+ # ignores error if specified
125
+ RegressionModel.drop(str(uuid4()), if_not_exists="ignore")
126
+
127
+
128
+ def test_delete_model_unauthorized(unauthorized, regression_model: RegressionModel):
129
+ with pytest.raises(LookupError):
130
+ RegressionModel.drop(regression_model.name)
131
+
132
+
133
+ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
134
+ memoryset = ScoredMemoryset.from_hf_dataset("test_memoryset_delete_before_regression_model", hf_dataset)
135
+ RegressionModel.create("test_regression_model_delete_before_memoryset", memoryset)
136
+ with pytest.raises(RuntimeError):
137
+ ScoredMemoryset.drop(memoryset.id)
138
+
139
+
140
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
141
+ def test_evaluate(
142
+ regression_model: RegressionModel,
143
+ eval_datasource: Datasource,
144
+ eval_dataset: Dataset,
145
+ data_type,
146
+ ):
147
+ """Test that model evaluation with a dataset works."""
148
+ result = (
149
+ regression_model.evaluate(eval_dataset)
150
+ if data_type == "dataset"
151
+ else regression_model.evaluate(eval_datasource)
152
+ )
153
+
154
+ assert isinstance(result, RegressionMetrics)
155
+ assert np.allclose(result.mae, 0.4)
156
+ assert 0.0 <= result.mse <= 1.0
157
+ assert 0.0 <= result.rmse <= 1.0
158
+ assert result.r2 is not None
159
+
160
+ assert isinstance(result.anomaly_score_mean, float)
161
+ assert isinstance(result.anomaly_score_median, float)
162
+ assert isinstance(result.anomaly_score_variance, float)
163
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
164
+ assert -1.0 <= result.anomaly_score_median <= 1.0
165
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
166
+
167
+
168
+ def test_evaluate_datasource_with_nones_raises_error(regression_model: RegressionModel, datasource: Datasource):
169
+ with pytest.raises(ValueError):
170
+ regression_model.evaluate(datasource, record_predictions=True, tags={"test"})
171
+
172
+
173
+ def test_evaluate_dataset_with_nones_raises_error(regression_model: RegressionModel, hf_dataset: Dataset):
174
+ with pytest.raises(ValueError):
175
+ regression_model.evaluate(hf_dataset, record_predictions=True, tags={"test"})
176
+
177
+
178
+ def test_evaluate_with_telemetry(regression_model, eval_dataset: Dataset):
179
+ result = regression_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
180
+ assert result is not None
181
+ assert isinstance(result, RegressionMetrics)
182
+ predictions = regression_model.predictions(tag="test")
183
+ assert len(predictions) == 4
184
+ assert all(p.tags == {"test"} for p in predictions)
185
+ assert all(p.expected_score is not None for p in predictions)
186
+ assert all(np.allclose(p.expected_score, s) for p, s in zip(predictions, eval_dataset["score"]))
187
+
188
+
189
+ def test_predict(regression_model: RegressionModel):
190
+ predictions = regression_model.predict(["Do you love soup?", "Are cats cute?"])
191
+ assert len(predictions) == 2
192
+ assert predictions[0].prediction_id is not None
193
+ assert predictions[1].prediction_id is not None
194
+ assert np.allclose(predictions[0].score, 0.1)
195
+ assert np.allclose(predictions[1].score, 0.9)
196
+ assert 0 <= predictions[0].confidence <= 1
197
+ assert 0 <= predictions[1].confidence <= 1
198
+
199
+
200
+ def test_regression_prediction_has_no_score(regression_model: RegressionModel):
201
+ """Ensure optional score is None for regression predictions."""
202
+ prediction = regression_model.predict("This beach is amazing!")
203
+ assert isinstance(prediction, RegressionPrediction)
204
+ assert prediction.score is None
205
+
206
+
207
+ def test_predict_unauthenticated(unauthenticated, regression_model: RegressionModel):
208
+ with pytest.raises(ValueError, match="Invalid API key"):
209
+ regression_model.predict(["This is excellent!", "This is terrible!"])
210
+
211
+
212
+ def test_predict_unauthorized(unauthorized, regression_model: RegressionModel):
213
+ with pytest.raises(LookupError):
214
+ regression_model.predict(["This is excellent!", "This is terrible!"])
215
+
216
+
217
+ def test_predict_constraint_violation(scored_memoryset: ScoredMemoryset):
218
+ model = RegressionModel.create(
219
+ "test_regression_model_lookup_count_too_high",
220
+ scored_memoryset,
221
+ memory_lookup_count=scored_memoryset.length + 2,
222
+ )
223
+ with pytest.raises(RuntimeError):
224
+ model.predict("test")
225
+
226
+
227
+ def test_predict_with_prompt(regression_model: RegressionModel):
228
+ """Test that prompt parameter is properly passed through to predictions"""
229
+ # Test with an instruction-supporting embedding model if available
230
+ prediction_with_prompt = regression_model.predict(
231
+ "This product is amazing!", prompt="Represent this text for rating prediction:"
232
+ )
233
+ prediction_without_prompt = regression_model.predict("This product is amazing!")
234
+
235
+ # Both should work and return valid predictions
236
+ assert prediction_with_prompt.score is not None
237
+ assert prediction_without_prompt.score is not None
238
+ assert 0 <= prediction_with_prompt.confidence <= 1
239
+ assert 0 <= prediction_without_prompt.confidence <= 1
240
+
241
+
242
+ def test_record_prediction_feedback(regression_model: RegressionModel):
243
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
244
+ expected_scores = [0.9, 0.1]
245
+ regression_model.record_feedback(
246
+ {
247
+ "prediction_id": p.prediction_id,
248
+ "category": "accurate",
249
+ "value": abs(p.score - expected_score) < 0.2,
250
+ }
251
+ for expected_score, p in zip(expected_scores, predictions)
252
+ )
253
+
254
+
255
+ def test_record_prediction_feedback_missing_category(regression_model: RegressionModel):
256
+ prediction = regression_model.predict("This is excellent!")
257
+ with pytest.raises(ValueError):
258
+ regression_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
259
+
260
+
261
+ def test_record_prediction_feedback_invalid_value(regression_model: RegressionModel):
262
+ prediction = regression_model.predict("This is excellent!")
263
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
264
+ regression_model.record_feedback(
265
+ {"prediction_id": prediction.prediction_id, "category": "accurate", "value": "invalid"}
266
+ )
267
+
268
+
269
+ def test_record_prediction_feedback_invalid_prediction_id(regression_model: RegressionModel):
270
+ with pytest.raises(ValueError, match=r"Invalid input.*"):
271
+ regression_model.record_feedback({"prediction_id": "invalid", "category": "accurate", "value": True})
272
+
273
+
274
+ def test_predict_with_memoryset_override(regression_model: RegressionModel, hf_dataset: Dataset):
275
+ # Create a memoryset with different scores
276
+ inverted_scored_memoryset = ScoredMemoryset.from_hf_dataset(
277
+ "test_memoryset_inverted_scores",
278
+ hf_dataset.map(lambda x: {"score": (2.0 - x["score"]) if x["score"] is not None else None}), # Invert scores
279
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
280
+ )
281
+ original_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
282
+
283
+ with regression_model.use_memoryset(inverted_scored_memoryset):
284
+ override_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
285
+ # With inverted scores, the predictions should be different
286
+ assert abs(override_predictions[0].score - original_predictions[0].score) > 0.1
287
+ assert abs(override_predictions[1].score - original_predictions[1].score) > 0.1
288
+
289
+ # After exiting context, predictions should be back to normal
290
+ new_predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
291
+ assert abs(new_predictions[0].score - original_predictions[0].score) < 0.1
292
+ assert abs(new_predictions[1].score - original_predictions[1].score) < 0.1
293
+
294
+
295
+ def test_predict_with_expected_scores(regression_model: RegressionModel):
296
+ prediction = regression_model.predict("This is excellent!", expected_scores=0.9)
297
+ assert prediction.expected_score == 0.9
298
+
299
+
300
+ def test_regression_prediction_update(regression_model: RegressionModel):
301
+ prediction = regression_model.predict("Test input", expected_scores=3.5)
302
+ assert prediction.expected_score == 3.5
303
+ assert prediction.tags == set()
304
+
305
+ # Update expected score
306
+ prediction.update(expected_score=4.5)
307
+ assert prediction.expected_score == 4.5
308
+
309
+ # Add tags
310
+ prediction.update(tags={"test", "updated"})
311
+ assert prediction.tags == {"test", "updated"}
312
+
313
+ # Clear both
314
+ prediction.update(expected_score=None, tags=None)
315
+ assert prediction.expected_score is None
316
+ assert prediction.tags == set()
317
+
318
+
319
+ def test_last_prediction_with_batch(regression_model: RegressionModel):
320
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
321
+ assert regression_model.last_prediction is not None
322
+ assert regression_model.last_prediction.prediction_id == predictions[-1].prediction_id
323
+ assert regression_model.last_prediction.input_value == "This is terrible!"
324
+ assert regression_model._last_prediction_was_batch is True
325
+
326
+
327
+ def test_last_prediction_with_single(regression_model: RegressionModel):
328
+ # Test that last_prediction is updated correctly with single prediction
329
+ prediction = regression_model.predict("This is excellent!")
330
+ assert regression_model.last_prediction is not None
331
+ assert regression_model.last_prediction.prediction_id == prediction.prediction_id
332
+ assert regression_model.last_prediction.input_value == "This is excellent!"
333
+ assert regression_model._last_prediction_was_batch is False
334
+
335
+
336
+ def test_batch_predict(regression_model: RegressionModel):
337
+ """Test batch predictions"""
338
+ predictions = regression_model.predict(["test input 1", "test input 2", "test input 3"])
339
+ assert len(predictions) == 3
340
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
341
+
342
+
343
+ def test_batch_predict_with_expected_scores(regression_model: RegressionModel):
344
+ """Test batch predictions with expected scores"""
345
+ predictions = regression_model.predict(["input 1", "input 2"], expected_scores=[0.5, 0.8])
346
+ assert len(predictions) == 2
347
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
348
+
349
+
350
+ def test_use_memoryset(regression_model: RegressionModel, scored_memoryset: ScoredMemoryset):
351
+ # Test that predictions work with a memoryset
352
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
353
+ assert len(predictions) == 2
354
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
355
+ assert all(0 <= pred.confidence <= 1 for pred in predictions)
356
+
357
+ # Test that predictions work with a different memoryset
358
+ with regression_model.use_memoryset(scored_memoryset):
359
+ predictions = regression_model.predict(["This is excellent!", "This is terrible!"])
360
+ assert len(predictions) == 2
361
+ assert all(isinstance(pred, RegressionPrediction) for pred in predictions)
362
+ assert all(0 <= pred.confidence <= 1 for pred in predictions)
363
+
364
+
365
+ def test_drop(regression_model):
366
+ """Test that model drop works."""
367
+ name = regression_model.name
368
+ RegressionModel.drop(name)
369
+ assert not RegressionModel.exists(name)