orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,19 +3,27 @@ from __future__ import annotations
3
3
  import logging
4
4
  from contextlib import contextmanager
5
5
  from datetime import datetime
6
- from typing import Any, Generator, Iterable, Literal, cast, overload
6
+ from typing import (
7
+ Any,
8
+ Generator,
9
+ Iterable,
10
+ Literal,
11
+ cast,
12
+ overload,
13
+ )
7
14
 
8
15
  from datasets import Dataset
9
16
 
10
17
  from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
11
18
  from ._utils.common import UNSET, CreateMode, DropMode
19
+ from .async_client import OrcaAsyncClient
12
20
  from .client import (
13
21
  BootstrapClassificationModelMeta,
14
22
  BootstrapClassificationModelResult,
15
23
  ClassificationModelMetadata,
24
+ OrcaClient,
16
25
  PredictiveModelUpdate,
17
26
  RACHeadType,
18
- orca_api,
19
27
  )
20
28
  from .datasource import Datasource
21
29
  from .job import Job
@@ -211,7 +219,8 @@ class ClassificationModel:
211
219
 
212
220
  return existing
213
221
 
214
- metadata = orca_api.POST(
222
+ client = OrcaClient._resolve_client()
223
+ metadata = client.POST(
215
224
  "/classification_model",
216
225
  json={
217
226
  "name": name,
@@ -240,7 +249,8 @@ class ClassificationModel:
240
249
  Raises:
241
250
  LookupError: If the classification model does not exist
242
251
  """
243
- return cls(orca_api.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
252
+ client = OrcaClient._resolve_client()
253
+ return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
244
254
 
245
255
  @classmethod
246
256
  def exists(cls, name_or_id: str) -> bool:
@@ -267,7 +277,8 @@ class ClassificationModel:
267
277
  Returns:
268
278
  List of handles to all classification models in the OrcaCloud
269
279
  """
270
- return [cls(metadata) for metadata in orca_api.GET("/classification_model")]
280
+ client = OrcaClient._resolve_client()
281
+ return [cls(metadata) for metadata in client.GET("/classification_model")]
271
282
 
272
283
  @classmethod
273
284
  def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
@@ -286,7 +297,8 @@ class ClassificationModel:
286
297
  LookupError: If the classification model does not exist and if_not_exists is `"error"`
287
298
  """
288
299
  try:
289
- orca_api.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
300
+ client = OrcaClient._resolve_client()
301
+ client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
290
302
  logging.info(f"Deleted model {name_or_id}")
291
303
  except LookupError:
292
304
  if if_not_exists == "error":
@@ -322,7 +334,8 @@ class ClassificationModel:
322
334
  update["description"] = description
323
335
  if locked is not UNSET:
324
336
  update["locked"] = locked
325
- orca_api.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
337
+ client = OrcaClient._resolve_client()
338
+ client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
326
339
  self.refresh()
327
340
 
328
341
  def lock(self) -> None:
@@ -435,7 +448,150 @@ class ClassificationModel:
435
448
  ]
436
449
 
437
450
  telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
438
- response = orca_api.POST(
451
+ client = OrcaClient._resolve_client()
452
+ response = client.POST(
453
+ "/gpu/classification_model/{name_or_id}/prediction",
454
+ params={"name_or_id": self.id},
455
+ json={
456
+ "input_values": value if isinstance(value, list) else [value],
457
+ "memoryset_override_name_or_id": self._memoryset_override_id,
458
+ "expected_labels": expected_labels,
459
+ "tags": list(tags or set()),
460
+ "save_telemetry": telemetry_on,
461
+ "save_telemetry_synchronously": telemetry_sync,
462
+ "filters": cast(list[FilterItem], parsed_filters),
463
+ "prompt": prompt,
464
+ "use_lookup_cache": use_lookup_cache,
465
+ },
466
+ timeout=timeout_seconds,
467
+ )
468
+
469
+ if telemetry_on and any(p["prediction_id"] is None for p in response):
470
+ raise RuntimeError("Failed to save prediction to database.")
471
+
472
+ predictions = [
473
+ ClassificationPrediction(
474
+ prediction_id=prediction["prediction_id"],
475
+ label=prediction["label"],
476
+ label_name=prediction["label_name"],
477
+ score=None,
478
+ confidence=prediction["confidence"],
479
+ anomaly_score=prediction["anomaly_score"],
480
+ memoryset=self.memoryset,
481
+ model=self,
482
+ logits=prediction["logits"],
483
+ input_value=input_value,
484
+ )
485
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
486
+ ]
487
+ self._last_prediction_was_batch = isinstance(value, list)
488
+ self._last_prediction = predictions[-1]
489
+ return predictions if isinstance(value, list) else predictions[0]
490
+
491
+ @overload
492
+ async def apredict(
493
+ self,
494
+ value: list[str],
495
+ expected_labels: list[int] | None = None,
496
+ filters: list[FilterItemTuple] = [],
497
+ tags: set[str] | None = None,
498
+ save_telemetry: TelemetryMode = "on",
499
+ prompt: str | None = None,
500
+ use_lookup_cache: bool = True,
501
+ timeout_seconds: int = 10,
502
+ ) -> list[ClassificationPrediction]:
503
+ pass
504
+
505
+ @overload
506
+ async def apredict(
507
+ self,
508
+ value: str,
509
+ expected_labels: int | None = None,
510
+ filters: list[FilterItemTuple] = [],
511
+ tags: set[str] | None = None,
512
+ save_telemetry: TelemetryMode = "on",
513
+ prompt: str | None = None,
514
+ use_lookup_cache: bool = True,
515
+ timeout_seconds: int = 10,
516
+ ) -> ClassificationPrediction:
517
+ pass
518
+
519
+ async def apredict(
520
+ self,
521
+ value: list[str] | str,
522
+ expected_labels: list[int] | list[str] | int | str | None = None,
523
+ filters: list[FilterItemTuple] = [],
524
+ tags: set[str] | None = None,
525
+ save_telemetry: TelemetryMode = "on",
526
+ prompt: str | None = None,
527
+ use_lookup_cache: bool = True,
528
+ timeout_seconds: int = 10,
529
+ ) -> list[ClassificationPrediction] | ClassificationPrediction:
530
+ """
531
+ Asynchronously predict label(s) for the given input value(s) grounded in similar memories
532
+
533
+ Params:
534
+ value: Value(s) to get predict the labels of
535
+ expected_labels: Expected label(s) for the given input to record for model evaluation
536
+ filters: Optional filters to apply during memory lookup
537
+ tags: Tags to add to the prediction(s)
538
+ save_telemetry: Whether to save telemetry for the prediction(s). One of
539
+ * `"off"`: Do not save telemetry
540
+ * `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
541
+ environment variable is set.
542
+ * `"sync"`: Save telemetry synchronously
543
+ * `"async"`: Save telemetry asynchronously
544
+ prompt: Optional prompt to use for instruction-tuned embedding models
545
+ use_lookup_cache: Whether to use cached lookup results for faster predictions
546
+ timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
547
+
548
+ Returns:
549
+ Label prediction or list of label predictions.
550
+
551
+ Raises:
552
+ ValueError: If timeout_seconds is not a positive integer
553
+ TimeoutError: If the request times out after the specified duration
554
+
555
+ Examples:
556
+ Predict the label for a single value:
557
+ >>> prediction = await model.apredict("I am happy", tags={"test"})
558
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
559
+
560
+ Predict the labels for a list of values:
561
+ >>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
562
+ [
563
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
564
+ ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
565
+ ]
566
+
567
+ Using a prompt with an instruction-tuned embedding model:
568
+ >>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
569
+ ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
570
+ """
571
+
572
+ if timeout_seconds <= 0:
573
+ raise ValueError("timeout_seconds must be a positive integer")
574
+
575
+ parsed_filters = [
576
+ _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
577
+ ]
578
+
579
+ if any(_is_metric_column(filter[0]) for filter in filters):
580
+ raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
581
+
582
+ if isinstance(expected_labels, int):
583
+ expected_labels = [expected_labels]
584
+ elif isinstance(expected_labels, str):
585
+ expected_labels = [self.memoryset.label_names.index(expected_labels)]
586
+ elif isinstance(expected_labels, list):
587
+ expected_labels = [
588
+ self.memoryset.label_names.index(label) if isinstance(label, str) else label
589
+ for label in expected_labels
590
+ ]
591
+
592
+ telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
593
+ client = OrcaAsyncClient._resolve_client()
594
+ response = await client.POST(
439
595
  "/gpu/classification_model/{name_or_id}/prediction",
440
596
  params={"name_or_id": self.id},
441
597
  json={
@@ -515,7 +671,8 @@ class ClassificationModel:
515
671
  >>> predictions = model.predictions(expected_label_match=False)
516
672
  [ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
517
673
  """
518
- predictions = orca_api.POST(
674
+ client = OrcaClient._resolve_client()
675
+ predictions = client.POST(
519
676
  "/telemetry/prediction",
520
677
  json={
521
678
  "model_id": self.id,
@@ -551,7 +708,8 @@ class ClassificationModel:
551
708
  tags: set[str] | None,
552
709
  background: bool = False,
553
710
  ) -> ClassificationMetrics | Job[ClassificationMetrics]:
554
- response = orca_api.POST(
711
+ client = OrcaClient._resolve_client()
712
+ response = client.POST(
555
713
  "/classification_model/{model_name_or_id}/evaluation",
556
714
  params={"model_name_or_id": self.id},
557
715
  json={
@@ -565,7 +723,8 @@ class ClassificationModel:
565
723
  )
566
724
 
567
725
  def get_value():
568
- res = orca_api.GET(
726
+ client = OrcaClient._resolve_client()
727
+ res = client.GET(
569
728
  "/classification_model/{model_name_or_id}/evaluation/{task_id}",
570
729
  params={"model_name_or_id": self.id, "task_id": response["task_id"]},
571
730
  )
@@ -773,7 +932,8 @@ class ClassificationModel:
773
932
  ValueError: If the value does not match previous value types for the category, or is a
774
933
  [`float`][float] that is not between `-1.0` and `+1.0`.
775
934
  """
776
- orca_api.PUT(
935
+ client = OrcaClient._resolve_client()
936
+ client.PUT(
777
937
  "/telemetry/prediction/feedback",
778
938
  json=[
779
939
  _parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
@@ -788,7 +948,8 @@ class ClassificationModel:
788
948
  num_examples_per_label: int,
789
949
  background: bool = False,
790
950
  ) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
791
- response = orca_api.POST(
951
+ client = OrcaClient._resolve_client()
952
+ response = client.POST(
792
953
  "/agents/bootstrap_classification_model",
793
954
  json={
794
955
  "model_description": model_description,
@@ -799,7 +960,8 @@ class ClassificationModel:
799
960
  )
800
961
 
801
962
  def get_result() -> BootstrappedClassificationModel:
802
- res = orca_api.GET(
963
+ client = OrcaClient._resolve_client()
964
+ res = client.GET(
803
965
  "/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
804
966
  )
805
967
  assert res["result"] is not None
@@ -53,9 +53,10 @@ def test_create_model_already_exists_return(readonly_memoryset, classification_m
53
53
  assert new_model.memory_lookup_count == 3
54
54
 
55
55
 
56
- def test_create_model_unauthenticated(unauthenticated, readonly_memoryset: LabeledMemoryset):
57
- with pytest.raises(ValueError, match="Invalid API key"):
58
- ClassificationModel.create("test_model", readonly_memoryset)
56
+ def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ ClassificationModel.create("test_model", readonly_memoryset)
59
60
 
60
61
 
61
62
  def test_get_model(classification_model: ClassificationModel):
@@ -68,9 +69,10 @@ def test_get_model(classification_model: ClassificationModel):
68
69
  assert fetched_model == classification_model
69
70
 
70
71
 
71
- def test_get_model_unauthenticated(unauthenticated):
72
- with pytest.raises(ValueError, match="Invalid API key"):
73
- ClassificationModel.open("test_model")
72
+ def test_get_model_unauthenticated(unauthenticated_client):
73
+ with unauthenticated_client.use():
74
+ with pytest.raises(ValueError, match="Invalid API key"):
75
+ ClassificationModel.open("test_model")
74
76
 
75
77
 
76
78
  def test_get_model_invalid_input():
@@ -83,9 +85,10 @@ def test_get_model_not_found():
83
85
  ClassificationModel.open(str(uuid4()))
84
86
 
85
87
 
86
- def test_get_model_unauthorized(unauthorized, classification_model: ClassificationModel):
87
- with pytest.raises(LookupError):
88
- ClassificationModel.open(classification_model.name)
88
+ def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
89
+ with unauthorized_client.use():
90
+ with pytest.raises(LookupError):
91
+ ClassificationModel.open(classification_model.name)
89
92
 
90
93
 
91
94
  def test_list_models(classification_model: ClassificationModel):
@@ -94,13 +97,15 @@ def test_list_models(classification_model: ClassificationModel):
94
97
  assert any(model.name == model.name for model in models)
95
98
 
96
99
 
97
- def test_list_models_unauthenticated(unauthenticated):
98
- with pytest.raises(ValueError, match="Invalid API key"):
99
- ClassificationModel.all()
100
+ def test_list_models_unauthenticated(unauthenticated_client):
101
+ with unauthenticated_client.use():
102
+ with pytest.raises(ValueError, match="Invalid API key"):
103
+ ClassificationModel.all()
100
104
 
101
105
 
102
- def test_list_models_unauthorized(unauthorized, classification_model: ClassificationModel):
103
- assert ClassificationModel.all() == []
106
+ def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
107
+ with unauthorized_client.use():
108
+ assert ClassificationModel.all() == []
104
109
 
105
110
 
106
111
  def test_update_model_attributes(classification_model: ClassificationModel):
@@ -131,9 +136,10 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
131
136
  ClassificationModel.open("model_to_delete")
132
137
 
133
138
 
134
- def test_delete_model_unauthenticated(unauthenticated, classification_model: ClassificationModel):
135
- with pytest.raises(ValueError, match="Invalid API key"):
136
- ClassificationModel.drop(classification_model.name)
139
+ def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
140
+ with unauthenticated_client.use():
141
+ with pytest.raises(ValueError, match="Invalid API key"):
142
+ ClassificationModel.drop(classification_model.name)
137
143
 
138
144
 
139
145
  def test_delete_model_not_found():
@@ -143,9 +149,10 @@ def test_delete_model_not_found():
143
149
  ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
144
150
 
145
151
 
146
- def test_delete_model_unauthorized(unauthorized, classification_model: ClassificationModel):
147
- with pytest.raises(LookupError):
148
- ClassificationModel.drop(classification_model.name)
152
+ def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
153
+ with unauthorized_client.use():
154
+ with pytest.raises(LookupError):
155
+ ClassificationModel.drop(classification_model.name)
149
156
 
150
157
 
151
158
  def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
@@ -254,14 +261,16 @@ def test_predict_disable_telemetry(classification_model: ClassificationModel, la
254
261
  assert 0 <= predictions[1].confidence <= 1
255
262
 
256
263
 
257
- def test_predict_unauthenticated(unauthenticated, classification_model: ClassificationModel):
258
- with pytest.raises(ValueError, match="Invalid API key"):
259
- classification_model.predict(["Do you love soup?", "Are cats cute?"])
264
+ def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
265
+ with unauthenticated_client.use():
266
+ with pytest.raises(ValueError, match="Invalid API key"):
267
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
260
268
 
261
269
 
262
- def test_predict_unauthorized(unauthorized, classification_model: ClassificationModel):
263
- with pytest.raises(LookupError):
264
- classification_model.predict(["Do you love soup?", "Are cats cute?"])
270
+ def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
271
+ with unauthorized_client.use():
272
+ with pytest.raises(LookupError):
273
+ classification_model.predict(["Do you love soup?", "Are cats cute?"])
265
274
 
266
275
 
267
276
  def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
@@ -396,7 +405,7 @@ def test_last_prediction_with_single(classification_model: ClassificationModel):
396
405
  def test_explain(writable_memoryset: LabeledMemoryset):
397
406
 
398
407
  writable_memoryset.analyze(
399
- {"name": "neighbor", "neighbor_counts": [1, 3]},
408
+ {"name": "distribution", "neighbor_counts": [1, 3]},
400
409
  lookup_count=3,
401
410
  )
402
411
 
@@ -430,7 +439,7 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
430
439
  """Test getting action recommendations for predictions"""
431
440
 
432
441
  writable_memoryset.analyze(
433
- {"name": "neighbor", "neighbor_counts": [1, 3]},
442
+ {"name": "distribution", "neighbor_counts": [1, 3]},
434
443
  lookup_count=3,
435
444
  )
436
445
 
@@ -494,3 +503,62 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
494
503
  # Both should work and return valid predictions
495
504
  assert prediction_with_prompt.label is not None
496
505
  assert prediction_without_prompt.label is not None
506
+
507
+
508
+ @pytest.mark.asyncio
509
+ async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
510
+ """Test async prediction with a single value"""
511
+ prediction = await classification_model.apredict("Do you love soup?")
512
+ assert isinstance(prediction, ClassificationPrediction)
513
+ assert prediction.prediction_id is not None
514
+ assert prediction.label == 0
515
+ assert prediction.label_name == label_names[0]
516
+ assert 0 <= prediction.confidence <= 1
517
+ assert prediction.logits is not None
518
+ assert len(prediction.logits) == 2
519
+
520
+
521
+ @pytest.mark.asyncio
522
+ async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
523
+ """Test async prediction with a batch of values"""
524
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
525
+ assert len(predictions) == 2
526
+ assert predictions[0].prediction_id is not None
527
+ assert predictions[1].prediction_id is not None
528
+ assert predictions[0].label == 0
529
+ assert predictions[0].label_name == label_names[0]
530
+ assert 0 <= predictions[0].confidence <= 1
531
+ assert predictions[1].label == 1
532
+ assert predictions[1].label_name == label_names[1]
533
+ assert 0 <= predictions[1].confidence <= 1
534
+
535
+
536
+ @pytest.mark.asyncio
537
+ async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
538
+ """Test async prediction with expected labels"""
539
+ prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
540
+ assert prediction.expected_label == 1
541
+
542
+
543
+ @pytest.mark.asyncio
544
+ async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
545
+ """Test async prediction with telemetry disabled"""
546
+ predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
547
+ assert len(predictions) == 2
548
+ assert predictions[0].prediction_id is None
549
+ assert predictions[1].prediction_id is None
550
+ assert predictions[0].label == 0
551
+ assert predictions[0].label_name == label_names[0]
552
+ assert 0 <= predictions[0].confidence <= 1
553
+ assert predictions[1].label == 1
554
+ assert predictions[1].label_name == label_names[1]
555
+ assert 0 <= predictions[1].confidence <= 1
556
+
557
+
558
+ @pytest.mark.asyncio
559
+ async def test_predict_async_with_filters(classification_model: ClassificationModel):
560
+ """Test async prediction with filters"""
561
+ # there are no memories with label 0 and key g2, so we force a wrong prediction
562
+ filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
563
+ assert filtered_prediction.label == 1
564
+ assert filtered_prediction.label_name == "cats"