orca-sdk 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +186 -43
- orca_sdk/_shared/metrics_test.py +99 -6
- orca_sdk/_utils/data_parsing_test.py +1 -1
- orca_sdk/async_client.py +52 -14
- orca_sdk/classification_model.py +107 -30
- orca_sdk/classification_model_test.py +327 -8
- orca_sdk/client.py +52 -14
- orca_sdk/conftest.py +140 -21
- orca_sdk/embedding_model.py +0 -2
- orca_sdk/memoryset.py +141 -26
- orca_sdk/memoryset_test.py +253 -4
- orca_sdk/regression_model.py +73 -16
- orca_sdk/regression_model_test.py +213 -0
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/RECORD +16 -16
- {orca_sdk-0.1.4.dist-info → orca_sdk-0.1.6.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -8,7 +8,13 @@ from .classification_model import ClassificationModel
|
|
|
8
8
|
from .conftest import skip_in_ci, skip_in_prod
|
|
9
9
|
from .datasource import Datasource
|
|
10
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
11
|
-
from .memoryset import
|
|
11
|
+
from .memoryset import (
|
|
12
|
+
LabeledMemory,
|
|
13
|
+
LabeledMemoryset,
|
|
14
|
+
ScoredMemory,
|
|
15
|
+
ScoredMemoryset,
|
|
16
|
+
Status,
|
|
17
|
+
)
|
|
12
18
|
from .regression_model import RegressionModel
|
|
13
19
|
|
|
14
20
|
"""
|
|
@@ -154,8 +160,8 @@ def test_create_memoryset_null_labels():
|
|
|
154
160
|
assert memoryset is not None
|
|
155
161
|
assert memoryset.length == 2
|
|
156
162
|
assert memoryset.label_names == ["negative", "positive"]
|
|
157
|
-
assert memoryset[0].label
|
|
158
|
-
assert memoryset[1].label
|
|
163
|
+
assert memoryset[0].label is None
|
|
164
|
+
assert memoryset[1].label is None
|
|
159
165
|
|
|
160
166
|
|
|
161
167
|
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
@@ -285,6 +291,100 @@ def test_search_count(readonly_memoryset: LabeledMemoryset):
|
|
|
285
291
|
assert memory_lookups[2].label == 0
|
|
286
292
|
|
|
287
293
|
|
|
294
|
+
def test_search_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
295
|
+
# Search within a specific partition - use "soup" which appears in both p1 and p2
|
|
296
|
+
# Use exclude_global to ensure we only get results from the specified partition
|
|
297
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
298
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
299
|
+
)
|
|
300
|
+
assert len(memory_lookups) > 0
|
|
301
|
+
# All results should be from partition p1 when partition_id is specified
|
|
302
|
+
assert all(
|
|
303
|
+
memory.partition_id == "p1" for memory in memory_lookups
|
|
304
|
+
), f"Expected all results from partition p1, but got: {[m.partition_id for m in memory_lookups]}"
|
|
305
|
+
|
|
306
|
+
# Search in a different partition - use "cats" which appears in both p1 and p2
|
|
307
|
+
memory_lookups_p2 = readonly_partitioned_memoryset.search(
|
|
308
|
+
"cats", partition_id="p2", partition_filter_mode="exclude_global", count=5
|
|
309
|
+
)
|
|
310
|
+
assert len(memory_lookups_p2) > 0
|
|
311
|
+
# All results should be from partition p2 when partition_id is specified
|
|
312
|
+
assert all(
|
|
313
|
+
memory.partition_id == "p2" for memory in memory_lookups_p2
|
|
314
|
+
), f"Expected all results from partition p2, but got: {[m.partition_id for m in memory_lookups_p2]}"
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_search_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
318
|
+
# Search excluding global memories - need to specify a partition_id when using exclude_global
|
|
319
|
+
# This tests that exclude_global works with a specific partition
|
|
320
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
321
|
+
"soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
322
|
+
)
|
|
323
|
+
assert len(memory_lookups) > 0
|
|
324
|
+
# All results should have a partition_id (not None) and be from p1
|
|
325
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_search_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
329
|
+
# Search only in global memories (partition_id=None in the data)
|
|
330
|
+
# Use a query that matches global memories and a reasonable count
|
|
331
|
+
memory_lookups = readonly_partitioned_memoryset.search("beach", partition_filter_mode="only_global", count=3)
|
|
332
|
+
# Should get at least some results (may be fewer than requested if not enough global memories match)
|
|
333
|
+
assert len(memory_lookups) > 0
|
|
334
|
+
# All results should be global (partition_id is None)
|
|
335
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
336
|
+
# When using only_global, all results should be global (either None)
|
|
337
|
+
assert all(
|
|
338
|
+
memory.partition_id is None for memory in memory_lookups
|
|
339
|
+
), f"Expected all results to be global (partition_id=None), but got partition_ids: {partition_ids}"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def test_search_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
343
|
+
# Search including global memories (default behavior)
|
|
344
|
+
# Use a reasonable count that won't exceed available memories
|
|
345
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
346
|
+
"i love soup", partition_filter_mode="include_global", count=5
|
|
347
|
+
)
|
|
348
|
+
assert len(memory_lookups) > 0
|
|
349
|
+
# Results can include both partitioned and global memories
|
|
350
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
351
|
+
# Should have at least one partition or global memory
|
|
352
|
+
assert len(partition_ids) > 0
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def test_search_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
356
|
+
# Search ignoring partition filtering entirely
|
|
357
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
358
|
+
"i love soup", partition_filter_mode="ignore_partitions", count=10
|
|
359
|
+
)
|
|
360
|
+
assert len(memory_lookups) > 0
|
|
361
|
+
# Results can come from any partition or global
|
|
362
|
+
partition_ids = {memory.partition_id for memory in memory_lookups}
|
|
363
|
+
# Should have results from multiple partitions/global
|
|
364
|
+
assert len(partition_ids) >= 1
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def test_search_multiple_queries_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
368
|
+
# Search multiple queries within a specific partition
|
|
369
|
+
memory_lookups = readonly_partitioned_memoryset.search(["i love soup", "cats are cute"], partition_id="p1", count=3)
|
|
370
|
+
assert len(memory_lookups) == 2
|
|
371
|
+
assert len(memory_lookups[0]) > 0
|
|
372
|
+
assert len(memory_lookups[1]) > 0
|
|
373
|
+
# All results should be from partition p1
|
|
374
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[0])
|
|
375
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups[1])
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def test_search_with_partition_id_and_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
379
|
+
# When partition_id is specified, partition_filter_mode should still work
|
|
380
|
+
# Search in p1 with exclude_global (should only return p1 results)
|
|
381
|
+
memory_lookups = readonly_partitioned_memoryset.search(
|
|
382
|
+
"i love soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
|
|
383
|
+
)
|
|
384
|
+
assert len(memory_lookups) > 0
|
|
385
|
+
assert all(memory.partition_id == "p1" for memory in memory_lookups)
|
|
386
|
+
|
|
387
|
+
|
|
288
388
|
def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
|
|
289
389
|
memory = readonly_memoryset[0]
|
|
290
390
|
assert memory.value == hf_dataset[0]["value"]
|
|
@@ -381,6 +481,155 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
|
|
|
381
481
|
assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
|
|
382
482
|
|
|
383
483
|
|
|
484
|
+
def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
485
|
+
# Query with partition_id and include_global (default) - includes both p1 and global memories
|
|
486
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1")
|
|
487
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
488
|
+
# Results should include both p1 and global memories
|
|
489
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
490
|
+
assert "p1" in partition_ids
|
|
491
|
+
assert None in partition_ids
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
495
|
+
# Query with partition_id and exclude_global mode - only returns p1 memories
|
|
496
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
497
|
+
assert len(memories) == 8 # Only 8 p1 memories (no global)
|
|
498
|
+
# All results should be from partition p1 (no global memories)
|
|
499
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
503
|
+
# Query with partition_id and include_global mode (default) - includes both p1 and global
|
|
504
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
|
|
505
|
+
assert len(memories) == 15 # 8 p1 + 7 global = 15
|
|
506
|
+
# Results should include both p1 and global memories
|
|
507
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
508
|
+
assert "p1" in partition_ids
|
|
509
|
+
assert None in partition_ids
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
513
|
+
# Query excluding global memories requires a partition_id
|
|
514
|
+
# Test with a specific partition_id
|
|
515
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
|
|
516
|
+
assert len(memories) == 8 # Only p1 memories
|
|
517
|
+
# All results should have a partition_id (not global)
|
|
518
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
522
|
+
# Query only in global memories
|
|
523
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
|
|
524
|
+
assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
|
|
525
|
+
# All results should be global (partition_id is None)
|
|
526
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
530
|
+
# Query including global memories - when no partition_id is specified,
|
|
531
|
+
# include_global seems to only return global memories
|
|
532
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
|
|
533
|
+
# Based on actual behavior, this returns only global memories
|
|
534
|
+
assert len(memories) == 7
|
|
535
|
+
# All results should be global
|
|
536
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
540
|
+
# Query ignoring partition filtering entirely - returns all memories
|
|
541
|
+
memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
|
|
542
|
+
assert len(memories) == 22 # All 22 memories
|
|
543
|
+
# Results can come from any partition or global
|
|
544
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
545
|
+
# Should have results from multiple partitions/global
|
|
546
|
+
assert len(partition_ids) >= 1
|
|
547
|
+
# Verify we have p1, p2, and global
|
|
548
|
+
assert "p1" in partition_ids
|
|
549
|
+
assert "p2" in partition_ids
|
|
550
|
+
assert None in partition_ids
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
554
|
+
# Query with filters and partition_id
|
|
555
|
+
memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
|
|
556
|
+
assert len(memories) > 0
|
|
557
|
+
# All results should match the filter and be from partition p1
|
|
558
|
+
assert all(memory.label == 0 for memory in memories)
|
|
559
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
563
|
+
# Query with filters and partition_filter_mode - exclude_global requires partition_id
|
|
564
|
+
memories = readonly_partitioned_memoryset.query(
|
|
565
|
+
filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
|
|
566
|
+
)
|
|
567
|
+
assert len(memories) > 0
|
|
568
|
+
# All results should match the filter and be from p1 (not global)
|
|
569
|
+
assert all(memory.label == 1 for memory in memories)
|
|
570
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
574
|
+
# Query with limit and partition_id
|
|
575
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
|
|
576
|
+
assert len(memories) == 3
|
|
577
|
+
# All results should be from partition p2
|
|
578
|
+
assert all(memory.partition_id == "p2" for memory in memories)
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
582
|
+
# Query with offset and partition_id - use exclude_global to get only p1 memories
|
|
583
|
+
memories_page1 = readonly_partitioned_memoryset.query(
|
|
584
|
+
partition_id="p1", partition_filter_mode="exclude_global", limit=5
|
|
585
|
+
)
|
|
586
|
+
memories_page2 = readonly_partitioned_memoryset.query(
|
|
587
|
+
partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
|
|
588
|
+
)
|
|
589
|
+
assert len(memories_page1) == 5
|
|
590
|
+
assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
|
|
591
|
+
# All results should be from partition p1
|
|
592
|
+
assert all(memory.partition_id == "p1" for memory in memories_page1)
|
|
593
|
+
assert all(memory.partition_id == "p1" for memory in memories_page2)
|
|
594
|
+
# Results should be different (pagination works)
|
|
595
|
+
memory_ids_page1 = {memory.memory_id for memory in memories_page1}
|
|
596
|
+
memory_ids_page2 = {memory.memory_id for memory in memories_page2}
|
|
597
|
+
assert memory_ids_page1.isdisjoint(memory_ids_page2)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
601
|
+
# Query a different partition to verify it works
|
|
602
|
+
# With include_global (default), it includes both p2 and global memories
|
|
603
|
+
memories = readonly_partitioned_memoryset.query(partition_id="p2")
|
|
604
|
+
assert len(memories) == 14 # 7 p2 + 7 global = 14
|
|
605
|
+
# Results should include both p2 and global memories
|
|
606
|
+
partition_ids = {memory.partition_id for memory in memories}
|
|
607
|
+
assert "p2" in partition_ids
|
|
608
|
+
assert None in partition_ids
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
|
|
612
|
+
# Query with metadata filter and partition_id
|
|
613
|
+
memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
|
|
614
|
+
assert len(memories) > 0
|
|
615
|
+
# All results should match the metadata filter and be from partition p1
|
|
616
|
+
assert all(memory.metadata.get("key") == "g1" for memory in memories)
|
|
617
|
+
assert all(memory.partition_id == "p1" for memory in memories)
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
|
|
621
|
+
readonly_partitioned_memoryset: LabeledMemoryset,
|
|
622
|
+
):
|
|
623
|
+
# Query only global memories with filters
|
|
624
|
+
memories = readonly_partitioned_memoryset.query(
|
|
625
|
+
filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
|
|
626
|
+
)
|
|
627
|
+
assert len(memories) > 0
|
|
628
|
+
# All results should match the filter and be global
|
|
629
|
+
assert all(memory.metadata.get("key") == "g3" for memory in memories)
|
|
630
|
+
assert all(memory.partition_id is None for memory in memories)
|
|
631
|
+
|
|
632
|
+
|
|
384
633
|
def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
|
|
385
634
|
"""Test that LabeledMemory.predictions() only returns classification predictions."""
|
|
386
635
|
# Given: A classification model with memories
|
|
@@ -696,7 +945,7 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
|
696
945
|
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
697
946
|
assert scored_memoryset[0].value == "i love soup"
|
|
698
947
|
assert scored_memoryset[0].score is not None
|
|
699
|
-
assert scored_memoryset[0].metadata == {"key": "g1", "label": 0}
|
|
948
|
+
assert scored_memoryset[0].metadata == {"key": "g1", "label": 0, "partition_id": "p1"}
|
|
700
949
|
assert scored_memoryset[0].source_id == "s1"
|
|
701
950
|
lookup = scored_memoryset.search("i love soup", count=1)
|
|
702
951
|
assert len(lookup) == 1
|
orca_sdk/regression_model.py
CHANGED
|
@@ -16,6 +16,7 @@ from .client import (
|
|
|
16
16
|
RARHeadType,
|
|
17
17
|
RegressionEvaluationRequest,
|
|
18
18
|
RegressionModelMetadata,
|
|
19
|
+
RegressionPredictionRequest,
|
|
19
20
|
)
|
|
20
21
|
from .datasource import Datasource
|
|
21
22
|
from .job import Job
|
|
@@ -290,6 +291,10 @@ class RegressionModel:
|
|
|
290
291
|
use_lookup_cache: bool = True,
|
|
291
292
|
timeout_seconds: int = 10,
|
|
292
293
|
ignore_unlabeled: bool = False,
|
|
294
|
+
partition_id: str | None = None,
|
|
295
|
+
partition_filter_mode: Literal[
|
|
296
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
297
|
+
] = "include_global",
|
|
293
298
|
use_gpu: bool = True,
|
|
294
299
|
) -> RegressionPrediction: ...
|
|
295
300
|
|
|
@@ -304,6 +309,10 @@ class RegressionModel:
|
|
|
304
309
|
use_lookup_cache: bool = True,
|
|
305
310
|
timeout_seconds: int = 10,
|
|
306
311
|
ignore_unlabeled: bool = False,
|
|
312
|
+
partition_id: str | list[str | None] | None = None,
|
|
313
|
+
partition_filter_mode: Literal[
|
|
314
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
315
|
+
] = "include_global",
|
|
307
316
|
use_gpu: bool = True,
|
|
308
317
|
) -> list[RegressionPrediction]: ...
|
|
309
318
|
|
|
@@ -318,6 +327,10 @@ class RegressionModel:
|
|
|
318
327
|
use_lookup_cache: bool = True,
|
|
319
328
|
timeout_seconds: int = 10,
|
|
320
329
|
ignore_unlabeled: bool = False,
|
|
330
|
+
partition_id: str | list[str | None] | None = None,
|
|
331
|
+
partition_filter_mode: Literal[
|
|
332
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
333
|
+
] = "include_global",
|
|
321
334
|
use_gpu: bool = True,
|
|
322
335
|
) -> RegressionPrediction | list[RegressionPrediction]:
|
|
323
336
|
"""
|
|
@@ -336,6 +349,12 @@ class RegressionModel:
|
|
|
336
349
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
337
350
|
ignore_unlabeled: If True, only use memories with scores during lookup.
|
|
338
351
|
If False (default), allow memories without scores when necessary.
|
|
352
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
353
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
354
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
355
|
+
* `"include_global"`: Include global memories
|
|
356
|
+
* `"exclude_global"`: Exclude global memories
|
|
357
|
+
* `"only_global"`: Only include global memories
|
|
339
358
|
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
340
359
|
|
|
341
360
|
Returns:
|
|
@@ -356,24 +375,29 @@ class RegressionModel:
|
|
|
356
375
|
|
|
357
376
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
358
377
|
client = OrcaClient._resolve_client()
|
|
378
|
+
request_json: RegressionPredictionRequest = {
|
|
379
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
380
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
381
|
+
"expected_scores": (
|
|
382
|
+
expected_scores
|
|
383
|
+
if isinstance(expected_scores, list)
|
|
384
|
+
else [expected_scores] if expected_scores is not None else None
|
|
385
|
+
),
|
|
386
|
+
"tags": list(tags or set()),
|
|
387
|
+
"save_telemetry": telemetry_on,
|
|
388
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
389
|
+
"prompt": prompt,
|
|
390
|
+
"use_lookup_cache": use_lookup_cache,
|
|
391
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
392
|
+
"partition_filter_mode": partition_filter_mode,
|
|
393
|
+
}
|
|
394
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
395
|
+
if partition_filter_mode != "ignore_partitions":
|
|
396
|
+
request_json["partition_ids"] = partition_id
|
|
359
397
|
response = client.POST(
|
|
360
398
|
endpoint,
|
|
361
399
|
params={"name_or_id": self.id},
|
|
362
|
-
json=
|
|
363
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
364
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
365
|
-
"expected_scores": (
|
|
366
|
-
expected_scores
|
|
367
|
-
if isinstance(expected_scores, list)
|
|
368
|
-
else [expected_scores] if expected_scores is not None else None
|
|
369
|
-
),
|
|
370
|
-
"tags": list(tags or set()),
|
|
371
|
-
"save_telemetry": telemetry_on,
|
|
372
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
373
|
-
"prompt": prompt,
|
|
374
|
-
"use_lookup_cache": use_lookup_cache,
|
|
375
|
-
"ignore_unlabeled": ignore_unlabeled,
|
|
376
|
-
},
|
|
400
|
+
json=request_json,
|
|
377
401
|
timeout=timeout_seconds,
|
|
378
402
|
)
|
|
379
403
|
|
|
@@ -471,6 +495,10 @@ class RegressionModel:
|
|
|
471
495
|
subsample: int | float | None,
|
|
472
496
|
background: bool = False,
|
|
473
497
|
ignore_unlabeled: bool = False,
|
|
498
|
+
partition_column: str | None = None,
|
|
499
|
+
partition_filter_mode: Literal[
|
|
500
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
501
|
+
] = "include_global",
|
|
474
502
|
) -> RegressionMetrics | Job[RegressionMetrics]:
|
|
475
503
|
client = OrcaClient._resolve_client()
|
|
476
504
|
response = client.POST(
|
|
@@ -485,6 +513,8 @@ class RegressionModel:
|
|
|
485
513
|
"telemetry_tags": list(tags) if tags else None,
|
|
486
514
|
"subsample": subsample,
|
|
487
515
|
"ignore_unlabeled": ignore_unlabeled,
|
|
516
|
+
"datasource_partition_column": partition_column,
|
|
517
|
+
"partition_filter_mode": partition_filter_mode,
|
|
488
518
|
},
|
|
489
519
|
)
|
|
490
520
|
|
|
@@ -521,6 +551,10 @@ class RegressionModel:
|
|
|
521
551
|
batch_size: int,
|
|
522
552
|
prompt: str | None = None,
|
|
523
553
|
ignore_unlabeled: bool = False,
|
|
554
|
+
partition_column: str | None = None,
|
|
555
|
+
partition_filter_mode: Literal[
|
|
556
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
557
|
+
] = "include_global",
|
|
524
558
|
) -> RegressionMetrics:
|
|
525
559
|
if len(dataset) == 0:
|
|
526
560
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -538,6 +572,8 @@ class RegressionModel:
|
|
|
538
572
|
save_telemetry="sync" if record_predictions else "off",
|
|
539
573
|
prompt=prompt,
|
|
540
574
|
ignore_unlabeled=ignore_unlabeled,
|
|
575
|
+
partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
|
|
576
|
+
partition_filter_mode=partition_filter_mode,
|
|
541
577
|
)
|
|
542
578
|
]
|
|
543
579
|
|
|
@@ -561,6 +597,10 @@ class RegressionModel:
|
|
|
561
597
|
subsample: int | float | None = None,
|
|
562
598
|
background: Literal[True],
|
|
563
599
|
ignore_unlabeled: bool = False,
|
|
600
|
+
partition_column: str | None = None,
|
|
601
|
+
partition_filter_mode: Literal[
|
|
602
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
603
|
+
] = "include_global",
|
|
564
604
|
) -> Job[RegressionMetrics]:
|
|
565
605
|
pass
|
|
566
606
|
|
|
@@ -578,6 +618,10 @@ class RegressionModel:
|
|
|
578
618
|
subsample: int | float | None = None,
|
|
579
619
|
background: Literal[False] = False,
|
|
580
620
|
ignore_unlabeled: bool = False,
|
|
621
|
+
partition_column: str | None = None,
|
|
622
|
+
partition_filter_mode: Literal[
|
|
623
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
624
|
+
] = "include_global",
|
|
581
625
|
) -> RegressionMetrics:
|
|
582
626
|
pass
|
|
583
627
|
|
|
@@ -594,6 +638,10 @@ class RegressionModel:
|
|
|
594
638
|
subsample: int | float | None = None,
|
|
595
639
|
background: bool = False,
|
|
596
640
|
ignore_unlabeled: bool = False,
|
|
641
|
+
partition_column: str | None = None,
|
|
642
|
+
partition_filter_mode: Literal[
|
|
643
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
644
|
+
] = "include_global",
|
|
597
645
|
) -> RegressionMetrics | Job[RegressionMetrics]:
|
|
598
646
|
"""
|
|
599
647
|
Evaluate the regression model on a given dataset or datasource
|
|
@@ -609,7 +657,12 @@ class RegressionModel:
|
|
|
609
657
|
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
610
658
|
background: Whether to run the operation in the background and return a job handle
|
|
611
659
|
ignore_unlabeled: If True, only use memories with scores during lookup. If False (default), allow memories without scores
|
|
612
|
-
|
|
660
|
+
partition_column: Optional name of the column that contains the partition IDs
|
|
661
|
+
partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
|
|
662
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
663
|
+
* `"include_global"`: Include global memories
|
|
664
|
+
* `"exclude_global"`: Exclude global memories
|
|
665
|
+
* `"only_global"`: Only include global memories
|
|
613
666
|
Returns:
|
|
614
667
|
RegressionMetrics containing metrics including MAE, MSE, RMSE, R2, and anomaly score statistics
|
|
615
668
|
|
|
@@ -640,6 +693,8 @@ class RegressionModel:
|
|
|
640
693
|
subsample=subsample,
|
|
641
694
|
background=background,
|
|
642
695
|
ignore_unlabeled=ignore_unlabeled,
|
|
696
|
+
partition_column=partition_column,
|
|
697
|
+
partition_filter_mode=partition_filter_mode,
|
|
643
698
|
)
|
|
644
699
|
elif isinstance(data, Dataset):
|
|
645
700
|
return self._evaluate_dataset(
|
|
@@ -651,6 +706,8 @@ class RegressionModel:
|
|
|
651
706
|
batch_size=batch_size,
|
|
652
707
|
prompt=prompt,
|
|
653
708
|
ignore_unlabeled=ignore_unlabeled,
|
|
709
|
+
partition_column=partition_column,
|
|
710
|
+
partition_filter_mode=partition_filter_mode,
|
|
654
711
|
)
|
|
655
712
|
else:
|
|
656
713
|
raise ValueError(f"Invalid data type: {type(data)}")
|