orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +84 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  42. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  43. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  44. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  45. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  46. orca_sdk/_generated_api_client/models/__init__.py +90 -24
  47. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  48. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  49. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  50. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  51. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  52. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  53. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  54. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  55. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  56. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  57. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  58. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  59. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  60. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  61. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  62. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  63. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  64. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  65. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  66. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  67. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  68. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  69. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  70. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  71. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  72. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  73. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  74. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  75. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  76. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  77. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  78. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  79. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  80. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  81. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  82. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  83. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  84. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  85. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  86. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  88. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  92. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  93. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  94. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  95. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  96. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  97. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  98. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  99. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  100. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  101. orca_sdk/_shared/__init__.py +9 -1
  102. orca_sdk/_shared/metrics.py +257 -87
  103. orca_sdk/_shared/metrics_test.py +136 -77
  104. orca_sdk/_utils/data_parsing.py +0 -3
  105. orca_sdk/_utils/data_parsing_test.py +0 -3
  106. orca_sdk/_utils/prediction_result_ui.py +55 -23
  107. orca_sdk/classification_model.py +184 -174
  108. orca_sdk/classification_model_test.py +178 -142
  109. orca_sdk/conftest.py +77 -26
  110. orca_sdk/datasource.py +34 -0
  111. orca_sdk/datasource_test.py +9 -1
  112. orca_sdk/embedding_model.py +136 -14
  113. orca_sdk/embedding_model_test.py +10 -6
  114. orca_sdk/job.py +329 -0
  115. orca_sdk/job_test.py +48 -0
  116. orca_sdk/memoryset.py +882 -161
  117. orca_sdk/memoryset_test.py +58 -23
  118. orca_sdk/regression_model.py +647 -0
  119. orca_sdk/regression_model_test.py +338 -0
  120. orca_sdk/telemetry.py +225 -106
  121. orca_sdk/telemetry_test.py +34 -30
  122. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  123. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
  124. orca_sdk/_utils/task.py +0 -73
  125. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -1,19 +1,15 @@
1
+ import os
1
2
  import random
2
- import time
3
- from typing import Generator
4
3
  from uuid import uuid4
5
4
 
6
5
  import pytest
7
- from datasets import ClassLabel, Features, Value
8
6
  from datasets.arrow_dataset import Dataset
9
7
 
10
- from orca_sdk.conftest import SAMPLE_DATA
11
-
12
- from ._generated_api_client.models import CascadingEditSuggestion
13
8
  from .classification_model import ClassificationModel
9
+ from .conftest import skip_in_prod
14
10
  from .datasource import Datasource
15
11
  from .embedding_model import PretrainedEmbeddingModel
16
- from .memoryset import LabeledMemoryset, TaskStatus
12
+ from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
17
13
 
18
14
  """
19
15
  Test Performance Note:
@@ -37,9 +33,11 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
37
33
  assert readonly_memoryset.name == "test_readonly_memoryset"
38
34
  assert readonly_memoryset.embedding_model == PretrainedEmbeddingModel.GTE_BASE
39
35
  assert readonly_memoryset.label_names == label_names
40
- assert readonly_memoryset.insertion_status == TaskStatus.COMPLETED
36
+ assert readonly_memoryset.insertion_status == Status.COMPLETED
41
37
  assert isinstance(readonly_memoryset.length, int)
42
38
  assert readonly_memoryset.length == len(hf_dataset)
39
+ assert readonly_memoryset.index_type == "IVF_FLAT"
40
+ assert readonly_memoryset.index_params == {"n_lists": 100}
43
41
 
44
42
 
45
43
  def test_create_memoryset_unauthenticated(unauthenticated, datasource):
@@ -95,6 +93,8 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
95
93
  assert fetched_memoryset is not None
96
94
  assert fetched_memoryset.name == readonly_memoryset.name
97
95
  assert fetched_memoryset.length == len(hf_dataset)
96
+ assert fetched_memoryset.index_type == "IVF_FLAT"
97
+ assert fetched_memoryset.index_params == {"n_lists": 100}
98
98
 
99
99
 
100
100
  def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
@@ -149,15 +149,25 @@ def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
149
149
  LabeledMemoryset.drop(readonly_memoryset.name)
150
150
 
151
151
 
152
- def test_update_memoryset_metadata(writable_memoryset: LabeledMemoryset):
153
- # NOTE: We're combining multiple tests into one here to avoid multiple API calls
154
-
155
- writable_memoryset.update_metadata(description="New description")
152
+ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
153
+ original_label_names = writable_memoryset.label_names
154
+ writable_memoryset.set(description="New description")
156
155
  assert writable_memoryset.description == "New description"
157
156
 
158
- writable_memoryset.update_metadata(description=None)
157
+ writable_memoryset.set(description=None)
159
158
  assert writable_memoryset.description is None
160
159
 
160
+ writable_memoryset.set(name="New_name")
161
+ assert writable_memoryset.name == "New_name"
162
+
163
+ writable_memoryset.set(name="test_writable_memoryset")
164
+ assert writable_memoryset.name == "test_writable_memoryset"
165
+
166
+ assert writable_memoryset.label_names == original_label_names
167
+
168
+ writable_memoryset.set(label_names=["New label 1", "New label 2"])
169
+ assert writable_memoryset.label_names == ["New label 1", "New label 2"]
170
+
161
171
 
162
172
  def test_search(readonly_memoryset: LabeledMemoryset):
163
173
  memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
@@ -214,11 +224,11 @@ def test_query_memoryset(readonly_memoryset: LabeledMemoryset):
214
224
  assert len(memories) == 8
215
225
  assert all(memory.label == 1 for memory in memories)
216
226
  assert len(readonly_memoryset.query(limit=2)) == 2
217
- assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "val1")])) == 1
227
+ assert len(readonly_memoryset.query(filters=[("metadata.key", "==", "g2")])) == 4
218
228
 
219
229
 
220
- def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
221
- prediction = model.predict("Do you love soup?")
230
+ def test_query_memoryset_with_feedback_metrics(classification_model: ClassificationModel):
231
+ prediction = classification_model.predict("Do you love soup?")
222
232
  feedback_name = f"correct_{random.randint(0, 1000000)}"
223
233
  prediction.record_feedback(category=feedback_name, value=prediction.label == 0)
224
234
  memories = prediction.memoryset.query(filters=[("label", "==", 0)], with_feedback_metrics=True)
@@ -239,8 +249,8 @@ def test_query_memoryset_with_feedback_metrics(model: ClassificationModel):
239
249
  assert isinstance(memory.lookup_count, int)
240
250
 
241
251
 
242
- def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel):
243
- prediction = model.predict("Do you love soup?")
252
+ def test_query_memoryset_with_feedback_metrics_filter(classification_model: ClassificationModel):
253
+ prediction = classification_model.predict("Do you love soup?")
244
254
  prediction.record_feedback(category="accurate", value=prediction.label == 0)
245
255
  memories = prediction.memoryset.query(
246
256
  filters=[("feedback_metrics.accurate.avg", ">", 0.5)], with_feedback_metrics=True
@@ -254,10 +264,10 @@ def test_query_memoryset_with_feedback_metrics_filter(model: ClassificationModel
254
264
  assert memory.feedback_metrics["accurate"]["count"] == 1
255
265
 
256
266
 
257
- def test_query_memoryset_with_feedback_metrics_sort(model: ClassificationModel):
258
- prediction = model.predict("Do you love soup?")
267
+ def test_query_memoryset_with_feedback_metrics_sort(classification_model: ClassificationModel):
268
+ prediction = classification_model.predict("Do you love soup?")
259
269
  prediction.record_feedback(category="positive", value=1.0)
260
- prediction2 = model.predict("Do you like cats?")
270
+ prediction2 = classification_model.predict("Do you like cats?")
261
271
  prediction2.record_feedback(category="positive", value=-1.0)
262
272
 
263
273
  memories = prediction.memoryset.query(
@@ -281,8 +291,10 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
281
291
  dict(value="cats are fun to play with", label=1),
282
292
  ]
283
293
  )
294
+ writable_memoryset.refresh()
284
295
  assert writable_memoryset.length == prev_length + 2
285
296
  writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
297
+ writable_memoryset.refresh()
286
298
  assert writable_memoryset.length == prev_length + 3
287
299
  last_memory = writable_memoryset[-1]
288
300
  assert last_memory.value == "tomato soup is my favorite"
@@ -292,6 +304,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
292
304
  assert last_memory.source_id == "test"
293
305
 
294
306
 
307
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
295
308
  def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
296
309
  # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
297
310
 
@@ -300,6 +313,7 @@ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Datas
300
313
  updated_memory = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
301
314
  assert updated_memory.value == "i love soup so much"
302
315
  assert updated_memory.label == hf_dataset[0]["label"]
316
+ writable_memoryset.refresh() # Refresh to ensure consistency after update
303
317
  assert writable_memoryset.get(memory_id).value == "i love soup so much"
304
318
 
305
319
  # test updating a memory instance
@@ -346,7 +360,7 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
346
360
  assert cloned_memoryset.name == "test_cloned_memoryset"
347
361
  assert cloned_memoryset.length == readonly_memoryset.length
348
362
  assert cloned_memoryset.embedding_model == PretrainedEmbeddingModel.DISTILBERT
349
- assert cloned_memoryset.insertion_status == TaskStatus.COMPLETED
363
+ assert cloned_memoryset.insertion_status == Status.COMPLETED
350
364
 
351
365
 
352
366
  def test_embedding_evaluation(eval_datasource: Datasource):
@@ -361,7 +375,6 @@ def test_embedding_evaluation(eval_datasource: Datasource):
361
375
  assert response["evaluation_results"][0] is not None
362
376
  assert response["evaluation_results"][0]["embedding_model_name"] == "CDE_SMALL"
363
377
  assert response["evaluation_results"][0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
364
- Datasource.drop("eval_datasource")
365
378
 
366
379
 
367
380
  @pytest.fixture(scope="function")
@@ -453,3 +466,25 @@ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
453
466
  assert LabeledMemoryset.exists(writable_memoryset.name)
454
467
  LabeledMemoryset.drop(writable_memoryset.name)
455
468
  assert not LabeledMemoryset.exists(writable_memoryset.name)
469
+
470
+
471
+ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
472
+ assert scored_memoryset.length == 16
473
+ assert isinstance(scored_memoryset[0], ScoredMemory)
474
+ assert scored_memoryset[0].value == "i love soup"
475
+ assert scored_memoryset[0].score is not None
476
+ assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
477
+ lookup = scored_memoryset.search("i love soup", count=1)
478
+ assert len(lookup) == 1
479
+ assert lookup[0].score < 0.11
480
+
481
+
482
+ @skip_in_prod("Production memorysets do not have session consistency guarantees")
483
+ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
484
+ # we are only updating an inconsequential metadata field so that we don't affect other tests
485
+ memory = scored_memoryset[0]
486
+ assert memory.label == 0
487
+ scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
488
+ assert scored_memoryset[0].label == 3
489
+ memory.update(label=4)
490
+ assert scored_memoryset[0].label == 4