orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -172
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +337 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
@@ -1,19 +1,15 @@
1
+ import os
1
2
  import random
2
- import time
3
- from typing import Generator
4
3
  from uuid import uuid4
5
4
 
6
5
  import pytest
7
- from datasets import ClassLabel, Features, Value
8
6
  from datasets.arrow_dataset import Dataset
9
7
 
10
- from orca_sdk.conftest import SAMPLE_DATA
11
-
12
- from ._generated_api_client.models import CascadingEditSuggestion
13
8
  from .classification_model import ClassificationModel
9
+ from .conftest import skip_in_prod
14
10
  from .datasource import Datasource
15
11
  from .embedding_model import PretrainedEmbeddingModel
16
- from .memoryset import LabeledMemoryset, TaskStatus
12
+ from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
17
13
 
18
14
  """
19
15
  Test Performance Note:
@@ -37,9 +33,11 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
37
33
  assert readonly_memoryset.name == "test_readonly_memoryset"
38
34
  assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
39
35
  assert readonly_memoryset.label_names == label_names
40
- assert readonly_memoryset.insertion_status == TaskStatus.COMPLETED
36
+ assert readonly_memoryset.insertion_status == Status.COMPLETED
41
37
  assert isinstance(readonly_memoryset.length, int)
42
38
  assert readonly_memoryset.length == len(hf_dataset)
39
+ assert readonly_memoryset.index_type == "IVF_FLAT"
40
+ assert readonly_memoryset.index_params == {"n_lists": 100}
43
41
 
44
42
 
45
43
  def test_create_memoryset_unauthenticated(unauthenticated, datasource):
@@ -95,6 +93,8 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
95
93
  assert fetched_memoryset is not None
96
94
  assert fetched_memoryset.name == readonly_memoryset.name
97
95
  assert fetched_memoryset.length == len(hf_dataset)
96
+ assert fetched_memoryset.index_type == "IVF_FLAT"
97
+ assert fetched_memoryset.index_params == {"n_lists": 100}
98
98
 
99
99
 
100
100
  def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
@@ -149,15 +149,25 @@ def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
149
149
  LabeledMemoryset.drop(readonly_memoryset.name)
150
150
 
151
151
 
152
- def test_update_memoryset_metadata(writable_memoryset: LabeledMemoryset):
153
- # NOTE: We're combining multiple tests into one here to avoid multiple API calls
154
-
155
- writable_memoryset.update_metadata(description="New description")
152
+ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
153
+ original_label_names = writable_memoryset.label_names
154
+ writable_memoryset.set(description="New description")
156
155
  assert writable_memoryset.description == "New description"
157
156
 
158
- writable_memoryset.update_metadata(description=None)
157
+ writable_memoryset.set(description=None)
159
158
  assert writable_memoryset.description is None
160
159
 
160
+ writable_memoryset.set(name="New_name")
161
+ assert writable_memoryset.name == "New_name"
162
+
163
+ writable_memoryset.set(name="test_writable_memoryset")
164
+ assert writable_memoryset.name == "test_writable_memoryset"
165
+
166
+ assert writable_memoryset.label_names == original_label_names
167
+
168
+ writable_memoryset.set(label_names=["New label 1", "New label 2"])
169
+ assert writable_memoryset.label_names == ["New label 1", "New label 2"]
170
+
161
171
 
162
172
  def test_search(readonly_memoryset: LabeledMemoryset):
163
173
  memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
@@ -214,11 +224,11 @@ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
214
224
  assert len(memories) == 8
215
225
  assert all(memory.label == 1 for memory in memories)
216
226
  assert len(readonly_memoryset.query(limit=2)) == 2
217
- assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
227
+ assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
218
228
 
219
229
 
220
- def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
221
- prediction = model.predict("Do you love soup?")
230
+ def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
231
+ prediction = classification_model.predict("Do you love soup?")
222
232
  feedback_name = f"correct_{random.randint(0, 1000000)}"
223
233
  prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
224
234
  memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
@@ -239,8 +249,8 @@ def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
239
249
  assert isinstance(memory.lookup_count, int)
240
250
 
241
251
 
242
- def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel):
243
- prediction = model.predict("Do you love soup?")
252
+ def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
253
+ prediction = classification_model.predict("Do you love soup?")
244
254
  prediction.record_feedback(category="accurate", value=prediction.label == 0)
245
255
  memories = prediction.memoryset.query(
246
256
  filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
@@ -254,10 +264,10 @@ def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel
254
264
  assert memory.feedback_metrics["accurate"]["count"] == 1
255
265
 
256
266
 
257
- def test_query_memoryset_with_feedback_metrics_sort(model: ClassificationModel):
258
- prediction = model.predict("Do you love soup?")
267
+ def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
268
+ prediction = classification_model.predict("Do you love soup?")
259
269
  prediction.record_feedback(category="positive", value=1.0)
260
- prediction2 = model.predict("Do you like cats?")
270
+ prediction2 = classification_model.predict("Do you like cats?")
261
271
  prediction2.record_feedback(category="positive", value=-1.0)
262
272
 
263
273
  memories = prediction.memoryset.query(
@@ -294,6 +304,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
294
304
  assert last_memory.source_id == "test"
295
305
 
296
306
 
307
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
297
308
  def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
298
309
  # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
299
310
 
@@ -302,6 +313,7 @@ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Datas
302
313
  updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
303
314
  assert updated_memory.value == "i love soup so much"
304
315
  assert updated_memory.label == hf_dataset[0]["label"]
316
+ writable_memoryset.refresh() # Refresh to ensure consistency after update
305
317
  assert writable_memoryset.get(memory_id).value == "i love soup so much"
306
318
 
307
319
  # test updating a memory instance
@@ -348,7 +360,7 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
348
360
  assert cloned_memoryset.name == "test_cloned_memoryset"
349
361
  assert cloned_memoryset.length == readonly_memoryset.length
350
362
  assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
351
- assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
363
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
352
364
 
353
365
 
354
366
  def test_embedding_evaluation(eval_datasource: Datasource):
@@ -363,7 +375,6 @@ def test_embedding_evaluation(eval_datasource: Datasource):
363
375
  assert response["evaluation_results"][0] is not None
364
376
  assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
365
377
  assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
366
- Datasource.drop("eval_datasource")
367
378
 
368
379
 
369
380
  @pytest.fixture(scope="function")
@@ -455,3 +466,25 @@ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
455
466
  assert LabeledMemoryset.exists(writable_memoryset.name)
456
467
  LabeledMemoryset.drop(writable_memoryset.name)
457
468
  assert not LabeledMemoryset.exists(writable_memoryset.name)
469
+
470
+
471
+ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
472
+ assert scored_memoryset.length == 16
473
+ assert isinstance(scored_memoryset[0], ScoredMemory)
474
+ assert scored_memoryset[0].value == "i love soup"
475
+ assert scored_memoryset[0].score is not None
476
+ assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
477
+ lookup = scored_memoryset.search("i love soup", count=1)
478
+ assert len(lookup) == 1
479
+ assert lookup[0].score < 0.11
480
+
481
+
482
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
483
+ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
484
+ # we are only updating an inconsequential metadata field so that we don't affect other tests
485
+ memory = scored_memoryset[0]
486
+ assert memory.label == 0
487
+ scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
488
+ assert scored_memoryset[0].label == 3
489
+ memory.update(label=4)
490
+ assert scored_memoryset[0].label == 4