orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -172
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +337 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import random
|
|
2
|
-
import time
|
|
3
|
-
from typing import Generator
|
|
4
3
|
from uuid import uuid4
|
|
5
4
|
|
|
6
5
|
import pytest
|
|
7
|
-
from datasets import ClassLabel, Features, Value
|
|
8
6
|
from datasets.arrow_dataset import Dataset
|
|
9
7
|
|
|
10
|
-
from orca_sdk.conftest import SAMPLE_DATA
|
|
11
|
-
|
|
12
|
-
from ._generated_api_client.models import CascadingEditSuggestion
|
|
13
8
|
from .classification_model import ClassificationModel
|
|
9
|
+
from .conftest import skip_in_prod
|
|
14
10
|
from .datasource import Datasource
|
|
15
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
16
|
-
from .memoryset import LabeledMemoryset,
|
|
12
|
+
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
17
13
|
|
|
18
14
|
"""
|
|
19
15
|
Test Performance Note:
|
|
@@ -37,9 +33,11 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
|
|
|
37
33
|
assert readonly_memoryset.name == "test_readonly_memoryset"
|
|
38
34
|
assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
|
|
39
35
|
assert readonly_memoryset.label_names == label_names
|
|
40
|
-
assert readonly_memoryset.insertion_status ==
|
|
36
|
+
assert readonly_memoryset.insertion_status == Status.COMPLETED
|
|
41
37
|
assert isinstance(readonly_memoryset.length, int)
|
|
42
38
|
assert readonly_memoryset.length == len(hf_dataset)
|
|
39
|
+
assert readonly_memoryset.index_type == "IVF_FLAT"
|
|
40
|
+
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
43
41
|
|
|
44
42
|
|
|
45
43
|
def test_create_memoryset_unauthenticated(unauthenticated, datasource):
|
|
@@ -95,6 +93,8 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
|
95
93
|
assert fetched_memoryset is not None
|
|
96
94
|
assert fetched_memoryset.name == readonly_memoryset.name
|
|
97
95
|
assert fetched_memoryset.length == len(hf_dataset)
|
|
96
|
+
assert fetched_memoryset.index_type == "IVF_FLAT"
|
|
97
|
+
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
98
98
|
|
|
99
99
|
|
|
100
100
|
def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
|
|
@@ -149,15 +149,25 @@ def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
|
|
|
149
149
|
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
writable_memoryset.update_metadata(description="New description")
|
|
152
|
+
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
153
|
+
original_label_names = writable_memoryset.label_names
|
|
154
|
+
writable_memoryset.set(description="New description")
|
|
156
155
|
assert writable_memoryset.description == "New description"
|
|
157
156
|
|
|
158
|
-
writable_memoryset.
|
|
157
|
+
writable_memoryset.set(description=None)
|
|
159
158
|
assert writable_memoryset.description is None
|
|
160
159
|
|
|
160
|
+
writable_memoryset.set(name="New_name")
|
|
161
|
+
assert writable_memoryset.name == "New_name"
|
|
162
|
+
|
|
163
|
+
writable_memoryset.set(name="test_writable_memoryset")
|
|
164
|
+
assert writable_memoryset.name == "test_writable_memoryset"
|
|
165
|
+
|
|
166
|
+
assert writable_memoryset.label_names == original_label_names
|
|
167
|
+
|
|
168
|
+
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
169
|
+
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
170
|
+
|
|
161
171
|
|
|
162
172
|
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
163
173
|
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
@@ -214,11 +224,11 @@ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
214
224
|
assert len(memories) == 8
|
|
215
225
|
assert all(memory.label == 1 for memory in memories)
|
|
216
226
|
assert len(readonly_memoryset.query(limit=2)) == 2
|
|
217
|
-
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "
|
|
227
|
+
assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
|
|
218
228
|
|
|
219
229
|
|
|
220
|
-
def test_query_memoryset_with_feedback_metrics(
|
|
221
|
-
prediction =
|
|
230
|
+
def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
|
|
231
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
222
232
|
feedback_name = f"correct_{random.randint(0, 1000000)}"
|
|
223
233
|
prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
|
|
224
234
|
memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
|
|
@@ -239,8 +249,8 @@ def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
|
|
|
239
249
|
assert isinstance(memory.lookup_count, int)
|
|
240
250
|
|
|
241
251
|
|
|
242
|
-
def test_query_memoryset_with_feedback_metrics_filter(
|
|
243
|
-
prediction =
|
|
252
|
+
def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
|
|
253
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
244
254
|
prediction.record_feedback(category="accurate", value=prediction.label == 0)
|
|
245
255
|
memories = prediction.memoryset.query(
|
|
246
256
|
filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
|
|
@@ -254,10 +264,10 @@ def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel
|
|
|
254
264
|
assert memory.feedback_metrics["accurate"]["count"] == 1
|
|
255
265
|
|
|
256
266
|
|
|
257
|
-
def test_query_memoryset_with_feedback_metrics_sort(
|
|
258
|
-
prediction =
|
|
267
|
+
def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
|
|
268
|
+
prediction = classification_model.predict("Do you love soup?")
|
|
259
269
|
prediction.record_feedback(category="positive", value=1.0)
|
|
260
|
-
prediction2 =
|
|
270
|
+
prediction2 = classification_model.predict("Do you like cats?")
|
|
261
271
|
prediction2.record_feedback(category="positive", value=-1.0)
|
|
262
272
|
|
|
263
273
|
memories = prediction.memoryset.query(
|
|
@@ -294,6 +304,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
294
304
|
assert last_memory.source_id == "test"
|
|
295
305
|
|
|
296
306
|
|
|
307
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
297
308
|
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
298
309
|
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
299
310
|
|
|
@@ -302,6 +313,7 @@ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Datas
|
|
|
302
313
|
updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
303
314
|
assert updated_memory.value == "i love soup so much"
|
|
304
315
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
316
|
+
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
305
317
|
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
306
318
|
|
|
307
319
|
# test updating a memory instance
|
|
@@ -348,7 +360,7 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
348
360
|
assert cloned_memoryset.name == "test_cloned_memoryset"
|
|
349
361
|
assert cloned_memoryset.length == readonly_memoryset.length
|
|
350
362
|
assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
|
|
351
|
-
assert cloned_memoryset.insertion_status ==
|
|
363
|
+
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
352
364
|
|
|
353
365
|
|
|
354
366
|
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
@@ -363,7 +375,6 @@ def test_embedding_evaluation(eval_datasource: Datasource):
|
|
|
363
375
|
assert response["evaluation_results"][0] is not None
|
|
364
376
|
assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
|
|
365
377
|
assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
366
|
-
Datasource.drop("eval_datasource")
|
|
367
378
|
|
|
368
379
|
|
|
369
380
|
@pytest.fixture(scope="function")
|
|
@@ -455,3 +466,25 @@ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
|
455
466
|
assert LabeledMemoryset.exists(writable_memoryset.name)
|
|
456
467
|
LabeledMemoryset.drop(writable_memoryset.name)
|
|
457
468
|
assert not LabeledMemoryset.exists(writable_memoryset.name)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
472
|
+
assert scored_memoryset.length == 16
|
|
473
|
+
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
474
|
+
assert scored_memoryset[0].value == "i love soup"
|
|
475
|
+
assert scored_memoryset[0].score is not None
|
|
476
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
|
|
477
|
+
lookup = scored_memoryset.search("i love soup", count=1)
|
|
478
|
+
assert len(lookup) == 1
|
|
479
|
+
assert lookup[0].score < 0.11
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
483
|
+
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
484
|
+
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
485
|
+
memory = scored_memoryset[0]
|
|
486
|
+
assert memory.label == 0
|
|
487
|
+
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
488
|
+
assert scored_memoryset[0].label == 3
|
|
489
|
+
memory.update(label=4)
|
|
490
|
+
assert scored_memoryset[0].label == 4
|