clarifai 10.1.0__py3-none-any.whl → 10.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +23 -43
- clarifai/client/base.py +44 -4
- clarifai/client/dataset.py +8 -12
- clarifai/client/input.py +29 -1
- clarifai/client/model.py +191 -10
- clarifai/client/module.py +7 -5
- clarifai/client/runner.py +3 -1
- clarifai/client/search.py +6 -3
- clarifai/client/user.py +14 -12
- clarifai/client/workflow.py +7 -4
- clarifai/datasets/upload/loaders/README.md +3 -4
- clarifai/datasets/upload/loaders/xview_detection.py +5 -5
- clarifai/rag/rag.py +25 -11
- clarifai/rag/utils.py +21 -6
- clarifai/utils/evaluation/__init__.py +427 -0
- clarifai/utils/evaluation/helpers.py +522 -0
- clarifai/utils/model_train.py +3 -1
- clarifai/versions.py +1 -1
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/METADATA +32 -7
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/RECORD +24 -23
- clarifai/datasets/upload/loaders/coco_segmentation.py +0 -98
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/LICENSE +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/WHEEL +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/entry_points.txt +0 -0
- {clarifai-10.1.0.dist-info → clarifai-10.1.1.dist-info}/top_level.txt +0 -0
clarifai/client/app.py
CHANGED
@@ -32,6 +32,7 @@ class App(Lister, BaseClient):
|
|
32
32
|
app_id: str = None,
|
33
33
|
base_url: str = "https://api.clarifai.com",
|
34
34
|
pat: str = None,
|
35
|
+
token: str = None,
|
35
36
|
**kwargs):
|
36
37
|
"""Initializes an App object.
|
37
38
|
|
@@ -40,6 +41,7 @@ class App(Lister, BaseClient):
|
|
40
41
|
app_id (str): The App ID for the App to interact with.
|
41
42
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
42
43
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
44
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
43
45
|
**kwargs: Additional keyword arguments to be passed to the App.
|
44
46
|
- name (str): The name of the app.
|
45
47
|
- description (str): The description of the app.
|
@@ -52,7 +54,8 @@ class App(Lister, BaseClient):
|
|
52
54
|
self.kwargs = {**kwargs, 'id': app_id}
|
53
55
|
self.app_info = resources_pb2.App(**self.kwargs)
|
54
56
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
55
|
-
BaseClient.__init__(
|
57
|
+
BaseClient.__init__(
|
58
|
+
self, user_id=self.user_id, app_id=self.id, base=base_url, pat=pat, token=token)
|
56
59
|
Lister.__init__(self)
|
57
60
|
|
58
61
|
def list_datasets(self, page_no: int = None,
|
@@ -85,7 +88,7 @@ class App(Lister, BaseClient):
|
|
85
88
|
for dataset_info in all_datasets_info:
|
86
89
|
if 'version' in list(dataset_info.keys()):
|
87
90
|
del dataset_info['version']['metrics']
|
88
|
-
yield Dataset(
|
91
|
+
yield Dataset.from_auth_helper(auth=self.auth_helper, **dataset_info)
|
89
92
|
|
90
93
|
def list_models(self,
|
91
94
|
filter_by: Dict[str, Any] = {},
|
@@ -126,7 +129,7 @@ class App(Lister, BaseClient):
|
|
126
129
|
if only_in_app:
|
127
130
|
if model_info['app_id'] != self.id:
|
128
131
|
continue
|
129
|
-
yield Model(
|
132
|
+
yield Model.from_auth_helper(auth=self.auth_helper, **model_info)
|
130
133
|
|
131
134
|
def list_workflows(self,
|
132
135
|
filter_by: Dict[str, Any] = {},
|
@@ -165,7 +168,7 @@ class App(Lister, BaseClient):
|
|
165
168
|
if only_in_app:
|
166
169
|
if workflow_info['app_id'] != self.id:
|
167
170
|
continue
|
168
|
-
yield Workflow(
|
171
|
+
yield Workflow.from_auth_helper(auth=self.auth_helper, **workflow_info)
|
169
172
|
|
170
173
|
def list_modules(self,
|
171
174
|
filter_by: Dict[str, Any] = {},
|
@@ -204,7 +207,7 @@ class App(Lister, BaseClient):
|
|
204
207
|
if only_in_app:
|
205
208
|
if module_info['app_id'] != self.id:
|
206
209
|
continue
|
207
|
-
yield Module(
|
210
|
+
yield Module.from_auth_helper(auth=self.auth_helper, **module_info)
|
208
211
|
|
209
212
|
def list_installed_module_versions(self,
|
210
213
|
filter_by: Dict[str, Any] = {},
|
@@ -239,11 +242,8 @@ class App(Lister, BaseClient):
|
|
239
242
|
for imv_info in all_imv_infos:
|
240
243
|
del imv_info['deploy_url']
|
241
244
|
del imv_info['installed_module_version_id'] # TODO: remove this after the backend fix
|
242
|
-
yield Module(
|
243
|
-
module_id=imv_info['module_version']['module_id'],
|
244
|
-
base_url=self.base,
|
245
|
-
pat=self.pat,
|
246
|
-
**imv_info)
|
245
|
+
yield Module.from_auth_helper(
|
246
|
+
auth=self.auth_helper, module_id=imv_info['module_version']['module_id'], **imv_info)
|
247
247
|
|
248
248
|
def list_concepts(self, page_no: int = None,
|
249
249
|
per_page: int = None) -> Generator[Concept, None, None]:
|
@@ -308,14 +308,8 @@ class App(Lister, BaseClient):
|
|
308
308
|
if response.status.code != status_code_pb2.SUCCESS:
|
309
309
|
raise Exception(response.status)
|
310
310
|
self.logger.info("\nDataset created\n%s", response.status)
|
311
|
-
kwargs.update({
|
312
|
-
'app_id': self.id,
|
313
|
-
'user_id': self.user_id,
|
314
|
-
'base_url': self.base,
|
315
|
-
'pat': self.pat
|
316
|
-
})
|
317
311
|
|
318
|
-
return Dataset(dataset_id=dataset_id, **kwargs)
|
312
|
+
return Dataset.from_auth_helper(self.auth_helper, dataset_id=dataset_id, **kwargs)
|
319
313
|
|
320
314
|
def create_model(self, model_id: str, **kwargs) -> Model:
|
321
315
|
"""Creates a model for the app.
|
@@ -339,14 +333,11 @@ class App(Lister, BaseClient):
|
|
339
333
|
raise Exception(response.status)
|
340
334
|
self.logger.info("\nModel created\n%s", response.status)
|
341
335
|
kwargs.update({
|
342
|
-
'
|
343
|
-
'user_id': self.user_id,
|
336
|
+
'model_id': model_id,
|
344
337
|
'model_type_id': response.model.model_type_id,
|
345
|
-
'base_url': self.base,
|
346
|
-
'pat': self.pat
|
347
338
|
})
|
348
339
|
|
349
|
-
return Model(
|
340
|
+
return Model.from_auth_helper(auth=self.auth_helper, **kwargs)
|
350
341
|
|
351
342
|
def create_workflow(self,
|
352
343
|
config_filepath: str,
|
@@ -436,9 +427,8 @@ class App(Lister, BaseClient):
|
|
436
427
|
display_workflow_tree(dict_response["workflows"][0]["nodes"])
|
437
428
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]][0],
|
438
429
|
"workflow")
|
439
|
-
kwargs.update({'base_url': self.base, 'pat': self.pat})
|
440
430
|
|
441
|
-
return Workflow(**kwargs)
|
431
|
+
return Workflow.from_auth_helper(auth=self.auth_helper, **kwargs)
|
442
432
|
|
443
433
|
def create_module(self, module_id: str, description: str, **kwargs) -> Module:
|
444
434
|
"""Creates a module for the app.
|
@@ -464,14 +454,8 @@ class App(Lister, BaseClient):
|
|
464
454
|
if response.status.code != status_code_pb2.SUCCESS:
|
465
455
|
raise Exception(response.status)
|
466
456
|
self.logger.info("\nModule created\n%s", response.status)
|
467
|
-
kwargs.update({
|
468
|
-
'app_id': self.id,
|
469
|
-
'user_id': self.user_id,
|
470
|
-
'base_url': self.base,
|
471
|
-
'pat': self.pat
|
472
|
-
})
|
473
457
|
|
474
|
-
return Module(module_id=module_id, **kwargs)
|
458
|
+
return Module.from_auth_helper(auth=self.auth_helper, module_id=module_id, **kwargs)
|
475
459
|
|
476
460
|
def dataset(self, dataset_id: str, **kwargs) -> Dataset:
|
477
461
|
"""Returns a Dataset object for the existing dataset ID.
|
@@ -496,8 +480,7 @@ class App(Lister, BaseClient):
|
|
496
480
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
497
481
|
list(dict_response.keys())[1])
|
498
482
|
kwargs['version'] = response.dataset.version if response.dataset.version else None
|
499
|
-
|
500
|
-
return Dataset(**kwargs)
|
483
|
+
return Dataset.from_auth_helper(auth=self.auth_helper, **kwargs)
|
501
484
|
|
502
485
|
def model(self, model_id: str, model_version_id: str = "", **kwargs) -> Model:
|
503
486
|
"""Returns a Model object for the existing model ID.
|
@@ -532,9 +515,8 @@ class App(Lister, BaseClient):
|
|
532
515
|
kwargs = self.process_response_keys(dict_response['model'], 'model')
|
533
516
|
kwargs[
|
534
517
|
'model_version'] = response.model.model_version if response.model.model_version else None
|
535
|
-
kwargs.update({'base_url': self.base, 'pat': self.pat})
|
536
518
|
|
537
|
-
return Model(**kwargs)
|
519
|
+
return Model.from_auth_helper(self.auth_helper, **kwargs)
|
538
520
|
|
539
521
|
def workflow(self, workflow_id: str, **kwargs) -> Workflow:
|
540
522
|
"""Returns a workflow object for the existing workflow ID.
|
@@ -558,9 +540,8 @@ class App(Lister, BaseClient):
|
|
558
540
|
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
559
541
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
560
542
|
list(dict_response.keys())[1])
|
561
|
-
kwargs.update({'base_url': self.base, 'pat': self.pat})
|
562
543
|
|
563
|
-
return Workflow(**kwargs)
|
544
|
+
return Workflow.from_auth_helper(auth=self.auth_helper, **kwargs)
|
564
545
|
|
565
546
|
def module(self, module_id: str, module_version_id: str = "", **kwargs) -> Module:
|
566
547
|
"""Returns a Module object for the existing module ID.
|
@@ -585,9 +566,8 @@ class App(Lister, BaseClient):
|
|
585
566
|
raise Exception(response.status)
|
586
567
|
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
587
568
|
kwargs = self.process_response_keys(dict_response['module'], 'module')
|
588
|
-
kwargs.update({'base_url': self.base, 'pat': self.pat})
|
589
569
|
|
590
|
-
return Module(**kwargs)
|
570
|
+
return Module.from_auth_helper(auth=self.auth_helper, **kwargs)
|
591
571
|
|
592
572
|
def inputs(self,):
|
593
573
|
"""Returns an Input object.
|
@@ -595,7 +575,7 @@ class App(Lister, BaseClient):
|
|
595
575
|
Returns:
|
596
576
|
Inputs: An input object.
|
597
577
|
"""
|
598
|
-
return Inputs(self.
|
578
|
+
return Inputs.from_auth_helper(self.auth_helper)
|
599
579
|
|
600
580
|
def delete_dataset(self, dataset_id: str) -> None:
|
601
581
|
"""Deletes an dataset for the user.
|
@@ -684,9 +664,9 @@ class App(Lister, BaseClient):
|
|
684
664
|
>>> app = App(app_id="app_id", user_id="user_id")
|
685
665
|
>>> search_client = app.search(top_k=12, metric="euclidean")
|
686
666
|
"""
|
687
|
-
|
688
|
-
|
689
|
-
return Search(
|
667
|
+
kwargs.get("user_id", self.user_app_id.user_id)
|
668
|
+
kwargs.get("app_id", self.user_app_id.app_id)
|
669
|
+
return Search.from_auth_helper(auth=self.auth_helper, **kwargs)
|
690
670
|
|
691
671
|
def __getattr__(self, name):
|
692
672
|
return getattr(self.app_info, name)
|
clarifai/client/base.py
CHANGED
@@ -7,7 +7,7 @@ from google.protobuf.wrappers_pb2 import BoolValue
|
|
7
7
|
|
8
8
|
from clarifai.client.auth import create_stub
|
9
9
|
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
10
|
-
from clarifai.errors import ApiError
|
10
|
+
from clarifai.errors import ApiError, UserError
|
11
11
|
from clarifai.utils.misc import get_from_dict_or_env
|
12
12
|
|
13
13
|
|
@@ -19,9 +19,11 @@ class BaseClient:
|
|
19
19
|
- user_id (str): A user ID for authentication.
|
20
20
|
- app_id (str): An app ID for the application to interact with.
|
21
21
|
- pat (str): A personal access token for authentication.
|
22
|
+
- token (str): A session token for authentication. Accepts either a session token or a pat.
|
22
23
|
- base (str): The base URL for the API endpoint. Defaults to 'https://api.clarifai.com'.
|
23
24
|
- ui (str): The URL for the UI. Defaults to 'https://clarifai.com'.
|
24
25
|
|
26
|
+
|
25
27
|
Attributes:
|
26
28
|
auth_helper (ClarifaiAuthHelper): An instance of ClarifaiAuthHelper for authentication.
|
27
29
|
STUB (Stub): The gRPC Stub object for API interaction.
|
@@ -31,15 +33,53 @@ class BaseClient:
|
|
31
33
|
"""
|
32
34
|
|
33
35
|
def __init__(self, **kwargs):
|
34
|
-
pat =
|
35
|
-
|
36
|
+
token, pat = "", ""
|
37
|
+
try:
|
38
|
+
pat = get_from_dict_or_env(key="pat", env_key="CLARIFAI_PAT", **kwargs)
|
39
|
+
except UserError:
|
40
|
+
token = get_from_dict_or_env(key="token", env_key="CLARIFAI_SESSION_TOKEN", **kwargs)
|
41
|
+
finally:
|
42
|
+
assert token or pat, Exception(
|
43
|
+
"Need 'pat' or 'token' in args or use one of the CLARIFAI_PAT or CLARIFAI_SESSION_TOKEN env vars"
|
44
|
+
)
|
45
|
+
kwargs.update({'token': token, 'pat': pat})
|
46
|
+
|
36
47
|
self.auth_helper = ClarifaiAuthHelper(**kwargs, validate=False)
|
37
48
|
self.STUB = create_stub(self.auth_helper)
|
38
49
|
self.metadata = self.auth_helper.metadata
|
39
50
|
self.pat = self.auth_helper.pat
|
51
|
+
self.token = self.auth_helper._token
|
40
52
|
self.user_app_id = self.auth_helper.get_user_app_id_proto()
|
41
53
|
self.base = self.auth_helper.base
|
42
54
|
|
55
|
+
@classmethod
|
56
|
+
def from_auth_helper(cls, auth: ClarifaiAuthHelper, **kwargs):
|
57
|
+
default_kwargs = {
|
58
|
+
"user_id": kwargs.get("user_id", None) or auth.user_id,
|
59
|
+
"app_id": kwargs.get("app_id", None) or auth.app_id,
|
60
|
+
"pat": kwargs.get("pat", None) or auth.pat,
|
61
|
+
"token": kwargs.get("token", None) or auth._token,
|
62
|
+
}
|
63
|
+
_base = kwargs.get("base", None) or auth.base
|
64
|
+
_clss = cls.__mro__[0]
|
65
|
+
if _clss == BaseClient:
|
66
|
+
kwargs = {
|
67
|
+
**default_kwargs,
|
68
|
+
"base": _base, # Baseclient uses `base`
|
69
|
+
"ui": kwargs.get("ui", None) or auth.ui
|
70
|
+
}
|
71
|
+
else:
|
72
|
+
# Remove user_id and app_id if a custom URL is provided
|
73
|
+
if kwargs.get("url"):
|
74
|
+
default_kwargs.pop("user_id", "")
|
75
|
+
default_kwargs.pop("app_id", "")
|
76
|
+
# Remove app_id if the class name contains "Runner"
|
77
|
+
if 'Runner' in _clss.__name__:
|
78
|
+
default_kwargs.pop("app_id", "")
|
79
|
+
kwargs.update({**default_kwargs, "base_url": _base})
|
80
|
+
|
81
|
+
return cls(**kwargs)
|
82
|
+
|
43
83
|
def _grpc_request(self, method: Callable, argument: Any):
|
44
84
|
"""Makes a gRPC request to the API.
|
45
85
|
|
@@ -52,7 +92,7 @@ class BaseClient:
|
|
52
92
|
"""
|
53
93
|
|
54
94
|
try:
|
55
|
-
res = method(argument)
|
95
|
+
res = method(argument, metadata=self.auth_helper.metadata)
|
56
96
|
# MessageToDict(res) TODO global debug logger
|
57
97
|
return res
|
58
98
|
except ApiError:
|
clarifai/client/dataset.py
CHANGED
@@ -43,6 +43,7 @@ class Dataset(Lister, BaseClient):
|
|
43
43
|
dataset_id: str = None,
|
44
44
|
base_url: str = "https://api.clarifai.com",
|
45
45
|
pat: str = None,
|
46
|
+
token: str = None,
|
46
47
|
**kwargs):
|
47
48
|
"""Initializes a Dataset object.
|
48
49
|
|
@@ -51,6 +52,7 @@ class Dataset(Lister, BaseClient):
|
|
51
52
|
dataset_id (str): The Dataset ID within the App to interact with.
|
52
53
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
53
54
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
55
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
54
56
|
**kwargs: Additional keyword arguments to be passed to the Dataset.
|
55
57
|
"""
|
56
58
|
if url and dataset_id:
|
@@ -66,9 +68,10 @@ class Dataset(Lister, BaseClient):
|
|
66
68
|
self.max_retires = 10
|
67
69
|
self.batch_size = 128 # limit max protos in a req
|
68
70
|
self.task = None # Upload dataset type
|
69
|
-
self.input_object = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat)
|
71
|
+
self.input_object = Inputs(user_id=self.user_id, app_id=self.app_id, pat=pat, token=token)
|
70
72
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
71
|
-
BaseClient.__init__(
|
73
|
+
BaseClient.__init__(
|
74
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
72
75
|
Lister.__init__(self)
|
73
76
|
|
74
77
|
def create_version(self, **kwargs) -> 'Dataset':
|
@@ -98,13 +101,10 @@ class Dataset(Lister, BaseClient):
|
|
98
101
|
self.logger.info("\nDataset Version created\n%s", response.status)
|
99
102
|
kwargs.update({
|
100
103
|
'dataset_id': self.id,
|
101
|
-
'app_id': self.app_id,
|
102
|
-
'user_id': self.user_id,
|
103
104
|
'version': response.dataset_versions[0],
|
104
|
-
'base_url': self.base,
|
105
|
-
'pat': self.pat
|
106
105
|
})
|
107
|
-
|
106
|
+
|
107
|
+
return Dataset.from_auth_helper(self.auth_helper, **kwargs)
|
108
108
|
|
109
109
|
def delete_version(self, version_id: str) -> None:
|
110
110
|
"""Deletes a dataset version for the Dataset.
|
@@ -162,13 +162,9 @@ class Dataset(Lister, BaseClient):
|
|
162
162
|
del dataset_version_info['metrics']
|
163
163
|
kwargs = {
|
164
164
|
'dataset_id': self.id,
|
165
|
-
'app_id': self.app_id,
|
166
|
-
'user_id': self.user_id,
|
167
165
|
'version': resources_pb2.DatasetVersion(**dataset_version_info),
|
168
|
-
'base_url': self.base,
|
169
|
-
'pat': self.pat
|
170
166
|
}
|
171
|
-
yield Dataset(**kwargs)
|
167
|
+
yield Dataset.from_auth_helper(self.auth_helper, **kwargs)
|
172
168
|
|
173
169
|
def _concurrent_annot_upload(self, annots: List[List[resources_pb2.Annotation]]
|
174
170
|
) -> Union[List[resources_pb2.Annotation], List[None]]:
|
clarifai/client/input.py
CHANGED
@@ -32,6 +32,7 @@ class Inputs(Lister, BaseClient):
|
|
32
32
|
logger_level: str = "INFO",
|
33
33
|
base_url: str = "https://api.clarifai.com",
|
34
34
|
pat: str = None,
|
35
|
+
token: str = None,
|
35
36
|
**kwargs):
|
36
37
|
"""Initializes an Input object.
|
37
38
|
|
@@ -39,6 +40,8 @@ class Inputs(Lister, BaseClient):
|
|
39
40
|
user_id (str): A user ID for authentication.
|
40
41
|
app_id (str): An app ID for the application to interact with.
|
41
42
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
43
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
44
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
42
45
|
**kwargs: Additional keyword arguments to be passed to the Input
|
43
46
|
"""
|
44
47
|
self.user_id = user_id
|
@@ -46,7 +49,8 @@ class Inputs(Lister, BaseClient):
|
|
46
49
|
self.kwargs = {**kwargs}
|
47
50
|
self.input_info = resources_pb2.Input(**self.kwargs)
|
48
51
|
self.logger = get_logger(logger_level=logger_level, name=__name__)
|
49
|
-
BaseClient.__init__(
|
52
|
+
BaseClient.__init__(
|
53
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
50
54
|
Lister.__init__(self)
|
51
55
|
|
52
56
|
@staticmethod
|
@@ -670,6 +674,30 @@ class Inputs(Lister, BaseClient):
|
|
670
674
|
|
671
675
|
return input_job_id, response
|
672
676
|
|
677
|
+
def patch_inputs(self, inputs: List[Input], action: str = 'merge') -> str:
|
678
|
+
"""Patch list of input objects to the app.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
inputs (list): List of input objects to upload.
|
682
|
+
action (str): Action to perform on the input. Options: 'merge', 'overwrite', 'remove'.
|
683
|
+
|
684
|
+
Returns:
|
685
|
+
response: Response from the grpc request.
|
686
|
+
"""
|
687
|
+
if not isinstance(inputs, list):
|
688
|
+
raise UserError("inputs must be a list of Input objects")
|
689
|
+
uuid.uuid4().hex # generate a unique id for this job
|
690
|
+
request = service_pb2.PatchInputsRequest(
|
691
|
+
user_app_id=self.user_app_id, inputs=inputs, action=action)
|
692
|
+
response = self._grpc_request(self.STUB.PatchInputs, request)
|
693
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
694
|
+
try:
|
695
|
+
self.logger.warning(f"Patch inputs failed, status: {response.annotations[0].status}")
|
696
|
+
except Exception:
|
697
|
+
self.logger.warning(f"Patch inputs failed, status: {response.status.details}")
|
698
|
+
|
699
|
+
self.logger.info("\nPatch Inputs Successful\n%s", response.status)
|
700
|
+
|
673
701
|
def upload_annotations(self, batch_annot: List[resources_pb2.Annotation], show_log: bool = True
|
674
702
|
) -> Union[List[resources_pb2.Annotation], List[None]]:
|
675
703
|
"""Upload image annotations to app.
|
clarifai/client/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import time
|
3
|
-
from typing import Any, Dict, Generator, List
|
3
|
+
from typing import Any, Dict, Generator, List, Union
|
4
4
|
|
5
5
|
import requests
|
6
6
|
import yaml
|
@@ -32,6 +32,7 @@ class Model(Lister, BaseClient):
|
|
32
32
|
model_version: Dict = {'id': ""},
|
33
33
|
base_url: str = "https://api.clarifai.com",
|
34
34
|
pat: str = None,
|
35
|
+
token: str = None,
|
35
36
|
**kwargs):
|
36
37
|
"""Initializes a Model object.
|
37
38
|
|
@@ -41,6 +42,7 @@ class Model(Lister, BaseClient):
|
|
41
42
|
model_version (dict): The Model Version to interact with.
|
42
43
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
43
44
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
44
46
|
**kwargs: Additional keyword arguments to be passed to the Model.
|
45
47
|
"""
|
46
48
|
if url and model_id:
|
@@ -55,7 +57,8 @@ class Model(Lister, BaseClient):
|
|
55
57
|
self.model_info = resources_pb2.Model(**self.kwargs)
|
56
58
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
57
59
|
self.training_params = {}
|
58
|
-
BaseClient.__init__(
|
60
|
+
BaseClient.__init__(
|
61
|
+
self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
|
59
62
|
Lister.__init__(self)
|
60
63
|
|
61
64
|
def list_training_templates(self) -> List[str]:
|
@@ -212,6 +215,8 @@ class Model(Lister, BaseClient):
|
|
212
215
|
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
213
216
|
>>> model.train('model_params.yaml')
|
214
217
|
"""
|
218
|
+
if not self.model_info.model_type_id:
|
219
|
+
self.load_info()
|
215
220
|
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
216
221
|
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
217
222
|
if not yaml_file and len(self.training_params) == 0:
|
@@ -222,8 +227,10 @@ class Model(Lister, BaseClient):
|
|
222
227
|
params_dict = yaml.safe_load(file)
|
223
228
|
else:
|
224
229
|
params_dict = self.training_params
|
225
|
-
|
226
|
-
|
230
|
+
#getting all the concepts for the model type
|
231
|
+
if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
|
232
|
+
concepts = self._list_concepts()
|
233
|
+
train_dict = params_parser(params_dict, concepts)
|
227
234
|
request = service_pb2.PostModelVersionsRequest(
|
228
235
|
user_app_id=self.user_app_id,
|
229
236
|
model_id=self.id,
|
@@ -331,7 +338,7 @@ class Model(Lister, BaseClient):
|
|
331
338
|
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
332
339
|
kwargs = self.process_response_keys(dict_response['model'], 'model')
|
333
340
|
|
334
|
-
return Model(base_url=self.base, pat=self.pat, **kwargs)
|
341
|
+
return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
|
335
342
|
|
336
343
|
def list_versions(self, page_no: int = None,
|
337
344
|
per_page: int = None) -> Generator['Model', None, None]:
|
@@ -373,11 +380,8 @@ class Model(Lister, BaseClient):
|
|
373
380
|
del model_version_info['train_info']['dataset']['version']['metrics']
|
374
381
|
except KeyError:
|
375
382
|
pass
|
376
|
-
yield Model(
|
377
|
-
model_id=self.id,
|
378
|
-
base_url=self.base,
|
379
|
-
pat=self.pat,
|
380
|
-
**dict(self.kwargs, model_version=model_version_info))
|
383
|
+
yield Model.from_auth_helper(
|
384
|
+
model_id=self.id, **dict(self.kwargs, model_version=model_version_info))
|
381
385
|
|
382
386
|
def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
|
383
387
|
"""Predicts the model based on the given inputs.
|
@@ -548,6 +552,17 @@ class Model(Lister, BaseClient):
|
|
548
552
|
resources_pb2.OutputInfo(
|
549
553
|
output_config=resources_pb2.OutputConfig(**output_config), params=params))
|
550
554
|
|
555
|
+
def _list_concepts(self) -> List[str]:
|
556
|
+
"""Lists all the concepts for the model type.
|
557
|
+
|
558
|
+
Returns:
|
559
|
+
concepts (List): List of concepts for the model type.
|
560
|
+
"""
|
561
|
+
request_data = dict(user_app_id=self.user_app_id)
|
562
|
+
all_concepts_infos = self.list_pages_generator(self.STUB.ListConcepts,
|
563
|
+
service_pb2.ListConceptsRequest, request_data)
|
564
|
+
return [concept_info['concept_id'] for concept_info in all_concepts_infos]
|
565
|
+
|
551
566
|
def load_info(self) -> None:
|
552
567
|
"""Loads the model info."""
|
553
568
|
request = service_pb2.GetModelRequest(
|
@@ -576,3 +591,169 @@ class Model(Lister, BaseClient):
|
|
576
591
|
if hasattr(self.model_info, param)
|
577
592
|
]
|
578
593
|
return f"Model Details: \n{', '.join(attribute_strings)}\n"
|
594
|
+
|
595
|
+
def list_evaluations(self) -> resources_pb2.EvalMetrics:
|
596
|
+
"""List all eval_metrics of current model version
|
597
|
+
|
598
|
+
Raises:
|
599
|
+
Exception: Failed to call API
|
600
|
+
|
601
|
+
Returns:
|
602
|
+
resources_pb2.EvalMetrics
|
603
|
+
"""
|
604
|
+
assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
605
|
+
request = service_pb2.ListModelVersionEvaluationsRequest(
|
606
|
+
user_app_id=self.user_app_id,
|
607
|
+
model_id=self.id,
|
608
|
+
model_version_id=self.model_info.model_version.id)
|
609
|
+
response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
|
610
|
+
|
611
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
612
|
+
raise Exception(response.status)
|
613
|
+
|
614
|
+
return response.eval_metrics
|
615
|
+
|
616
|
+
def evaluate(self,
|
617
|
+
dataset_id: str,
|
618
|
+
dataset_app_id: str = None,
|
619
|
+
dataset_user_id: str = None,
|
620
|
+
eval_id: str = None,
|
621
|
+
extended_metrics: dict = None,
|
622
|
+
eval_info: dict = None) -> resources_pb2.EvalMetrics:
|
623
|
+
""" Run evaluation
|
624
|
+
|
625
|
+
Args:
|
626
|
+
dataset_id (str): Dataset Id.
|
627
|
+
dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
|
628
|
+
dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
|
629
|
+
eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
|
630
|
+
extended_metrics (dict): user custom metrics result. Default is None.
|
631
|
+
eval_info (dict): custom eval info. Default is empty dict.
|
632
|
+
|
633
|
+
Return
|
634
|
+
eval_metrics
|
635
|
+
|
636
|
+
"""
|
637
|
+
assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
638
|
+
metrics = None
|
639
|
+
if isinstance(extended_metrics, dict):
|
640
|
+
metrics = Struct()
|
641
|
+
metrics.update(extended_metrics)
|
642
|
+
metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
|
643
|
+
|
644
|
+
eval_info_params = None
|
645
|
+
if isinstance(eval_info, dict):
|
646
|
+
eval_info_params = Struct()
|
647
|
+
eval_info_params.update(eval_info)
|
648
|
+
eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
|
649
|
+
|
650
|
+
eval_metric = resources_pb2.EvalMetrics(
|
651
|
+
id=eval_id,
|
652
|
+
model=resources_pb2.Model(
|
653
|
+
id=self.id,
|
654
|
+
app_id=self.auth_helper.app_id,
|
655
|
+
user_id=self.auth_helper.user_id,
|
656
|
+
model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
|
657
|
+
),
|
658
|
+
extended_metrics=metrics,
|
659
|
+
ground_truth_dataset=resources_pb2.Dataset(
|
660
|
+
id=dataset_id,
|
661
|
+
app_id=dataset_app_id or self.auth_helper.app_id,
|
662
|
+
user_id=dataset_user_id or self.auth_helper.user_id,
|
663
|
+
),
|
664
|
+
eval_info=eval_info_params,
|
665
|
+
)
|
666
|
+
request = service_pb2.PostEvaluationsRequest(
|
667
|
+
user_app_id=self.user_app_id,
|
668
|
+
eval_metrics=[eval_metric],
|
669
|
+
)
|
670
|
+
response = self._grpc_request(self.STUB.PostEvaluations, request)
|
671
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
672
|
+
raise Exception(response.status)
|
673
|
+
self.logger.info(
|
674
|
+
"\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
|
675
|
+
)
|
676
|
+
|
677
|
+
return response.eval_metrics
|
678
|
+
|
679
|
+
def get_eval_by_id(
|
680
|
+
self,
|
681
|
+
eval_id: str,
|
682
|
+
label_counts=False,
|
683
|
+
test_set=False,
|
684
|
+
binary_metrics=False,
|
685
|
+
confusion_matrix=False,
|
686
|
+
metrics_by_class=False,
|
687
|
+
metrics_by_area=False,
|
688
|
+
) -> resources_pb2.EvalMetrics:
|
689
|
+
"""Get detail eval_metrics by eval_id with extra metric fields
|
690
|
+
|
691
|
+
Args:
|
692
|
+
eval_id (str): eval id
|
693
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
694
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
695
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
696
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
697
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
698
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
699
|
+
|
700
|
+
Raises:
|
701
|
+
Exception: Failed to call API
|
702
|
+
|
703
|
+
Returns:
|
704
|
+
resources_pb2.EvalMetrics: eval_metrics
|
705
|
+
"""
|
706
|
+
request = service_pb2.GetEvaluationRequest(
|
707
|
+
user_app_id=self.user_app_id,
|
708
|
+
evaluation_id=eval_id,
|
709
|
+
fields=resources_pb2.FieldsValue(
|
710
|
+
label_counts=label_counts,
|
711
|
+
test_set=test_set,
|
712
|
+
binary_metrics=binary_metrics,
|
713
|
+
confusion_matrix=confusion_matrix,
|
714
|
+
metrics_by_class=metrics_by_class,
|
715
|
+
metrics_by_area=metrics_by_area,
|
716
|
+
))
|
717
|
+
response = self._grpc_request(self.STUB.GetEvaluation, request)
|
718
|
+
|
719
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
720
|
+
raise Exception(response.status)
|
721
|
+
|
722
|
+
return response.eval_metrics
|
723
|
+
|
724
|
+
def get_latest_eval(self,
|
725
|
+
label_counts=False,
|
726
|
+
test_set=False,
|
727
|
+
binary_metrics=False,
|
728
|
+
confusion_matrix=False,
|
729
|
+
metrics_by_class=False,
|
730
|
+
metrics_by_area=False) -> Union[resources_pb2.EvalMetrics, None]:
|
731
|
+
"""
|
732
|
+
Run `get_eval_by_id` method with latest `eval_id`
|
733
|
+
|
734
|
+
Args:
|
735
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
736
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
737
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
738
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
739
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
740
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
741
|
+
|
742
|
+
Returns:
|
743
|
+
eval_metric if model is evaluated otherwise None.
|
744
|
+
|
745
|
+
"""
|
746
|
+
|
747
|
+
_latest = self.list_evaluations()[0]
|
748
|
+
result = None
|
749
|
+
if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
|
750
|
+
result = self.get_eval_by_id(
|
751
|
+
eval_id=_latest.id,
|
752
|
+
label_counts=label_counts,
|
753
|
+
test_set=test_set,
|
754
|
+
binary_metrics=binary_metrics,
|
755
|
+
confusion_matrix=confusion_matrix,
|
756
|
+
metrics_by_class=metrics_by_class,
|
757
|
+
metrics_by_area=metrics_by_area)
|
758
|
+
|
759
|
+
return result
|