orca-sdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,14 +3,7 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from datetime import datetime
6
- from typing import (
7
- Any,
8
- Generator,
9
- Iterable,
10
- Literal,
11
- cast,
12
- overload,
13
- )
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
14
7
 
15
8
  from datasets import Dataset
16
9
 
@@ -20,8 +13,11 @@ from .async_client import OrcaAsyncClient
20
13
  from .client import (
21
14
  BootstrapClassificationModelMeta,
22
15
  BootstrapClassificationModelResult,
16
+ ClassificationEvaluationRequest,
23
17
  ClassificationModelMetadata,
18
+ ClassificationPredictionRequest,
24
19
  OrcaClient,
20
+ PostClassificationModelByModelNameOrIdEvaluationParams,
25
21
  PredictiveModelUpdate,
26
22
  RACHeadType,
27
23
  )
@@ -207,7 +203,12 @@ class ClassificationModel:
207
203
  raise ValueError(f"Model with name {name} already exists")
208
204
  elif if_exists == "open":
209
205
  existing = cls.open(name)
210
- for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
206
+ for attribute in {
207
+ "head_type",
208
+ "memory_lookup_count",
209
+ "num_classes",
210
+ "min_memory_weight",
211
+ }:
211
212
  local_attribute = locals()[attribute]
212
213
  existing_attribute = getattr(existing, attribute)
213
214
  if local_attribute is not None and local_attribute != existing_attribute:
@@ -357,6 +358,12 @@ class ClassificationModel:
357
358
  prompt: str | None = None,
358
359
  use_lookup_cache: bool = True,
359
360
  timeout_seconds: int = 10,
361
+ ignore_unlabeled: bool = False,
362
+ partition_id: str | list[str | None] | None = None,
363
+ partition_filter_mode: Literal[
364
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
365
+ ] = "include_global",
366
+ use_gpu: bool = True,
360
367
  ) -> list[ClassificationPrediction]:
361
368
  pass
362
369
 
@@ -371,6 +378,12 @@ class ClassificationModel:
371
378
  prompt: str | None = None,
372
379
  use_lookup_cache: bool = True,
373
380
  timeout_seconds: int = 10,
381
+ ignore_unlabeled: bool = False,
382
+ partition_id: str | None = None,
383
+ partition_filter_mode: Literal[
384
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
385
+ ] = "include_global",
386
+ use_gpu: bool = True,
374
387
  ) -> ClassificationPrediction:
375
388
  pass
376
389
 
@@ -384,6 +397,12 @@ class ClassificationModel:
384
397
  prompt: str | None = None,
385
398
  use_lookup_cache: bool = True,
386
399
  timeout_seconds: int = 10,
400
+ ignore_unlabeled: bool = False,
401
+ partition_id: str | None | list[str | None] = None,
402
+ partition_filter_mode: Literal[
403
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
404
+ ] = "include_global",
405
+ use_gpu: bool = True,
387
406
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
388
407
  """
389
408
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -402,6 +421,15 @@ class ClassificationModel:
402
421
  prompt: Optional prompt to use for instruction-tuned embedding models
403
422
  use_lookup_cache: Whether to use cached lookup results for faster predictions
404
423
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
424
+ ignore_unlabeled: If True, only use labeled memories during lookup.
425
+ If False (default), allow unlabeled memories when necessary.
426
+ partition_id: Optional partition ID(s) to use during memory lookup
427
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
428
+ * `"ignore_partitions"`: Ignore partitions
429
+ * `"include_global"`: Include global memories
430
+ * `"exclude_global"`: Exclude global memories
431
+ * `"only_global"`: Only include global memories
432
+ use_gpu: Whether to use GPU for the prediction (defaults to True)
405
433
 
406
434
  Returns:
407
435
  Label prediction or list of label predictions
@@ -447,22 +475,33 @@ class ClassificationModel:
447
475
  for label in expected_labels
448
476
  ]
449
477
 
478
+ if use_gpu:
479
+ endpoint = "/gpu/classification_model/{name_or_id}/prediction"
480
+ else:
481
+ endpoint = "/classification_model/{name_or_id}/prediction"
482
+
450
483
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
451
484
  client = OrcaClient._resolve_client()
485
+ request_json: ClassificationPredictionRequest = {
486
+ "input_values": value if isinstance(value, list) else [value],
487
+ "memoryset_override_name_or_id": self._memoryset_override_id,
488
+ "expected_labels": expected_labels,
489
+ "tags": list(tags or set()),
490
+ "save_telemetry": telemetry_on,
491
+ "save_telemetry_synchronously": telemetry_sync,
492
+ "filters": cast(list[FilterItem], parsed_filters),
493
+ "prompt": prompt,
494
+ "use_lookup_cache": use_lookup_cache,
495
+ "ignore_unlabeled": ignore_unlabeled,
496
+ "partition_filter_mode": partition_filter_mode,
497
+ }
498
+ # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
499
+ if partition_filter_mode != "ignore_partitions":
500
+ request_json["partition_ids"] = partition_id
452
501
  response = client.POST(
453
- "/gpu/classification_model/{name_or_id}/prediction",
502
+ endpoint,
454
503
  params={"name_or_id": self.id},
455
- json={
456
- "input_values": value if isinstance(value, list) else [value],
457
- "memoryset_override_name_or_id": self._memoryset_override_id,
458
- "expected_labels": expected_labels,
459
- "tags": list(tags or set()),
460
- "save_telemetry": telemetry_on,
461
- "save_telemetry_synchronously": telemetry_sync,
462
- "filters": cast(list[FilterItem], parsed_filters),
463
- "prompt": prompt,
464
- "use_lookup_cache": use_lookup_cache,
465
- },
504
+ json=request_json,
466
505
  timeout=timeout_seconds,
467
506
  )
468
507
 
@@ -499,6 +538,11 @@ class ClassificationModel:
499
538
  prompt: str | None = None,
500
539
  use_lookup_cache: bool = True,
501
540
  timeout_seconds: int = 10,
541
+ ignore_unlabeled: bool = False,
542
+ partition_id: str | list[str | None] | None = None,
543
+ partition_filter_mode: Literal[
544
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
545
+ ] = "include_global",
502
546
  ) -> list[ClassificationPrediction]:
503
547
  pass
504
548
 
@@ -513,6 +557,11 @@ class ClassificationModel:
513
557
  prompt: str | None = None,
514
558
  use_lookup_cache: bool = True,
515
559
  timeout_seconds: int = 10,
560
+ ignore_unlabeled: bool = False,
561
+ partition_id: str | None = None,
562
+ partition_filter_mode: Literal[
563
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
564
+ ] = "include_global",
516
565
  ) -> ClassificationPrediction:
517
566
  pass
518
567
 
@@ -526,6 +575,11 @@ class ClassificationModel:
526
575
  prompt: str | None = None,
527
576
  use_lookup_cache: bool = True,
528
577
  timeout_seconds: int = 10,
578
+ ignore_unlabeled: bool = False,
579
+ partition_id: str | None | list[str | None] = None,
580
+ partition_filter_mode: Literal[
581
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
582
+ ] = "include_global",
529
583
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
530
584
  """
531
585
  Asynchronously predict label(s) for the given input value(s) grounded in similar memories
@@ -544,7 +598,14 @@ class ClassificationModel:
544
598
  prompt: Optional prompt to use for instruction-tuned embedding models
545
599
  use_lookup_cache: Whether to use cached lookup results for faster predictions
546
600
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
547
-
601
+ ignore_unlabeled: If True, only use labeled memories during lookup.
602
+ If False (default), allow unlabeled memories when necessary.
603
+ partition_id: Optional partition ID(s) to use during memory lookup
604
+ partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
605
+ * `"ignore_partitions"`: Ignore partitions
606
+ * `"include_global"`: Include global memories
607
+ * `"exclude_global"`: Exclude global memories
608
+ * `"only_global"`: Only include global memories
548
609
  Returns:
549
610
  Label prediction or list of label predictions.
550
611
 
@@ -591,20 +652,26 @@ class ClassificationModel:
591
652
 
592
653
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
593
654
  client = OrcaAsyncClient._resolve_client()
655
+ request_json: ClassificationPredictionRequest = {
656
+ "input_values": value if isinstance(value, list) else [value],
657
+ "memoryset_override_name_or_id": self._memoryset_override_id,
658
+ "expected_labels": expected_labels,
659
+ "tags": list(tags or set()),
660
+ "save_telemetry": telemetry_on,
661
+ "save_telemetry_synchronously": telemetry_sync,
662
+ "filters": cast(list[FilterItem], parsed_filters),
663
+ "prompt": prompt,
664
+ "use_lookup_cache": use_lookup_cache,
665
+ "ignore_unlabeled": ignore_unlabeled,
666
+ "partition_filter_mode": partition_filter_mode,
667
+ }
668
+ # Don't send partition_ids when partition_filter_mode is "ignore_partitions"
669
+ if partition_filter_mode != "ignore_partitions":
670
+ request_json["partition_ids"] = partition_id
594
671
  response = await client.POST(
595
672
  "/gpu/classification_model/{name_or_id}/prediction",
596
673
  params={"name_or_id": self.id},
597
- json={
598
- "input_values": value if isinstance(value, list) else [value],
599
- "memoryset_override_name_or_id": self._memoryset_override_id,
600
- "expected_labels": expected_labels,
601
- "tags": list(tags or set()),
602
- "save_telemetry": telemetry_on,
603
- "save_telemetry_synchronously": telemetry_sync,
604
- "filters": cast(list[FilterItem], parsed_filters),
605
- "prompt": prompt,
606
- "use_lookup_cache": use_lookup_cache,
607
- },
674
+ json=request_json,
608
675
  timeout=timeout_seconds,
609
676
  )
610
677
 
@@ -706,7 +773,13 @@ class ClassificationModel:
706
773
  label_column: str,
707
774
  record_predictions: bool,
708
775
  tags: set[str] | None,
776
+ subsample: int | float | None,
709
777
  background: bool = False,
778
+ ignore_unlabeled: bool = False,
779
+ partition_column: str | None = None,
780
+ partition_filter_mode: Literal[
781
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
782
+ ] = "include_global",
710
783
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
711
784
  client = OrcaClient._resolve_client()
712
785
  response = client.POST(
@@ -719,14 +792,18 @@ class ClassificationModel:
719
792
  "memoryset_override_name_or_id": self._memoryset_override_id,
720
793
  "record_telemetry": record_predictions,
721
794
  "telemetry_tags": list(tags) if tags else None,
795
+ "subsample": subsample,
796
+ "ignore_unlabeled": ignore_unlabeled,
797
+ "datasource_partition_column": partition_column,
798
+ "partition_filter_mode": partition_filter_mode,
722
799
  },
723
800
  )
724
801
 
725
802
  def get_value():
726
803
  client = OrcaClient._resolve_client()
727
804
  res = client.GET(
728
- "/classification_model/{model_name_or_id}/evaluation/{task_id}",
729
- params={"model_name_or_id": self.id, "task_id": response["task_id"]},
805
+ "/classification_model/{model_name_or_id}/evaluation/{job_id}",
806
+ params={"model_name_or_id": self.id, "job_id": response["job_id"]},
730
807
  )
731
808
  assert res["result"] is not None
732
809
  return ClassificationMetrics(
@@ -743,7 +820,7 @@ class ClassificationModel:
743
820
  roc_curve=res["result"].get("roc_curve"),
744
821
  )
745
822
 
746
- job = Job(response["task_id"], get_value)
823
+ job = Job(response["job_id"], get_value)
747
824
  return job if background else job.result()
748
825
 
749
826
  def _evaluate_dataset(
@@ -754,6 +831,11 @@ class ClassificationModel:
754
831
  record_predictions: bool,
755
832
  tags: set[str],
756
833
  batch_size: int,
834
+ ignore_unlabeled: bool,
835
+ partition_column: str | None = None,
836
+ partition_filter_mode: Literal[
837
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
838
+ ] = "include_global",
757
839
  ) -> ClassificationMetrics:
758
840
  if len(dataset) == 0:
759
841
  raise ValueError("Evaluation dataset cannot be empty")
@@ -769,6 +851,9 @@ class ClassificationModel:
769
851
  expected_labels=dataset[i : i + batch_size][label_column],
770
852
  tags=tags,
771
853
  save_telemetry="sync" if record_predictions else "off",
854
+ ignore_unlabeled=ignore_unlabeled,
855
+ partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
856
+ partition_filter_mode=partition_filter_mode,
772
857
  )
773
858
  ]
774
859
 
@@ -786,10 +871,16 @@ class ClassificationModel:
786
871
  *,
787
872
  value_column: str = "value",
788
873
  label_column: str = "label",
874
+ partition_column: str | None = None,
789
875
  record_predictions: bool = False,
790
876
  tags: set[str] = {"evaluation"},
791
877
  batch_size: int = 100,
878
+ subsample: int | float | None = None,
792
879
  background: Literal[True],
880
+ ignore_unlabeled: bool = False,
881
+ partition_filter_mode: Literal[
882
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
883
+ ] = "include_global",
793
884
  ) -> Job[ClassificationMetrics]:
794
885
  pass
795
886
 
@@ -800,10 +891,16 @@ class ClassificationModel:
800
891
  *,
801
892
  value_column: str = "value",
802
893
  label_column: str = "label",
894
+ partition_column: str | None = None,
803
895
  record_predictions: bool = False,
804
896
  tags: set[str] = {"evaluation"},
805
897
  batch_size: int = 100,
898
+ subsample: int | float | None = None,
806
899
  background: Literal[False] = False,
900
+ ignore_unlabeled: bool = False,
901
+ partition_filter_mode: Literal[
902
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
903
+ ] = "include_global",
807
904
  ) -> ClassificationMetrics:
808
905
  pass
809
906
 
@@ -813,10 +910,16 @@ class ClassificationModel:
813
910
  *,
814
911
  value_column: str = "value",
815
912
  label_column: str = "label",
913
+ partition_column: str | None = None,
816
914
  record_predictions: bool = False,
817
915
  tags: set[str] = {"evaluation"},
818
916
  batch_size: int = 100,
917
+ subsample: int | float | None = None,
819
918
  background: bool = False,
919
+ ignore_unlabeled: bool = False,
920
+ partition_filter_mode: Literal[
921
+ "ignore_partitions", "include_global", "exclude_global", "only_global"
922
+ ] = "include_global",
820
923
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
821
924
  """
822
925
  Evaluate the classification model on a given dataset or datasource
@@ -825,11 +928,18 @@ class ClassificationModel:
825
928
  data: Dataset or Datasource to evaluate the model on
826
929
  value_column: Name of the column that contains the input values to the model
827
930
  label_column: Name of the column containing the expected labels
931
+ partition_column: Optional name of the column that contains the partition IDs
828
932
  record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
829
933
  tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
830
934
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
935
+ subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
831
936
  background: Whether to run the operation in the background and return a job handle
832
-
937
+ ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
938
+ partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
939
+ * `"ignore_partitions"`: Ignore partitions
940
+ * `"include_global"`: Include global memories
941
+ * `"exclude_global"`: Exclude global memories
942
+ * `"only_global"`: Only include global memories
833
943
  Returns:
834
944
  EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
835
945
 
@@ -850,7 +960,11 @@ class ClassificationModel:
850
960
  label_column=label_column,
851
961
  record_predictions=record_predictions,
852
962
  tags=tags,
963
+ subsample=subsample,
853
964
  background=background,
965
+ ignore_unlabeled=ignore_unlabeled,
966
+ partition_column=partition_column,
967
+ partition_filter_mode=partition_filter_mode,
854
968
  )
855
969
  elif isinstance(data, Dataset):
856
970
  return self._evaluate_dataset(
@@ -860,6 +974,9 @@ class ClassificationModel:
860
974
  record_predictions=record_predictions,
861
975
  tags=tags,
862
976
  batch_size=batch_size,
977
+ ignore_unlabeled=ignore_unlabeled,
978
+ partition_column=partition_column,
979
+ partition_filter_mode=partition_filter_mode,
863
980
  )
864
981
  else:
865
982
  raise ValueError(f"Invalid data type: {type(data)}")
@@ -961,11 +1078,9 @@ class ClassificationModel:
961
1078
 
962
1079
  def get_result() -> BootstrappedClassificationModel:
963
1080
  client = OrcaClient._resolve_client()
964
- res = client.GET(
965
- "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
966
- )
1081
+ res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
967
1082
  assert res["result"] is not None
968
1083
  return BootstrappedClassificationModel(res["result"])
969
1084
 
970
- job = Job(response["task_id"], get_result)
1085
+ job = Job(response["job_id"], get_result)
971
1086
  return job if background else job.result()