orca-sdk 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,13 @@ from .classification_model import ClassificationModel
8
8
  from .conftest import skip_in_ci, skip_in_prod
9
9
  from .datasource import Datasource
10
10
  from .embedding_model import PretrainedEmbeddingModel
11
- from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
11
+ from .memoryset import (
12
+ LabeledMemory,
13
+ LabeledMemoryset,
14
+ ScoredMemory,
15
+ ScoredMemoryset,
16
+ Status,
17
+ )
12
18
  from .regression_model import RegressionModel
13
19
 
14
20
  """
@@ -154,8 +160,8 @@ def test_create_memoryset_null_labels():
154
160
  assert memoryset is not None
155
161
  assert memoryset.length == 2
156
162
  assert memoryset.label_names == ["negative", "positive"]
157
- assert memoryset[0].label == None
158
- assert memoryset[1].label == None
163
+ assert memoryset[0].label is None
164
+ assert memoryset[1].label is None
159
165
 
160
166
 
161
167
  def test_open_memoryset(readonly_memoryset, hf_dataset):
@@ -285,6 +291,100 @@ def test_search_count(readonly_memoryset: LabeledMemoryset):
285
291
  assert memory_lookups[2].label == 0
286
292
 
287
293
 
294
+ def test_search_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
295
+ # Search within a specific partition - use "soup" which appears in both p1 and p2
296
+ # Use exclude_global to ensure we only get results from the specified partition
297
+ memory_lookups = readonly_partitioned_memoryset.search(
298
+ "soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
299
+ )
300
+ assert len(memory_lookups) > 0
301
+ # All results should be from partition p1 when partition_id is specified
302
+ assert all(
303
+ memory.partition_id == "p1" for memory in memory_lookups
304
+ ), f"Expected all results from partition p1, but got: {[m.partition_id for m in memory_lookups]}"
305
+
306
+ # Search in a different partition - use "cats" which appears in both p1 and p2
307
+ memory_lookups_p2 = readonly_partitioned_memoryset.search(
308
+ "cats", partition_id="p2", partition_filter_mode="exclude_global", count=5
309
+ )
310
+ assert len(memory_lookups_p2) > 0
311
+ # All results should be from partition p2 when partition_id is specified
312
+ assert all(
313
+ memory.partition_id == "p2" for memory in memory_lookups_p2
314
+ ), f"Expected all results from partition p2, but got: {[m.partition_id for m in memory_lookups_p2]}"
315
+
316
+
317
+ def test_search_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
318
+ # Search excluding global memories - need to specify a partition_id when using exclude_global
319
+ # This tests that exclude_global works with a specific partition
320
+ memory_lookups = readonly_partitioned_memoryset.search(
321
+ "soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
322
+ )
323
+ assert len(memory_lookups) > 0
324
+ # All results should have a partition_id (not None) and be from p1
325
+ assert all(memory.partition_id == "p1" for memory in memory_lookups)
326
+
327
+
328
+ def test_search_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
329
+ # Search only in global memories (partition_id=None in the data)
330
+ # Use a query that matches global memories and a reasonable count
331
+ memory_lookups = readonly_partitioned_memoryset.search("beach", partition_filter_mode="only_global", count=3)
332
+ # Should get at least some results (may be fewer than requested if not enough global memories match)
333
+ assert len(memory_lookups) > 0
334
+ # All results should be global (partition_id is None)
335
+ partition_ids = {memory.partition_id for memory in memory_lookups}
336
+ # When using only_global, all results should be global (either None)
337
+ assert all(
338
+ memory.partition_id is None for memory in memory_lookups
339
+ ), f"Expected all results to be global (partition_id=None), but got partition_ids: {partition_ids}"
340
+
341
+
342
+ def test_search_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
343
+ # Search including global memories (default behavior)
344
+ # Use a reasonable count that won't exceed available memories
345
+ memory_lookups = readonly_partitioned_memoryset.search(
346
+ "i love soup", partition_filter_mode="include_global", count=5
347
+ )
348
+ assert len(memory_lookups) > 0
349
+ # Results can include both partitioned and global memories
350
+ partition_ids = {memory.partition_id for memory in memory_lookups}
351
+ # Should have at least one partition or global memory
352
+ assert len(partition_ids) > 0
353
+
354
+
355
+ def test_search_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
356
+ # Search ignoring partition filtering entirely
357
+ memory_lookups = readonly_partitioned_memoryset.search(
358
+ "i love soup", partition_filter_mode="ignore_partitions", count=10
359
+ )
360
+ assert len(memory_lookups) > 0
361
+ # Results can come from any partition or global
362
+ partition_ids = {memory.partition_id for memory in memory_lookups}
363
+ # Should have results from multiple partitions/global
364
+ assert len(partition_ids) >= 1
365
+
366
+
367
+ def test_search_multiple_queries_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
368
+ # Search multiple queries within a specific partition
369
+ memory_lookups = readonly_partitioned_memoryset.search(["i love soup", "cats are cute"], partition_id="p1", count=3)
370
+ assert len(memory_lookups) == 2
371
+ assert len(memory_lookups[0]) > 0
372
+ assert len(memory_lookups[1]) > 0
373
+ # All results should be from partition p1
374
+ assert all(memory.partition_id == "p1" for memory in memory_lookups[0])
375
+ assert all(memory.partition_id == "p1" for memory in memory_lookups[1])
376
+
377
+
378
+ def test_search_with_partition_id_and_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
379
+ # When partition_id is specified, partition_filter_mode should still work
380
+ # Search in p1 with exclude_global (should only return p1 results)
381
+ memory_lookups = readonly_partitioned_memoryset.search(
382
+ "i love soup", partition_id="p1", partition_filter_mode="exclude_global", count=5
383
+ )
384
+ assert len(memory_lookups) > 0
385
+ assert all(memory.partition_id == "p1" for memory in memory_lookups)
386
+
387
+
288
388
  def test_get_memory_at_index(readonly_memoryset: LabeledMemoryset, hf_dataset: Dataset, label_names: list[str]):
289
389
  memory = readonly_memoryset[0]
290
390
  assert memory.value == hf_dataset[0]["value"]
@@ -381,6 +481,155 @@ def test_query_memoryset_with_feedback_metrics_sort(classification_model: Classi
381
481
  assert memories[-1].feedback_metrics["positive"]["avg"] == -1.0
382
482
 
383
483
 
484
+ def test_query_memoryset_with_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
485
+ # Query with partition_id and include_global (default) - includes both p1 and global memories
486
+ memories = readonly_partitioned_memoryset.query(partition_id="p1")
487
+ assert len(memories) == 15 # 8 p1 + 7 global = 15
488
+ # Results should include both p1 and global memories
489
+ partition_ids = {memory.partition_id for memory in memories}
490
+ assert "p1" in partition_ids
491
+ assert None in partition_ids
492
+
493
+
494
+ def test_query_memoryset_with_partition_id_and_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
495
+ # Query with partition_id and exclude_global mode - only returns p1 memories
496
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
497
+ assert len(memories) == 8 # Only 8 p1 memories (no global)
498
+ # All results should be from partition p1 (no global memories)
499
+ assert all(memory.partition_id == "p1" for memory in memories)
500
+
501
+
502
+ def test_query_memoryset_with_partition_id_and_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
503
+ # Query with partition_id and include_global mode (default) - includes both p1 and global
504
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="include_global")
505
+ assert len(memories) == 15 # 8 p1 + 7 global = 15
506
+ # Results should include both p1 and global memories
507
+ partition_ids = {memory.partition_id for memory in memories}
508
+ assert "p1" in partition_ids
509
+ assert None in partition_ids
510
+
511
+
512
+ def test_query_memoryset_with_partition_filter_mode_exclude_global(readonly_partitioned_memoryset: LabeledMemoryset):
513
+ # Query excluding global memories requires a partition_id
514
+ # Test with a specific partition_id
515
+ memories = readonly_partitioned_memoryset.query(partition_id="p1", partition_filter_mode="exclude_global")
516
+ assert len(memories) == 8 # Only p1 memories
517
+ # All results should have a partition_id (not global)
518
+ assert all(memory.partition_id == "p1" for memory in memories)
519
+
520
+
521
+ def test_query_memoryset_with_partition_filter_mode_only_global(readonly_partitioned_memoryset: LabeledMemoryset):
522
+ # Query only in global memories
523
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="only_global")
524
+ assert len(memories) == 7 # There are 7 global memories in SAMPLE_DATA
525
+ # All results should be global (partition_id is None)
526
+ assert all(memory.partition_id is None for memory in memories)
527
+
528
+
529
+ def test_query_memoryset_with_partition_filter_mode_include_global(readonly_partitioned_memoryset: LabeledMemoryset):
530
+ # Query including global memories - when no partition_id is specified,
531
+ # include_global seems to only return global memories
532
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="include_global")
533
+ # Based on actual behavior, this returns only global memories
534
+ assert len(memories) == 7
535
+ # All results should be global
536
+ assert all(memory.partition_id is None for memory in memories)
537
+
538
+
539
+ def test_query_memoryset_with_partition_filter_mode_ignore_partitions(readonly_partitioned_memoryset: LabeledMemoryset):
540
+ # Query ignoring partition filtering entirely - returns all memories
541
+ memories = readonly_partitioned_memoryset.query(partition_filter_mode="ignore_partitions", limit=100)
542
+ assert len(memories) == 22 # All 22 memories
543
+ # Results can come from any partition or global
544
+ partition_ids = {memory.partition_id for memory in memories}
545
+ # Should have results from multiple partitions/global
546
+ assert len(partition_ids) >= 1
547
+ # Verify we have p1, p2, and global
548
+ assert "p1" in partition_ids
549
+ assert "p2" in partition_ids
550
+ assert None in partition_ids
551
+
552
+
553
+ def test_query_memoryset_with_filters_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
554
+ # Query with filters and partition_id
555
+ memories = readonly_partitioned_memoryset.query(filters=[("label", "==", 0)], partition_id="p1")
556
+ assert len(memories) > 0
557
+ # All results should match the filter and be from partition p1
558
+ assert all(memory.label == 0 for memory in memories)
559
+ assert all(memory.partition_id == "p1" for memory in memories)
560
+
561
+
562
+ def test_query_memoryset_with_filters_and_partition_filter_mode(readonly_partitioned_memoryset: LabeledMemoryset):
563
+ # Query with filters and partition_filter_mode - exclude_global requires partition_id
564
+ memories = readonly_partitioned_memoryset.query(
565
+ filters=[("label", "==", 1)], partition_id="p1", partition_filter_mode="exclude_global"
566
+ )
567
+ assert len(memories) > 0
568
+ # All results should match the filter and be from p1 (not global)
569
+ assert all(memory.label == 1 for memory in memories)
570
+ assert all(memory.partition_id == "p1" for memory in memories)
571
+
572
+
573
+ def test_query_memoryset_with_limit_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
574
+ # Query with limit and partition_id
575
+ memories = readonly_partitioned_memoryset.query(partition_id="p2", limit=3)
576
+ assert len(memories) == 3
577
+ # All results should be from partition p2
578
+ assert all(memory.partition_id == "p2" for memory in memories)
579
+
580
+
581
+ def test_query_memoryset_with_offset_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
582
+ # Query with offset and partition_id - use exclude_global to get only p1 memories
583
+ memories_page1 = readonly_partitioned_memoryset.query(
584
+ partition_id="p1", partition_filter_mode="exclude_global", limit=5
585
+ )
586
+ memories_page2 = readonly_partitioned_memoryset.query(
587
+ partition_id="p1", partition_filter_mode="exclude_global", offset=5, limit=5
588
+ )
589
+ assert len(memories_page1) == 5
590
+ assert len(memories_page2) == 3 # Only 3 remaining p1 memories (8 total - 5 = 3)
591
+ # All results should be from partition p1
592
+ assert all(memory.partition_id == "p1" for memory in memories_page1)
593
+ assert all(memory.partition_id == "p1" for memory in memories_page2)
594
+ # Results should be different (pagination works)
595
+ memory_ids_page1 = {memory.memory_id for memory in memories_page1}
596
+ memory_ids_page2 = {memory.memory_id for memory in memories_page2}
597
+ assert memory_ids_page1.isdisjoint(memory_ids_page2)
598
+
599
+
600
+ def test_query_memoryset_with_partition_id_p2(readonly_partitioned_memoryset: LabeledMemoryset):
601
+ # Query a different partition to verify it works
602
+ # With include_global (default), it includes both p2 and global memories
603
+ memories = readonly_partitioned_memoryset.query(partition_id="p2")
604
+ assert len(memories) == 14 # 7 p2 + 7 global = 14
605
+ # Results should include both p2 and global memories
606
+ partition_ids = {memory.partition_id for memory in memories}
607
+ assert "p2" in partition_ids
608
+ assert None in partition_ids
609
+
610
+
611
+ def test_query_memoryset_with_metadata_filter_and_partition_id(readonly_partitioned_memoryset: LabeledMemoryset):
612
+ # Query with metadata filter and partition_id
613
+ memories = readonly_partitioned_memoryset.query(filters=[("metadata.key", "==", "g1")], partition_id="p1")
614
+ assert len(memories) > 0
615
+ # All results should match the metadata filter and be from partition p1
616
+ assert all(memory.metadata.get("key") == "g1" for memory in memories)
617
+ assert all(memory.partition_id == "p1" for memory in memories)
618
+
619
+
620
+ def test_query_memoryset_with_partition_filter_mode_only_global_and_filters(
621
+ readonly_partitioned_memoryset: LabeledMemoryset,
622
+ ):
623
+ # Query only global memories with filters
624
+ memories = readonly_partitioned_memoryset.query(
625
+ filters=[("metadata.key", "==", "g3")], partition_filter_mode="only_global"
626
+ )
627
+ assert len(memories) > 0
628
+ # All results should match the filter and be global
629
+ assert all(memory.metadata.get("key") == "g3" for memory in memories)
630
+ assert all(memory.partition_id is None for memory in memories)
631
+
632
+
384
633
  def test_labeled_memory_predictions_property(classification_model: ClassificationModel):
385
634
  """Test that LabeledMemory.predictions() only returns classification predictions."""
386
635
  # Given: A classification model with memories
@@ -696,7 +945,7 @@ def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
696
945
  assert isinstance(scored_memoryset[0], ScoredMemory)
697
946
  assert scored_memoryset[0].value == "i love soup"
698
947
  assert scored_memoryset[0].score is not None
699
- assert scored_memoryset[0].metadata == {"key": "g1", "label": 0}
948
+ assert scored_memoryset[0].metadata == {"key": "g1", "label": 0, "partition_id": "p1"}
700
949
  assert scored_memoryset[0].source_id == "s1"
701
950
  lookup = scored_memoryset.search("i love soup", count=1)
702
951
  assert len(lookup) == 1
@@ -16,6 +16,7 @@ from .client import (
16
16
  RARHeadType,
17
17
  RegressionEvaluationRequest,
18
18
  RegressionModelMetadata,
19
+ RegressionPredictionRequest,
19
20
  )
20
21
  from .datasource import Datasource
21
22
  from .job import Job
@@ -290,6 +291,10 @@ class RegressionModel:
290
291
  use_lookup_cache: bool = True,
291
292
  timeout_seconds: int = 10,
292
293
  ignore_unlabeled: bool = False,
294
+ partition_id: str | None = None,
295
+ partition_filter_mode: Literal[
296
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
297
+ ] = "include_global",
293
298
  use_gpu: bool = True,
294
299
  ) -> RegressionPrediction: ...
295
300
 
@@ -304,6 +309,10 @@ class RegressionModel:
304
309
  use_lookup_cache: bool = True,
305
310
  timeout_seconds: int = 10,
306
311
  ignore_unlabeled: bool = False,
312
+ partition_id: str | list[str | None] | None = None,
313
+ partition_filter_mode: Literal[
314
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
315
+ ] = "include_global",
307
316
  use_gpu: bool = True,
308
317
  ) -> list[RegressionPrediction]: ...
309
318
 
@@ -318,6 +327,10 @@ class RegressionModel:
318
327
  use_lookup_cache: bool = True,
319
328
  timeout_seconds: int = 10,
320
329
  ignore_unlabeled: bool = False,
330
+ partition_id: str | list[str | None] | None = None,
331
+ partition_filter_mode: Literal[
332
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
333
+ ] = "include_global",
321
334
  use_gpu: bool = True,
322
335
  ) -> RegressionPrediction | list[RegressionPrediction]:
323
336
  """
@@ -336,6 +349,12 @@ class RegressionModel:
336
349
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
337
350
  ignore_unlabeled: If True, only use memories with scores during lookup.
338
351
  If False (default), allow memories without scores when necessary.
352
+ partition_id: Optional partition ID(s) to use during memory lookup
353
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
354
+ * `"ignore_partitions"`: Ignore partitions
355
+ * `"include_global"`: Include global memories
356
+ * `"exclude_global"`: Exclude global memories
357
+ * `"only_global"`: Only include global memories
339
358
  use_gpu: Whether to use GPU for the prediction (defaults to True)
340
359
 
341
360
  Returns:
@@ -356,24 +375,29 @@ class RegressionModel:
356
375
 
357
376
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
358
377
  client = OrcaClient._resolve_client()
378
+ request_json: RegressionPredictionRequest = {
379
+ "input_values": value if isinstance(value, list) else [value],
380
+ "memoryset_override_name_or_id": self._memoryset_override_id,
381
+ "expected_scores": (
382
+ expected_scores
383
+ if isinstance(expected_scores, list)
384
+ else [expected_scores] if expected_scores is not None else None
385
+ ),
386
+ "tags": list(tags or set()),
387
+ "save_telemetry": telemetry_on,
388
+ "save_telemetry_synchronously": telemetry_sync,
389
+ "prompt": prompt,
390
+ "use_lookup_cache": use_lookup_cache,
391
+ "ignore_unlabeled": ignore_unlabeled,
392
+ "partition_filter_mode": partition_filter_mode,
393
+ }
394
+ # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
395
+ if partition_filter_mode != "ignore_partitions":
396
+ request_json["partition_ids"] = partition_id
359
397
  response = client.POST(
360
398
  endpoint,
361
399
  params={"name_or_id": self.id},
362
- json={
363
- "input_values": value if isinstance(value, list) else [value],
364
- "memoryset_override_name_or_id": self._memoryset_override_id,
365
- "expected_scores": (
366
- expected_scores
367
- if isinstance(expected_scores, list)
368
- else [expected_scores] if expected_scores is not None else None
369
- ),
370
- "tags": list(tags or set()),
371
- "save_telemetry": telemetry_on,
372
- "save_telemetry_synchronously": telemetry_sync,
373
- "prompt": prompt,
374
- "use_lookup_cache": use_lookup_cache,
375
- "ignore_unlabeled": ignore_unlabeled,
376
- },
400
+ json=request_json,
377
401
  timeout=timeout_seconds,
378
402
  )
379
403
 
@@ -471,6 +495,10 @@ class RegressionModel:
471
495
  subsample: int | float | None,
472
496
  background: bool = False,
473
497
  ignore_unlabeled: bool = False,
498
+ partition_column: str | None = None,
499
+ partition_filter_mode: Literal[
500
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
501
+ ] = "include_global",
474
502
  ) -> RegressionMetrics | Job[RegressionMetrics]:
475
503
  client = OrcaClient._resolve_client()
476
504
  response = client.POST(
@@ -485,6 +513,8 @@ class RegressionModel:
485
513
  "telemetry_tags": list(tags) if tags else None,
486
514
  "subsample": subsample,
487
515
  "ignore_unlabeled": ignore_unlabeled,
516
+ "datasource_partition_column": partition_column,
517
+ "partition_filter_mode": partition_filter_mode,
488
518
  },
489
519
  )
490
520
 
@@ -521,6 +551,10 @@ class RegressionModel:
521
551
  batch_size: int,
522
552
  prompt: str | None = None,
523
553
  ignore_unlabeled: bool = False,
554
+ partition_column: str | None = None,
555
+ partition_filter_mode: Literal[
556
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
557
+ ] = "include_global",
524
558
  ) -> RegressionMetrics:
525
559
  if len(dataset) == 0:
526
560
  raise ValueError("Evaluation dataset cannot be empty")
@@ -538,6 +572,8 @@ class RegressionModel:
538
572
  save_telemetry="sync" if record_predictions else "off",
539
573
  prompt=prompt,
540
574
  ignore_unlabeled=ignore_unlabeled,
575
+ partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
576
+ partition_filter_mode=partition_filter_mode,
541
577
  )
542
578
  ]
543
579
 
@@ -561,6 +597,10 @@ class RegressionModel:
561
597
  subsample: int | float | None = None,
562
598
  background: Literal[True],
563
599
  ignore_unlabeled: bool = False,
600
+ partition_column: str | None = None,
601
+ partition_filter_mode: Literal[
602
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
603
+ ] = "include_global",
564
604
  ) -> Job[RegressionMetrics]:
565
605
  pass
566
606
 
@@ -578,6 +618,10 @@ class RegressionModel:
578
618
  subsample: int | float | None = None,
579
619
  background: Literal[False] = False,
580
620
  ignore_unlabeled: bool = False,
621
+ partition_column: str | None = None,
622
+ partition_filter_mode: Literal[
623
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
624
+ ] = "include_global",
581
625
  ) -> RegressionMetrics:
582
626
  pass
583
627
 
@@ -594,6 +638,10 @@ class RegressionModel:
594
638
  subsample: int | float | None = None,
595
639
  background: bool = False,
596
640
  ignore_unlabeled: bool = False,
641
+ partition_column: str | None = None,
642
+ partition_filter_mode: Literal[
643
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
644
+ ] = "include_global",
597
645
  ) -> RegressionMetrics | Job[RegressionMetrics]:
598
646
  """
599
647
  Evaluate the regression model on a given dataset or datasource
@@ -609,7 +657,12 @@ class RegressionModel:
609
657
  subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
610
658
  background: Whether to run the operation in the background and return a job handle
611
659
  ignore_unlabeled: If True, only use memories with scores during lookup. If False (default), allow memories without scores
612
-
660
+ partition_column: Optional name of the column that contains the partition IDs
661
+ partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
662
+ * `"ignore_partitions"`: Ignore partitions
663
+ * `"include_global"`: Include global memories
664
+ * `"exclude_global"`: Exclude global memories
665
+ * `"only_global"`: Only include global memories
613
666
  Returns:
614
667
  RegressionMetrics containing metrics including MAE, MSE, RMSE, R2, and anomaly score statistics
615
668
 
@@ -640,6 +693,8 @@ class RegressionModel:
640
693
  subsample=subsample,
641
694
  background=background,
642
695
  ignore_unlabeled=ignore_unlabeled,
696
+ partition_column=partition_column,
697
+ partition_filter_mode=partition_filter_mode,
643
698
  )
644
699
  elif isinstance(data, Dataset):
645
700
  return self._evaluate_dataset(
@@ -651,6 +706,8 @@ class RegressionModel:
651
706
  batch_size=batch_size,
652
707
  prompt=prompt,
653
708
  ignore_unlabeled=ignore_unlabeled,
709
+ partition_column=partition_column,
710
+ partition_filter_mode=partition_filter_mode,
654
711
  )
655
712
  else:
656
713
  raise ValueError(f"Invalid data type: {type(data)}")