orca-sdk 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from .client import (
15
15
  BootstrapLabeledMemoryDataResult,
16
16
  ClassificationModelMetadata,
17
17
  ClassificationPredictionRequest,
18
+ ListPredictionsRequest,
18
19
  OrcaClient,
19
20
  PredictiveModelUpdate,
20
21
  RACHeadType,
@@ -363,6 +364,7 @@ class ClassificationModel:
363
364
  "ignore_partitions", "include_global", "exclude_global", "only_global"
364
365
  ] = "include_global",
365
366
  use_gpu: bool = True,
367
+ batch_size: int = 100,
366
368
  ) -> list[ClassificationPrediction]:
367
369
  pass
368
370
 
@@ -383,6 +385,7 @@ class ClassificationModel:
383
385
  "ignore_partitions", "include_global", "exclude_global", "only_global"
384
386
  ] = "include_global",
385
387
  use_gpu: bool = True,
388
+ batch_size: int = 100,
386
389
  ) -> ClassificationPrediction:
387
390
  pass
388
391
 
@@ -402,6 +405,7 @@ class ClassificationModel:
402
405
  "ignore_partitions", "include_global", "exclude_global", "only_global"
403
406
  ] = "include_global",
404
407
  use_gpu: bool = True,
408
+ batch_size: int = 100,
405
409
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
406
410
  """
407
411
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -429,6 +433,7 @@ class ClassificationModel:
429
433
  * `"exclude_global"`: Exclude global memories
430
434
  * `"only_global"`: Only include global memories
431
435
  use_gpu: Whether to use GPU for the prediction (defaults to True)
436
+ batch_size: Number of values to process in a single API call
432
437
 
433
438
  Returns:
434
439
  Label prediction or list of label predictions
@@ -456,6 +461,8 @@ class ClassificationModel:
456
461
 
457
462
  if timeout_seconds <= 0:
458
463
  raise ValueError("timeout_seconds must be a positive integer")
464
+ if batch_size <= 0 or batch_size > 500:
465
+ raise ValueError("batch_size must be between 1 and 500")
459
466
 
460
467
  parsed_filters = [
461
468
  _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
@@ -464,10 +471,17 @@ class ClassificationModel:
464
471
  if any(_is_metric_column(filter[0]) for filter in filters):
465
472
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
466
473
 
474
+ # Convert to list for batching
475
+ values = value if isinstance(value, list) else [value]
476
+ if isinstance(expected_labels, list) and len(expected_labels) != len(values):
477
+ raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
478
+ if isinstance(partition_id, list) and len(partition_id) != len(values):
479
+ raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
480
+
467
481
  if isinstance(expected_labels, int):
468
- expected_labels = [expected_labels]
482
+ expected_labels = [expected_labels] * len(values)
469
483
  elif isinstance(expected_labels, str):
470
- expected_labels = [self.memoryset.label_names.index(expected_labels)]
484
+ expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
471
485
  elif isinstance(expected_labels, list):
472
486
  expected_labels = [
473
487
  self.memoryset.label_names.index(label) if isinstance(label, str) else label
@@ -481,47 +495,56 @@ class ClassificationModel:
481
495
 
482
496
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
483
497
  client = OrcaClient._resolve_client()
484
- request_json: ClassificationPredictionRequest = {
485
- "input_values": value if isinstance(value, list) else [value],
486
- "memoryset_override_name_or_id": self._memoryset_override_id,
487
- "expected_labels": expected_labels,
488
- "tags": list(tags or set()),
489
- "save_telemetry": telemetry_on,
490
- "save_telemetry_synchronously": telemetry_sync,
491
- "filters": cast(list[FilterItem], parsed_filters),
492
- "prompt": prompt,
493
- "use_lookup_cache": use_lookup_cache,
494
- "ignore_unlabeled": ignore_unlabeled,
495
- "partition_filter_mode": partition_filter_mode,
496
- }
497
- # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
498
- if partition_filter_mode != "ignore_partitions":
499
- request_json["partition_ids"] = partition_id
500
- response = client.POST(
501
- endpoint,
502
- params={"name_or_id": self.id},
503
- json=request_json,
504
- timeout=timeout_seconds,
505
- )
506
498
 
507
- if telemetry_on and any(p["prediction_id"] is None for p in response):
508
- raise RuntimeError("Failed to save prediction to database.")
499
+ predictions: list[ClassificationPrediction] = []
500
+ for i in range(0, len(values), batch_size):
501
+ batch_values = values[i : i + batch_size]
502
+ batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
509
503
 
510
- predictions = [
511
- ClassificationPrediction(
512
- prediction_id=prediction["prediction_id"],
513
- label=prediction["label"],
514
- label_name=prediction["label_name"],
515
- score=None,
516
- confidence=prediction["confidence"],
517
- anomaly_score=prediction["anomaly_score"],
518
- memoryset=self.memoryset,
519
- model=self,
520
- logits=prediction["logits"],
521
- input_value=input_value,
504
+ request_json: ClassificationPredictionRequest = {
505
+ "input_values": batch_values,
506
+ "memoryset_override_name_or_id": self._memoryset_override_id,
507
+ "expected_labels": batch_expected_labels,
508
+ "tags": list(tags or set()),
509
+ "save_telemetry": telemetry_on,
510
+ "save_telemetry_synchronously": telemetry_sync,
511
+ "filters": cast(list[FilterItem], parsed_filters),
512
+ "prompt": prompt,
513
+ "use_lookup_cache": use_lookup_cache,
514
+ "ignore_unlabeled": ignore_unlabeled,
515
+ "partition_filter_mode": partition_filter_mode,
516
+ }
517
+ if partition_filter_mode != "ignore_partitions":
518
+ request_json["partition_ids"] = (
519
+ partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
520
+ )
521
+
522
+ response = client.POST(
523
+ endpoint,
524
+ params={"name_or_id": self.id},
525
+ json=request_json,
526
+ timeout=timeout_seconds,
522
527
  )
523
- for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
524
- ]
528
+
529
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
530
+ raise RuntimeError("Failed to save some prediction to database.")
531
+
532
+ predictions.extend(
533
+ ClassificationPrediction(
534
+ prediction_id=prediction["prediction_id"],
535
+ label=prediction["label"],
536
+ label_name=prediction["label_name"],
537
+ score=None,
538
+ confidence=prediction["confidence"],
539
+ anomaly_score=prediction["anomaly_score"],
540
+ memoryset=self.memoryset,
541
+ model=self,
542
+ logits=prediction["logits"],
543
+ input_value=input_value,
544
+ )
545
+ for prediction, input_value in zip(response, batch_values)
546
+ )
547
+
525
548
  self._last_prediction_was_batch = isinstance(value, list)
526
549
  self._last_prediction = predictions[-1]
527
550
  return predictions if isinstance(value, list) else predictions[0]
@@ -542,6 +565,7 @@ class ClassificationModel:
542
565
  partition_filter_mode: Literal[
543
566
  "ignore_partitions", "include_global", "exclude_global", "only_global"
544
567
  ] = "include_global",
568
+ batch_size: int = 100,
545
569
  ) -> list[ClassificationPrediction]:
546
570
  pass
547
571
 
@@ -561,6 +585,7 @@ class ClassificationModel:
561
585
  partition_filter_mode: Literal[
562
586
  "ignore_partitions", "include_global", "exclude_global", "only_global"
563
587
  ] = "include_global",
588
+ batch_size: int = 100,
564
589
  ) -> ClassificationPrediction:
565
590
  pass
566
591
 
@@ -579,6 +604,7 @@ class ClassificationModel:
579
604
  partition_filter_mode: Literal[
580
605
  "ignore_partitions", "include_global", "exclude_global", "only_global"
581
606
  ] = "include_global",
607
+ batch_size: int = 100,
582
608
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
583
609
  """
584
610
  Asynchronously predict label(s) for the given input value(s) grounded in similar memories
@@ -605,6 +631,8 @@ class ClassificationModel:
605
631
  * `"include_global"`: Include global memories
606
632
  * `"exclude_global"`: Exclude global memories
607
633
  * `"only_global"`: Only include global memories
634
+ batch_size: Number of values to process in a single API call
635
+
608
636
  Returns:
609
637
  Label prediction or list of label predictions.
610
638
 
@@ -631,6 +659,8 @@ class ClassificationModel:
631
659
 
632
660
  if timeout_seconds <= 0:
633
661
  raise ValueError("timeout_seconds must be a positive integer")
662
+ if batch_size <= 0 or batch_size > 500:
663
+ raise ValueError("batch_size must be between 1 and 500")
634
664
 
635
665
  parsed_filters = [
636
666
  _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
@@ -639,10 +669,17 @@ class ClassificationModel:
639
669
  if any(_is_metric_column(filter[0]) for filter in filters):
640
670
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
641
671
 
672
+ # Convert to list for batching
673
+ values = value if isinstance(value, list) else [value]
674
+ if isinstance(expected_labels, list) and len(expected_labels) != len(values):
675
+ raise ValueError("Invalid input: \n\texpected_labels must be the same length as values")
676
+ if isinstance(partition_id, list) and len(partition_id) != len(values):
677
+ raise ValueError("Invalid input: \n\tpartition_id must be the same length as values")
678
+
642
679
  if isinstance(expected_labels, int):
643
- expected_labels = [expected_labels]
680
+ expected_labels = [expected_labels] * len(values)
644
681
  elif isinstance(expected_labels, str):
645
- expected_labels = [self.memoryset.label_names.index(expected_labels)]
682
+ expected_labels = [self.memoryset.label_names.index(expected_labels)] * len(values)
646
683
  elif isinstance(expected_labels, list):
647
684
  expected_labels = [
648
685
  self.memoryset.label_names.index(label) if isinstance(label, str) else label
@@ -651,75 +688,89 @@ class ClassificationModel:
651
688
 
652
689
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
653
690
  client = OrcaAsyncClient._resolve_client()
654
- request_json: ClassificationPredictionRequest = {
655
- "input_values": value if isinstance(value, list) else [value],
656
- "memoryset_override_name_or_id": self._memoryset_override_id,
657
- "expected_labels": expected_labels,
658
- "tags": list(tags or set()),
659
- "save_telemetry": telemetry_on,
660
- "save_telemetry_synchronously": telemetry_sync,
661
- "filters": cast(list[FilterItem], parsed_filters),
662
- "prompt": prompt,
663
- "use_lookup_cache": use_lookup_cache,
664
- "ignore_unlabeled": ignore_unlabeled,
665
- "partition_filter_mode": partition_filter_mode,
666
- }
667
- # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
668
- if partition_filter_mode != "ignore_partitions":
669
- request_json["partition_ids"] = partition_id
670
- response = await client.POST(
671
- "/gpu/classification_model/{name_or_id}/prediction",
672
- params={"name_or_id": self.id},
673
- json=request_json,
674
- timeout=timeout_seconds,
675
- )
676
691
 
677
- if telemetry_on and any(p["prediction_id"] is None for p in response):
678
- raise RuntimeError("Failed to save prediction to database.")
692
+ predictions: list[ClassificationPrediction] = []
693
+ for i in range(0, len(values), batch_size):
694
+ batch_values = values[i : i + batch_size]
695
+ batch_expected_labels = expected_labels[i : i + batch_size] if expected_labels else None
679
696
 
680
- predictions = [
681
- ClassificationPrediction(
682
- prediction_id=prediction["prediction_id"],
683
- label=prediction["label"],
684
- label_name=prediction["label_name"],
685
- score=None,
686
- confidence=prediction["confidence"],
687
- anomaly_score=prediction["anomaly_score"],
688
- memoryset=self.memoryset,
689
- model=self,
690
- logits=prediction["logits"],
691
- input_value=input_value,
697
+ request_json: ClassificationPredictionRequest = {
698
+ "input_values": batch_values,
699
+ "memoryset_override_name_or_id": self._memoryset_override_id,
700
+ "expected_labels": batch_expected_labels,
701
+ "tags": list(tags or set()),
702
+ "save_telemetry": telemetry_on,
703
+ "save_telemetry_synchronously": telemetry_sync,
704
+ "filters": cast(list[FilterItem], parsed_filters),
705
+ "prompt": prompt,
706
+ "use_lookup_cache": use_lookup_cache,
707
+ "ignore_unlabeled": ignore_unlabeled,
708
+ "partition_filter_mode": partition_filter_mode,
709
+ }
710
+ if partition_filter_mode != "ignore_partitions":
711
+ request_json["partition_ids"] = (
712
+ partition_id[i : i + batch_size] if isinstance(partition_id, list) else partition_id
713
+ )
714
+ response = await client.POST(
715
+ "/gpu/classification_model/{name_or_id}/prediction",
716
+ params={"name_or_id": self.id},
717
+ json=request_json,
718
+ timeout=timeout_seconds,
692
719
  )
693
- for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
694
- ]
720
+
721
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
722
+ raise RuntimeError("Failed to save some prediction to database.")
723
+
724
+ predictions.extend(
725
+ ClassificationPrediction(
726
+ prediction_id=prediction["prediction_id"],
727
+ label=prediction["label"],
728
+ label_name=prediction["label_name"],
729
+ score=None,
730
+ confidence=prediction["confidence"],
731
+ anomaly_score=prediction["anomaly_score"],
732
+ memoryset=self.memoryset,
733
+ model=self,
734
+ logits=prediction["logits"],
735
+ input_value=input_value,
736
+ )
737
+ for prediction, input_value in zip(response, batch_values)
738
+ )
739
+
695
740
  self._last_prediction_was_batch = isinstance(value, list)
696
741
  self._last_prediction = predictions[-1]
697
742
  return predictions if isinstance(value, list) else predictions[0]
698
743
 
699
744
  def predictions(
700
745
  self,
701
- limit: int = 100,
746
+ limit: int | None = None,
702
747
  offset: int = 0,
703
748
  tag: str | None = None,
704
749
  sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
705
750
  expected_label_match: bool | None = None,
751
+ batch_size: int = 100,
706
752
  ) -> list[ClassificationPrediction]:
707
753
  """
708
754
  Get a list of predictions made by this model
709
755
 
710
756
  Params:
711
- limit: Optional maximum number of predictions to return
757
+ limit: Maximum number of predictions to return. If `None`, returns all predictions
758
+ by automatically paginating through results.
712
759
  offset: Optional offset of the first prediction to return
713
760
  tag: Optional tag to filter predictions by
714
761
  sort: Optional list of columns and directions to sort the predictions by.
715
762
  Predictions can be sorted by `timestamp` or `confidence`.
716
763
  expected_label_match: Optional filter to only include predictions where the expected
717
764
  label does (`True`) or doesn't (`False`) match the predicted label
765
+ batch_size: Number of predictions to fetch in a single API call
718
766
 
719
767
  Returns:
720
768
  List of label predictions
721
769
 
722
770
  Examples:
771
+ Get all predictions with a specific tag:
772
+ >>> predictions = model.predictions(tag="evaluation")
773
+
723
774
  Get the last 3 predictions:
724
775
  >>> predictions = model.predictions(limit=3, sort=[("timestamp", "desc")])
725
776
  [
@@ -737,33 +788,61 @@ class ClassificationModel:
737
788
  >>> predictions = model.predictions(expected_label_match=False)
738
789
  [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
739
790
  """
791
+ if batch_size <= 0 or batch_size > 500:
792
+ raise ValueError("batch_size must be between 1 and 500")
793
+ if limit == 0:
794
+ return []
795
+
740
796
  client = OrcaClient._resolve_client()
741
- predictions = client.POST(
742
- "/telemetry/prediction",
743
- json={
797
+ all_predictions: list[ClassificationPrediction] = []
798
+
799
+ if limit is not None and limit < batch_size:
800
+ pages = [(offset, limit)]
801
+ else:
802
+ # automatically paginate the requests if necessary
803
+ total = client.POST(
804
+ "/telemetry/prediction/count",
805
+ json={
806
+ "model_id": self.id,
807
+ "tag": tag,
808
+ "expected_label_match": expected_label_match,
809
+ },
810
+ )
811
+ max_limit = max(total - offset, 0)
812
+ limit = min(limit, max_limit) if limit is not None else max_limit
813
+ pages = [(o, min(batch_size, limit - (o - offset))) for o in range(offset, offset + limit, batch_size)]
814
+
815
+ for current_offset, current_limit in pages:
816
+ request_json: ListPredictionsRequest = {
744
817
  "model_id": self.id,
745
- "limit": limit,
746
- "offset": offset,
747
- "sort": [list(sort_item) for sort_item in sort],
818
+ "limit": current_limit,
819
+ "offset": current_offset,
748
820
  "tag": tag,
749
821
  "expected_label_match": expected_label_match,
750
- },
751
- )
752
- return [
753
- ClassificationPrediction(
754
- prediction_id=prediction["prediction_id"],
755
- label=prediction["label"],
756
- label_name=prediction["label_name"],
757
- score=None,
758
- confidence=prediction["confidence"],
759
- anomaly_score=prediction["anomaly_score"],
760
- memoryset=self.memoryset,
761
- model=self,
762
- telemetry=prediction,
822
+ }
823
+ if sort:
824
+ request_json["sort"] = sort
825
+ response = client.POST(
826
+ "/telemetry/prediction",
827
+ json=request_json,
763
828
  )
764
- for prediction in predictions
765
- if "label" in prediction
766
- ]
829
+ all_predictions.extend(
830
+ ClassificationPrediction(
831
+ prediction_id=prediction["prediction_id"],
832
+ label=prediction["label"],
833
+ label_name=prediction["label_name"],
834
+ score=None,
835
+ confidence=prediction["confidence"],
836
+ anomaly_score=prediction["anomaly_score"],
837
+ memoryset=self.memoryset,
838
+ model=self,
839
+ telemetry=prediction,
840
+ )
841
+ for prediction in response
842
+ if "label" in prediction
843
+ )
844
+
845
+ return all_predictions
767
846
 
768
847
  def _evaluate_datasource(
769
848
  self,
@@ -861,6 +940,7 @@ class ClassificationModel:
861
940
  logits=[p.logits for p in predictions],
862
941
  anomaly_scores=[p.anomaly_score for p in predictions],
863
942
  include_curves=True,
943
+ include_confusion_matrix=True,
864
944
  )
865
945
 
866
946
  @overload
@@ -218,13 +218,17 @@ def test_evaluate_dataset_with_nones_raises_error(classification_model: Classifi
218
218
 
219
219
 
220
220
  def test_evaluate_with_telemetry(classification_model: ClassificationModel, eval_dataset: Dataset):
221
- result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"})
221
+ result = classification_model.evaluate(eval_dataset, record_predictions=True, tags={"test"}, batch_size=2)
222
222
  assert result is not None
223
223
  assert isinstance(result, ClassificationMetrics)
224
- predictions = classification_model.predictions(tag="test")
224
+ predictions = classification_model.predictions(tag="test", batch_size=100, sort=[("timestamp", "asc")])
225
225
  assert len(predictions) == 4
226
226
  assert all(p.tags == {"test"} for p in predictions)
227
- assert all(p.expected_label == l for p, l in zip(predictions, eval_dataset["label"]))
227
+ prediction_expected_labels = [p.expected_label if p.expected_label is not None else -1 for p in predictions]
228
+ eval_expected_labels = list(eval_dataset["label"])
229
+ assert all(
230
+ p == l for p, l in zip(prediction_expected_labels, eval_expected_labels)
231
+ ), f"Prediction expected labels: {prediction_expected_labels} do not match eval expected labels: {eval_expected_labels}"
228
232
 
229
233
 
230
234
  def test_evaluate_with_partition_column_dataset(partitioned_classification_model: ClassificationModel):
@@ -361,7 +365,7 @@ def test_evaluate_with_partition_column_datasource(partitioned_classification_mo
361
365
 
362
366
 
363
367
  def test_predict(classification_model: ClassificationModel, label_names: list[str]):
364
- predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"])
368
+ predictions = classification_model.predict(["Do you love soup?", "Are cats cute?"], batch_size=1)
365
369
  assert len(predictions) == 2
366
370
  assert predictions[0].prediction_id is not None
367
371
  assert predictions[1].prediction_id is not None