orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +3 -3
- orca_sdk/_utils/analysis_ui.py +4 -1
- orca_sdk/_utils/auth.py +2 -3
- orca_sdk/_utils/common.py +24 -1
- orca_sdk/_utils/prediction_result_ui.py +4 -1
- orca_sdk/_utils/torch_parsing.py +77 -0
- orca_sdk/_utils/torch_parsing_test.py +142 -0
- orca_sdk/_utils/value_parser.py +44 -17
- orca_sdk/_utils/value_parser_test.py +6 -5
- orca_sdk/async_client.py +234 -22
- orca_sdk/classification_model.py +203 -66
- orca_sdk/classification_model_test.py +85 -25
- orca_sdk/client.py +234 -20
- orca_sdk/conftest.py +97 -16
- orca_sdk/credentials_test.py +5 -8
- orca_sdk/datasource.py +44 -21
- orca_sdk/datasource_test.py +8 -2
- orca_sdk/embedding_model.py +15 -33
- orca_sdk/embedding_model_test.py +30 -1
- orca_sdk/memoryset.py +558 -425
- orca_sdk/memoryset_test.py +120 -185
- orca_sdk/regression_model.py +186 -65
- orca_sdk/regression_model_test.py +62 -3
- orca_sdk/telemetry.py +16 -7
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/METADATA +4 -8
- orca_sdk-0.1.12.dist-info/RECORD +38 -0
- orca_sdk/_shared/__init__.py +0 -10
- orca_sdk/_shared/metrics.py +0 -634
- orca_sdk/_shared/metrics_test.py +0 -570
- orca_sdk/_utils/data_parsing.py +0 -129
- orca_sdk/_utils/data_parsing_test.py +0 -244
- orca_sdk-0.1.10.dist-info/RECORD +0 -41
- {orca_sdk-0.1.10.dist-info → orca_sdk-0.1.12.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -5,7 +5,6 @@ import pytest
|
|
|
5
5
|
from datasets.arrow_dataset import Dataset
|
|
6
6
|
|
|
7
7
|
from .classification_model import ClassificationModel
|
|
8
|
-
from .conftest import skip_in_ci, skip_in_prod
|
|
9
8
|
from .datasource import Datasource
|
|
10
9
|
from .embedding_model import PretrainedEmbeddingModel
|
|
11
10
|
from .memoryset import (
|
|
@@ -59,7 +58,6 @@ def test_create_empty_labeled_memoryset():
|
|
|
59
58
|
|
|
60
59
|
# inserting should work on an empty memoryset
|
|
61
60
|
memoryset.insert(dict(value="i love soup", label=1, key="k1"))
|
|
62
|
-
memoryset.refresh()
|
|
63
61
|
assert memoryset.length == 1
|
|
64
62
|
m = memoryset[0]
|
|
65
63
|
assert isinstance(m, LabeledMemory)
|
|
@@ -104,7 +102,6 @@ def test_create_empty_scored_memoryset():
|
|
|
104
102
|
|
|
105
103
|
# inserting should work on an empty memoryset
|
|
106
104
|
memoryset.insert(dict(value="i love soup", score=0.25, key="k1", label=0))
|
|
107
|
-
memoryset.refresh()
|
|
108
105
|
assert memoryset.length == 1
|
|
109
106
|
m = memoryset[0]
|
|
110
107
|
assert isinstance(m, ScoredMemory)
|
|
@@ -128,6 +125,33 @@ def test_create_empty_scored_memoryset():
|
|
|
128
125
|
ScoredMemoryset.drop(name, if_not_exists="ignore")
|
|
129
126
|
|
|
130
127
|
|
|
128
|
+
def test_create_empty_partitioned_labeled_memoryset():
|
|
129
|
+
name = f"test_empty_partitioned_labeled_{uuid4()}"
|
|
130
|
+
label_names = ["negative", "positive"]
|
|
131
|
+
try:
|
|
132
|
+
memoryset = LabeledMemoryset.create(
|
|
133
|
+
name, label_names=label_names, partitioned=True, description="empty partitioned labeled test"
|
|
134
|
+
)
|
|
135
|
+
assert memoryset is not None
|
|
136
|
+
assert memoryset.name == name
|
|
137
|
+
assert memoryset.length == 0
|
|
138
|
+
assert memoryset.partitioned is True
|
|
139
|
+
|
|
140
|
+
# inserting with partition_id should work
|
|
141
|
+
memoryset.insert(dict(value="i love soup", label=1, partition_id="p1"))
|
|
142
|
+
memoryset.insert(dict(value="cats are cute", label=0, partition_id="p2"))
|
|
143
|
+
assert memoryset.length == 2
|
|
144
|
+
finally:
|
|
145
|
+
LabeledMemoryset.drop(name, if_not_exists="ignore")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_non_partitioned_memoryset_has_partitioned_false(
|
|
149
|
+
readonly_partitioned_memoryset: LabeledMemoryset, readonly_memoryset: LabeledMemoryset
|
|
150
|
+
):
|
|
151
|
+
assert readonly_partitioned_memoryset.partitioned is True
|
|
152
|
+
assert readonly_memoryset.partitioned is False
|
|
153
|
+
|
|
154
|
+
|
|
131
155
|
def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
|
|
132
156
|
with unauthenticated_client.use():
|
|
133
157
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
@@ -563,155 +587,6 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
|
|
|
563
587
|
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
564
588
|
|
|
565
589
|
|
|
566
|
-
def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
567
|
-
# Query with partition_id and include_global (default) - includes both p1 and global memories
|
|
568
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p1")
|
|
569
|
-
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
570
|
-
# Results should include both p1 and global memories
|
|
571
|
-
partition_ids = {memory.partition_id for memory in memories}
|
|
572
|
-
assert "p1" in partition_ids
|
|
573
|
-
assert None in partition_ids
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
577
|
-
# Query with partition_id and exclude_global mode - only returns p1 memories
|
|
578
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
579
|
-
assert len(memories) == 8 # Only 8 p1 memories (no global)
|
|
580
|
-
# All results should be from partition p1 (no global memories)
|
|
581
|
-
assert all(memory.partition_id == "p1" for memory in memories)
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
585
|
-
# Query with partition_id and include_global mode (default) - includes both p1 and global
|
|
586
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
|
|
587
|
-
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
588
|
-
# Results should include both p1 and global memories
|
|
589
|
-
partition_ids = {memory.partition_id for memory in memories}
|
|
590
|
-
assert "p1" in partition_ids
|
|
591
|
-
assert None in partition_ids
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
595
|
-
# Query excluding global memories requires a partition_id
|
|
596
|
-
# Test with a specific partition_id
|
|
597
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
598
|
-
assert len(memories) == 8 # Only p1 memories
|
|
599
|
-
# All results should have a partition_id (not global)
|
|
600
|
-
assert all(memory.partition_id == "p1" for memory in memories)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
604
|
-
# Query only in global memories
|
|
605
|
-
memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
|
|
606
|
-
assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
|
|
607
|
-
# All results should be global (partition_id is None)
|
|
608
|
-
assert all(memory.partition_id is None for memory in memories)
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
612
|
-
# Query including global memories - when no partition_id is specified,
|
|
613
|
-
# include_global seems to only return global memories
|
|
614
|
-
memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
|
|
615
|
-
# Based on actual behavior, this returns only global memories
|
|
616
|
-
assert len(memories) == 7
|
|
617
|
-
# All results should be global
|
|
618
|
-
assert all(memory.partition_id is None for memory in memories)
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
622
|
-
# Query ignoring partition filtering entirely - returns all memories
|
|
623
|
-
memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
|
|
624
|
-
assert len(memories) == 22 # All 22 memories
|
|
625
|
-
# Results can come from any partition or global
|
|
626
|
-
partition_ids = {memory.partition_id for memory in memories}
|
|
627
|
-
# Should have results from multiple partitions/global
|
|
628
|
-
assert len(partition_ids) >= 1
|
|
629
|
-
# Verify we have p1, p2, and global
|
|
630
|
-
assert "p1" in partition_ids
|
|
631
|
-
assert "p2" in partition_ids
|
|
632
|
-
assert None in partition_ids
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
636
|
-
# Query with filters and partition_id
|
|
637
|
-
memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
|
|
638
|
-
assert len(memories) > 0
|
|
639
|
-
# All results should match the filter and be from partition p1
|
|
640
|
-
assert all(memory.label == 0 for memory in memories)
|
|
641
|
-
assert all(memory.partition_id == "p1" for memory in memories)
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
645
|
-
# Query with filters and partition_filter_mode - exclude_global requires partition_id
|
|
646
|
-
memories = readonly_partitioned_memoryset.query(
|
|
647
|
-
filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
|
|
648
|
-
)
|
|
649
|
-
assert len(memories) > 0
|
|
650
|
-
# All results should match the filter and be from p1 (not global)
|
|
651
|
-
assert all(memory.label == 1 for memory in memories)
|
|
652
|
-
assert all(memory.partition_id == "p1" for memory in memories)
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
656
|
-
# Query with limit and partition_id
|
|
657
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
|
|
658
|
-
assert len(memories) == 3
|
|
659
|
-
# All results should be from partition p2
|
|
660
|
-
assert all(memory.partition_id == "p2" for memory in memories)
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
664
|
-
# Query with offset and partition_id - use exclude_global to get only p1 memories
|
|
665
|
-
memories_page1 = readonly_partitioned_memoryset.query(
|
|
666
|
-
partition_id="p1", partition_filter_mode="exclude_global", limit=5
|
|
667
|
-
)
|
|
668
|
-
memories_page2 = readonly_partitioned_memoryset.query(
|
|
669
|
-
partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
|
|
670
|
-
)
|
|
671
|
-
assert len(memories_page1) == 5
|
|
672
|
-
assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
|
|
673
|
-
# All results should be from partition p1
|
|
674
|
-
assert all(memory.partition_id == "p1" for memory in memories_page1)
|
|
675
|
-
assert all(memory.partition_id == "p1" for memory in memories_page2)
|
|
676
|
-
# Results should be different (pagination works)
|
|
677
|
-
memory_ids_page1 = {memory.memory_id for memory in memories_page1}
|
|
678
|
-
memory_ids_page2 = {memory.memory_id for memory in memories_page2}
|
|
679
|
-
assert memory_ids_page1.isdisjoint(memory_ids_page2)
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
683
|
-
# Query a different partition to verify it works
|
|
684
|
-
# With include_global (default), it includes both p2 and global memories
|
|
685
|
-
memories = readonly_partitioned_memoryset.query(partition_id="p2")
|
|
686
|
-
assert len(memories) == 14 # 7 p2 + 7 global = 14
|
|
687
|
-
# Results should include both p2 and global memories
|
|
688
|
-
partition_ids = {memory.partition_id for memory in memories}
|
|
689
|
-
assert "p2" in partition_ids
|
|
690
|
-
assert None in partition_ids
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
694
|
-
# Query with metadata filter and partition_id
|
|
695
|
-
memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
|
|
696
|
-
assert len(memories) > 0
|
|
697
|
-
# All results should match the metadata filter and be from partition p1
|
|
698
|
-
assert all(memory.metadata.get("key") == "g1" for memory in memories)
|
|
699
|
-
assert all(memory.partition_id == "p1" for memory in memories)
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
|
|
703
|
-
readonly_partitioned_memoryset: LabeledMemoryset,
|
|
704
|
-
):
|
|
705
|
-
# Query only global memories with filters
|
|
706
|
-
memories = readonly_partitioned_memoryset.query(
|
|
707
|
-
filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
|
|
708
|
-
)
|
|
709
|
-
assert len(memories) > 0
|
|
710
|
-
# All results should match the filter and be global
|
|
711
|
-
assert all(memory.metadata.get("key") == "g3" for memory in memories)
|
|
712
|
-
assert all(memory.partition_id is None for memory in memories)
|
|
713
|
-
|
|
714
|
-
|
|
715
590
|
def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
|
|
716
591
|
"""Test that LabeledMemory.predictions() only returns classification predictions."""
|
|
717
592
|
# Given: A classification model with memories
|
|
@@ -850,7 +725,6 @@ def test_memory_predictions_expected_label_filter(classification_model: Classifi
|
|
|
850
725
|
|
|
851
726
|
|
|
852
727
|
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
853
|
-
writable_memoryset.refresh()
|
|
854
728
|
prev_length = writable_memoryset.length
|
|
855
729
|
writable_memoryset.insert(
|
|
856
730
|
[
|
|
@@ -859,10 +733,8 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
859
733
|
],
|
|
860
734
|
batch_size=1,
|
|
861
735
|
)
|
|
862
|
-
writable_memoryset.refresh()
|
|
863
736
|
assert writable_memoryset.length == prev_length + 2
|
|
864
737
|
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
865
|
-
writable_memoryset.refresh()
|
|
866
738
|
assert writable_memoryset.length == prev_length + 3
|
|
867
739
|
last_memory = writable_memoryset[-1]
|
|
868
740
|
assert last_memory.value == "tomato soup is my favorite"
|
|
@@ -872,18 +744,16 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
872
744
|
assert last_memory.source_id == "test"
|
|
873
745
|
|
|
874
746
|
|
|
875
|
-
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
876
|
-
@skip_in_ci("CI environment may not have session consistency guarantees")
|
|
877
747
|
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
878
748
|
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
879
749
|
|
|
880
750
|
# test updating a single memory
|
|
881
751
|
memory_id = writable_memoryset[0].memory_id
|
|
882
|
-
|
|
752
|
+
updated_count = writable_memoryset.update(dict(memory_id=memory_id, value="i love soup so much"))
|
|
753
|
+
assert updated_count == 1
|
|
754
|
+
updated_memory = writable_memoryset.get(memory_id, consistency_level="Strong")
|
|
883
755
|
assert updated_memory.value == "i love soup so much"
|
|
884
756
|
assert updated_memory.label == hf_dataset[0]["label"]
|
|
885
|
-
writable_memoryset.refresh() # Refresh to ensure consistency after update
|
|
886
|
-
assert writable_memoryset.get(memory_id).value == "i love soup so much"
|
|
887
757
|
|
|
888
758
|
# test updating a memory instance
|
|
889
759
|
memory = writable_memoryset[0]
|
|
@@ -894,15 +764,52 @@ def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Datas
|
|
|
894
764
|
|
|
895
765
|
# test updating multiple memories
|
|
896
766
|
memory_ids = [memory.memory_id for memory in writable_memoryset[:2]]
|
|
897
|
-
|
|
767
|
+
updated_count = writable_memoryset.update(
|
|
898
768
|
[
|
|
899
769
|
dict(memory_id=memory_ids[0], value="i love soup so much"),
|
|
900
770
|
dict(memory_id=memory_ids[1], value="cats are so cute"),
|
|
901
771
|
],
|
|
902
772
|
batch_size=1,
|
|
903
773
|
)
|
|
904
|
-
assert
|
|
905
|
-
assert
|
|
774
|
+
assert updated_count == 2
|
|
775
|
+
assert writable_memoryset.get(memory_ids[0], consistency_level="Strong").value == "i love soup so much"
|
|
776
|
+
assert writable_memoryset.get(memory_ids[1], consistency_level="Strong").value == "cats are so cute"
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def test_update_memory_metadata(writable_memoryset: LabeledMemoryset):
|
|
780
|
+
memory = writable_memoryset[0]
|
|
781
|
+
assert memory.metadata["key"] == "g1"
|
|
782
|
+
|
|
783
|
+
# Updating label without metadata should preserve existing metadata
|
|
784
|
+
updated = memory.update(label=1)
|
|
785
|
+
assert updated.label == 1
|
|
786
|
+
assert updated.metadata["key"] == "g1", "Metadata should be preserved when not specified"
|
|
787
|
+
|
|
788
|
+
# Updating metadata via top-level keys should update only specified keys
|
|
789
|
+
updated = memory.update(key="updated", new_key="added")
|
|
790
|
+
assert updated.metadata["key"] == "updated", "Existing metadata key should be preserved"
|
|
791
|
+
assert updated.metadata["new_key"] == "added", "New metadata key should be added"
|
|
792
|
+
|
|
793
|
+
# Can explicitly clear metadata by passing metadata={}
|
|
794
|
+
writable_memoryset.update(dict(memory_id=memory.memory_id, metadata={}))
|
|
795
|
+
updated = writable_memoryset.get(memory.memory_id, consistency_level="Strong")
|
|
796
|
+
assert updated.metadata == {}, "Metadata should be cleared when explicitly set to {}"
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def test_update_memories_by_filter(writable_memoryset: LabeledMemoryset):
|
|
800
|
+
source_ids_to_update = ["s1", "s3"]
|
|
801
|
+
initial_length = len(writable_memoryset)
|
|
802
|
+
updated_count = writable_memoryset.update(
|
|
803
|
+
filters=[("source_id", "in", source_ids_to_update)],
|
|
804
|
+
patch={"label": 1},
|
|
805
|
+
)
|
|
806
|
+
assert updated_count == 2
|
|
807
|
+
assert len(writable_memoryset) == initial_length
|
|
808
|
+
updated_memories = writable_memoryset.query(
|
|
809
|
+
filters=[("source_id", "in", source_ids_to_update)], consistency_level="Strong"
|
|
810
|
+
)
|
|
811
|
+
assert len(updated_memories) == 2
|
|
812
|
+
assert all(memory.label == 1 for memory in updated_memories)
|
|
906
813
|
|
|
907
814
|
|
|
908
815
|
def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
@@ -911,17 +818,60 @@ def test_delete_memories(writable_memoryset: LabeledMemoryset):
|
|
|
911
818
|
# test deleting a single memory
|
|
912
819
|
prev_length = writable_memoryset.length
|
|
913
820
|
memory_id = writable_memoryset[0].memory_id
|
|
914
|
-
writable_memoryset.delete(memory_id)
|
|
821
|
+
deleted_count = writable_memoryset.delete(memory_id)
|
|
822
|
+
assert deleted_count == 1
|
|
915
823
|
with pytest.raises(LookupError):
|
|
916
824
|
writable_memoryset.get(memory_id)
|
|
917
825
|
assert writable_memoryset.length == prev_length - 1
|
|
918
826
|
|
|
919
827
|
# test deleting multiple memories
|
|
920
828
|
prev_length = writable_memoryset.length
|
|
921
|
-
writable_memoryset.delete(
|
|
829
|
+
deleted_count = writable_memoryset.delete(
|
|
830
|
+
[writable_memoryset[0].memory_id, writable_memoryset[1].memory_id], batch_size=1
|
|
831
|
+
)
|
|
832
|
+
assert deleted_count == 2
|
|
922
833
|
assert writable_memoryset.length == prev_length - 2
|
|
923
834
|
|
|
924
835
|
|
|
836
|
+
def test_delete_memories_by_filter(writable_memoryset: LabeledMemoryset):
|
|
837
|
+
source_ids_to_delete = ["s1", "s3"]
|
|
838
|
+
initial_length = len(writable_memoryset)
|
|
839
|
+
memories_before = writable_memoryset.query(filters=[("source_id", "in", source_ids_to_delete)])
|
|
840
|
+
assert len(memories_before) == 2
|
|
841
|
+
deleted_count = writable_memoryset.delete(filters=[("source_id", "in", source_ids_to_delete)])
|
|
842
|
+
assert deleted_count == 2
|
|
843
|
+
assert len(writable_memoryset) == initial_length - 2
|
|
844
|
+
memories_after = writable_memoryset.query(filters=[("source_id", "in", source_ids_to_delete)])
|
|
845
|
+
assert len(memories_after) == 0
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def test_delete_all_memories(writable_memoryset: LabeledMemoryset):
|
|
849
|
+
initial_count = writable_memoryset.length
|
|
850
|
+
deleted_count = writable_memoryset.truncate()
|
|
851
|
+
assert deleted_count == initial_count
|
|
852
|
+
assert writable_memoryset.length == 0
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
def test_delete_all_memories_from_partition(writable_memoryset: LabeledMemoryset):
|
|
856
|
+
memories_in_partition = len(writable_memoryset.query(filters=[("partition_id", "==", "p1")]))
|
|
857
|
+
assert memories_in_partition > 0
|
|
858
|
+
deleted_count = writable_memoryset.truncate(partition_id="p1")
|
|
859
|
+
assert deleted_count == memories_in_partition
|
|
860
|
+
memories_in_partition_after = len(writable_memoryset.query(filters=[("partition_id", "==", "p1")]))
|
|
861
|
+
assert memories_in_partition_after == 0
|
|
862
|
+
assert writable_memoryset.length > 0
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def test_delete_all_memories_from_global_partition(writable_memoryset: LabeledMemoryset):
|
|
866
|
+
memories_in_global_partition = len(writable_memoryset.query(filters=[("partition_id", "==", None)]))
|
|
867
|
+
assert memories_in_global_partition > 0
|
|
868
|
+
deleted_count = writable_memoryset.truncate(partition_id=None)
|
|
869
|
+
assert deleted_count == memories_in_global_partition
|
|
870
|
+
memories_in_global_partition_after = len(writable_memoryset.query(filters=[("partition_id", "==", None)]))
|
|
871
|
+
assert memories_in_global_partition_after == 0
|
|
872
|
+
assert writable_memoryset.length > 0
|
|
873
|
+
|
|
874
|
+
|
|
925
875
|
def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
926
876
|
cloned_memoryset = readonly_memoryset.clone(
|
|
927
877
|
"test_cloned_memoryset", embedding_model=PretrainedEmbeddingModel.DISTILBERT
|
|
@@ -983,7 +933,6 @@ async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
|
983
933
|
|
|
984
934
|
|
|
985
935
|
def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
986
|
-
# Insert a memory to test cascading edits
|
|
987
936
|
SOUP = 0
|
|
988
937
|
CATS = 1
|
|
989
938
|
query_text = "i love soup" # from SAMPLE_DATA in conftest.py
|
|
@@ -993,11 +942,7 @@ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
|
993
942
|
dict(value=mislabeled_soup_text, label=CATS), # mislabeled soup memory
|
|
994
943
|
]
|
|
995
944
|
)
|
|
996
|
-
|
|
997
|
-
# Fetch the memory to update
|
|
998
945
|
memory = writable_memoryset.query(filters=[("value", "==", query_text)])[0]
|
|
999
|
-
|
|
1000
|
-
# Update the label and get cascading edit suggestions
|
|
1001
946
|
suggestions = writable_memoryset.get_cascading_edits_suggestions(
|
|
1002
947
|
memory=memory,
|
|
1003
948
|
old_label=CATS,
|
|
@@ -1005,8 +950,6 @@ def test_get_cascading_edits_suggestions(writable_memoryset: LabeledMemoryset):
|
|
|
1005
950
|
max_neighbors=10,
|
|
1006
951
|
max_validation_neighbors=5,
|
|
1007
952
|
)
|
|
1008
|
-
|
|
1009
|
-
# Validate the suggestions
|
|
1010
953
|
assert len(suggestions) == 1
|
|
1011
954
|
assert suggestions[0]["neighbor"]["value"] == mislabeled_soup_text
|
|
1012
955
|
|
|
@@ -1062,26 +1005,24 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
|
1062
1005
|
assert lookup[0].score < 0.11
|
|
1063
1006
|
|
|
1064
1007
|
|
|
1065
|
-
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
1066
1008
|
def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
1067
1009
|
# we are only updating an inconsequential metadata field so that we don't affect other tests
|
|
1068
1010
|
memory = scored_memoryset[0]
|
|
1069
1011
|
assert memory.label == 0
|
|
1070
1012
|
scored_memoryset.update(dict(memory_id=memory.memory_id, label=3))
|
|
1071
|
-
|
|
1072
|
-
memory.
|
|
1073
|
-
|
|
1013
|
+
memory = scored_memoryset.get(memory.memory_id, consistency_level="Strong")
|
|
1014
|
+
assert memory.label == 3
|
|
1015
|
+
memory = memory.update(label=4)
|
|
1016
|
+
memory = scored_memoryset.get(memory.memory_id, consistency_level="Strong")
|
|
1017
|
+
assert memory.label == 4
|
|
1074
1018
|
|
|
1075
1019
|
|
|
1076
1020
|
@pytest.mark.asyncio
|
|
1077
1021
|
async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
|
|
1078
|
-
"""Test async insertion of a single memory"""
|
|
1079
|
-
await writable_memoryset.arefresh()
|
|
1080
1022
|
prev_length = writable_memoryset.length
|
|
1081
1023
|
|
|
1082
1024
|
await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
|
|
1083
1025
|
|
|
1084
|
-
await writable_memoryset.arefresh()
|
|
1085
1026
|
assert writable_memoryset.length == prev_length + 1
|
|
1086
1027
|
last_memory = writable_memoryset[-1]
|
|
1087
1028
|
assert last_memory.value == "async tomato soup is my favorite"
|
|
@@ -1091,8 +1032,6 @@ async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset
|
|
|
1091
1032
|
|
|
1092
1033
|
@pytest.mark.asyncio
|
|
1093
1034
|
async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
|
|
1094
|
-
"""Test async insertion of multiple memories"""
|
|
1095
|
-
await writable_memoryset.arefresh()
|
|
1096
1035
|
prev_length = writable_memoryset.length
|
|
1097
1036
|
|
|
1098
1037
|
await writable_memoryset.ainsert(
|
|
@@ -1102,7 +1041,6 @@ async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset)
|
|
|
1102
1041
|
]
|
|
1103
1042
|
)
|
|
1104
1043
|
|
|
1105
|
-
await writable_memoryset.arefresh()
|
|
1106
1044
|
assert writable_memoryset.length == prev_length + 2
|
|
1107
1045
|
|
|
1108
1046
|
# Check the inserted memories
|
|
@@ -1121,8 +1059,6 @@ async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset)
|
|
|
1121
1059
|
|
|
1122
1060
|
@pytest.mark.asyncio
|
|
1123
1061
|
async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
|
|
1124
|
-
"""Test async insertion with source_id and metadata"""
|
|
1125
|
-
await writable_memoryset.arefresh()
|
|
1126
1062
|
prev_length = writable_memoryset.length
|
|
1127
1063
|
|
|
1128
1064
|
await writable_memoryset.ainsert(
|
|
@@ -1131,7 +1067,6 @@ async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledM
|
|
|
1131
1067
|
)
|
|
1132
1068
|
)
|
|
1133
1069
|
|
|
1134
|
-
await writable_memoryset.arefresh()
|
|
1135
1070
|
assert writable_memoryset.length == prev_length + 1
|
|
1136
1071
|
last_memory = writable_memoryset[-1]
|
|
1137
1072
|
assert last_memory.value == "async soup with source id"
|