clarifai 10.1.0__py3-none-any.whl → 10.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/input.py CHANGED
@@ -18,6 +18,7 @@ from tqdm import tqdm
18
18
 
19
19
  from clarifai.client.base import BaseClient
20
20
  from clarifai.client.lister import Lister
21
+ from clarifai.constants.dataset import MAX_RETRIES
21
22
  from clarifai.errors import UserError
22
23
  from clarifai.utils.logging import get_logger
23
24
  from clarifai.utils.misc import BackoffIterator, Chunker
@@ -32,6 +33,7 @@ class Inputs(Lister, BaseClient):
32
33
  logger_level: str = "INFO",
33
34
  base_url: str = "https://api.clarifai.com",
34
35
  pat: str = None,
36
+ token: str = None,
35
37
  **kwargs):
36
38
  """Initializes an Input object.
37
39
 
@@ -39,6 +41,8 @@ class Inputs(Lister, BaseClient):
39
41
  user_id (str): A user ID for authentication.
40
42
  app_id (str): An app ID for the application to interact with.
41
43
  base_url (str): Base API url. Default "https://api.clarifai.com"
44
+ pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
42
46
  **kwargs: Additional keyword arguments to be passed to the Input
43
47
  """
44
48
  self.user_id = user_id
@@ -46,7 +50,8 @@ class Inputs(Lister, BaseClient):
46
50
  self.kwargs = {**kwargs}
47
51
  self.input_info = resources_pb2.Input(**self.kwargs)
48
52
  self.logger = get_logger(logger_level=logger_level, name=__name__)
49
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
53
+ BaseClient.__init__(
54
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
50
55
  Lister.__init__(self)
51
56
 
52
57
  @staticmethod
@@ -670,6 +675,30 @@ class Inputs(Lister, BaseClient):
670
675
 
671
676
  return input_job_id, response
672
677
 
678
+ def patch_inputs(self, inputs: List[Input], action: str = 'merge') -> str:
679
+ """Patch list of input objects to the app.
680
+
681
+ Args:
682
+ inputs (list): List of input objects to upload.
683
+ action (str): Action to perform on the input. Options: 'merge', 'overwrite', 'remove'.
684
+
685
+ Returns:
686
+ response: Response from the grpc request.
687
+ """
688
+ if not isinstance(inputs, list):
689
+ raise UserError("inputs must be a list of Input objects")
690
+ uuid.uuid4().hex # generate a unique id for this job
691
+ request = service_pb2.PatchInputsRequest(
692
+ user_app_id=self.user_app_id, inputs=inputs, action=action)
693
+ response = self._grpc_request(self.STUB.PatchInputs, request)
694
+ if response.status.code != status_code_pb2.SUCCESS:
695
+ try:
696
+ self.logger.warning(f"Patch inputs failed, status: {response.annotations[0].status}")
697
+ except Exception:
698
+ self.logger.warning(f"Patch inputs failed, status: {response.status.details}")
699
+
700
+ self.logger.info("\nPatch Inputs Successful\n%s", response.status)
701
+
673
702
  def upload_annotations(self, batch_annot: List[resources_pb2.Annotation], show_log: bool = True
674
703
  ) -> Union[List[resources_pb2.Annotation], List[None]]:
675
704
  """Upload image annotations to app.
@@ -908,10 +937,14 @@ class Inputs(Lister, BaseClient):
908
937
  """Retry failed uploads.
909
938
 
910
939
  Args:
911
- failed_inputs (List[Input]): failed input prots
940
+ failed_inputs (List[Input]): failed input protos
912
941
  """
913
- if failed_inputs:
914
- self._upload_batch(failed_inputs)
942
+ for _retry in range(MAX_RETRIES):
943
+ if failed_inputs:
944
+ self.logger.info(f"Retrying upload for {len(failed_inputs)} Failed inputs..\n")
945
+ failed_inputs = self._upload_batch(failed_inputs)
946
+
947
+ self.logger.warning(f"Failed to upload {len(failed_inputs)} inputs..\n ")
915
948
 
916
949
  def _delete_failed_inputs(self, inputs: List[Input]) -> List[Input]:
917
950
  """Delete failed input ids from clarifai platform dataset.
clarifai/client/model.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import time
3
- from typing import Any, Dict, Generator, List
3
+ from typing import Any, Dict, Generator, List, Union
4
4
 
5
5
  import requests
6
6
  import yaml
@@ -9,6 +9,7 @@ from clarifai_grpc.grpc.api.resources_pb2 import Input
9
9
  from clarifai_grpc.grpc.api.status import status_code_pb2
10
10
  from google.protobuf.json_format import MessageToDict
11
11
  from google.protobuf.struct_pb2 import Struct
12
+ from tqdm import tqdm
12
13
 
13
14
  from clarifai.client.base import BaseClient
14
15
  from clarifai.client.input import Inputs
@@ -32,6 +33,7 @@ class Model(Lister, BaseClient):
32
33
  model_version: Dict = {'id': ""},
33
34
  base_url: str = "https://api.clarifai.com",
34
35
  pat: str = None,
36
+ token: str = None,
35
37
  **kwargs):
36
38
  """Initializes a Model object.
37
39
 
@@ -41,6 +43,7 @@ class Model(Lister, BaseClient):
41
43
  model_version (dict): The Model Version to interact with.
42
44
  base_url (str): Base API url. Default "https://api.clarifai.com"
43
45
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
46
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
44
47
  **kwargs: Additional keyword arguments to be passed to the Model.
45
48
  """
46
49
  if url and model_id:
@@ -55,7 +58,8 @@ class Model(Lister, BaseClient):
55
58
  self.model_info = resources_pb2.Model(**self.kwargs)
56
59
  self.logger = get_logger(logger_level="INFO", name=__name__)
57
60
  self.training_params = {}
58
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
61
+ BaseClient.__init__(
62
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
59
63
  Lister.__init__(self)
60
64
 
61
65
  def list_training_templates(self) -> List[str]:
@@ -212,6 +216,8 @@ class Model(Lister, BaseClient):
212
216
  >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
213
217
  >>> model.train('model_params.yaml')
214
218
  """
219
+ if not self.model_info.model_type_id:
220
+ self.load_info()
215
221
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
216
222
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
217
223
  if not yaml_file and len(self.training_params) == 0:
@@ -222,8 +228,10 @@ class Model(Lister, BaseClient):
222
228
  params_dict = yaml.safe_load(file)
223
229
  else:
224
230
  params_dict = self.training_params
225
-
226
- train_dict = params_parser(params_dict)
231
+ #getting all the concepts for the model type
232
+ if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
233
+ concepts = self._list_concepts()
234
+ train_dict = params_parser(params_dict, concepts)
227
235
  request = service_pb2.PostModelVersionsRequest(
228
236
  user_app_id=self.user_app_id,
229
237
  model_id=self.id,
@@ -331,7 +339,7 @@ class Model(Lister, BaseClient):
331
339
  dict_response = MessageToDict(response, preserving_proto_field_name=True)
332
340
  kwargs = self.process_response_keys(dict_response['model'], 'model')
333
341
 
334
- return Model(base_url=self.base, pat=self.pat, **kwargs)
342
+ return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
335
343
 
336
344
  def list_versions(self, page_no: int = None,
337
345
  per_page: int = None) -> Generator['Model', None, None]:
@@ -373,10 +381,9 @@ class Model(Lister, BaseClient):
373
381
  del model_version_info['train_info']['dataset']['version']['metrics']
374
382
  except KeyError:
375
383
  pass
376
- yield Model(
384
+ yield Model.from_auth_helper(
385
+ auth=self.auth_helper,
377
386
  model_id=self.id,
378
- base_url=self.base,
379
- pat=self.pat,
380
387
  **dict(self.kwargs, model_version=model_version_info))
381
388
 
382
389
  def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
@@ -548,6 +555,17 @@ class Model(Lister, BaseClient):
548
555
  resources_pb2.OutputInfo(
549
556
  output_config=resources_pb2.OutputConfig(**output_config), params=params))
550
557
 
558
+ def _list_concepts(self) -> List[str]:
559
+ """Lists all the concepts for the model type.
560
+
561
+ Returns:
562
+ concepts (List): List of concepts for the model type.
563
+ """
564
+ request_data = dict(user_app_id=self.user_app_id)
565
+ all_concepts_infos = self.list_pages_generator(self.STUB.ListConcepts,
566
+ service_pb2.ListConceptsRequest, request_data)
567
+ return [concept_info['concept_id'] for concept_info in all_concepts_infos]
568
+
551
569
  def load_info(self) -> None:
552
570
  """Loads the model info."""
553
571
  request = service_pb2.GetModelRequest(
@@ -576,3 +594,256 @@ class Model(Lister, BaseClient):
576
594
  if hasattr(self.model_info, param)
577
595
  ]
578
596
  return f"Model Details: \n{', '.join(attribute_strings)}\n"
597
+
598
+ def list_evaluations(self) -> resources_pb2.EvalMetrics:
599
+ """List all eval_metrics of current model version
600
+
601
+ Raises:
602
+ Exception: Failed to call API
603
+
604
+ Returns:
605
+ resources_pb2.EvalMetrics
606
+ """
607
+ assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
608
+ request = service_pb2.ListModelVersionEvaluationsRequest(
609
+ user_app_id=self.user_app_id,
610
+ model_id=self.id,
611
+ model_version_id=self.model_info.model_version.id)
612
+ response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
613
+
614
+ if response.status.code != status_code_pb2.SUCCESS:
615
+ raise Exception(response.status)
616
+
617
+ return response.eval_metrics
618
+
619
+ def evaluate(self,
620
+ dataset_id: str,
621
+ dataset_app_id: str = None,
622
+ dataset_user_id: str = None,
623
+ eval_id: str = None,
624
+ extended_metrics: dict = None,
625
+ eval_info: dict = None) -> resources_pb2.EvalMetrics:
626
+ """ Run evaluation
627
+
628
+ Args:
629
+ dataset_id (str): Dataset Id.
630
+ dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
631
+ dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
632
+ eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
633
+ extended_metrics (dict): user custom metrics result. Default is None.
634
+ eval_info (dict): custom eval info. Default is empty dict.
635
+
636
+ Return
637
+ eval_metrics
638
+
639
+ """
640
+ assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
641
+ metrics = None
642
+ if isinstance(extended_metrics, dict):
643
+ metrics = Struct()
644
+ metrics.update(extended_metrics)
645
+ metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
646
+
647
+ eval_info_params = None
648
+ if isinstance(eval_info, dict):
649
+ eval_info_params = Struct()
650
+ eval_info_params.update(eval_info)
651
+ eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
652
+
653
+ eval_metric = resources_pb2.EvalMetrics(
654
+ id=eval_id,
655
+ model=resources_pb2.Model(
656
+ id=self.id,
657
+ app_id=self.auth_helper.app_id,
658
+ user_id=self.auth_helper.user_id,
659
+ model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
660
+ ),
661
+ extended_metrics=metrics,
662
+ ground_truth_dataset=resources_pb2.Dataset(
663
+ id=dataset_id,
664
+ app_id=dataset_app_id or self.auth_helper.app_id,
665
+ user_id=dataset_user_id or self.auth_helper.user_id,
666
+ ),
667
+ eval_info=eval_info_params,
668
+ )
669
+ request = service_pb2.PostEvaluationsRequest(
670
+ user_app_id=self.user_app_id,
671
+ eval_metrics=[eval_metric],
672
+ )
673
+ response = self._grpc_request(self.STUB.PostEvaluations, request)
674
+ if response.status.code != status_code_pb2.SUCCESS:
675
+ raise Exception(response.status)
676
+ self.logger.info(
677
+ "\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
678
+ )
679
+
680
+ return response.eval_metrics
681
+
682
+ def get_eval_by_id(
683
+ self,
684
+ eval_id: str,
685
+ label_counts=False,
686
+ test_set=False,
687
+ binary_metrics=False,
688
+ confusion_matrix=False,
689
+ metrics_by_class=False,
690
+ metrics_by_area=False,
691
+ ) -> resources_pb2.EvalMetrics:
692
+ """Get detail eval_metrics by eval_id with extra metric fields
693
+
694
+ Args:
695
+ eval_id (str): eval id
696
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
697
+ test_set (bool, optional): Set True to get test set. Defaults to False.
698
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
699
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
700
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
701
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
702
+
703
+ Raises:
704
+ Exception: Failed to call API
705
+
706
+ Returns:
707
+ resources_pb2.EvalMetrics: eval_metrics
708
+ """
709
+ request = service_pb2.GetEvaluationRequest(
710
+ user_app_id=self.user_app_id,
711
+ evaluation_id=eval_id,
712
+ fields=resources_pb2.FieldsValue(
713
+ label_counts=label_counts,
714
+ test_set=test_set,
715
+ binary_metrics=binary_metrics,
716
+ confusion_matrix=confusion_matrix,
717
+ metrics_by_class=metrics_by_class,
718
+ metrics_by_area=metrics_by_area,
719
+ ))
720
+ response = self._grpc_request(self.STUB.GetEvaluation, request)
721
+
722
+ if response.status.code != status_code_pb2.SUCCESS:
723
+ raise Exception(response.status)
724
+
725
+ return response.eval_metrics
726
+
727
+ def get_latest_eval(self,
728
+ label_counts=False,
729
+ test_set=False,
730
+ binary_metrics=False,
731
+ confusion_matrix=False,
732
+ metrics_by_class=False,
733
+ metrics_by_area=False) -> Union[resources_pb2.EvalMetrics, None]:
734
+ """
735
+ Run `get_eval_by_id` method with latest `eval_id`
736
+
737
+ Args:
738
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
739
+ test_set (bool, optional): Set True to get test set. Defaults to False.
740
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
741
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
742
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
743
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
744
+
745
+ Returns:
746
+ eval_metric if model is evaluated otherwise None.
747
+
748
+ """
749
+
750
+ _latest = self.list_evaluations()[0]
751
+ result = None
752
+ if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
753
+ result = self.get_eval_by_id(
754
+ eval_id=_latest.id,
755
+ label_counts=label_counts,
756
+ test_set=test_set,
757
+ binary_metrics=binary_metrics,
758
+ confusion_matrix=confusion_matrix,
759
+ metrics_by_class=metrics_by_class,
760
+ metrics_by_area=metrics_by_area)
761
+
762
+ return result
763
+
764
+ def export(self, export_dir: str = None) -> None:
765
+ """Export the model, stores the exported model as model.tar file
766
+
767
+ Args:
768
+ export_dir (str): The directory to save the exported model.
769
+
770
+ Example:
771
+ >>> from clarifai.client.model import Model
772
+ >>> model = Model("url")
773
+ >>> model.export('/path/to/export_model_dir')
774
+ """
775
+ assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
776
+ try:
777
+ if not os.path.exists(export_dir):
778
+ os.makedirs(export_dir)
779
+ except OSError as e:
780
+ raise Exception(f"An error occurred while creating the directory: {e}")
781
+
782
+ def _get_export_response():
783
+ get_export_request = service_pb2.GetModelVersionExportRequest(
784
+ user_app_id=self.user_app_id,
785
+ model_id=self.id,
786
+ version_id=self.model_info.model_version.id,
787
+ )
788
+ response = self._grpc_request(self.STUB.GetModelVersionExport, get_export_request)
789
+
790
+ if response.status.code != status_code_pb2.SUCCESS and response.status.code != status_code_pb2.CONN_DOES_NOT_EXIST:
791
+ raise Exception(response.status)
792
+
793
+ return response
794
+
795
+ def _download_exported_model(
796
+ get_model_export_response: service_pb2.SingleModelVersionExportResponse,
797
+ local_filepath: str):
798
+ model_export_url = get_model_export_response.export.url
799
+ model_export_file_size = get_model_export_response.export.size
800
+
801
+ response = requests.get(model_export_url, stream=True)
802
+ response.raise_for_status()
803
+
804
+ with open(local_filepath, 'wb') as f:
805
+ progress = tqdm(
806
+ total=model_export_file_size, unit='B', unit_scale=True, desc="Exporting model")
807
+ for chunk in response.iter_content(chunk_size=8192):
808
+ f.write(chunk)
809
+ progress.update(len(chunk))
810
+ progress.close()
811
+
812
+ self.logger.info(
813
+ f"Model ID {self.id} with version {self.model_info.model_version.id} exported successfully to {export_dir}/model.tar"
814
+ )
815
+
816
+ get_export_response = _get_export_response()
817
+ if get_export_response.status.code == status_code_pb2.CONN_DOES_NOT_EXIST:
818
+ put_export_request = service_pb2.PutModelVersionExportsRequest(
819
+ user_app_id=self.user_app_id,
820
+ model_id=self.id,
821
+ version_id=self.model_info.model_version.id,
822
+ )
823
+
824
+ response = self._grpc_request(self.STUB.PutModelVersionExports, put_export_request)
825
+ if response.status.code != status_code_pb2.SUCCESS:
826
+ raise Exception(response.status)
827
+
828
+ self.logger.info(
829
+ f"Model ID {self.id} with version {self.model_info.model_version.id} export started, please wait..."
830
+ )
831
+ time.sleep(5)
832
+ start_time = time.time()
833
+ backoff_iterator = BackoffIterator()
834
+ while True:
835
+ get_export_response = _get_export_response()
836
+ if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \
837
+ time.time() - start_time < 60 * 30: # 30 minutes
838
+ self.logger.info(
839
+ f"Model ID {self.id} with version {self.model_info.model_version.id} is still exporting, please wait..."
840
+ )
841
+ time.sleep(next(backoff_iterator))
842
+ elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
843
+ _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
844
+ elif time.time() - start_time > 60 * 30:
845
+ raise Exception(
846
+ f"""Model Export took too long. Please try again or contact support@clarifai.com
847
+ Req ID: {get_export_response.status.req_id}""")
848
+ elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
849
+ _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
clarifai/client/module.py CHANGED
@@ -18,6 +18,7 @@ class Module(Lister, BaseClient):
18
18
  module_version: Dict = {'id': ""},
19
19
  base_url: str = "https://api.clarifai.com",
20
20
  pat: str = None,
21
+ token: str = None,
21
22
  **kwargs):
22
23
  """Initializes a Module object.
23
24
 
@@ -26,7 +27,8 @@ class Module(Lister, BaseClient):
26
27
  module_id (str): The Module ID to interact with.
27
28
  module_version (dict): The Module Version to interact with.
28
29
  base_url (str): Base API url. Default "https://api.clarifai.com"
29
- pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
30
+ pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
31
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
30
32
  **kwargs: Additional keyword arguments to be passed to the Module.
31
33
  """
32
34
  if url and module_id:
@@ -41,7 +43,8 @@ class Module(Lister, BaseClient):
41
43
  self.kwargs = {**kwargs, 'id': module_id, 'module_version': module_version}
42
44
  self.module_info = resources_pb2.Module(**self.kwargs)
43
45
  self.logger = get_logger(logger_level="INFO", name=__name__)
44
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
46
+ BaseClient.__init__(
47
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
45
48
  Lister.__init__(self)
46
49
 
47
50
  def list_versions(self, page_no: int = None,
@@ -78,10 +81,9 @@ class Module(Lister, BaseClient):
78
81
  for module_version_info in all_module_versions_info:
79
82
  module_version_info['id'] = module_version_info['module_version_id']
80
83
  del module_version_info['module_version_id']
81
- yield Module(
84
+ yield Module.from_auth_helper(
85
+ self.auth_helper,
82
86
  module_id=self.id,
83
- base_url=self.base,
84
- pat=self.pat,
85
87
  **dict(self.kwargs, module_version=module_version_info))
86
88
 
87
89
  def __getattr__(self, name):
clarifai/client/runner.py CHANGED
@@ -39,6 +39,7 @@ class Runner(BaseClient):
39
39
  check_runner_exists: bool = True,
40
40
  base_url: str = "https://api.clarifai.com",
41
41
  pat: str = None,
42
+ token: str = None,
42
43
  num_parallel_polls: int = 4,
43
44
  **kwargs) -> None:
44
45
  """
@@ -47,6 +48,7 @@ class Runner(BaseClient):
47
48
  user_id (str): Clarifai User ID
48
49
  base_url (str): Base API url. Default "https://api.clarifai.com"
49
50
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
51
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
50
52
  num_parallel_polls (int): the max number of threads for parallel run loops to be fetching work from
51
53
  """
52
54
  user_id = user_id or os.environ.get("CLARIFAI_USER_ID", "")
@@ -60,7 +62,7 @@ class Runner(BaseClient):
60
62
  self.kwargs = {**kwargs, 'id': runner_id, 'user_id': user_id}
61
63
  self.runner_info = resources_pb2.Runner(**self.kwargs)
62
64
  self.num_parallel_polls = min(10, num_parallel_polls)
63
- BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat)
65
+ BaseClient.__init__(self, user_id=self.user_id, app_id="", base=base_url, pat=pat, token=token)
64
66
 
65
67
  # Check that the runner exists.
66
68
  if check_runner_exists:
clarifai/client/search.py CHANGED
@@ -23,7 +23,8 @@ class Search(Lister, BaseClient):
23
23
  top_k: int = DEFAULT_TOP_K,
24
24
  metric: str = DEFAULT_SEARCH_METRIC,
25
25
  base_url: str = "https://api.clarifai.com",
26
- pat: str = None):
26
+ pat: str = None,
27
+ token: str = None):
27
28
  """Initialize the Search object.
28
29
 
29
30
  Args:
@@ -33,6 +34,7 @@ class Search(Lister, BaseClient):
33
34
  metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
34
35
  base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
35
36
  pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
37
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
36
38
 
37
39
  Raises:
38
40
  UserError: If the metric is not 'cosine' or 'euclidean'.
@@ -46,9 +48,11 @@ class Search(Lister, BaseClient):
46
48
  self.data_proto = resources_pb2.Data()
47
49
  self.top_k = top_k
48
50
 
49
- self.inputs = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat)
51
+ self.inputs = Inputs(
52
+ user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
50
53
  self.rank_filter_schema = get_schema()
51
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
54
+ BaseClient.__init__(
55
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
52
56
  Lister.__init__(self, page_size=1000)
53
57
 
54
58
  def _get_annot_proto(self, **kwargs):
clarifai/client/user.py CHANGED
@@ -19,6 +19,7 @@ class User(Lister, BaseClient):
19
19
  user_id: str = None,
20
20
  base_url: str = "https://api.clarifai.com",
21
21
  pat: str = None,
22
+ token: str = None,
22
23
  **kwargs):
23
24
  """Initializes an User object.
24
25
 
@@ -26,12 +27,13 @@ class User(Lister, BaseClient):
26
27
  user_id (str): The user ID for the user to interact with.
27
28
  base_url (str): Base API url. Default "https://api.clarifai.com"
28
29
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
30
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
29
31
  **kwargs: Additional keyword arguments to be passed to the User.
30
32
  """
31
33
  self.kwargs = {**kwargs, 'id': user_id}
32
34
  self.user_info = resources_pb2.User(**self.kwargs)
33
35
  self.logger = get_logger(logger_level="INFO", name=__name__)
34
- BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat)
36
+ BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat, token=token)
35
37
  Lister.__init__(self)
36
38
 
37
39
  def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
@@ -62,7 +64,9 @@ class User(Lister, BaseClient):
62
64
  per_page=per_page,
63
65
  page_no=page_no)
64
66
  for app_info in all_apps_info:
65
- yield App(base_url=self.base, pat=self.pat, **app_info)
67
+ yield App.from_auth_helper(
68
+ self.auth_helper,
69
+ **app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
66
70
 
67
71
  def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
68
72
  per_page: int = None) -> Generator[Runner, None, None]:
@@ -94,7 +98,8 @@ class User(Lister, BaseClient):
94
98
  page_no=page_no)
95
99
 
96
100
  for runner_info in all_runners_info:
97
- yield Runner(check_runner_exists=False, base_url=self.base, pat=self.pat, **runner_info)
101
+ yield Runner.from_auth_helper(
102
+ auth=self.auth_helper, check_runner_exists=False, **runner_info)
98
103
 
99
104
  def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
100
105
  """Creates an app for the user.
@@ -120,8 +125,7 @@ class User(Lister, BaseClient):
120
125
  if response.status.code != status_code_pb2.SUCCESS:
121
126
  raise Exception(response.status)
122
127
  self.logger.info("\nApp created\n%s", response.status)
123
- kwargs.update({'user_id': self.id, 'base_url': self.base, 'pat': self.pat})
124
- return App(app_id=app_id, **kwargs)
128
+ return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
125
129
 
126
130
  def create_runner(self, runner_id: str, labels: List[str], description: str) -> Runner:
127
131
  """Create a runner
@@ -151,14 +155,13 @@ class User(Lister, BaseClient):
151
155
  raise Exception(response.status)
152
156
  self.logger.info("\nRunner created\n%s", response.status)
153
157
 
154
- return Runner(
158
+ return Runner.from_auth_helper(
159
+ auth=self.auth_helper,
155
160
  runner_id=runner_id,
156
161
  user_id=self.id,
157
162
  labels=labels,
158
163
  description=description,
159
- check_runner_exists=False,
160
- base_url=self.base,
161
- pat=self.pat)
164
+ check_runner_exists=False)
162
165
 
163
166
  def app(self, app_id: str, **kwargs) -> App:
164
167
  """Returns an App object for the specified app ID.
@@ -181,8 +184,7 @@ class User(Lister, BaseClient):
181
184
  raise Exception(response.status)
182
185
 
183
186
  kwargs['user_id'] = self.id
184
- kwargs.update({'base_url': self.base, 'pat': self.pat})
185
- return App(app_id=app_id, **kwargs)
187
+ return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
186
188
 
187
189
  def runner(self, runner_id: str) -> Runner:
188
190
  """Returns a Runner object if exists.
@@ -210,7 +212,7 @@ class User(Lister, BaseClient):
210
212
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
211
213
  list(dict_response.keys())[1])
212
214
 
213
- return Runner(check_runner_exists=False, base_url=self.base, pat=self.pat, **kwargs)
215
+ return Runner.from_auth_helper(self.auth_helper, check_runner_exists=False, **kwargs)
214
216
 
215
217
  def delete_app(self, app_id: str) -> None:
216
218
  """Deletes an app for the user.
@@ -27,6 +27,7 @@ class Workflow(Lister, BaseClient):
27
27
  output_config: Dict = {'min_value': 0},
28
28
  base_url: str = "https://api.clarifai.com",
29
29
  pat: str = None,
30
+ token: str = None,
30
31
  **kwargs):
31
32
  """Initializes a Workflow object.
32
33
 
@@ -40,6 +41,8 @@ class Workflow(Lister, BaseClient):
40
41
  select_concepts (list[Concept]): The concepts to select.
41
42
  sample_ms (int): The number of milliseconds to sample.
42
43
  base_url (str): Base API url. Default "https://api.clarifai.com"
44
+ pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
43
46
  **kwargs: Additional keyword arguments to be passed to the Workflow.
44
47
  """
45
48
  if url and workflow_id:
@@ -55,7 +58,8 @@ class Workflow(Lister, BaseClient):
55
58
  self.output_config = output_config
56
59
  self.workflow_info = resources_pb2.Workflow(**self.kwargs)
57
60
  self.logger = get_logger(logger_level="INFO", name=__name__)
58
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
61
+ BaseClient.__init__(
62
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
59
63
  Lister.__init__(self)
60
64
 
61
65
  def predict(self, inputs: List[Input], workflow_state_id: str = None):
@@ -206,10 +210,9 @@ class Workflow(Lister, BaseClient):
206
210
  for workflow_version_info in all_workflow_versions_info:
207
211
  workflow_version_info['id'] = workflow_version_info['workflow_version_id']
208
212
  del workflow_version_info['workflow_version_id']
209
- yield Workflow(
213
+ yield Workflow.from_auth_helper(
214
+ auth=self.auth_helper,
210
215
  workflow_id=self.id,
211
- base_url=self.base,
212
- pat=self.pat,
213
216
  **dict(self.kwargs, version=workflow_version_info))
214
217
 
215
218
  def export(self, out_path: str):
@@ -20,3 +20,5 @@ TASK_TO_ANNOTATION_TYPE = {
20
20
  "polygons": "polygons"
21
21
  },
22
22
  }
23
+
24
+ MAX_RETRIES = 2