orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3942 -0
- orca_sdk/classification_model.py +218 -20
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +899 -712
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +64 -12
- orca_sdk/datasource_test.py +144 -18
- orca_sdk/embedding_model.py +54 -37
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +27 -21
- orca_sdk/memoryset.py +823 -205
- orca_sdk/memoryset_test.py +315 -33
- orca_sdk/regression_model.py +59 -15
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +76 -26
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/METADATA +1 -1
- orca_sdk-0.1.4.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -9,13 +9,16 @@ from datasets import Dataset
|
|
|
9
9
|
|
|
10
10
|
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
11
11
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
|
+
from .async_client import OrcaAsyncClient
|
|
12
13
|
from .client import (
|
|
13
14
|
BootstrapClassificationModelMeta,
|
|
14
15
|
BootstrapClassificationModelResult,
|
|
16
|
+
ClassificationEvaluationRequest,
|
|
15
17
|
ClassificationModelMetadata,
|
|
18
|
+
OrcaClient,
|
|
19
|
+
PostClassificationModelByModelNameOrIdEvaluationParams,
|
|
16
20
|
PredictiveModelUpdate,
|
|
17
21
|
RACHeadType,
|
|
18
|
-
orca_api,
|
|
19
22
|
)
|
|
20
23
|
from .datasource import Datasource
|
|
21
24
|
from .job import Job
|
|
@@ -199,7 +202,12 @@ class ClassificationModel:
|
|
|
199
202
|
raise ValueError(f"Model with name {name} already exists")
|
|
200
203
|
elif if_exists == "open":
|
|
201
204
|
existing = cls.open(name)
|
|
202
|
-
for attribute in {
|
|
205
|
+
for attribute in {
|
|
206
|
+
"head_type",
|
|
207
|
+
"memory_lookup_count",
|
|
208
|
+
"num_classes",
|
|
209
|
+
"min_memory_weight",
|
|
210
|
+
}:
|
|
203
211
|
local_attribute = locals()[attribute]
|
|
204
212
|
existing_attribute = getattr(existing, attribute)
|
|
205
213
|
if local_attribute is not None and local_attribute != existing_attribute:
|
|
@@ -211,7 +219,8 @@ class ClassificationModel:
|
|
|
211
219
|
|
|
212
220
|
return existing
|
|
213
221
|
|
|
214
|
-
|
|
222
|
+
client = OrcaClient._resolve_client()
|
|
223
|
+
metadata = client.POST(
|
|
215
224
|
"/classification_model",
|
|
216
225
|
json={
|
|
217
226
|
"name": name,
|
|
@@ -240,7 +249,8 @@ class ClassificationModel:
|
|
|
240
249
|
Raises:
|
|
241
250
|
LookupError: If the classification model does not exist
|
|
242
251
|
"""
|
|
243
|
-
|
|
252
|
+
client = OrcaClient._resolve_client()
|
|
253
|
+
return cls(client.GET("/classification_model/{name_or_id}", params={"name_or_id": name}))
|
|
244
254
|
|
|
245
255
|
@classmethod
|
|
246
256
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -267,7 +277,8 @@ class ClassificationModel:
|
|
|
267
277
|
Returns:
|
|
268
278
|
List of handles to all classification models in the OrcaCloud
|
|
269
279
|
"""
|
|
270
|
-
|
|
280
|
+
client = OrcaClient._resolve_client()
|
|
281
|
+
return [cls(metadata) for metadata in client.GET("/classification_model")]
|
|
271
282
|
|
|
272
283
|
@classmethod
|
|
273
284
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -286,7 +297,8 @@ class ClassificationModel:
|
|
|
286
297
|
LookupError: If the classification model does not exist and if_not_exists is `"error"`
|
|
287
298
|
"""
|
|
288
299
|
try:
|
|
289
|
-
|
|
300
|
+
client = OrcaClient._resolve_client()
|
|
301
|
+
client.DELETE("/classification_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
290
302
|
logging.info(f"Deleted model {name_or_id}")
|
|
291
303
|
except LookupError:
|
|
292
304
|
if if_not_exists == "error":
|
|
@@ -322,7 +334,8 @@ class ClassificationModel:
|
|
|
322
334
|
update["description"] = description
|
|
323
335
|
if locked is not UNSET:
|
|
324
336
|
update["locked"] = locked
|
|
325
|
-
|
|
337
|
+
client = OrcaClient._resolve_client()
|
|
338
|
+
client.PATCH("/classification_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
326
339
|
self.refresh()
|
|
327
340
|
|
|
328
341
|
def lock(self) -> None:
|
|
@@ -344,6 +357,8 @@ class ClassificationModel:
|
|
|
344
357
|
prompt: str | None = None,
|
|
345
358
|
use_lookup_cache: bool = True,
|
|
346
359
|
timeout_seconds: int = 10,
|
|
360
|
+
ignore_unlabeled: bool = False,
|
|
361
|
+
use_gpu: bool = True,
|
|
347
362
|
) -> list[ClassificationPrediction]:
|
|
348
363
|
pass
|
|
349
364
|
|
|
@@ -358,6 +373,8 @@ class ClassificationModel:
|
|
|
358
373
|
prompt: str | None = None,
|
|
359
374
|
use_lookup_cache: bool = True,
|
|
360
375
|
timeout_seconds: int = 10,
|
|
376
|
+
ignore_unlabeled: bool = False,
|
|
377
|
+
use_gpu: bool = True,
|
|
361
378
|
) -> ClassificationPrediction:
|
|
362
379
|
pass
|
|
363
380
|
|
|
@@ -371,6 +388,8 @@ class ClassificationModel:
|
|
|
371
388
|
prompt: str | None = None,
|
|
372
389
|
use_lookup_cache: bool = True,
|
|
373
390
|
timeout_seconds: int = 10,
|
|
391
|
+
ignore_unlabeled: bool = False,
|
|
392
|
+
use_gpu: bool = True,
|
|
374
393
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
375
394
|
"""
|
|
376
395
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -389,6 +408,9 @@ class ClassificationModel:
|
|
|
389
408
|
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
390
409
|
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
391
410
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
411
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
412
|
+
If False (default), allow unlabeled memories when necessary.
|
|
413
|
+
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
392
414
|
|
|
393
415
|
Returns:
|
|
394
416
|
Label prediction or list of label predictions
|
|
@@ -421,6 +443,159 @@ class ClassificationModel:
|
|
|
421
443
|
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
422
444
|
]
|
|
423
445
|
|
|
446
|
+
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
447
|
+
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
448
|
+
|
|
449
|
+
if isinstance(expected_labels, int):
|
|
450
|
+
expected_labels = [expected_labels]
|
|
451
|
+
elif isinstance(expected_labels, str):
|
|
452
|
+
expected_labels = [self.memoryset.label_names.index(expected_labels)]
|
|
453
|
+
elif isinstance(expected_labels, list):
|
|
454
|
+
expected_labels = [
|
|
455
|
+
self.memoryset.label_names.index(label) if isinstance(label, str) else label
|
|
456
|
+
for label in expected_labels
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
if use_gpu:
|
|
460
|
+
endpoint = "/gpu/classification_model/{name_or_id}/prediction"
|
|
461
|
+
else:
|
|
462
|
+
endpoint = "/classification_model/{name_or_id}/prediction"
|
|
463
|
+
|
|
464
|
+
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
465
|
+
client = OrcaClient._resolve_client()
|
|
466
|
+
response = client.POST(
|
|
467
|
+
endpoint,
|
|
468
|
+
params={"name_or_id": self.id},
|
|
469
|
+
json={
|
|
470
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
471
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
472
|
+
"expected_labels": expected_labels,
|
|
473
|
+
"tags": list(tags or set()),
|
|
474
|
+
"save_telemetry": telemetry_on,
|
|
475
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
476
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
477
|
+
"prompt": prompt,
|
|
478
|
+
"use_lookup_cache": use_lookup_cache,
|
|
479
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
480
|
+
},
|
|
481
|
+
timeout=timeout_seconds,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
if telemetry_on and any(p["prediction_id"] is None for p in response):
|
|
485
|
+
raise RuntimeError("Failed to save prediction to database.")
|
|
486
|
+
|
|
487
|
+
predictions = [
|
|
488
|
+
ClassificationPrediction(
|
|
489
|
+
prediction_id=prediction["prediction_id"],
|
|
490
|
+
label=prediction["label"],
|
|
491
|
+
label_name=prediction["label_name"],
|
|
492
|
+
score=None,
|
|
493
|
+
confidence=prediction["confidence"],
|
|
494
|
+
anomaly_score=prediction["anomaly_score"],
|
|
495
|
+
memoryset=self.memoryset,
|
|
496
|
+
model=self,
|
|
497
|
+
logits=prediction["logits"],
|
|
498
|
+
input_value=input_value,
|
|
499
|
+
)
|
|
500
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
501
|
+
]
|
|
502
|
+
self._last_prediction_was_batch = isinstance(value, list)
|
|
503
|
+
self._last_prediction = predictions[-1]
|
|
504
|
+
return predictions if isinstance(value, list) else predictions[0]
|
|
505
|
+
|
|
506
|
+
@overload
|
|
507
|
+
async def apredict(
|
|
508
|
+
self,
|
|
509
|
+
value: list[str],
|
|
510
|
+
expected_labels: list[int] | None = None,
|
|
511
|
+
filters: list[FilterItemTuple] = [],
|
|
512
|
+
tags: set[str] | None = None,
|
|
513
|
+
save_telemetry: TelemetryMode = "on",
|
|
514
|
+
prompt: str | None = None,
|
|
515
|
+
use_lookup_cache: bool = True,
|
|
516
|
+
timeout_seconds: int = 10,
|
|
517
|
+
ignore_unlabeled: bool = False,
|
|
518
|
+
) -> list[ClassificationPrediction]:
|
|
519
|
+
pass
|
|
520
|
+
|
|
521
|
+
@overload
|
|
522
|
+
async def apredict(
|
|
523
|
+
self,
|
|
524
|
+
value: str,
|
|
525
|
+
expected_labels: int | None = None,
|
|
526
|
+
filters: list[FilterItemTuple] = [],
|
|
527
|
+
tags: set[str] | None = None,
|
|
528
|
+
save_telemetry: TelemetryMode = "on",
|
|
529
|
+
prompt: str | None = None,
|
|
530
|
+
use_lookup_cache: bool = True,
|
|
531
|
+
timeout_seconds: int = 10,
|
|
532
|
+
ignore_unlabeled: bool = False,
|
|
533
|
+
) -> ClassificationPrediction:
|
|
534
|
+
pass
|
|
535
|
+
|
|
536
|
+
async def apredict(
|
|
537
|
+
self,
|
|
538
|
+
value: list[str] | str,
|
|
539
|
+
expected_labels: list[int] | list[str] | int | str | None = None,
|
|
540
|
+
filters: list[FilterItemTuple] = [],
|
|
541
|
+
tags: set[str] | None = None,
|
|
542
|
+
save_telemetry: TelemetryMode = "on",
|
|
543
|
+
prompt: str | None = None,
|
|
544
|
+
use_lookup_cache: bool = True,
|
|
545
|
+
timeout_seconds: int = 10,
|
|
546
|
+
ignore_unlabeled: bool = False,
|
|
547
|
+
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
548
|
+
"""
|
|
549
|
+
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
550
|
+
|
|
551
|
+
Params:
|
|
552
|
+
value: Value(s) to get predict the labels of
|
|
553
|
+
expected_labels: Expected label(s) for the given input to record for model evaluation
|
|
554
|
+
filters: Optional filters to apply during memory lookup
|
|
555
|
+
tags: Tags to add to the prediction(s)
|
|
556
|
+
save_telemetry: Whether to save telemetry for the prediction(s). One of
|
|
557
|
+
* `"off"`: Do not save telemetry
|
|
558
|
+
* `"on"`: Save telemetry asynchronously unless the `ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY`
|
|
559
|
+
environment variable is set.
|
|
560
|
+
* `"sync"`: Save telemetry synchronously
|
|
561
|
+
* `"async"`: Save telemetry asynchronously
|
|
562
|
+
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
563
|
+
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
564
|
+
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
565
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
566
|
+
If False (default), allow unlabeled memories when necessary.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
Label prediction or list of label predictions.
|
|
570
|
+
|
|
571
|
+
Raises:
|
|
572
|
+
ValueError: If timeout_seconds is not a positive integer
|
|
573
|
+
TimeoutError: If the request times out after the specified duration
|
|
574
|
+
|
|
575
|
+
Examples:
|
|
576
|
+
Predict the label for a single value:
|
|
577
|
+
>>> prediction = await model.apredict("I am happy", tags={"test"})
|
|
578
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
579
|
+
|
|
580
|
+
Predict the labels for a list of values:
|
|
581
|
+
>>> predictions = await model.apredict(["I am happy", "I am sad"], expected_labels=[1, 0])
|
|
582
|
+
[
|
|
583
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy'}),
|
|
584
|
+
ClassificationPrediction({label: <negative: 0>, confidence: 0.05, anomaly_score: 0.1, input_value: 'I am sad'}),
|
|
585
|
+
]
|
|
586
|
+
|
|
587
|
+
Using a prompt with an instruction-tuned embedding model:
|
|
588
|
+
>>> prediction = await model.apredict("I am happy", prompt="Represent this text for sentiment classification:")
|
|
589
|
+
ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy' })
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
if timeout_seconds <= 0:
|
|
593
|
+
raise ValueError("timeout_seconds must be a positive integer")
|
|
594
|
+
|
|
595
|
+
parsed_filters = [
|
|
596
|
+
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
597
|
+
]
|
|
598
|
+
|
|
424
599
|
if any(_is_metric_column(filter[0]) for filter in filters):
|
|
425
600
|
raise ValueError(f"Cannot filter on {filters} - telemetry filters are not supported for predictions")
|
|
426
601
|
|
|
@@ -435,7 +610,8 @@ class ClassificationModel:
|
|
|
435
610
|
]
|
|
436
611
|
|
|
437
612
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
438
|
-
|
|
613
|
+
client = OrcaAsyncClient._resolve_client()
|
|
614
|
+
response = await client.POST(
|
|
439
615
|
"/gpu/classification_model/{name_or_id}/prediction",
|
|
440
616
|
params={"name_or_id": self.id},
|
|
441
617
|
json={
|
|
@@ -448,6 +624,7 @@ class ClassificationModel:
|
|
|
448
624
|
"filters": cast(list[FilterItem], parsed_filters),
|
|
449
625
|
"prompt": prompt,
|
|
450
626
|
"use_lookup_cache": use_lookup_cache,
|
|
627
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
451
628
|
},
|
|
452
629
|
timeout=timeout_seconds,
|
|
453
630
|
)
|
|
@@ -515,7 +692,8 @@ class ClassificationModel:
|
|
|
515
692
|
>>> predictions = model.predictions(expected_label_match=False)
|
|
516
693
|
[ClassificationPrediction({label: <positive: 1>, confidence: 0.95, anomaly_score: 0.1, input_value: 'I am happy', expected_label: 0})]
|
|
517
694
|
"""
|
|
518
|
-
|
|
695
|
+
client = OrcaClient._resolve_client()
|
|
696
|
+
predictions = client.POST(
|
|
519
697
|
"/telemetry/prediction",
|
|
520
698
|
json={
|
|
521
699
|
"model_id": self.id,
|
|
@@ -549,9 +727,12 @@ class ClassificationModel:
|
|
|
549
727
|
label_column: str,
|
|
550
728
|
record_predictions: bool,
|
|
551
729
|
tags: set[str] | None,
|
|
730
|
+
subsample: int | float | None,
|
|
552
731
|
background: bool = False,
|
|
732
|
+
ignore_unlabeled: bool = False,
|
|
553
733
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
554
|
-
|
|
734
|
+
client = OrcaClient._resolve_client()
|
|
735
|
+
response = client.POST(
|
|
555
736
|
"/classification_model/{model_name_or_id}/evaluation",
|
|
556
737
|
params={"model_name_or_id": self.id},
|
|
557
738
|
json={
|
|
@@ -561,13 +742,16 @@ class ClassificationModel:
|
|
|
561
742
|
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
562
743
|
"record_telemetry": record_predictions,
|
|
563
744
|
"telemetry_tags": list(tags) if tags else None,
|
|
745
|
+
"subsample": subsample,
|
|
746
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
564
747
|
},
|
|
565
748
|
)
|
|
566
749
|
|
|
567
750
|
def get_value():
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
751
|
+
client = OrcaClient._resolve_client()
|
|
752
|
+
res = client.GET(
|
|
753
|
+
"/classification_model/{model_name_or_id}/evaluation/{job_id}",
|
|
754
|
+
params={"model_name_or_id": self.id, "job_id": response["job_id"]},
|
|
571
755
|
)
|
|
572
756
|
assert res["result"] is not None
|
|
573
757
|
return ClassificationMetrics(
|
|
@@ -584,7 +768,7 @@ class ClassificationModel:
|
|
|
584
768
|
roc_curve=res["result"].get("roc_curve"),
|
|
585
769
|
)
|
|
586
770
|
|
|
587
|
-
job = Job(response["
|
|
771
|
+
job = Job(response["job_id"], get_value)
|
|
588
772
|
return job if background else job.result()
|
|
589
773
|
|
|
590
774
|
def _evaluate_dataset(
|
|
@@ -595,6 +779,7 @@ class ClassificationModel:
|
|
|
595
779
|
record_predictions: bool,
|
|
596
780
|
tags: set[str],
|
|
597
781
|
batch_size: int,
|
|
782
|
+
ignore_unlabeled: bool,
|
|
598
783
|
) -> ClassificationMetrics:
|
|
599
784
|
if len(dataset) == 0:
|
|
600
785
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -610,6 +795,7 @@ class ClassificationModel:
|
|
|
610
795
|
expected_labels=dataset[i : i + batch_size][label_column],
|
|
611
796
|
tags=tags,
|
|
612
797
|
save_telemetry="sync" if record_predictions else "off",
|
|
798
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
613
799
|
)
|
|
614
800
|
]
|
|
615
801
|
|
|
@@ -630,7 +816,9 @@ class ClassificationModel:
|
|
|
630
816
|
record_predictions: bool = False,
|
|
631
817
|
tags: set[str] = {"evaluation"},
|
|
632
818
|
batch_size: int = 100,
|
|
819
|
+
subsample: int | float | None = None,
|
|
633
820
|
background: Literal[True],
|
|
821
|
+
ignore_unlabeled: bool = False,
|
|
634
822
|
) -> Job[ClassificationMetrics]:
|
|
635
823
|
pass
|
|
636
824
|
|
|
@@ -644,7 +832,9 @@ class ClassificationModel:
|
|
|
644
832
|
record_predictions: bool = False,
|
|
645
833
|
tags: set[str] = {"evaluation"},
|
|
646
834
|
batch_size: int = 100,
|
|
835
|
+
subsample: int | float | None = None,
|
|
647
836
|
background: Literal[False] = False,
|
|
837
|
+
ignore_unlabeled: bool = False,
|
|
648
838
|
) -> ClassificationMetrics:
|
|
649
839
|
pass
|
|
650
840
|
|
|
@@ -657,7 +847,9 @@ class ClassificationModel:
|
|
|
657
847
|
record_predictions: bool = False,
|
|
658
848
|
tags: set[str] = {"evaluation"},
|
|
659
849
|
batch_size: int = 100,
|
|
850
|
+
subsample: int | float | None = None,
|
|
660
851
|
background: bool = False,
|
|
852
|
+
ignore_unlabeled: bool = False,
|
|
661
853
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
662
854
|
"""
|
|
663
855
|
Evaluate the classification model on a given dataset or datasource
|
|
@@ -669,7 +861,9 @@ class ClassificationModel:
|
|
|
669
861
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
670
862
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
671
863
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
864
|
+
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
672
865
|
background: Whether to run the operation in the background and return a job handle
|
|
866
|
+
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
673
867
|
|
|
674
868
|
Returns:
|
|
675
869
|
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
@@ -691,7 +885,9 @@ class ClassificationModel:
|
|
|
691
885
|
label_column=label_column,
|
|
692
886
|
record_predictions=record_predictions,
|
|
693
887
|
tags=tags,
|
|
888
|
+
subsample=subsample,
|
|
694
889
|
background=background,
|
|
890
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
695
891
|
)
|
|
696
892
|
elif isinstance(data, Dataset):
|
|
697
893
|
return self._evaluate_dataset(
|
|
@@ -701,6 +897,7 @@ class ClassificationModel:
|
|
|
701
897
|
record_predictions=record_predictions,
|
|
702
898
|
tags=tags,
|
|
703
899
|
batch_size=batch_size,
|
|
900
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
704
901
|
)
|
|
705
902
|
else:
|
|
706
903
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -773,7 +970,8 @@ class ClassificationModel:
|
|
|
773
970
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
774
971
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
775
972
|
"""
|
|
776
|
-
|
|
973
|
+
client = OrcaClient._resolve_client()
|
|
974
|
+
client.PUT(
|
|
777
975
|
"/telemetry/prediction/feedback",
|
|
778
976
|
json=[
|
|
779
977
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
@@ -788,7 +986,8 @@ class ClassificationModel:
|
|
|
788
986
|
num_examples_per_label: int,
|
|
789
987
|
background: bool = False,
|
|
790
988
|
) -> Job[BootstrappedClassificationModel] | BootstrappedClassificationModel:
|
|
791
|
-
|
|
989
|
+
client = OrcaClient._resolve_client()
|
|
990
|
+
response = client.POST(
|
|
792
991
|
"/agents/bootstrap_classification_model",
|
|
793
992
|
json={
|
|
794
993
|
"model_description": model_description,
|
|
@@ -799,11 +998,10 @@ class ClassificationModel:
|
|
|
799
998
|
)
|
|
800
999
|
|
|
801
1000
|
def get_result() -> BootstrappedClassificationModel:
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
)
|
|
1001
|
+
client = OrcaClient._resolve_client()
|
|
1002
|
+
res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
|
|
805
1003
|
assert res["result"] is not None
|
|
806
1004
|
return BootstrappedClassificationModel(res["result"])
|
|
807
1005
|
|
|
808
|
-
job = Job(response["
|
|
1006
|
+
job = Job(response["job_id"], get_result)
|
|
809
1007
|
return job if background else job.result()
|
|
@@ -53,9 +53,10 @@ def test_create_model_already_exists_return(readonly_memoryset, classification_m
|
|
|
53
53
|
assert new_model.memory_lookup_count == 3
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def test_create_model_unauthenticated(
|
|
57
|
-
with
|
|
58
|
-
|
|
56
|
+
def test_create_model_unauthenticated(unauthenticated_client, readonly_memoryset: LabeledMemoryset):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
ClassificationModel.create("test_model", readonly_memoryset)
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
def test_get_model(classification_model: ClassificationModel):
|
|
@@ -68,9 +69,10 @@ def test_get_model(classification_model: ClassificationModel):
|
|
|
68
69
|
assert fetched_model == classification_model
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
def test_get_model_unauthenticated(
|
|
72
|
-
with
|
|
73
|
-
|
|
72
|
+
def test_get_model_unauthenticated(unauthenticated_client):
|
|
73
|
+
with unauthenticated_client.use():
|
|
74
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
75
|
+
ClassificationModel.open("test_model")
|
|
74
76
|
|
|
75
77
|
|
|
76
78
|
def test_get_model_invalid_input():
|
|
@@ -83,9 +85,10 @@ def test_get_model_not_found():
|
|
|
83
85
|
ClassificationModel.open(str(uuid4()))
|
|
84
86
|
|
|
85
87
|
|
|
86
|
-
def test_get_model_unauthorized(
|
|
87
|
-
with
|
|
88
|
-
|
|
88
|
+
def test_get_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
89
|
+
with unauthorized_client.use():
|
|
90
|
+
with pytest.raises(LookupError):
|
|
91
|
+
ClassificationModel.open(classification_model.name)
|
|
89
92
|
|
|
90
93
|
|
|
91
94
|
def test_list_models(classification_model: ClassificationModel):
|
|
@@ -94,13 +97,15 @@ def test_list_models(classification_model: ClassificationModel):
|
|
|
94
97
|
assert any(model.name == model.name for model in models)
|
|
95
98
|
|
|
96
99
|
|
|
97
|
-
def test_list_models_unauthenticated(
|
|
98
|
-
with
|
|
99
|
-
|
|
100
|
+
def test_list_models_unauthenticated(unauthenticated_client):
|
|
101
|
+
with unauthenticated_client.use():
|
|
102
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
103
|
+
ClassificationModel.all()
|
|
100
104
|
|
|
101
105
|
|
|
102
|
-
def test_list_models_unauthorized(
|
|
103
|
-
|
|
106
|
+
def test_list_models_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
107
|
+
with unauthorized_client.use():
|
|
108
|
+
assert ClassificationModel.all() == []
|
|
104
109
|
|
|
105
110
|
|
|
106
111
|
def test_update_model_attributes(classification_model: ClassificationModel):
|
|
@@ -131,9 +136,10 @@ def test_delete_model(readonly_memoryset: LabeledMemoryset):
|
|
|
131
136
|
ClassificationModel.open("model_to_delete")
|
|
132
137
|
|
|
133
138
|
|
|
134
|
-
def test_delete_model_unauthenticated(
|
|
135
|
-
with
|
|
136
|
-
|
|
139
|
+
def test_delete_model_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
140
|
+
with unauthenticated_client.use():
|
|
141
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
142
|
+
ClassificationModel.drop(classification_model.name)
|
|
137
143
|
|
|
138
144
|
|
|
139
145
|
def test_delete_model_not_found():
|
|
@@ -143,9 +149,10 @@ def test_delete_model_not_found():
|
|
|
143
149
|
ClassificationModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
144
150
|
|
|
145
151
|
|
|
146
|
-
def test_delete_model_unauthorized(
|
|
147
|
-
with
|
|
148
|
-
|
|
152
|
+
def test_delete_model_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
153
|
+
with unauthorized_client.use():
|
|
154
|
+
with pytest.raises(LookupError):
|
|
155
|
+
ClassificationModel.drop(classification_model.name)
|
|
149
156
|
|
|
150
157
|
|
|
151
158
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
@@ -254,14 +261,16 @@ def test_predict_disable_telemetry(classification_model: ClassificationModel, la
|
|
|
254
261
|
assert 0 <= predictions[1].confidence <= 1
|
|
255
262
|
|
|
256
263
|
|
|
257
|
-
def test_predict_unauthenticated(
|
|
258
|
-
with
|
|
259
|
-
|
|
264
|
+
def test_predict_unauthenticated(unauthenticated_client, classification_model: ClassificationModel):
|
|
265
|
+
with unauthenticated_client.use():
|
|
266
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
267
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
260
268
|
|
|
261
269
|
|
|
262
|
-
def test_predict_unauthorized(
|
|
263
|
-
with
|
|
264
|
-
|
|
270
|
+
def test_predict_unauthorized(unauthorized_client, classification_model: ClassificationModel):
|
|
271
|
+
with unauthorized_client.use():
|
|
272
|
+
with pytest.raises(LookupError):
|
|
273
|
+
classification_model.predict(["Do you love soup?", "Are cats cute?"])
|
|
265
274
|
|
|
266
275
|
|
|
267
276
|
def test_predict_constraint_violation(readonly_memoryset: LabeledMemoryset):
|
|
@@ -396,7 +405,7 @@ def test_last_prediction_with_single(classification_model: ClassificationModel):
|
|
|
396
405
|
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
397
406
|
|
|
398
407
|
writable_memoryset.analyze(
|
|
399
|
-
{"name": "
|
|
408
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
400
409
|
lookup_count=3,
|
|
401
410
|
)
|
|
402
411
|
|
|
@@ -430,7 +439,7 @@ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
|
430
439
|
"""Test getting action recommendations for predictions"""
|
|
431
440
|
|
|
432
441
|
writable_memoryset.analyze(
|
|
433
|
-
{"name": "
|
|
442
|
+
{"name": "distribution", "neighbor_counts": [1, 3]},
|
|
434
443
|
lookup_count=3,
|
|
435
444
|
)
|
|
436
445
|
|
|
@@ -494,3 +503,62 @@ def test_predict_with_prompt(classification_model: ClassificationModel):
|
|
|
494
503
|
# Both should work and return valid predictions
|
|
495
504
|
assert prediction_with_prompt.label is not None
|
|
496
505
|
assert prediction_without_prompt.label is not None
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_predict_async_single(classification_model: ClassificationModel, label_names: list[str]):
|
|
510
|
+
"""Test async prediction with a single value"""
|
|
511
|
+
prediction = await classification_model.apredict("Do you love soup?")
|
|
512
|
+
assert isinstance(prediction, ClassificationPrediction)
|
|
513
|
+
assert prediction.prediction_id is not None
|
|
514
|
+
assert prediction.label == 0
|
|
515
|
+
assert prediction.label_name == label_names[0]
|
|
516
|
+
assert 0 <= prediction.confidence <= 1
|
|
517
|
+
assert prediction.logits is not None
|
|
518
|
+
assert len(prediction.logits) == 2
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@pytest.mark.asyncio
|
|
522
|
+
async def test_predict_async_batch(classification_model: ClassificationModel, label_names: list[str]):
|
|
523
|
+
"""Test async prediction with a batch of values"""
|
|
524
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"])
|
|
525
|
+
assert len(predictions) == 2
|
|
526
|
+
assert predictions[0].prediction_id is not None
|
|
527
|
+
assert predictions[1].prediction_id is not None
|
|
528
|
+
assert predictions[0].label == 0
|
|
529
|
+
assert predictions[0].label_name == label_names[0]
|
|
530
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
531
|
+
assert predictions[1].label == 1
|
|
532
|
+
assert predictions[1].label_name == label_names[1]
|
|
533
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@pytest.mark.asyncio
|
|
537
|
+
async def test_predict_async_with_expected_labels(classification_model: ClassificationModel):
|
|
538
|
+
"""Test async prediction with expected labels"""
|
|
539
|
+
prediction = await classification_model.apredict("Do you love soup?", expected_labels=1)
|
|
540
|
+
assert prediction.expected_label == 1
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@pytest.mark.asyncio
|
|
544
|
+
async def test_predict_async_disable_telemetry(classification_model: ClassificationModel, label_names: list[str]):
|
|
545
|
+
"""Test async prediction with telemetry disabled"""
|
|
546
|
+
predictions = await classification_model.apredict(["Do you love soup?", "Are cats cute?"], save_telemetry="off")
|
|
547
|
+
assert len(predictions) == 2
|
|
548
|
+
assert predictions[0].prediction_id is None
|
|
549
|
+
assert predictions[1].prediction_id is None
|
|
550
|
+
assert predictions[0].label == 0
|
|
551
|
+
assert predictions[0].label_name == label_names[0]
|
|
552
|
+
assert 0 <= predictions[0].confidence <= 1
|
|
553
|
+
assert predictions[1].label == 1
|
|
554
|
+
assert predictions[1].label_name == label_names[1]
|
|
555
|
+
assert 0 <= predictions[1].confidence <= 1
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
@pytest.mark.asyncio
|
|
559
|
+
async def test_predict_async_with_filters(classification_model: ClassificationModel):
|
|
560
|
+
"""Test async prediction with filters"""
|
|
561
|
+
# there are no memories with label 0 and key g2, so we force a wrong prediction
|
|
562
|
+
filtered_prediction = await classification_model.apredict("I love soup", filters=[("key", "==", "g2")])
|
|
563
|
+
assert filtered_prediction.label == 1
|
|
564
|
+
assert filtered_prediction.label_name == "cats"
|