clarifai 10.1.0__py3-none-any.whl → 10.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +23 -43
- clarifai/client/base.py +44 -4
- clarifai/client/dataset.py +138 -52
- clarifai/client/input.py +37 -4
- clarifai/client/model.py +279 -8
- clarifai/client/module.py +7 -5
- clarifai/client/runner.py +3 -1
- clarifai/client/search.py +7 -3
- clarifai/client/user.py +14 -12
- clarifai/client/workflow.py +7 -4
- clarifai/constants/dataset.py +2 -0
- clarifai/datasets/upload/loaders/README.md +3 -4
- clarifai/datasets/upload/loaders/xview_detection.py +5 -5
- clarifai/models/model_serving/cli/_utils.py +1 -1
- clarifai/models/model_serving/cli/build.py +1 -1
- clarifai/models/model_serving/cli/upload.py +1 -1
- clarifai/models/model_serving/utils.py +3 -1
- clarifai/rag/rag.py +25 -11
- clarifai/rag/utils.py +21 -6
- clarifai/utils/evaluation/__init__.py +427 -0
- clarifai/utils/evaluation/helpers.py +522 -0
- clarifai/utils/logging.py +30 -0
- clarifai/utils/model_train.py +3 -1
- clarifai/versions.py +1 -1
- clarifai/workflows/validate.py +1 -1
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/METADATA +46 -9
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/RECORD +31 -30
- clarifai/datasets/upload/loaders/coco_segmentation.py +0 -98
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/LICENSE +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/WHEEL +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/entry_points.txt +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.2.0.dist-info}/top_level.txt +0 -0
clarifai/client/input.py
CHANGED
@@ -18,6 +18,7 @@ from tqdm import tqdm
|
|
18
18
|
|
19
19
|
from clarifai.client.base import BaseClient
|
20
20
|
from clarifai.client.lister import Lister
|
21
|
+
from clarifai.constants.dataset import MAX_RETRIES
|
21
22
|
from clarifai.errors import UserError
|
22
23
|
from clarifai.utils.logging import get_logger
|
23
24
|
from clarifai.utils.misc import BackoffIterator, Chunker
|
@@ -32,6 +33,7 @@ class Inputs(Lister, BaseClient):
|
|
32
33
|
logger_level: str = "INFO",
|
33
34
|
base_url: str = "https://api.clarifai.com",
|
34
35
|
pat: str = None,
|
36
|
+
token: str = None,
|
35
37
|
**kwargs):
|
36
38
|
"""Initializes an Input object.
|
37
39
|
|
@@ -39,6 +41,8 @@ class Inputs(Lister, BaseClient):
|
|
39
41
|
user_id (str): A user ID for authentication.
|
40
42
|
app_id (str): An app ID for the application to interact with.
|
41
43
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
44
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
42
46
|
**kwargs: Additional keyword arguments to be passed to the Input
|
43
47
|
"""
|
44
48
|
self.user_id = user_id
|
@@ -46,7 +50,8 @@ class Inputs(Lister, BaseClient):
|
|
46
50
|
self.kwargs = {**kwargs}
|
47
51
|
self.input_info = resources_pb2.Input(**self.kwargs)
|
48
52
|
self.logger = get_logger(logger_level=logger_level, name=__name__)
|
49
|
-
BaseClient.__init__(
|
53
|
+
BaseClient.__init__(
|
54
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
50
55
|
Lister.__init__(self)
|
51
56
|
|
52
57
|
@staticmethod
|
@@ -670,6 +675,30 @@ class Inputs(Lister, BaseClient):
|
|
670
675
|
|
671
676
|
return input_job_id, response
|
672
677
|
|
678
|
+
def patch_inputs(self, inputs: List[Input], action: str = 'merge') -> str:
|
679
|
+
"""Patch list of input objects to the app.
|
680
|
+
|
681
|
+
Args:
|
682
|
+
inputs (list): List of input objects to upload.
|
683
|
+
action (str): Action to perform on the input. Options: 'merge', 'overwrite', 'remove'.
|
684
|
+
|
685
|
+
Returns:
|
686
|
+
response: Response from the grpc request.
|
687
|
+
"""
|
688
|
+
if not isinstance(inputs, list):
|
689
|
+
raise UserError("inputs must be a list of Input objects")
|
690
|
+
uuid.uuid4().hex # generate a unique id for this job
|
691
|
+
request = service_pb2.PatchInputsRequest(
|
692
|
+
user_app_id=self.user_app_id, inputs=inputs, action=action)
|
693
|
+
response = self._grpc_request(self.STUB.PatchInputs, request)
|
694
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
695
|
+
try:
|
696
|
+
self.logger.warning(f"Patch inputs failed, status: {response.annotations[0].status}")
|
697
|
+
except Exception:
|
698
|
+
self.logger.warning(f"Patch inputs failed, status: {response.status.details}")
|
699
|
+
|
700
|
+
self.logger.info("\nPatch Inputs Successful\n%s", response.status)
|
701
|
+
|
673
702
|
def upload_annotations(self, batch_annot: List[resources_pb2.Annotation], show_log: bool = True
|
674
703
|
) -> Union[List[resources_pb2.Annotation], List[None]]:
|
675
704
|
"""Upload image annotations to app.
|
@@ -908,10 +937,14 @@ class Inputs(Lister, BaseClient):
|
|
908
937
|
"""Retry failed uploads.
|
909
938
|
|
910
939
|
Args:
|
911
|
-
failed_inputs (List[Input]): failed input
|
940
|
+
failed_inputs (List[Input]): failed input protos
|
912
941
|
"""
|
913
|
-
|
914
|
-
|
942
|
+
for _retry in range(MAX_RETRIES):
|
943
|
+
if failed_inputs:
|
944
|
+
self.logger.info(f"Retrying upload for {len(failed_inputs)} Failed inputs..\n")
|
945
|
+
failed_inputs = self._upload_batch(failed_inputs)
|
946
|
+
|
947
|
+
self.logger.warning(f"Failed to upload {len(failed_inputs)} inputs..\n ")
|
915
948
|
|
916
949
|
def _delete_failed_inputs(self, inputs: List[Input]) -> List[Input]:
|
917
950
|
"""Delete failed input ids from clarifai platform dataset.
|
clarifai/client/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import time
|
3
|
-
from typing import Any, Dict, Generator, List
|
3
|
+
from typing import Any, Dict, Generator, List, Union
|
4
4
|
|
5
5
|
import requests
|
6
6
|
import yaml
|
@@ -9,6 +9,7 @@ from clarifai_grpc.grpc.api.resources_pb2 import Input
|
|
9
9
|
from clarifai_grpc.grpc.api.status import status_code_pb2
|
10
10
|
from google.protobuf.json_format import MessageToDict
|
11
11
|
from google.protobuf.struct_pb2 import Struct
|
12
|
+
from tqdm import tqdm
|
12
13
|
|
13
14
|
from clarifai.client.base import BaseClient
|
14
15
|
from clarifai.client.input import Inputs
|
@@ -32,6 +33,7 @@ class Model(Lister, BaseClient):
|
|
32
33
|
model_version: Dict = {'id': ""},
|
33
34
|
base_url: str = "https://api.clarifai.com",
|
34
35
|
pat: str = None,
|
36
|
+
token: str = None,
|
35
37
|
**kwargs):
|
36
38
|
"""Initializes a Model object.
|
37
39
|
|
@@ -41,6 +43,7 @@ class Model(Lister, BaseClient):
|
|
41
43
|
model_version (dict): The Model Version to interact with.
|
42
44
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
43
45
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
46
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
44
47
|
**kwargs: Additional keyword arguments to be passed to the Model.
|
45
48
|
"""
|
46
49
|
if url and model_id:
|
@@ -55,7 +58,8 @@ class Model(Lister, BaseClient):
|
|
55
58
|
self.model_info = resources_pb2.Model(**self.kwargs)
|
56
59
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
57
60
|
self.training_params = {}
|
58
|
-
BaseClient.__init__(
|
61
|
+
BaseClient.__init__(
|
62
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
59
63
|
Lister.__init__(self)
|
60
64
|
|
61
65
|
def list_training_templates(self) -> List[str]:
|
@@ -212,6 +216,8 @@ class Model(Lister, BaseClient):
|
|
212
216
|
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
213
217
|
>>> model.train('model_params.yaml')
|
214
218
|
"""
|
219
|
+
if not self.model_info.model_type_id:
|
220
|
+
self.load_info()
|
215
221
|
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
216
222
|
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
217
223
|
if not yaml_file and len(self.training_params) == 0:
|
@@ -222,8 +228,10 @@ class Model(Lister, BaseClient):
|
|
222
228
|
params_dict = yaml.safe_load(file)
|
223
229
|
else:
|
224
230
|
params_dict = self.training_params
|
225
|
-
|
226
|
-
|
231
|
+
#getting all the concepts for the model type
|
232
|
+
if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
|
233
|
+
concepts = self._list_concepts()
|
234
|
+
train_dict = params_parser(params_dict, concepts)
|
227
235
|
request = service_pb2.PostModelVersionsRequest(
|
228
236
|
user_app_id=self.user_app_id,
|
229
237
|
model_id=self.id,
|
@@ -331,7 +339,7 @@ class Model(Lister, BaseClient):
|
|
331
339
|
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
332
340
|
kwargs = self.process_response_keys(dict_response['model'], 'model')
|
333
341
|
|
334
|
-
return Model(base_url=self.base, pat=self.pat, **kwargs)
|
342
|
+
return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
|
335
343
|
|
336
344
|
def list_versions(self, page_no: int = None,
|
337
345
|
per_page: int = None) -> Generator['Model', None, None]:
|
@@ -373,10 +381,9 @@ class Model(Lister, BaseClient):
|
|
373
381
|
del model_version_info['train_info']['dataset']['version']['metrics']
|
374
382
|
except KeyError:
|
375
383
|
pass
|
376
|
-
yield Model(
|
384
|
+
yield Model.from_auth_helper(
|
385
|
+
auth=self.auth_helper,
|
377
386
|
model_id=self.id,
|
378
|
-
base_url=self.base,
|
379
|
-
pat=self.pat,
|
380
387
|
**dict(self.kwargs, model_version=model_version_info))
|
381
388
|
|
382
389
|
def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
|
@@ -548,6 +555,17 @@ class Model(Lister, BaseClient):
|
|
548
555
|
resources_pb2.OutputInfo(
|
549
556
|
output_config=resources_pb2.OutputConfig(**output_config), params=params))
|
550
557
|
|
558
|
+
def _list_concepts(self) -> List[str]:
|
559
|
+
"""Lists all the concepts for the model type.
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
concepts (List): List of concepts for the model type.
|
563
|
+
"""
|
564
|
+
request_data = dict(user_app_id=self.user_app_id)
|
565
|
+
all_concepts_infos = self.list_pages_generator(self.STUB.ListConcepts,
|
566
|
+
service_pb2.ListConceptsRequest, request_data)
|
567
|
+
return [concept_info['concept_id'] for concept_info in all_concepts_infos]
|
568
|
+
|
551
569
|
def load_info(self) -> None:
|
552
570
|
"""Loads the model info."""
|
553
571
|
request = service_pb2.GetModelRequest(
|
@@ -576,3 +594,256 @@ class Model(Lister, BaseClient):
|
|
576
594
|
if hasattr(self.model_info, param)
|
577
595
|
]
|
578
596
|
return f"Model Details: \n{', '.join(attribute_strings)}\n"
|
597
|
+
|
598
|
+
def list_evaluations(self) -> resources_pb2.EvalMetrics:
|
599
|
+
"""List all eval_metrics of current model version
|
600
|
+
|
601
|
+
Raises:
|
602
|
+
Exception: Failed to call API
|
603
|
+
|
604
|
+
Returns:
|
605
|
+
resources_pb2.EvalMetrics
|
606
|
+
"""
|
607
|
+
assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
608
|
+
request = service_pb2.ListModelVersionEvaluationsRequest(
|
609
|
+
user_app_id=self.user_app_id,
|
610
|
+
model_id=self.id,
|
611
|
+
model_version_id=self.model_info.model_version.id)
|
612
|
+
response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
|
613
|
+
|
614
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
615
|
+
raise Exception(response.status)
|
616
|
+
|
617
|
+
return response.eval_metrics
|
618
|
+
|
619
|
+
def evaluate(self,
|
620
|
+
dataset_id: str,
|
621
|
+
dataset_app_id: str = None,
|
622
|
+
dataset_user_id: str = None,
|
623
|
+
eval_id: str = None,
|
624
|
+
extended_metrics: dict = None,
|
625
|
+
eval_info: dict = None) -> resources_pb2.EvalMetrics:
|
626
|
+
""" Run evaluation
|
627
|
+
|
628
|
+
Args:
|
629
|
+
dataset_id (str): Dataset Id.
|
630
|
+
dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
|
631
|
+
dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
|
632
|
+
eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
|
633
|
+
extended_metrics (dict): user custom metrics result. Default is None.
|
634
|
+
eval_info (dict): custom eval info. Default is empty dict.
|
635
|
+
|
636
|
+
Return
|
637
|
+
eval_metrics
|
638
|
+
|
639
|
+
"""
|
640
|
+
assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
641
|
+
metrics = None
|
642
|
+
if isinstance(extended_metrics, dict):
|
643
|
+
metrics = Struct()
|
644
|
+
metrics.update(extended_metrics)
|
645
|
+
metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
|
646
|
+
|
647
|
+
eval_info_params = None
|
648
|
+
if isinstance(eval_info, dict):
|
649
|
+
eval_info_params = Struct()
|
650
|
+
eval_info_params.update(eval_info)
|
651
|
+
eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
|
652
|
+
|
653
|
+
eval_metric = resources_pb2.EvalMetrics(
|
654
|
+
id=eval_id,
|
655
|
+
model=resources_pb2.Model(
|
656
|
+
id=self.id,
|
657
|
+
app_id=self.auth_helper.app_id,
|
658
|
+
user_id=self.auth_helper.user_id,
|
659
|
+
model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
|
660
|
+
),
|
661
|
+
extended_metrics=metrics,
|
662
|
+
ground_truth_dataset=resources_pb2.Dataset(
|
663
|
+
id=dataset_id,
|
664
|
+
app_id=dataset_app_id or self.auth_helper.app_id,
|
665
|
+
user_id=dataset_user_id or self.auth_helper.user_id,
|
666
|
+
),
|
667
|
+
eval_info=eval_info_params,
|
668
|
+
)
|
669
|
+
request = service_pb2.PostEvaluationsRequest(
|
670
|
+
user_app_id=self.user_app_id,
|
671
|
+
eval_metrics=[eval_metric],
|
672
|
+
)
|
673
|
+
response = self._grpc_request(self.STUB.PostEvaluations, request)
|
674
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
675
|
+
raise Exception(response.status)
|
676
|
+
self.logger.info(
|
677
|
+
"\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
|
678
|
+
)
|
679
|
+
|
680
|
+
return response.eval_metrics
|
681
|
+
|
682
|
+
def get_eval_by_id(
|
683
|
+
self,
|
684
|
+
eval_id: str,
|
685
|
+
label_counts=False,
|
686
|
+
test_set=False,
|
687
|
+
binary_metrics=False,
|
688
|
+
confusion_matrix=False,
|
689
|
+
metrics_by_class=False,
|
690
|
+
metrics_by_area=False,
|
691
|
+
) -> resources_pb2.EvalMetrics:
|
692
|
+
"""Get detail eval_metrics by eval_id with extra metric fields
|
693
|
+
|
694
|
+
Args:
|
695
|
+
eval_id (str): eval id
|
696
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
697
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
698
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
699
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
700
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
701
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
702
|
+
|
703
|
+
Raises:
|
704
|
+
Exception: Failed to call API
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
resources_pb2.EvalMetrics: eval_metrics
|
708
|
+
"""
|
709
|
+
request = service_pb2.GetEvaluationRequest(
|
710
|
+
user_app_id=self.user_app_id,
|
711
|
+
evaluation_id=eval_id,
|
712
|
+
fields=resources_pb2.FieldsValue(
|
713
|
+
label_counts=label_counts,
|
714
|
+
test_set=test_set,
|
715
|
+
binary_metrics=binary_metrics,
|
716
|
+
confusion_matrix=confusion_matrix,
|
717
|
+
metrics_by_class=metrics_by_class,
|
718
|
+
metrics_by_area=metrics_by_area,
|
719
|
+
))
|
720
|
+
response = self._grpc_request(self.STUB.GetEvaluation, request)
|
721
|
+
|
722
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
723
|
+
raise Exception(response.status)
|
724
|
+
|
725
|
+
return response.eval_metrics
|
726
|
+
|
727
|
+
def get_latest_eval(self,
|
728
|
+
label_counts=False,
|
729
|
+
test_set=False,
|
730
|
+
binary_metrics=False,
|
731
|
+
confusion_matrix=False,
|
732
|
+
metrics_by_class=False,
|
733
|
+
metrics_by_area=False) -> Union[resources_pb2.EvalMetrics, None]:
|
734
|
+
"""
|
735
|
+
Run `get_eval_by_id` method with latest `eval_id`
|
736
|
+
|
737
|
+
Args:
|
738
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
739
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
740
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
741
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
742
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
743
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
eval_metric if model is evaluated otherwise None.
|
747
|
+
|
748
|
+
"""
|
749
|
+
|
750
|
+
_latest = self.list_evaluations()[0]
|
751
|
+
result = None
|
752
|
+
if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
|
753
|
+
result = self.get_eval_by_id(
|
754
|
+
eval_id=_latest.id,
|
755
|
+
label_counts=label_counts,
|
756
|
+
test_set=test_set,
|
757
|
+
binary_metrics=binary_metrics,
|
758
|
+
confusion_matrix=confusion_matrix,
|
759
|
+
metrics_by_class=metrics_by_class,
|
760
|
+
metrics_by_area=metrics_by_area)
|
761
|
+
|
762
|
+
return result
|
763
|
+
|
764
|
+
def export(self, export_dir: str = None) -> None:
|
765
|
+
"""Export the model, stores the exported model as model.tar file
|
766
|
+
|
767
|
+
Args:
|
768
|
+
export_dir (str): The directory to save the exported model.
|
769
|
+
|
770
|
+
Example:
|
771
|
+
>>> from clarifai.client.model import Model
|
772
|
+
>>> model = Model("url")
|
773
|
+
>>> model.export('/path/to/export_model_dir')
|
774
|
+
"""
|
775
|
+
assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
776
|
+
try:
|
777
|
+
if not os.path.exists(export_dir):
|
778
|
+
os.makedirs(export_dir)
|
779
|
+
except OSError as e:
|
780
|
+
raise Exception(f"An error occurred while creating the directory: {e}")
|
781
|
+
|
782
|
+
def _get_export_response():
|
783
|
+
get_export_request = service_pb2.GetModelVersionExportRequest(
|
784
|
+
user_app_id=self.user_app_id,
|
785
|
+
model_id=self.id,
|
786
|
+
version_id=self.model_info.model_version.id,
|
787
|
+
)
|
788
|
+
response = self._grpc_request(self.STUB.GetModelVersionExport, get_export_request)
|
789
|
+
|
790
|
+
if response.status.code != status_code_pb2.SUCCESS and response.status.code != status_code_pb2.CONN_DOES_NOT_EXIST:
|
791
|
+
raise Exception(response.status)
|
792
|
+
|
793
|
+
return response
|
794
|
+
|
795
|
+
def _download_exported_model(
|
796
|
+
get_model_export_response: service_pb2.SingleModelVersionExportResponse,
|
797
|
+
local_filepath: str):
|
798
|
+
model_export_url = get_model_export_response.export.url
|
799
|
+
model_export_file_size = get_model_export_response.export.size
|
800
|
+
|
801
|
+
response = requests.get(model_export_url, stream=True)
|
802
|
+
response.raise_for_status()
|
803
|
+
|
804
|
+
with open(local_filepath, 'wb') as f:
|
805
|
+
progress = tqdm(
|
806
|
+
total=model_export_file_size, unit='B', unit_scale=True, desc="Exporting model")
|
807
|
+
for chunk in response.iter_content(chunk_size=8192):
|
808
|
+
f.write(chunk)
|
809
|
+
progress.update(len(chunk))
|
810
|
+
progress.close()
|
811
|
+
|
812
|
+
self.logger.info(
|
813
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} exported successfully to {export_dir}/model.tar"
|
814
|
+
)
|
815
|
+
|
816
|
+
get_export_response = _get_export_response()
|
817
|
+
if get_export_response.status.code == status_code_pb2.CONN_DOES_NOT_EXIST:
|
818
|
+
put_export_request = service_pb2.PutModelVersionExportsRequest(
|
819
|
+
user_app_id=self.user_app_id,
|
820
|
+
model_id=self.id,
|
821
|
+
version_id=self.model_info.model_version.id,
|
822
|
+
)
|
823
|
+
|
824
|
+
response = self._grpc_request(self.STUB.PutModelVersionExports, put_export_request)
|
825
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
826
|
+
raise Exception(response.status)
|
827
|
+
|
828
|
+
self.logger.info(
|
829
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} export started, please wait..."
|
830
|
+
)
|
831
|
+
time.sleep(5)
|
832
|
+
start_time = time.time()
|
833
|
+
backoff_iterator = BackoffIterator()
|
834
|
+
while True:
|
835
|
+
get_export_response = _get_export_response()
|
836
|
+
if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \
|
837
|
+
time.time() - start_time < 60 * 30: # 30 minutes
|
838
|
+
self.logger.info(
|
839
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} is still exporting, please wait..."
|
840
|
+
)
|
841
|
+
time.sleep(next(backoff_iterator))
|
842
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
843
|
+
_download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
|
844
|
+
elif time.time() - start_time > 60 * 30:
|
845
|
+
raise Exception(
|
846
|
+
f"""Model Export took too long. Please try again or contact support@clarifai.com
|
847
|
+
Req ID: {get_export_response.status.req_id}""")
|
848
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
849
|
+
_download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
|
clarifai/client/module.py
CHANGED
@@ -18,6 +18,7 @@ class Module(Lister, BaseClient):
|
|
18
18
|
module_version: Dict = {'id': ""},
|
19
19
|
base_url: str = "https://api.clarifai.com",
|
20
20
|
pat: str = None,
|
21
|
+
token: str = None,
|
21
22
|
**kwargs):
|
22
23
|
"""Initializes a Module object.
|
23
24
|
|
@@ -26,7 +27,8 @@ class Module(Lister, BaseClient):
|
|
26
27
|
module_id (str): The Module ID to interact with.
|
27
28
|
module_version (dict): The Module Version to interact with.
|
28
29
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
29
|
-
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
|
31
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
|
30
32
|
**kwargs: Additional keyword arguments to be passed to the Module.
|
31
33
|
"""
|
32
34
|
if url and module_id:
|
@@ -41,7 +43,8 @@ class Module(Lister, BaseClient):
|
|
41
43
|
self.kwargs = {**kwargs, 'id': module_id, 'module_version': module_version}
|
42
44
|
self.module_info = resources_pb2.Module(**self.kwargs)
|
43
45
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
44
|
-
BaseClient.__init__(
|
46
|
+
BaseClient.__init__(
|
47
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
45
48
|
Lister.__init__(self)
|
46
49
|
|
47
50
|
def list_versions(self, page_no: int = None,
|
@@ -78,10 +81,9 @@ class Module(Lister, BaseClient):
|
|
78
81
|
for module_version_info in all_module_versions_info:
|
79
82
|
module_version_info['id'] = module_version_info['module_version_id']
|
80
83
|
del module_version_info['module_version_id']
|
81
|
-
yield Module(
|
84
|
+
yield Module.from_auth_helper(
|
85
|
+
self.auth_helper,
|
82
86
|
module_id=self.id,
|
83
|
-
base_url=self.base,
|
84
|
-
pat=self.pat,
|
85
87
|
**dict(self.kwargs, module_version=module_version_info))
|
86
88
|
|
87
89
|
def __getattr__(self, name):
|
clarifai/client/runner.py
CHANGED
@@ -39,6 +39,7 @@ class Runner(BaseClient):
|
|
39
39
|
check_runner_exists: bool = True,
|
40
40
|
base_url: str = "https://api.clarifai.com",
|
41
41
|
pat: str = None,
|
42
|
+
token: str = None,
|
42
43
|
num_parallel_polls: int = 4,
|
43
44
|
**kwargs) -> None:
|
44
45
|
"""
|
@@ -47,6 +48,7 @@ class Runner(BaseClient):
|
|
47
48
|
user_id (str): Clarifai User ID
|
48
49
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
49
50
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
51
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
50
52
|
num_parallel_polls (int): the max number of threads for parallel run loops to be fetching work from
|
51
53
|
"""
|
52
54
|
user_id = user_id or os.environ.get("CLARIFAI_USER_ID", "")
|
@@ -60,7 +62,7 @@ class Runner(BaseClient):
|
|
60
62
|
self.kwargs = {**kwargs, 'id': runner_id, 'user_id': user_id}
|
61
63
|
self.runner_info = resources_pb2.Runner(**self.kwargs)
|
62
64
|
self.num_parallel_polls = min(10, num_parallel_polls)
|
63
|
-
BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat)
|
65
|
+
BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat, token=token)
|
64
66
|
|
65
67
|
# Check that the runner exists.
|
66
68
|
if check_runner_exists:
|
clarifai/client/search.py
CHANGED
@@ -23,7 +23,8 @@ class Search(Lister, BaseClient):
|
|
23
23
|
top_k: int = DEFAULT_TOP_K,
|
24
24
|
metric: str = DEFAULT_SEARCH_METRIC,
|
25
25
|
base_url: str = "https://api.clarifai.com",
|
26
|
-
pat: str = None
|
26
|
+
pat: str = None,
|
27
|
+
token: str = None):
|
27
28
|
"""Initialize the Search object.
|
28
29
|
|
29
30
|
Args:
|
@@ -33,6 +34,7 @@ class Search(Lister, BaseClient):
|
|
33
34
|
metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
|
34
35
|
base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
|
35
36
|
pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
37
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
36
38
|
|
37
39
|
Raises:
|
38
40
|
UserError: If the metric is not 'cosine' or 'euclidean'.
|
@@ -46,9 +48,11 @@ class Search(Lister, BaseClient):
|
|
46
48
|
self.data_proto = resources_pb2.Data()
|
47
49
|
self.top_k = top_k
|
48
50
|
|
49
|
-
self.inputs = Inputs(
|
51
|
+
self.inputs = Inputs(
|
52
|
+
user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
|
50
53
|
self.rank_filter_schema = get_schema()
|
51
|
-
BaseClient.__init__(
|
54
|
+
BaseClient.__init__(
|
55
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
52
56
|
Lister.__init__(self, page_size=1000)
|
53
57
|
|
54
58
|
def _get_annot_proto(self, **kwargs):
|
clarifai/client/user.py
CHANGED
@@ -19,6 +19,7 @@ class User(Lister, BaseClient):
|
|
19
19
|
user_id: str = None,
|
20
20
|
base_url: str = "https://api.clarifai.com",
|
21
21
|
pat: str = None,
|
22
|
+
token: str = None,
|
22
23
|
**kwargs):
|
23
24
|
"""Initializes an User object.
|
24
25
|
|
@@ -26,12 +27,13 @@ class User(Lister, BaseClient):
|
|
26
27
|
user_id (str): The user ID for the user to interact with.
|
27
28
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
28
29
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
29
31
|
**kwargs: Additional keyword arguments to be passed to the User.
|
30
32
|
"""
|
31
33
|
self.kwargs = {**kwargs, 'id': user_id}
|
32
34
|
self.user_info = resources_pb2.User(**self.kwargs)
|
33
35
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
34
|
-
BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat)
|
36
|
+
BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat, token=token)
|
35
37
|
Lister.__init__(self)
|
36
38
|
|
37
39
|
def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
@@ -62,7 +64,9 @@ class User(Lister, BaseClient):
|
|
62
64
|
per_page=per_page,
|
63
65
|
page_no=page_no)
|
64
66
|
for app_info in all_apps_info:
|
65
|
-
yield App(
|
67
|
+
yield App.from_auth_helper(
|
68
|
+
self.auth_helper,
|
69
|
+
**app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
|
66
70
|
|
67
71
|
def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
68
72
|
per_page: int = None) -> Generator[Runner, None, None]:
|
@@ -94,7 +98,8 @@ class User(Lister, BaseClient):
|
|
94
98
|
page_no=page_no)
|
95
99
|
|
96
100
|
for runner_info in all_runners_info:
|
97
|
-
yield Runner(
|
101
|
+
yield Runner.from_auth_helper(
|
102
|
+
auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
98
103
|
|
99
104
|
def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
|
100
105
|
"""Creates an app for the user.
|
@@ -120,8 +125,7 @@ class User(Lister, BaseClient):
|
|
120
125
|
if response.status.code != status_code_pb2.SUCCESS:
|
121
126
|
raise Exception(response.status)
|
122
127
|
self.logger.info("\nApp created\n%s", response.status)
|
123
|
-
|
124
|
-
return App(app_id=app_id, **kwargs)
|
128
|
+
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
|
125
129
|
|
126
130
|
def create_runner(self, runner_id: str, labels: List[str], description: str) -> Runner:
|
127
131
|
"""Create a runner
|
@@ -151,14 +155,13 @@ class User(Lister, BaseClient):
|
|
151
155
|
raise Exception(response.status)
|
152
156
|
self.logger.info("\nRunner created\n%s", response.status)
|
153
157
|
|
154
|
-
return Runner(
|
158
|
+
return Runner.from_auth_helper(
|
159
|
+
auth=self.auth_helper,
|
155
160
|
runner_id=runner_id,
|
156
161
|
user_id=self.id,
|
157
162
|
labels=labels,
|
158
163
|
description=description,
|
159
|
-
check_runner_exists=False
|
160
|
-
base_url=self.base,
|
161
|
-
pat=self.pat)
|
164
|
+
check_runner_exists=False)
|
162
165
|
|
163
166
|
def app(self, app_id: str, **kwargs) -> App:
|
164
167
|
"""Returns an App object for the specified app ID.
|
@@ -181,8 +184,7 @@ class User(Lister, BaseClient):
|
|
181
184
|
raise Exception(response.status)
|
182
185
|
|
183
186
|
kwargs['user_id'] = self.id
|
184
|
-
|
185
|
-
return App(app_id=app_id, **kwargs)
|
187
|
+
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
|
186
188
|
|
187
189
|
def runner(self, runner_id: str) -> Runner:
|
188
190
|
"""Returns a Runner object if exists.
|
@@ -210,7 +212,7 @@ class User(Lister, BaseClient):
|
|
210
212
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
211
213
|
list(dict_response.keys())[1])
|
212
214
|
|
213
|
-
return Runner(
|
215
|
+
return Runner.from_auth_helper(self.auth_helper, check_runner_exists=False, **kwargs)
|
214
216
|
|
215
217
|
def delete_app(self, app_id: str) -> None:
|
216
218
|
"""Deletes an app for the user.
|
clarifai/client/workflow.py
CHANGED
@@ -27,6 +27,7 @@ class Workflow(Lister, BaseClient):
|
|
27
27
|
output_config: Dict = {'min_value': 0},
|
28
28
|
base_url: str = "https://api.clarifai.com",
|
29
29
|
pat: str = None,
|
30
|
+
token: str = None,
|
30
31
|
**kwargs):
|
31
32
|
"""Initializes a Workflow object.
|
32
33
|
|
@@ -40,6 +41,8 @@ class Workflow(Lister, BaseClient):
|
|
40
41
|
select_concepts (list[Concept]): The concepts to select.
|
41
42
|
sample_ms (int): The number of milliseconds to sample.
|
42
43
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
44
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
43
46
|
**kwargs: Additional keyword arguments to be passed to the Workflow.
|
44
47
|
"""
|
45
48
|
if url and workflow_id:
|
@@ -55,7 +58,8 @@ class Workflow(Lister, BaseClient):
|
|
55
58
|
self.output_config = output_config
|
56
59
|
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
|
57
60
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
58
|
-
BaseClient.__init__(
|
61
|
+
BaseClient.__init__(
|
62
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
59
63
|
Lister.__init__(self)
|
60
64
|
|
61
65
|
def predict(self, inputs: List[Input], workflow_state_id: str = None):
|
@@ -206,10 +210,9 @@ class Workflow(Lister, BaseClient):
|
|
206
210
|
for workflow_version_info in all_workflow_versions_info:
|
207
211
|
workflow_version_info['id'] = workflow_version_info['workflow_version_id']
|
208
212
|
del workflow_version_info['workflow_version_id']
|
209
|
-
yield Workflow(
|
213
|
+
yield Workflow.from_auth_helper(
|
214
|
+
auth=self.auth_helper,
|
210
215
|
workflow_id=self.id,
|
211
|
-
base_url=self.base,
|
212
|
-
pat=self.pat,
|
213
216
|
**dict(self.kwargs, version=workflow_version_info))
|
214
217
|
|
215
218
|
def export(self, out_path: str):
|