orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,10 +5,11 @@ import pytest
5
5
  from datasets.arrow_dataset import Dataset
6
6
 
7
7
  from .classification_model import ClassificationModel
8
- from .conftest import skip_in_prod
8
+ from .conftest import skip_in_ci, skip_in_prod
9
9
  from .datasource import Datasource
10
10
  from .embedding_model import PretrainedEmbeddingModel
11
11
  from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
12
+ from .regression_model import RegressionModel
12
13
 
13
14
  """
14
15
  Test Performance Note:
@@ -39,9 +40,10 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
39
40
  assert readonly_memoryset.index_params == {"n_lists": 100}
40
41
 
41
42
 
42
- def test_create_memoryset_unauthenticated(unauthenticated, datasource):
43
- with pytest.raises(ValueError, match="Invalid API key"):
44
- LabeledMemoryset.create("test_memoryset", datasource)
43
+ def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
44
+ with unauthenticated_client.use():
45
+ with pytest.raises(ValueError, match="Invalid API key"):
46
+ LabeledMemoryset.create("test_memoryset", datasource)
45
47
 
46
48
 
47
49
  def test_create_memoryset_invalid_input(datasource):
@@ -87,6 +89,75 @@ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_
87
89
  assert opened_memoryset.length == len(hf_dataset)
88
90
 
89
91
 
92
+ def test_if_exists_error_no_datasource_creation(
93
+ readonly_memoryset: LabeledMemoryset,
94
+ ):
95
+ memoryset_name = readonly_memoryset.name
96
+ datasource_name = f"{memoryset_name}_datasource"
97
+ Datasource.drop(datasource_name, if_not_exists="ignore")
98
+ assert not Datasource.exists(datasource_name)
99
+ with pytest.raises(ValueError):
100
+ LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
101
+ assert not Datasource.exists(datasource_name)
102
+
103
+
104
+ def test_if_exists_open_reuses_existing_datasource(
105
+ readonly_memoryset: LabeledMemoryset,
106
+ ):
107
+ memoryset_name = readonly_memoryset.name
108
+ datasource_name = f"{memoryset_name}_datasource"
109
+ Datasource.drop(datasource_name, if_not_exists="ignore")
110
+ assert not Datasource.exists(datasource_name)
111
+ reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
112
+ assert reopened.id == readonly_memoryset.id
113
+ assert not Datasource.exists(datasource_name)
114
+
115
+
116
+ def test_create_memoryset_string_label():
117
+ assert not LabeledMemoryset.exists("test_string_label")
118
+ memoryset = LabeledMemoryset.from_hf_dataset(
119
+ "test_string_label",
120
+ Dataset.from_dict({"value": ["terrible", "great"], "label": ["negative", "positive"]}),
121
+ )
122
+ assert memoryset is not None
123
+ assert memoryset.length == 2
124
+ assert memoryset.label_names == ["negative", "positive"]
125
+ assert memoryset[0].label == 0
126
+ assert memoryset[1].label == 1
127
+ assert memoryset[0].label_name == "negative"
128
+ assert memoryset[1].label_name == "positive"
129
+
130
+
131
+ def test_create_memoryset_integer_label():
132
+ assert not LabeledMemoryset.exists("test_integer_label")
133
+ memoryset = LabeledMemoryset.from_hf_dataset(
134
+ "test_integer_label",
135
+ Dataset.from_dict({"value": ["terrible", "great"], "label": [0, 1]}),
136
+ label_names=["negative", "positive"],
137
+ )
138
+ assert memoryset is not None
139
+ assert memoryset.length == 2
140
+ assert memoryset.label_names == ["negative", "positive"]
141
+ assert memoryset[0].label == 0
142
+ assert memoryset[1].label == 1
143
+ assert memoryset[0].label_name == "negative"
144
+ assert memoryset[1].label_name == "positive"
145
+
146
+
147
+ def test_create_memoryset_null_labels():
148
+ memoryset = LabeledMemoryset.from_hf_dataset(
149
+ "test_null_labels",
150
+ Dataset.from_dict({"value": ["terrible", "great"]}),
151
+ label_names=["negative", "positive"],
152
+ label_column=None,
153
+ )
154
+ assert memoryset is not None
155
+ assert memoryset.length == 2
156
+ assert memoryset.label_names == ["negative", "positive"]
157
+ assert memoryset[0].label == None
158
+ assert memoryset[1].label == None
159
+
160
+
90
161
  def test_open_memoryset(readonly_memoryset, hf_dataset):
91
162
  fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
92
163
  assert fetched_memoryset is not None
@@ -96,9 +167,10 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
96
167
  assert fetched_memoryset.index_params == {"n_lists": 100}
97
168
 
98
169
 
99
- def test_open_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
100
- with pytest.raises(ValueError, match="Invalid API key"):
101
- LabeledMemoryset.open(readonly_memoryset.name)
170
+ def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
171
+ with unauthenticated_client.use():
172
+ with pytest.raises(ValueError, match="Invalid API key"):
173
+ LabeledMemoryset.open(readonly_memoryset.name)
102
174
 
103
175
 
104
176
  def test_open_memoryset_not_found():
@@ -111,9 +183,10 @@ def test_open_memoryset_invalid_input():
111
183
  LabeledMemoryset.open("not valid id")
112
184
 
113
185
 
114
- def test_open_memoryset_unauthorized(unauthorized, readonly_memoryset):
115
- with pytest.raises(LookupError):
116
- LabeledMemoryset.open(readonly_memoryset.name)
186
+ def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
187
+ with unauthorized_client.use():
188
+ with pytest.raises(LookupError):
189
+ LabeledMemoryset.open(readonly_memoryset.name)
117
190
 
118
191
 
119
192
  def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
@@ -142,18 +215,21 @@ def test_all_memorysets_hidden(
142
215
  assert hidden_memoryset in all_memorysets
143
216
 
144
217
 
145
- def test_all_memorysets_unauthenticated(unauthenticated):
146
- with pytest.raises(ValueError, match="Invalid API key"):
147
- LabeledMemoryset.all()
218
+ def test_all_memorysets_unauthenticated(unauthenticated_client):
219
+ with unauthenticated_client.use():
220
+ with pytest.raises(ValueError, match="Invalid API key"):
221
+ LabeledMemoryset.all()
148
222
 
149
223
 
150
- def test_all_memorysets_unauthorized(unauthorized, readonly_memoryset):
151
- assert readonly_memoryset not in LabeledMemoryset.all()
224
+ def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
225
+ with unauthorized_client.use():
226
+ assert readonly_memoryset not in LabeledMemoryset.all()
152
227
 
153
228
 
154
- def test_drop_memoryset_unauthenticated(unauthenticated, readonly_memoryset):
155
- with pytest.raises(ValueError, match="Invalid API key"):
156
- LabeledMemoryset.drop(readonly_memoryset.name)
229
+ def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
230
+ with unauthenticated_client.use():
231
+ with pytest.raises(ValueError, match="Invalid API key"):
232
+ LabeledMemoryset.drop(readonly_memoryset.name)
157
233
 
158
234
 
159
235
  def test_drop_memoryset_not_found():
@@ -163,9 +239,10 @@ def test_drop_memoryset_not_found():
163
239
  LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
164
240
 
165
241
 
166
- def test_drop_memoryset_unauthorized(unauthorized, readonly_memoryset):
167
- with pytest.raises(LookupError):
168
- LabeledMemoryset.drop(readonly_memoryset.name)
242
+ def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
243
+ with unauthorized_client.use():
244
+ with pytest.raises(LookupError):
245
+ LabeledMemoryset.drop(readonly_memoryset.name)
169
246
 
170
247
 
171
248
  def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
@@ -304,6 +381,143 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
304
381
  assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
305
382
 
306
383
 
384
+ def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
385
+ """Test that LabeledMemory.predictions() only returns classification predictions."""
386
+ # Given: A classification model with memories
387
+ memories = classification_model.memoryset.query(limit=1)
388
+ assert len(memories) > 0
389
+ memory = memories[0]
390
+
391
+ # When: I call the predictions method
392
+ predictions = memory.predictions()
393
+
394
+ # Then: It should return a list of ClassificationPrediction objects
395
+ assert isinstance(predictions, list)
396
+ for prediction in predictions:
397
+ assert prediction.__class__.__name__ == "ClassificationPrediction"
398
+ assert hasattr(prediction, "label")
399
+ assert not hasattr(prediction, "score") or prediction.score is None
400
+
401
+
402
+ def test_scored_memory_predictions_property(regression_model: RegressionModel):
403
+ """Test that ScoredMemory.predictions() only returns regression predictions."""
404
+ # Given: A regression model with memories
405
+ memories = regression_model.memoryset.query(limit=1)
406
+ assert len(memories) > 0
407
+ memory = memories[0]
408
+
409
+ # When: I call the predictions method
410
+ predictions = memory.predictions()
411
+
412
+ # Then: It should return a list of RegressionPrediction objects
413
+ assert isinstance(predictions, list)
414
+ for prediction in predictions:
415
+ assert prediction.__class__.__name__ == "RegressionPrediction"
416
+ assert hasattr(prediction, "score")
417
+ assert not hasattr(prediction, "label") or prediction.label is None
418
+
419
+
420
+ def test_memory_feedback_property(classification_model: ClassificationModel):
421
+ """Test that memory.feedback() returns feedback from relevant predictions."""
422
+ # Given: A prediction with recorded feedback
423
+ prediction = classification_model.predict("Test feedback")
424
+ feedback_category = f"test_feedback_{random.randint(0, 1000000)}"
425
+ prediction.record_feedback(category=feedback_category, value=True)
426
+
427
+ # And: A memory that was used in the prediction
428
+ memory_lookups = prediction.memory_lookups
429
+ assert len(memory_lookups) > 0
430
+ memory = memory_lookups[0]
431
+
432
+ # When: I access the feedback property
433
+ feedback = memory.feedback()
434
+
435
+ # Then: It should return feedback aggregated by category as a dict
436
+ assert isinstance(feedback, dict)
437
+ assert feedback_category in feedback
438
+ # Feedback values are lists (you may want to look at mean on the raw data)
439
+ assert isinstance(feedback[feedback_category], list)
440
+ assert len(feedback[feedback_category]) > 0
441
+ # For binary feedback, values should be booleans
442
+ assert isinstance(feedback[feedback_category][0], bool)
443
+
444
+
445
+ def test_memory_predictions_method_parameters(classification_model: ClassificationModel):
446
+ """Test that memory.predictions() method supports pagination, sorting, and filtering."""
447
+ # Given: A classification model with memories
448
+ memories = classification_model.memoryset.query(limit=1)
449
+ assert len(memories) > 0
450
+ memory = memories[0]
451
+
452
+ # When: I call predictions with limit parameter
453
+ predictions_limited = memory.predictions(limit=2)
454
+
455
+ # Then: It should respect the limit
456
+ assert isinstance(predictions_limited, list)
457
+ assert len(predictions_limited) <= 2
458
+
459
+ # When: I call predictions with offset parameter
460
+ all_predictions = memory.predictions(limit=100)
461
+ if len(all_predictions) > 1:
462
+ predictions_offset = memory.predictions(limit=1, offset=1)
463
+ # Then: offset should skip the first prediction
464
+ assert predictions_offset[0].prediction_id != all_predictions[0].prediction_id
465
+
466
+ # When: I call predictions with sort parameter
467
+ predictions_sorted = memory.predictions(limit=10, sort=[("timestamp", "desc")])
468
+ # Then: It should return predictions (sorting verified by API)
469
+ assert isinstance(predictions_sorted, list)
470
+
471
+ # When: I call predictions with expected_label_match parameter
472
+ correct_predictions = memory.predictions(expected_label_match=True)
473
+ incorrect_predictions = memory.predictions(expected_label_match=False)
474
+ # Then: Both should return lists (correctness verified by API filtering)
475
+ assert isinstance(correct_predictions, list)
476
+ assert isinstance(incorrect_predictions, list)
477
+
478
+
479
+ def test_memory_predictions_expected_label_filter(classification_model: ClassificationModel):
480
+ """Test that memory.predictions(expected_label_match=...) filters predictions by correctness."""
481
+ # Given: Make an initial prediction to learn the model's label for a known input
482
+ baseline_prediction = classification_model.predict("Filter test input", save_telemetry="sync")
483
+ original_label = baseline_prediction.label
484
+ alternate_label = 0 if original_label else 1
485
+
486
+ # When: Make a second prediction with an intentionally incorrect expected label
487
+ mismatched_prediction = classification_model.predict(
488
+ "Filter test input",
489
+ expected_labels=alternate_label,
490
+ save_telemetry="sync",
491
+ )
492
+ mismatched_memory = mismatched_prediction.memory_lookups[0]
493
+
494
+ # Then: The prediction should show up when filtering for incorrect predictions
495
+ incorrect_predictions = mismatched_memory.predictions(expected_label_match=False)
496
+ assert any(pred.prediction_id == mismatched_prediction.prediction_id for pred in incorrect_predictions)
497
+
498
+ # Produce a correct prediction (predicted label matches expected label)
499
+ correct_prediction = classification_model.predict(
500
+ "Filter test input",
501
+ expected_labels=original_label,
502
+ save_telemetry="sync",
503
+ )
504
+
505
+ # Ensure we are inspecting a memory used by both correct and incorrect predictions
506
+ correct_lookup_ids = {lookup.memory_id for lookup in correct_prediction.memory_lookups}
507
+ if mismatched_memory.memory_id not in correct_lookup_ids:
508
+ shared_lookup = next(
509
+ (lookup for lookup in mismatched_prediction.memory_lookups if lookup.memory_id in correct_lookup_ids),
510
+ None,
511
+ )
512
+ assert shared_lookup is not None, "No shared memory lookup between correct and incorrect predictions"
513
+ mismatched_memory = shared_lookup
514
+
515
+ # And: The correct prediction should appear when filtering for correct predictions
516
+ correct_predictions = mismatched_memory.predictions(expected_label_match=True)
517
+ assert any(pred.prediction_id == correct_prediction.prediction_id for pred in correct_predictions)
518
+ assert all(pred.prediction_id != mismatched_prediction.prediction_id for pred in correct_predictions)
519
+
520
+
307
521
  def test_insert_memories(writable_memoryset: LabeledMemoryset):
308
522
  writable_memoryset.refresh()
309
523
  prev_length = writable_memoryset.length
@@ -327,6 +541,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
327
541
 
328
542
 
329
543
  @skip_in_prod("Production memorysets do not have session consistency guarantees")
544
+ @skip_in_ci("CI environment may not have session consistency guarantees")
330
545
  def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
331
546
  # We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
332
547
 
@@ -385,17 +600,6 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
385
600
  assert cloned_memoryset.insertion_status == Status.COMPLETED
386
601
 
387
602
 
388
- def test_embedding_evaluation(eval_datasource: Datasource):
389
- results = LabeledMemoryset.run_embedding_evaluation(
390
- eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
391
- )
392
- assert isinstance(results, list)
393
- assert len(results) == 1
394
- assert results[0] is not None
395
- assert results[0]["embedding_model_name"] == "CDE_SMALL"
396
- assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
397
-
398
-
399
603
  @pytest.fixture(scope="function")
400
604
  async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
401
605
  writable_memoryset.insert(
@@ -492,7 +696,8 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
492
696
  assert isinstance(scored_memoryset[0], ScoredMemory)
493
697
  assert scored_memoryset[0].value == "i love soup"
494
698
  assert scored_memoryset[0].score is not None
495
- assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
699
+ assert scored_memoryset[0].metadata == {"key": "g1", "label": 0}
700
+ assert scored_memoryset[0].source_id == "s1"
496
701
  lookup = scored_memoryset.search("i love soup", count=1)
497
702
  assert len(lookup) == 1
498
703
  assert lookup[0].score is not None
@@ -508,3 +713,80 @@ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
508
713
  assert scored_memoryset[0].label == 3
509
714
  memory.update(label=4)
510
715
  assert scored_memoryset[0].label == 4
716
+
717
+
718
+ @pytest.mark.asyncio
719
+ async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
720
+ """Test async insertion of a single memory"""
721
+ await writable_memoryset.arefresh()
722
+ prev_length = writable_memoryset.length
723
+
724
+ await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
725
+
726
+ await writable_memoryset.arefresh()
727
+ assert writable_memoryset.length == prev_length + 1
728
+ last_memory = writable_memoryset[-1]
729
+ assert last_memory.value == "async tomato soup is my favorite"
730
+ assert last_memory.label == 0
731
+ assert last_memory.metadata["key"] == "async_test"
732
+
733
+
734
+ @pytest.mark.asyncio
735
+ async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
736
+ """Test async insertion of multiple memories"""
737
+ await writable_memoryset.arefresh()
738
+ prev_length = writable_memoryset.length
739
+
740
+ await writable_memoryset.ainsert(
741
+ [
742
+ dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
743
+ dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
744
+ ]
745
+ )
746
+
747
+ await writable_memoryset.arefresh()
748
+ assert writable_memoryset.length == prev_length + 2
749
+
750
+ # Check the inserted memories
751
+ last_two_memories = writable_memoryset[-2:]
752
+ values = [memory.value for memory in last_two_memories]
753
+ labels = [memory.label for memory in last_two_memories]
754
+ keys = [memory.metadata.get("key") for memory in last_two_memories]
755
+
756
+ assert "async batch soup is delicious" in values
757
+ assert "async batch cats are adorable" in values
758
+ assert 0 in labels
759
+ assert 1 in labels
760
+ assert "batch_test_1" in keys
761
+ assert "batch_test_2" in keys
762
+
763
+
764
+ @pytest.mark.asyncio
765
+ async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
766
+ """Test async insertion with source_id and metadata"""
767
+ await writable_memoryset.arefresh()
768
+ prev_length = writable_memoryset.length
769
+
770
+ await writable_memoryset.ainsert(
771
+ dict(
772
+ value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
773
+ )
774
+ )
775
+
776
+ await writable_memoryset.arefresh()
777
+ assert writable_memoryset.length == prev_length + 1
778
+ last_memory = writable_memoryset[-1]
779
+ assert last_memory.value == "async soup with source id"
780
+ assert last_memory.label == 0
781
+ assert last_memory.source_id == "async_source_123"
782
+ assert last_memory.metadata["custom_field"] == "async_custom_value"
783
+
784
+
785
+ @pytest.mark.asyncio
786
+ async def test_insert_memories_async_unauthenticated(
787
+ unauthenticated_async_client, writable_memoryset: LabeledMemoryset
788
+ ):
789
+ """Test async insertion with invalid authentication"""
790
+ with unauthenticated_async_client.use():
791
+ with pytest.raises(ValueError, match="Invalid API key"):
792
+ await writable_memoryset.ainsert(dict(value="this should fail", label=0))