clarifai 10.1.0__py3-none-any.whl → 10.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/app.py CHANGED
@@ -32,6 +32,7 @@ class App(Lister, BaseClient):
32
32
  app_id: str = None,
33
33
  base_url: str = "https://api.clarifai.com",
34
34
  pat: str = None,
35
+ token: str = None,
35
36
  **kwargs):
36
37
  """Initializes an App object.
37
38
 
@@ -40,6 +41,7 @@ class App(Lister, BaseClient):
40
41
  app_id (str): The App ID for the App to interact with.
41
42
  base_url (str): Base API url. Default "https://api.clarifai.com"
42
43
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
44
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
43
45
  **kwargs: Additional keyword arguments to be passed to the App.
44
46
  - name (str): The name of the app.
45
47
  - description (str): The description of the app.
@@ -52,7 +54,8 @@ class App(Lister, BaseClient):
52
54
  self.kwargs = {**kwargs, 'id': app_id}
53
55
  self.app_info = resources_pb2.App(**self.kwargs)
54
56
  self.logger = get_logger(logger_level="INFO", name=__name__)
55
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.id, base=base_url, pat=pat)
57
+ BaseClient.__init__(
58
+ self, user_id=self.user_id, app_id=self.id, base=base_url, pat=pat, token=token)
56
59
  Lister.__init__(self)
57
60
 
58
61
  def list_datasets(self, page_no: int = None,
@@ -85,7 +88,7 @@ class App(Lister, BaseClient):
85
88
  for dataset_info in all_datasets_info:
86
89
  if 'version' in list(dataset_info.keys()):
87
90
  del dataset_info['version']['metrics']
88
- yield Dataset(base_url=self.base, pat=self.pat, **dataset_info)
91
+ yield Dataset.from_auth_helper(auth=self.auth_helper, **dataset_info)
89
92
 
90
93
  def list_models(self,
91
94
  filter_by: Dict[str, Any] = {},
@@ -126,7 +129,7 @@ class App(Lister, BaseClient):
126
129
  if only_in_app:
127
130
  if model_info['app_id'] != self.id:
128
131
  continue
129
- yield Model(base_url=self.base, pat=self.pat, **model_info)
132
+ yield Model.from_auth_helper(auth=self.auth_helper, **model_info)
130
133
 
131
134
  def list_workflows(self,
132
135
  filter_by: Dict[str, Any] = {},
@@ -165,7 +168,7 @@ class App(Lister, BaseClient):
165
168
  if only_in_app:
166
169
  if workflow_info['app_id'] != self.id:
167
170
  continue
168
- yield Workflow(base_url=self.base, pat=self.pat, **workflow_info)
171
+ yield Workflow.from_auth_helper(auth=self.auth_helper, **workflow_info)
169
172
 
170
173
  def list_modules(self,
171
174
  filter_by: Dict[str, Any] = {},
@@ -204,7 +207,7 @@ class App(Lister, BaseClient):
204
207
  if only_in_app:
205
208
  if module_info['app_id'] != self.id:
206
209
  continue
207
- yield Module(base_url=self.base, pat=self.pat, **module_info)
210
+ yield Module.from_auth_helper(auth=self.auth_helper, **module_info)
208
211
 
209
212
  def list_installed_module_versions(self,
210
213
  filter_by: Dict[str, Any] = {},
@@ -239,11 +242,8 @@ class App(Lister, BaseClient):
239
242
  for imv_info in all_imv_infos:
240
243
  del imv_info['deploy_url']
241
244
  del imv_info['installed_module_version_id'] # TODO: remove this after the backend fix
242
- yield Module(
243
- module_id=imv_info['module_version']['module_id'],
244
- base_url=self.base,
245
- pat=self.pat,
246
- **imv_info)
245
+ yield Module.from_auth_helper(
246
+ auth=self.auth_helper, module_id=imv_info['module_version']['module_id'], **imv_info)
247
247
 
248
248
  def list_concepts(self, page_no: int = None,
249
249
  per_page: int = None) -> Generator[Concept, None, None]:
@@ -308,14 +308,8 @@ class App(Lister, BaseClient):
308
308
  if response.status.code != status_code_pb2.SUCCESS:
309
309
  raise Exception(response.status)
310
310
  self.logger.info("\nDataset created\n%s", response.status)
311
- kwargs.update({
312
- 'app_id': self.id,
313
- 'user_id': self.user_id,
314
- 'base_url': self.base,
315
- 'pat': self.pat
316
- })
317
311
 
318
- return Dataset(dataset_id=dataset_id, **kwargs)
312
+ return Dataset.from_auth_helper(self.auth_helper, dataset_id=dataset_id, **kwargs)
319
313
 
320
314
  def create_model(self, model_id: str, **kwargs) -> Model:
321
315
  """Creates a model for the app.
@@ -339,14 +333,11 @@ class App(Lister, BaseClient):
339
333
  raise Exception(response.status)
340
334
  self.logger.info("\nModel created\n%s", response.status)
341
335
  kwargs.update({
342
- 'app_id': self.id,
343
- 'user_id': self.user_id,
336
+ 'model_id': model_id,
344
337
  'model_type_id': response.model.model_type_id,
345
- 'base_url': self.base,
346
- 'pat': self.pat
347
338
  })
348
339
 
349
- return Model(model_id=model_id, **kwargs)
340
+ return Model.from_auth_helper(auth=self.auth_helper, **kwargs)
350
341
 
351
342
  def create_workflow(self,
352
343
  config_filepath: str,
@@ -436,9 +427,8 @@ class App(Lister, BaseClient):
436
427
  display_workflow_tree(dict_response["workflows"][0]["nodes"])
437
428
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]][0],
438
429
  "workflow")
439
- kwargs.update({'base_url': self.base, 'pat': self.pat})
440
430
 
441
- return Workflow(**kwargs)
431
+ return Workflow.from_auth_helper(auth=self.auth_helper, **kwargs)
442
432
 
443
433
  def create_module(self, module_id: str, description: str, **kwargs) -> Module:
444
434
  """Creates a module for the app.
@@ -464,14 +454,8 @@ class App(Lister, BaseClient):
464
454
  if response.status.code != status_code_pb2.SUCCESS:
465
455
  raise Exception(response.status)
466
456
  self.logger.info("\nModule created\n%s", response.status)
467
- kwargs.update({
468
- 'app_id': self.id,
469
- 'user_id': self.user_id,
470
- 'base_url': self.base,
471
- 'pat': self.pat
472
- })
473
457
 
474
- return Module(module_id=module_id, **kwargs)
458
+ return Module.from_auth_helper(auth=self.auth_helper, module_id=module_id, **kwargs)
475
459
 
476
460
  def dataset(self, dataset_id: str, **kwargs) -> Dataset:
477
461
  """Returns a Dataset object for the existing dataset ID.
@@ -496,8 +480,7 @@ class App(Lister, BaseClient):
496
480
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
497
481
  list(dict_response.keys())[1])
498
482
  kwargs['version'] = response.dataset.version if response.dataset.version else None
499
- kwargs.update({'base_url': self.base, 'pat': self.pat})
500
- return Dataset(**kwargs)
483
+ return Dataset.from_auth_helper(auth=self.auth_helper, **kwargs)
501
484
 
502
485
  def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
503
486
  """Returns a Model object for the existing model ID.
@@ -532,9 +515,8 @@ class App(Lister, BaseClient):
532
515
  kwargs = self.process_response_keys(dict_response['model'], 'model')
533
516
  kwargs[
534
517
  'model_version'] = response.model.model_version if response.model.model_version else None
535
- kwargs.update({'base_url': self.base, 'pat': self.pat})
536
518
 
537
- return Model(**kwargs)
519
+ return Model.from_auth_helper(self.auth_helper, **kwargs)
538
520
 
539
521
  def workflow(self, workflow_id: str, **kwargs) -> Workflow:
540
522
  """Returns a workflow object for the existing workflow ID.
@@ -558,9 +540,8 @@ class App(Lister, BaseClient):
558
540
  dict_response = MessageToDict(response, preserving_proto_field_name=True)
559
541
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
560
542
  list(dict_response.keys())[1])
561
- kwargs.update({'base_url': self.base, 'pat': self.pat})
562
543
 
563
- return Workflow(**kwargs)
544
+ return Workflow.from_auth_helper(auth=self.auth_helper, **kwargs)
564
545
 
565
546
  def module(self, module_id: str, module_version_id: str = "", **kwargs) -> Module:
566
547
  """Returns a Module object for the existing module ID.
@@ -585,9 +566,8 @@ class App(Lister, BaseClient):
585
566
  raise Exception(response.status)
586
567
  dict_response = MessageToDict(response, preserving_proto_field_name=True)
587
568
  kwargs = self.process_response_keys(dict_response['module'], 'module')
588
- kwargs.update({'base_url': self.base, 'pat': self.pat})
589
569
 
590
- return Module(**kwargs)
570
+ return Module.from_auth_helper(auth=self.auth_helper, **kwargs)
591
571
 
592
572
  def inputs(self,):
593
573
  """Returns an Input object.
@@ -595,7 +575,7 @@ class App(Lister, BaseClient):
595
575
  Returns:
596
576
  Inputs: An input object.
597
577
  """
598
- return Inputs(self.user_id, self.id, base_url=self.base, pat=self.pat)
578
+ return Inputs.from_auth_helper(self.auth_helper)
599
579
 
600
580
  def delete_dataset(self, dataset_id: str) -> None:
601
581
  """Deletes an dataset for the user.
@@ -684,9 +664,9 @@ class App(Lister, BaseClient):
684
664
  >>> app = App(app_id="app_id", user_id="user_id")
685
665
  >>> search_client = app.search(top_k=12, metric="euclidean")
686
666
  """
687
- user_id = kwargs.get("user_id", self.user_app_id.user_id)
688
- app_id = kwargs.get("app_id", self.user_app_id.app_id)
689
- return Search(user_id=user_id, app_id=app_id, base_url=self.base, pat=self.pat, **kwargs)
667
+ kwargs.get("user_id", self.user_app_id.user_id)
668
+ kwargs.get("app_id", self.user_app_id.app_id)
669
+ return Search.from_auth_helper(auth=self.auth_helper, **kwargs)
690
670
 
691
671
  def __getattr__(self, name):
692
672
  return getattr(self.app_info, name)
clarifai/client/base.py CHANGED
@@ -7,7 +7,7 @@ from google.protobuf.wrappers_pb2 import BoolValue
7
7
 
8
8
  from clarifai.client.auth import create_stub
9
9
  from clarifai.client.auth.helper import ClarifaiAuthHelper
10
- from clarifai.errors import ApiError
10
+ from clarifai.errors import ApiError, UserError
11
11
  from clarifai.utils.misc import get_from_dict_or_env
12
12
 
13
13
 
@@ -19,9 +19,11 @@ class BaseClient:
19
19
  - user_id (str): A user ID for authentication.
20
20
  - app_id (str): An app ID for the application to interact with.
21
21
  - pat (str): A personal access token for authentication.
22
+ - token (str): A session token for authentication. Accepts either a session token or a pat.
22
23
  - base (str): The base URL for the API endpoint. Defaults to 'https://api.clarifai.com'.
23
24
  - ui (str): The URL for the UI. Defaults to 'https://clarifai.com'.
24
25
 
26
+
25
27
  Attributes:
26
28
  auth_helper (ClarifaiAuthHelper): An instance of ClarifaiAuthHelper for authentication.
27
29
  STUB (Stub): The gRPC Stub object for API interaction.
@@ -31,15 +33,53 @@ class BaseClient:
31
33
  """
32
34
 
33
35
  def __init__(self, **kwargs):
34
- pat = get_from_dict_or_env(key="pat", env_key="CLARIFAI_PAT", **kwargs)
35
- kwargs.update({'pat': pat})
36
+ token, pat = "", ""
37
+ try:
38
+ pat = get_from_dict_or_env(key="pat", env_key="CLARIFAI_PAT", **kwargs)
39
+ except UserError:
40
+ token = get_from_dict_or_env(key="token", env_key="CLARIFAI_SESSION_TOKEN", **kwargs)
41
+ finally:
42
+ assert token or pat, Exception(
43
+ "Need 'pat' or 'token' in args or use one of the CLARIFAI_PAT or CLARIFAI_SESSION_TOKEN env vars"
44
+ )
45
+ kwargs.update({'token': token, 'pat': pat})
46
+
36
47
  self.auth_helper = ClarifaiAuthHelper(**kwargs, validate=False)
37
48
  self.STUB = create_stub(self.auth_helper)
38
49
  self.metadata = self.auth_helper.metadata
39
50
  self.pat = self.auth_helper.pat
51
+ self.token = self.auth_helper._token
40
52
  self.user_app_id = self.auth_helper.get_user_app_id_proto()
41
53
  self.base = self.auth_helper.base
42
54
 
55
+ @classmethod
56
+ def from_auth_helper(cls, auth: ClarifaiAuthHelper, **kwargs):
57
+ default_kwargs = {
58
+ "user_id": kwargs.get("user_id", None) or auth.user_id,
59
+ "app_id": kwargs.get("app_id", None) or auth.app_id,
60
+ "pat": kwargs.get("pat", None) or auth.pat,
61
+ "token": kwargs.get("token", None) or auth._token,
62
+ }
63
+ _base = kwargs.get("base", None) or auth.base
64
+ _clss = cls.__mro__[0]
65
+ if _clss == BaseClient:
66
+ kwargs = {
67
+ **default_kwargs,
68
+ "base": _base, # Baseclient uses `base`
69
+ "ui": kwargs.get("ui", None) or auth.ui
70
+ }
71
+ else:
72
+ # Remove user_id and app_id if a custom URL is provided
73
+ if kwargs.get("url"):
74
+ default_kwargs.pop("user_id", "")
75
+ default_kwargs.pop("app_id", "")
76
+ # Remove app_id if the class name contains "Runner"
77
+ if 'Runner' in _clss.__name__:
78
+ default_kwargs.pop("app_id", "")
79
+ kwargs.update({**default_kwargs, "base_url": _base})
80
+
81
+ return cls(**kwargs)
82
+
43
83
  def _grpc_request(self, method: Callable, argument: Any):
44
84
  """Makes a gRPC request to the API.
45
85
 
@@ -52,7 +92,7 @@ class BaseClient:
52
92
  """
53
93
 
54
94
  try:
55
- res = method(argument)
95
+ res = method(argument, metadata=self.auth_helper.metadata)
56
96
  # MessageToDict(res) TODO global debug logger
57
97
  return res
58
98
  except ApiError:
@@ -43,6 +43,7 @@ class Dataset(Lister, BaseClient):
43
43
  dataset_id: str = None,
44
44
  base_url: str = "https://api.clarifai.com",
45
45
  pat: str = None,
46
+ token: str = None,
46
47
  **kwargs):
47
48
  """Initializes a Dataset object.
48
49
 
@@ -51,6 +52,7 @@ class Dataset(Lister, BaseClient):
51
52
  dataset_id (str): The Dataset ID within the App to interact with.
52
53
  base_url (str): Base API url. Default "https://api.clarifai.com"
53
54
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
55
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
54
56
  **kwargs: Additional keyword arguments to be passed to the Dataset.
55
57
  """
56
58
  if url and dataset_id:
@@ -66,9 +68,10 @@ class Dataset(Lister, BaseClient):
66
68
  self.max_retires = 10
67
69
  self.batch_size = 128 # limit max protos in a req
68
70
  self.task = None # Upload dataset type
69
- self.input_object = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat)
71
+ self.input_object = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat, token=token)
70
72
  self.logger = get_logger(logger_level="INFO", name=__name__)
71
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
73
+ BaseClient.__init__(
74
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
72
75
  Lister.__init__(self)
73
76
 
74
77
  def create_version(self, **kwargs) -> 'Dataset':
@@ -98,13 +101,10 @@ class Dataset(Lister, BaseClient):
98
101
  self.logger.info("\nDataset Version created\n%s", response.status)
99
102
  kwargs.update({
100
103
  'dataset_id': self.id,
101
- 'app_id': self.app_id,
102
- 'user_id': self.user_id,
103
104
  'version': response.dataset_versions[0],
104
- 'base_url': self.base,
105
- 'pat': self.pat
106
105
  })
107
- return Dataset(**kwargs)
106
+
107
+ return Dataset.from_auth_helper(self.auth_helper, **kwargs)
108
108
 
109
109
  def delete_version(self, version_id: str) -> None:
110
110
  """Deletes a dataset version for the Dataset.
@@ -162,13 +162,9 @@ class Dataset(Lister, BaseClient):
162
162
  del dataset_version_info['metrics']
163
163
  kwargs = {
164
164
  'dataset_id': self.id,
165
- 'app_id': self.app_id,
166
- 'user_id': self.user_id,
167
165
  'version': resources_pb2.DatasetVersion(**dataset_version_info),
168
- 'base_url': self.base,
169
- 'pat': self.pat
170
166
  }
171
- yield Dataset(**kwargs)
167
+ yield Dataset.from_auth_helper(self.auth_helper, **kwargs)
172
168
 
173
169
  def _concurrent_annot_upload(self, annots: List[List[resources_pb2.Annotation]]
174
170
  ) -> Union[List[resources_pb2.Annotation], List[None]]:
clarifai/client/input.py CHANGED
@@ -32,6 +32,7 @@ class Inputs(Lister, BaseClient):
32
32
  logger_level: str = "INFO",
33
33
  base_url: str = "https://api.clarifai.com",
34
34
  pat: str = None,
35
+ token: str = None,
35
36
  **kwargs):
36
37
  """Initializes an Input object.
37
38
 
@@ -39,6 +40,8 @@ class Inputs(Lister, BaseClient):
39
40
  user_id (str): A user ID for authentication.
40
41
  app_id (str): An app ID for the application to interact with.
41
42
  base_url (str): Base API url. Default "https://api.clarifai.com"
43
+ pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
44
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
42
45
  **kwargs: Additional keyword arguments to be passed to the Input
43
46
  """
44
47
  self.user_id = user_id
@@ -46,7 +49,8 @@ class Inputs(Lister, BaseClient):
46
49
  self.kwargs = {**kwargs}
47
50
  self.input_info = resources_pb2.Input(**self.kwargs)
48
51
  self.logger = get_logger(logger_level=logger_level, name=__name__)
49
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
52
+ BaseClient.__init__(
53
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
50
54
  Lister.__init__(self)
51
55
 
52
56
  @staticmethod
@@ -670,6 +674,30 @@ class Inputs(Lister, BaseClient):
670
674
 
671
675
  return input_job_id, response
672
676
 
677
+ def patch_inputs(self, inputs: List[Input], action: str = 'merge') -> str:
678
+ """Patch list of input objects to the app.
679
+
680
+ Args:
681
+ inputs (list): List of input objects to upload.
682
+ action (str): Action to perform on the input. Options: 'merge', 'overwrite', 'remove'.
683
+
684
+ Returns:
685
+ response: Response from the grpc request.
686
+ """
687
+ if not isinstance(inputs, list):
688
+ raise UserError("inputs must be a list of Input objects")
689
+ uuid.uuid4().hex # generate a unique id for this job
690
+ request = service_pb2.PatchInputsRequest(
691
+ user_app_id=self.user_app_id, inputs=inputs, action=action)
692
+ response = self._grpc_request(self.STUB.PatchInputs, request)
693
+ if response.status.code != status_code_pb2.SUCCESS:
694
+ try:
695
+ self.logger.warning(f"Patch inputs failed, status: {response.annotations[0].status}")
696
+ except Exception:
697
+ self.logger.warning(f"Patch inputs failed, status: {response.status.details}")
698
+
699
+ self.logger.info("\nPatch Inputs Successful\n%s", response.status)
700
+
673
701
  def upload_annotations(self, batch_annot: List[resources_pb2.Annotation], show_log: bool = True
674
702
  ) -> Union[List[resources_pb2.Annotation], List[None]]:
675
703
  """Upload image annotations to app.
clarifai/client/model.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import time
3
- from typing import Any, Dict, Generator, List
3
+ from typing import Any, Dict, Generator, List, Union
4
4
 
5
5
  import requests
6
6
  import yaml
@@ -32,6 +32,7 @@ class Model(Lister, BaseClient):
32
32
  model_version: Dict = {'id': ""},
33
33
  base_url: str = "https://api.clarifai.com",
34
34
  pat: str = None,
35
+ token: str = None,
35
36
  **kwargs):
36
37
  """Initializes a Model object.
37
38
 
@@ -41,6 +42,7 @@ class Model(Lister, BaseClient):
41
42
  model_version (dict): The Model Version to interact with.
42
43
  base_url (str): Base API url. Default "https://api.clarifai.com"
43
44
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
44
46
  **kwargs: Additional keyword arguments to be passed to the Model.
45
47
  """
46
48
  if url and model_id:
@@ -55,7 +57,8 @@ class Model(Lister, BaseClient):
55
57
  self.model_info = resources_pb2.Model(**self.kwargs)
56
58
  self.logger = get_logger(logger_level="INFO", name=__name__)
57
59
  self.training_params = {}
58
- BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat)
60
+ BaseClient.__init__(
61
+ self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
59
62
  Lister.__init__(self)
60
63
 
61
64
  def list_training_templates(self) -> List[str]:
@@ -212,6 +215,8 @@ class Model(Lister, BaseClient):
212
215
  >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
213
216
  >>> model.train('model_params.yaml')
214
217
  """
218
+ if not self.model_info.model_type_id:
219
+ self.load_info()
215
220
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
216
221
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
217
222
  if not yaml_file and len(self.training_params) == 0:
@@ -222,8 +227,10 @@ class Model(Lister, BaseClient):
222
227
  params_dict = yaml.safe_load(file)
223
228
  else:
224
229
  params_dict = self.training_params
225
-
226
- train_dict = params_parser(params_dict)
230
+ #getting all the concepts for the model type
231
+ if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
232
+ concepts = self._list_concepts()
233
+ train_dict = params_parser(params_dict, concepts)
227
234
  request = service_pb2.PostModelVersionsRequest(
228
235
  user_app_id=self.user_app_id,
229
236
  model_id=self.id,
@@ -331,7 +338,7 @@ class Model(Lister, BaseClient):
331
338
  dict_response = MessageToDict(response, preserving_proto_field_name=True)
332
339
  kwargs = self.process_response_keys(dict_response['model'], 'model')
333
340
 
334
- return Model(base_url=self.base, pat=self.pat, **kwargs)
341
+ return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
335
342
 
336
343
  def list_versions(self, page_no: int = None,
337
344
  per_page: int = None) -> Generator['Model', None, None]:
@@ -373,11 +380,8 @@ class Model(Lister, BaseClient):
373
380
  del model_version_info['train_info']['dataset']['version']['metrics']
374
381
  except KeyError:
375
382
  pass
376
- yield Model(
377
- model_id=self.id,
378
- base_url=self.base,
379
- pat=self.pat,
380
- **dict(self.kwargs, model_version=model_version_info))
383
+ yield Model.from_auth_helper(
384
+ model_id=self.id, **dict(self.kwargs, model_version=model_version_info))
381
385
 
382
386
  def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
383
387
  """Predicts the model based on the given inputs.
@@ -548,6 +552,17 @@ class Model(Lister, BaseClient):
548
552
  resources_pb2.OutputInfo(
549
553
  output_config=resources_pb2.OutputConfig(**output_config), params=params))
550
554
 
555
+ def _list_concepts(self) -> List[str]:
556
+ """Lists all the concepts for the model type.
557
+
558
+ Returns:
559
+ concepts (List): List of concepts for the model type.
560
+ """
561
+ request_data = dict(user_app_id=self.user_app_id)
562
+ all_concepts_infos = self.list_pages_generator(self.STUB.ListConcepts,
563
+ service_pb2.ListConceptsRequest, request_data)
564
+ return [concept_info['concept_id'] for concept_info in all_concepts_infos]
565
+
551
566
  def load_info(self) -> None:
552
567
  """Loads the model info."""
553
568
  request = service_pb2.GetModelRequest(
@@ -576,3 +591,169 @@ class Model(Lister, BaseClient):
576
591
  if hasattr(self.model_info, param)
577
592
  ]
578
593
  return f"Model Details: \n{', '.join(attribute_strings)}\n"
594
+
595
+ def list_evaluations(self) -> resources_pb2.EvalMetrics:
596
+ """List all eval_metrics of current model version
597
+
598
+ Raises:
599
+ Exception: Failed to call API
600
+
601
+ Returns:
602
+ resources_pb2.EvalMetrics
603
+ """
604
+ assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
605
+ request = service_pb2.ListModelVersionEvaluationsRequest(
606
+ user_app_id=self.user_app_id,
607
+ model_id=self.id,
608
+ model_version_id=self.model_info.model_version.id)
609
+ response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
610
+
611
+ if response.status.code != status_code_pb2.SUCCESS:
612
+ raise Exception(response.status)
613
+
614
+ return response.eval_metrics
615
+
616
+ def evaluate(self,
617
+ dataset_id: str,
618
+ dataset_app_id: str = None,
619
+ dataset_user_id: str = None,
620
+ eval_id: str = None,
621
+ extended_metrics: dict = None,
622
+ eval_info: dict = None) -> resources_pb2.EvalMetrics:
623
+ """ Run evaluation
624
+
625
+ Args:
626
+ dataset_id (str): Dataset Id.
627
+ dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
628
+ dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
629
+ eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
630
+ extended_metrics (dict): user custom metrics result. Default is None.
631
+ eval_info (dict): custom eval info. Default is empty dict.
632
+
633
+ Return
634
+ eval_metrics
635
+
636
+ """
637
+ assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
638
+ metrics = None
639
+ if isinstance(extended_metrics, dict):
640
+ metrics = Struct()
641
+ metrics.update(extended_metrics)
642
+ metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
643
+
644
+ eval_info_params = None
645
+ if isinstance(eval_info, dict):
646
+ eval_info_params = Struct()
647
+ eval_info_params.update(eval_info)
648
+ eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
649
+
650
+ eval_metric = resources_pb2.EvalMetrics(
651
+ id=eval_id,
652
+ model=resources_pb2.Model(
653
+ id=self.id,
654
+ app_id=self.auth_helper.app_id,
655
+ user_id=self.auth_helper.user_id,
656
+ model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
657
+ ),
658
+ extended_metrics=metrics,
659
+ ground_truth_dataset=resources_pb2.Dataset(
660
+ id=dataset_id,
661
+ app_id=dataset_app_id or self.auth_helper.app_id,
662
+ user_id=dataset_user_id or self.auth_helper.user_id,
663
+ ),
664
+ eval_info=eval_info_params,
665
+ )
666
+ request = service_pb2.PostEvaluationsRequest(
667
+ user_app_id=self.user_app_id,
668
+ eval_metrics=[eval_metric],
669
+ )
670
+ response = self._grpc_request(self.STUB.PostEvaluations, request)
671
+ if response.status.code != status_code_pb2.SUCCESS:
672
+ raise Exception(response.status)
673
+ self.logger.info(
674
+ "\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
675
+ )
676
+
677
+ return response.eval_metrics
678
+
679
+ def get_eval_by_id(
680
+ self,
681
+ eval_id: str,
682
+ label_counts=False,
683
+ test_set=False,
684
+ binary_metrics=False,
685
+ confusion_matrix=False,
686
+ metrics_by_class=False,
687
+ metrics_by_area=False,
688
+ ) -> resources_pb2.EvalMetrics:
689
+ """Get detail eval_metrics by eval_id with extra metric fields
690
+
691
+ Args:
692
+ eval_id (str): eval id
693
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
694
+ test_set (bool, optional): Set True to get test set. Defaults to False.
695
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
696
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
697
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
698
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
699
+
700
+ Raises:
701
+ Exception: Failed to call API
702
+
703
+ Returns:
704
+ resources_pb2.EvalMetrics: eval_metrics
705
+ """
706
+ request = service_pb2.GetEvaluationRequest(
707
+ user_app_id=self.user_app_id,
708
+ evaluation_id=eval_id,
709
+ fields=resources_pb2.FieldsValue(
710
+ label_counts=label_counts,
711
+ test_set=test_set,
712
+ binary_metrics=binary_metrics,
713
+ confusion_matrix=confusion_matrix,
714
+ metrics_by_class=metrics_by_class,
715
+ metrics_by_area=metrics_by_area,
716
+ ))
717
+ response = self._grpc_request(self.STUB.GetEvaluation, request)
718
+
719
+ if response.status.code != status_code_pb2.SUCCESS:
720
+ raise Exception(response.status)
721
+
722
+ return response.eval_metrics
723
+
724
+ def get_latest_eval(self,
725
+ label_counts=False,
726
+ test_set=False,
727
+ binary_metrics=False,
728
+ confusion_matrix=False,
729
+ metrics_by_class=False,
730
+ metrics_by_area=False) -> Union[resources_pb2.EvalMetrics, None]:
731
+ """
732
+ Run `get_eval_by_id` method with latest `eval_id`
733
+
734
+ Args:
735
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
736
+ test_set (bool, optional): Set True to get test set. Defaults to False.
737
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
738
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
739
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
740
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
741
+
742
+ Returns:
743
+ eval_metric if model is evaluated otherwise None.
744
+
745
+ """
746
+
747
+ _latest = self.list_evaluations()[0]
748
+ result = None
749
+ if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
750
+ result = self.get_eval_by_id(
751
+ eval_id=_latest.id,
752
+ label_counts=label_counts,
753
+ test_set=test_set,
754
+ binary_metrics=binary_metrics,
755
+ confusion_matrix=confusion_matrix,
756
+ metrics_by_class=metrics_by_class,
757
+ metrics_by_area=metrics_by_area)
758
+
759
+ return result