orca-sdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +179 -40
- orca_sdk/_shared/metrics_test.py +99 -6
- orca_sdk/_utils/data_parsing_test.py +1 -1
- orca_sdk/async_client.py +462 -301
- orca_sdk/classification_model.py +156 -41
- orca_sdk/classification_model_test.py +327 -8
- orca_sdk/client.py +462 -301
- orca_sdk/conftest.py +140 -21
- orca_sdk/datasource.py +45 -2
- orca_sdk/datasource_test.py +120 -0
- orca_sdk/embedding_model.py +32 -24
- orca_sdk/job.py +17 -17
- orca_sdk/memoryset.py +459 -56
- orca_sdk/memoryset_test.py +435 -2
- orca_sdk/regression_model.py +110 -19
- orca_sdk/regression_model_test.py +213 -0
- orca_sdk/telemetry.py +52 -13
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/RECORD +20 -20
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -8,7 +8,14 @@ from .classification_model import ClassificationModel
|
|
|
8
8
|
from .conftest import skip_in_ci, skip_in_prod
|
|
9
9
|
from .datasource import Datasource
|
|
10
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
11
|
-
from .memoryset import
|
|
11
|
+
from .memoryset import (
|
|
12
|
+
LabeledMemory,
|
|
13
|
+
LabeledMemoryset,
|
|
14
|
+
ScoredMemory,
|
|
15
|
+
ScoredMemoryset,
|
|
16
|
+
Status,
|
|
17
|
+
)
|
|
18
|
+
from .regression_model import RegressionModel
|
|
12
19
|
|
|
13
20
|
"""
|
|
14
21
|
Test Performance Note:
|
|
@@ -112,6 +119,51 @@ def test_if_exists_open_reuses_existing_datasource(
|
|
|
112
119
|
assert not Datasource.exists(datasource_name)
|
|
113
120
|
|
|
114
121
|
|
|
122
|
+
def test_create_memoryset_string_label():
|
|
123
|
+
assert not LabeledMemoryset.exists("test_string_label")
|
|
124
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
125
|
+
"test_string_label",
|
|
126
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": ["negative", "positive"]}),
|
|
127
|
+
)
|
|
128
|
+
assert memoryset is not None
|
|
129
|
+
assert memoryset.length == 2
|
|
130
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
131
|
+
assert memoryset[0].label == 0
|
|
132
|
+
assert memoryset[1].label == 1
|
|
133
|
+
assert memoryset[0].label_name == "negative"
|
|
134
|
+
assert memoryset[1].label_name == "positive"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_create_memoryset_integer_label():
|
|
138
|
+
assert not LabeledMemoryset.exists("test_integer_label")
|
|
139
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
140
|
+
"test_integer_label",
|
|
141
|
+
Dataset.from_dict({"value": ["terrible", "great"], "label": [0, 1]}),
|
|
142
|
+
label_names=["negative", "positive"],
|
|
143
|
+
)
|
|
144
|
+
assert memoryset is not None
|
|
145
|
+
assert memoryset.length == 2
|
|
146
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
147
|
+
assert memoryset[0].label == 0
|
|
148
|
+
assert memoryset[1].label == 1
|
|
149
|
+
assert memoryset[0].label_name == "negative"
|
|
150
|
+
assert memoryset[1].label_name == "positive"
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_create_memoryset_null_labels():
|
|
154
|
+
memoryset = LabeledMemoryset.from_hf_dataset(
|
|
155
|
+
"test_null_labels",
|
|
156
|
+
Dataset.from_dict({"value": ["terrible", "great"]}),
|
|
157
|
+
label_names=["negative", "positive"],
|
|
158
|
+
label_column=None,
|
|
159
|
+
)
|
|
160
|
+
assert memoryset is not None
|
|
161
|
+
assert memoryset.length == 2
|
|
162
|
+
assert memoryset.label_names == ["negative", "positive"]
|
|
163
|
+
assert memoryset[0].label is None
|
|
164
|
+
assert memoryset[1].label is None
|
|
165
|
+
|
|
166
|
+
|
|
115
167
|
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
116
168
|
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
117
169
|
assert fetched_memoryset is not None
|
|
@@ -239,6 +291,100 @@ def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
|
239
291
|
assert memory_lookups[2].label == 0
|
|
240
292
|
|
|
241
293
|
|
|
294
|
+
def test_search_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
295
|
+
# Search within a specific partition - use "soup" which appears in both p1 and p2
|
|
296
|
+
# Use exclude_global to ensure we only get results from the specified partition
|
|
297
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
298
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
299
|
+
)
|
|
300
|
+
assert len(memory_lookups) > 0
|
|
301
|
+
# All results should be from partition p1 when partition_id is specified
|
|
302
|
+
assert all(
|
|
303
|
+
memory.partition_id == "p1" for memory in memory_lookups
|
|
304
|
+
), f"Expected all results from partition p1, but got: {[m.partition_id for m in memory_lookups]}"
|
|
305
|
+
|
|
306
|
+
# Search in a different partition - use "cats" which appears in both p1 and p2
|
|
307
|
+
memory_lookups_p2 = readonly_partitioned_memoryset.search(
|
|
308
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global", count=5
|
|
309
|
+
)
|
|
310
|
+
assert len(memory_lookups_p2) > 0
|
|
311
|
+
# All results should be from partition p2 when partition_id is specified
|
|
312
|
+
assert all(
|
|
313
|
+
memory.partition_id == "p2" for memory in memory_lookups_p2
|
|
314
|
+
), f"Expected all results from partition p2, but got: {[m.partition_id for m in memory_lookups_p2]}"
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_search_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
318
|
+
# Search excluding global memories - need to specify a partition_id when using exclude_global
|
|
319
|
+
# This tests that exclude_global works with a specific partition
|
|
320
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
321
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
322
|
+
)
|
|
323
|
+
assert len(memory_lookups) > 0
|
|
324
|
+
# All results should have a partition_id (not None) and be from p1
|
|
325
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_search_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
329
|
+
# Search only in global memories (partition_id=None in the data)
|
|
330
|
+
# Use a query that matches global memories and a reasonable count
|
|
331
|
+
memory_lookups = readonly_partitioned_memoryset.search("beach", partition_filter_mode="only_global", count=3)
|
|
332
|
+
# Should get at least some results (may be fewer than requested if not enough global memories match)
|
|
333
|
+
assert len(memory_lookups) > 0
|
|
334
|
+
# All results should be global (partition_id is None)
|
|
335
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
336
|
+
# When using only_global, all results should be global (either None)
|
|
337
|
+
assert all(
|
|
338
|
+
memory.partition_id is None for memory in memory_lookups
|
|
339
|
+
), f"Expected all results to be global (partition_id=None), but got partition_ids: {partition_ids}"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def test_search_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
343
|
+
# Search including global memories (default behavior)
|
|
344
|
+
# Use a reasonable count that won't exceed available memories
|
|
345
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
346
|
+
"i love soup", partition_filter_mode="include_global", count=5
|
|
347
|
+
)
|
|
348
|
+
assert len(memory_lookups) > 0
|
|
349
|
+
# Results can include both partitioned and global memories
|
|
350
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
351
|
+
# Should have at least one partition or global memory
|
|
352
|
+
assert len(partition_ids) > 0
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def test_search_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
356
|
+
# Search ignoring partition filtering entirely
|
|
357
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
358
|
+
"i love soup", partition_filter_mode="ignore_partitions", count=10
|
|
359
|
+
)
|
|
360
|
+
assert len(memory_lookups) > 0
|
|
361
|
+
# Results can come from any partition or global
|
|
362
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
363
|
+
# Should have results from multiple partitions/global
|
|
364
|
+
assert len(partition_ids) >= 1
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def test_search_multiple_queries_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
368
|
+
# Search multiple queries within a specific partition
|
|
369
|
+
memory_lookups = readonly_partitioned_memoryset.search(["i love soup", "cats are cute"], partition_id="p1", count=3)
|
|
370
|
+
assert len(memory_lookups) == 2
|
|
371
|
+
assert len(memory_lookups[0]) > 0
|
|
372
|
+
assert len(memory_lookups[1]) > 0
|
|
373
|
+
# All results should be from partition p1
|
|
374
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[0])
|
|
375
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[1])
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def test_search_with_partition_id_and_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
379
|
+
# When partition_id is specified, partition_filter_mode should still work
|
|
380
|
+
# Search in p1 with exclude_global (should only return p1 results)
|
|
381
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
382
|
+
"i love soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
383
|
+
)
|
|
384
|
+
assert len(memory_lookups) > 0
|
|
385
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
386
|
+
|
|
387
|
+
|
|
242
388
|
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
243
389
|
memory = readonly_memoryset[0]
|
|
244
390
|
assert memory.value == hf_dataset[0]["value"]
|
|
@@ -335,6 +481,292 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
|
|
|
335
481
|
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
336
482
|
|
|
337
483
|
|
|
484
|
+
def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
485
|
+
# Query with partition_id and include_global (default) - includes both p1 and global memories
|
|
486
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1")
|
|
487
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
488
|
+
# Results should include both p1 and global memories
|
|
489
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
490
|
+
assert "p1" in partition_ids
|
|
491
|
+
assert None in partition_ids
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
495
|
+
# Query with partition_id and exclude_global mode - only returns p1 memories
|
|
496
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
497
|
+
assert len(memories) == 8 # Only 8 p1 memories (no global)
|
|
498
|
+
# All results should be from partition p1 (no global memories)
|
|
499
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
503
|
+
# Query with partition_id and include_global mode (default) - includes both p1 and global
|
|
504
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
|
|
505
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
506
|
+
# Results should include both p1 and global memories
|
|
507
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
508
|
+
assert "p1" in partition_ids
|
|
509
|
+
assert None in partition_ids
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
513
|
+
# Query excluding global memories requires a partition_id
|
|
514
|
+
# Test with a specific partition_id
|
|
515
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
516
|
+
assert len(memories) == 8 # Only p1 memories
|
|
517
|
+
# All results should have a partition_id (not global)
|
|
518
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
522
|
+
# Query only in global memories
|
|
523
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
|
|
524
|
+
assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
|
|
525
|
+
# All results should be global (partition_id is None)
|
|
526
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
530
|
+
# Query including global memories - when no partition_id is specified,
|
|
531
|
+
# include_global seems to only return global memories
|
|
532
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
|
|
533
|
+
# Based on actual behavior, this returns only global memories
|
|
534
|
+
assert len(memories) == 7
|
|
535
|
+
# All results should be global
|
|
536
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
540
|
+
# Query ignoring partition filtering entirely - returns all memories
|
|
541
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
|
|
542
|
+
assert len(memories) == 22 # All 22 memories
|
|
543
|
+
# Results can come from any partition or global
|
|
544
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
545
|
+
# Should have results from multiple partitions/global
|
|
546
|
+
assert len(partition_ids) >= 1
|
|
547
|
+
# Verify we have p1, p2, and global
|
|
548
|
+
assert "p1" in partition_ids
|
|
549
|
+
assert "p2" in partition_ids
|
|
550
|
+
assert None in partition_ids
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
554
|
+
# Query with filters and partition_id
|
|
555
|
+
memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
|
|
556
|
+
assert len(memories) > 0
|
|
557
|
+
# All results should match the filter and be from partition p1
|
|
558
|
+
assert all(memory.label == 0 for memory in memories)
|
|
559
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
563
|
+
# Query with filters and partition_filter_mode - exclude_global requires partition_id
|
|
564
|
+
memories = readonly_partitioned_memoryset.query(
|
|
565
|
+
filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
|
|
566
|
+
)
|
|
567
|
+
assert len(memories) > 0
|
|
568
|
+
# All results should match the filter and be from p1 (not global)
|
|
569
|
+
assert all(memory.label == 1 for memory in memories)
|
|
570
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
574
|
+
# Query with limit and partition_id
|
|
575
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
|
|
576
|
+
assert len(memories) == 3
|
|
577
|
+
# All results should be from partition p2
|
|
578
|
+
assert all(memory.partition_id == "p2" for memory in memories)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
582
|
+
# Query with offset and partition_id - use exclude_global to get only p1 memories
|
|
583
|
+
memories_page1 = readonly_partitioned_memoryset.query(
|
|
584
|
+
partition_id="p1", partition_filter_mode="exclude_global", limit=5
|
|
585
|
+
)
|
|
586
|
+
memories_page2 = readonly_partitioned_memoryset.query(
|
|
587
|
+
partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
|
|
588
|
+
)
|
|
589
|
+
assert len(memories_page1) == 5
|
|
590
|
+
assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
|
|
591
|
+
# All results should be from partition p1
|
|
592
|
+
assert all(memory.partition_id == "p1" for memory in memories_page1)
|
|
593
|
+
assert all(memory.partition_id == "p1" for memory in memories_page2)
|
|
594
|
+
# Results should be different (pagination works)
|
|
595
|
+
memory_ids_page1 = {memory.memory_id for memory in memories_page1}
|
|
596
|
+
memory_ids_page2 = {memory.memory_id for memory in memories_page2}
|
|
597
|
+
assert memory_ids_page1.isdisjoint(memory_ids_page2)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
601
|
+
# Query a different partition to verify it works
|
|
602
|
+
# With include_global (default), it includes both p2 and global memories
|
|
603
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2")
|
|
604
|
+
assert len(memories) == 14 # 7 p2 + 7 global = 14
|
|
605
|
+
# Results should include both p2 and global memories
|
|
606
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
607
|
+
assert "p2" in partition_ids
|
|
608
|
+
assert None in partition_ids
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
612
|
+
# Query with metadata filter and partition_id
|
|
613
|
+
memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
|
|
614
|
+
assert len(memories) > 0
|
|
615
|
+
# All results should match the metadata filter and be from partition p1
|
|
616
|
+
assert all(memory.metadata.get("key") == "g1" for memory in memories)
|
|
617
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
|
|
621
|
+
readonly_partitioned_memoryset: LabeledMemoryset,
|
|
622
|
+
):
|
|
623
|
+
# Query only global memories with filters
|
|
624
|
+
memories = readonly_partitioned_memoryset.query(
|
|
625
|
+
filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
|
|
626
|
+
)
|
|
627
|
+
assert len(memories) > 0
|
|
628
|
+
# All results should match the filter and be global
|
|
629
|
+
assert all(memory.metadata.get("key") == "g3" for memory in memories)
|
|
630
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
|
|
634
|
+
"""Test that LabeledMemory.predictions() only returns classification predictions."""
|
|
635
|
+
# Given: A classification model with memories
|
|
636
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
637
|
+
assert len(memories) > 0
|
|
638
|
+
memory = memories[0]
|
|
639
|
+
|
|
640
|
+
# When: I call the predictions method
|
|
641
|
+
predictions = memory.predictions()
|
|
642
|
+
|
|
643
|
+
# Then: It should return a list of ClassificationPrediction objects
|
|
644
|
+
assert isinstance(predictions, list)
|
|
645
|
+
for prediction in predictions:
|
|
646
|
+
assert prediction.__class__.__name__ == "ClassificationPrediction"
|
|
647
|
+
assert hasattr(prediction, "label")
|
|
648
|
+
assert not hasattr(prediction, "score") or prediction.score is None
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def test_scored_memory_predictions_property(regression_model: RegressionModel):
|
|
652
|
+
"""Test that ScoredMemory.predictions() only returns regression predictions."""
|
|
653
|
+
# Given: A regression model with memories
|
|
654
|
+
memories = regression_model.memoryset.query(limit=1)
|
|
655
|
+
assert len(memories) > 0
|
|
656
|
+
memory = memories[0]
|
|
657
|
+
|
|
658
|
+
# When: I call the predictions method
|
|
659
|
+
predictions = memory.predictions()
|
|
660
|
+
|
|
661
|
+
# Then: It should return a list of RegressionPrediction objects
|
|
662
|
+
assert isinstance(predictions, list)
|
|
663
|
+
for prediction in predictions:
|
|
664
|
+
assert prediction.__class__.__name__ == "RegressionPrediction"
|
|
665
|
+
assert hasattr(prediction, "score")
|
|
666
|
+
assert not hasattr(prediction, "label") or prediction.label is None
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def test_memory_feedback_property(classification_model: ClassificationModel):
|
|
670
|
+
"""Test that memory.feedback() returns feedback from relevant predictions."""
|
|
671
|
+
# Given: A prediction with recorded feedback
|
|
672
|
+
prediction = classification_model.predict("Test feedback")
|
|
673
|
+
feedback_category = f"test_feedback_{random.randint(0, 1000000)}"
|
|
674
|
+
prediction.record_feedback(category=feedback_category, value=True)
|
|
675
|
+
|
|
676
|
+
# And: A memory that was used in the prediction
|
|
677
|
+
memory_lookups = prediction.memory_lookups
|
|
678
|
+
assert len(memory_lookups) > 0
|
|
679
|
+
memory = memory_lookups[0]
|
|
680
|
+
|
|
681
|
+
# When: I access the feedback property
|
|
682
|
+
feedback = memory.feedback()
|
|
683
|
+
|
|
684
|
+
# Then: It should return feedback aggregated by category as a dict
|
|
685
|
+
assert isinstance(feedback, dict)
|
|
686
|
+
assert feedback_category in feedback
|
|
687
|
+
# Feedback values are lists (you may want to look at mean on the raw data)
|
|
688
|
+
assert isinstance(feedback[feedback_category], list)
|
|
689
|
+
assert len(feedback[feedback_category]) > 0
|
|
690
|
+
# For binary feedback, values should be booleans
|
|
691
|
+
assert isinstance(feedback[feedback_category][0], bool)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def test_memory_predictions_method_parameters(classification_model: ClassificationModel):
|
|
695
|
+
"""Test that memory.predictions() method supports pagination, sorting, and filtering."""
|
|
696
|
+
# Given: A classification model with memories
|
|
697
|
+
memories = classification_model.memoryset.query(limit=1)
|
|
698
|
+
assert len(memories) > 0
|
|
699
|
+
memory = memories[0]
|
|
700
|
+
|
|
701
|
+
# When: I call predictions with limit parameter
|
|
702
|
+
predictions_limited = memory.predictions(limit=2)
|
|
703
|
+
|
|
704
|
+
# Then: It should respect the limit
|
|
705
|
+
assert isinstance(predictions_limited, list)
|
|
706
|
+
assert len(predictions_limited) <= 2
|
|
707
|
+
|
|
708
|
+
# When: I call predictions with offset parameter
|
|
709
|
+
all_predictions = memory.predictions(limit=100)
|
|
710
|
+
if len(all_predictions) > 1:
|
|
711
|
+
predictions_offset = memory.predictions(limit=1, offset=1)
|
|
712
|
+
# Then: offset should skip the first prediction
|
|
713
|
+
assert predictions_offset[0].prediction_id != all_predictions[0].prediction_id
|
|
714
|
+
|
|
715
|
+
# When: I call predictions with sort parameter
|
|
716
|
+
predictions_sorted = memory.predictions(limit=10, sort=[("timestamp", "desc")])
|
|
717
|
+
# Then: It should return predictions (sorting verified by API)
|
|
718
|
+
assert isinstance(predictions_sorted, list)
|
|
719
|
+
|
|
720
|
+
# When: I call predictions with expected_label_match parameter
|
|
721
|
+
correct_predictions = memory.predictions(expected_label_match=True)
|
|
722
|
+
incorrect_predictions = memory.predictions(expected_label_match=False)
|
|
723
|
+
# Then: Both should return lists (correctness verified by API filtering)
|
|
724
|
+
assert isinstance(correct_predictions, list)
|
|
725
|
+
assert isinstance(incorrect_predictions, list)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def test_memory_predictions_expected_label_filter(classification_model: ClassificationModel):
|
|
729
|
+
"""Test that memory.predictions(expected_label_match=...) filters predictions by correctness."""
|
|
730
|
+
# Given: Make an initial prediction to learn the model's label for a known input
|
|
731
|
+
baseline_prediction = classification_model.predict("Filter test input", save_telemetry="sync")
|
|
732
|
+
original_label = baseline_prediction.label
|
|
733
|
+
alternate_label = 0 if original_label else 1
|
|
734
|
+
|
|
735
|
+
# When: Make a second prediction with an intentionally incorrect expected label
|
|
736
|
+
mismatched_prediction = classification_model.predict(
|
|
737
|
+
"Filter test input",
|
|
738
|
+
expected_labels=alternate_label,
|
|
739
|
+
save_telemetry="sync",
|
|
740
|
+
)
|
|
741
|
+
mismatched_memory = mismatched_prediction.memory_lookups[0]
|
|
742
|
+
|
|
743
|
+
# Then: The prediction should show up when filtering for incorrect predictions
|
|
744
|
+
incorrect_predictions = mismatched_memory.predictions(expected_label_match=False)
|
|
745
|
+
assert any(pred.prediction_id == mismatched_prediction.prediction_id for pred in incorrect_predictions)
|
|
746
|
+
|
|
747
|
+
# Produce a correct prediction (predicted label matches expected label)
|
|
748
|
+
correct_prediction = classification_model.predict(
|
|
749
|
+
"Filter test input",
|
|
750
|
+
expected_labels=original_label,
|
|
751
|
+
save_telemetry="sync",
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
# Ensure we are inspecting a memory used by both correct and incorrect predictions
|
|
755
|
+
correct_lookup_ids = {lookup.memory_id for lookup in correct_prediction.memory_lookups}
|
|
756
|
+
if mismatched_memory.memory_id not in correct_lookup_ids:
|
|
757
|
+
shared_lookup = next(
|
|
758
|
+
(lookup for lookup in mismatched_prediction.memory_lookups if lookup.memory_id in correct_lookup_ids),
|
|
759
|
+
None,
|
|
760
|
+
)
|
|
761
|
+
assert shared_lookup is not None, "No shared memory lookup between correct and incorrect predictions"
|
|
762
|
+
mismatched_memory = shared_lookup
|
|
763
|
+
|
|
764
|
+
# And: The correct prediction should appear when filtering for correct predictions
|
|
765
|
+
correct_predictions = mismatched_memory.predictions(expected_label_match=True)
|
|
766
|
+
assert any(pred.prediction_id == correct_prediction.prediction_id for pred in correct_predictions)
|
|
767
|
+
assert all(pred.prediction_id != mismatched_prediction.prediction_id for pred in correct_predictions)
|
|
768
|
+
|
|
769
|
+
|
|
338
770
|
def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
339
771
|
writable_memoryset.refresh()
|
|
340
772
|
prev_length = writable_memoryset.length
|
|
@@ -513,7 +945,8 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
|
513
945
|
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
514
946
|
assert scored_memoryset[0].value == "i love soup"
|
|
515
947
|
assert scored_memoryset[0].score is not None
|
|
516
|
-
assert scored_memoryset[0].metadata == {"key": "g1", "
|
|
948
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "label": 0, "partition_id": "p1"}
|
|
949
|
+
assert scored_memoryset[0].source_id == "s1"
|
|
517
950
|
lookup = scored_memoryset.search("i love soup", count=1)
|
|
518
951
|
assert len(lookup) == 1
|
|
519
952
|
assert lookup[0].score is not None
|