clarifai 10.2.0__py3-none-any.whl → 10.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +21 -10
- clarifai/client/auth/helper.py +12 -2
- clarifai/client/base.py +14 -4
- clarifai/client/dataset.py +59 -8
- clarifai/client/input.py +15 -2
- clarifai/client/model.py +201 -21
- clarifai/client/module.py +9 -1
- clarifai/client/search.py +10 -2
- clarifai/client/user.py +22 -14
- clarifai/client/workflow.py +10 -2
- clarifai/constants/input.py +1 -0
- clarifai/datasets/export/inputs_annotations.py +18 -12
- clarifai/utils/evaluation/__init__.py +2 -426
- clarifai/utils/evaluation/main.py +426 -0
- clarifai/utils/evaluation/testset_annotation_parser.py +150 -0
- clarifai/utils/misc.py +4 -10
- clarifai/utils/model_train.py +6 -7
- clarifai/versions.py +1 -1
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/METADATA +23 -15
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/RECORD +24 -25
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/WHEEL +1 -1
- clarifai/client/runner.py +0 -234
- clarifai/runners/__init__.py +0 -0
- clarifai/runners/example.py +0 -40
- clarifai/runners/example_llama2.py +0 -81
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/LICENSE +0 -0
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/entry_points.txt +0 -0
- {clarifai-10.2.0.dist-info → clarifai-10.3.0.dist-info}/top_level.txt +0 -0
clarifai/client/model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import os
|
2
2
|
import time
|
3
|
-
from typing import Any, Dict, Generator, List, Union
|
3
|
+
from typing import Any, Dict, Generator, List, Tuple, Union
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
import requests
|
6
7
|
import yaml
|
7
8
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
@@ -12,6 +13,7 @@ from google.protobuf.struct_pb2 import Struct
|
|
12
13
|
from tqdm import tqdm
|
13
14
|
|
14
15
|
from clarifai.client.base import BaseClient
|
16
|
+
from clarifai.client.dataset import Dataset
|
15
17
|
from clarifai.client.input import Inputs
|
16
18
|
from clarifai.client.lister import Lister
|
17
19
|
from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS, TRAINABLE_MODEL_TYPES
|
@@ -34,6 +36,7 @@ class Model(Lister, BaseClient):
|
|
34
36
|
base_url: str = "https://api.clarifai.com",
|
35
37
|
pat: str = None,
|
36
38
|
token: str = None,
|
39
|
+
root_certificates_path: str = None,
|
37
40
|
**kwargs):
|
38
41
|
"""Initializes a Model object.
|
39
42
|
|
@@ -44,6 +47,7 @@ class Model(Lister, BaseClient):
|
|
44
47
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
45
48
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
46
49
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
50
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
47
51
|
**kwargs: Additional keyword arguments to be passed to the Model.
|
48
52
|
"""
|
49
53
|
if url and model_id:
|
@@ -59,7 +63,13 @@ class Model(Lister, BaseClient):
|
|
59
63
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
60
64
|
self.training_params = {}
|
61
65
|
BaseClient.__init__(
|
62
|
-
self,
|
66
|
+
self,
|
67
|
+
user_id=self.user_id,
|
68
|
+
app_id=self.app_id,
|
69
|
+
base=base_url,
|
70
|
+
pat=pat,
|
71
|
+
token=token,
|
72
|
+
root_certificates_path=root_certificates_path)
|
63
73
|
Lister.__init__(self)
|
64
74
|
|
65
75
|
def list_training_templates(self) -> List[str]:
|
@@ -243,7 +253,7 @@ class Model(Lister, BaseClient):
|
|
243
253
|
|
244
254
|
return response.model.model_version.id
|
245
255
|
|
246
|
-
def training_status(self, version_id: str, training_logs: bool = False) -> Dict[str, str]:
|
256
|
+
def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
|
247
257
|
"""Get the training status for the model version. Also stores training logs
|
248
258
|
|
249
259
|
Args:
|
@@ -258,19 +268,20 @@ class Model(Lister, BaseClient):
|
|
258
268
|
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
259
269
|
>>> model.training_status(version_id='version_id',training_logs=True)
|
260
270
|
"""
|
271
|
+
if not version_id and not self.model_info.model_version.id:
|
272
|
+
raise UserError(
|
273
|
+
"Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
274
|
+
)
|
275
|
+
|
276
|
+
if not self.model_info.model_type_id or not self.model_info.model_version.train_log:
|
277
|
+
self.load_info()
|
261
278
|
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
262
279
|
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
263
280
|
|
264
|
-
request = service_pb2.GetModelVersionRequest(
|
265
|
-
user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
|
266
|
-
response = self._grpc_request(self.STUB.GetModelVersion, request)
|
267
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
268
|
-
raise Exception(response.status)
|
269
|
-
|
270
281
|
if training_logs:
|
271
282
|
try:
|
272
|
-
if
|
273
|
-
log_response = requests.get(
|
283
|
+
if self.model_info.model_version.train_log:
|
284
|
+
log_response = requests.get(self.model_info.model_version.train_log)
|
274
285
|
log_response.raise_for_status() # Check for any HTTP errors
|
275
286
|
with open(version_id + '.log', 'wb') as file:
|
276
287
|
for chunk in log_response.iter_content(chunk_size=4096): # 4KB
|
@@ -280,7 +291,7 @@ class Model(Lister, BaseClient):
|
|
280
291
|
except requests.exceptions.RequestException as e:
|
281
292
|
raise Exception(f"An error occurred while getting training logs: {e}")
|
282
293
|
|
283
|
-
return
|
294
|
+
return self.model_info.model_version.status
|
284
295
|
|
285
296
|
def delete_version(self, version_id: str) -> None:
|
286
297
|
"""Deletes a model version for the Model.
|
@@ -407,7 +418,7 @@ class Model(Lister, BaseClient):
|
|
407
418
|
model=self.model_info)
|
408
419
|
|
409
420
|
start_time = time.time()
|
410
|
-
backoff_iterator = BackoffIterator()
|
421
|
+
backoff_iterator = BackoffIterator(10)
|
411
422
|
while True:
|
412
423
|
response = self._grpc_request(self.STUB.PostModelOutputs, request)
|
413
424
|
|
@@ -617,18 +628,22 @@ class Model(Lister, BaseClient):
|
|
617
628
|
return response.eval_metrics
|
618
629
|
|
619
630
|
def evaluate(self,
|
620
|
-
|
631
|
+
dataset: Dataset = None,
|
632
|
+
dataset_id: str = None,
|
621
633
|
dataset_app_id: str = None,
|
622
634
|
dataset_user_id: str = None,
|
635
|
+
dataset_version_id: str = None,
|
623
636
|
eval_id: str = None,
|
624
637
|
extended_metrics: dict = None,
|
625
638
|
eval_info: dict = None) -> resources_pb2.EvalMetrics:
|
626
639
|
""" Run evaluation
|
627
640
|
|
628
641
|
Args:
|
629
|
-
|
642
|
+
dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
|
643
|
+
dataset_id (str): Dataset Id. Default is None.
|
630
644
|
dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
|
631
645
|
dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
|
646
|
+
dataset_version_id (str): Dataset version Id. Default is None.
|
632
647
|
eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
|
633
648
|
extended_metrics (dict): user custom metrics result. Default is None.
|
634
649
|
eval_info (dict): custom eval info. Default is empty dict.
|
@@ -638,6 +653,23 @@ class Model(Lister, BaseClient):
|
|
638
653
|
|
639
654
|
"""
|
640
655
|
assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
656
|
+
|
657
|
+
if dataset:
|
658
|
+
self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
|
659
|
+
dataset_id = dataset.id
|
660
|
+
dataset_app_id = dataset.app_id
|
661
|
+
dataset_user_id = dataset.user_id
|
662
|
+
dataset_version_id = dataset.version.id
|
663
|
+
else:
|
664
|
+
self.logger.warning(
|
665
|
+
"Arguments prefixed with `dataset_` will be removed soon, please use dataset")
|
666
|
+
|
667
|
+
gt_dataset = resources_pb2.Dataset(
|
668
|
+
id=dataset_id,
|
669
|
+
app_id=dataset_app_id or self.auth_helper.app_id,
|
670
|
+
user_id=dataset_user_id or self.auth_helper.user_id,
|
671
|
+
version=resources_pb2.DatasetVersion(id=dataset_version_id))
|
672
|
+
|
641
673
|
metrics = None
|
642
674
|
if isinstance(extended_metrics, dict):
|
643
675
|
metrics = Struct()
|
@@ -659,11 +691,7 @@ class Model(Lister, BaseClient):
|
|
659
691
|
model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
|
660
692
|
),
|
661
693
|
extended_metrics=metrics,
|
662
|
-
ground_truth_dataset=
|
663
|
-
id=dataset_id,
|
664
|
-
app_id=dataset_app_id or self.auth_helper.app_id,
|
665
|
-
user_id=dataset_user_id or self.auth_helper.user_id,
|
666
|
-
),
|
694
|
+
ground_truth_dataset=gt_dataset,
|
667
695
|
eval_info=eval_info_params,
|
668
696
|
)
|
669
697
|
request = service_pb2.PostEvaluationsRequest(
|
@@ -761,6 +789,157 @@ class Model(Lister, BaseClient):
|
|
761
789
|
|
762
790
|
return result
|
763
791
|
|
792
|
+
def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
|
793
|
+
"""Get all eval data of dataset
|
794
|
+
|
795
|
+
Args:
|
796
|
+
dataset (Dataset): Clarifai dataset
|
797
|
+
|
798
|
+
Returns:
|
799
|
+
List[resources_pb2.EvalMetrics]
|
800
|
+
"""
|
801
|
+
_id = dataset.id
|
802
|
+
app = dataset.app_id or self.app_id
|
803
|
+
user_id = dataset.user_id or self.user_id
|
804
|
+
version = dataset.version.id
|
805
|
+
|
806
|
+
list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
|
807
|
+
outputs = []
|
808
|
+
for _eval in list_eval:
|
809
|
+
if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
|
810
|
+
gt_ds = _eval.ground_truth_dataset
|
811
|
+
if (_id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id):
|
812
|
+
if not version or version == gt_ds.version.id:
|
813
|
+
outputs.append(_eval)
|
814
|
+
|
815
|
+
return outputs
|
816
|
+
|
817
|
+
def get_raw_eval(self,
|
818
|
+
dataset: Dataset = None,
|
819
|
+
eval_id: str = None,
|
820
|
+
return_format: str = 'array') -> Union[resources_pb2.EvalTestSetEntry, Tuple[
|
821
|
+
np.array, np.array, list, List[Input]], Tuple[List[dict], List[dict]]]:
|
822
|
+
"""Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
|
823
|
+
|
824
|
+
Args:
|
825
|
+
dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
|
826
|
+
eval_id (str): Evaluation ID, get eval data of specific eval id.
|
827
|
+
return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
|
828
|
+
|
829
|
+
Returns:
|
830
|
+
|
831
|
+
Depends on `return_format`.
|
832
|
+
|
833
|
+
* if return_format == proto
|
834
|
+
`resources_pb2.EvalTestSetEntry`
|
835
|
+
|
836
|
+
* if return_format == array
|
837
|
+
`Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
|
838
|
+
y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
|
839
|
+
- if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
|
840
|
+
- if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
|
841
|
+
|
842
|
+
* if return_format == coco: Applicable only for 'visual-detector'
|
843
|
+
`Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
|
844
|
+
|
845
|
+
Example Usages:
|
846
|
+
------
|
847
|
+
* Evaluate `visual-classifier` using sklearn
|
848
|
+
|
849
|
+
```python
|
850
|
+
import os
|
851
|
+
from sklearn.metrics import accuracy_score
|
852
|
+
from sklearn.metrics import classification_report
|
853
|
+
import numpy as np
|
854
|
+
from clarifai.client.model import Model
|
855
|
+
from clarifai.client.dataset import Dataset
|
856
|
+
os.environ["CLARIFAI_PAT"] = "???"
|
857
|
+
model = Model(url="url/of/model/includes/version-id")
|
858
|
+
dataset = Dataset(dataset_id="dataset-id")
|
859
|
+
y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
|
860
|
+
y = np.argmax(y, axis=1)
|
861
|
+
y_pred = np.argmax(y_pred, axis=1)
|
862
|
+
report = classification_report(y, y_pred, target_names=clss)
|
863
|
+
print(report)
|
864
|
+
acc = accuracy_score(y, y_pred)
|
865
|
+
print("acc ", acc)
|
866
|
+
```
|
867
|
+
|
868
|
+
* Evaluate `visual-detector` using COCOeval
|
869
|
+
|
870
|
+
```python
|
871
|
+
import os
|
872
|
+
import json
|
873
|
+
from pycocotools.coco import COCO
|
874
|
+
from pycocotools.cocoeval import COCOeval
|
875
|
+
from clarifai.client.model import Model
|
876
|
+
from clarifai.client.dataset import Dataset
|
877
|
+
os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
|
878
|
+
model = Model(url=model_url)
|
879
|
+
dataset = Dataset(url=dataset_url)
|
880
|
+
y, y_pred = model.get_raw_eval(dataset, return_format="coco")
|
881
|
+
# save as files to load in COCO API
|
882
|
+
def save_annot(d, path):
|
883
|
+
with open(path, "w") as fp:
|
884
|
+
json.dump(d, fp, indent=2)
|
885
|
+
gt_path = os.path.join("gt.json")
|
886
|
+
pred_path = os.path.join("pred.json")
|
887
|
+
save_annot(y, gt_path)
|
888
|
+
save_annot(y_pred, pred_path)
|
889
|
+
|
890
|
+
cocoGt = COCO(gt_path)
|
891
|
+
cocoPred = COCO(pred_path)
|
892
|
+
cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
|
893
|
+
cocoEval.evaluate()
|
894
|
+
cocoEval.accumulate()
|
895
|
+
cocoEval.summarize() # Print out result of all classes with all area type
|
896
|
+
# Example:
|
897
|
+
# Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
|
898
|
+
# Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
|
899
|
+
# Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
|
900
|
+
# ...
|
901
|
+
```
|
902
|
+
|
903
|
+
"""
|
904
|
+
from clarifai.utils.evaluation.testset_annotation_parser import (
|
905
|
+
parse_eval_annotation_classifier, parse_eval_annotation_detector,
|
906
|
+
parse_eval_annotation_detector_coco)
|
907
|
+
|
908
|
+
valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
|
909
|
+
supported_format = ['proto', 'array', 'coco']
|
910
|
+
assert return_format in supported_format, ValueError(
|
911
|
+
f"Expected return_format in {supported_format}, got {return_format}")
|
912
|
+
self.load_info()
|
913
|
+
model_type_id = self.model_info.model_type_id
|
914
|
+
assert model_type_id in valid_model_types, \
|
915
|
+
f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
|
916
|
+
assert not (dataset and
|
917
|
+
eval_id), "Using both `dataset` and `eval_id`, but only one should be passed."
|
918
|
+
assert not dataset or not eval_id, "Please provide either `dataset` or `eval_id`, but nothing was passed."
|
919
|
+
if model_type_id.endswith("-classifier") and return_format == "coco":
|
920
|
+
raise ValueError(
|
921
|
+
f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
|
922
|
+
)
|
923
|
+
|
924
|
+
if dataset:
|
925
|
+
eval_by_ds = self.get_eval_by_dataset(dataset)
|
926
|
+
if len(eval_by_ds) == 0:
|
927
|
+
raise Exception(f"Model is not valuated with dataset: {dataset}")
|
928
|
+
eval_id = eval_by_ds[0].id
|
929
|
+
|
930
|
+
detail_eval_data = self.get_eval_by_id(eval_id=eval_id, test_set=True, metrics_by_class=True)
|
931
|
+
|
932
|
+
if return_format == "proto":
|
933
|
+
return detail_eval_data.test_set
|
934
|
+
else:
|
935
|
+
if model_type_id.endswith("-classifier"):
|
936
|
+
return parse_eval_annotation_classifier(detail_eval_data)
|
937
|
+
elif model_type_id == "visual-detector":
|
938
|
+
if return_format == "array":
|
939
|
+
return parse_eval_annotation_detector(detail_eval_data)
|
940
|
+
elif return_format == "coco":
|
941
|
+
return parse_eval_annotation_detector_coco(detail_eval_data)
|
942
|
+
|
764
943
|
def export(self, export_dir: str = None) -> None:
|
765
944
|
"""Export the model, stores the exported model as model.tar file
|
766
945
|
|
@@ -830,7 +1009,7 @@ class Model(Lister, BaseClient):
|
|
830
1009
|
)
|
831
1010
|
time.sleep(5)
|
832
1011
|
start_time = time.time()
|
833
|
-
backoff_iterator = BackoffIterator()
|
1012
|
+
backoff_iterator = BackoffIterator(10)
|
834
1013
|
while True:
|
835
1014
|
get_export_response = _get_export_response()
|
836
1015
|
if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \
|
@@ -841,6 +1020,7 @@ class Model(Lister, BaseClient):
|
|
841
1020
|
time.sleep(next(backoff_iterator))
|
842
1021
|
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
843
1022
|
_download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
|
1023
|
+
break
|
844
1024
|
elif time.time() - start_time > 60 * 30:
|
845
1025
|
raise Exception(
|
846
1026
|
f"""Model Export took too long. Please try again or contact support@clarifai.com
|
clarifai/client/module.py
CHANGED
@@ -19,6 +19,7 @@ class Module(Lister, BaseClient):
|
|
19
19
|
base_url: str = "https://api.clarifai.com",
|
20
20
|
pat: str = None,
|
21
21
|
token: str = None,
|
22
|
+
root_certificates_path: str = None,
|
22
23
|
**kwargs):
|
23
24
|
"""Initializes a Module object.
|
24
25
|
|
@@ -29,6 +30,7 @@ class Module(Lister, BaseClient):
|
|
29
30
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
30
31
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT.
|
31
32
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN.
|
33
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
32
34
|
**kwargs: Additional keyword arguments to be passed to the Module.
|
33
35
|
"""
|
34
36
|
if url and module_id:
|
@@ -44,7 +46,13 @@ class Module(Lister, BaseClient):
|
|
44
46
|
self.module_info = resources_pb2.Module(**self.kwargs)
|
45
47
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
46
48
|
BaseClient.__init__(
|
47
|
-
self,
|
49
|
+
self,
|
50
|
+
user_id=self.user_id,
|
51
|
+
app_id=self.app_id,
|
52
|
+
base=base_url,
|
53
|
+
pat=pat,
|
54
|
+
token=token,
|
55
|
+
root_certificates_path=root_certificates_path)
|
48
56
|
Lister.__init__(self)
|
49
57
|
|
50
58
|
def list_versions(self, page_no: int = None,
|
clarifai/client/search.py
CHANGED
@@ -24,7 +24,8 @@ class Search(Lister, BaseClient):
|
|
24
24
|
metric: str = DEFAULT_SEARCH_METRIC,
|
25
25
|
base_url: str = "https://api.clarifai.com",
|
26
26
|
pat: str = None,
|
27
|
-
token: str = None
|
27
|
+
token: str = None,
|
28
|
+
root_certificates_path: str = None):
|
28
29
|
"""Initialize the Search object.
|
29
30
|
|
30
31
|
Args:
|
@@ -35,6 +36,7 @@ class Search(Lister, BaseClient):
|
|
35
36
|
base_url (str, optional): Base API url. Defaults to "https://api.clarifai.com".
|
36
37
|
pat (str, optional): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
37
38
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
39
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
38
40
|
|
39
41
|
Raises:
|
40
42
|
UserError: If the metric is not 'cosine' or 'euclidean'.
|
@@ -52,7 +54,13 @@ class Search(Lister, BaseClient):
|
|
52
54
|
user_id=self.user_id, app_id=self.app_id, pat=pat, token=token, base_url=base_url)
|
53
55
|
self.rank_filter_schema = get_schema()
|
54
56
|
BaseClient.__init__(
|
55
|
-
self,
|
57
|
+
self,
|
58
|
+
user_id=self.user_id,
|
59
|
+
app_id=self.app_id,
|
60
|
+
base=base_url,
|
61
|
+
pat=pat,
|
62
|
+
token=token,
|
63
|
+
root_certificates_path=root_certificates_path)
|
56
64
|
Lister.__init__(self, page_size=1000)
|
57
65
|
|
58
66
|
def _get_annot_proto(self, **kwargs):
|
clarifai/client/user.py
CHANGED
@@ -7,7 +7,6 @@ from google.protobuf.json_format import MessageToDict
|
|
7
7
|
from clarifai.client.app import App
|
8
8
|
from clarifai.client.base import BaseClient
|
9
9
|
from clarifai.client.lister import Lister
|
10
|
-
from clarifai.client.runner import Runner
|
11
10
|
from clarifai.errors import UserError
|
12
11
|
from clarifai.utils.logging import get_logger
|
13
12
|
|
@@ -20,6 +19,7 @@ class User(Lister, BaseClient):
|
|
20
19
|
base_url: str = "https://api.clarifai.com",
|
21
20
|
pat: str = None,
|
22
21
|
token: str = None,
|
22
|
+
root_certificates_path: str = None,
|
23
23
|
**kwargs):
|
24
24
|
"""Initializes an User object.
|
25
25
|
|
@@ -28,12 +28,20 @@ class User(Lister, BaseClient):
|
|
28
28
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
29
29
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
30
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
31
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
31
32
|
**kwargs: Additional keyword arguments to be passed to the User.
|
32
33
|
"""
|
33
34
|
self.kwargs = {**kwargs, 'id': user_id}
|
34
35
|
self.user_info = resources_pb2.User(**self.kwargs)
|
35
36
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
36
|
-
BaseClient.__init__(
|
37
|
+
BaseClient.__init__(
|
38
|
+
self,
|
39
|
+
user_id=self.id,
|
40
|
+
app_id="",
|
41
|
+
base=base_url,
|
42
|
+
pat=pat,
|
43
|
+
token=token,
|
44
|
+
root_certificates_path=root_certificates_path)
|
37
45
|
Lister.__init__(self)
|
38
46
|
|
39
47
|
def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
@@ -69,7 +77,7 @@ class User(Lister, BaseClient):
|
|
69
77
|
**app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
|
70
78
|
|
71
79
|
def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
72
|
-
per_page: int = None) -> Generator[
|
80
|
+
per_page: int = None) -> Generator[dict, None, None]:
|
73
81
|
"""List all runners for the user
|
74
82
|
|
75
83
|
Args:
|
@@ -78,7 +86,7 @@ class User(Lister, BaseClient):
|
|
78
86
|
per_page (int): The number of items per page.
|
79
87
|
|
80
88
|
Yields:
|
81
|
-
|
89
|
+
Dict: Dictionaries containing information about the runners.
|
82
90
|
|
83
91
|
Example:
|
84
92
|
>>> from clarifai.client.user import User
|
@@ -98,8 +106,7 @@ class User(Lister, BaseClient):
|
|
98
106
|
page_no=page_no)
|
99
107
|
|
100
108
|
for runner_info in all_runners_info:
|
101
|
-
yield
|
102
|
-
auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
109
|
+
yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
103
110
|
|
104
111
|
def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
|
105
112
|
"""Creates an app for the user.
|
@@ -127,7 +134,7 @@ class User(Lister, BaseClient):
|
|
127
134
|
self.logger.info("\nApp created\n%s", response.status)
|
128
135
|
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
|
129
136
|
|
130
|
-
def create_runner(self, runner_id: str, labels: List[str], description: str) ->
|
137
|
+
def create_runner(self, runner_id: str, labels: List[str], description: str) -> dict:
|
131
138
|
"""Create a runner
|
132
139
|
|
133
140
|
Args:
|
@@ -136,13 +143,14 @@ class User(Lister, BaseClient):
|
|
136
143
|
description (str): Description of Runner
|
137
144
|
|
138
145
|
Returns:
|
139
|
-
|
146
|
+
Dict: A dictionary containing information about the specified Runner ID.
|
140
147
|
|
141
148
|
Example:
|
142
149
|
>>> from clarifai.client.user import User
|
143
150
|
>>> client = User(user_id="user_id")
|
144
|
-
>>>
|
151
|
+
>>> runner_info = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
|
145
152
|
"""
|
153
|
+
|
146
154
|
if not isinstance(labels, List):
|
147
155
|
raise UserError("Labels must be a List of strings")
|
148
156
|
|
@@ -155,7 +163,7 @@ class User(Lister, BaseClient):
|
|
155
163
|
raise Exception(response.status)
|
156
164
|
self.logger.info("\nRunner created\n%s", response.status)
|
157
165
|
|
158
|
-
return
|
166
|
+
return dict(
|
159
167
|
auth=self.auth_helper,
|
160
168
|
runner_id=runner_id,
|
161
169
|
user_id=self.id,
|
@@ -186,19 +194,19 @@ class User(Lister, BaseClient):
|
|
186
194
|
kwargs['user_id'] = self.id
|
187
195
|
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
|
188
196
|
|
189
|
-
def runner(self, runner_id: str) ->
|
197
|
+
def runner(self, runner_id: str) -> dict:
|
190
198
|
"""Returns a Runner object if exists.
|
191
199
|
|
192
200
|
Args:
|
193
201
|
runner_id (str): The runner ID to interact with
|
194
202
|
|
195
203
|
Returns:
|
196
|
-
|
204
|
+
Dict: A dictionary containing information about the existing runner ID.
|
197
205
|
|
198
206
|
Example:
|
199
207
|
>>> from clarifai.client.user import User
|
200
208
|
>>> client = User(user_id="user_id")
|
201
|
-
>>>
|
209
|
+
>>> runner_info = client.runner(runner_id="runner_id")
|
202
210
|
"""
|
203
211
|
request = service_pb2.GetRunnerRequest(user_app_id=self.user_app_id, runner_id=runner_id)
|
204
212
|
response = self._grpc_request(self.STUB.GetRunner, request)
|
@@ -212,7 +220,7 @@ class User(Lister, BaseClient):
|
|
212
220
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
213
221
|
list(dict_response.keys())[1])
|
214
222
|
|
215
|
-
return
|
223
|
+
return dict(self.auth_helper, check_runner_exists=False, **kwargs)
|
216
224
|
|
217
225
|
def delete_app(self, app_id: str) -> None:
|
218
226
|
"""Deletes an app for the user.
|
clarifai/client/workflow.py
CHANGED
@@ -28,6 +28,7 @@ class Workflow(Lister, BaseClient):
|
|
28
28
|
base_url: str = "https://api.clarifai.com",
|
29
29
|
pat: str = None,
|
30
30
|
token: str = None,
|
31
|
+
root_certificates_path: str = None,
|
31
32
|
**kwargs):
|
32
33
|
"""Initializes a Workflow object.
|
33
34
|
|
@@ -43,6 +44,7 @@ class Workflow(Lister, BaseClient):
|
|
43
44
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
44
45
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
46
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
47
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
46
48
|
**kwargs: Additional keyword arguments to be passed to the Workflow.
|
47
49
|
"""
|
48
50
|
if url and workflow_id:
|
@@ -59,7 +61,13 @@ class Workflow(Lister, BaseClient):
|
|
59
61
|
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
|
60
62
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
61
63
|
BaseClient.__init__(
|
62
|
-
self,
|
64
|
+
self,
|
65
|
+
user_id=self.user_id,
|
66
|
+
app_id=self.app_id,
|
67
|
+
base=base_url,
|
68
|
+
pat=pat,
|
69
|
+
token=token,
|
70
|
+
root_certificates_path=root_certificates_path)
|
63
71
|
Lister.__init__(self)
|
64
72
|
|
65
73
|
def predict(self, inputs: List[Input], workflow_state_id: str = None):
|
@@ -83,7 +91,7 @@ class Workflow(Lister, BaseClient):
|
|
83
91
|
request.workflow_state.id = workflow_state_id
|
84
92
|
|
85
93
|
start_time = time.time()
|
86
|
-
backoff_iterator = BackoffIterator()
|
94
|
+
backoff_iterator = BackoffIterator(10)
|
87
95
|
|
88
96
|
while True:
|
89
97
|
response = self._grpc_request(self.STUB.PostWorkflowResults, request)
|
@@ -0,0 +1 @@
|
|
1
|
+
MAX_UPLOAD_BATCH_SIZE = 128
|
@@ -21,7 +21,7 @@ logger = get_logger("INFO", __name__)
|
|
21
21
|
class DatasetExportReader:
|
22
22
|
|
23
23
|
def __init__(self,
|
24
|
-
session: requests.Session,
|
24
|
+
session: requests.Session = None,
|
25
25
|
archive_url: Optional[str] = None,
|
26
26
|
local_archive_path: Optional[str] = None):
|
27
27
|
"""Download/Reads the zipfile archive and yields every api.Input object.
|
@@ -31,9 +31,11 @@ class DatasetExportReader:
|
|
31
31
|
archive_url: URL of the DatasetVersionExport archive
|
32
32
|
local_archive_path: Path to the DatasetVersionExport archive
|
33
33
|
"""
|
34
|
-
self.input_count =
|
34
|
+
self.input_count = None
|
35
35
|
self.temp_file = None
|
36
36
|
self.session = session
|
37
|
+
if not self.session:
|
38
|
+
self.session = requests.Session()
|
37
39
|
|
38
40
|
assert archive_url or local_archive_path, UserError(
|
39
41
|
"Either archive_url or local_archive_path must be provided.")
|
@@ -59,7 +61,8 @@ class DatasetExportReader:
|
|
59
61
|
def _download_temp_archive(self, archive_url: str,
|
60
62
|
chunk_size: int = 128) -> tempfile.TemporaryFile:
|
61
63
|
"""Downloads the temp archive of InputBatches."""
|
62
|
-
|
64
|
+
session = requests.Session()
|
65
|
+
r = session.get(archive_url, stream=True)
|
63
66
|
temp_file = tempfile.TemporaryFile()
|
64
67
|
for chunk in r.iter_content(chunk_size=chunk_size):
|
65
68
|
temp_file.write(chunk)
|
@@ -67,10 +70,12 @@ class DatasetExportReader:
|
|
67
70
|
return temp_file
|
68
71
|
|
69
72
|
def __len__(self) -> int:
|
70
|
-
if
|
73
|
+
if self.input_count is None:
|
74
|
+
input_count = 0
|
71
75
|
if self.file_name_list is not None:
|
72
76
|
for filename in self.file_name_list:
|
73
|
-
|
77
|
+
input_count += int(filename.split('_n')[-1])
|
78
|
+
self.input_count = input_count
|
74
79
|
|
75
80
|
return self.input_count
|
76
81
|
|
@@ -111,7 +116,8 @@ class InputAnnotationDownloader:
|
|
111
116
|
"""
|
112
117
|
self.input_iterator = input_iterator
|
113
118
|
self.num_workers = min(num_workers, 10) # Max 10 threads
|
114
|
-
self.
|
119
|
+
self.num_inputs = 0
|
120
|
+
self.num_annotations = 0
|
115
121
|
self.split_prefix = None
|
116
122
|
self.session = session
|
117
123
|
self.input_ext = dict(image=".png", text=".txt", audio=".mp3", video=".mp4")
|
@@ -182,14 +188,14 @@ class InputAnnotationDownloader:
|
|
182
188
|
self._save_audio_to_archive(new_archive, hosted_url, file_name)
|
183
189
|
elif input_type == "video":
|
184
190
|
self._save_video_to_archive(new_archive, hosted_url, file_name)
|
185
|
-
self.
|
191
|
+
self.num_inputs += 1
|
186
192
|
|
187
193
|
if data_dict.get("concepts") or data_dict.get("regions"):
|
188
194
|
file_name = os.path.join(split, "annotations", input_.id + ".json")
|
189
195
|
annot_data = data_dict.get("concepts") or data_dict.get("regions")
|
190
196
|
|
191
197
|
self._save_annotation_to_archive(new_archive, annot_data, file_name)
|
192
|
-
self.
|
198
|
+
self.num_annotations += 1
|
193
199
|
|
194
200
|
def _check_output_archive(self, save_path: str) -> None:
|
195
201
|
try:
|
@@ -198,8 +204,8 @@ class InputAnnotationDownloader:
|
|
198
204
|
raise e
|
199
205
|
assert len(
|
200
206
|
archive.namelist()
|
201
|
-
) == self.
|
202
|
-
len(archive.namelist()), self.
|
207
|
+
) == self.num_inputs + self.num_annotations, "Archive has %d inputs+annotations | expecting %d inputs+annotations" % (
|
208
|
+
len(archive.namelist()), self.num_inputs + self.num_annotations)
|
203
209
|
|
204
210
|
def download_archive(self, save_path: str, split: Optional[str] = None) -> None:
|
205
211
|
"""Downloads the archive from the URL into an archive of inputs, annotations in the directory format
|
@@ -218,5 +224,5 @@ class InputAnnotationDownloader:
|
|
218
224
|
progress.update()
|
219
225
|
|
220
226
|
self._check_output_archive(save_path)
|
221
|
-
logger.info("Downloaded %d inputs
|
222
|
-
|
227
|
+
logger.info("Downloaded %d inputs and %d annotations to %s" %
|
228
|
+
(self.num_inputs, self.num_annotations, save_path))
|