orca-sdk 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,14 +3,7 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from datetime import datetime
6
- from typing import (
7
- Any,
8
- Generator,
9
- Iterable,
10
- Literal,
11
- cast,
12
- overload,
13
- )
6
+ from typing import Any, Generator, Iterable, Literal, cast, overload
14
7
 
15
8
  from datasets import Dataset
16
9
 
@@ -20,8 +13,10 @@ from .async_client import OrcaAsyncClient
20
13
  from .client import (
21
14
  BootstrapClassificationModelMeta,
22
15
  BootstrapClassificationModelResult,
16
+ ClassificationEvaluationRequest,
23
17
  ClassificationModelMetadata,
24
18
  OrcaClient,
19
+ PostClassificationModelByModelNameOrIdEvaluationParams,
25
20
  PredictiveModelUpdate,
26
21
  RACHeadType,
27
22
  )
@@ -207,7 +202,12 @@ class ClassificationModel:
207
202
  raise ValueError(f"Model with name {name} already exists")
208
203
  elif if_exists == "open":
209
204
  existing = cls.open(name)
210
- for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
205
+ for attribute in {
206
+ "head_type",
207
+ "memory_lookup_count",
208
+ "num_classes",
209
+ "min_memory_weight",
210
+ }:
211
211
  local_attribute = locals()[attribute]
212
212
  existing_attribute = getattr(existing, attribute)
213
213
  if local_attribute is not None and local_attribute != existing_attribute:
@@ -357,6 +357,8 @@ class ClassificationModel:
357
357
  prompt: str | None = None,
358
358
  use_lookup_cache: bool = True,
359
359
  timeout_seconds: int = 10,
360
+ ignore_unlabeled: bool = False,
361
+ use_gpu: bool = True,
360
362
  ) -> list[ClassificationPrediction]:
361
363
  pass
362
364
 
@@ -371,6 +373,8 @@ class ClassificationModel:
371
373
  prompt: str | None = None,
372
374
  use_lookup_cache: bool = True,
373
375
  timeout_seconds: int = 10,
376
+ ignore_unlabeled: bool = False,
377
+ use_gpu: bool = True,
374
378
  ) -> ClassificationPrediction:
375
379
  pass
376
380
 
@@ -384,6 +388,8 @@ class ClassificationModel:
384
388
  prompt: str | None = None,
385
389
  use_lookup_cache: bool = True,
386
390
  timeout_seconds: int = 10,
391
+ ignore_unlabeled: bool = False,
392
+ use_gpu: bool = True,
387
393
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
388
394
  """
389
395
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -402,6 +408,9 @@ class ClassificationModel:
402
408
  prompt: Optional prompt to use for instruction-tuned embedding models
403
409
  use_lookup_cache: Whether to use cached lookup results for faster predictions
404
410
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
411
+ ignore_unlabeled: If True, only use labeled memories during lookup.
412
+ If False (default), allow unlabeled memories when necessary.
413
+ use_gpu: Whether to use GPU for the prediction (defaults to True)
405
414
 
406
415
  Returns:
407
416
  Label prediction or list of label predictions
@@ -447,10 +456,15 @@ class ClassificationModel:
447
456
  for label in expected_labels
448
457
  ]
449
458
 
459
+ if use_gpu:
460
+ endpoint = "/gpu/classification_model/{name_or_id}/prediction"
461
+ else:
462
+ endpoint = "/classification_model/{name_or_id}/prediction"
463
+
450
464
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
451
465
  client = OrcaClient._resolve_client()
452
466
  response = client.POST(
453
- "/gpu/classification_model/{name_or_id}/prediction",
467
+ endpoint,
454
468
  params={"name_or_id": self.id},
455
469
  json={
456
470
  "input_values": value if isinstance(value, list) else [value],
@@ -462,6 +476,7 @@ class ClassificationModel:
462
476
  "filters": cast(list[FilterItem], parsed_filters),
463
477
  "prompt": prompt,
464
478
  "use_lookup_cache": use_lookup_cache,
479
+ "ignore_unlabeled": ignore_unlabeled,
465
480
  },
466
481
  timeout=timeout_seconds,
467
482
  )
@@ -499,6 +514,7 @@ class ClassificationModel:
499
514
  prompt: str | None = None,
500
515
  use_lookup_cache: bool = True,
501
516
  timeout_seconds: int = 10,
517
+ ignore_unlabeled: bool = False,
502
518
  ) -> list[ClassificationPrediction]:
503
519
  pass
504
520
 
@@ -513,6 +529,7 @@ class ClassificationModel:
513
529
  prompt: str | None = None,
514
530
  use_lookup_cache: bool = True,
515
531
  timeout_seconds: int = 10,
532
+ ignore_unlabeled: bool = False,
516
533
  ) -> ClassificationPrediction:
517
534
  pass
518
535
 
@@ -526,6 +543,7 @@ class ClassificationModel:
526
543
  prompt: str | None = None,
527
544
  use_lookup_cache: bool = True,
528
545
  timeout_seconds: int = 10,
546
+ ignore_unlabeled: bool = False,
529
547
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
530
548
  """
531
549
  Asynchronously predict label(s) for the given input value(s) grounded in similar memories
@@ -544,6 +562,8 @@ class ClassificationModel:
544
562
  prompt: Optional prompt to use for instruction-tuned embedding models
545
563
  use_lookup_cache: Whether to use cached lookup results for faster predictions
546
564
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
565
+ ignore_unlabeled: If True, only use labeled memories during lookup.
566
+ If False (default), allow unlabeled memories when necessary.
547
567
 
548
568
  Returns:
549
569
  Label prediction or list of label predictions.
@@ -604,6 +624,7 @@ class ClassificationModel:
604
624
  "filters": cast(list[FilterItem], parsed_filters),
605
625
  "prompt": prompt,
606
626
  "use_lookup_cache": use_lookup_cache,
627
+ "ignore_unlabeled": ignore_unlabeled,
607
628
  },
608
629
  timeout=timeout_seconds,
609
630
  )
@@ -706,7 +727,9 @@ class ClassificationModel:
706
727
  label_column: str,
707
728
  record_predictions: bool,
708
729
  tags: set[str] | None,
730
+ subsample: int | float | None,
709
731
  background: bool = False,
732
+ ignore_unlabeled: bool = False,
710
733
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
711
734
  client = OrcaClient._resolve_client()
712
735
  response = client.POST(
@@ -719,14 +742,16 @@ class ClassificationModel:
719
742
  "memoryset_override_name_or_id": self._memoryset_override_id,
720
743
  "record_telemetry": record_predictions,
721
744
  "telemetry_tags": list(tags) if tags else None,
745
+ "subsample": subsample,
746
+ "ignore_unlabeled": ignore_unlabeled,
722
747
  },
723
748
  )
724
749
 
725
750
  def get_value():
726
751
  client = OrcaClient._resolve_client()
727
752
  res = client.GET(
728
- "/classification_model/{model_name_or_id}/evaluation/{task_id}",
729
- params={"model_name_or_id": self.id, "task_id": response["task_id"]},
753
+ "/classification_model/{model_name_or_id}/evaluation/{job_id}",
754
+ params={"model_name_or_id": self.id, "job_id": response["job_id"]},
730
755
  )
731
756
  assert res["result"] is not None
732
757
  return ClassificationMetrics(
@@ -743,7 +768,7 @@ class ClassificationModel:
743
768
  roc_curve=res["result"].get("roc_curve"),
744
769
  )
745
770
 
746
- job = Job(response["task_id"], get_value)
771
+ job = Job(response["job_id"], get_value)
747
772
  return job if background else job.result()
748
773
 
749
774
  def _evaluate_dataset(
@@ -754,6 +779,7 @@ class ClassificationModel:
754
779
  record_predictions: bool,
755
780
  tags: set[str],
756
781
  batch_size: int,
782
+ ignore_unlabeled: bool,
757
783
  ) -> ClassificationMetrics:
758
784
  if len(dataset) == 0:
759
785
  raise ValueError("Evaluation dataset cannot be empty")
@@ -769,6 +795,7 @@ class ClassificationModel:
769
795
  expected_labels=dataset[i : i + batch_size][label_column],
770
796
  tags=tags,
771
797
  save_telemetry="sync" if record_predictions else "off",
798
+ ignore_unlabeled=ignore_unlabeled,
772
799
  )
773
800
  ]
774
801
 
@@ -789,7 +816,9 @@ class ClassificationModel:
789
816
  record_predictions: bool = False,
790
817
  tags: set[str] = {"evaluation"},
791
818
  batch_size: int = 100,
819
+ subsample: int | float | None = None,
792
820
  background: Literal[True],
821
+ ignore_unlabeled: bool = False,
793
822
  ) -> Job[ClassificationMetrics]:
794
823
  pass
795
824
 
@@ -803,7 +832,9 @@ class ClassificationModel:
803
832
  record_predictions: bool = False,
804
833
  tags: set[str] = {"evaluation"},
805
834
  batch_size: int = 100,
835
+ subsample: int | float | None = None,
806
836
  background: Literal[False] = False,
837
+ ignore_unlabeled: bool = False,
807
838
  ) -> ClassificationMetrics:
808
839
  pass
809
840
 
@@ -816,7 +847,9 @@ class ClassificationModel:
816
847
  record_predictions: bool = False,
817
848
  tags: set[str] = {"evaluation"},
818
849
  batch_size: int = 100,
850
+ subsample: int | float | None = None,
819
851
  background: bool = False,
852
+ ignore_unlabeled: bool = False,
820
853
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
821
854
  """
822
855
  Evaluate the classification model on a given dataset or datasource
@@ -828,7 +861,9 @@ class ClassificationModel:
828
861
  record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
829
862
  tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
830
863
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
864
+ subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
831
865
  background: Whether to run the operation in the background and return a job handle
866
+ ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
832
867
 
833
868
  Returns:
834
869
  EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
@@ -850,7 +885,9 @@ class ClassificationModel:
850
885
  label_column=label_column,
851
886
  record_predictions=record_predictions,
852
887
  tags=tags,
888
+ subsample=subsample,
853
889
  background=background,
890
+ ignore_unlabeled=ignore_unlabeled,
854
891
  )
855
892
  elif isinstance(data, Dataset):
856
893
  return self._evaluate_dataset(
@@ -860,6 +897,7 @@ class ClassificationModel:
860
897
  record_predictions=record_predictions,
861
898
  tags=tags,
862
899
  batch_size=batch_size,
900
+ ignore_unlabeled=ignore_unlabeled,
863
901
  )
864
902
  else:
865
903
  raise ValueError(f"Invalid data type: {type(data)}")
@@ -961,11 +999,9 @@ class ClassificationModel:
961
999
 
962
1000
  def get_result() -> BootstrappedClassificationModel:
963
1001
  client = OrcaClient._resolve_client()
964
- res = client.GET(
965
- "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
966
- )
1002
+ res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
967
1003
  assert res["result"] is not None
968
1004
  return BootstrappedClassificationModel(res["result"])
969
1005
 
970
- job = Job(response["task_id"], get_result)
1006
+ job = Job(response["job_id"], get_result)
971
1007
  return job if background else job.result()