orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,13 +9,16 @@ from datasets import Dataset
9
9
 
10
10
  from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
11
  from ._utils.common import UNSET, CreateMode, DropMode
12
+ from .async_client import OrcaAsyncClient
12
13
  from .client import (
13
14
  BootstrapClassificationModelMeta,
14
15
  BootstrapClassificationModelResult,
16
+ ClassificationEvaluationRequest,
15
17
  ClassificationModelMetadata,
18
+ OrcaClient,
19
+ PostClassificationModelByModelNameOrIdEvaluationParams,
16
20
  PredictiveModelUpdate,
17
21
  RACHeadType,
18
- orca_api,
19
22
  )
20
23
  from .datasource import Datasource
21
24
  from .job import Job
@@ -199,7 +202,12 @@ class ClassificationModel:
199
202
  raise ValueError(f"Model with name {name} already exists")
200
203
  elif if_exists == "open":
201
204
  existing = cls.open(name)
202
- for attribute in {"head_type", "memory_lookup_count", "num_classes", "min_memory_weight"}:
205
+ for attribute in {
206
+ "head_type",
207
+ "memory_lookup_count",
208
+ "num_classes",
209
+ "min_memory_weight",
210
+ }:
203
211
  local_attribute = locals()[attribute]
204
212
  existing_attribute = getattr(existing, attribute)
205
213
  if local_attribute is not None and local_attribute != existing_attribute:
@@ -211,7 +219,8 @@ class ClassificationModel:
211
219
 
212
220
  return existing
213
221
 
214
- metadata = orca_api.POST(
222
+ client = OrcaClient._resolve_client()
223
+ metadata = client.POST(
215
224
  "/classification_model",
216
225
  json={
217
226
  "name": name,
@@ -240,7 +249,8 @@ class ClassificationModel:
240
249
  Raises:
241
250
  LookupError: If the classification model does not exist
242
251
  """
243
- return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
252
+ client = OrcaClient._resolve_client()
253
+ return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
244
254
 
245
255
  @classmethod
246
256
  def exists(cls, name_or_id: str) -> bool:
@@ -267,7 +277,8 @@ class ClassificationModel:
267
277
  Returns:
268
278
  List of handles to all classification models in the OrcaCloud
269
279
  """
270
- return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
280
+ client = OrcaClient._resolve_client()
281
+ return [cls(metadata) for metadata in client.GET("/classification_model")]
271
282
 
272
283
  @classmethod
273
284
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -286,7 +297,8 @@ class ClassificationModel:
286
297
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
287
298
  """
288
299
  try:
289
- orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
300
+ client = OrcaClient._resolve_client()
301
+ client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
290
302
  logging.info(f"Deleted model {name_or_id}")
291
303
  except LookupError:
292
304
  if if_not_exists == "error":
@@ -322,7 +334,8 @@ class ClassificationModel:
322
334
  update["description"] = description
323
335
  if locked is not UNSET:
324
336
  update["locked"] = locked
325
- orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
337
+ client = OrcaClient._resolve_client()
338
+ client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
326
339
  self.refresh()
327
340
 
328
341
  def lock(self) -> None:
@@ -344,6 +357,8 @@ class ClassificationModel:
344
357
  prompt: str | None = None,
345
358
  use_lookup_cache: bool = True,
346
359
  timeout_seconds: int = 10,
360
+ ignore_unlabeled: bool = False,
361
+ use_gpu: bool = True,
347
362
  ) -> list[ClassificationPrediction]:
348
363
  pass
349
364
 
@@ -358,6 +373,8 @@ class ClassificationModel:
358
373
  prompt: str | None = None,
359
374
  use_lookup_cache: bool = True,
360
375
  timeout_seconds: int = 10,
376
+ ignore_unlabeled: bool = False,
377
+ use_gpu: bool = True,
361
378
  ) -> ClassificationPrediction:
362
379
  pass
363
380
 
@@ -371,6 +388,8 @@ class ClassificationModel:
371
388
  prompt: str | None = None,
372
389
  use_lookup_cache: bool = True,
373
390
  timeout_seconds: int = 10,
391
+ ignore_unlabeled: bool = False,
392
+ use_gpu: bool = True,
374
393
  ) -> list[ClassificationPrediction] | ClassificationPrediction:
375
394
  """
376
395
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -389,6 +408,9 @@ class ClassificationModel:
389
408
  prompt: Optional prompt to use for instruction-tuned embedding models
390
409
  use_lookup_cache: Whether to use cached lookup results for faster predictions
391
410
  timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
411
+ ignore_unlabeled: If True, only use labeled memories during lookup.
412
+ If False (default), allow unlabeled memories when necessary.
413
+ use_gpu: Whether to use GPU for the prediction (defaults to True)
392
414
 
393
415
  Returns:
394
416
  Label prediction or list of label predictions
@@ -421,6 +443,159 @@ class ClassificationModel:
421
443
  _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
422
444
  ]
423
445
 
446
+ if any(_is_metric_column(filter[0]) for filter in filters):
447
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
448
+
449
+ if isinstance(expected_labels, int):
450
+ expected_labels = [expected_labels]
451
+ elif isinstance(expected_labels, str):
452
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
453
+ elif isinstance(expected_labels, list):
454
+ expected_labels = [
455
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
456
+ for label in expected_labels
457
+ ]
458
+
459
+ if use_gpu:
460
+ endpoint = "/gpu/classification_model/{name_or_id}/prediction"
461
+ else:
462
+ endpoint = "/classification_model/{name_or_id}/prediction"
463
+
464
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
465
+ client = OrcaClient._resolve_client()
466
+ response = client.POST(
467
+ endpoint,
468
+ params={"name_or_id": self.id},
469
+ json={
470
+ "input_values": value if isinstance(value, list) else [value],
471
+ "memoryset_override_name_or_id": self._memoryset_override_id,
472
+ "expected_labels": expected_labels,
473
+ "tags": list(tags or set()),
474
+ "save_telemetry": telemetry_on,
475
+ "save_telemetry_synchronously": telemetry_sync,
476
+ "filters": cast(list[FilterItem], parsed_filters),
477
+ "prompt": prompt,
478
+ "use_lookup_cache": use_lookup_cache,
479
+ "ignore_unlabeled": ignore_unlabeled,
480
+ },
481
+ timeout=timeout_seconds,
482
+ )
483
+
484
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
485
+ raise RuntimeError("Failed to save prediction to database.")
486
+
487
+ predictions = [
488
+ ClassificationPrediction(
489
+ prediction_id=prediction["prediction_id"],
490
+ label=prediction["label"],
491
+ label_name=prediction["label_name"],
492
+ score=None,
493
+ confidence=prediction["confidence"],
494
+ anomaly_score=prediction["anomaly_score"],
495
+ memoryset=self.memoryset,
496
+ model=self,
497
+ logits=prediction["logits"],
498
+ input_value=input_value,
499
+ )
500
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
501
+ ]
502
+ self._last_prediction_was_batch = isinstance(value, list)
503
+ self._last_prediction = predictions[-1]
504
+ return predictions if isinstance(value, list) else predictions[0]
505
+
506
+ @overload
507
+ async def apredict(
508
+ self,
509
+ value: list[str],
510
+ expected_labels: list[int] | None = None,
511
+ filters: list[FilterItemTuple] = [],
512
+ tags: set[str] | None = None,
513
+ save_telemetry: TelemetryMode = "on",
514
+ prompt: str | None = None,
515
+ use_lookup_cache: bool = True,
516
+ timeout_seconds: int = 10,
517
+ ignore_unlabeled: bool = False,
518
+ ) -> list[ClassificationPrediction]:
519
+ pass
520
+
521
+ @overload
522
+ async def apredict(
523
+ self,
524
+ value: str,
525
+ expected_labels: int | None = None,
526
+ filters: list[FilterItemTuple] = [],
527
+ tags: set[str] | None = None,
528
+ save_telemetry: TelemetryMode = "on",
529
+ prompt: str | None = None,
530
+ use_lookup_cache: bool = True,
531
+ timeout_seconds: int = 10,
532
+ ignore_unlabeled: bool = False,
533
+ ) -> ClassificationPrediction:
534
+ pass
535
+
536
+ async def apredict(
537
+ self,
538
+ value: list[str] | str,
539
+ expected_labels: list[int] | list[str] | int | str | None = None,
540
+ filters: list[FilterItemTuple] = [],
541
+ tags: set[str] | None = None,
542
+ save_telemetry: TelemetryMode = "on",
543
+ prompt: str | None = None,
544
+ use_lookup_cache: bool = True,
545
+ timeout_seconds: int = 10,
546
+ ignore_unlabeled: bool = False,
547
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
548
+ """
549
+ Asynchronously predict label(s) for the given input value(s) grounded in similar memories
550
+
551
+ Params:
552
+ value: Value(s) to get predict the labels of
553
+ expected_labels: Expected label(s) for the given input to record for model evaluation
554
+ filters: Optional filters to apply during memory lookup
555
+ tags: Tags to add to the prediction(s)
556
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
557
+ * `"off"`: Do not save telemetry
558
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
559
+ environment variable is set.
560
+ * `"sync"`: Save telemetry synchronously
561
+ * `"async"`: Save telemetry asynchronously
562
+ prompt: Optional prompt to use for instruction-tuned embedding models
563
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
564
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
565
+ ignore_unlabeled: If True, only use labeled memories during lookup.
566
+ If False (default), allow unlabeled memories when necessary.
567
+
568
+ Returns:
569
+ Label prediction or list of label predictions.
570
+
571
+ Raises:
572
+ ValueError: If timeout_seconds is not a positive integer
573
+ TimeoutError: If the request times out after the specified duration
574
+
575
+ Examples:
576
+ Predict the label for a single value:
577
+ >>> prediction = await model.apredict("I am happy", tags={"test"})
578
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
579
+
580
+ Predict the labels for a list of values:
581
+ >>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
582
+ [
583
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
584
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
585
+ ]
586
+
587
+ Using a prompt with an instruction-tuned embedding model:
588
+ >>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
589
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
590
+ """
591
+
592
+ if timeout_seconds <= 0:
593
+ raise ValueError("timeout_seconds must be a positive integer")
594
+
595
+ parsed_filters = [
596
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
597
+ ]
598
+
424
599
  if any(_is_metric_column(filter[0]) for filter in filters):
425
600
  raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
426
601
 
@@ -435,7 +610,8 @@ class ClassificationModel:
435
610
  ]
436
611
 
437
612
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
438
- response = orca_api.POST(
613
+ client = OrcaAsyncClient._resolve_client()
614
+ response = await client.POST(
439
615
  "/gpu/classification_model/{name_or_id}/prediction",
440
616
  params={"name_or_id": self.id},
441
617
  json={
@@ -448,6 +624,7 @@ class ClassificationModel:
448
624
  "filters": cast(list[FilterItem], parsed_filters),
449
625
  "prompt": prompt,
450
626
  "use_lookup_cache": use_lookup_cache,
627
+ "ignore_unlabeled": ignore_unlabeled,
451
628
  },
452
629
  timeout=timeout_seconds,
453
630
  )
@@ -515,7 +692,8 @@ class ClassificationModel:
515
692
  >>> predictions = model.predictions(expected_label_match=False)
516
693
  [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
517
694
  """
518
- predictions = orca_api.POST(
695
+ client = OrcaClient._resolve_client()
696
+ predictions = client.POST(
519
697
  "/telemetry/prediction",
520
698
  json={
521
699
  "model_id": self.id,
@@ -549,9 +727,12 @@ class ClassificationModel:
549
727
  label_column: str,
550
728
  record_predictions: bool,
551
729
  tags: set[str] | None,
730
+ subsample: int | float | None,
552
731
  background: bool = False,
732
+ ignore_unlabeled: bool = False,
553
733
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
554
- response = orca_api.POST(
734
+ client = OrcaClient._resolve_client()
735
+ response = client.POST(
555
736
  "/classification_model/{model_name_or_id}/evaluation",
556
737
  params={"model_name_or_id": self.id},
557
738
  json={
@@ -561,13 +742,16 @@ class ClassificationModel:
561
742
  "memoryset_override_name_or_id": self._memoryset_override_id,
562
743
  "record_telemetry": record_predictions,
563
744
  "telemetry_tags": list(tags) if tags else None,
745
+ "subsample": subsample,
746
+ "ignore_unlabeled": ignore_unlabeled,
564
747
  },
565
748
  )
566
749
 
567
750
  def get_value():
568
- res = orca_api.GET(
569
- "/classification_model/{model_name_or_id}/evaluation/{task_id}",
570
- params={"model_name_or_id": self.id, "task_id": response["task_id"]},
751
+ client = OrcaClient._resolve_client()
752
+ res = client.GET(
753
+ "/classification_model/{model_name_or_id}/evaluation/{job_id}",
754
+ params={"model_name_or_id": self.id, "job_id": response["job_id"]},
571
755
  )
572
756
  assert res["result"] is not None
573
757
  return ClassificationMetrics(
@@ -584,7 +768,7 @@ class ClassificationModel:
584
768
  roc_curve=res["result"].get("roc_curve"),
585
769
  )
586
770
 
587
- job = Job(response["task_id"], get_value)
771
+ job = Job(response["job_id"], get_value)
588
772
  return job if background else job.result()
589
773
 
590
774
  def _evaluate_dataset(
@@ -595,6 +779,7 @@ class ClassificationModel:
595
779
  record_predictions: bool,
596
780
  tags: set[str],
597
781
  batch_size: int,
782
+ ignore_unlabeled: bool,
598
783
  ) -> ClassificationMetrics:
599
784
  if len(dataset) == 0:
600
785
  raise ValueError("Evaluation dataset cannot be empty")
@@ -610,6 +795,7 @@ class ClassificationModel:
610
795
  expected_labels=dataset[i : i + batch_size][label_column],
611
796
  tags=tags,
612
797
  save_telemetry="sync" if record_predictions else "off",
798
+ ignore_unlabeled=ignore_unlabeled,
613
799
  )
614
800
  ]
615
801
 
@@ -630,7 +816,9 @@ class ClassificationModel:
630
816
  record_predictions: bool = False,
631
817
  tags: set[str] = {"evaluation"},
632
818
  batch_size: int = 100,
819
+ subsample: int | float | None = None,
633
820
  background: Literal[True],
821
+ ignore_unlabeled: bool = False,
634
822
  ) -> Job[ClassificationMetrics]:
635
823
  pass
636
824
 
@@ -644,7 +832,9 @@ class ClassificationModel:
644
832
  record_predictions: bool = False,
645
833
  tags: set[str] = {"evaluation"},
646
834
  batch_size: int = 100,
835
+ subsample: int | float | None = None,
647
836
  background: Literal[False] = False,
837
+ ignore_unlabeled: bool = False,
648
838
  ) -> ClassificationMetrics:
649
839
  pass
650
840
 
@@ -657,7 +847,9 @@ class ClassificationModel:
657
847
  record_predictions: bool = False,
658
848
  tags: set[str] = {"evaluation"},
659
849
  batch_size: int = 100,
850
+ subsample: int | float | None = None,
660
851
  background: bool = False,
852
+ ignore_unlabeled: bool = False,
661
853
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
662
854
  """
663
855
  Evaluate the classification model on a given dataset or datasource
@@ -669,7 +861,9 @@ class ClassificationModel:
669
861
  record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
670
862
  tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
671
863
  batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
864
+ subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
672
865
  background: Whether to run the operation in the background and return a job handle
866
+ ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
673
867
 
674
868
  Returns:
675
869
  EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
@@ -691,7 +885,9 @@ class ClassificationModel:
691
885
  label_column=label_column,
692
886
  record_predictions=record_predictions,
693
887
  tags=tags,
888
+ subsample=subsample,
694
889
  background=background,
890
+ ignore_unlabeled=ignore_unlabeled,
695
891
  )
696
892
  elif isinstance(data, Dataset):
697
893
  return self._evaluate_dataset(
@@ -701,6 +897,7 @@ class ClassificationModel:
701
897
  record_predictions=record_predictions,
702
898
  tags=tags,
703
899
  batch_size=batch_size,
900
+ ignore_unlabeled=ignore_unlabeled,
704
901
  )
705
902
  else:
706
903
  raise ValueError(f"Invalid data type: {type(data)}")
@@ -773,7 +970,8 @@ class ClassificationModel:
773
970
  ValueError: If the value does not match previous value types for the category, or is a
774
971
  [`float`][float] that is not between `-1.0` and `+1.0`.
775
972
  """
776
- orca_api.PUT(
973
+ client = OrcaClient._resolve_client()
974
+ client.PUT(
777
975
  "/telemetry/prediction/feedback",
778
976
  json=[
779
977
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
@@ -788,7 +986,8 @@ class ClassificationModel:
788
986
  num_examples_per_label: int,
789
987
  background: bool = False,
790
988
  ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
791
- response = orca_api.POST(
989
+ client = OrcaClient._resolve_client()
990
+ response = client.POST(
792
991
  "/agents/bootstrap_classification_model",
793
992
  json={
794
993
  "model_description": model_description,
@@ -799,11 +998,10 @@ class ClassificationModel:
799
998
  )
800
999
 
801
1000
  def get_result() -> BootstrappedClassificationModel:
802
- res = orca_api.GET(
803
- "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
804
- )
1001
+ client = OrcaClient._resolve_client()
1002
+ res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
805
1003
  assert res["result"] is not None
806
1004
  return BootstrappedClassificationModel(res["result"])
807
1005
 
808
- job = Job(response["task_id"], get_result)
1006
+ job = Job(response["job_id"], get_result)
809
1007
  return job if background else job.result()
@@ -53,9 +53,10 @@ def test_create_model_already_exists_return(readonly_memoryset, classification_m
53
53
  assert new_model.memory_lookup_count == 3
54
54
 
55
55
 
56
- def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
57
- with pytest.raises(ValueError, match="Invalid API key"):
58
- ClassificationModel.create("test_model", readonly_memoryset)
56
+ def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ ClassificationModel.create("test_model", readonly_memoryset)
59
60
 
60
61
 
61
62
  def test_get_model(classification_model: ClassificationModel):
@@ -68,9 +69,10 @@ def test_get_model(classification_model: ClassificationModel):
68
69
  assert fetched_model == classification_model
69
70
 
70
71
 
71
- def test_get_model_unauthenticated(unauthenticated):
72
- with pytest.raises(ValueError, match="Invalid API key"):
73
- ClassificationModel.open("test_model")
72
+ def test_get_model_unauthenticated(unauthenticated_client):
73
+ with unauthenticated_client.use():
74
+ with pytest.raises(ValueError, match="Invalid API key"):
75
+ ClassificationModel.open("test_model")
74
76
 
75
77
 
76
78
  def test_get_model_invalid_input():
@@ -83,9 +85,10 @@ def test_get_model_not_found():
83
85
  ClassificationModel.open(str(uuid4()))
84
86
 
85
87
 
86
- def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
87
- with pytest.raises(LookupError):
88
- ClassificationModel.open(classification_model.name)
88
+ def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
89
+ with unauthorized_client.use():
90
+ with pytest.raises(LookupError):
91
+ ClassificationModel.open(classification_model.name)
89
92
 
90
93
 
91
94
  def test_list_models(classification_model: ClassificationModel):
@@ -94,13 +97,15 @@ def test_list_models(classification_model: ClassificationModel):
94
97
  assert any(model.name == model.name for model in models)
95
98
 
96
99
 
97
- def test_list_models_unauthenticated(unauthenticated):
98
- with pytest.raises(ValueError, match="Invalid API key"):
99
- ClassificationModel.all()
100
+ def test_list_models_unauthenticated(unauthenticated_client):
101
+ with unauthenticated_client.use():
102
+ with pytest.raises(ValueError, match="Invalid API key"):
103
+ ClassificationModel.all()
100
104
 
101
105
 
102
- def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
103
- assert ClassificationModel.all() == []
106
+ def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
107
+ with unauthorized_client.use():
108
+ assert ClassificationModel.all() == []
104
109
 
105
110
 
106
111
  def test_update_model_attributes(classification_model: ClassificationModel):
@@ -131,9 +136,10 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
131
136
  ClassificationModel.open("model_to_delete")
132
137
 
133
138
 
134
- def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
135
- with pytest.raises(ValueError, match="Invalid API key"):
136
- ClassificationModel.drop(classification_model.name)
139
+ def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
140
+ with unauthenticated_client.use():
141
+ with pytest.raises(ValueError, match="Invalid API key"):
142
+ ClassificationModel.drop(classification_model.name)
137
143
 
138
144
 
139
145
  def test_delete_model_not_found():
@@ -143,9 +149,10 @@ def test_delete_model_not_found():
143
149
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
144
150
 
145
151
 
146
- def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
147
- with pytest.raises(LookupError):
148
- ClassificationModel.drop(classification_model.name)
152
+ def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
153
+ with unauthorized_client.use():
154
+ with pytest.raises(LookupError):
155
+ ClassificationModel.drop(classification_model.name)
149
156
 
150
157
 
151
158
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
@@ -254,14 +261,16 @@ def test_predict_disable_telemetry(classification_model: ClassificationModel, la
254
261
  assert 0 <= predictions[1].confidence <= 1
255
262
 
256
263
 
257
- def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
258
- with pytest.raises(ValueError, match="Invalid API key"):
259
- classification_model.predict(["Do you love soup?", "Are cats cute?"])
264
+ def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
265
+ with unauthenticated_client.use():
266
+ with pytest.raises(ValueError, match="Invalid API key"):
267
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
260
268
 
261
269
 
262
- def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
263
- with pytest.raises(LookupError):
264
- classification_model.predict(["Do you love soup?", "Are cats cute?"])
270
+ def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
271
+ with unauthorized_client.use():
272
+ with pytest.raises(LookupError):
273
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
265
274
 
266
275
 
267
276
  def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
@@ -396,7 +405,7 @@ def test_last_prediction_with_single(classification_model: ClassificationModel):
396
405
  def test_explain(writable_memoryset: LabeledMemoryset):
397
406
 
398
407
  writable_memoryset.analyze(
399
- {"name": "neighbor", "neighbor_counts": [1, 3]},
408
+ {"name": "distribution", "neighbor_counts": [1, 3]},
400
409
  lookup_count=3,
401
410
  )
402
411
 
@@ -430,7 +439,7 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
430
439
  """Test getting action recommendations for predictions"""
431
440
 
432
441
  writable_memoryset.analyze(
433
- {"name": "neighbor", "neighbor_counts": [1, 3]},
442
+ {"name": "distribution", "neighbor_counts": [1, 3]},
434
443
  lookup_count=3,
435
444
  )
436
445
 
@@ -494,3 +503,62 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
494
503
  # Both should work and return valid predictions
495
504
  assert prediction_with_prompt.label is not None
496
505
  assert prediction_without_prompt.label is not None
506
+
507
+
508
+ @pytest.mark.asyncio
509
+ async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
510
+ """Test async prediction with a single value"""
511
+ prediction = await classification_model.apredict("Do you love soup?")
512
+ assert isinstance(prediction, ClassificationPrediction)
513
+ assert prediction.prediction_id is not None
514
+ assert prediction.label == 0
515
+ assert prediction.label_name == label_names[0]
516
+ assert 0 <= prediction.confidence <= 1
517
+ assert prediction.logits is not None
518
+ assert len(prediction.logits) == 2
519
+
520
+
521
+ @pytest.mark.asyncio
522
+ async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
523
+ """Test async prediction with a batch of values"""
524
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
525
+ assert len(predictions) == 2
526
+ assert predictions[0].prediction_id is not None
527
+ assert predictions[1].prediction_id is not None
528
+ assert predictions[0].label == 0
529
+ assert predictions[0].label_name == label_names[0]
530
+ assert 0 <= predictions[0].confidence <= 1
531
+ assert predictions[1].label == 1
532
+ assert predictions[1].label_name == label_names[1]
533
+ assert 0 <= predictions[1].confidence <= 1
534
+
535
+
536
+ @pytest.mark.asyncio
537
+ async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
538
+ """Test async prediction with expected labels"""
539
+ prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
540
+ assert prediction.expected_label == 1
541
+
542
+
543
+ @pytest.mark.asyncio
544
+ async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
545
+ """Test async prediction with telemetry disabled"""
546
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
547
+ assert len(predictions) == 2
548
+ assert predictions[0].prediction_id is None
549
+ assert predictions[1].prediction_id is None
550
+ assert predictions[0].label == 0
551
+ assert predictions[0].label_name == label_names[0]
552
+ assert 0 <= predictions[0].confidence <= 1
553
+ assert predictions[1].label == 1
554
+ assert predictions[1].label_name == label_names[1]
555
+ assert 0 <= predictions[1].confidence <= 1
556
+
557
+
558
+ @pytest.mark.asyncio
559
+ async def test_predict_async_with_filters(classification_model: ClassificationModel):
560
+ """Test async prediction with filters"""
561
+ # there are no memories with label 0 and key g2, so we force a wrong prediction
562
+ filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
563
+ assert filtered_prediction.label == 1
564
+ assert filtered_prediction.label_name == "cats"