orca-sdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +179 -40
- orca_sdk/_shared/metrics_test.py +99 -6
- orca_sdk/_utils/data_parsing_test.py +1 -1
- orca_sdk/async_client.py +462 -301
- orca_sdk/classification_model.py +156 -41
- orca_sdk/classification_model_test.py +327 -8
- orca_sdk/client.py +462 -301
- orca_sdk/conftest.py +140 -21
- orca_sdk/datasource.py +45 -2
- orca_sdk/datasource_test.py +120 -0
- orca_sdk/embedding_model.py +32 -24
- orca_sdk/job.py +17 -17
- orca_sdk/memoryset.py +459 -56
- orca_sdk/memoryset_test.py +435 -2
- orca_sdk/regression_model.py +110 -19
- orca_sdk/regression_model_test.py +213 -0
- orca_sdk/telemetry.py +52 -13
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/METADATA +1 -1
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/RECORD +20 -20
- {orca_sdk-0.1.3.dist-info → orca_sdk-0.1.5.dist-info}/WHEEL +0 -0
orca_sdk/classification_model.py
CHANGED
|
@@ -3,14 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from typing import
|
|
7
|
-
Any,
|
|
8
|
-
Generator,
|
|
9
|
-
Iterable,
|
|
10
|
-
Literal,
|
|
11
|
-
cast,
|
|
12
|
-
overload,
|
|
13
|
-
)
|
|
6
|
+
from typing import Any, Generator, Iterable, Literal, cast, overload
|
|
14
7
|
|
|
15
8
|
from datasets import Dataset
|
|
16
9
|
|
|
@@ -20,8 +13,11 @@ from .async_client import OrcaAsyncClient
|
|
|
20
13
|
from .client import (
|
|
21
14
|
BootstrapClassificationModelMeta,
|
|
22
15
|
BootstrapClassificationModelResult,
|
|
16
|
+
ClassificationEvaluationRequest,
|
|
23
17
|
ClassificationModelMetadata,
|
|
18
|
+
ClassificationPredictionRequest,
|
|
24
19
|
OrcaClient,
|
|
20
|
+
PostClassificationModelByModelNameOrIdEvaluationParams,
|
|
25
21
|
PredictiveModelUpdate,
|
|
26
22
|
RACHeadType,
|
|
27
23
|
)
|
|
@@ -207,7 +203,12 @@ class ClassificationModel:
|
|
|
207
203
|
raise ValueError(f"Model with name {name} already exists")
|
|
208
204
|
elif if_exists == "open":
|
|
209
205
|
existing = cls.open(name)
|
|
210
|
-
for attribute in {
|
|
206
|
+
for attribute in {
|
|
207
|
+
"head_type",
|
|
208
|
+
"memory_lookup_count",
|
|
209
|
+
"num_classes",
|
|
210
|
+
"min_memory_weight",
|
|
211
|
+
}:
|
|
211
212
|
local_attribute = locals()[attribute]
|
|
212
213
|
existing_attribute = getattr(existing, attribute)
|
|
213
214
|
if local_attribute is not None and local_attribute != existing_attribute:
|
|
@@ -357,6 +358,12 @@ class ClassificationModel:
|
|
|
357
358
|
prompt: str | None = None,
|
|
358
359
|
use_lookup_cache: bool = True,
|
|
359
360
|
timeout_seconds: int = 10,
|
|
361
|
+
ignore_unlabeled: bool = False,
|
|
362
|
+
partition_id: str | list[str | None] | None = None,
|
|
363
|
+
partition_filter_mode: Literal[
|
|
364
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
365
|
+
] = "include_global",
|
|
366
|
+
use_gpu: bool = True,
|
|
360
367
|
) -> list[ClassificationPrediction]:
|
|
361
368
|
pass
|
|
362
369
|
|
|
@@ -371,6 +378,12 @@ class ClassificationModel:
|
|
|
371
378
|
prompt: str | None = None,
|
|
372
379
|
use_lookup_cache: bool = True,
|
|
373
380
|
timeout_seconds: int = 10,
|
|
381
|
+
ignore_unlabeled: bool = False,
|
|
382
|
+
partition_id: str | None = None,
|
|
383
|
+
partition_filter_mode: Literal[
|
|
384
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
385
|
+
] = "include_global",
|
|
386
|
+
use_gpu: bool = True,
|
|
374
387
|
) -> ClassificationPrediction:
|
|
375
388
|
pass
|
|
376
389
|
|
|
@@ -384,6 +397,12 @@ class ClassificationModel:
|
|
|
384
397
|
prompt: str | None = None,
|
|
385
398
|
use_lookup_cache: bool = True,
|
|
386
399
|
timeout_seconds: int = 10,
|
|
400
|
+
ignore_unlabeled: bool = False,
|
|
401
|
+
partition_id: str | None | list[str | None] = None,
|
|
402
|
+
partition_filter_mode: Literal[
|
|
403
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
404
|
+
] = "include_global",
|
|
405
|
+
use_gpu: bool = True,
|
|
387
406
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
388
407
|
"""
|
|
389
408
|
Predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -402,6 +421,15 @@ class ClassificationModel:
|
|
|
402
421
|
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
403
422
|
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
404
423
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
424
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
425
|
+
If False (default), allow unlabeled memories when necessary.
|
|
426
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
427
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
428
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
429
|
+
* `"include_global"`: Include global memories
|
|
430
|
+
* `"exclude_global"`: Exclude global memories
|
|
431
|
+
* `"only_global"`: Only include global memories
|
|
432
|
+
use_gpu: Whether to use GPU for the prediction (defaults to True)
|
|
405
433
|
|
|
406
434
|
Returns:
|
|
407
435
|
Label prediction or list of label predictions
|
|
@@ -447,22 +475,33 @@ class ClassificationModel:
|
|
|
447
475
|
for label in expected_labels
|
|
448
476
|
]
|
|
449
477
|
|
|
478
|
+
if use_gpu:
|
|
479
|
+
endpoint = "/gpu/classification_model/{name_or_id}/prediction"
|
|
480
|
+
else:
|
|
481
|
+
endpoint = "/classification_model/{name_or_id}/prediction"
|
|
482
|
+
|
|
450
483
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
451
484
|
client = OrcaClient._resolve_client()
|
|
485
|
+
request_json: ClassificationPredictionRequest = {
|
|
486
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
487
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
488
|
+
"expected_labels": expected_labels,
|
|
489
|
+
"tags": list(tags or set()),
|
|
490
|
+
"save_telemetry": telemetry_on,
|
|
491
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
492
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
493
|
+
"prompt": prompt,
|
|
494
|
+
"use_lookup_cache": use_lookup_cache,
|
|
495
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
496
|
+
"partition_filter_mode": partition_filter_mode,
|
|
497
|
+
}
|
|
498
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
499
|
+
if partition_filter_mode != "ignore_partitions":
|
|
500
|
+
request_json["partition_ids"] = partition_id
|
|
452
501
|
response = client.POST(
|
|
453
|
-
|
|
502
|
+
endpoint,
|
|
454
503
|
params={"name_or_id": self.id},
|
|
455
|
-
json=
|
|
456
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
457
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
458
|
-
"expected_labels": expected_labels,
|
|
459
|
-
"tags": list(tags or set()),
|
|
460
|
-
"save_telemetry": telemetry_on,
|
|
461
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
462
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
463
|
-
"prompt": prompt,
|
|
464
|
-
"use_lookup_cache": use_lookup_cache,
|
|
465
|
-
},
|
|
504
|
+
json=request_json,
|
|
466
505
|
timeout=timeout_seconds,
|
|
467
506
|
)
|
|
468
507
|
|
|
@@ -499,6 +538,11 @@ class ClassificationModel:
|
|
|
499
538
|
prompt: str | None = None,
|
|
500
539
|
use_lookup_cache: bool = True,
|
|
501
540
|
timeout_seconds: int = 10,
|
|
541
|
+
ignore_unlabeled: bool = False,
|
|
542
|
+
partition_id: str | list[str | None] | None = None,
|
|
543
|
+
partition_filter_mode: Literal[
|
|
544
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
545
|
+
] = "include_global",
|
|
502
546
|
) -> list[ClassificationPrediction]:
|
|
503
547
|
pass
|
|
504
548
|
|
|
@@ -513,6 +557,11 @@ class ClassificationModel:
|
|
|
513
557
|
prompt: str | None = None,
|
|
514
558
|
use_lookup_cache: bool = True,
|
|
515
559
|
timeout_seconds: int = 10,
|
|
560
|
+
ignore_unlabeled: bool = False,
|
|
561
|
+
partition_id: str | None = None,
|
|
562
|
+
partition_filter_mode: Literal[
|
|
563
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
564
|
+
] = "include_global",
|
|
516
565
|
) -> ClassificationPrediction:
|
|
517
566
|
pass
|
|
518
567
|
|
|
@@ -526,6 +575,11 @@ class ClassificationModel:
|
|
|
526
575
|
prompt: str | None = None,
|
|
527
576
|
use_lookup_cache: bool = True,
|
|
528
577
|
timeout_seconds: int = 10,
|
|
578
|
+
ignore_unlabeled: bool = False,
|
|
579
|
+
partition_id: str | None | list[str | None] = None,
|
|
580
|
+
partition_filter_mode: Literal[
|
|
581
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
582
|
+
] = "include_global",
|
|
529
583
|
) -> list[ClassificationPrediction] | ClassificationPrediction:
|
|
530
584
|
"""
|
|
531
585
|
Asynchronously predict label(s) for the given input value(s) grounded in similar memories
|
|
@@ -544,7 +598,14 @@ class ClassificationModel:
|
|
|
544
598
|
prompt: Optional prompt to use for instruction-tuned embedding models
|
|
545
599
|
use_lookup_cache: Whether to use cached lookup results for faster predictions
|
|
546
600
|
timeout_seconds: Timeout in seconds for the request, defaults to 10 seconds
|
|
547
|
-
|
|
601
|
+
ignore_unlabeled: If True, only use labeled memories during lookup.
|
|
602
|
+
If False (default), allow unlabeled memories when necessary.
|
|
603
|
+
partition_id: Optional partition ID(s) to use during memory lookup
|
|
604
|
+
partition_filter_mode: Optional partition filter mode to use for the prediction(s). One of
|
|
605
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
606
|
+
* `"include_global"`: Include global memories
|
|
607
|
+
* `"exclude_global"`: Exclude global memories
|
|
608
|
+
* `"only_global"`: Only include global memories
|
|
548
609
|
Returns:
|
|
549
610
|
Label prediction or list of label predictions.
|
|
550
611
|
|
|
@@ -591,20 +652,26 @@ class ClassificationModel:
|
|
|
591
652
|
|
|
592
653
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
593
654
|
client = OrcaAsyncClient._resolve_client()
|
|
655
|
+
request_json: ClassificationPredictionRequest = {
|
|
656
|
+
"input_values": value if isinstance(value, list) else [value],
|
|
657
|
+
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
658
|
+
"expected_labels": expected_labels,
|
|
659
|
+
"tags": list(tags or set()),
|
|
660
|
+
"save_telemetry": telemetry_on,
|
|
661
|
+
"save_telemetry_synchronously": telemetry_sync,
|
|
662
|
+
"filters": cast(list[FilterItem], parsed_filters),
|
|
663
|
+
"prompt": prompt,
|
|
664
|
+
"use_lookup_cache": use_lookup_cache,
|
|
665
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
666
|
+
"partition_filter_mode": partition_filter_mode,
|
|
667
|
+
}
|
|
668
|
+
# Don't send partition_ids when partition_filter_mode is "ignore_partitions"
|
|
669
|
+
if partition_filter_mode != "ignore_partitions":
|
|
670
|
+
request_json["partition_ids"] = partition_id
|
|
594
671
|
response = await client.POST(
|
|
595
672
|
"/gpu/classification_model/{name_or_id}/prediction",
|
|
596
673
|
params={"name_or_id": self.id},
|
|
597
|
-
json=
|
|
598
|
-
"input_values": value if isinstance(value, list) else [value],
|
|
599
|
-
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
600
|
-
"expected_labels": expected_labels,
|
|
601
|
-
"tags": list(tags or set()),
|
|
602
|
-
"save_telemetry": telemetry_on,
|
|
603
|
-
"save_telemetry_synchronously": telemetry_sync,
|
|
604
|
-
"filters": cast(list[FilterItem], parsed_filters),
|
|
605
|
-
"prompt": prompt,
|
|
606
|
-
"use_lookup_cache": use_lookup_cache,
|
|
607
|
-
},
|
|
674
|
+
json=request_json,
|
|
608
675
|
timeout=timeout_seconds,
|
|
609
676
|
)
|
|
610
677
|
|
|
@@ -706,7 +773,13 @@ class ClassificationModel:
|
|
|
706
773
|
label_column: str,
|
|
707
774
|
record_predictions: bool,
|
|
708
775
|
tags: set[str] | None,
|
|
776
|
+
subsample: int | float | None,
|
|
709
777
|
background: bool = False,
|
|
778
|
+
ignore_unlabeled: bool = False,
|
|
779
|
+
partition_column: str | None = None,
|
|
780
|
+
partition_filter_mode: Literal[
|
|
781
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
782
|
+
] = "include_global",
|
|
710
783
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
711
784
|
client = OrcaClient._resolve_client()
|
|
712
785
|
response = client.POST(
|
|
@@ -719,14 +792,18 @@ class ClassificationModel:
|
|
|
719
792
|
"memoryset_override_name_or_id": self._memoryset_override_id,
|
|
720
793
|
"record_telemetry": record_predictions,
|
|
721
794
|
"telemetry_tags": list(tags) if tags else None,
|
|
795
|
+
"subsample": subsample,
|
|
796
|
+
"ignore_unlabeled": ignore_unlabeled,
|
|
797
|
+
"datasource_partition_column": partition_column,
|
|
798
|
+
"partition_filter_mode": partition_filter_mode,
|
|
722
799
|
},
|
|
723
800
|
)
|
|
724
801
|
|
|
725
802
|
def get_value():
|
|
726
803
|
client = OrcaClient._resolve_client()
|
|
727
804
|
res = client.GET(
|
|
728
|
-
"/classification_model/{model_name_or_id}/evaluation/{
|
|
729
|
-
params={"model_name_or_id": self.id, "
|
|
805
|
+
"/classification_model/{model_name_or_id}/evaluation/{job_id}",
|
|
806
|
+
params={"model_name_or_id": self.id, "job_id": response["job_id"]},
|
|
730
807
|
)
|
|
731
808
|
assert res["result"] is not None
|
|
732
809
|
return ClassificationMetrics(
|
|
@@ -743,7 +820,7 @@ class ClassificationModel:
|
|
|
743
820
|
roc_curve=res["result"].get("roc_curve"),
|
|
744
821
|
)
|
|
745
822
|
|
|
746
|
-
job = Job(response["
|
|
823
|
+
job = Job(response["job_id"], get_value)
|
|
747
824
|
return job if background else job.result()
|
|
748
825
|
|
|
749
826
|
def _evaluate_dataset(
|
|
@@ -754,6 +831,11 @@ class ClassificationModel:
|
|
|
754
831
|
record_predictions: bool,
|
|
755
832
|
tags: set[str],
|
|
756
833
|
batch_size: int,
|
|
834
|
+
ignore_unlabeled: bool,
|
|
835
|
+
partition_column: str | None = None,
|
|
836
|
+
partition_filter_mode: Literal[
|
|
837
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
838
|
+
] = "include_global",
|
|
757
839
|
) -> ClassificationMetrics:
|
|
758
840
|
if len(dataset) == 0:
|
|
759
841
|
raise ValueError("Evaluation dataset cannot be empty")
|
|
@@ -769,6 +851,9 @@ class ClassificationModel:
|
|
|
769
851
|
expected_labels=dataset[i : i + batch_size][label_column],
|
|
770
852
|
tags=tags,
|
|
771
853
|
save_telemetry="sync" if record_predictions else "off",
|
|
854
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
855
|
+
partition_id=dataset[i : i + batch_size][partition_column] if partition_column else None,
|
|
856
|
+
partition_filter_mode=partition_filter_mode,
|
|
772
857
|
)
|
|
773
858
|
]
|
|
774
859
|
|
|
@@ -786,10 +871,16 @@ class ClassificationModel:
|
|
|
786
871
|
*,
|
|
787
872
|
value_column: str = "value",
|
|
788
873
|
label_column: str = "label",
|
|
874
|
+
partition_column: str | None = None,
|
|
789
875
|
record_predictions: bool = False,
|
|
790
876
|
tags: set[str] = {"evaluation"},
|
|
791
877
|
batch_size: int = 100,
|
|
878
|
+
subsample: int | float | None = None,
|
|
792
879
|
background: Literal[True],
|
|
880
|
+
ignore_unlabeled: bool = False,
|
|
881
|
+
partition_filter_mode: Literal[
|
|
882
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
883
|
+
] = "include_global",
|
|
793
884
|
) -> Job[ClassificationMetrics]:
|
|
794
885
|
pass
|
|
795
886
|
|
|
@@ -800,10 +891,16 @@ class ClassificationModel:
|
|
|
800
891
|
*,
|
|
801
892
|
value_column: str = "value",
|
|
802
893
|
label_column: str = "label",
|
|
894
|
+
partition_column: str | None = None,
|
|
803
895
|
record_predictions: bool = False,
|
|
804
896
|
tags: set[str] = {"evaluation"},
|
|
805
897
|
batch_size: int = 100,
|
|
898
|
+
subsample: int | float | None = None,
|
|
806
899
|
background: Literal[False] = False,
|
|
900
|
+
ignore_unlabeled: bool = False,
|
|
901
|
+
partition_filter_mode: Literal[
|
|
902
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
903
|
+
] = "include_global",
|
|
807
904
|
) -> ClassificationMetrics:
|
|
808
905
|
pass
|
|
809
906
|
|
|
@@ -813,10 +910,16 @@ class ClassificationModel:
|
|
|
813
910
|
*,
|
|
814
911
|
value_column: str = "value",
|
|
815
912
|
label_column: str = "label",
|
|
913
|
+
partition_column: str | None = None,
|
|
816
914
|
record_predictions: bool = False,
|
|
817
915
|
tags: set[str] = {"evaluation"},
|
|
818
916
|
batch_size: int = 100,
|
|
917
|
+
subsample: int | float | None = None,
|
|
819
918
|
background: bool = False,
|
|
919
|
+
ignore_unlabeled: bool = False,
|
|
920
|
+
partition_filter_mode: Literal[
|
|
921
|
+
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
922
|
+
] = "include_global",
|
|
820
923
|
) -> ClassificationMetrics | Job[ClassificationMetrics]:
|
|
821
924
|
"""
|
|
822
925
|
Evaluate the classification model on a given dataset or datasource
|
|
@@ -825,11 +928,18 @@ class ClassificationModel:
|
|
|
825
928
|
data: Dataset or Datasource to evaluate the model on
|
|
826
929
|
value_column: Name of the column that contains the input values to the model
|
|
827
930
|
label_column: Name of the column containing the expected labels
|
|
931
|
+
partition_column: Optional name of the column that contains the partition IDs
|
|
828
932
|
record_predictions: Whether to record [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s for analysis
|
|
829
933
|
tags: Optional tags to add to the recorded [`ClassificationPrediction`][orca_sdk.telemetry.ClassificationPrediction]s
|
|
830
934
|
batch_size: Batch size for processing Dataset inputs (only used when input is a Dataset)
|
|
935
|
+
subsample: Optional number (int) of rows to sample or fraction (float in (0, 1]) of data to sample for evaluation.
|
|
831
936
|
background: Whether to run the operation in the background and return a job handle
|
|
832
|
-
|
|
937
|
+
ignore_unlabeled: If True, only use labeled memories during lookup. If False (default), allow unlabeled memories
|
|
938
|
+
partition_filter_mode: Optional partition filter mode to use for the evaluation. One of
|
|
939
|
+
* `"ignore_partitions"`: Ignore partitions
|
|
940
|
+
* `"include_global"`: Include global memories
|
|
941
|
+
* `"exclude_global"`: Exclude global memories
|
|
942
|
+
* `"only_global"`: Only include global memories
|
|
833
943
|
Returns:
|
|
834
944
|
EvaluationResult containing metrics including accuracy, F1 score, ROC AUC, PR AUC, and anomaly score statistics
|
|
835
945
|
|
|
@@ -850,7 +960,11 @@ class ClassificationModel:
|
|
|
850
960
|
label_column=label_column,
|
|
851
961
|
record_predictions=record_predictions,
|
|
852
962
|
tags=tags,
|
|
963
|
+
subsample=subsample,
|
|
853
964
|
background=background,
|
|
965
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
966
|
+
partition_column=partition_column,
|
|
967
|
+
partition_filter_mode=partition_filter_mode,
|
|
854
968
|
)
|
|
855
969
|
elif isinstance(data, Dataset):
|
|
856
970
|
return self._evaluate_dataset(
|
|
@@ -860,6 +974,9 @@ class ClassificationModel:
|
|
|
860
974
|
record_predictions=record_predictions,
|
|
861
975
|
tags=tags,
|
|
862
976
|
batch_size=batch_size,
|
|
977
|
+
ignore_unlabeled=ignore_unlabeled,
|
|
978
|
+
partition_column=partition_column,
|
|
979
|
+
partition_filter_mode=partition_filter_mode,
|
|
863
980
|
)
|
|
864
981
|
else:
|
|
865
982
|
raise ValueError(f"Invalid data type: {type(data)}")
|
|
@@ -961,11 +1078,9 @@ class ClassificationModel:
|
|
|
961
1078
|
|
|
962
1079
|
def get_result() -> BootstrappedClassificationModel:
|
|
963
1080
|
client = OrcaClient._resolve_client()
|
|
964
|
-
res = client.GET(
|
|
965
|
-
"/agents/bootstrap_classification_model/{task_id}", params={"task_id": response["task_id"]}
|
|
966
|
-
)
|
|
1081
|
+
res = client.GET("/agents/bootstrap_classification_model/{job_id}", params={"job_id": response["job_id"]})
|
|
967
1082
|
assert res["result"] is not None
|
|
968
1083
|
return BootstrappedClassificationModel(res["result"])
|
|
969
1084
|
|
|
970
|
-
job = Job(response["
|
|
1085
|
+
job = Job(response["job_id"], get_result)
|
|
971
1086
|
return job if background else job.result()
|