orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +84 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  42. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  43. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  44. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  45. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  46. orca_sdk/_generated_api_client/models/__init__.py +90 -24
  47. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  48. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  49. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  50. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  51. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  52. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  53. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  54. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  55. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  56. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  57. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  58. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  59. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  60. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  61. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  62. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  63. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  64. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  65. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  66. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  67. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  68. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  69. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  70. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  71. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  72. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  73. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  74. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  75. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  76. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  77. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  78. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  79. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  80. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  81. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  82. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  83. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  84. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  85. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  86. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  88. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  92. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  93. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  94. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  95. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  96. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  97. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  98. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  99. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  100. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  101. orca_sdk/_shared/__init__.py +9 -1
  102. orca_sdk/_shared/metrics.py +257 -87
  103. orca_sdk/_shared/metrics_test.py +136 -77
  104. orca_sdk/_utils/data_parsing.py +0 -3
  105. orca_sdk/_utils/data_parsing_test.py +0 -3
  106. orca_sdk/_utils/prediction_result_ui.py +55 -23
  107. orca_sdk/classification_model.py +184 -174
  108. orca_sdk/classification_model_test.py +178 -142
  109. orca_sdk/conftest.py +77 -26
  110. orca_sdk/datasource.py +34 -0
  111. orca_sdk/datasource_test.py +9 -1
  112. orca_sdk/embedding_model.py +136 -14
  113. orca_sdk/embedding_model_test.py +10 -6
  114. orca_sdk/job.py +329 -0
  115. orca_sdk/job_test.py +48 -0
  116. orca_sdk/memoryset.py +882 -161
  117. orca_sdk/memoryset_test.py +58 -23
  118. orca_sdk/regression_model.py +647 -0
  119. orca_sdk/regression_model_test.py +338 -0
  120. orca_sdk/telemetry.py +225 -106
  121. orca_sdk/telemetry_test.py +34 -30
  122. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  123. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
  124. orca_sdk/_utils/task.py +0 -73
  125. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -1,46 +1,52 @@
1
+ import logging
1
2
  from uuid import uuid4
2
3
 
3
4
  import numpy as np
4
5
  import pytest
5
6
  from datasets.arrow_dataset import Dataset
6
7
 
7
- from .classification_model import ClassificationModel
8
+ from .classification_model import ClassificationMetrics, ClassificationModel
9
+ from .conftest import skip_in_ci
8
10
  from .datasource import Datasource
9
11
  from .embedding_model import PretrainedEmbeddingModel
10
12
  from .memoryset import LabeledMemoryset
11
13
 
12
14
 
13
- def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
14
- assert model is not None
15
- assert model.name == "test_model"
16
- assert model.memoryset == readonly_memoryset
17
- assert model.num_classes == 2
18
- assert model.memory_lookup_count == 3
15
+ def test_create_model(classification_model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
16
+ assert classification_model is not None
17
+ assert classification_model.name == "test_classification_model"
18
+ assert classification_model.memoryset == readonly_memoryset
19
+ assert classification_model.num_classes == 2
20
+ assert classification_model.memory_lookup_count == 3
19
21
 
20
22
 
21
- def test_create_model_already_exists_error(readonly_memoryset, model: ClassificationModel):
23
+ def test_create_model_already_exists_error(readonly_memoryset, classification_model):
22
24
  with pytest.raises(ValueError):
23
- ClassificationModel.create("test_model", readonly_memoryset)
25
+ ClassificationModel.create("test_classification_model", readonly_memoryset)
24
26
  with pytest.raises(ValueError):
25
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="error")
27
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="error")
26
28
 
27
29
 
28
- def test_create_model_already_exists_return(readonly_memoryset, model: ClassificationModel):
30
+ def test_create_model_already_exists_return(readonly_memoryset, classification_model):
29
31
  with pytest.raises(ValueError):
30
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", head_type="MMOE")
32
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", head_type="MMOE")
31
33
 
32
34
  with pytest.raises(ValueError):
33
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", memory_lookup_count=37)
35
+ ClassificationModel.create(
36
+ "test_classification_model", readonly_memoryset, if_exists="open", memory_lookup_count=37
37
+ )
34
38
 
35
39
  with pytest.raises(ValueError):
36
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", num_classes=19)
40
+ ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open", num_classes=19)
37
41
 
38
42
  with pytest.raises(ValueError):
39
- ClassificationModel.create("test_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77)
43
+ ClassificationModel.create(
44
+ "test_classification_model", readonly_memoryset, if_exists="open", min_memory_weight=0.77
45
+ )
40
46
 
41
- new_model = ClassificationModel.create("test_model", readonly_memoryset, if_exists="open")
47
+ new_model = ClassificationModel.create("test_classification_model", readonly_memoryset, if_exists="open")
42
48
  assert new_model is not None
43
- assert new_model.name == "test_model"
49
+ assert new_model.name == "test_classification_model"
44
50
  assert new_model.memoryset == readonly_memoryset
45
51
  assert new_model.num_classes == 2
46
52
  assert new_model.memory_lookup_count == 3
@@ -51,14 +57,14 @@ def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: Label
51
57
  ClassificationModel.create("test_model", readonly_memoryset)
52
58
 
53
59
 
54
- def test_get_model(model: ClassificationModel):
55
- fetched_model = ClassificationModel.open(model.name)
60
+ def test_get_model(classification_model: ClassificationModel):
61
+ fetched_model = ClassificationModel.open(classification_model.name)
56
62
  assert fetched_model is not None
57
- assert fetched_model.id == model.id
58
- assert fetched_model.name == model.name
63
+ assert fetched_model.id == classification_model.id
64
+ assert fetched_model.name == classification_model.name
59
65
  assert fetched_model.num_classes == 2
60
66
  assert fetched_model.memory_lookup_count == 3
61
- assert fetched_model == model
67
+ assert fetched_model == classification_model
62
68
 
63
69
 
64
70
  def test_get_model_unauthenticated(unauthenticated):
@@ -76,12 +82,12 @@ def test_get_model_not_found():
76
82
  ClassificationModel.open(str(uuid4()))
77
83
 
78
84
 
79
- def test_get_model_unauthorized(unauthorized, model: ClassificationModel):
85
+ def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
80
86
  with pytest.raises(LookupError):
81
- ClassificationModel.open(model.name)
87
+ ClassificationModel.open(classification_model.name)
82
88
 
83
89
 
84
- def test_list_models(model: ClassificationModel):
90
+ def test_list_models(classification_model: ClassificationModel):
85
91
  models = ClassificationModel.all()
86
92
  assert len(models) > 0
87
93
  assert any(model.name == model.name for model in models)
@@ -92,19 +98,28 @@ def test_list_models_unauthenticated(unauthenticated):
92
98
  ClassificationModel.all()
93
99
 
94
100
 
95
- def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
101
+ def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
96
102
  assert ClassificationModel.all() == []
97
103
 
98
104
 
99
- def test_update_model(model: ClassificationModel):
100
- model.update_metadata(description="New description")
101
- assert model.description == "New description"
105
+ def test_update_model_attributes(classification_model: ClassificationModel):
106
+ classification_model.description = "New description"
107
+ assert classification_model.description == "New description"
108
+
109
+ classification_model.set(description=None)
110
+ assert classification_model.description is None
102
111
 
112
+ classification_model.set(locked=True)
113
+ assert classification_model.locked is True
103
114
 
104
- def test_update_model_no_description(model: ClassificationModel):
105
- assert model.description is not None
106
- model.update_metadata(description=None)
107
- assert model.description is None
115
+ classification_model.set(locked=False)
116
+ assert classification_model.locked is False
117
+
118
+ classification_model.lock()
119
+ assert classification_model.locked is True
120
+
121
+ classification_model.unlock()
122
+ assert classification_model.locked is False
108
123
 
109
124
 
110
125
  def test_delete_model(readonly_memoryset: LabeledMemoryset):
@@ -115,9 +130,9 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
115
130
  ClassificationModel.open("model_to_delete")
116
131
 
117
132
 
118
- def test_delete_model_unauthenticated(unauthenticated, model: ClassificationModel):
133
+ def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
119
134
  with pytest.raises(ValueError, match="Invalid API key"):
120
- ClassificationModel.drop(model.name)
135
+ ClassificationModel.drop(classification_model.name)
121
136
 
122
137
 
123
138
  def test_delete_model_not_found():
@@ -127,9 +142,9 @@ def test_delete_model_not_found():
127
142
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
128
143
 
129
144
 
130
- def test_delete_model_unauthorized(unauthorized, model: ClassificationModel):
145
+ def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
131
146
  with pytest.raises(LookupError):
132
- ClassificationModel.drop(model.name)
147
+ ClassificationModel.drop(classification_model.name)
133
148
 
134
149
 
135
150
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
@@ -139,78 +154,57 @@ def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
139
154
  LabeledMemoryset.drop(memoryset.id)
140
155
 
141
156
 
142
- def test_evaluate(model, eval_datasource: Datasource):
143
- result = model.evaluate(eval_datasource)
157
+ @pytest.mark.parametrize("data_type", ["dataset", "datasource"])
158
+ def test_evaluate(classification_model, eval_datasource: Datasource, eval_dataset: Dataset, data_type):
159
+ result = (
160
+ classification_model.evaluate(eval_dataset)
161
+ if data_type == "dataset"
162
+ else classification_model.evaluate(eval_datasource)
163
+ )
164
+
144
165
  assert result is not None
145
- assert isinstance(result, dict)
146
- # And anomaly score statistics are present and valid
147
- assert isinstance(result["anomaly_score_mean"], float)
148
- assert isinstance(result["anomaly_score_median"], float)
149
- assert isinstance(result["anomaly_score_variance"], float)
150
- assert -1.0 <= result["anomaly_score_mean"] <= 1.0
151
- assert -1.0 <= result["anomaly_score_median"] <= 1.0
152
- assert -1.0 <= result["anomaly_score_variance"] <= 1.0
153
- assert isinstance(result["accuracy"], float)
154
- assert isinstance(result["f1_score"], float)
155
- assert isinstance(result["loss"], float)
156
- assert len(result["precision_recall_curve"]["thresholds"]) == 4
157
- assert len(result["precision_recall_curve"]["precisions"]) == 4
158
- assert len(result["precision_recall_curve"]["recalls"]) == 4
159
- assert len(result["roc_curve"]["thresholds"]) == 4
160
- assert len(result["roc_curve"]["false_positive_rates"]) == 4
161
- assert len(result["roc_curve"]["true_positive_rates"]) == 4
162
-
163
-
164
- def test_evaluate_combined(model, eval_datasource: Datasource, eval_dataset: Dataset):
165
- result_datasource = model.evaluate(eval_datasource)
166
-
167
- result_dataset = model.evaluate(eval_dataset)
168
-
169
- for result in [result_datasource, result_dataset]:
170
- assert result is not None
171
- assert isinstance(result, dict)
172
- assert isinstance(result["accuracy"], float)
173
- assert isinstance(result["f1_score"], float)
174
- assert isinstance(result["loss"], float)
175
- assert np.allclose(result["accuracy"], 0.5)
176
- assert np.allclose(result["f1_score"], 0.5)
177
-
178
- assert isinstance(result["precision_recall_curve"]["thresholds"], list)
179
- assert isinstance(result["precision_recall_curve"]["precisions"], list)
180
- assert isinstance(result["precision_recall_curve"]["recalls"], list)
181
- assert isinstance(result["roc_curve"]["thresholds"], list)
182
- assert isinstance(result["roc_curve"]["false_positive_rates"], list)
183
- assert isinstance(result["roc_curve"]["true_positive_rates"], list)
184
-
185
- assert np.allclose(result["roc_curve"]["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
186
- assert np.allclose(result["roc_curve"]["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
187
- assert np.allclose(result["roc_curve"]["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
188
- assert np.allclose(result["roc_curve"]["auc"], 0.625)
189
-
190
- assert np.allclose(
191
- result["precision_recall_curve"]["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927]
192
- )
193
- assert np.allclose(result["precision_recall_curve"]["precisions"], [0.5, 0.5, 1.0, 1.0])
194
- assert np.allclose(result["precision_recall_curve"]["recalls"], [1.0, 0.5, 0.5, 0.0])
195
- assert np.allclose(result["precision_recall_curve"]["auc"], 0.75)
196
-
197
-
198
- def test_evaluate_with_telemetry(model):
199
- samples = [
200
- {"text": "chicken noodle soup is the best", "label": 1},
201
- {"text": "cats are cute", "label": 0},
202
- ]
203
- eval_datasource = Datasource.from_list("eval_datasource_2", samples)
204
- result = model.evaluate(eval_datasource, value_column="text", record_predictions=True, tags={"test"})
166
+ assert isinstance(result, ClassificationMetrics)
167
+
168
+ assert isinstance(result.accuracy, float)
169
+ assert np.allclose(result.accuracy, 0.5)
170
+ assert isinstance(result.f1_score, float)
171
+ assert np.allclose(result.f1_score, 0.5)
172
+ assert isinstance(result.loss, float)
173
+
174
+ assert isinstance(result.anomaly_score_mean, float)
175
+ assert isinstance(result.anomaly_score_median, float)
176
+ assert isinstance(result.anomaly_score_variance, float)
177
+ assert -1.0 <= result.anomaly_score_mean <= 1.0
178
+ assert -1.0 <= result.anomaly_score_median <= 1.0
179
+ assert -1.0 <= result.anomaly_score_variance <= 1.0
180
+
181
+ assert result.pr_auc is not None
182
+ assert np.allclose(result.pr_auc, 0.75)
183
+ assert result.pr_curve is not None
184
+ assert np.allclose(result.pr_curve["thresholds"], [0.0, 0.0, 0.8155114054679871, 0.834095299243927])
185
+ assert np.allclose(result.pr_curve["precisions"], [0.5, 0.5, 1.0, 1.0])
186
+ assert np.allclose(result.pr_curve["recalls"], [1.0, 0.5, 0.5, 0.0])
187
+
188
+ assert result.roc_auc is not None
189
+ assert np.allclose(result.roc_auc, 0.625)
190
+ assert result.roc_curve is not None
191
+ assert np.allclose(result.roc_curve["thresholds"], [0.0, 0.8155114054679871, 0.834095299243927, 1.0])
192
+ assert np.allclose(result.roc_curve["false_positive_rates"], [1.0, 0.5, 0.0, 0.0])
193
+ assert np.allclose(result.roc_curve["true_positive_rates"], [1.0, 0.5, 0.5, 0.0])
194
+
195
+
196
+ def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
197
+ result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
205
198
  assert result is not None
206
- predictions = model.predictions(tag="test")
207
- assert len(predictions) == 2
199
+ assert isinstance(result, ClassificationMetrics)
200
+ predictions = classification_model.predictions(tag="test")
201
+ assert len(predictions) == 4
208
202
  assert all(p.tags == {"test"} for p in predictions)
209
- assert all(p.expected_label == s["label"] for p, s in zip(predictions, samples))
203
+ assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
210
204
 
211
205
 
212
- def test_predict(model: ClassificationModel, label_names: list[str]):
213
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
206
+ def test_predict(classification_model: ClassificationModel, label_names: list[str]):
207
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
214
208
  assert len(predictions) == 2
215
209
  assert predictions[0].prediction_id is not None
216
210
  assert predictions[1].prediction_id is not None
@@ -229,8 +223,8 @@ def test_predict(model: ClassificationModel, label_names: list[str]):
229
223
  assert predictions[1].logits[0] < predictions[1].logits[1]
230
224
 
231
225
 
232
- def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
233
- predictions = model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry=False)
226
+ def test_predict_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
227
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
234
228
  assert len(predictions) == 2
235
229
  assert predictions[0].prediction_id is None
236
230
  assert predictions[1].prediction_id is None
@@ -242,14 +236,14 @@ def test_predict_disable_telemetry(model: ClassificationModel, label_names: list
242
236
  assert 0 <= predictions[1].confidence <= 1
243
237
 
244
238
 
245
- def test_predict_unauthenticated(unauthenticated, model: ClassificationModel):
239
+ def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
246
240
  with pytest.raises(ValueError, match="Invalid API key"):
247
- model.predict(["Do you love soup?", "Are cats cute?"])
241
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
248
242
 
249
243
 
250
- def test_predict_unauthorized(unauthorized, model: ClassificationModel):
244
+ def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
251
245
  with pytest.raises(LookupError):
252
- model.predict(["Do you love soup?", "Are cats cute?"])
246
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
253
247
 
254
248
 
255
249
  def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
@@ -263,10 +257,10 @@ def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
263
257
  model.predict("test")
264
258
 
265
259
 
266
- def test_record_prediction_feedback(model: ClassificationModel):
267
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
260
+ def test_record_prediction_feedback(classification_model: ClassificationModel):
261
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
268
262
  expected_labels = [0, 1]
269
- model.record_feedback(
263
+ classification_model.record_feedback(
270
264
  {
271
265
  "prediction_id": p.prediction_id,
272
266
  "category": "correct",
@@ -276,65 +270,107 @@ def test_record_prediction_feedback(model: ClassificationModel):
276
270
  )
277
271
 
278
272
 
279
- def test_record_prediction_feedback_missing_category(model: ClassificationModel):
280
- prediction = model.predict("Do you love soup?")
273
+ def test_record_prediction_feedback_missing_category(classification_model: ClassificationModel):
274
+ prediction = classification_model.predict("Do you love soup?")
281
275
  with pytest.raises(ValueError):
282
- model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
276
+ classification_model.record_feedback({"prediction_id": prediction.prediction_id, "value": True})
283
277
 
284
278
 
285
- def test_record_prediction_feedback_invalid_value(model: ClassificationModel):
286
- prediction = model.predict("Do you love soup?")
279
+ def test_record_prediction_feedback_invalid_value(classification_model: ClassificationModel):
280
+ prediction = classification_model.predict("Do you love soup?")
287
281
  with pytest.raises(ValueError, match=r"Invalid input.*"):
288
- model.record_feedback({"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"})
282
+ classification_model.record_feedback(
283
+ {"prediction_id": prediction.prediction_id, "category": "correct", "value": "invalid"}
284
+ )
289
285
 
290
286
 
291
- def test_record_prediction_feedback_invalid_prediction_id(model: ClassificationModel):
287
+ def test_record_prediction_feedback_invalid_prediction_id(classification_model: ClassificationModel):
292
288
  with pytest.raises(ValueError, match=r"Invalid input.*"):
293
- model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
289
+ classification_model.record_feedback({"prediction_id": "invalid", "category": "correct", "value": True})
294
290
 
295
291
 
296
- def test_predict_with_memoryset_override(model: ClassificationModel, hf_dataset: Dataset):
292
+ def test_predict_with_memoryset_override(classification_model: ClassificationModel, hf_dataset: Dataset):
297
293
  inverted_labeled_memoryset = LabeledMemoryset.from_hf_dataset(
298
294
  "test_memoryset_inverted_labels",
299
295
  hf_dataset.map(lambda x: {"label": 1 if x["label"] == 0 else 0}),
300
296
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
301
297
  )
302
- with model.use_memoryset(inverted_labeled_memoryset):
303
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
298
+ with classification_model.use_memoryset(inverted_labeled_memoryset):
299
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
304
300
  assert predictions[0].label == 1
305
301
  assert predictions[1].label == 0
306
302
 
307
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
303
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
308
304
  assert predictions[0].label == 0
309
305
  assert predictions[1].label == 1
310
306
 
311
307
 
312
- def test_predict_with_expected_labels(model: ClassificationModel):
313
- prediction = model.predict("Do you love soup?", expected_labels=1)
308
+ def test_predict_with_expected_labels(classification_model: ClassificationModel):
309
+ prediction = classification_model.predict("Do you love soup?", expected_labels=1)
314
310
  assert prediction.expected_label == 1
315
311
 
316
312
 
317
- def test_predict_with_expected_labels_invalid_input(model: ClassificationModel):
313
+ def test_predict_with_expected_labels_invalid_input(classification_model: ClassificationModel):
318
314
  # invalid number of expected labels for batch prediction
319
315
  with pytest.raises(ValueError, match=r"Invalid input.*"):
320
- model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
316
+ classification_model.predict(["Do you love soup?", "Are cats cute?"], expected_labels=[0])
321
317
  # invalid label value
322
318
  with pytest.raises(ValueError):
323
- model.predict("Do you love soup?", expected_labels=5)
319
+ classification_model.predict("Do you love soup?", expected_labels=5)
324
320
 
325
321
 
326
- def test_last_prediction_with_batch(model: ClassificationModel):
327
- predictions = model.predict(["Do you love soup?", "Are cats cute?"])
328
- assert model.last_prediction is not None
329
- assert model.last_prediction.prediction_id == predictions[-1].prediction_id
330
- assert model.last_prediction.input_value == "Are cats cute?"
331
- assert model._last_prediction_was_batch is True
322
+ def test_predict_with_filters(classification_model: ClassificationModel):
323
+ # there are no memories with label 0 and key g1, so we force a wrong prediction
324
+ filtered_prediction = classification_model.predict("I love soup", filters=[("key", "==", "g2")])
325
+ assert filtered_prediction.label == 1
326
+ assert filtered_prediction.label_name == "cats"
332
327
 
333
328
 
334
- def test_last_prediction_with_single(model: ClassificationModel):
329
+ def test_last_prediction_with_batch(classification_model: ClassificationModel):
330
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
331
+ assert classification_model.last_prediction is not None
332
+ assert classification_model.last_prediction.prediction_id == predictions[-1].prediction_id
333
+ assert classification_model.last_prediction.input_value == "Are cats cute?"
334
+ assert classification_model._last_prediction_was_batch is True
335
+
336
+
337
+ def test_last_prediction_with_single(classification_model: ClassificationModel):
335
338
  # Test that last_prediction is updated correctly with single prediction
336
- prediction = model.predict("Do you love soup?")
337
- assert model.last_prediction is not None
338
- assert model.last_prediction.prediction_id == prediction.prediction_id
339
- assert model.last_prediction.input_value == "Do you love soup?"
340
- assert model._last_prediction_was_batch is False
339
+ prediction = classification_model.predict("Do you love soup?")
340
+ assert classification_model.last_prediction is not None
341
+ assert classification_model.last_prediction.prediction_id == prediction.prediction_id
342
+ assert classification_model.last_prediction.input_value == "Do you love soup?"
343
+ assert classification_model._last_prediction_was_batch is False
344
+
345
+
346
+ @skip_in_ci("We don't have Anthropic API key in CI")
347
+ def test_explain(writable_memoryset: LabeledMemoryset):
348
+
349
+ writable_memoryset.analyze(
350
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
351
+ lookup_count=3,
352
+ )
353
+
354
+ model = ClassificationModel.create(
355
+ "test_model_for_explain",
356
+ writable_memoryset,
357
+ num_classes=2,
358
+ memory_lookup_count=3,
359
+ description="This is a test model for explain",
360
+ )
361
+
362
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
363
+ assert len(predictions) == 2
364
+
365
+ try:
366
+ explanation = predictions[0].explanation
367
+ assert explanation is not None
368
+ assert len(explanation) > 10
369
+ assert "soup" in explanation.lower()
370
+ except Exception as e:
371
+ if "ANTHROPIC_API_KEY" in str(e):
372
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set")
373
+ else:
374
+ raise e
375
+ finally:
376
+ ClassificationModel.drop("test_model_for_explain")
orca_sdk/conftest.py CHANGED
@@ -11,15 +11,33 @@ from .classification_model import ClassificationModel
11
11
  from .credentials import OrcaCredentials
12
12
  from .datasource import Datasource
13
13
  from .embedding_model import PretrainedEmbeddingModel
14
- from .memoryset import LabeledMemoryset
14
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
15
+ from .regression_model import RegressionModel
15
16
 
16
- logging.basicConfig(level=logging.INFO)
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
17
18
 
18
19
  os.environ["ORCA_API_URL"] = os.environ.get("ORCA_API_URL", "http://localhost:1584/")
19
20
 
20
21
  os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
21
22
 
22
23
 
24
+ def skip_in_prod(reason: str):
25
+ """Custom decorator to skip tests when running against production API"""
26
+ PROD_API_URLs = ["https://api.orcadb.ai", "https://api.dev.orcadb.ai"]
27
+ return pytest.mark.skipif(
28
+ os.environ["ORCA_API_URL"] in PROD_API_URLs,
29
+ reason=reason,
30
+ )
31
+
32
+
33
+ def skip_in_ci(reason: str):
34
+ """Custom decorator to skip tests when running in CI"""
35
+ return pytest.mark.skipif(
36
+ os.environ.get("GITHUB_ACTIONS", "false") == "true",
37
+ reason=reason,
38
+ )
39
+
40
+
23
41
  def _create_org_id():
24
42
  # UUID start to identify test data (0xtest...)
25
43
  return "10e50000-0000-4000-a000-" + str(uuid4())[24:]
@@ -71,27 +89,27 @@ def label_names():
71
89
 
72
90
 
73
91
  SAMPLE_DATA = [
74
- {"value": "i love soup", "label": 0, "key": "val1", "score": 0.1, "source_id": "s1"},
75
- {"value": "cats are cute", "label": 1, "key": "val2", "score": 0.2, "source_id": "s2"},
76
- {"value": "soup is good", "label": 0, "key": "val3", "score": 0.3, "source_id": "s3"},
77
- {"value": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
78
- {"value": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
79
- {"value": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
80
- {"value": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
81
- {"value": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
82
- {"value": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
83
- {"value": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
84
- {"value": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
85
- {"value": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
86
- {"value": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
87
- {"value": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
88
- {"value": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
89
- {"value": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
92
+ {"value": "i love soup", "label": 0, "key": "g1", "score": 0.1, "source_id": "s1"},
93
+ {"value": "cats are cute", "label": 1, "key": "g1", "score": 0.9, "source_id": "s2"},
94
+ {"value": "soup is good", "label": 0, "key": "g1", "score": 0.1, "source_id": "s3"},
95
+ {"value": "i love cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s4"},
96
+ {"value": "everyone loves cats", "label": 1, "key": "g1", "score": 0.9, "source_id": "s5"},
97
+ {"value": "soup is great for the winter", "label": 0, "key": "g1", "score": 0.1, "source_id": "s6"},
98
+ {"value": "hot soup on a rainy day!", "label": 0, "key": "g1", "score": 0.1, "source_id": "s7"},
99
+ {"value": "cats sleep all day", "label": 1, "key": "g1", "score": 0.9, "source_id": "s8"},
100
+ {"value": "homemade soup recipes", "label": 0, "key": "g1", "score": 0.1, "source_id": "s9"},
101
+ {"value": "cats purr when happy", "label": 1, "key": "g2", "score": 0.9, "source_id": "s10"},
102
+ {"value": "chicken noodle soup is classic", "label": 0, "key": "g1", "score": 0.1, "source_id": "s11"},
103
+ {"value": "kittens are baby cats", "label": 1, "key": "g2", "score": 0.9, "source_id": "s12"},
104
+ {"value": "soup can be served cold too", "label": 0, "key": "g1", "score": 0.1, "source_id": "s13"},
105
+ {"value": "cats have nine lives", "label": 1, "key": "g2", "score": 0.9, "source_id": "s14"},
106
+ {"value": "tomato soup with grilled cheese", "label": 0, "key": "g1", "score": 0.1, "source_id": "s15"},
107
+ {"value": "cats are independent animals", "label": 1, "key": "g2", "score": 0.9, "source_id": "s16"},
90
108
  ]
91
109
 
92
110
 
93
111
  @pytest.fixture(scope="session")
94
- def hf_dataset(label_names):
112
+ def hf_dataset(label_names: list[str]) -> Dataset:
95
113
  return Dataset.from_list(
96
114
  SAMPLE_DATA,
97
115
  features=Features(
@@ -107,16 +125,16 @@ def hf_dataset(label_names):
107
125
 
108
126
 
109
127
  @pytest.fixture(scope="session")
110
- def datasource(hf_dataset) -> Datasource:
128
+ def datasource(hf_dataset: Dataset) -> Datasource:
111
129
  datasource = Datasource.from_hf_dataset("test_datasource", hf_dataset)
112
130
  return datasource
113
131
 
114
132
 
115
133
  EVAL_DATASET = [
116
- {"value": "chicken noodle soup is the best", "label": 1},
117
- {"value": "cats are cute", "label": 0},
118
- {"value": "soup is great for the winter", "label": 0},
119
- {"value": "i love cats", "label": 1},
134
+ {"value": "chicken noodle soup is the best", "label": 1, "score": 0.9}, # mislabeled
135
+ {"value": "cats are cute", "label": 0, "score": 0.1}, # mislabeled
136
+ {"value": "soup is great for the winter", "label": 0, "score": 0.1},
137
+ {"value": "i love cats", "label": 1, "score": 0.9},
120
138
  ]
121
139
 
122
140
 
@@ -140,6 +158,8 @@ def readonly_memoryset(datasource: Datasource) -> LabeledMemoryset:
140
158
  embedding_model=PretrainedEmbeddingModel.GTE_BASE,
141
159
  source_id_column="source_id",
142
160
  max_seq_length_override=32,
161
+ index_type="IVF_FLAT",
162
+ index_params={"n_lists": 100},
143
163
  )
144
164
  return memoryset
145
165
 
@@ -176,14 +196,45 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
176
196
 
177
197
  if memory_ids:
178
198
  memoryset.delete(memory_ids)
199
+ memoryset.refresh()
179
200
  assert len(memoryset) == 0
180
201
  memoryset.insert(SAMPLE_DATA)
181
202
  # If the test dropped the memoryset, do nothing — it will be recreated on the next use.
182
203
 
183
204
 
184
205
  @pytest.fixture(scope="session")
185
- def model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
206
+ def classification_model(readonly_memoryset: LabeledMemoryset) -> ClassificationModel:
186
207
  model = ClassificationModel.create(
187
- "test_model", readonly_memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
208
+ "test_classification_model",
209
+ readonly_memoryset,
210
+ num_classes=2,
211
+ memory_lookup_count=3,
212
+ description="test_description",
213
+ )
214
+ return model
215
+
216
+
217
+ # Add scored memoryset and regression model fixtures
218
+ @pytest.fixture(scope="session")
219
+ def scored_memoryset(datasource: Datasource) -> ScoredMemoryset:
220
+ memoryset = ScoredMemoryset.create(
221
+ "test_scored_memoryset",
222
+ datasource=datasource,
223
+ embedding_model=PretrainedEmbeddingModel.GTE_BASE,
224
+ source_id_column="source_id",
225
+ max_seq_length_override=32,
226
+ index_type="IVF_FLAT",
227
+ index_params={"n_lists": 100},
228
+ )
229
+ return memoryset
230
+
231
+
232
+ @pytest.fixture(scope="session")
233
+ def regression_model(scored_memoryset: ScoredMemoryset) -> RegressionModel:
234
+ model = RegressionModel.create(
235
+ "test_regression_model",
236
+ scored_memoryset,
237
+ memory_lookup_count=3,
238
+ description="test_regression_description",
188
239
  )
189
240
  return model