clarifai 10.2.1__py3-none-any.whl → 10.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/app.py CHANGED
@@ -33,6 +33,7 @@ class App(Lister, BaseClient):
33
33
  base_url: str = "https://api.clarifai.com",
34
34
  pat: str = None,
35
35
  token: str = None,
36
+ root_certificates_path: str = None,
36
37
  **kwargs):
37
38
  """Initializes an App object.
38
39
 
@@ -42,6 +43,7 @@ class App(Lister, BaseClient):
42
43
  base_url (str): Base API url. Default "https://api.clarifai.com"
43
44
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
44
45
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
46
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
45
47
  **kwargs: Additional keyword arguments to be passed to the App.
46
48
  - name (str): The name of the app.
47
49
  - description (str): The description of the app.
@@ -55,7 +57,13 @@ class App(Lister, BaseClient):
55
57
  self.app_info = resources_pb2.App(**self.kwargs)
56
58
  self.logger = get_logger(logger_level="INFO", name=__name__)
57
59
  BaseClient.__init__(
58
- self, user_id=self.user_id, app_id=self.id, base=base_url, pat=pat, token=token)
60
+ self,
61
+ user_id=self.user_id,
62
+ app_id=self.id,
63
+ base=base_url,
64
+ pat=pat,
65
+ token=token,
66
+ root_certificates_path=root_certificates_path)
59
67
  Lister.__init__(self)
60
68
 
61
69
  def list_datasets(self, page_no: int = None,
@@ -374,8 +382,8 @@ class App(Lister, BaseClient):
374
382
  output_info = get_yaml_output_info_proto(node['model'].get('output_info', None))
375
383
  try:
376
384
  model = self.model(
377
- node['model']['model_id'],
378
- node['model'].get('model_version_id', ""),
385
+ model_id=node['model']['model_id'],
386
+ model_version={"id": node['model'].get('model_version_id', "")},
379
387
  user_id=node['model'].get('user_id', ""),
380
388
  app_id=node['model'].get('app_id', ""))
381
389
  except Exception as e:
@@ -485,7 +493,7 @@ class App(Lister, BaseClient):
485
493
  kwargs['dataset_version_id'] = dataset_version_id
486
494
  return Dataset.from_auth_helper(auth=self.auth_helper, **kwargs)
487
495
 
488
- def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
496
+ def model(self, model_id: str, model_version: Dict = {'id': ""}, **kwargs) -> Model:
489
497
  """Returns a Model object for the existing model ID.
490
498
 
491
499
  Args:
@@ -498,7 +506,7 @@ class App(Lister, BaseClient):
498
506
  Example:
499
507
  >>> from clarifai.client.app import App
500
508
  >>> app = App(app_id="app_id", user_id="user_id")
501
- >>> model_v1 = app.model(model_id="model_id", model_version_id="model_version_id")
509
+ >>> model_v1 = app.model(model_id="model_id", model_version={"id":"model_version_id")
502
510
  """
503
511
  # Change user_app_id based on whether user_id or app_id is specified.
504
512
  if kwargs.get("user_id") or kwargs.get("app_id"):
@@ -506,10 +514,10 @@ class App(Lister, BaseClient):
506
514
  user_app_id=self.auth_helper.get_user_app_id_proto(
507
515
  kwargs.get("user_id"), kwargs.get("app_id")),
508
516
  model_id=model_id,
509
- version_id=model_version_id)
517
+ version_id=model_version["id"])
510
518
  else:
511
519
  request = service_pb2.GetModelRequest(
512
- user_app_id=self.user_app_id, model_id=model_id, version_id=model_version_id)
520
+ user_app_id=self.user_app_id, model_id=model_id, version_id=model_version["id"])
513
521
  response = self._grpc_request(self.STUB.GetModel, request)
514
522
 
515
523
  if response.status.code != status_code_pb2.SUCCESS:
@@ -65,6 +65,7 @@ class ClarifaiAuthHelper:
65
65
  token: str = "",
66
66
  base: str = DEFAULT_BASE,
67
67
  ui: str = DEFAULT_UI,
68
+ root_certificates_path: str = None,
68
69
  validate: bool = True,
69
70
  ):
70
71
  """
@@ -85,6 +86,7 @@ class ClarifaiAuthHelper:
85
86
  https://api.clarifai.com (default), https://host:port, http://host:port, host:port (will be treated as http, not https). It's highly recommended to include the http:// or https:// otherwise we need to check the endpoint to determine if it has SSL during this __init__
86
87
  ui: a url to the UI. Examples include clarifai.com,
87
88
  https://clarifai.com (default), https://host:port, http://host:port, host:port (will be treated as http, not https). It's highly recommended to include the http:// or https:// otherwise we need to check the endpoint to determine if it has SSL during this __init__
89
+ root_certificates_path: path to the root certificates file. This is only used for grpc secure channels.
88
90
  validate: whether to validate the inputs. This is useful for overriding vars then validating
89
91
  """
90
92
 
@@ -92,6 +94,7 @@ class ClarifaiAuthHelper:
92
94
  self.app_id = app_id
93
95
  self._pat = pat
94
96
  self._token = token
97
+ self._root_certificates_path = root_certificates_path
95
98
 
96
99
  self.set_base(base)
97
100
  self.set_ui(ui)
@@ -113,6 +116,8 @@ class ClarifaiAuthHelper:
113
116
  raise Exception(
114
117
  "Need 'pat' or 'token' in the query params or use one of the CLARIFAI_PAT or CLARIFAI_SESSION_TOKEN env vars"
115
118
  )
119
+ if (self._root_certificates_path) and (not os.path.exists(self._root_certificates_path)):
120
+ raise Exception("Root certificates path %s does not exist" % self._root_certificates_path)
116
121
 
117
122
  @classmethod
118
123
  def from_streamlit(cls, st: Any) -> "ClarifaiAuthHelper":
@@ -219,6 +224,8 @@ Additionally, these optional params are supported:
219
224
  self.set_base(query_params["base"][0])
220
225
  if "ui" in query_params:
221
226
  self.set_ui(query_params["ui"][0])
227
+ if "root_certificates_path" in query_params:
228
+ self._root_certificates_path = query_params["root_certificates_path"][0]
222
229
 
223
230
  @classmethod
224
231
  def from_env(cls, validate: bool = True) -> "ClarifaiAuthHelper":
@@ -229,6 +236,7 @@ Additionally, these optional params are supported:
229
236
  token: CLARIFAI_SESSION_TOKEN env var.
230
237
  pat: CLARIFAI_PAT env var.
231
238
  base: CLARIFAI_API_BASE env var.
239
+ root_certificates_path: CLARIFAI_ROOT_CERTIFICATES_PATH env var.
232
240
  """
233
241
  user_id = os.environ.get("CLARIFAI_USER_ID", "")
234
242
  app_id = os.environ.get("CLARIFAI_APP_ID", "")
@@ -236,7 +244,8 @@ Additionally, these optional params are supported:
236
244
  pat = os.environ.get("CLARIFAI_PAT", "")
237
245
  base = os.environ.get("CLARIFAI_API_BASE", DEFAULT_BASE)
238
246
  ui = os.environ.get("CLARIFAI_UI", DEFAULT_UI)
239
- return cls(user_id, app_id, pat, token, base, ui, validate)
247
+ root_certificates_path = os.environ.get("CLARIFAI_ROOT_CERTIFICATES_PATH", None)
248
+ return cls(user_id, app_id, pat, token, base, ui, root_certificates_path, validate)
240
249
 
241
250
  def get_user_app_id_proto(
242
251
  self,
@@ -281,7 +290,8 @@ Additionally, these optional params are supported:
281
290
 
282
291
  https = base_https_cache[self._base]
283
292
  if https:
284
- channel = ClarifaiChannel.get_grpc_channel(base=self._base)
293
+ channel = ClarifaiChannel.get_grpc_channel(
294
+ base=self._base, root_certificates_path=self._root_certificates_path)
285
295
  else:
286
296
  if self._base.find(":") >= 0:
287
297
  host, port = self._base.split(":")
clarifai/client/base.py CHANGED
@@ -22,6 +22,7 @@ class BaseClient:
22
22
  - token (str): A session token for authentication. Accepts either a session token or a pat.
23
23
  - base (str): The base URL for the API endpoint. Defaults to 'https://api.clarifai.com'.
24
24
  - ui (str): The URL for the UI. Defaults to 'https://clarifai.com'.
25
+ - root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
25
26
 
26
27
 
27
28
  Attributes:
@@ -51,14 +52,21 @@ class BaseClient:
51
52
  self.token = self.auth_helper._token
52
53
  self.user_app_id = self.auth_helper.get_user_app_id_proto()
53
54
  self.base = self.auth_helper.base
55
+ self.root_certificates_path = self.auth_helper._root_certificates_path
54
56
 
55
57
  @classmethod
56
58
  def from_auth_helper(cls, auth: ClarifaiAuthHelper, **kwargs):
57
59
  default_kwargs = {
58
- "user_id": kwargs.get("user_id", None) or auth.user_id,
59
- "app_id": kwargs.get("app_id", None) or auth.app_id,
60
- "pat": kwargs.get("pat", None) or auth.pat,
61
- "token": kwargs.get("token", None) or auth._token,
60
+ "user_id":
61
+ kwargs.get("user_id", None) or auth.user_id,
62
+ "app_id":
63
+ kwargs.get("app_id", None) or auth.app_id,
64
+ "pat":
65
+ kwargs.get("pat", None) or auth.pat,
66
+ "token":
67
+ kwargs.get("token", None) or auth._token,
68
+ "root_certificates_path":
69
+ kwargs.get("root_certificates_path", None) or auth._root_certificates_path
62
70
  }
63
71
  _base = kwargs.get("base", None) or auth.base
64
72
  _clss = cls.__mro__[0]
@@ -160,6 +168,8 @@ class BaseClient:
160
168
  value = value_s
161
169
  elif key == 'metrics':
162
170
  continue
171
+ elif key == 'size':
172
+ value = int(value)
163
173
  elif key in ['metadata']:
164
174
  if isinstance(value, dict) and value != {}:
165
175
  value_s = struct_pb2.Struct()
@@ -47,6 +47,7 @@ class Dataset(Lister, BaseClient):
47
47
  base_url: str = "https://api.clarifai.com",
48
48
  pat: str = None,
49
49
  token: str = None,
50
+ root_certificates_path: str = None,
50
51
  **kwargs):
51
52
  """Initializes a Dataset object.
52
53
 
@@ -57,6 +58,7 @@ class Dataset(Lister, BaseClient):
57
58
  base_url (str): Base API url. Default "https://api.clarifai.com"
58
59
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
59
60
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
61
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
60
62
  **kwargs: Additional keyword arguments to be passed to the Dataset.
61
63
  """
62
64
  if url and dataset_id:
@@ -77,10 +79,21 @@ class Dataset(Lister, BaseClient):
77
79
  self.batch_size = 128 # limit max protos in a req
78
80
  self.task = None # Upload dataset type
79
81
  self.input_object = Inputs(
80
- user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
82
+ user_id=self.user_id,
83
+ app_id=self.app_id,
84
+ pat=pat,
85
+ token=token,
86
+ base_url=base_url,
87
+ root_certificates_path=root_certificates_path)
81
88
  self.logger = get_logger(logger_level="INFO", name=__name__)
82
89
  BaseClient.__init__(
83
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
90
+ self,
91
+ user_id=self.user_id,
92
+ app_id=self.app_id,
93
+ base=base_url,
94
+ pat=pat,
95
+ token=token,
96
+ root_certificates_path=root_certificates_path)
84
97
  Lister.__init__(self)
85
98
 
86
99
  def create_version(self, **kwargs) -> 'Dataset':
clarifai/client/input.py CHANGED
@@ -19,6 +19,7 @@ from tqdm import tqdm
19
19
  from clarifai.client.base import BaseClient
20
20
  from clarifai.client.lister import Lister
21
21
  from clarifai.constants.dataset import MAX_RETRIES
22
+ from clarifai.constants.input import MAX_UPLOAD_BATCH_SIZE
22
23
  from clarifai.errors import UserError
23
24
  from clarifai.utils.logging import get_logger
24
25
  from clarifai.utils.misc import BackoffIterator, Chunker
@@ -34,6 +35,7 @@ class Inputs(Lister, BaseClient):
34
35
  base_url: str = "https://api.clarifai.com",
35
36
  pat: str = None,
36
37
  token: str = None,
38
+ root_certificates_path: str = None,
37
39
  **kwargs):
38
40
  """Initializes an Input object.
39
41
 
@@ -43,6 +45,7 @@ class Inputs(Lister, BaseClient):
43
45
  base_url (str): Base API url. Default "https://api.clarifai.com"
44
46
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
47
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
48
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
46
49
  **kwargs: Additional keyword arguments to be passed to the Input
47
50
  """
48
51
  self.user_id = user_id
@@ -51,7 +54,13 @@ class Inputs(Lister, BaseClient):
51
54
  self.input_info = resources_pb2.Input(**self.kwargs)
52
55
  self.logger = get_logger(logger_level=logger_level, name=__name__)
53
56
  BaseClient.__init__(
54
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
57
+ self,
58
+ user_id=self.user_id,
59
+ app_id=self.app_id,
60
+ base=base_url,
61
+ pat=pat,
62
+ token=token,
63
+ root_certificates_path=root_certificates_path)
55
64
  Lister.__init__(self)
56
65
 
57
66
  @staticmethod
@@ -660,6 +669,10 @@ class Inputs(Lister, BaseClient):
660
669
  """
661
670
  if not isinstance(inputs, list):
662
671
  raise UserError("inputs must be a list of Input objects")
672
+ if len(inputs) > MAX_UPLOAD_BATCH_SIZE:
673
+ raise UserError(
674
+ f"Number of inputs to upload exceeds the maximum batch size of {MAX_UPLOAD_BATCH_SIZE}. Please reduce batch size."
675
+ )
663
676
  input_job_id = uuid.uuid4().hex # generate a unique id for this job
664
677
  request = service_pb2.PostInputsRequest(
665
678
  user_app_id=self.user_app_id, inputs=inputs, inputs_add_job_id=input_job_id)
clarifai/client/model.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  import time
3
- from typing import Any, Dict, Generator, List, Union
3
+ from typing import Any, Dict, Generator, List, Tuple, Union
4
4
 
5
+ import numpy as np
5
6
  import requests
6
7
  import yaml
7
8
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
@@ -12,6 +13,7 @@ from google.protobuf.struct_pb2 import Struct
12
13
  from tqdm import tqdm
13
14
 
14
15
  from clarifai.client.base import BaseClient
16
+ from clarifai.client.dataset import Dataset
15
17
  from clarifai.client.input import Inputs
16
18
  from clarifai.client.lister import Lister
17
19
  from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS, TRAINABLE_MODEL_TYPES
@@ -34,6 +36,7 @@ class Model(Lister, BaseClient):
34
36
  base_url: str = "https://api.clarifai.com",
35
37
  pat: str = None,
36
38
  token: str = None,
39
+ root_certificates_path: str = None,
37
40
  **kwargs):
38
41
  """Initializes a Model object.
39
42
 
@@ -44,6 +47,7 @@ class Model(Lister, BaseClient):
44
47
  base_url (str): Base API url. Default "https://api.clarifai.com"
45
48
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
46
49
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
50
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
47
51
  **kwargs: Additional keyword arguments to be passed to the Model.
48
52
  """
49
53
  if url and model_id:
@@ -59,7 +63,13 @@ class Model(Lister, BaseClient):
59
63
  self.logger = get_logger(logger_level="INFO", name=__name__)
60
64
  self.training_params = {}
61
65
  BaseClient.__init__(
62
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
66
+ self,
67
+ user_id=self.user_id,
68
+ app_id=self.app_id,
69
+ base=base_url,
70
+ pat=pat,
71
+ token=token,
72
+ root_certificates_path=root_certificates_path)
63
73
  Lister.__init__(self)
64
74
 
65
75
  def list_training_templates(self) -> List[str]:
@@ -243,7 +253,7 @@ class Model(Lister, BaseClient):
243
253
 
244
254
  return response.model.model_version.id
245
255
 
246
- def training_status(self, version_id: str, training_logs: bool = False) -> Dict[str, str]:
256
+ def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
247
257
  """Get the training status for the model version. Also stores training logs
248
258
 
249
259
  Args:
@@ -258,19 +268,20 @@ class Model(Lister, BaseClient):
258
268
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
259
269
  >>> model.training_status(version_id='version_id',training_logs=True)
260
270
  """
271
+ if not version_id and not self.model_info.model_version.id:
272
+ raise UserError(
273
+ "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
274
+ )
275
+
276
+ if not self.model_info.model_type_id or not self.model_info.model_version.train_log:
277
+ self.load_info()
261
278
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
262
279
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
263
280
 
264
- request = service_pb2.GetModelVersionRequest(
265
- user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
266
- response = self._grpc_request(self.STUB.GetModelVersion, request)
267
- if response.status.code != status_code_pb2.SUCCESS:
268
- raise Exception(response.status)
269
-
270
281
  if training_logs:
271
282
  try:
272
- if response.model_version.train_log:
273
- log_response = requests.get(response.model_version.train_log)
283
+ if self.model_info.model_version.train_log:
284
+ log_response = requests.get(self.model_info.model_version.train_log)
274
285
  log_response.raise_for_status() # Check for any HTTP errors
275
286
  with open(version_id + '.log', 'wb') as file:
276
287
  for chunk in log_response.iter_content(chunk_size=4096): # 4KB
@@ -280,7 +291,7 @@ class Model(Lister, BaseClient):
280
291
  except requests.exceptions.RequestException as e:
281
292
  raise Exception(f"An error occurred while getting training logs: {e}")
282
293
 
283
- return response.model_version.status
294
+ return self.model_info.model_version.status
284
295
 
285
296
  def delete_version(self, version_id: str) -> None:
286
297
  """Deletes a model version for the Model.
@@ -617,18 +628,22 @@ class Model(Lister, BaseClient):
617
628
  return response.eval_metrics
618
629
 
619
630
  def evaluate(self,
620
- dataset_id: str,
631
+ dataset: Dataset = None,
632
+ dataset_id: str = None,
621
633
  dataset_app_id: str = None,
622
634
  dataset_user_id: str = None,
635
+ dataset_version_id: str = None,
623
636
  eval_id: str = None,
624
637
  extended_metrics: dict = None,
625
638
  eval_info: dict = None) -> resources_pb2.EvalMetrics:
626
639
  """ Run evaluation
627
640
 
628
641
  Args:
629
- dataset_id (str): Dataset Id.
642
+ dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
643
+ dataset_id (str): Dataset Id. Default is None.
630
644
  dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
631
645
  dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
646
+ dataset_version_id (str): Dataset version Id. Default is None.
632
647
  eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
633
648
  extended_metrics (dict): user custom metrics result. Default is None.
634
649
  eval_info (dict): custom eval info. Default is empty dict.
@@ -638,6 +653,23 @@ class Model(Lister, BaseClient):
638
653
 
639
654
  """
640
655
  assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
656
+
657
+ if dataset:
658
+ self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
659
+ dataset_id = dataset.id
660
+ dataset_app_id = dataset.app_id
661
+ dataset_user_id = dataset.user_id
662
+ dataset_version_id = dataset.version.id
663
+ else:
664
+ self.logger.warning(
665
+ "Arguments prefixed with `dataset_` will be removed soon, please use dataset")
666
+
667
+ gt_dataset = resources_pb2.Dataset(
668
+ id=dataset_id,
669
+ app_id=dataset_app_id or self.auth_helper.app_id,
670
+ user_id=dataset_user_id or self.auth_helper.user_id,
671
+ version=resources_pb2.DatasetVersion(id=dataset_version_id))
672
+
641
673
  metrics = None
642
674
  if isinstance(extended_metrics, dict):
643
675
  metrics = Struct()
@@ -659,11 +691,7 @@ class Model(Lister, BaseClient):
659
691
  model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
660
692
  ),
661
693
  extended_metrics=metrics,
662
- ground_truth_dataset=resources_pb2.Dataset(
663
- id=dataset_id,
664
- app_id=dataset_app_id or self.auth_helper.app_id,
665
- user_id=dataset_user_id or self.auth_helper.user_id,
666
- ),
694
+ ground_truth_dataset=gt_dataset,
667
695
  eval_info=eval_info_params,
668
696
  )
669
697
  request = service_pb2.PostEvaluationsRequest(
@@ -761,6 +789,157 @@ class Model(Lister, BaseClient):
761
789
 
762
790
  return result
763
791
 
792
+ def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
793
+ """Get all eval data of dataset
794
+
795
+ Args:
796
+ dataset (Dataset): Clarifai dataset
797
+
798
+ Returns:
799
+ List[resources_pb2.EvalMetrics]
800
+ """
801
+ _id = dataset.id
802
+ app = dataset.app_id or self.app_id
803
+ user_id = dataset.user_id or self.user_id
804
+ version = dataset.version.id
805
+
806
+ list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
807
+ outputs = []
808
+ for _eval in list_eval:
809
+ if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
810
+ gt_ds = _eval.ground_truth_dataset
811
+ if (_id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id):
812
+ if not version or version == gt_ds.version.id:
813
+ outputs.append(_eval)
814
+
815
+ return outputs
816
+
817
+ def get_raw_eval(self,
818
+ dataset: Dataset = None,
819
+ eval_id: str = None,
820
+ return_format: str = 'array') -> Union[resources_pb2.EvalTestSetEntry, Tuple[
821
+ np.array, np.array, list, List[Input]], Tuple[List[dict], List[dict]]]:
822
+ """Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
823
+
824
+ Args:
825
+ dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
826
+ eval_id (str): Evaluation ID, get eval data of specific eval id.
827
+ return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
828
+
829
+ Returns:
830
+
831
+ Depends on `return_format`.
832
+
833
+ * if return_format == proto
834
+ `resources_pb2.EvalTestSetEntry`
835
+
836
+ * if return_format == array
837
+ `Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
838
+ y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
839
+ - if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
840
+ - if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
841
+
842
+ * if return_format == coco: Applicable only for 'visual-detector'
843
+ `Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
844
+
845
+ Example Usages:
846
+ ------
847
+ * Evaluate `visual-classifier` using sklearn
848
+
849
+ ```python
850
+ import os
851
+ from sklearn.metrics import accuracy_score
852
+ from sklearn.metrics import classification_report
853
+ import numpy as np
854
+ from clarifai.client.model import Model
855
+ from clarifai.client.dataset import Dataset
856
+ os.environ["CLARIFAI_PAT"] = "???"
857
+ model = Model(url="url/of/model/includes/version-id")
858
+ dataset = Dataset(dataset_id="dataset-id")
859
+ y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
860
+ y = np.argmax(y, axis=1)
861
+ y_pred = np.argmax(y_pred, axis=1)
862
+ report = classification_report(y, y_pred, target_names=clss)
863
+ print(report)
864
+ acc = accuracy_score(y, y_pred)
865
+ print("acc ", acc)
866
+ ```
867
+
868
+ * Evaluate `visual-detector` using COCOeval
869
+
870
+ ```python
871
+ import os
872
+ import json
873
+ from pycocotools.coco import COCO
874
+ from pycocotools.cocoeval import COCOeval
875
+ from clarifai.client.model import Model
876
+ from clarifai.client.dataset import Dataset
877
+ os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
878
+ model = Model(url=model_url)
879
+ dataset = Dataset(url=dataset_url)
880
+ y, y_pred = model.get_raw_eval(dataset, return_format="coco")
881
+ # save as files to load in COCO API
882
+ def save_annot(d, path):
883
+ with open(path, "w") as fp:
884
+ json.dump(d, fp, indent=2)
885
+ gt_path = os.path.join("gt.json")
886
+ pred_path = os.path.join("pred.json")
887
+ save_annot(y, gt_path)
888
+ save_annot(y_pred, pred_path)
889
+
890
+ cocoGt = COCO(gt_path)
891
+ cocoPred = COCO(pred_path)
892
+ cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
893
+ cocoEval.evaluate()
894
+ cocoEval.accumulate()
895
+ cocoEval.summarize() # Print out result of all classes with all area type
896
+ # Example:
897
+ # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
898
+ # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
899
+ # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
900
+ # ...
901
+ ```
902
+
903
+ """
904
+ from clarifai.utils.evaluation.testset_annotation_parser import (
905
+ parse_eval_annotation_classifier, parse_eval_annotation_detector,
906
+ parse_eval_annotation_detector_coco)
907
+
908
+ valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
909
+ supported_format = ['proto', 'array', 'coco']
910
+ assert return_format in supported_format, ValueError(
911
+ f"Expected return_format in {supported_format}, got {return_format}")
912
+ self.load_info()
913
+ model_type_id = self.model_info.model_type_id
914
+ assert model_type_id in valid_model_types, \
915
+ f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
916
+ assert not (dataset and
917
+ eval_id), "Using both `dataset` and `eval_id`, but only one should be passed."
918
+ assert not dataset or not eval_id, "Please provide either `dataset` or `eval_id`, but nothing was passed."
919
+ if model_type_id.endswith("-classifier") and return_format == "coco":
920
+ raise ValueError(
921
+ f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
922
+ )
923
+
924
+ if dataset:
925
+ eval_by_ds = self.get_eval_by_dataset(dataset)
926
+ if len(eval_by_ds) == 0:
927
+ raise Exception(f"Model is not valuated with dataset: {dataset}")
928
+ eval_id = eval_by_ds[0].id
929
+
930
+ detail_eval_data = self.get_eval_by_id(eval_id=eval_id, test_set=True, metrics_by_class=True)
931
+
932
+ if return_format == "proto":
933
+ return detail_eval_data.test_set
934
+ else:
935
+ if model_type_id.endswith("-classifier"):
936
+ return parse_eval_annotation_classifier(detail_eval_data)
937
+ elif model_type_id == "visual-detector":
938
+ if return_format == "array":
939
+ return parse_eval_annotation_detector(detail_eval_data)
940
+ elif return_format == "coco":
941
+ return parse_eval_annotation_detector_coco(detail_eval_data)
942
+
764
943
  def export(self, export_dir: str = None) -> None:
765
944
  """Export the model, stores the exported model as model.tar file
766
945
 
clarifai/client/module.py CHANGED
@@ -19,6 +19,7 @@ class Module(Lister, BaseClient):
19
19
  base_url: str = "https://api.clarifai.com",
20
20
  pat: str = None,
21
21
  token: str = None,
22
+ root_certificates_path: str = None,
22
23
  **kwargs):
23
24
  """Initializes a Module object.
24
25
 
@@ -29,6 +30,7 @@ class Module(Lister, BaseClient):
29
30
  base_url (str): Base API url. Default "https://api.clarifai.com"
30
31
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
31
32
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
33
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
32
34
  **kwargs: Additional keyword arguments to be passed to the Module.
33
35
  """
34
36
  if url and module_id:
@@ -44,7 +46,13 @@ class Module(Lister, BaseClient):
44
46
  self.module_info = resources_pb2.Module(**self.kwargs)
45
47
  self.logger = get_logger(logger_level="INFO", name=__name__)
46
48
  BaseClient.__init__(
47
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
49
+ self,
50
+ user_id=self.user_id,
51
+ app_id=self.app_id,
52
+ base=base_url,
53
+ pat=pat,
54
+ token=token,
55
+ root_certificates_path=root_certificates_path)
48
56
  Lister.__init__(self)
49
57
 
50
58
  def list_versions(self, page_no: int = None,
clarifai/client/search.py CHANGED
@@ -24,7 +24,8 @@ class Search(Lister, BaseClient):
24
24
  metric: str = DEFAULT_SEARCH_METRIC,
25
25
  base_url: str = "https://api.clarifai.com",
26
26
  pat: str = None,
27
- token: str = None):
27
+ token: str = None,
28
+ root_certificates_path: str = None):
28
29
  """Initialize the Search object.
29
30
 
30
31
  Args:
@@ -35,6 +36,7 @@ class Search(Lister, BaseClient):
35
36
  base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
36
37
  pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
37
38
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
39
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
38
40
 
39
41
  Raises:
40
42
  UserError: If the metric is not 'cosine' or 'euclidean'.
@@ -52,7 +54,13 @@ class Search(Lister, BaseClient):
52
54
  user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
53
55
  self.rank_filter_schema = get_schema()
54
56
  BaseClient.__init__(
55
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
57
+ self,
58
+ user_id=self.user_id,
59
+ app_id=self.app_id,
60
+ base=base_url,
61
+ pat=pat,
62
+ token=token,
63
+ root_certificates_path=root_certificates_path)
56
64
  Lister.__init__(self, page_size=1000)
57
65
 
58
66
  def _get_annot_proto(self, **kwargs):