orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -172
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +337 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
@@ -1,53 +1,52 @@
1
1
  import logging
2
- import os
3
2
  from uuid import uuid4
4
3
 
5
4
  import numpy as np
6
5
  import pytest
7
6
  from datasets.arrow_dataset import Dataset
8
7
 
9
- from .classification_model import ClassificationModel
8
+ from .classification_model import ClassificationMetrics, ClassificationModel
9
+ from .conftest import skip_in_ci
10
10
  from .datasource import Datasource
11
11
  from .embedding_model import PretrainedEmbeddingModel
12
12
  from .memoryset import LabeledMemoryset
13
13
 
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
15
14
 
15
+ def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
16
+ assert classification_model is not None
17
+ assert classification_model.name == "test_classification_model"
18
+ assert classification_model.memoryset == readonly_memoryset
19
+ assert classification_model.num_classes == 2
20
+ assert classification_model.memory_lookup_count == 3
16
21
 
17
- SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
18
22
 
19
-
20
- def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
21
- assert model is not None
22
- assert model.name == "test_model"
23
- assert model.memoryset == readonly_memoryset
24
- assert model.num_classes == 2
25
- assert model.memory_lookup_count == 3
26
-
27
-
28
- def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
23
+ def test_create_model_already_exists_error(readonly_memoryset, classification_model):
29
24
  with pytest.raises(ValueError):
30
- ClassificationModel.create("test_model", readonly_memoryset)
25
+ ClassificationModel.create("test_classification_model", readonly_memoryset)
31
26
  with pytest.raises(ValueError):
32
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
27
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
33
28
 
34
29
 
35
- def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
30
+ def test_create_model_already_exists_return(readonly_memoryset, classification_model):
36
31
  with pytest.raises(ValueError):
37
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
32
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
38
33
 
39
34
  with pytest.raises(ValueError):
40
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
35
+ ClassificationModel.create(
36
+ "test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
37
+ )
41
38
 
42
39
  with pytest.raises(ValueError):
43
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
40
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
44
41
 
45
42
  with pytest.raises(ValueError):
46
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
43
+ ClassificationModel.create(
44
+ "test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
45
+ )
47
46
 
48
- new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
47
+ new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
49
48
  assert new_model is not None
50
- assert new_model.name == "test_model"
49
+ assert new_model.name == "test_classification_model"
51
50
  assert new_model.memoryset == readonly_memoryset
52
51
  assert new_model.num_classes == 2
53
52
  assert new_model.memory_lookup_count == 3
@@ -58,14 +57,14 @@ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: Label
58
57
  ClassificationModel.create("test_model", readonly_memoryset)
59
58
 
60
59
 
61
- def test_get_model(model: ClassificationModel):
62
- fetched_model = ClassificationModel.open(model.name)
60
+ def test_get_model(classification_model: ClassificationModel):
61
+ fetched_model = ClassificationModel.open(classification_model.name)
63
62
  assert fetched_model is not None
64
- assert fetched_model.id == model.id
65
- assert fetched_model.name == model.name
63
+ assert fetched_model.id == classification_model.id
64
+ assert fetched_model.name == classification_model.name
66
65
  assert fetched_model.num_classes == 2
67
66
  assert fetched_model.memory_lookup_count == 3
68
- assert fetched_model == model
67
+ assert fetched_model == classification_model
69
68
 
70
69
 
71
70
  def test_get_model_unauthenticated(unauthenticated):
@@ -83,12 +82,12 @@ def test_get_model_not_found():
83
82
  ClassificationModel.open(str(uuid4()))
84
83
 
85
84
 
86
- def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
85
+ def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
87
86
  with pytest.raises(LookupError):
88
- ClassificationModel.open(model.name)
87
+ ClassificationModel.open(classification_model.name)
89
88
 
90
89
 
91
- def test_list_models(model: ClassificationModel):
90
+ def test_list_models(classification_model: ClassificationModel):
92
91
  models = ClassificationModel.all()
93
92
  assert len(models) > 0
94
93
  assert any(model.name == model.name for model in models)
@@ -99,19 +98,28 @@ def test_list_models_unauthenticated(unauthenticated):
99
98
  ClassificationModel.all()
100
99
 
101
100
 
102
- def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
101
+ def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
103
102
  assert ClassificationModel.all() == []
104
103
 
105
104
 
106
- def test_update_model(model: ClassificationModel):
107
- model.update_metadata(description="New description")
108
- assert model.description == "New description"
105
+ def test_update_model_attributes(classification_model: ClassificationModel):
106
+ classification_model.description = "New description"
107
+ assert classification_model.description == "New description"
108
+
109
+ classification_model.set(description=None)
110
+ assert classification_model.description is None
111
+
112
+ classification_model.set(locked=True)
113
+ assert classification_model.locked is True
109
114
 
115
+ classification_model.set(locked=False)
116
+ assert classification_model.locked is False
110
117
 
111
- def test_update_model_no_description(model: ClassificationModel):
112
- assert model.description is not None
113
- model.update_metadata(description=None)
114
- assert model.description is None
118
+ classification_model.lock()
119
+ assert classification_model.locked is True
120
+
121
+ classification_model.unlock()
122
+ assert classification_model.locked is False
115
123
 
116
124
 
117
125
  def test_delete_model(readonly_memoryset: LabeledMemoryset):
@@ -122,9 +130,9 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
122
130
  ClassificationModel.open("model_to_delete")
123
131
 
124
132
 
125
- def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
133
+ def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
126
134
  with pytest.raises(ValueError, match="Invalid API key"):
127
- ClassificationModel.drop(model.name)
135
+ ClassificationModel.drop(classification_model.name)
128
136
 
129
137
 
130
138
  def test_delete_model_not_found():
@@ -134,9 +142,9 @@ def test_delete_model_not_found():
134
142
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
135
143
 
136
144
 
137
- def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
145
+ def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
138
146
  with pytest.raises(LookupError):
139
- ClassificationModel.drop(model.name)
147
+ ClassificationModel.drop(classification_model.name)
140
148
 
141
149
 
142
150
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
@@ -146,78 +154,57 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
146
154
  LabeledMemoryset.drop(memoryset.id)
147
155
 
148
156
 
149
- def test_evaluate(model, eval_datasource: Datasource):
150
- result = model.evaluate(eval_datasource)
157
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
158
+ def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
159
+ result = (
160
+ classification_model.evaluate(eval_dataset)
161
+ if data_type == "dataset"
162
+ else classification_model.evaluate(eval_datasource)
163
+ )
164
+
151
165
  assert result is not None
152
- assert isinstance(result, dict)
153
- # And anomaly score statistics are present and valid
154
- assert isinstance(result["anomaly_score_mean"], float)
155
- assert isinstance(result["anomaly_score_median"], float)
156
- assert isinstance(result["anomaly_score_variance"], float)
157
- assert -1.0 <= result["anomaly_score_mean"] <= 1.0
158
- assert -1.0 <= result["anomaly_score_median"] <= 1.0
159
- assert -1.0 <= result["anomaly_score_variance"] <= 1.0
160
- assert isinstance(result["accuracy"], float)
161
- assert isinstance(result["f1_score"], float)
162
- assert isinstance(result["loss"], float)
163
- assert len(result["precision_recall_curve"]["thresholds"]) == 4
164
- assert len(result["precision_recall_curve"]["precisions"]) == 4
165
- assert len(result["precision_recall_curve"]["recalls"]) == 4
166
- assert len(result["roc_curve"]["thresholds"]) == 4
167
- assert len(result["roc_curve"]["false_positive_rates"]) == 4
168
- assert len(result["roc_curve"]["true_positive_rates"]) == 4
169
-
170
-
171
- def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
172
- result_datasource = model.evaluate(eval_datasource)
173
-
174
- result_dataset = model.evaluate(eval_dataset)
175
-
176
- for result in [result_datasource, result_dataset]:
177
- assert result is not None
178
- assert isinstance(result, dict)
179
- assert isinstance(result["accuracy"], float)
180
- assert isinstance(result["f1_score"], float)
181
- assert isinstance(result["loss"], float)
182
- assert np.allclose(result["accuracy"], 0.5)
183
- assert np.allclose(result["f1_score"], 0.5)
184
-
185
- assert isinstance(result["precision_recall_curve"]["thresholds"], list)
186
- assert isinstance(result["precision_recall_curve"]["precisions"], list)
187
- assert isinstance(result["precision_recall_curve"]["recalls"], list)
188
- assert isinstance(result["roc_curve"]["thresholds"], list)
189
- assert isinstance(result["roc_curve"]["false_positive_rates"], list)
190
- assert isinstance(result["roc_curve"]["true_positive_rates"], list)
191
-
192
- assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
193
- assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
194
- assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
195
- assert np.allclose(result["roc_curve"]["auc"], 0.625)
196
-
197
- assert np.allclose(
198
- result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
199
- )
200
- assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
201
- assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
202
- assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
203
-
204
-
205
- def test_evaluate_with_telemetry(model):
206
- samples = [
207
- {"text": "chicken noodle soup is the best", "label": 1},
208
- {"text": "cats are cute", "label": 0},
209
- ]
210
- eval_datasource = Datasource.from_list("eval_datasource_2", samples)
211
- result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
166
+ assert isinstance(result, ClassificationMetrics)
167
+
168
+ assert isinstance(result.accuracy, float)
169
+ assert np.allclose(result.accuracy, 0.5)
170
+ assert isinstance(result.f1_score, float)
171
+ assert np.allclose(result.f1_score, 0.5)
172
+ assert isinstance(result.loss, float)
173
+
174
+ assert isinstance(result.anomaly_score_mean, float)
175
+ assert isinstance(result.anomaly_score_median, float)
176
+ assert isinstance(result.anomaly_score_variance, float)
177
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
178
+ assert -1.0 <= result.anomaly_score_median <= 1.0
179
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
180
+
181
+ assert result.pr_auc is not None
182
+ assert np.allclose(result.pr_auc, 0.75)
183
+ assert result.pr_curve is not None
184
+ assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
185
+ assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
186
+ assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
187
+
188
+ assert result.roc_auc is not None
189
+ assert np.allclose(result.roc_auc, 0.625)
190
+ assert result.roc_curve is not None
191
+ assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
192
+ assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
193
+ assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
194
+
195
+
196
+ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
197
+ result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
212
198
  assert result is not None
213
- predictions = model.predictions(tag="test")
214
- assert len(predictions) == 2
199
+ assert isinstance(result, ClassificationMetrics)
200
+ predictions = classification_model.predictions(tag="test")
201
+ assert len(predictions) == 4
215
202
  assert all(p.tags == {"test"} for p in predictions)
216
- assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
203
+ assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
217
204
 
218
205
 
219
- def test_predict(model: ClassificationModel, label_names: list[str]):
220
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
206
+ def test_predict(classification_model: ClassificationModel, label_names: list[str]):
207
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
221
208
  assert len(predictions) == 2
222
209
  assert predictions[0].prediction_id is not None
223
210
  assert predictions[1].prediction_id is not None
@@ -236,8 +223,8 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
236
223
  assert predictions[1].logits[0] < predictions[1].logits[1]
237
224
 
238
225
 
239
- def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
240
- predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
226
+ def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
227
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
241
228
  assert len(predictions) == 2
242
229
  assert predictions[0].prediction_id is None
243
230
  assert predictions[1].prediction_id is None
@@ -249,14 +236,14 @@ def test_predict_disable_telemetry(model: ClassificationModel, label_names: list
249
236
  assert 0 <= predictions[1].confidence <= 1
250
237
 
251
238
 
252
- def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
239
+ def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
253
240
  with pytest.raises(ValueError, match="Invalid API key"):
254
- model.predict(["Do you love soup?", "Are cats cute?"])
241
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
255
242
 
256
243
 
257
- def test_predict_unauthorized(unauthorized, model: ClassificationModel):
244
+ def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
258
245
  with pytest.raises(LookupError):
259
- model.predict(["Do you love soup?", "Are cats cute?"])
246
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
260
247
 
261
248
 
262
249
  def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
@@ -270,10 +257,10 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
270
257
  model.predict("test")
271
258
 
272
259
 
273
- def test_record_prediction_feedback(model: ClassificationModel):
274
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
260
+ def test_record_prediction_feedback(classification_model: ClassificationModel):
261
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
275
262
  expected_labels = [0, 1]
276
- model.record_feedback(
263
+ classification_model.record_feedback(
277
264
  {
278
265
  "prediction_id": p.prediction_id,
279
266
  "category": "correct",
@@ -283,73 +270,80 @@ def test_record_prediction_feedback(model: ClassificationModel):
283
270
  )
284
271
 
285
272
 
286
- def test_record_prediction_feedback_missing_category(model: ClassificationModel):
287
- prediction = model.predict("Do you love soup?")
273
+ def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
274
+ prediction = classification_model.predict("Do you love soup?")
288
275
  with pytest.raises(ValueError):
289
- model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
276
+ classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
290
277
 
291
278
 
292
- def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
293
- prediction = model.predict("Do you love soup?")
279
+ def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
280
+ prediction = classification_model.predict("Do you love soup?")
294
281
  with pytest.raises(ValueError, match=r"Invalid input.*"):
295
- model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
282
+ classification_model.record_feedback(
283
+ {"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
284
+ )
296
285
 
297
286
 
298
- def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
287
+ def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
299
288
  with pytest.raises(ValueError, match=r"Invalid input.*"):
300
- model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
289
+ classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
301
290
 
302
291
 
303
- def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
292
+ def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
304
293
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
305
294
  "test_memoryset_inverted_labels",
306
295
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
307
296
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
308
297
  )
309
- with model.use_memoryset(inverted_labeled_memoryset):
310
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
298
+ with classification_model.use_memoryset(inverted_labeled_memoryset):
299
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
311
300
  assert predictions[0].label == 1
312
301
  assert predictions[1].label == 0
313
302
 
314
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
303
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
315
304
  assert predictions[0].label == 0
316
305
  assert predictions[1].label == 1
317
306
 
318
307
 
319
- def test_predict_with_expected_labels(model: ClassificationModel):
320
- prediction = model.predict("Do you love soup?", expected_labels=1)
308
+ def test_predict_with_expected_labels(classification_model: ClassificationModel):
309
+ prediction = classification_model.predict("Do you love soup?", expected_labels=1)
321
310
  assert prediction.expected_label == 1
322
311
 
323
312
 
324
- def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
313
+ def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
325
314
  # invalid number of expected labels for batch prediction
326
315
  with pytest.raises(ValueError, match=r"Invalid input.*"):
327
- model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
316
+ classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
328
317
  # invalid label value
329
318
  with pytest.raises(ValueError):
330
- model.predict("Do you love soup?", expected_labels=5)
319
+ classification_model.predict("Do you love soup?", expected_labels=5)
331
320
 
332
321
 
333
- def test_last_prediction_with_batch(model: ClassificationModel):
334
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
335
- assert model.last_prediction is not None
336
- assert model.last_prediction.prediction_id == predictions[-1].prediction_id
337
- assert model.last_prediction.input_value == "Are cats cute?"
338
- assert model._last_prediction_was_batch is True
322
+ def test_predict_with_filters(classification_model: ClassificationModel):
323
+ # there are no memories with label 0 and key g1, so we force a wrong prediction
324
+ filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
325
+ assert filtered_prediction.label == 1
326
+ assert filtered_prediction.label_name == "cats"
327
+
328
+
329
+ def test_last_prediction_with_batch(classification_model: ClassificationModel):
330
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
331
+ assert classification_model.last_prediction is not None
332
+ assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
333
+ assert classification_model.last_prediction.input_value == "Are cats cute?"
334
+ assert classification_model._last_prediction_was_batch is True
339
335
 
340
336
 
341
- def test_last_prediction_with_single(model: ClassificationModel):
337
+ def test_last_prediction_with_single(classification_model: ClassificationModel):
342
338
  # Test that last_prediction is updated correctly with single prediction
343
- prediction = model.predict("Do you love soup?")
344
- assert model.last_prediction is not None
345
- assert model.last_prediction.prediction_id == prediction.prediction_id
346
- assert model.last_prediction.input_value == "Do you love soup?"
347
- assert model._last_prediction_was_batch is False
339
+ prediction = classification_model.predict("Do you love soup?")
340
+ assert classification_model.last_prediction is not None
341
+ assert classification_model.last_prediction.prediction_id == prediction.prediction_id
342
+ assert classification_model.last_prediction.input_value == "Do you love soup?"
343
+ assert classification_model._last_prediction_was_batch is False
348
344
 
349
345
 
350
- @pytest.mark.skipif(
351
- SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
352
- )
346
+ @skip_in_ci("We don't have Anthropic API key in CI")
353
347
  def test_explain(writable_memoryset: LabeledMemoryset):
354
348
 
355
349
  writable_memoryset.analyze(
@@ -370,17 +364,13 @@ def test_explain(writable_memoryset: LabeledMemoryset):
370
364
 
371
365
  try:
372
366
  explanation = predictions[0].explanation
373
- print(explanation)
374
367
  assert explanation is not None
375
368
  assert len(explanation) > 10
376
369
  assert "soup" in explanation.lower()
377
370
  except Exception as e:
378
371
  if "ANTHROPIC_API_KEY" in str(e):
379
- logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set on server")
372
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
380
373
  else:
381
374
  raise e
382
375
  finally:
383
- try:
384
- ClassificationModel.drop("test_model_for_explain")
385
- except Exception as e:
386
- logging.info(f"Failed to drop test model for explain: {e}")
376
+ ClassificationModel.drop("test_model_for_explain")
orca_sdk/conftest.py CHANGED
@@ -11,15 +11,33 @@ from .classification_model import ClassificationModel
11
11
  from .credentials import OrcaCredentials
12
12
  from .datasource import Datasource
13
13
  from .embedding_model import PretrainedEmbeddingModel
14
- from .memoryset import LabeledMemoryset
14
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
15
+ from .regression_model import RegressionModel
15
16
 
16
- logging.basicConfig(level=logging.INFO)
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
17
18
 
18
19
  os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
19
20
 
20
21
  os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
21
22
 
22
23
 
24
+ def skip_in_prod(reason: str):
25
+ """Custom decorator to skip tests when running against production API"""
26
+ PROD_API_URLs = ["https://api.orcadb.ai", "https://api.dev.orcadb.ai"]
27
+ return pytest.mark.skipif(
28
+ os.environ["ORCA_API_URL"] in PROD_API_URLs,
29
+ reason=reason,
30
+ )
31
+
32
+
33
+ def skip_in_ci(reason: str):
34
+ """Custom decorator to skip tests when running in CI"""
35
+ return pytest.mark.skipif(
36
+ os.environ.get("GITHUB_ACTIONS", "false") == "true",
37
+ reason=reason,
38
+ )
39
+
40
+
23
41
  def _create_org_id():
24
42
  # UUID start to identify test data (0xtest...)
25
43
  return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
@@ -71,27 +89,27 @@ def label_names():
71
89
 
72
90
 
73
91
  SAMPLE_DATA = [
74
- {"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
75
- {"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
76
- {"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
77
- {"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
78
- {"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
79
- {"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
80
- {"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
81
- {"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
82
- {"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
83
- {"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
84
- {"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
85
- {"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
86
- {"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
87
- {"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
88
- {"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
89
- {"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
92
+ {"value": "i love soup", "label": 0, "key": "g1", "score": 0.1, "source_id": "s1"},
93
+ {"value": "cats are cute", "label": 1, "key": "g1", "score": 0.9, "source_id": "s2"},
94
+ {"value": "soup is good", "label": 0, "key": "g1", "score": 0.1, "source_id": "s3"},
95
+ {"value": "i love cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s4"},
96
+ {"value": "everyone loves cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s5"},
97
+ {"value": "soup is great for the winter", "label": 0, "key": "g1", "score": 0.1, "source_id": "s6"},
98
+ {"value": "hot soup on a rainy day!", "label": 0, "key": "g1", "score": 0.1, "source_id": "s7"},
99
+ {"value": "cats sleep all day", "label": 1, "key": "g1", "score": 0.9, "source_id": "s8"},
100
+ {"value": "homemade soup recipes", "label": 0, "key": "g1", "score": 0.1, "source_id": "s9"},
101
+ {"value": "cats purr when happy", "label": 1, "key": "g2", "score": 0.9, "source_id": "s10"},
102
+ {"value": "chicken noodle soup is classic", "label": 0, "key": "g1", "score": 0.1, "source_id": "s11"},
103
+ {"value": "kittens are baby cats", "label": 1, "key": "g2", "score": 0.9, "source_id": "s12"},
104
+ {"value": "soup can be served cold too", "label": 0, "key": "g1", "score": 0.1, "source_id": "s13"},
105
+ {"value": "cats have nine lives", "label": 1, "key": "g2", "score": 0.9, "source_id": "s14"},
106
+ {"value": "tomato soup with grilled cheese", "label": 0, "key": "g1", "score": 0.1, "source_id": "s15"},
107
+ {"value": "cats are independent animals", "label": 1, "key": "g2", "score": 0.9, "source_id": "s16"},
90
108
  ]
91
109
 
92
110
 
93
111
  @pytest.fixture(scope="session")
94
- def hf_dataset(label_names):
112
+ def hf_dataset(label_names: list[str]) -> Dataset:
95
113
  return Dataset.from_list(
96
114
  SAMPLE_DATA,
97
115
  features=Features(
@@ -107,16 +125,16 @@ def hf_dataset(label_names):
107
125
 
108
126
 
109
127
  @pytest.fixture(scope="session")
110
- def datasource(hf_dataset) -> Datasource:
128
+ def datasource(hf_dataset: Dataset) -> Datasource:
111
129
  datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
112
130
  return datasource
113
131
 
114
132
 
115
133
  EVAL_DATASET = [
116
- {"value": "chicken noodle soup is the best", "label": 1},
117
- {"value": "cats are cute", "label": 0},
118
- {"value": "soup is great for the winter", "label": 0},
119
- {"value": "i love cats", "label": 1},
134
+ {"value": "chicken noodle soup is the best", "label": 1, "score": 0.9}, # mislabeled
135
+ {"value": "cats are cute", "label": 0, "score": 0.1}, # mislabeled
136
+ {"value": "soup is great for the winter", "label": 0, "score": 0.1},
137
+ {"value": "i love cats", "label": 1, "score": 0.9},
120
138
  ]
121
139
 
122
140
 
@@ -140,6 +158,8 @@ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
140
158
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
141
159
  source_id_column="source_id",
142
160
  max_seq_length_override=32,
161
+ index_type="IVF_FLAT",
162
+ index_params={"n_lists": 100},
143
163
  )
144
164
  return memoryset
145
165
 
@@ -183,8 +203,38 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
183
203
 
184
204
 
185
205
  @pytest.fixture(scope="session")
186
- def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
206
+ def classification_model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
187
207
  model = ClassificationModel.create(
188
- "test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
208
+ "test_classification_model",
209
+ readonly_memoryset,
210
+ num_classes=2,
211
+ memory_lookup_count=3,
212
+ description="test_description",
213
+ )
214
+ return model
215
+
216
+
217
+ # Add scored memoryset and regression model fixtures
218
+ @pytest.fixture(scope="session")
219
+ def scored_memoryset(datasource: Datasource) -> ScoredMemoryset:
220
+ memoryset = ScoredMemoryset.create(
221
+ "test_scored_memoryset",
222
+ datasource=datasource,
223
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
224
+ source_id_column="source_id",
225
+ max_seq_length_override=32,
226
+ index_type="IVF_FLAT",
227
+ index_params={"n_lists": 100},
228
+ )
229
+ return memoryset
230
+
231
+
232
+ @pytest.fixture(scope="session")
233
+ def regression_model(scored_memoryset: ScoredMemoryset) -> RegressionModel:
234
+ model = RegressionModel.create(
235
+ "test_regression_model",
236
+ scored_memoryset,
237
+ memory_lookup_count=3,
238
+ description="test_regression_description",
189
239
  )
190
240
  return model
@@ -4,7 +4,6 @@ from uuid import uuid4
4
4
 
5
5
  import pytest
6
6
 
7
- from ._generated_api_client.models import EmbeddingEvaluationResponse, TaskStatus
8
7
  from .datasource import Datasource
9
8
 
10
9