clarifai 10.2.0__py3-none-any.whl → 10.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/model.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  import time
3
- from typing import Any, Dict, Generator, List, Union
3
+ from typing import Any, Dict, Generator, List, Tuple, Union
4
4
 
5
+ import numpy as np
5
6
  import requests
6
7
  import yaml
7
8
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
@@ -12,6 +13,7 @@ from google.protobuf.struct_pb2 import Struct
12
13
  from tqdm import tqdm
13
14
 
14
15
  from clarifai.client.base import BaseClient
16
+ from clarifai.client.dataset import Dataset
15
17
  from clarifai.client.input import Inputs
16
18
  from clarifai.client.lister import Lister
17
19
  from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS, TRAINABLE_MODEL_TYPES
@@ -34,6 +36,7 @@ class Model(Lister, BaseClient):
34
36
  base_url: str = "https://api.clarifai.com",
35
37
  pat: str = None,
36
38
  token: str = None,
39
+ root_certificates_path: str = None,
37
40
  **kwargs):
38
41
  """Initializes a Model object.
39
42
 
@@ -44,6 +47,7 @@ class Model(Lister, BaseClient):
44
47
  base_url (str): Base API url. Default "https://api.clarifai.com"
45
48
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
46
49
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
50
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
47
51
  **kwargs: Additional keyword arguments to be passed to the Model.
48
52
  """
49
53
  if url and model_id:
@@ -59,7 +63,13 @@ class Model(Lister, BaseClient):
59
63
  self.logger = get_logger(logger_level="INFO", name=__name__)
60
64
  self.training_params = {}
61
65
  BaseClient.__init__(
62
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
66
+ self,
67
+ user_id=self.user_id,
68
+ app_id=self.app_id,
69
+ base=base_url,
70
+ pat=pat,
71
+ token=token,
72
+ root_certificates_path=root_certificates_path)
63
73
  Lister.__init__(self)
64
74
 
65
75
  def list_training_templates(self) -> List[str]:
@@ -243,7 +253,7 @@ class Model(Lister, BaseClient):
243
253
 
244
254
  return response.model.model_version.id
245
255
 
246
- def training_status(self, version_id: str, training_logs: bool = False) -> Dict[str, str]:
256
+ def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
247
257
  """Get the training status for the model version. Also stores training logs
248
258
 
249
259
  Args:
@@ -258,19 +268,20 @@ class Model(Lister, BaseClient):
258
268
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
259
269
  >>> model.training_status(version_id='version_id',training_logs=True)
260
270
  """
271
+ if not version_id and not self.model_info.model_version.id:
272
+ raise UserError(
273
+ "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
274
+ )
275
+
276
+ if not self.model_info.model_type_id or not self.model_info.model_version.train_log:
277
+ self.load_info()
261
278
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
262
279
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
263
280
 
264
- request = service_pb2.GetModelVersionRequest(
265
- user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
266
- response = self._grpc_request(self.STUB.GetModelVersion, request)
267
- if response.status.code != status_code_pb2.SUCCESS:
268
- raise Exception(response.status)
269
-
270
281
  if training_logs:
271
282
  try:
272
- if response.model_version.train_log:
273
- log_response = requests.get(response.model_version.train_log)
283
+ if self.model_info.model_version.train_log:
284
+ log_response = requests.get(self.model_info.model_version.train_log)
274
285
  log_response.raise_for_status() # Check for any HTTP errors
275
286
  with open(version_id + '.log', 'wb') as file:
276
287
  for chunk in log_response.iter_content(chunk_size=4096): # 4KB
@@ -280,7 +291,7 @@ class Model(Lister, BaseClient):
280
291
  except requests.exceptions.RequestException as e:
281
292
  raise Exception(f"An error occurred while getting training logs: {e}")
282
293
 
283
- return response.model_version.status
294
+ return self.model_info.model_version.status
284
295
 
285
296
  def delete_version(self, version_id: str) -> None:
286
297
  """Deletes a model version for the Model.
@@ -407,7 +418,7 @@ class Model(Lister, BaseClient):
407
418
  model=self.model_info)
408
419
 
409
420
  start_time = time.time()
410
- backoff_iterator = BackoffIterator()
421
+ backoff_iterator = BackoffIterator(10)
411
422
  while True:
412
423
  response = self._grpc_request(self.STUB.PostModelOutputs, request)
413
424
 
@@ -617,18 +628,22 @@ class Model(Lister, BaseClient):
617
628
  return response.eval_metrics
618
629
 
619
630
  def evaluate(self,
620
- dataset_id: str,
631
+ dataset: Dataset = None,
632
+ dataset_id: str = None,
621
633
  dataset_app_id: str = None,
622
634
  dataset_user_id: str = None,
635
+ dataset_version_id: str = None,
623
636
  eval_id: str = None,
624
637
  extended_metrics: dict = None,
625
638
  eval_info: dict = None) -> resources_pb2.EvalMetrics:
626
639
  """ Run evaluation
627
640
 
628
641
  Args:
629
- dataset_id (str): Dataset Id.
642
+ dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
643
+ dataset_id (str): Dataset Id. Default is None.
630
644
  dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
631
645
  dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
646
+ dataset_version_id (str): Dataset version Id. Default is None.
632
647
  eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
633
648
  extended_metrics (dict): user custom metrics result. Default is None.
634
649
  eval_info (dict): custom eval info. Default is empty dict.
@@ -638,6 +653,23 @@ class Model(Lister, BaseClient):
638
653
 
639
654
  """
640
655
  assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
656
+
657
+ if dataset:
658
+ self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
659
+ dataset_id = dataset.id
660
+ dataset_app_id = dataset.app_id
661
+ dataset_user_id = dataset.user_id
662
+ dataset_version_id = dataset.version.id
663
+ else:
664
+ self.logger.warning(
665
+ "Arguments prefixed with `dataset_` will be removed soon, please use dataset")
666
+
667
+ gt_dataset = resources_pb2.Dataset(
668
+ id=dataset_id,
669
+ app_id=dataset_app_id or self.auth_helper.app_id,
670
+ user_id=dataset_user_id or self.auth_helper.user_id,
671
+ version=resources_pb2.DatasetVersion(id=dataset_version_id))
672
+
641
673
  metrics = None
642
674
  if isinstance(extended_metrics, dict):
643
675
  metrics = Struct()
@@ -659,11 +691,7 @@ class Model(Lister, BaseClient):
659
691
  model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
660
692
  ),
661
693
  extended_metrics=metrics,
662
- ground_truth_dataset=resources_pb2.Dataset(
663
- id=dataset_id,
664
- app_id=dataset_app_id or self.auth_helper.app_id,
665
- user_id=dataset_user_id or self.auth_helper.user_id,
666
- ),
694
+ ground_truth_dataset=gt_dataset,
667
695
  eval_info=eval_info_params,
668
696
  )
669
697
  request = service_pb2.PostEvaluationsRequest(
@@ -761,6 +789,157 @@ class Model(Lister, BaseClient):
761
789
 
762
790
  return result
763
791
 
792
+ def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
793
+ """Get all eval data of dataset
794
+
795
+ Args:
796
+ dataset (Dataset): Clarifai dataset
797
+
798
+ Returns:
799
+ List[resources_pb2.EvalMetrics]
800
+ """
801
+ _id = dataset.id
802
+ app = dataset.app_id or self.app_id
803
+ user_id = dataset.user_id or self.user_id
804
+ version = dataset.version.id
805
+
806
+ list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
807
+ outputs = []
808
+ for _eval in list_eval:
809
+ if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
810
+ gt_ds = _eval.ground_truth_dataset
811
+ if (_id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id):
812
+ if not version or version == gt_ds.version.id:
813
+ outputs.append(_eval)
814
+
815
+ return outputs
816
+
817
+ def get_raw_eval(self,
818
+ dataset: Dataset = None,
819
+ eval_id: str = None,
820
+ return_format: str = 'array') -> Union[resources_pb2.EvalTestSetEntry, Tuple[
821
+ np.array, np.array, list, List[Input]], Tuple[List[dict], List[dict]]]:
822
+ """Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
823
+
824
+ Args:
825
+ dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
826
+ eval_id (str): Evaluation ID, get eval data of specific eval id.
827
+ return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
828
+
829
+ Returns:
830
+
831
+ Depends on `return_format`.
832
+
833
+ * if return_format == proto
834
+ `resources_pb2.EvalTestSetEntry`
835
+
836
+ * if return_format == array
837
+ `Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
838
+ y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
839
+ - if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
840
+ - if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
841
+
842
+ * if return_format == coco: Applicable only for 'visual-detector'
843
+ `Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
844
+
845
+ Example Usages:
846
+ ------
847
+ * Evaluate `visual-classifier` using sklearn
848
+
849
+ ```python
850
+ import os
851
+ from sklearn.metrics import accuracy_score
852
+ from sklearn.metrics import classification_report
853
+ import numpy as np
854
+ from clarifai.client.model import Model
855
+ from clarifai.client.dataset import Dataset
856
+ os.environ["CLARIFAI_PAT"] = "???"
857
+ model = Model(url="url/of/model/includes/version-id")
858
+ dataset = Dataset(dataset_id="dataset-id")
859
+ y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
860
+ y = np.argmax(y, axis=1)
861
+ y_pred = np.argmax(y_pred, axis=1)
862
+ report = classification_report(y, y_pred, target_names=clss)
863
+ print(report)
864
+ acc = accuracy_score(y, y_pred)
865
+ print("acc ", acc)
866
+ ```
867
+
868
+ * Evaluate `visual-detector` using COCOeval
869
+
870
+ ```python
871
+ import os
872
+ import json
873
+ from pycocotools.coco import COCO
874
+ from pycocotools.cocoeval import COCOeval
875
+ from clarifai.client.model import Model
876
+ from clarifai.client.dataset import Dataset
877
+ os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
878
+ model = Model(url=model_url)
879
+ dataset = Dataset(url=dataset_url)
880
+ y, y_pred = model.get_raw_eval(dataset, return_format="coco")
881
+ # save as files to load in COCO API
882
+ def save_annot(d, path):
883
+ with open(path, "w") as fp:
884
+ json.dump(d, fp, indent=2)
885
+ gt_path = os.path.join("gt.json")
886
+ pred_path = os.path.join("pred.json")
887
+ save_annot(y, gt_path)
888
+ save_annot(y_pred, pred_path)
889
+
890
+ cocoGt = COCO(gt_path)
891
+ cocoPred = COCO(pred_path)
892
+ cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
893
+ cocoEval.evaluate()
894
+ cocoEval.accumulate()
895
+ cocoEval.summarize() # Print out result of all classes with all area type
896
+ # Example:
897
+ # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
898
+ # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
899
+ # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
900
+ # ...
901
+ ```
902
+
903
+ """
904
+ from clarifai.utils.evaluation.testset_annotation_parser import (
905
+ parse_eval_annotation_classifier, parse_eval_annotation_detector,
906
+ parse_eval_annotation_detector_coco)
907
+
908
+ valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
909
+ supported_format = ['proto', 'array', 'coco']
910
+ assert return_format in supported_format, ValueError(
911
+ f"Expected return_format in {supported_format}, got {return_format}")
912
+ self.load_info()
913
+ model_type_id = self.model_info.model_type_id
914
+ assert model_type_id in valid_model_types, \
915
+ f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
916
+ assert not (dataset and
917
+ eval_id), "Using both `dataset` and `eval_id`, but only one should be passed."
918
+ assert not dataset or not eval_id, "Please provide either `dataset` or `eval_id`, but nothing was passed."
919
+ if model_type_id.endswith("-classifier") and return_format == "coco":
920
+ raise ValueError(
921
+ f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
922
+ )
923
+
924
+ if dataset:
925
+ eval_by_ds = self.get_eval_by_dataset(dataset)
926
+ if len(eval_by_ds) == 0:
927
+ raise Exception(f"Model is not valuated with dataset: {dataset}")
928
+ eval_id = eval_by_ds[0].id
929
+
930
+ detail_eval_data = self.get_eval_by_id(eval_id=eval_id, test_set=True, metrics_by_class=True)
931
+
932
+ if return_format == "proto":
933
+ return detail_eval_data.test_set
934
+ else:
935
+ if model_type_id.endswith("-classifier"):
936
+ return parse_eval_annotation_classifier(detail_eval_data)
937
+ elif model_type_id == "visual-detector":
938
+ if return_format == "array":
939
+ return parse_eval_annotation_detector(detail_eval_data)
940
+ elif return_format == "coco":
941
+ return parse_eval_annotation_detector_coco(detail_eval_data)
942
+
764
943
  def export(self, export_dir: str = None) -> None:
765
944
  """Export the model, stores the exported model as model.tar file
766
945
 
@@ -830,7 +1009,7 @@ class Model(Lister, BaseClient):
830
1009
  )
831
1010
  time.sleep(5)
832
1011
  start_time = time.time()
833
- backoff_iterator = BackoffIterator()
1012
+ backoff_iterator = BackoffIterator(10)
834
1013
  while True:
835
1014
  get_export_response = _get_export_response()
836
1015
  if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \
@@ -841,6 +1020,7 @@ class Model(Lister, BaseClient):
841
1020
  time.sleep(next(backoff_iterator))
842
1021
  elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
843
1022
  _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
1023
+ break
844
1024
  elif time.time() - start_time > 60 * 30:
845
1025
  raise Exception(
846
1026
  f"""Model Export took too long. Please try again or contact support@clarifai.com
clarifai/client/module.py CHANGED
@@ -19,6 +19,7 @@ class Module(Lister, BaseClient):
19
19
  base_url: str = "https://api.clarifai.com",
20
20
  pat: str = None,
21
21
  token: str = None,
22
+ root_certificates_path: str = None,
22
23
  **kwargs):
23
24
  """Initializes a Module object.
24
25
 
@@ -29,6 +30,7 @@ class Module(Lister, BaseClient):
29
30
  base_url (str): Base API url. Default "https://api.clarifai.com"
30
31
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
31
32
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
33
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
32
34
  **kwargs: Additional keyword arguments to be passed to the Module.
33
35
  """
34
36
  if url and module_id:
@@ -44,7 +46,13 @@ class Module(Lister, BaseClient):
44
46
  self.module_info = resources_pb2.Module(**self.kwargs)
45
47
  self.logger = get_logger(logger_level="INFO", name=__name__)
46
48
  BaseClient.__init__(
47
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
49
+ self,
50
+ user_id=self.user_id,
51
+ app_id=self.app_id,
52
+ base=base_url,
53
+ pat=pat,
54
+ token=token,
55
+ root_certificates_path=root_certificates_path)
48
56
  Lister.__init__(self)
49
57
 
50
58
  def list_versions(self, page_no: int = None,
clarifai/client/search.py CHANGED
@@ -24,7 +24,8 @@ class Search(Lister, BaseClient):
24
24
  metric: str = DEFAULT_SEARCH_METRIC,
25
25
  base_url: str = "https://api.clarifai.com",
26
26
  pat: str = None,
27
- token: str = None):
27
+ token: str = None,
28
+ root_certificates_path: str = None):
28
29
  """Initialize the Search object.
29
30
 
30
31
  Args:
@@ -35,6 +36,7 @@ class Search(Lister, BaseClient):
35
36
  base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
36
37
  pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
37
38
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
39
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
38
40
 
39
41
  Raises:
40
42
  UserError: If the metric is not 'cosine' or 'euclidean'.
@@ -52,7 +54,13 @@ class Search(Lister, BaseClient):
52
54
  user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
53
55
  self.rank_filter_schema = get_schema()
54
56
  BaseClient.__init__(
55
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
57
+ self,
58
+ user_id=self.user_id,
59
+ app_id=self.app_id,
60
+ base=base_url,
61
+ pat=pat,
62
+ token=token,
63
+ root_certificates_path=root_certificates_path)
56
64
  Lister.__init__(self, page_size=1000)
57
65
 
58
66
  def _get_annot_proto(self, **kwargs):
clarifai/client/user.py CHANGED
@@ -7,7 +7,6 @@ from google.protobuf.json_format import MessageToDict
7
7
  from clarifai.client.app import App
8
8
  from clarifai.client.base import BaseClient
9
9
  from clarifai.client.lister import Lister
10
- from clarifai.client.runner import Runner
11
10
  from clarifai.errors import UserError
12
11
  from clarifai.utils.logging import get_logger
13
12
 
@@ -20,6 +19,7 @@ class User(Lister, BaseClient):
20
19
  base_url: str = "https://api.clarifai.com",
21
20
  pat: str = None,
22
21
  token: str = None,
22
+ root_certificates_path: str = None,
23
23
  **kwargs):
24
24
  """Initializes an User object.
25
25
 
@@ -28,12 +28,20 @@ class User(Lister, BaseClient):
28
28
  base_url (str): Base API url. Default "https://api.clarifai.com"
29
29
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
30
30
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
31
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
31
32
  **kwargs: Additional keyword arguments to be passed to the User.
32
33
  """
33
34
  self.kwargs = {**kwargs, 'id': user_id}
34
35
  self.user_info = resources_pb2.User(**self.kwargs)
35
36
  self.logger = get_logger(logger_level="INFO", name=__name__)
36
- BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat, token=token)
37
+ BaseClient.__init__(
38
+ self,
39
+ user_id=self.id,
40
+ app_id="",
41
+ base=base_url,
42
+ pat=pat,
43
+ token=token,
44
+ root_certificates_path=root_certificates_path)
37
45
  Lister.__init__(self)
38
46
 
39
47
  def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
@@ -69,7 +77,7 @@ class User(Lister, BaseClient):
69
77
  **app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
70
78
 
71
79
  def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
72
- per_page: int = None) -> Generator[Runner, None, None]:
80
+ per_page: int = None) -> Generator[dict, None, None]:
73
81
  """List all runners for the user
74
82
 
75
83
  Args:
@@ -78,7 +86,7 @@ class User(Lister, BaseClient):
78
86
  per_page (int): The number of items per page.
79
87
 
80
88
  Yields:
81
- Runner: Runner objects for the runners.
89
+ Dict: Dictionaries containing information about the runners.
82
90
 
83
91
  Example:
84
92
  >>> from clarifai.client.user import User
@@ -98,8 +106,7 @@ class User(Lister, BaseClient):
98
106
  page_no=page_no)
99
107
 
100
108
  for runner_info in all_runners_info:
101
- yield Runner.from_auth_helper(
102
- auth=self.auth_helper, check_runner_exists=False, **runner_info)
109
+ yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
103
110
 
104
111
  def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
105
112
  """Creates an app for the user.
@@ -127,7 +134,7 @@ class User(Lister, BaseClient):
127
134
  self.logger.info("\nApp created\n%s", response.status)
128
135
  return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
129
136
 
130
- def create_runner(self, runner_id: str, labels: List[str], description: str) -> Runner:
137
+ def create_runner(self, runner_id: str, labels: List[str], description: str) -> dict:
131
138
  """Create a runner
132
139
 
133
140
  Args:
@@ -136,13 +143,14 @@ class User(Lister, BaseClient):
136
143
  description (str): Description of Runner
137
144
 
138
145
  Returns:
139
- Runner: A runner object for the specified Runner ID
146
+ Dict: A dictionary containing information about the specified Runner ID.
140
147
 
141
148
  Example:
142
149
  >>> from clarifai.client.user import User
143
150
  >>> client = User(user_id="user_id")
144
- >>> runner = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
151
+ >>> runner_info = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
145
152
  """
153
+
146
154
  if not isinstance(labels, List):
147
155
  raise UserError("Labels must be a List of strings")
148
156
 
@@ -155,7 +163,7 @@ class User(Lister, BaseClient):
155
163
  raise Exception(response.status)
156
164
  self.logger.info("\nRunner created\n%s", response.status)
157
165
 
158
- return Runner.from_auth_helper(
166
+ return dict(
159
167
  auth=self.auth_helper,
160
168
  runner_id=runner_id,
161
169
  user_id=self.id,
@@ -186,19 +194,19 @@ class User(Lister, BaseClient):
186
194
  kwargs['user_id'] = self.id
187
195
  return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
188
196
 
189
- def runner(self, runner_id: str) -> Runner:
197
+ def runner(self, runner_id: str) -> dict:
190
198
  """Returns a Runner object if exists.
191
199
 
192
200
  Args:
193
201
  runner_id (str): The runner ID to interact with
194
202
 
195
203
  Returns:
196
- Runner: A Runner object for the existing runner ID.
204
+ Dict: A dictionary containing information about the existing runner ID.
197
205
 
198
206
  Example:
199
207
  >>> from clarifai.client.user import User
200
208
  >>> client = User(user_id="user_id")
201
- >>> runner = client.runner(runner_id="runner_id")
209
+ >>> runner_info = client.runner(runner_id="runner_id")
202
210
  """
203
211
  request = service_pb2.GetRunnerRequest(user_app_id=self.user_app_id, runner_id=runner_id)
204
212
  response = self._grpc_request(self.STUB.GetRunner, request)
@@ -212,7 +220,7 @@ class User(Lister, BaseClient):
212
220
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
213
221
  list(dict_response.keys())[1])
214
222
 
215
- return Runner.from_auth_helper(self.auth_helper, check_runner_exists=False, **kwargs)
223
+ return dict(self.auth_helper, check_runner_exists=False, **kwargs)
216
224
 
217
225
  def delete_app(self, app_id: str) -> None:
218
226
  """Deletes an app for the user.
@@ -28,6 +28,7 @@ class Workflow(Lister, BaseClient):
28
28
  base_url: str = "https://api.clarifai.com",
29
29
  pat: str = None,
30
30
  token: str = None,
31
+ root_certificates_path: str = None,
31
32
  **kwargs):
32
33
  """Initializes a Workflow object.
33
34
 
@@ -43,6 +44,7 @@ class Workflow(Lister, BaseClient):
43
44
  base_url (str): Base API url. Default "https://api.clarifai.com"
44
45
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
46
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
47
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
46
48
  **kwargs: Additional keyword arguments to be passed to the Workflow.
47
49
  """
48
50
  if url and workflow_id:
@@ -59,7 +61,13 @@ class Workflow(Lister, BaseClient):
59
61
  self.workflow_info = resources_pb2.Workflow(**self.kwargs)
60
62
  self.logger = get_logger(logger_level="INFO", name=__name__)
61
63
  BaseClient.__init__(
62
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
64
+ self,
65
+ user_id=self.user_id,
66
+ app_id=self.app_id,
67
+ base=base_url,
68
+ pat=pat,
69
+ token=token,
70
+ root_certificates_path=root_certificates_path)
63
71
  Lister.__init__(self)
64
72
 
65
73
  def predict(self, inputs: List[Input], workflow_state_id: str = None):
@@ -83,7 +91,7 @@ class Workflow(Lister, BaseClient):
83
91
  request.workflow_state.id = workflow_state_id
84
92
 
85
93
  start_time = time.time()
86
- backoff_iterator = BackoffIterator()
94
+ backoff_iterator = BackoffIterator(10)
87
95
 
88
96
  while True:
89
97
  response = self._grpc_request(self.STUB.PostWorkflowResults, request)
@@ -0,0 +1 @@
1
+ MAX_UPLOAD_BATCH_SIZE = 128
@@ -21,7 +21,7 @@ logger = get_logger("INFO", __name__)
21
21
  class DatasetExportReader:
22
22
 
23
23
  def __init__(self,
24
- session: requests.Session,
24
+ session: requests.Session = None,
25
25
  archive_url: Optional[str] = None,
26
26
  local_archive_path: Optional[str] = None):
27
27
  """Download/Reads the zipfile archive and yields every api.Input object.
@@ -31,9 +31,11 @@ class DatasetExportReader:
31
31
  archive_url: URL of the DatasetVersionExport archive
32
32
  local_archive_path: Path to the DatasetVersionExport archive
33
33
  """
34
- self.input_count = 0
34
+ self.input_count = None
35
35
  self.temp_file = None
36
36
  self.session = session
37
+ if not self.session:
38
+ self.session = requests.Session()
37
39
 
38
40
  assert archive_url or local_archive_path, UserError(
39
41
  "Either archive_url or local_archive_path must be provided.")
@@ -59,7 +61,8 @@ class DatasetExportReader:
59
61
  def _download_temp_archive(self, archive_url: str,
60
62
  chunk_size: int = 128) -> tempfile.TemporaryFile:
61
63
  """Downloads the temp archive of InputBatches."""
62
- r = self.session.get(archive_url, stream=True)
64
+ session = requests.Session()
65
+ r = session.get(archive_url, stream=True)
63
66
  temp_file = tempfile.TemporaryFile()
64
67
  for chunk in r.iter_content(chunk_size=chunk_size):
65
68
  temp_file.write(chunk)
@@ -67,10 +70,12 @@ class DatasetExportReader:
67
70
  return temp_file
68
71
 
69
72
  def __len__(self) -> int:
70
- if not self.input_count:
73
+ if self.input_count is None:
74
+ input_count = 0
71
75
  if self.file_name_list is not None:
72
76
  for filename in self.file_name_list:
73
- self.input_count += int(filename.split('_n')[-1])
77
+ input_count += int(filename.split('_n')[-1])
78
+ self.input_count = input_count
74
79
 
75
80
  return self.input_count
76
81
 
@@ -111,7 +116,8 @@ class InputAnnotationDownloader:
111
116
  """
112
117
  self.input_iterator = input_iterator
113
118
  self.num_workers = min(num_workers, 10) # Max 10 threads
114
- self.num_inputs_annotations = 0
119
+ self.num_inputs = 0
120
+ self.num_annotations = 0
115
121
  self.split_prefix = None
116
122
  self.session = session
117
123
  self.input_ext = dict(image=".png", text=".txt", audio=".mp3", video=".mp4")
@@ -182,14 +188,14 @@ class InputAnnotationDownloader:
182
188
  self._save_audio_to_archive(new_archive, hosted_url, file_name)
183
189
  elif input_type == "video":
184
190
  self._save_video_to_archive(new_archive, hosted_url, file_name)
185
- self.num_inputs_annotations += 1
191
+ self.num_inputs += 1
186
192
 
187
193
  if data_dict.get("concepts") or data_dict.get("regions"):
188
194
  file_name = os.path.join(split, "annotations", input_.id + ".json")
189
195
  annot_data = data_dict.get("concepts") or data_dict.get("regions")
190
196
 
191
197
  self._save_annotation_to_archive(new_archive, annot_data, file_name)
192
- self.num_inputs_annotations += 1
198
+ self.num_annotations += 1
193
199
 
194
200
  def _check_output_archive(self, save_path: str) -> None:
195
201
  try:
@@ -198,8 +204,8 @@ class InputAnnotationDownloader:
198
204
  raise e
199
205
  assert len(
200
206
  archive.namelist()
201
- ) == self.num_inputs_annotations, "Archive has %d inputs+annotations | expecting %d inputs+annotations" % (
202
- len(archive.namelist()), self.num_inputs_annotations)
207
+ ) == self.num_inputs + self.num_annotations, "Archive has %d inputs+annotations | expecting %d inputs+annotations" % (
208
+ len(archive.namelist()), self.num_inputs + self.num_annotations)
203
209
 
204
210
  def download_archive(self, save_path: str, split: Optional[str] = None) -> None:
205
211
  """Downloads the archive from the URL into an archive of inputs, annotations in the directory format
@@ -218,5 +224,5 @@ class InputAnnotationDownloader:
218
224
  progress.update()
219
225
 
220
226
  self._check_output_archive(save_path)
221
- logger.info("Downloaded %d inputs+annotations to %s" % (self.num_inputs_annotations,
222
- save_path))
227
+ logger.info("Downloaded %d inputs and %d annotations to %s" %
228
+ (self.num_inputs, self.num_annotations, save_path))