clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__main__.py +1 -1
- clarifai/cli/base.py +144 -136
- clarifai/cli/compute_cluster.py +45 -31
- clarifai/cli/deployment.py +93 -76
- clarifai/cli/model.py +578 -180
- clarifai/cli/nodepool.py +100 -82
- clarifai/client/__init__.py +12 -2
- clarifai/client/app.py +973 -911
- clarifai/client/auth/helper.py +345 -342
- clarifai/client/auth/register.py +7 -7
- clarifai/client/auth/stub.py +107 -106
- clarifai/client/base.py +185 -178
- clarifai/client/compute_cluster.py +214 -180
- clarifai/client/dataset.py +793 -698
- clarifai/client/deployment.py +55 -50
- clarifai/client/input.py +1223 -1088
- clarifai/client/lister.py +47 -45
- clarifai/client/model.py +1939 -1717
- clarifai/client/model_client.py +525 -502
- clarifai/client/module.py +82 -73
- clarifai/client/nodepool.py +358 -213
- clarifai/client/runner.py +58 -0
- clarifai/client/search.py +342 -309
- clarifai/client/user.py +419 -414
- clarifai/client/workflow.py +294 -274
- clarifai/constants/dataset.py +11 -17
- clarifai/constants/model.py +8 -2
- clarifai/datasets/export/inputs_annotations.py +233 -217
- clarifai/datasets/upload/base.py +63 -51
- clarifai/datasets/upload/features.py +43 -38
- clarifai/datasets/upload/image.py +237 -207
- clarifai/datasets/upload/loaders/coco_captions.py +34 -32
- clarifai/datasets/upload/loaders/coco_detection.py +72 -65
- clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
- clarifai/datasets/upload/loaders/xview_detection.py +274 -132
- clarifai/datasets/upload/multimodal.py +55 -46
- clarifai/datasets/upload/text.py +55 -47
- clarifai/datasets/upload/utils.py +250 -234
- clarifai/errors.py +51 -50
- clarifai/models/api.py +260 -238
- clarifai/modules/css.py +50 -50
- clarifai/modules/pages.py +33 -33
- clarifai/rag/rag.py +312 -288
- clarifai/rag/utils.py +91 -84
- clarifai/runners/models/model_builder.py +906 -802
- clarifai/runners/models/model_class.py +370 -331
- clarifai/runners/models/model_run_locally.py +459 -419
- clarifai/runners/models/model_runner.py +170 -162
- clarifai/runners/models/model_servicer.py +78 -70
- clarifai/runners/server.py +111 -101
- clarifai/runners/utils/code_script.py +225 -187
- clarifai/runners/utils/const.py +4 -1
- clarifai/runners/utils/data_types/__init__.py +12 -0
- clarifai/runners/utils/data_types/data_types.py +598 -0
- clarifai/runners/utils/data_utils.py +387 -440
- clarifai/runners/utils/loader.py +247 -227
- clarifai/runners/utils/method_signatures.py +411 -386
- clarifai/runners/utils/openai_convertor.py +108 -109
- clarifai/runners/utils/serializers.py +175 -179
- clarifai/runners/utils/url_fetcher.py +35 -35
- clarifai/schema/search.py +56 -63
- clarifai/urls/helper.py +125 -102
- clarifai/utils/cli.py +129 -123
- clarifai/utils/config.py +127 -87
- clarifai/utils/constants.py +49 -0
- clarifai/utils/evaluation/helpers.py +503 -466
- clarifai/utils/evaluation/main.py +431 -393
- clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
- clarifai/utils/logging.py +324 -306
- clarifai/utils/misc.py +60 -56
- clarifai/utils/model_train.py +165 -146
- clarifai/utils/protobuf.py +126 -103
- clarifai/versions.py +3 -1
- clarifai/workflows/export.py +48 -50
- clarifai/workflows/utils.py +39 -36
- clarifai/workflows/validate.py +55 -43
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
- clarifai-11.4.0.dist-info/RECORD +109 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
- clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/__pycache__/errors.cpython-310.pyc +0 -0
- clarifai/__pycache__/errors.cpython-311.pyc +0 -0
- clarifai/__pycache__/versions.cpython-310.pyc +0 -0
- clarifai/__pycache__/versions.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
- clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
- clarifai/client/cli/__init__.py +0 -0
- clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
- clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
- clarifai/client/cli/base_cli.py +0 -88
- clarifai/client/cli/model_cli.py +0 -29
- clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
- clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
- clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
- clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
- clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
- clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
- clarifai/runners/models/base_typed_model.py +0 -238
- clarifai/runners/models/model_class_refract.py +0 -80
- clarifai/runners/models/model_upload.py +0 -607
- clarifai/runners/models/temp.py +0 -25
- clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
- clarifai/runners/utils/data_handler.py +0 -231
- clarifai/runners/utils/data_handler_refract.py +0 -213
- clarifai/runners/utils/data_types.py +0 -469
- clarifai/runners/utils/logger.py +0 -0
- clarifai/runners/utils/openai_format.py +0 -87
- clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
- clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
- clarifai-11.3.0rc2.dist-info/RECORD +0 -322
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
clarifai/client/model.py
CHANGED
@@ -22,1751 +22,1973 @@ from clarifai.client.input import Inputs
|
|
22
22
|
from clarifai.client.lister import Lister
|
23
23
|
from clarifai.client.model_client import ModelClient
|
24
24
|
from clarifai.client.nodepool import Nodepool
|
25
|
-
from clarifai.constants.model import (
|
26
|
-
|
27
|
-
|
25
|
+
from clarifai.constants.model import (
|
26
|
+
CHUNK_SIZE,
|
27
|
+
MAX_CHUNK_SIZE,
|
28
|
+
MAX_RANGE_SIZE,
|
29
|
+
MIN_CHUNK_SIZE,
|
30
|
+
MIN_RANGE_SIZE,
|
31
|
+
MODEL_EXPORT_TIMEOUT,
|
32
|
+
RANGE_SIZE,
|
33
|
+
TRAINABLE_MODEL_TYPES,
|
34
|
+
)
|
28
35
|
from clarifai.errors import UserError
|
29
36
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
30
37
|
from clarifai.utils.logging import logger
|
31
38
|
from clarifai.utils.misc import BackoffIterator
|
32
|
-
from clarifai.utils.model_train import (
|
33
|
-
|
34
|
-
|
39
|
+
from clarifai.utils.model_train import (
|
40
|
+
find_and_replace_key,
|
41
|
+
params_parser,
|
42
|
+
response_to_model_params,
|
43
|
+
response_to_param_info,
|
44
|
+
response_to_templates,
|
45
|
+
)
|
35
46
|
from clarifai.utils.protobuf import dict_to_protobuf
|
47
|
+
|
36
48
|
MAX_SIZE_PER_STREAM = int(89_128_960) # 85GiB
|
37
49
|
MIN_CHUNK_FOR_UPLOAD_FILE = int(5_242_880) # 5MiB
|
38
50
|
MAX_CHUNK_FOR_UPLOAD_FILE = int(5_242_880_000) # 5GiB
|
39
51
|
|
40
52
|
|
41
53
|
class Model(Lister, BaseClient):
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
url: str = None,
|
46
|
-
model_id: str = None,
|
47
|
-
model_version: Dict = {'id': ""},
|
48
|
-
base_url: str = "https://api.clarifai.com",
|
49
|
-
pat: str = None,
|
50
|
-
token: str = None,
|
51
|
-
root_certificates_path: str = None,
|
52
|
-
compute_cluster_id: str = None,
|
53
|
-
nodepool_id: str = None,
|
54
|
-
deployment_id: str = None,
|
55
|
-
**kwargs):
|
56
|
-
"""Initializes a Model object.
|
57
|
-
|
58
|
-
Args:
|
59
|
-
url (str): The URL to initialize the model object.
|
60
|
-
model_id (str): The Model ID to interact with.
|
61
|
-
model_version (dict): The Model Version to interact with.
|
62
|
-
base_url (str): Base API url. Default "https://api.clarifai.com"
|
63
|
-
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
64
|
-
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
65
|
-
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
66
|
-
**kwargs: Additional keyword arguments to be passed to the Model.
|
67
|
-
"""
|
68
|
-
if url and model_id:
|
69
|
-
raise UserError("You can only specify one of url or model_id.")
|
70
|
-
if not url and not model_id:
|
71
|
-
raise UserError("You must specify one of url or model_id.")
|
72
|
-
if url:
|
73
|
-
user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(url)
|
74
|
-
model_version = {'id': model_version_id}
|
75
|
-
kwargs = {'user_id': user_id, 'app_id': app_id}
|
76
|
-
|
77
|
-
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, }
|
78
|
-
self.model_info = resources_pb2.Model()
|
79
|
-
dict_to_protobuf(self.model_info, self.kwargs)
|
80
|
-
|
81
|
-
self.logger = logger
|
82
|
-
self.training_params = {}
|
83
|
-
self.input_types = None
|
84
|
-
self._client = None
|
85
|
-
self._added_methods = False
|
86
|
-
self._set_runner_selector(
|
87
|
-
compute_cluster_id=compute_cluster_id,
|
88
|
-
nodepool_id=nodepool_id,
|
89
|
-
deployment_id=deployment_id,
|
90
|
-
user_id=self.user_id, # FIXME the deployment's user_id can be different than the model's.
|
91
|
-
)
|
92
|
-
BaseClient.__init__(
|
54
|
+
"""Model is a class that provides access to Clarifai API endpoints related to Model information."""
|
55
|
+
|
56
|
+
def __init__(
|
93
57
|
self,
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
]:
|
148
|
-
raise UserError(
|
149
|
-
f"Template should be provided for {self.model_info.model_type_id} model type")
|
150
|
-
if template is not None and self.model_info.model_type_id in [
|
151
|
-
"clusterer", "embedding-classifier"
|
152
|
-
]:
|
153
|
-
raise UserError(
|
154
|
-
f"Template should not be provided for {self.model_info.model_type_id} model type")
|
155
|
-
|
156
|
-
request = service_pb2.ListModelTypesRequest(user_app_id=self.user_app_id,)
|
157
|
-
response = self._grpc_request(self.STUB.ListModelTypes, request)
|
158
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
159
|
-
raise Exception(response.status)
|
160
|
-
params = response_to_model_params(
|
161
|
-
response=response, model_type_id=self.model_info.model_type_id, template=template)
|
162
|
-
# yaml file
|
163
|
-
assert save_to.endswith('.yaml'), "File extension should be .yaml"
|
164
|
-
with open(save_to, 'w') as f:
|
165
|
-
yaml.dump(params, f, default_flow_style=False, sort_keys=False)
|
166
|
-
# updating the global model params
|
167
|
-
self.training_params.update(params)
|
168
|
-
|
169
|
-
return params
|
170
|
-
|
171
|
-
def update_params(self, **kwargs) -> None:
|
172
|
-
"""Updates the model params for the model.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
**kwargs: model params to update.
|
176
|
-
|
177
|
-
Example:
|
178
|
-
>>> from clarifai.client.model import Model
|
179
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
180
|
-
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
181
|
-
>>> model.update_params(batch_size = 8, dataset_version = 'dataset_version_id')
|
182
|
-
"""
|
183
|
-
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
184
|
-
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
185
|
-
if len(self.training_params) == 0:
|
186
|
-
raise UserError(
|
187
|
-
f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
|
188
|
-
)
|
189
|
-
# getting all the keys in nested dictionary
|
190
|
-
all_keys = [key for key in self.training_params.keys()] + [
|
191
|
-
key for key in self.training_params.values() if isinstance(key, dict) for key in key
|
192
|
-
]
|
193
|
-
# checking if the given params are valid
|
194
|
-
if not set(kwargs.keys()).issubset(all_keys):
|
195
|
-
raise UserError("Invalid params")
|
196
|
-
# updating the global model params
|
197
|
-
for key, value in kwargs.items():
|
198
|
-
find_and_replace_key(self.training_params, key, value)
|
199
|
-
|
200
|
-
def get_param_info(self, param: str) -> Dict[str, Any]:
|
201
|
-
"""Returns the param info for the param.
|
202
|
-
|
203
|
-
Args:
|
204
|
-
param (str): The param to get the info for.
|
205
|
-
|
206
|
-
Returns:
|
207
|
-
param_info (Dict): Dictionary of model param info for the param.
|
208
|
-
|
209
|
-
Example:
|
210
|
-
>>> from clarifai.client.model import Model
|
211
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
212
|
-
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
213
|
-
>>> model.get_param_info('param')
|
214
|
-
"""
|
215
|
-
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
216
|
-
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
217
|
-
if len(self.training_params) == 0:
|
218
|
-
raise UserError(
|
219
|
-
f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
|
220
|
-
)
|
221
|
-
|
222
|
-
all_keys = [key for key in self.training_params.keys()] + [
|
223
|
-
key for key in self.training_params.values() if isinstance(key, dict) for key in key
|
224
|
-
]
|
225
|
-
if param not in all_keys:
|
226
|
-
raise UserError(f"Invalid param: '{param}' for model type '{self.model_info.model_type_id}'")
|
227
|
-
template = self.training_params['train_params']['template'] if 'template' in all_keys else None
|
228
|
-
|
229
|
-
request = service_pb2.ListModelTypesRequest(user_app_id=self.user_app_id,)
|
230
|
-
response = self._grpc_request(self.STUB.ListModelTypes, request)
|
231
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
232
|
-
raise Exception(response.status)
|
233
|
-
param_info = response_to_param_info(
|
234
|
-
response=response,
|
235
|
-
model_type_id=self.model_info.model_type_id,
|
236
|
-
param=param,
|
237
|
-
template=template)
|
238
|
-
|
239
|
-
return param_info
|
240
|
-
|
241
|
-
def train(self, yaml_file: str = None) -> str:
|
242
|
-
"""Trains the model based on the given yaml file or model params.
|
243
|
-
|
244
|
-
Args:
|
245
|
-
yaml_file (str): The yaml file for the model params.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
model_version_id (str): The model version ID for the model.
|
249
|
-
|
250
|
-
Example:
|
251
|
-
>>> from clarifai.client.model import Model
|
252
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
253
|
-
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
254
|
-
>>> model.train('model_params.yaml')
|
255
|
-
"""
|
256
|
-
if not self.model_info.model_type_id:
|
257
|
-
self.load_info()
|
258
|
-
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
259
|
-
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
260
|
-
if not yaml_file and len(self.training_params) == 0:
|
261
|
-
raise UserError("Provide yaml file or run 'model.get_params()'")
|
262
|
-
|
263
|
-
if yaml_file:
|
264
|
-
with open(yaml_file, 'r') as file:
|
265
|
-
params_dict = yaml.safe_load(file)
|
266
|
-
else:
|
267
|
-
params_dict = self.training_params
|
268
|
-
# getting all the concepts for the model type
|
269
|
-
if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
|
270
|
-
concepts = self._list_concepts()
|
271
|
-
train_dict = params_parser(params_dict, concepts)
|
272
|
-
request = service_pb2.PostModelVersionsRequest(
|
273
|
-
user_app_id=self.user_app_id,
|
274
|
-
model_id=self.id,
|
275
|
-
model_versions=[resources_pb2.ModelVersion(**train_dict)])
|
276
|
-
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
277
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
278
|
-
raise Exception(response.status)
|
279
|
-
self.logger.info("\nModel Training Started\n%s", response.status)
|
280
|
-
|
281
|
-
return response.model.model_version.id
|
282
|
-
|
283
|
-
def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
|
284
|
-
"""Get the training status for the model version. Also stores training logs
|
285
|
-
|
286
|
-
Args:
|
287
|
-
version_id (str): The version ID to get the training status for.
|
288
|
-
training_logs (bool): Whether to save the training logs in a file.
|
289
|
-
|
290
|
-
Returns:
|
291
|
-
training_status (Dict): Dictionary of training status for the model version.
|
292
|
-
|
293
|
-
Example:
|
294
|
-
>>> from clarifai.client.model import Model
|
295
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
296
|
-
>>> model.training_status(version_id='version_id',training_logs=True)
|
297
|
-
"""
|
298
|
-
if not version_id and not self.model_info.model_version.id:
|
299
|
-
raise UserError(
|
300
|
-
"Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
301
|
-
)
|
302
|
-
|
303
|
-
self.load_info()
|
304
|
-
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
305
|
-
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
306
|
-
|
307
|
-
if training_logs:
|
308
|
-
try:
|
309
|
-
if self.model_info.model_version.train_log:
|
310
|
-
log_response = requests.get(self.model_info.model_version.train_log)
|
311
|
-
log_response.raise_for_status() # Check for any HTTP errors
|
312
|
-
with open(version_id + '.log', 'wb') as file:
|
313
|
-
for chunk in log_response.iter_content(chunk_size=4096): # 4KB
|
314
|
-
file.write(chunk)
|
315
|
-
self.logger.info(f"\nTraining logs are saving in '{version_id+'.log'}' file")
|
316
|
-
|
317
|
-
except requests.exceptions.RequestException as e:
|
318
|
-
raise Exception(f"An error occurred while getting training logs: {e}")
|
319
|
-
|
320
|
-
return self.model_info.model_version.status
|
321
|
-
|
322
|
-
def delete_version(self, version_id: str) -> None:
|
323
|
-
"""Deletes a model version for the Model.
|
324
|
-
|
325
|
-
Args:
|
326
|
-
version_id (str): The version ID to delete.
|
327
|
-
|
328
|
-
Example:
|
329
|
-
>>> from clarifai.client.model import Model
|
330
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
331
|
-
>>> model.delete_version(version_id='version_id')
|
332
|
-
"""
|
333
|
-
request = service_pb2.DeleteModelVersionRequest(
|
334
|
-
user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
|
335
|
-
|
336
|
-
response = self._grpc_request(self.STUB.DeleteModelVersion, request)
|
337
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
338
|
-
raise Exception(response.status)
|
339
|
-
self.logger.info("\nModel Version Deleted\n%s", response.status)
|
340
|
-
|
341
|
-
def create_version(self, **kwargs) -> 'Model':
|
342
|
-
"""Creates a model version for the Model.
|
343
|
-
|
344
|
-
Args:
|
345
|
-
**kwargs: Additional keyword arguments to be passed to Model Version.
|
346
|
-
- description (str): The description of the model version.
|
347
|
-
- concepts (list[Concept]): The concepts to associate with the model version.
|
348
|
-
- output_info (resources_pb2.OutputInfo(): The output info to associate with the model version.
|
349
|
-
|
350
|
-
Returns:
|
351
|
-
Model: A Model object for the specified model ID.
|
352
|
-
|
353
|
-
Example:
|
354
|
-
>>> from clarifai.client.model import Model
|
355
|
-
>>> model = Model("url")
|
356
|
-
or
|
357
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
358
|
-
>>> model_version = model.create_version(description='model_version_description')
|
359
|
-
"""
|
360
|
-
if self.model_info.model_type_id in TRAINABLE_MODEL_TYPES:
|
361
|
-
raise UserError(
|
362
|
-
f"{self.model_info.model_type_id} is a trainable model type. Use 'model.train()' to train the model"
|
363
|
-
)
|
364
|
-
|
365
|
-
request = service_pb2.PostModelVersionsRequest(
|
366
|
-
user_app_id=self.user_app_id,
|
367
|
-
model_id=self.id,
|
368
|
-
model_versions=[resources_pb2.ModelVersion(**kwargs)])
|
369
|
-
|
370
|
-
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
371
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
372
|
-
raise Exception(response.status)
|
373
|
-
self.logger.info("\nModel Version created\n%s", response.status)
|
374
|
-
|
375
|
-
kwargs.update({'app_id': self.app_id, 'user_id': self.user_id})
|
376
|
-
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
377
|
-
kwargs = self.process_response_keys(dict_response['model'], 'model')
|
378
|
-
|
379
|
-
return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
|
380
|
-
|
381
|
-
def list_versions(self, page_no: int = None,
|
382
|
-
per_page: int = None) -> Generator['Model', None, None]:
|
383
|
-
"""Lists all the versions for the model.
|
384
|
-
|
385
|
-
Args:
|
386
|
-
page_no (int): The page number to list.
|
387
|
-
per_page (int): The number of items per page.
|
388
|
-
|
389
|
-
Yields:
|
390
|
-
Model: Model objects for the versions of the model.
|
391
|
-
|
392
|
-
Example:
|
393
|
-
>>> from clarifai.client.model import Model
|
394
|
-
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
395
|
-
or
|
396
|
-
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
397
|
-
>>> all_model_versions = list(model.list_versions())
|
398
|
-
|
399
|
-
Note:
|
400
|
-
Defaults to 16 per page if page_no is specified and per_page is not specified.
|
401
|
-
If both page_no and per_page are None, then lists all the resources.
|
402
|
-
"""
|
403
|
-
request_data = dict(
|
404
|
-
user_app_id=self.user_app_id,
|
405
|
-
model_id=self.id,
|
406
|
-
)
|
407
|
-
all_model_versions_info = self.list_pages_generator(
|
408
|
-
self.STUB.ListModelVersions,
|
409
|
-
service_pb2.ListModelVersionsRequest,
|
410
|
-
request_data,
|
411
|
-
per_page=per_page,
|
412
|
-
page_no=page_no)
|
413
|
-
|
414
|
-
for model_version_info in all_model_versions_info:
|
415
|
-
model_version_info['id'] = model_version_info['model_version_id']
|
416
|
-
del model_version_info['model_version_id']
|
417
|
-
try:
|
418
|
-
del model_version_info['train_info']['dataset']['version']['metrics']
|
419
|
-
except KeyError:
|
420
|
-
pass
|
421
|
-
yield Model.from_auth_helper(
|
422
|
-
auth=self.auth_helper,
|
423
|
-
model_id=self.id,
|
424
|
-
**dict(self.kwargs, model_version=model_version_info))
|
425
|
-
|
426
|
-
@property
|
427
|
-
def client(self):
|
428
|
-
if self._client is None:
|
429
|
-
request_template = service_pb2.PostModelOutputsRequest(
|
430
|
-
user_app_id=self.user_app_id,
|
431
|
-
model_id=self.id,
|
432
|
-
version_id=self.model_version.id,
|
433
|
-
model=self.model_info,
|
434
|
-
runner_selector=self._runner_selector,
|
435
|
-
)
|
436
|
-
self._client = ModelClient(self.STUB, request_template=request_template)
|
437
|
-
return self._client
|
438
|
-
|
439
|
-
def predict(self, *args, **kwargs):
|
440
|
-
"""
|
441
|
-
Calls the model's predict() method with the given arguments.
|
442
|
-
|
443
|
-
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
444
|
-
protos directly for compatibility with previous versions of the SDK.
|
445
|
-
"""
|
446
|
-
|
447
|
-
inputs = None
|
448
|
-
if 'inputs' in kwargs:
|
449
|
-
inputs = kwargs['inputs']
|
450
|
-
elif args:
|
451
|
-
inputs = args[0]
|
452
|
-
if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
|
453
|
-
assert len(args) <= 1, "Cannot pass in raw protos and additional arguments at the same time."
|
454
|
-
inference_params = kwargs.get('inference_params', {})
|
455
|
-
output_config = kwargs.get('output_config', {})
|
456
|
-
return self.client._predict_by_proto(
|
457
|
-
inputs=inputs, inference_params=inference_params, output_config=output_config)
|
458
|
-
|
459
|
-
return self.client.predict(*args, **kwargs)
|
460
|
-
|
461
|
-
def __getattr__(self, name):
|
462
|
-
try:
|
463
|
-
return getattr(self.model_info, name)
|
464
|
-
except AttributeError:
|
465
|
-
pass
|
466
|
-
if not self._added_methods:
|
467
|
-
# fetch and set all the model methods
|
468
|
-
self._added_methods = True
|
469
|
-
self.client.fetch()
|
470
|
-
for method_name in self.client._method_signatures.keys():
|
471
|
-
if not hasattr(self, method_name):
|
472
|
-
setattr(self, method_name, getattr(self.client, method_name))
|
473
|
-
if hasattr(self.client, name):
|
474
|
-
return getattr(self.client, name)
|
475
|
-
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
476
|
-
|
477
|
-
def _check_predict_input_type(self, input_type: str) -> None:
|
478
|
-
"""Checks if the input type is valid for the model.
|
479
|
-
|
480
|
-
Args:
|
481
|
-
input_type (str): The input type to check.
|
482
|
-
Returns:
|
483
|
-
None
|
484
|
-
"""
|
485
|
-
if not input_type:
|
486
|
-
self.load_input_types()
|
487
|
-
if len(self.input_types) > 1:
|
488
|
-
raise UserError(
|
489
|
-
"Model has multiple input types. Please use model.predict() for this multi-modal model."
|
58
|
+
url: str = None,
|
59
|
+
model_id: str = None,
|
60
|
+
model_version: Dict = {'id': ""},
|
61
|
+
base_url: str = "https://api.clarifai.com",
|
62
|
+
pat: str = None,
|
63
|
+
token: str = None,
|
64
|
+
root_certificates_path: str = None,
|
65
|
+
compute_cluster_id: str = None,
|
66
|
+
nodepool_id: str = None,
|
67
|
+
deployment_id: str = None,
|
68
|
+
**kwargs,
|
69
|
+
):
|
70
|
+
"""Initializes a Model object.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
url (str): The URL to initialize the model object.
|
74
|
+
model_id (str): The Model ID to interact with.
|
75
|
+
model_version (dict): The Model Version to interact with.
|
76
|
+
base_url (str): Base API url. Default "https://api.clarifai.com"
|
77
|
+
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
78
|
+
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
79
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
80
|
+
**kwargs: Additional keyword arguments to be passed to the Model.
|
81
|
+
"""
|
82
|
+
if url and model_id:
|
83
|
+
raise UserError("You can only specify one of url or model_id.")
|
84
|
+
if not url and not model_id:
|
85
|
+
raise UserError("You must specify one of url or model_id.")
|
86
|
+
if url:
|
87
|
+
user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(
|
88
|
+
url
|
89
|
+
)
|
90
|
+
model_version = {'id': model_version_id}
|
91
|
+
kwargs = {'user_id': user_id, 'app_id': app_id}
|
92
|
+
|
93
|
+
self.kwargs = {
|
94
|
+
**kwargs,
|
95
|
+
'id': model_id,
|
96
|
+
'model_version': model_version,
|
97
|
+
}
|
98
|
+
self.model_info = resources_pb2.Model()
|
99
|
+
dict_to_protobuf(self.model_info, self.kwargs)
|
100
|
+
|
101
|
+
self.logger = logger
|
102
|
+
self.training_params = {}
|
103
|
+
self.input_types = None
|
104
|
+
self._client = None
|
105
|
+
self._added_methods = False
|
106
|
+
self._set_runner_selector(
|
107
|
+
compute_cluster_id=compute_cluster_id,
|
108
|
+
nodepool_id=nodepool_id,
|
109
|
+
deployment_id=deployment_id,
|
110
|
+
user_id=self.user_id, # FIXME the deployment's user_id can be different than the model's.
|
490
111
|
)
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
112
|
+
BaseClient.__init__(
|
113
|
+
self,
|
114
|
+
user_id=self.user_id,
|
115
|
+
app_id=self.app_id,
|
116
|
+
base=base_url,
|
117
|
+
pat=pat,
|
118
|
+
token=token,
|
119
|
+
root_certificates_path=root_certificates_path,
|
120
|
+
)
|
121
|
+
Lister.__init__(self)
|
496
122
|
|
497
|
-
|
498
|
-
|
123
|
+
@classmethod
|
124
|
+
def from_current_context(cls, **kwargs) -> 'Model':
|
125
|
+
from clarifai.utils.config import Config
|
499
126
|
|
500
|
-
|
501
|
-
None
|
127
|
+
current = Config.from_yaml().current
|
502
128
|
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
self
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
129
|
+
# set the current context to env vars.
|
130
|
+
current.set_to_env()
|
131
|
+
|
132
|
+
url = f"https://clarifai.com/{current.user_id}/{current.app_id}/models/{current.model_id}"
|
133
|
+
|
134
|
+
# construct the Model object.
|
135
|
+
kwargs = {}
|
136
|
+
try:
|
137
|
+
kwargs['deployment_id'] = current.deployment_id
|
138
|
+
except AttributeError:
|
139
|
+
try:
|
140
|
+
kwargs['compute_cluster_id'] = current.compute_cluster_id
|
141
|
+
kwargs['nodepool_id'] = current.nodepool_id
|
142
|
+
except AttributeError:
|
143
|
+
pass
|
144
|
+
|
145
|
+
return Model(url, base_url=current.api_base, pat=current.pat, **kwargs)
|
146
|
+
|
147
|
+
def list_training_templates(self) -> List[str]:
|
148
|
+
"""Lists all the training templates for the model type.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
templates (List): List of training templates for the model type.
|
152
|
+
|
153
|
+
Example:
|
154
|
+
>>> from clarifai.client.model import Model
|
155
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
156
|
+
>>> print(model.list_training_templates())
|
157
|
+
"""
|
158
|
+
if not self.model_info.model_type_id:
|
159
|
+
self.load_info()
|
160
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
161
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
162
|
+
request = service_pb2.ListModelTypesRequest(
|
163
|
+
user_app_id=self.user_app_id,
|
537
164
|
)
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
|
544
|
-
raise UserError(
|
545
|
-
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
|
165
|
+
response = self._grpc_request(self.STUB.ListModelTypes, request)
|
166
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
167
|
+
raise Exception(response.status)
|
168
|
+
templates = response_to_templates(
|
169
|
+
response=response, model_type_id=self.model_info.model_type_id
|
546
170
|
)
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
171
|
+
|
172
|
+
return templates
|
173
|
+
|
174
|
+
def get_params(self, template: str = None, save_to: str = 'params.yaml') -> Dict[str, Any]:
|
175
|
+
"""Returns the model params for the model type and yaml file.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
template (str): The template to use for the model type.
|
179
|
+
yaml_file (str): The yaml file to save the model params.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
params (Dict): Dictionary of model params for the model type.
|
183
|
+
|
184
|
+
Example:
|
185
|
+
>>> from clarifai.client.model import Model
|
186
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
187
|
+
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
188
|
+
"""
|
189
|
+
if not self.model_info.model_type_id:
|
190
|
+
self.load_info()
|
191
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
192
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
193
|
+
if template is None and self.model_info.model_type_id not in [
|
194
|
+
"clusterer",
|
195
|
+
"embedding-classifier",
|
196
|
+
]:
|
197
|
+
raise UserError(
|
198
|
+
f"Template should be provided for {self.model_info.model_type_id} model type"
|
199
|
+
)
|
200
|
+
if template is not None and self.model_info.model_type_id in [
|
201
|
+
"clusterer",
|
202
|
+
"embedding-classifier",
|
203
|
+
]:
|
204
|
+
raise UserError(
|
205
|
+
f"Template should not be provided for {self.model_info.model_type_id} model type"
|
206
|
+
)
|
207
|
+
|
208
|
+
request = service_pb2.ListModelTypesRequest(
|
209
|
+
user_app_id=self.user_app_id,
|
210
|
+
)
|
211
|
+
response = self._grpc_request(self.STUB.ListModelTypes, request)
|
212
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
213
|
+
raise Exception(response.status)
|
214
|
+
params = response_to_model_params(
|
215
|
+
response=response, model_type_id=self.model_info.model_type_id, template=template
|
216
|
+
)
|
217
|
+
# yaml file
|
218
|
+
assert save_to.endswith('.yaml'), "File extension should be .yaml"
|
219
|
+
with open(save_to, 'w') as f:
|
220
|
+
yaml.dump(params, f, default_flow_style=False, sort_keys=False)
|
221
|
+
# updating the global model params
|
222
|
+
self.training_params.update(params)
|
223
|
+
|
224
|
+
return params
|
225
|
+
|
226
|
+
def update_params(self, **kwargs) -> None:
|
227
|
+
"""Updates the model params for the model.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
**kwargs: model params to update.
|
231
|
+
|
232
|
+
Example:
|
233
|
+
>>> from clarifai.client.model import Model
|
234
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
235
|
+
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
236
|
+
>>> model.update_params(batch_size = 8, dataset_version = 'dataset_version_id')
|
237
|
+
"""
|
238
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
239
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
240
|
+
if len(self.training_params) == 0:
|
241
|
+
raise UserError(
|
242
|
+
f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
|
243
|
+
)
|
244
|
+
# getting all the keys in nested dictionary
|
245
|
+
all_keys = [key for key in self.training_params.keys()] + [
|
246
|
+
key for key in self.training_params.values() if isinstance(key, dict) for key in key
|
247
|
+
]
|
248
|
+
# checking if the given params are valid
|
249
|
+
if not set(kwargs.keys()).issubset(all_keys):
|
250
|
+
raise UserError("Invalid params")
|
251
|
+
# updating the global model params
|
252
|
+
for key, value in kwargs.items():
|
253
|
+
find_and_replace_key(self.training_params, key, value)
|
254
|
+
|
255
|
+
def get_param_info(self, param: str) -> Dict[str, Any]:
|
256
|
+
"""Returns the param info for the param.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
param (str): The param to get the info for.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
param_info (Dict): Dictionary of model param info for the param.
|
263
|
+
|
264
|
+
Example:
|
265
|
+
>>> from clarifai.client.model import Model
|
266
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
267
|
+
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
268
|
+
>>> model.get_param_info('param')
|
269
|
+
"""
|
270
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
271
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
272
|
+
if len(self.training_params) == 0:
|
273
|
+
raise UserError(
|
274
|
+
f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
|
275
|
+
)
|
276
|
+
|
277
|
+
all_keys = [key for key in self.training_params.keys()] + [
|
278
|
+
key for key in self.training_params.values() if isinstance(key, dict) for key in key
|
279
|
+
]
|
280
|
+
if param not in all_keys:
|
281
|
+
raise UserError(
|
282
|
+
f"Invalid param: '{param}' for model type '{self.model_info.model_type_id}'"
|
283
|
+
)
|
284
|
+
template = (
|
285
|
+
self.training_params['train_params']['template'] if 'template' in all_keys else None
|
286
|
+
)
|
287
|
+
|
288
|
+
request = service_pb2.ListModelTypesRequest(
|
289
|
+
user_app_id=self.user_app_id,
|
290
|
+
)
|
291
|
+
response = self._grpc_request(self.STUB.ListModelTypes, request)
|
292
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
293
|
+
raise Exception(response.status)
|
294
|
+
param_info = response_to_param_info(
|
295
|
+
response=response,
|
296
|
+
model_type_id=self.model_info.model_type_id,
|
297
|
+
param=param,
|
298
|
+
template=template,
|
299
|
+
)
|
300
|
+
|
301
|
+
return param_info
|
302
|
+
|
303
|
+
def train(self, yaml_file: str = None) -> str:
|
304
|
+
"""Trains the model based on the given yaml file or model params.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
yaml_file (str): The yaml file for the model params.
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
model_version_id (str): The model version ID for the model.
|
311
|
+
|
312
|
+
Example:
|
313
|
+
>>> from clarifai.client.model import Model
|
314
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
315
|
+
>>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
|
316
|
+
>>> model.train('model_params.yaml')
|
317
|
+
"""
|
318
|
+
if not self.model_info.model_type_id:
|
319
|
+
self.load_info()
|
320
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
321
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
322
|
+
if not yaml_file and len(self.training_params) == 0:
|
323
|
+
raise UserError("Provide yaml file or run 'model.get_params()'")
|
324
|
+
|
325
|
+
if yaml_file:
|
326
|
+
with open(yaml_file, 'r') as file:
|
327
|
+
params_dict = yaml.safe_load(file)
|
328
|
+
else:
|
329
|
+
params_dict = self.training_params
|
330
|
+
# getting all the concepts for the model type
|
331
|
+
if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
|
332
|
+
concepts = self._list_concepts()
|
333
|
+
train_dict = params_parser(params_dict, concepts)
|
334
|
+
request = service_pb2.PostModelVersionsRequest(
|
335
|
+
user_app_id=self.user_app_id,
|
336
|
+
model_id=self.id,
|
337
|
+
model_versions=[resources_pb2.ModelVersion(**train_dict)],
|
338
|
+
)
|
339
|
+
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
340
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
341
|
+
raise Exception(response.status)
|
342
|
+
self.logger.info("\nModel Training Started\n%s", response.status)
|
343
|
+
|
344
|
+
return response.model.model_version.id
|
345
|
+
|
346
|
+
def training_status(
|
347
|
+
self, version_id: str = None, training_logs: bool = False
|
348
|
+
) -> Dict[str, str]:
|
349
|
+
"""Get the training status for the model version. Also stores training logs
|
350
|
+
|
351
|
+
Args:
|
352
|
+
version_id (str): The version ID to get the training status for.
|
353
|
+
training_logs (bool): Whether to save the training logs in a file.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
training_status (Dict): Dictionary of training status for the model version.
|
357
|
+
|
358
|
+
Example:
|
359
|
+
>>> from clarifai.client.model import Model
|
360
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
361
|
+
>>> model.training_status(version_id='version_id',training_logs=True)
|
362
|
+
"""
|
363
|
+
if not version_id and not self.model_info.model_version.id:
|
364
|
+
raise UserError(
|
365
|
+
"Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
366
|
+
)
|
367
|
+
|
368
|
+
self.load_info()
|
369
|
+
if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
|
370
|
+
raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
|
371
|
+
|
372
|
+
if training_logs:
|
373
|
+
try:
|
374
|
+
if self.model_info.model_version.train_log:
|
375
|
+
log_response = requests.get(self.model_info.model_version.train_log)
|
376
|
+
log_response.raise_for_status() # Check for any HTTP errors
|
377
|
+
with open(version_id + '.log', 'wb') as file:
|
378
|
+
for chunk in log_response.iter_content(chunk_size=4096): # 4KB
|
379
|
+
file.write(chunk)
|
380
|
+
self.logger.info(f"\nTraining logs are saving in '{version_id + '.log'}' file")
|
381
|
+
|
382
|
+
except requests.exceptions.RequestException as e:
|
383
|
+
raise Exception(f"An error occurred while getting training logs: {e}")
|
384
|
+
|
385
|
+
return self.model_info.model_version.status
|
386
|
+
|
387
|
+
def delete_version(self, version_id: str) -> None:
|
388
|
+
"""Deletes a model version for the Model.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
version_id (str): The version ID to delete.
|
392
|
+
|
393
|
+
Example:
|
394
|
+
>>> from clarifai.client.model import Model
|
395
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
396
|
+
>>> model.delete_version(version_id='version_id')
|
397
|
+
"""
|
398
|
+
request = service_pb2.DeleteModelVersionRequest(
|
399
|
+
user_app_id=self.user_app_id, model_id=self.id, version_id=version_id
|
400
|
+
)
|
401
|
+
|
402
|
+
response = self._grpc_request(self.STUB.DeleteModelVersion, request)
|
403
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
404
|
+
raise Exception(response.status)
|
405
|
+
self.logger.info("\nModel Version Deleted\n%s", response.status)
|
406
|
+
|
407
|
+
def create_version(self, **kwargs) -> 'Model':
|
408
|
+
"""Creates a model version for the Model.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
**kwargs: Additional keyword arguments to be passed to Model Version.
|
412
|
+
- description (str): The description of the model version.
|
413
|
+
- concepts (list[Concept]): The concepts to associate with the model version.
|
414
|
+
- output_info (resources_pb2.OutputInfo(): The output info to associate with the model version.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
Model: A Model object for the specified model ID.
|
418
|
+
|
419
|
+
Example:
|
420
|
+
>>> from clarifai.client.model import Model
|
421
|
+
>>> model = Model("url")
|
422
|
+
or
|
423
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
424
|
+
>>> model_version = model.create_version(description='model_version_description')
|
425
|
+
"""
|
426
|
+
if self.model_info.model_type_id in TRAINABLE_MODEL_TYPES:
|
427
|
+
if 'pretrained_model_config' not in kwargs:
|
428
|
+
raise UserError(
|
429
|
+
f"{self.model_info.model_type_id} is a trainable model type. Use 'model.train()' to train the model"
|
430
|
+
)
|
431
|
+
|
432
|
+
request = service_pb2.PostModelVersionsRequest(
|
433
|
+
user_app_id=self.user_app_id,
|
434
|
+
model_id=self.id,
|
435
|
+
model_versions=[resources_pb2.ModelVersion(**kwargs)],
|
436
|
+
)
|
437
|
+
|
438
|
+
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
439
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
440
|
+
raise Exception(response.status)
|
441
|
+
self.logger.info("\nModel Version created\n%s", response.status)
|
442
|
+
|
443
|
+
kwargs.update({'app_id': self.app_id, 'user_id': self.user_id})
|
444
|
+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
445
|
+
kwargs = self.process_response_keys(dict_response['model'], 'model')
|
446
|
+
|
447
|
+
return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
|
448
|
+
|
449
|
+
def list_versions(
|
450
|
+
self, page_no: int = None, per_page: int = None
|
451
|
+
) -> Generator['Model', None, None]:
|
452
|
+
"""Lists all the versions for the model.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
page_no (int): The page number to list.
|
456
|
+
per_page (int): The number of items per page.
|
457
|
+
|
458
|
+
Yields:
|
459
|
+
Model: Model objects for the versions of the model.
|
460
|
+
|
461
|
+
Example:
|
462
|
+
>>> from clarifai.client.model import Model
|
463
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
464
|
+
or
|
465
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
466
|
+
>>> all_model_versions = list(model.list_versions())
|
467
|
+
|
468
|
+
Note:
|
469
|
+
Defaults to 16 per page if page_no is specified and per_page is not specified.
|
470
|
+
If both page_no and per_page are None, then lists all the resources.
|
471
|
+
"""
|
472
|
+
request_data = dict(
|
473
|
+
user_app_id=self.user_app_id,
|
474
|
+
model_id=self.id,
|
475
|
+
)
|
476
|
+
all_model_versions_info = self.list_pages_generator(
|
477
|
+
self.STUB.ListModelVersions,
|
478
|
+
service_pb2.ListModelVersionsRequest,
|
479
|
+
request_data,
|
480
|
+
per_page=per_page,
|
481
|
+
page_no=page_no,
|
482
|
+
)
|
483
|
+
|
484
|
+
for model_version_info in all_model_versions_info:
|
485
|
+
model_version_info['id'] = model_version_info['model_version_id']
|
486
|
+
del model_version_info['model_version_id']
|
487
|
+
try:
|
488
|
+
del model_version_info['train_info']['dataset']['version']['metrics']
|
489
|
+
except KeyError:
|
490
|
+
pass
|
491
|
+
yield Model.from_auth_helper(
|
492
|
+
auth=self.auth_helper,
|
493
|
+
model_id=self.id,
|
494
|
+
**dict(self.kwargs, model_version=model_version_info),
|
495
|
+
)
|
496
|
+
|
497
|
+
@property
|
498
|
+
def client(self):
|
499
|
+
if self._client is None:
|
500
|
+
request_template = service_pb2.PostModelOutputsRequest(
|
501
|
+
user_app_id=self.user_app_id,
|
502
|
+
model_id=self.id,
|
503
|
+
version_id=self.model_version.id,
|
504
|
+
model=self.model_info,
|
505
|
+
runner_selector=self._runner_selector,
|
506
|
+
)
|
507
|
+
self._client = ModelClient(self.STUB, request_template=request_template)
|
508
|
+
return self._client
|
509
|
+
|
510
|
+
def predict(self, *args, **kwargs):
|
511
|
+
"""
|
512
|
+
Calls the model's predict() method with the given arguments.
|
513
|
+
|
514
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
515
|
+
protos directly for compatibility with previous versions of the SDK.
|
516
|
+
"""
|
517
|
+
|
518
|
+
inputs = None
|
519
|
+
if 'inputs' in kwargs:
|
520
|
+
inputs = kwargs['inputs']
|
521
|
+
elif args:
|
522
|
+
inputs = args[0]
|
523
|
+
if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
|
524
|
+
assert len(args) <= 1, (
|
525
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
526
|
+
)
|
527
|
+
inference_params = kwargs.get('inference_params', {})
|
528
|
+
output_config = kwargs.get('output_config', {})
|
529
|
+
return self.client._predict_by_proto(
|
530
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
531
|
+
)
|
532
|
+
|
533
|
+
return self.client.predict(*args, **kwargs)
|
534
|
+
|
535
|
+
def __getattr__(self, name):
|
536
|
+
try:
|
537
|
+
return getattr(self.model_info, name)
|
538
|
+
except AttributeError:
|
539
|
+
pass
|
540
|
+
if not self._added_methods:
|
541
|
+
# fetch and set all the model methods
|
542
|
+
self._added_methods = True
|
543
|
+
self.client.fetch()
|
544
|
+
for method_name in self.client._method_signatures.keys():
|
545
|
+
if not hasattr(self, method_name):
|
546
|
+
setattr(self, method_name, getattr(self.client, method_name))
|
547
|
+
if hasattr(self.client, name):
|
548
|
+
return getattr(self.client, name)
|
549
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
550
|
+
|
551
|
+
def _check_predict_input_type(self, input_type: str) -> None:
|
552
|
+
"""Checks if the input type is valid for the model.
|
553
|
+
|
554
|
+
Args:
|
555
|
+
input_type (str): The input type to check.
|
556
|
+
Returns:
|
557
|
+
None
|
558
|
+
"""
|
559
|
+
if not input_type:
|
560
|
+
self.load_input_types()
|
561
|
+
if len(self.input_types) > 1:
|
562
|
+
raise UserError(
|
563
|
+
"Model has multiple input types. Please use model.predict() for this multi-modal model."
|
564
|
+
)
|
565
|
+
else:
|
566
|
+
self.input_types = [input_type]
|
567
|
+
if self.input_types[0] not in {'image', 'text', 'video', 'audio'}:
|
568
|
+
raise UserError(
|
569
|
+
f"Got input type {input_type} but expected one of image, text, video, audio."
|
570
|
+
)
|
571
|
+
|
572
|
+
def load_input_types(self) -> None:
|
573
|
+
"""Loads the input types for the model.
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
None
|
577
|
+
|
578
|
+
Example:
|
579
|
+
>>> from clarifai.client.model import Model
|
580
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
581
|
+
or
|
582
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
583
|
+
>>> model.load_input_types()
|
584
|
+
"""
|
585
|
+
if self.input_types:
|
586
|
+
return self.input_types
|
587
|
+
if self.model_info.model_type_id == "":
|
588
|
+
self.load_info()
|
589
|
+
request = service_pb2.GetModelTypeRequest(
|
590
|
+
user_app_id=self.user_app_id,
|
591
|
+
model_type_id=self.model_info.model_type_id,
|
592
|
+
)
|
593
|
+
response = self._grpc_request(self.STUB.GetModelType, request)
|
594
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
595
|
+
raise Exception(response.status)
|
596
|
+
self.input_types = response.model_type.input_fields
|
597
|
+
|
598
|
+
def _set_runner_selector(
|
599
|
+
self,
|
600
|
+
compute_cluster_id: str = None,
|
601
|
+
nodepool_id: str = None,
|
602
|
+
deployment_id: str = None,
|
603
|
+
user_id: str = None,
|
604
|
+
):
|
605
|
+
runner_selector = None
|
606
|
+
if deployment_id and (compute_cluster_id or nodepool_id):
|
607
|
+
raise UserError(
|
608
|
+
"You can only specify one of deployment_id or compute_cluster_id and nodepool_id."
|
609
|
+
)
|
610
|
+
|
611
|
+
if deployment_id:
|
612
|
+
if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
|
613
|
+
raise UserError(
|
614
|
+
"User ID is required for model prediction with deployment ID, please provide user_id in the method call."
|
615
|
+
)
|
616
|
+
if not user_id:
|
617
|
+
user_id = os.environ.get('CLARIFAI_USER_ID')
|
618
|
+
runner_selector = Deployment.get_runner_selector(
|
619
|
+
user_id=user_id, deployment_id=deployment_id
|
620
|
+
)
|
621
|
+
elif compute_cluster_id and nodepool_id:
|
622
|
+
if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
|
623
|
+
raise UserError(
|
624
|
+
"User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
|
625
|
+
)
|
626
|
+
if not user_id:
|
627
|
+
user_id = os.environ.get('CLARIFAI_USER_ID')
|
628
|
+
runner_selector = Nodepool.get_runner_selector(
|
629
|
+
user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id
|
630
|
+
)
|
631
|
+
|
632
|
+
# set the runner selector
|
633
|
+
self._runner_selector = runner_selector
|
634
|
+
|
635
|
+
def predict_by_filepath(
|
636
|
+
self,
|
637
|
+
filepath: str,
|
638
|
+
input_type: str = None,
|
639
|
+
inference_params: Dict = {},
|
640
|
+
output_config: Dict = {},
|
641
|
+
):
|
642
|
+
"""Predicts the model based on the given filepath.
|
643
|
+
|
644
|
+
Args:
|
645
|
+
filepath (str): The filepath to predict.
|
646
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
647
|
+
inference_params (dict): The inference params to override.
|
648
|
+
output_config (dict): The output config to override.
|
649
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
650
|
+
max_concepts (int): The maximum number of concepts to return.
|
651
|
+
select_concepts (list[Concept]): The concepts to select.
|
652
|
+
|
653
|
+
Example:
|
654
|
+
>>> from clarifai.client.model import Model
|
655
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
656
|
+
or
|
657
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
658
|
+
>>> model_prediction = model.predict_by_filepath('/path/to/image.jpg')
|
659
|
+
>>> model_prediction = model.predict_by_filepath('/path/to/text.txt')
|
660
|
+
"""
|
661
|
+
if not os.path.isfile(filepath):
|
662
|
+
raise UserError('Invalid filepath.')
|
663
|
+
|
664
|
+
with open(filepath, "rb") as f:
|
665
|
+
file_bytes = f.read()
|
666
|
+
|
667
|
+
return self.predict_by_bytes(file_bytes, input_type, inference_params, output_config)
|
668
|
+
|
669
|
+
def predict_by_bytes(
|
670
|
+
self,
|
671
|
+
input_bytes: bytes,
|
672
|
+
input_type: str = None,
|
673
|
+
inference_params: Dict = {},
|
674
|
+
output_config: Dict = {},
|
675
|
+
):
|
676
|
+
"""Predicts the model based on the given bytes.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
input_bytes (bytes): File Bytes to predict on.
|
680
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
681
|
+
inference_params (dict): The inference params to override.
|
682
|
+
output_config (dict): The output config to override.
|
683
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
684
|
+
max_concepts (int): The maximum number of concepts to return.
|
685
|
+
select_concepts (list[Concept]): The concepts to select.
|
686
|
+
|
687
|
+
Example:
|
688
|
+
>>> from clarifai.client.model import Model
|
689
|
+
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
|
690
|
+
>>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI',
|
691
|
+
inference_params=dict(temperature=str(0.7), max_tokens=30)))
|
692
|
+
"""
|
693
|
+
self._check_predict_input_type(input_type)
|
694
|
+
|
695
|
+
if self.input_types[0] == "image":
|
696
|
+
input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
|
697
|
+
elif self.input_types[0] == "text":
|
698
|
+
input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
|
699
|
+
elif self.input_types[0] == "video":
|
700
|
+
input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
|
701
|
+
elif self.input_types[0] == "audio":
|
702
|
+
input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
|
703
|
+
|
704
|
+
return self.predict(
|
705
|
+
inputs=[input_proto], inference_params=inference_params, output_config=output_config
|
706
|
+
)
|
707
|
+
|
708
|
+
def predict_by_url(
|
709
|
+
self,
|
710
|
+
url: str,
|
711
|
+
input_type: str = None,
|
712
|
+
inference_params: Dict = {},
|
713
|
+
output_config: Dict = {},
|
714
|
+
):
|
715
|
+
"""Predicts the model based on the given URL.
|
716
|
+
|
717
|
+
Args:
|
718
|
+
url (str): The URL to predict.
|
719
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio'.
|
720
|
+
inference_params (dict): The inference params to override.
|
721
|
+
output_config (dict): The output config to override.
|
722
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
723
|
+
max_concepts (int): The maximum number of concepts to return.
|
724
|
+
select_concepts (list[Concept]): The concepts to select.
|
725
|
+
|
726
|
+
Example:
|
727
|
+
>>> from clarifai.client.model import Model
|
728
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
729
|
+
or
|
730
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
731
|
+
>>> model_prediction = model.predict_by_url('url')
|
732
|
+
"""
|
733
|
+
self._check_predict_input_type(input_type)
|
734
|
+
|
735
|
+
if self.input_types[0] == "image":
|
736
|
+
input_proto = Inputs.get_input_from_url("", image_url=url)
|
737
|
+
elif self.input_types[0] == "text":
|
738
|
+
input_proto = Inputs.get_input_from_url("", text_url=url)
|
739
|
+
elif self.input_types[0] == "video":
|
740
|
+
input_proto = Inputs.get_input_from_url("", video_url=url)
|
741
|
+
elif self.input_types[0] == "audio":
|
742
|
+
input_proto = Inputs.get_input_from_url("", audio_url=url)
|
743
|
+
|
744
|
+
return self.predict(
|
745
|
+
inputs=[input_proto], inference_params=inference_params, output_config=output_config
|
746
|
+
)
|
747
|
+
|
748
|
+
def generate(self, *args, **kwargs):
|
749
|
+
"""
|
750
|
+
Calls the model's generate() method with the given arguments.
|
751
|
+
|
752
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
753
|
+
protos directly for compatibility with previous versions of the SDK.
|
754
|
+
"""
|
755
|
+
|
756
|
+
inputs = None
|
757
|
+
if 'inputs' in kwargs:
|
758
|
+
inputs = kwargs['inputs']
|
759
|
+
elif args:
|
760
|
+
inputs = args[0]
|
761
|
+
if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
|
762
|
+
assert len(args) <= 1, (
|
763
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
764
|
+
)
|
765
|
+
inference_params = kwargs.get('inference_params', {})
|
766
|
+
output_config = kwargs.get('output_config', {})
|
767
|
+
return self.client._generate_by_proto(
|
768
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
769
|
+
)
|
770
|
+
|
771
|
+
return self.client.generate(*args, **kwargs)
|
772
|
+
|
773
|
+
def generate_by_filepath(
|
774
|
+
self,
|
775
|
+
filepath: str,
|
776
|
+
input_type: str = None,
|
777
|
+
inference_params: Dict = {},
|
778
|
+
output_config: Dict = {},
|
779
|
+
):
|
780
|
+
"""Generate the stream output on model based on the given filepath.
|
781
|
+
|
782
|
+
Args:
|
783
|
+
filepath (str): The filepath to predict.
|
784
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
785
|
+
inference_params (dict): The inference params to override.
|
786
|
+
output_config (dict): The output config to override.
|
787
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
788
|
+
max_concepts (int): The maximum number of concepts to return.
|
789
|
+
select_concepts (list[Concept]): The concepts to select.
|
790
|
+
|
791
|
+
Example:
|
792
|
+
>>> from clarifai.client.model import Model
|
793
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
794
|
+
or
|
795
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
796
|
+
>>> stream_response = model.generate_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
|
797
|
+
>>> list_stream_response = [response for response in stream_response]
|
798
|
+
"""
|
799
|
+
if not os.path.isfile(filepath):
|
800
|
+
raise UserError('Invalid filepath.')
|
801
|
+
|
802
|
+
with open(filepath, "rb") as f:
|
803
|
+
file_bytes = f.read()
|
804
|
+
|
805
|
+
return self.generate_by_bytes(
|
806
|
+
input_bytes=file_bytes,
|
807
|
+
input_type=input_type,
|
808
|
+
inference_params=inference_params,
|
809
|
+
output_config=output_config,
|
810
|
+
)
|
811
|
+
|
812
|
+
def generate_by_bytes(
|
813
|
+
self,
|
814
|
+
input_bytes: bytes,
|
815
|
+
input_type: str = None,
|
816
|
+
inference_params: Dict = {},
|
817
|
+
output_config: Dict = {},
|
818
|
+
):
|
819
|
+
"""Generate the stream output on model based on the given bytes.
|
820
|
+
|
821
|
+
Args:
|
822
|
+
input_bytes (bytes): File Bytes to predict on.
|
823
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
824
|
+
inference_params (dict): The inference params to override.
|
825
|
+
output_config (dict): The output config to override.
|
826
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
827
|
+
max_concepts (int): The maximum number of concepts to return.
|
828
|
+
select_concepts (list[Concept]): The concepts to select.
|
829
|
+
|
830
|
+
Example:
|
831
|
+
>>> from clarifai.client.model import Model
|
832
|
+
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
|
833
|
+
>>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
|
834
|
+
deployment_id='deployment_id',
|
835
|
+
inference_params=dict(temperature=str(0.7), max_tokens=30)))
|
836
|
+
>>> list_stream_response = [response for response in stream_response]
|
837
|
+
"""
|
838
|
+
self._check_predict_input_type(input_type)
|
839
|
+
|
896
840
|
if self.input_types[0] == "image":
|
897
|
-
|
841
|
+
input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
|
898
842
|
elif self.input_types[0] == "text":
|
899
|
-
|
843
|
+
input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
|
900
844
|
elif self.input_types[0] == "video":
|
901
|
-
|
845
|
+
input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
|
902
846
|
elif self.input_types[0] == "audio":
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
847
|
+
input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
|
848
|
+
|
849
|
+
return self.generate(
|
850
|
+
inputs=[input_proto], inference_params=inference_params, output_config=output_config
|
851
|
+
)
|
852
|
+
|
853
|
+
def generate_by_url(
|
854
|
+
self,
|
855
|
+
url: str,
|
856
|
+
input_type: str = None,
|
857
|
+
inference_params: Dict = {},
|
858
|
+
output_config: Dict = {},
|
859
|
+
):
|
860
|
+
"""Generate the stream output on model based on the given URL.
|
861
|
+
|
862
|
+
Args:
|
863
|
+
url (str): The URL to predict.
|
864
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
865
|
+
inference_params (dict): The inference params to override.
|
866
|
+
output_config (dict): The output config to override.
|
867
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
868
|
+
max_concepts (int): The maximum number of concepts to return.
|
869
|
+
select_concepts (list[Concept]): The concepts to select.
|
870
|
+
|
871
|
+
Example:
|
872
|
+
>>> from clarifai.client.model import Model
|
873
|
+
>>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
|
874
|
+
or
|
875
|
+
>>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
|
876
|
+
>>> stream_response = model.generate_by_url('url', deployment_id='deployment_id')
|
877
|
+
>>> list_stream_response = [response for response in stream_response]
|
878
|
+
"""
|
879
|
+
self._check_predict_input_type(input_type)
|
880
|
+
|
934
881
|
if self.input_types[0] == "image":
|
935
|
-
|
882
|
+
input_proto = Inputs.get_input_from_url("", image_url=url)
|
936
883
|
elif self.input_types[0] == "text":
|
937
|
-
|
884
|
+
input_proto = Inputs.get_input_from_url("", text_url=url)
|
938
885
|
elif self.input_types[0] == "video":
|
939
|
-
|
886
|
+
input_proto = Inputs.get_input_from_url("", video_url=url)
|
940
887
|
elif self.input_types[0] == "audio":
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
888
|
+
input_proto = Inputs.get_input_from_url("", audio_url=url)
|
889
|
+
|
890
|
+
return self.generate(
|
891
|
+
inputs=[input_proto], inference_params=inference_params, output_config=output_config
|
892
|
+
)
|
893
|
+
|
894
|
+
def stream(self, *args, **kwargs):
|
895
|
+
"""
|
896
|
+
Calls the model's stream() method with the given arguments.
|
897
|
+
|
898
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
899
|
+
protos directly for compatibility with previous versions of the SDK.
|
900
|
+
"""
|
901
|
+
|
902
|
+
use_proto_call = False
|
903
|
+
inputs = None
|
904
|
+
if 'inputs' in kwargs:
|
905
|
+
inputs = kwargs['inputs']
|
906
|
+
elif args:
|
907
|
+
inputs = args[0]
|
908
|
+
if inputs and isinstance(inputs, Iterable):
|
909
|
+
inputs_iter = inputs
|
910
|
+
try:
|
911
|
+
peek = next(inputs_iter)
|
912
|
+
except StopIteration:
|
913
|
+
pass
|
914
|
+
else:
|
915
|
+
use_proto_call = (
|
916
|
+
peek and isinstance(peek, list) and isinstance(peek[0], resources_pb2.Input)
|
917
|
+
)
|
918
|
+
# put back the peeked value
|
919
|
+
if inputs_iter is inputs:
|
920
|
+
inputs = itertools.chain([peek], inputs_iter)
|
921
|
+
if 'inputs' in kwargs:
|
922
|
+
kwargs['inputs'] = inputs
|
923
|
+
else:
|
924
|
+
args = (inputs,) + args[1:]
|
925
|
+
|
926
|
+
if use_proto_call:
|
927
|
+
assert len(args) <= 1, (
|
928
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
929
|
+
)
|
930
|
+
inference_params = kwargs.get('inference_params', {})
|
931
|
+
output_config = kwargs.get('output_config', {})
|
932
|
+
return self.client._stream_by_proto(
|
933
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
934
|
+
)
|
935
|
+
|
936
|
+
return self.client.stream(*args, **kwargs)
|
937
|
+
|
938
|
+
def stream_by_filepath(
|
939
|
+
self,
|
940
|
+
filepath: str,
|
941
|
+
input_type: str = None,
|
942
|
+
inference_params: Dict = {},
|
943
|
+
output_config: Dict = {},
|
944
|
+
):
|
945
|
+
"""Stream the model output based on the given filepath.
|
946
|
+
|
947
|
+
Args:
|
948
|
+
filepath (str): The filepath to predict.
|
949
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
950
|
+
inference_params (dict): The inference params to override.
|
951
|
+
output_config (dict): The output config to override.
|
952
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
953
|
+
max_concepts (int): The maximum number of concepts to return.
|
954
|
+
select_concepts (list[Concept]): The concepts to select.
|
955
|
+
|
956
|
+
Example:
|
957
|
+
>>> from clarifai.client.model import Model
|
958
|
+
>>> model = Model("url")
|
959
|
+
>>> stream_response = model.stream_by_filepath('/path/to/image.jpg', deployment_id='deployment_id')
|
960
|
+
>>> list_stream_response = [response for response in stream_response]
|
961
|
+
"""
|
962
|
+
if not os.path.isfile(filepath):
|
963
|
+
raise UserError('Invalid filepath.')
|
964
|
+
|
965
|
+
with open(filepath, "rb") as f:
|
966
|
+
file_bytes = f.read()
|
967
|
+
|
968
|
+
return self.stream_by_bytes(
|
969
|
+
input_bytes_iterator=iter([file_bytes]),
|
970
|
+
input_type=input_type,
|
971
|
+
inference_params=inference_params,
|
972
|
+
output_config=output_config,
|
973
|
+
)
|
974
|
+
|
975
|
+
def stream_by_bytes(
|
976
|
+
self,
|
977
|
+
input_bytes_iterator: Iterator[bytes],
|
978
|
+
input_type: str = None,
|
979
|
+
inference_params: Dict = {},
|
980
|
+
output_config: Dict = {},
|
981
|
+
):
|
982
|
+
"""Stream the model output based on the given bytes.
|
983
|
+
|
984
|
+
Args:
|
985
|
+
input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
|
986
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
987
|
+
inference_params (dict): The inference params to override.
|
988
|
+
output_config (dict): The output config to override.
|
989
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
990
|
+
max_concepts (int): The maximum number of concepts to return.
|
991
|
+
select_concepts (list[Concept]): The concepts to select.
|
992
|
+
|
993
|
+
Example:
|
994
|
+
>>> from clarifai.client.model import Model
|
995
|
+
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
|
996
|
+
>>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
|
997
|
+
deployment_id='deployment_id',
|
998
|
+
inference_params=dict(temperature=str(0.7), max_tokens=30)))
|
999
|
+
>>> list_stream_response = [response for response in stream_response]
|
1000
|
+
"""
|
1001
|
+
self._check_predict_input_type(input_type)
|
1002
|
+
|
1003
|
+
def input_generator():
|
1004
|
+
for input_bytes in input_bytes_iterator:
|
1005
|
+
if self.input_types[0] == "image":
|
1006
|
+
yield [Inputs.get_input_from_bytes("", image_bytes=input_bytes)]
|
1007
|
+
elif self.input_types[0] == "text":
|
1008
|
+
yield [Inputs.get_input_from_bytes("", text_bytes=input_bytes)]
|
1009
|
+
elif self.input_types[0] == "video":
|
1010
|
+
yield [Inputs.get_input_from_bytes("", video_bytes=input_bytes)]
|
1011
|
+
elif self.input_types[0] == "audio":
|
1012
|
+
yield [Inputs.get_input_from_bytes("", audio_bytes=input_bytes)]
|
1013
|
+
|
1014
|
+
return self.stream(
|
1015
|
+
inputs=input_generator(),
|
1016
|
+
inference_params=inference_params,
|
1017
|
+
output_config=output_config,
|
1018
|
+
)
|
1019
|
+
|
1020
|
+
def stream_by_url(
|
1021
|
+
self,
|
1022
|
+
url_iterator: Iterator[str],
|
1023
|
+
input_type: str = None,
|
1024
|
+
inference_params: Dict = {},
|
1025
|
+
output_config: Dict = {},
|
1026
|
+
):
|
1027
|
+
"""Stream the model output based on the given URL.
|
1028
|
+
|
1029
|
+
Args:
|
1030
|
+
url_iterator (Iterator[str]): Iterator of URLs to predict.
|
1031
|
+
input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
|
1032
|
+
inference_params (dict): The inference params to override.
|
1033
|
+
output_config (dict): The output config to override.
|
1034
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
1035
|
+
max_concepts (int): The maximum number of concepts to return.
|
1036
|
+
select_concepts (list[Concept]): The concepts to select.
|
1037
|
+
|
1038
|
+
Example:
|
1039
|
+
>>> from clarifai.client.model import Model
|
1040
|
+
>>> model = Model("url")
|
1041
|
+
>>> stream_response = model.stream_by_url(iter(['url']), deployment_id='deployment_id')
|
1042
|
+
>>> list_stream_response = [response for response in stream_response]
|
1043
|
+
"""
|
1044
|
+
self._check_predict_input_type(input_type)
|
1045
|
+
|
1046
|
+
def input_generator():
|
1047
|
+
for url in url_iterator:
|
1048
|
+
if self.input_types[0] == "image":
|
1049
|
+
yield [Inputs.get_input_from_url("", image_url=url)]
|
1050
|
+
elif self.input_types[0] == "text":
|
1051
|
+
yield [Inputs.get_input_from_url("", text_url=url)]
|
1052
|
+
elif self.input_types[0] == "video":
|
1053
|
+
yield [Inputs.get_input_from_url("", video_url=url)]
|
1054
|
+
elif self.input_types[0] == "audio":
|
1055
|
+
yield [Inputs.get_input_from_url("", audio_url=url)]
|
1056
|
+
|
1057
|
+
return self.stream(
|
1058
|
+
inputs=input_generator(),
|
1059
|
+
inference_params=inference_params,
|
1060
|
+
output_config=output_config,
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
def _override_model_version(
|
1064
|
+
self, inference_params: Dict = {}, output_config: Dict = {}
|
1065
|
+
) -> None:
|
1066
|
+
"""Overrides the model version.
|
1067
|
+
|
1068
|
+
Args:
|
1069
|
+
inference_params (dict): The inference params to override.
|
1070
|
+
output_config (dict): The output config to override.
|
1071
|
+
min_value (float): The minimum value of the prediction confidence to filter.
|
1072
|
+
max_concepts (int): The maximum number of concepts to return.
|
1073
|
+
select_concepts (list[Concept]): The concepts to select.
|
1074
|
+
sample_ms (int): The number of milliseconds to sample.
|
1075
|
+
"""
|
1076
|
+
params = Struct()
|
1077
|
+
if inference_params is not None:
|
1078
|
+
params.update(inference_params)
|
1079
|
+
|
1080
|
+
self.model_info.model_version.output_info.CopyFrom(
|
1081
|
+
resources_pb2.OutputInfo(
|
1082
|
+
output_config=resources_pb2.OutputConfig(**output_config), params=params
|
1083
|
+
)
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
def _list_concepts(self) -> List[str]:
|
1087
|
+
"""Lists all the concepts for the model type.
|
1088
|
+
|
1089
|
+
Returns:
|
1090
|
+
concepts (List): List of concepts for the model type.
|
1091
|
+
"""
|
1092
|
+
request_data = dict(user_app_id=self.user_app_id)
|
1093
|
+
all_concepts_infos = self.list_pages_generator(
|
1094
|
+
self.STUB.ListConcepts, service_pb2.ListConceptsRequest, request_data
|
1095
|
+
)
|
1096
|
+
return [concept_info['concept_id'] for concept_info in all_concepts_infos]
|
1097
|
+
|
1098
|
+
def load_info(self) -> None:
|
1099
|
+
"""Loads the model info."""
|
1100
|
+
request = service_pb2.GetModelRequest(
|
1101
|
+
user_app_id=self.user_app_id,
|
1102
|
+
model_id=self.id,
|
1103
|
+
version_id=self.model_info.model_version.id,
|
1104
|
+
)
|
1105
|
+
response = self._grpc_request(self.STUB.GetModel, request)
|
1106
|
+
|
1107
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1108
|
+
raise Exception(response.status)
|
1109
|
+
|
1110
|
+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
1111
|
+
self.kwargs = self.process_response_keys(dict_response['model'])
|
1112
|
+
self.model_info = resources_pb2.Model()
|
1113
|
+
dict_to_protobuf(self.model_info, self.kwargs)
|
1114
|
+
|
1115
|
+
def __str__(self):
|
1116
|
+
if len(self.kwargs) < 10:
|
1117
|
+
self.load_info()
|
1118
|
+
|
1119
|
+
init_params = [param for param in self.kwargs.keys()]
|
1120
|
+
attribute_strings = [
|
1121
|
+
f"{param}={getattr(self.model_info, param)}"
|
1122
|
+
for param in init_params
|
1123
|
+
if hasattr(self.model_info, param)
|
1124
|
+
]
|
1125
|
+
return f"Model Details: \n{', '.join(attribute_strings)}\n"
|
1126
|
+
|
1127
|
+
def list_evaluations(self) -> resources_pb2.EvalMetrics:
|
1128
|
+
"""List all eval_metrics of current model version
|
1129
|
+
|
1130
|
+
Raises:
|
1131
|
+
Exception: Failed to call API
|
1132
|
+
|
1133
|
+
Returns:
|
1134
|
+
resources_pb2.EvalMetrics
|
1135
|
+
"""
|
1136
|
+
assert self.model_info.model_version.id, (
|
1137
|
+
"Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
1138
|
+
)
|
1139
|
+
request = service_pb2.ListModelVersionEvaluationsRequest(
|
1140
|
+
user_app_id=self.user_app_id,
|
1141
|
+
model_id=self.id,
|
1142
|
+
model_version_id=self.model_info.model_version.id,
|
1143
|
+
)
|
1144
|
+
response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
|
1145
|
+
|
1146
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1147
|
+
raise Exception(response.status)
|
1148
|
+
|
1149
|
+
return response.eval_metrics
|
1150
|
+
|
1151
|
+
def evaluate(
|
1152
|
+
self,
|
1153
|
+
dataset: Dataset = None,
|
1154
|
+
dataset_id: str = None,
|
1155
|
+
dataset_app_id: str = None,
|
1156
|
+
dataset_user_id: str = None,
|
1157
|
+
dataset_version_id: str = None,
|
1158
|
+
eval_id: str = None,
|
1159
|
+
extended_metrics: dict = None,
|
1160
|
+
eval_info: dict = None,
|
1161
|
+
) -> resources_pb2.EvalMetrics:
|
1162
|
+
"""Run evaluation
|
1163
|
+
|
1164
|
+
Args:
|
1165
|
+
dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
|
1166
|
+
dataset_id (str): Dataset Id. Default is None.
|
1167
|
+
dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
|
1168
|
+
dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
|
1169
|
+
dataset_version_id (str): Dataset version Id. Default is None.
|
1170
|
+
eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
|
1171
|
+
extended_metrics (dict): user custom metrics result. Default is None.
|
1172
|
+
eval_info (dict): custom eval info. Default is empty dict.
|
1173
|
+
|
1174
|
+
Return
|
1175
|
+
eval_metrics
|
1176
|
+
|
1177
|
+
"""
|
1178
|
+
assert self.model_info.model_version.id, (
|
1179
|
+
"Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
1180
|
+
)
|
1181
|
+
|
1182
|
+
if dataset:
|
1183
|
+
self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
|
1184
|
+
dataset_id = dataset.id
|
1185
|
+
dataset_app_id = dataset.app_id
|
1186
|
+
dataset_user_id = dataset.user_id
|
1187
|
+
dataset_version_id = dataset.version.id
|
1188
|
+
else:
|
1189
|
+
self.logger.warning(
|
1190
|
+
"Arguments prefixed with `dataset_` will be removed soon, please use dataset"
|
1191
|
+
)
|
1192
|
+
|
1193
|
+
gt_dataset = resources_pb2.Dataset(
|
1194
|
+
id=dataset_id,
|
1195
|
+
app_id=dataset_app_id or self.auth_helper.app_id,
|
1196
|
+
user_id=dataset_user_id or self.auth_helper.user_id,
|
1197
|
+
version=resources_pb2.DatasetVersion(id=dataset_version_id),
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
metrics = None
|
1201
|
+
if isinstance(extended_metrics, dict):
|
1202
|
+
metrics = Struct()
|
1203
|
+
metrics.update(extended_metrics)
|
1204
|
+
metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
|
1205
|
+
|
1206
|
+
eval_info_params = None
|
1207
|
+
if isinstance(eval_info, dict):
|
1208
|
+
eval_info_params = Struct()
|
1209
|
+
eval_info_params.update(eval_info)
|
1210
|
+
eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
|
1211
|
+
|
1212
|
+
eval_metric = resources_pb2.EvalMetrics(
|
1213
|
+
id=eval_id,
|
1214
|
+
model=resources_pb2.Model(
|
1215
|
+
id=self.id,
|
1216
|
+
app_id=self.auth_helper.app_id,
|
1217
|
+
user_id=self.auth_helper.user_id,
|
1218
|
+
model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
|
1219
|
+
),
|
1220
|
+
extended_metrics=metrics,
|
1221
|
+
ground_truth_dataset=gt_dataset,
|
1222
|
+
eval_info=eval_info_params,
|
1223
|
+
)
|
1224
|
+
request = service_pb2.PostEvaluationsRequest(
|
1225
|
+
user_app_id=self.user_app_id,
|
1226
|
+
eval_metrics=[eval_metric],
|
1227
|
+
)
|
1228
|
+
response = self._grpc_request(self.STUB.PostEvaluations, request)
|
1229
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1230
|
+
raise Exception(response.status)
|
1231
|
+
self.logger.info(
|
1232
|
+
"\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
|
1233
|
+
)
|
1234
|
+
|
1235
|
+
return response.eval_metrics
|
1236
|
+
|
1237
|
+
def get_eval_by_id(
|
1238
|
+
self,
|
1239
|
+
eval_id: str,
|
1240
|
+
label_counts=False,
|
1241
|
+
test_set=False,
|
1242
|
+
binary_metrics=False,
|
1243
|
+
confusion_matrix=False,
|
1244
|
+
metrics_by_class=False,
|
1245
|
+
metrics_by_area=False,
|
1246
|
+
) -> resources_pb2.EvalMetrics:
|
1247
|
+
"""Get detail eval_metrics by eval_id with extra metric fields
|
1248
|
+
|
1249
|
+
Args:
|
1250
|
+
eval_id (str): eval id
|
1251
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
1252
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
1253
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
1254
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
1255
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
1256
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
1257
|
+
|
1258
|
+
Raises:
|
1259
|
+
Exception: Failed to call API
|
1260
|
+
|
1261
|
+
Returns:
|
1262
|
+
resources_pb2.EvalMetrics: eval_metrics
|
1263
|
+
"""
|
1264
|
+
request = service_pb2.GetEvaluationRequest(
|
1265
|
+
user_app_id=self.user_app_id,
|
1266
|
+
evaluation_id=eval_id,
|
1267
|
+
fields=resources_pb2.FieldsValue(
|
1268
|
+
label_counts=label_counts,
|
1269
|
+
test_set=test_set,
|
1270
|
+
binary_metrics=binary_metrics,
|
1271
|
+
confusion_matrix=confusion_matrix,
|
1272
|
+
metrics_by_class=metrics_by_class,
|
1273
|
+
metrics_by_area=metrics_by_area,
|
1274
|
+
),
|
1275
|
+
)
|
1276
|
+
response = self._grpc_request(self.STUB.GetEvaluation, request)
|
1277
|
+
|
1278
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1279
|
+
raise Exception(response.status)
|
1280
|
+
|
1281
|
+
return response.eval_metrics
|
1282
|
+
|
1283
|
+
def get_latest_eval(
|
1284
|
+
self,
|
1285
|
+
label_counts=False,
|
1286
|
+
test_set=False,
|
1287
|
+
binary_metrics=False,
|
1288
|
+
confusion_matrix=False,
|
1289
|
+
metrics_by_class=False,
|
1290
|
+
metrics_by_area=False,
|
1291
|
+
) -> Union[resources_pb2.EvalMetrics, None]:
|
1292
|
+
"""
|
1293
|
+
Run `get_eval_by_id` method with latest `eval_id`
|
1294
|
+
|
1295
|
+
Args:
|
1296
|
+
label_counts (bool, optional): Set True to get label counts. Defaults to False.
|
1297
|
+
test_set (bool, optional): Set True to get test set. Defaults to False.
|
1298
|
+
binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
|
1299
|
+
confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
|
1300
|
+
metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
|
1301
|
+
metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
|
1302
|
+
|
1303
|
+
Returns:
|
1304
|
+
eval_metric if model is evaluated otherwise None.
|
1305
|
+
|
1306
|
+
"""
|
1307
|
+
|
1308
|
+
_latest = self.list_evaluations()[0]
|
1309
|
+
result = None
|
1310
|
+
if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
|
1311
|
+
result = self.get_eval_by_id(
|
1312
|
+
eval_id=_latest.id,
|
1313
|
+
label_counts=label_counts,
|
1314
|
+
test_set=test_set,
|
1315
|
+
binary_metrics=binary_metrics,
|
1316
|
+
confusion_matrix=confusion_matrix,
|
1317
|
+
metrics_by_class=metrics_by_class,
|
1318
|
+
metrics_by_area=metrics_by_area,
|
1319
|
+
)
|
1320
|
+
|
1321
|
+
return result
|
1322
|
+
|
1323
|
+
def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
|
1324
|
+
"""Get all eval data of dataset
|
1325
|
+
|
1326
|
+
Args:
|
1327
|
+
dataset (Dataset): Clarifai dataset
|
1328
|
+
|
1329
|
+
Returns:
|
1330
|
+
List[resources_pb2.EvalMetrics]
|
1331
|
+
"""
|
1332
|
+
_id = dataset.id
|
1333
|
+
app = dataset.app_id or self.app_id
|
1334
|
+
user_id = dataset.user_id or self.user_id
|
1335
|
+
version = dataset.version.id
|
1336
|
+
|
1337
|
+
list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
|
1338
|
+
outputs = []
|
1339
|
+
for _eval in list_eval:
|
1340
|
+
if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
|
1341
|
+
gt_ds = _eval.ground_truth_dataset
|
1342
|
+
if _id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id:
|
1343
|
+
if not version or version == gt_ds.version.id:
|
1344
|
+
outputs.append(_eval)
|
1345
|
+
|
1346
|
+
return outputs
|
1347
|
+
|
1348
|
+
def get_raw_eval(
|
1349
|
+
self, dataset: Dataset = None, eval_id: str = None, return_format: str = 'array'
|
1350
|
+
) -> Union[
|
1351
|
+
resources_pb2.EvalTestSetEntry,
|
1352
|
+
Tuple[np.array, np.array, list, List[Input]],
|
1353
|
+
Tuple[List[dict], List[dict]],
|
1354
|
+
]:
|
1355
|
+
"""Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
|
1356
|
+
|
1357
|
+
Args:
|
1358
|
+
dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
|
1359
|
+
eval_id (str): Evaluation ID, get eval data of specific eval id.
|
1360
|
+
return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
|
1361
|
+
|
1362
|
+
Returns:
|
1363
|
+
|
1364
|
+
Depends on `return_format`.
|
1365
|
+
|
1366
|
+
* if return_format == proto
|
1367
|
+
`resources_pb2.EvalTestSetEntry`
|
1368
|
+
|
1369
|
+
* if return_format == array
|
1370
|
+
`Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
|
1371
|
+
y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
|
1372
|
+
- if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
|
1373
|
+
- if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
|
1374
|
+
|
1375
|
+
* if return_format == coco: Applicable only for 'visual-detector'
|
1376
|
+
`Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
|
1377
|
+
|
1378
|
+
Example Usages:
|
1379
|
+
------
|
1380
|
+
* Evaluate `visual-classifier` using sklearn
|
1381
|
+
|
1382
|
+
```python
|
1383
|
+
import os
|
1384
|
+
from sklearn.metrics import accuracy_score
|
1385
|
+
from sklearn.metrics import classification_report
|
1386
|
+
import numpy as np
|
1387
|
+
from clarifai.client.model import Model
|
1388
|
+
from clarifai.client.dataset import Dataset
|
1389
|
+
os.environ["CLARIFAI_PAT"] = "???"
|
1390
|
+
model = Model(url="url/of/model/includes/version-id")
|
1391
|
+
dataset = Dataset(dataset_id="dataset-id")
|
1392
|
+
y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
|
1393
|
+
y = np.argmax(y, axis=1)
|
1394
|
+
y_pred = np.argmax(y_pred, axis=1)
|
1395
|
+
report = classification_report(y, y_pred, target_names=clss)
|
1396
|
+
print(report)
|
1397
|
+
acc = accuracy_score(y, y_pred)
|
1398
|
+
print("acc ", acc)
|
1399
|
+
```
|
1400
|
+
|
1401
|
+
* Evaluate `visual-detector` using COCOeval
|
1402
|
+
|
1403
|
+
```python
|
1404
|
+
import os
|
1405
|
+
import json
|
1406
|
+
from pycocotools.coco import COCO
|
1407
|
+
from pycocotools.cocoeval import COCOeval
|
1408
|
+
from clarifai.client.model import Model
|
1409
|
+
from clarifai.client.dataset import Dataset
|
1410
|
+
os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
|
1411
|
+
model = Model(url=model_url)
|
1412
|
+
dataset = Dataset(url=dataset_url)
|
1413
|
+
y, y_pred = model.get_raw_eval(dataset, return_format="coco")
|
1414
|
+
# save as files to load in COCO API
|
1415
|
+
def save_annot(d, path):
|
1416
|
+
with open(path, "w") as fp:
|
1417
|
+
json.dump(d, fp, indent=2)
|
1418
|
+
gt_path = os.path.join("gt.json")
|
1419
|
+
pred_path = os.path.join("pred.json")
|
1420
|
+
save_annot(y, gt_path)
|
1421
|
+
save_annot(y_pred, pred_path)
|
1422
|
+
|
1423
|
+
cocoGt = COCO(gt_path)
|
1424
|
+
cocoPred = COCO(pred_path)
|
1425
|
+
cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
|
1426
|
+
cocoEval.evaluate()
|
1427
|
+
cocoEval.accumulate()
|
1428
|
+
cocoEval.summarize() # Print out result of all classes with all area type
|
1429
|
+
# Example:
|
1430
|
+
# Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
|
1431
|
+
# Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
|
1432
|
+
# Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
|
1433
|
+
# ...
|
1434
|
+
```
|
1435
|
+
|
1436
|
+
"""
|
1437
|
+
from clarifai.utils.evaluation.testset_annotation_parser import (
|
1438
|
+
parse_eval_annotation_classifier,
|
1439
|
+
parse_eval_annotation_detector,
|
1440
|
+
parse_eval_annotation_detector_coco,
|
1441
|
+
)
|
1442
|
+
|
1443
|
+
valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
|
1444
|
+
supported_format = ['proto', 'array', 'coco']
|
1445
|
+
assert return_format in supported_format, ValueError(
|
1446
|
+
f"Expected return_format in {supported_format}, got {return_format}"
|
1447
|
+
)
|
1448
|
+
self.load_info()
|
1449
|
+
model_type_id = self.model_info.model_type_id
|
1450
|
+
assert model_type_id in valid_model_types, (
|
1451
|
+
f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
|
1452
|
+
)
|
1453
|
+
assert not (dataset and eval_id), (
|
1454
|
+
"Using both `dataset` and `eval_id`, but only one should be passed."
|
1455
|
+
)
|
1456
|
+
assert not dataset or not eval_id, (
|
1457
|
+
"Please provide either `dataset` or `eval_id`, but nothing was passed."
|
1458
|
+
)
|
1459
|
+
if model_type_id.endswith("-classifier") and return_format == "coco":
|
1460
|
+
raise ValueError(
|
1461
|
+
f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
|
1462
|
+
)
|
1463
|
+
|
1464
|
+
if dataset:
|
1465
|
+
eval_by_ds = self.get_eval_by_dataset(dataset)
|
1466
|
+
if len(eval_by_ds) == 0:
|
1467
|
+
raise Exception(f"Model is not valuated with dataset: {dataset}")
|
1468
|
+
eval_id = eval_by_ds[0].id
|
1469
|
+
|
1470
|
+
detail_eval_data = self.get_eval_by_id(
|
1471
|
+
eval_id=eval_id, test_set=True, metrics_by_class=True
|
1472
|
+
)
|
1473
|
+
|
1474
|
+
if return_format == "proto":
|
1475
|
+
return detail_eval_data.test_set
|
1476
|
+
elif model_type_id.endswith("-classifier"):
|
1477
|
+
return parse_eval_annotation_classifier(detail_eval_data)
|
1478
|
+
elif model_type_id == "visual-detector":
|
1479
|
+
if return_format == "array":
|
1480
|
+
return parse_eval_annotation_detector(detail_eval_data)
|
1481
|
+
elif return_format == "coco":
|
1482
|
+
return parse_eval_annotation_detector_coco(detail_eval_data)
|
1483
|
+
|
1484
|
+
def export(self, export_dir: str = None) -> None:
|
1485
|
+
"""Export the model, stores the exported model as model.tar file
|
1486
|
+
|
1487
|
+
Args:
|
1488
|
+
export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None.
|
1489
|
+
|
1490
|
+
Example:
|
1491
|
+
>>> from clarifai.client.model import Model
|
1492
|
+
>>> model = Model("url")
|
1493
|
+
>>> model.export()
|
1494
|
+
or
|
1495
|
+
>>> model.export('/path/to/export_model_dir')
|
1496
|
+
"""
|
1497
|
+
assert self.model_info.model_version.id, (
|
1498
|
+
"Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
1499
|
+
)
|
1500
|
+
if export_dir:
|
1501
|
+
try:
|
1502
|
+
if not os.path.exists(export_dir):
|
1503
|
+
os.makedirs(export_dir)
|
1504
|
+
except OSError as e:
|
1505
|
+
raise Exception(f"An error occurred while creating the directory: {e}")
|
1506
|
+
|
1507
|
+
def _get_export_response():
|
1508
|
+
get_export_request = service_pb2.GetModelVersionExportRequest(
|
1509
|
+
user_app_id=self.user_app_id,
|
1510
|
+
model_id=self.id,
|
1511
|
+
version_id=self.model_info.model_version.id,
|
1512
|
+
)
|
1513
|
+
response = self._grpc_request(self.STUB.GetModelVersionExport, get_export_request)
|
1514
|
+
|
1515
|
+
if (
|
1516
|
+
response.status.code != status_code_pb2.SUCCESS
|
1517
|
+
and response.status.code != status_code_pb2.CONN_DOES_NOT_EXIST
|
1518
|
+
):
|
1519
|
+
raise Exception(response.status)
|
1520
|
+
|
1521
|
+
return response
|
1522
|
+
|
1523
|
+
def _download_exported_model(
|
1524
|
+
get_model_export_response: service_pb2.SingleModelVersionExportResponse,
|
1525
|
+
local_filepath: str,
|
1526
|
+
):
|
1527
|
+
model_export_url = get_model_export_response.export.url
|
1528
|
+
model_export_file_size = get_model_export_response.export.size
|
1529
|
+
|
1530
|
+
with open(local_filepath, 'wb') as f:
|
1531
|
+
progress = tqdm(
|
1532
|
+
total=model_export_file_size, unit='B', unit_scale=True, desc="Exporting model"
|
1533
|
+
)
|
1534
|
+
downloaded_size = 0
|
1535
|
+
range_size = RANGE_SIZE
|
1536
|
+
chunk_size = CHUNK_SIZE
|
1537
|
+
retry = False
|
1538
|
+
retry_count = 0
|
1539
|
+
while downloaded_size < model_export_file_size:
|
1540
|
+
if downloaded_size + range_size >= model_export_file_size:
|
1541
|
+
range_header = f"bytes={downloaded_size}-"
|
1542
|
+
else:
|
1543
|
+
range_header = (
|
1544
|
+
f"bytes={downloaded_size}-{(downloaded_size + range_size - 1)}"
|
1545
|
+
)
|
1546
|
+
try:
|
1547
|
+
session = requests.Session()
|
1548
|
+
retries = Retry(
|
1549
|
+
total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]
|
1550
|
+
)
|
1551
|
+
session.mount('https://', HTTPAdapter(max_retries=retries))
|
1552
|
+
session.headers.update(
|
1553
|
+
{'Authorization': self.metadata[0][1], 'Range': range_header}
|
1554
|
+
)
|
1555
|
+
response = session.get(model_export_url, stream=True)
|
1556
|
+
response.raise_for_status()
|
1557
|
+
|
1558
|
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
1559
|
+
f.write(chunk)
|
1560
|
+
progress.update(len(chunk))
|
1561
|
+
f.flush()
|
1562
|
+
os.fsync(f.fileno())
|
1563
|
+
downloaded_size += range_size
|
1564
|
+
if not retry:
|
1565
|
+
range_size = (
|
1566
|
+
(range_size * 2)
|
1567
|
+
if (range_size * 2) < MAX_RANGE_SIZE
|
1568
|
+
else MAX_RANGE_SIZE
|
1569
|
+
)
|
1570
|
+
chunk_size = (
|
1571
|
+
(chunk_size * 2)
|
1572
|
+
if (chunk_size * 2) < MAX_CHUNK_SIZE
|
1573
|
+
else MAX_CHUNK_SIZE
|
1574
|
+
)
|
1575
|
+
except Exception as e:
|
1576
|
+
self.logger.error(f"Error downloading model: {e}")
|
1577
|
+
range_size = (
|
1578
|
+
(range_size // 2)
|
1579
|
+
if (range_size // 2) > MIN_RANGE_SIZE
|
1580
|
+
else MIN_RANGE_SIZE
|
1581
|
+
)
|
1582
|
+
chunk_size = (
|
1583
|
+
(chunk_size // 2)
|
1584
|
+
if (chunk_size // 2) > MIN_CHUNK_SIZE
|
1585
|
+
else MIN_CHUNK_SIZE
|
1586
|
+
)
|
1587
|
+
retry = True
|
1588
|
+
retry_count += 1
|
1589
|
+
f.seek(downloaded_size)
|
1590
|
+
progress.reset(total=model_export_file_size)
|
1591
|
+
progress.update(downloaded_size)
|
1592
|
+
if retry_count > 5:
|
1593
|
+
break
|
1594
|
+
progress.close()
|
1595
|
+
|
1596
|
+
self.logger.info(
|
1597
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} exported successfully to {export_dir}/model.tar"
|
1598
|
+
)
|
1599
|
+
|
1600
|
+
get_export_response = _get_export_response()
|
1601
|
+
if get_export_response.status.code == status_code_pb2.CONN_DOES_NOT_EXIST:
|
1602
|
+
put_export_request = service_pb2.PutModelVersionExportsRequest(
|
1603
|
+
user_app_id=self.user_app_id,
|
1604
|
+
model_id=self.id,
|
1605
|
+
version_id=self.model_info.model_version.id,
|
1606
|
+
)
|
1607
|
+
|
1608
|
+
response = self._grpc_request(self.STUB.PutModelVersionExports, put_export_request)
|
1609
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1610
|
+
raise Exception(response.status)
|
1611
|
+
|
1612
|
+
self.logger.info(
|
1613
|
+
f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}"
|
1614
|
+
)
|
1615
|
+
if export_dir:
|
1616
|
+
start_time = time.time()
|
1617
|
+
backoff_iterator = BackoffIterator(10)
|
1618
|
+
while True:
|
1619
|
+
get_export_response = _get_export_response()
|
1620
|
+
if (
|
1621
|
+
get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING
|
1622
|
+
or get_export_response.export.status.code
|
1623
|
+
== status_code_pb2.MODEL_EXPORT_PENDING
|
1624
|
+
) and time.time() - start_time < MODEL_EXPORT_TIMEOUT:
|
1625
|
+
self.logger.info(
|
1626
|
+
f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
|
1627
|
+
)
|
1628
|
+
time.sleep(next(backoff_iterator))
|
1629
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
1630
|
+
_download_exported_model(
|
1631
|
+
get_export_response, os.path.join(export_dir, "model.tar")
|
1632
|
+
)
|
1633
|
+
break
|
1634
|
+
elif time.time() - start_time > MODEL_EXPORT_TIMEOUT:
|
1635
|
+
raise Exception(
|
1636
|
+
f"""Model Export took too long. Please try again or contact support@clarifai.com
|
1637
|
+
Req ID: {get_export_response.status.req_id}"""
|
1638
|
+
)
|
1639
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
1640
|
+
if export_dir:
|
1641
|
+
_download_exported_model(
|
1642
|
+
get_export_response, os.path.join(export_dir, "model.tar")
|
1643
|
+
)
|
1644
|
+
else:
|
1645
|
+
self.logger.info(
|
1646
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}"
|
1647
|
+
)
|
1648
|
+
elif (
|
1649
|
+
get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING
|
1650
|
+
or get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING
|
1651
|
+
):
|
1451
1652
|
self.logger.info(
|
1452
1653
|
f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
|
1453
1654
|
)
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1655
|
+
|
1656
|
+
@staticmethod
|
1657
|
+
def _make_pretrained_config_proto(
|
1658
|
+
input_field_maps: dict, output_field_maps: dict, url: str = None
|
1659
|
+
):
|
1660
|
+
"""Make PretrainedModelConfig for uploading new version
|
1661
|
+
|
1662
|
+
Args:
|
1663
|
+
input_field_maps (dict): dict
|
1664
|
+
output_field_maps (dict): dict
|
1665
|
+
url (str, optional): direct download url. Defaults to None.
|
1666
|
+
"""
|
1667
|
+
|
1668
|
+
def _parse_fields_map(x):
|
1669
|
+
"""parse input, outputs to Struct"""
|
1670
|
+
_fields_map = Struct()
|
1671
|
+
_fields_map.update(x)
|
1672
|
+
return _fields_map
|
1673
|
+
|
1674
|
+
input_fields_map = _parse_fields_map(input_field_maps)
|
1675
|
+
output_fields_map = _parse_fields_map(output_field_maps)
|
1676
|
+
|
1677
|
+
return resources_pb2.PretrainedModelConfig(
|
1678
|
+
input_fields_map=input_fields_map,
|
1679
|
+
output_fields_map=output_fields_map,
|
1680
|
+
model_zip_url=url,
|
1681
|
+
)
|
1682
|
+
|
1683
|
+
@staticmethod
|
1684
|
+
def _make_inference_params_proto(
|
1685
|
+
inference_parameters: List[Dict],
|
1686
|
+
) -> List[resources_pb2.ModelTypeField]:
|
1687
|
+
"""Convert list of Clarifai inference parameters to proto for uploading new version
|
1688
|
+
|
1689
|
+
Args:
|
1690
|
+
inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
|
1691
|
+
|
1692
|
+
Returns:
|
1693
|
+
List[resources_pb2.ModelTypeField]
|
1694
|
+
"""
|
1695
|
+
|
1696
|
+
def _make_default_value_proto(dtype, value):
|
1697
|
+
if dtype == 1:
|
1698
|
+
return Value(bool_value=value)
|
1699
|
+
elif dtype == 2 or dtype == 21:
|
1700
|
+
return Value(string_value=value)
|
1701
|
+
elif dtype == 3:
|
1702
|
+
return Value(number_value=value)
|
1703
|
+
|
1704
|
+
iterative_proto_params = []
|
1705
|
+
for param in inference_parameters:
|
1706
|
+
dtype = param.get("field_type")
|
1707
|
+
proto_param = resources_pb2.ModelTypeField(
|
1708
|
+
path=param.get("path"),
|
1709
|
+
field_type=dtype,
|
1710
|
+
default_value=_make_default_value_proto(
|
1711
|
+
dtype=dtype, value=param.get("default_value")
|
1712
|
+
),
|
1713
|
+
description=param.get("description"),
|
1714
|
+
)
|
1715
|
+
iterative_proto_params.append(proto_param)
|
1716
|
+
return iterative_proto_params
|
1717
|
+
|
1718
|
+
def create_version_by_file(
|
1719
|
+
self,
|
1720
|
+
file_path: str,
|
1721
|
+
input_field_maps: dict,
|
1722
|
+
output_field_maps: dict,
|
1723
|
+
inference_parameter_configs: dict = None,
|
1724
|
+
model_version: str = None,
|
1725
|
+
part_id: int = 1,
|
1726
|
+
range_start: int = 0,
|
1727
|
+
no_cache: bool = False,
|
1728
|
+
no_resume: bool = False,
|
1729
|
+
description: str = "",
|
1730
|
+
) -> 'Model':
|
1731
|
+
"""Create model version by uploading local file
|
1732
|
+
|
1733
|
+
Args:
|
1734
|
+
file_path (str): path to built file.
|
1735
|
+
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1736
|
+
{clarifai_input_field: triton_input_filed}.
|
1737
|
+
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1738
|
+
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1739
|
+
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1740
|
+
model_version (str, optional): Custom model version. Defaults to None.
|
1741
|
+
part_id (int, optional): part id of file. Defaults to 1.
|
1742
|
+
range_start (int, optional): range of uploaded size. Defaults to 0.
|
1743
|
+
no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
|
1744
|
+
no_resume (bool, optional): disable auto resume upload. Defaults to False.
|
1745
|
+
description (str): Model description.
|
1746
|
+
|
1747
|
+
Return:
|
1748
|
+
Model: instance of Model with new created version
|
1749
|
+
|
1750
|
+
"""
|
1751
|
+
file_size = os.path.getsize(file_path)
|
1752
|
+
assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, (
|
1753
|
+
"The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
|
1754
|
+
)
|
1755
|
+
|
1756
|
+
pretrained_proto = Model._make_pretrained_config_proto(
|
1757
|
+
input_field_maps=input_field_maps, output_field_maps=output_field_maps
|
1758
|
+
)
|
1759
|
+
inference_param_proto = (
|
1760
|
+
Model._make_inference_params_proto(inference_parameter_configs)
|
1761
|
+
if inference_parameter_configs
|
1762
|
+
else None
|
1763
|
+
)
|
1764
|
+
|
1765
|
+
if file_size >= 1e9:
|
1766
|
+
chunk_size = 1024 * 50_000 # 50MB
|
1767
|
+
else:
|
1768
|
+
chunk_size = 1024 * 10_000 # 10MB
|
1769
|
+
|
1770
|
+
# self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
|
1771
|
+
# self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
|
1772
|
+
|
1773
|
+
cache_dir = os.path.join(file_path, '..', '.cache')
|
1774
|
+
cache_upload_file = os.path.join(cache_dir, "upload.json")
|
1775
|
+
last_percent = 0
|
1776
|
+
if os.path.exists(cache_upload_file) and not no_resume:
|
1777
|
+
with open(cache_upload_file, "r") as fp:
|
1778
|
+
try:
|
1779
|
+
cache_info = json.load(fp)
|
1780
|
+
if isinstance(cache_info, dict):
|
1781
|
+
part_id = cache_info.get("part_id", part_id)
|
1782
|
+
chunk_size = cache_info.get("chunk_size", chunk_size)
|
1783
|
+
range_start = cache_info.get("range_start", range_start)
|
1784
|
+
model_version = cache_info.get("model_version", model_version)
|
1785
|
+
last_percent = cache_info.get("last_percent", last_percent)
|
1786
|
+
except Exception as e:
|
1787
|
+
self.logger.error(f"Skipping loading the upload cache due to error {e}.")
|
1788
|
+
|
1789
|
+
def init_model_version_upload(model_version):
|
1790
|
+
return service_pb2.PostModelVersionsUploadRequest(
|
1791
|
+
upload_config=service_pb2.PostModelVersionsUploadConfig(
|
1792
|
+
user_app_id=self.user_app_id,
|
1793
|
+
model_id=self.id,
|
1794
|
+
total_size=file_size,
|
1795
|
+
model_version=resources_pb2.ModelVersion(
|
1796
|
+
id=model_version,
|
1797
|
+
pretrained_model_config=pretrained_proto,
|
1798
|
+
description=description,
|
1799
|
+
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto),
|
1800
|
+
),
|
1801
|
+
)
|
1802
|
+
)
|
1803
|
+
|
1804
|
+
def _uploading(chunk, part_id, range_start, model_version):
|
1805
|
+
return service_pb2.PostModelVersionsUploadRequest(
|
1806
|
+
content_part=resources_pb2.UploadContentPart(
|
1807
|
+
data=chunk, part_number=part_id, range_start=range_start
|
1808
|
+
)
|
1809
|
+
)
|
1810
|
+
|
1811
|
+
finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
|
1812
|
+
uploading_in_progress_status = [
|
1813
|
+
status_code_pb2.UPLOAD_IN_PROGRESS,
|
1814
|
+
status_code_pb2.MODEL_UPLOADING,
|
1815
|
+
]
|
1816
|
+
|
1817
|
+
def _save_cache(cache: dict):
|
1818
|
+
if not no_cache:
|
1819
|
+
os.makedirs(cache_dir, exist_ok=True)
|
1820
|
+
with open(cache_upload_file, "w") as fp:
|
1821
|
+
json.dump(cache, fp, indent=2)
|
1822
|
+
|
1823
|
+
def stream_request(fp, part_id, end_part_id, chunk_size, version):
|
1824
|
+
yield init_model_version_upload(version)
|
1825
|
+
for iter_part_id in range(part_id, end_part_id):
|
1826
|
+
chunk = fp.read(chunk_size)
|
1827
|
+
if not chunk:
|
1828
|
+
return
|
1829
|
+
yield _uploading(
|
1830
|
+
chunk=chunk,
|
1831
|
+
part_id=iter_part_id,
|
1832
|
+
range_start=chunk_size * (iter_part_id - 1),
|
1833
|
+
model_version=version,
|
1834
|
+
)
|
1835
|
+
|
1836
|
+
tqdm_loader = tqdm(total=100)
|
1837
|
+
if model_version:
|
1838
|
+
desc = f"Uploading model `{self.id}` version `{model_version}` ..."
|
1839
|
+
else:
|
1840
|
+
desc = f"Uploading model `{self.id}` ..."
|
1841
|
+
tqdm_loader.set_description(desc)
|
1842
|
+
|
1843
|
+
cache_uploading_info = {}
|
1844
|
+
cache_uploading_info["part_id"] = part_id
|
1845
|
+
cache_uploading_info["model_version"] = model_version
|
1846
|
+
cache_uploading_info["range_start"] = range_start
|
1847
|
+
cache_uploading_info["chunk_size"] = chunk_size
|
1848
|
+
cache_uploading_info["last_percent"] = last_percent
|
1849
|
+
tqdm_loader.update(last_percent)
|
1850
|
+
last_part_id = part_id
|
1851
|
+
n_chunks = file_size // chunk_size
|
1852
|
+
n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
|
1853
|
+
|
1854
|
+
def stream_and_logging(
|
1855
|
+
request, tqdm_loader, cache_uploading_info, expected_steps: int = None
|
1856
|
+
):
|
1857
|
+
for st_step, st_response in enumerate(
|
1858
|
+
self.auth_helper.get_stub().PostModelVersionsUpload(
|
1859
|
+
request, metadata=self.auth_helper.metadata
|
1860
|
+
)
|
1861
|
+
):
|
1862
|
+
if st_response.status.code in uploading_in_progress_status:
|
1863
|
+
if cache_uploading_info["model_version"]:
|
1864
|
+
assert (
|
1865
|
+
st_response.model_version_id == cache_uploading_info["model_version"]
|
1866
|
+
), RuntimeError
|
1867
|
+
else:
|
1868
|
+
cache_uploading_info["model_version"] = st_response.model_version_id
|
1869
|
+
if st_step > 0:
|
1870
|
+
cache_uploading_info["part_id"] += 1
|
1871
|
+
cache_uploading_info["range_start"] += chunk_size
|
1872
|
+
_save_cache(cache_uploading_info)
|
1873
|
+
|
1874
|
+
if st_response.status.percent_completed:
|
1875
|
+
step_percent = (
|
1876
|
+
st_response.status.percent_completed
|
1877
|
+
- cache_uploading_info["last_percent"]
|
1878
|
+
)
|
1879
|
+
cache_uploading_info["last_percent"] += step_percent
|
1880
|
+
tqdm_loader.set_description(
|
1881
|
+
f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
|
1882
|
+
)
|
1883
|
+
tqdm_loader.update(step_percent)
|
1884
|
+
elif st_response.status.code not in finished_status + uploading_in_progress_status:
|
1885
|
+
# TODO: Find better way to handle error
|
1886
|
+
if expected_steps and st_step < expected_steps:
|
1887
|
+
raise Exception(f"Failed to upload model, error: {st_response.status}")
|
1888
|
+
|
1889
|
+
with open(file_path, 'rb') as fp:
|
1890
|
+
# seeking
|
1891
|
+
for _ in range(1, last_part_id):
|
1892
|
+
fp.read(chunk_size)
|
1893
|
+
# Stream even part
|
1894
|
+
end_part_id = n_chunks or 1
|
1895
|
+
for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
|
1896
|
+
end_part_id = iter_part_id + n_chunk_per_stream
|
1897
|
+
end_part_id = min(n_chunks, end_part_id)
|
1898
|
+
expected_steps = end_part_id - iter_part_id + 1 # init step
|
1899
|
+
st_reqs = stream_request(
|
1900
|
+
fp,
|
1901
|
+
iter_part_id,
|
1902
|
+
end_part_id=end_part_id,
|
1903
|
+
chunk_size=chunk_size,
|
1904
|
+
version=cache_uploading_info["model_version"],
|
1905
|
+
)
|
1906
|
+
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
|
1907
|
+
# Stream last part
|
1908
|
+
accum_size = (end_part_id - 1) * chunk_size
|
1909
|
+
remained_size = file_size - accum_size if accum_size >= 0 else file_size
|
1910
|
+
st_reqs = stream_request(
|
1911
|
+
fp,
|
1912
|
+
end_part_id,
|
1913
|
+
end_part_id=end_part_id + 1,
|
1914
|
+
chunk_size=remained_size,
|
1915
|
+
version=cache_uploading_info["model_version"],
|
1916
|
+
)
|
1917
|
+
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
|
1918
|
+
|
1919
|
+
# clean up cache
|
1920
|
+
if not no_cache:
|
1921
|
+
try:
|
1922
|
+
os.remove(cache_upload_file)
|
1923
|
+
except Exception:
|
1924
|
+
_save_cache({})
|
1925
|
+
|
1926
|
+
if cache_uploading_info["last_percent"] <= 100:
|
1927
|
+
tqdm_loader.update(100 - cache_uploading_info["last_percent"])
|
1928
|
+
tqdm_loader.set_description("Upload done")
|
1929
|
+
|
1930
|
+
tqdm_loader.set_description(
|
1931
|
+
f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
|
1932
|
+
)
|
1933
|
+
|
1934
|
+
return Model.from_auth_helper(
|
1935
|
+
auth=self.auth_helper,
|
1936
|
+
model_id=self.id,
|
1937
|
+
model_version=dict(id=cache_uploading_info.get('model_version')),
|
1938
|
+
)
|
1939
|
+
|
1940
|
+
def create_version_by_url(
|
1941
|
+
self,
|
1942
|
+
url: str,
|
1943
|
+
input_field_maps: dict,
|
1944
|
+
output_field_maps: dict,
|
1945
|
+
inference_parameter_configs: List[dict] = None,
|
1946
|
+
description: str = "",
|
1947
|
+
) -> 'Model':
|
1948
|
+
"""Upload a new version of an existing model in the Clarifai platform using direct download url.
|
1949
|
+
|
1950
|
+
Args:
|
1951
|
+
url (str]): url of zip of model
|
1952
|
+
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1953
|
+
{clarifai_input_field: triton_input_filed}.
|
1954
|
+
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1955
|
+
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1956
|
+
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1957
|
+
description (str): Model description.
|
1958
|
+
|
1959
|
+
Return:
|
1960
|
+
Model: instance of Model with new created version
|
1961
|
+
"""
|
1962
|
+
|
1963
|
+
pretrained_proto = Model._make_pretrained_config_proto(
|
1964
|
+
input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url
|
1965
|
+
)
|
1966
|
+
inference_param_proto = (
|
1967
|
+
Model._make_inference_params_proto(inference_parameter_configs)
|
1968
|
+
if inference_parameter_configs
|
1969
|
+
else None
|
1970
|
+
)
|
1971
|
+
request = service_pb2.PostModelVersionsRequest(
|
1972
|
+
user_app_id=self.user_app_id,
|
1973
|
+
model_id=self.id,
|
1974
|
+
model_versions=[
|
1975
|
+
resources_pb2.ModelVersion(
|
1976
|
+
pretrained_model_config=pretrained_proto,
|
1977
|
+
description=description,
|
1978
|
+
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto),
|
1979
|
+
)
|
1980
|
+
],
|
1981
|
+
)
|
1982
|
+
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
1983
|
+
|
1984
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1985
|
+
raise Exception(f"Failed to upload model, error: {response.status}")
|
1466
1986
|
self.logger.info(
|
1467
|
-
f"
|
1987
|
+
f"Success uploading model {self.id}, new version {response.model.model_version.id}"
|
1988
|
+
)
|
1989
|
+
|
1990
|
+
return Model.from_auth_helper(
|
1991
|
+
auth=self.auth_helper,
|
1992
|
+
model_id=self.id,
|
1993
|
+
model_version=dict(id=response.model.model_version.id),
|
1468
1994
|
)
|
1469
|
-
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
|
1470
|
-
get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING:
|
1471
|
-
self.logger.info(
|
1472
|
-
f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
|
1473
|
-
)
|
1474
|
-
|
1475
|
-
@staticmethod
|
1476
|
-
def _make_pretrained_config_proto(input_field_maps: dict,
|
1477
|
-
output_field_maps: dict,
|
1478
|
-
url: str = None):
|
1479
|
-
"""Make PretrainedModelConfig for uploading new version
|
1480
|
-
|
1481
|
-
Args:
|
1482
|
-
input_field_maps (dict): dict
|
1483
|
-
output_field_maps (dict): dict
|
1484
|
-
url (str, optional): direct download url. Defaults to None.
|
1485
|
-
"""
|
1486
|
-
|
1487
|
-
def _parse_fields_map(x):
|
1488
|
-
"""parse input, outputs to Struct"""
|
1489
|
-
_fields_map = Struct()
|
1490
|
-
_fields_map.update(x)
|
1491
|
-
return _fields_map
|
1492
|
-
|
1493
|
-
input_fields_map = _parse_fields_map(input_field_maps)
|
1494
|
-
output_fields_map = _parse_fields_map(output_field_maps)
|
1495
|
-
|
1496
|
-
return resources_pb2.PretrainedModelConfig(
|
1497
|
-
input_fields_map=input_fields_map, output_fields_map=output_fields_map, model_zip_url=url)
|
1498
|
-
|
1499
|
-
@staticmethod
|
1500
|
-
def _make_inference_params_proto(
|
1501
|
-
inference_parameters: List[Dict]) -> List[resources_pb2.ModelTypeField]:
|
1502
|
-
"""Convert list of Clarifai inference parameters to proto for uploading new version
|
1503
|
-
|
1504
|
-
Args:
|
1505
|
-
inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
|
1506
|
-
|
1507
|
-
Returns:
|
1508
|
-
List[resources_pb2.ModelTypeField]
|
1509
|
-
"""
|
1510
|
-
|
1511
|
-
def _make_default_value_proto(dtype, value):
|
1512
|
-
if dtype == 1:
|
1513
|
-
return Value(bool_value=value)
|
1514
|
-
elif dtype == 2 or dtype == 21:
|
1515
|
-
return Value(string_value=value)
|
1516
|
-
elif dtype == 3:
|
1517
|
-
return Value(number_value=value)
|
1518
|
-
|
1519
|
-
iterative_proto_params = []
|
1520
|
-
for param in inference_parameters:
|
1521
|
-
dtype = param.get("field_type")
|
1522
|
-
proto_param = resources_pb2.ModelTypeField(
|
1523
|
-
path=param.get("path"),
|
1524
|
-
field_type=dtype,
|
1525
|
-
default_value=_make_default_value_proto(dtype=dtype, value=param.get("default_value")),
|
1526
|
-
description=param.get("description"),
|
1527
|
-
)
|
1528
|
-
iterative_proto_params.append(proto_param)
|
1529
|
-
return iterative_proto_params
|
1530
|
-
|
1531
|
-
def create_version_by_file(self,
|
1532
|
-
file_path: str,
|
1533
|
-
input_field_maps: dict,
|
1534
|
-
output_field_maps: dict,
|
1535
|
-
inference_parameter_configs: dict = None,
|
1536
|
-
model_version: str = None,
|
1537
|
-
part_id: int = 1,
|
1538
|
-
range_start: int = 0,
|
1539
|
-
no_cache: bool = False,
|
1540
|
-
no_resume: bool = False,
|
1541
|
-
description: str = "") -> 'Model':
|
1542
|
-
"""Create model version by uploading local file
|
1543
|
-
|
1544
|
-
Args:
|
1545
|
-
file_path (str): path to built file.
|
1546
|
-
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1547
|
-
{clarifai_input_field: triton_input_filed}.
|
1548
|
-
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1549
|
-
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1550
|
-
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1551
|
-
model_version (str, optional): Custom model version. Defaults to None.
|
1552
|
-
part_id (int, optional): part id of file. Defaults to 1.
|
1553
|
-
range_start (int, optional): range of uploaded size. Defaults to 0.
|
1554
|
-
no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
|
1555
|
-
no_resume (bool, optional): disable auto resume upload. Defaults to False.
|
1556
|
-
description (str): Model description.
|
1557
|
-
|
1558
|
-
Return:
|
1559
|
-
Model: instance of Model with new created version
|
1560
|
-
|
1561
|
-
"""
|
1562
|
-
file_size = os.path.getsize(file_path)
|
1563
|
-
assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
|
1564
|
-
|
1565
|
-
pretrained_proto = Model._make_pretrained_config_proto(
|
1566
|
-
input_field_maps=input_field_maps, output_field_maps=output_field_maps)
|
1567
|
-
inference_param_proto = Model._make_inference_params_proto(
|
1568
|
-
inference_parameter_configs) if inference_parameter_configs else None
|
1569
|
-
|
1570
|
-
if file_size >= 1e9:
|
1571
|
-
chunk_size = 1024 * 50_000 # 50MB
|
1572
|
-
else:
|
1573
|
-
chunk_size = 1024 * 10_000 # 10MB
|
1574
|
-
|
1575
|
-
#self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
|
1576
|
-
#self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
|
1577
|
-
|
1578
|
-
cache_dir = os.path.join(file_path, '..', '.cache')
|
1579
|
-
cache_upload_file = os.path.join(cache_dir, "upload.json")
|
1580
|
-
last_percent = 0
|
1581
|
-
if os.path.exists(cache_upload_file) and not no_resume:
|
1582
|
-
with open(cache_upload_file, "r") as fp:
|
1583
|
-
try:
|
1584
|
-
cache_info = json.load(fp)
|
1585
|
-
if isinstance(cache_info, dict):
|
1586
|
-
part_id = cache_info.get("part_id", part_id)
|
1587
|
-
chunk_size = cache_info.get("chunk_size", chunk_size)
|
1588
|
-
range_start = cache_info.get("range_start", range_start)
|
1589
|
-
model_version = cache_info.get("model_version", model_version)
|
1590
|
-
last_percent = cache_info.get("last_percent", last_percent)
|
1591
|
-
except Exception as e:
|
1592
|
-
self.logger.error(f"Skipping loading the upload cache due to error {e}.")
|
1593
|
-
|
1594
|
-
def init_model_version_upload(model_version):
|
1595
|
-
return service_pb2.PostModelVersionsUploadRequest(
|
1596
|
-
upload_config=service_pb2.PostModelVersionsUploadConfig(
|
1597
|
-
user_app_id=self.user_app_id,
|
1598
|
-
model_id=self.id,
|
1599
|
-
total_size=file_size,
|
1600
|
-
model_version=resources_pb2.ModelVersion(
|
1601
|
-
id=model_version,
|
1602
|
-
pretrained_model_config=pretrained_proto,
|
1603
|
-
description=description,
|
1604
|
-
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto)),
|
1605
|
-
))
|
1606
|
-
|
1607
|
-
def _uploading(chunk, part_id, range_start, model_version):
|
1608
|
-
return service_pb2.PostModelVersionsUploadRequest(
|
1609
|
-
content_part=resources_pb2.UploadContentPart(
|
1610
|
-
data=chunk, part_number=part_id, range_start=range_start))
|
1611
|
-
|
1612
|
-
finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
|
1613
|
-
uploading_in_progress_status = [
|
1614
|
-
status_code_pb2.UPLOAD_IN_PROGRESS, status_code_pb2.MODEL_UPLOADING
|
1615
|
-
]
|
1616
|
-
|
1617
|
-
def _save_cache(cache: dict):
|
1618
|
-
if not no_cache:
|
1619
|
-
os.makedirs(cache_dir, exist_ok=True)
|
1620
|
-
with open(cache_upload_file, "w") as fp:
|
1621
|
-
json.dump(cache, fp, indent=2)
|
1622
|
-
|
1623
|
-
def stream_request(fp, part_id, end_part_id, chunk_size, version):
|
1624
|
-
yield init_model_version_upload(version)
|
1625
|
-
for iter_part_id in range(part_id, end_part_id):
|
1626
|
-
chunk = fp.read(chunk_size)
|
1627
|
-
if not chunk:
|
1628
|
-
return
|
1629
|
-
yield _uploading(
|
1630
|
-
chunk=chunk,
|
1631
|
-
part_id=iter_part_id,
|
1632
|
-
range_start=chunk_size * (iter_part_id - 1),
|
1633
|
-
model_version=version)
|
1634
|
-
|
1635
|
-
tqdm_loader = tqdm(total=100)
|
1636
|
-
if model_version:
|
1637
|
-
desc = f"Uploading model `{self.id}` version `{model_version}` ..."
|
1638
|
-
else:
|
1639
|
-
desc = f"Uploading model `{self.id}` ..."
|
1640
|
-
tqdm_loader.set_description(desc)
|
1641
|
-
|
1642
|
-
cache_uploading_info = {}
|
1643
|
-
cache_uploading_info["part_id"] = part_id
|
1644
|
-
cache_uploading_info["model_version"] = model_version
|
1645
|
-
cache_uploading_info["range_start"] = range_start
|
1646
|
-
cache_uploading_info["chunk_size"] = chunk_size
|
1647
|
-
cache_uploading_info["last_percent"] = last_percent
|
1648
|
-
tqdm_loader.update(last_percent)
|
1649
|
-
last_part_id = part_id
|
1650
|
-
n_chunks = file_size // chunk_size
|
1651
|
-
n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
|
1652
|
-
|
1653
|
-
def stream_and_logging(request, tqdm_loader, cache_uploading_info, expected_steps: int = None):
|
1654
|
-
for st_step, st_response in enumerate(self.auth_helper.get_stub().PostModelVersionsUpload(
|
1655
|
-
request, metadata=self.auth_helper.metadata)):
|
1656
|
-
if st_response.status.code in uploading_in_progress_status:
|
1657
|
-
if cache_uploading_info["model_version"]:
|
1658
|
-
assert st_response.model_version_id == cache_uploading_info[
|
1659
|
-
"model_version"], RuntimeError
|
1660
|
-
else:
|
1661
|
-
cache_uploading_info["model_version"] = st_response.model_version_id
|
1662
|
-
if st_step > 0:
|
1663
|
-
cache_uploading_info["part_id"] += 1
|
1664
|
-
cache_uploading_info["range_start"] += chunk_size
|
1665
|
-
_save_cache(cache_uploading_info)
|
1666
|
-
|
1667
|
-
if st_response.status.percent_completed:
|
1668
|
-
step_percent = st_response.status.percent_completed - cache_uploading_info["last_percent"]
|
1669
|
-
cache_uploading_info["last_percent"] += step_percent
|
1670
|
-
tqdm_loader.set_description(
|
1671
|
-
f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
|
1672
|
-
)
|
1673
|
-
tqdm_loader.update(step_percent)
|
1674
|
-
elif st_response.status.code not in finished_status + uploading_in_progress_status:
|
1675
|
-
# TODO: Find better way to handle error
|
1676
|
-
if expected_steps and st_step < expected_steps:
|
1677
|
-
raise Exception(f"Failed to upload model, error: {st_response.status}")
|
1678
|
-
|
1679
|
-
with open(file_path, 'rb') as fp:
|
1680
|
-
# seeking
|
1681
|
-
for _ in range(1, last_part_id):
|
1682
|
-
fp.read(chunk_size)
|
1683
|
-
# Stream even part
|
1684
|
-
end_part_id = n_chunks or 1
|
1685
|
-
for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
|
1686
|
-
end_part_id = iter_part_id + n_chunk_per_stream
|
1687
|
-
if end_part_id >= n_chunks:
|
1688
|
-
end_part_id = n_chunks
|
1689
|
-
expected_steps = end_part_id - iter_part_id + 1 # init step
|
1690
|
-
st_reqs = stream_request(
|
1691
|
-
fp,
|
1692
|
-
iter_part_id,
|
1693
|
-
end_part_id=end_part_id,
|
1694
|
-
chunk_size=chunk_size,
|
1695
|
-
version=cache_uploading_info["model_version"])
|
1696
|
-
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
|
1697
|
-
# Stream last part
|
1698
|
-
accum_size = (end_part_id - 1) * chunk_size
|
1699
|
-
remained_size = file_size - accum_size if accum_size >= 0 else file_size
|
1700
|
-
st_reqs = stream_request(
|
1701
|
-
fp,
|
1702
|
-
end_part_id,
|
1703
|
-
end_part_id=end_part_id + 1,
|
1704
|
-
chunk_size=remained_size,
|
1705
|
-
version=cache_uploading_info["model_version"])
|
1706
|
-
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
|
1707
|
-
|
1708
|
-
# clean up cache
|
1709
|
-
if not no_cache:
|
1710
|
-
try:
|
1711
|
-
os.remove(cache_upload_file)
|
1712
|
-
except Exception:
|
1713
|
-
_save_cache({})
|
1714
|
-
|
1715
|
-
if cache_uploading_info["last_percent"] <= 100:
|
1716
|
-
tqdm_loader.update(100 - cache_uploading_info["last_percent"])
|
1717
|
-
tqdm_loader.set_description("Upload done")
|
1718
|
-
|
1719
|
-
tqdm_loader.set_description(
|
1720
|
-
f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
|
1721
|
-
)
|
1722
|
-
|
1723
|
-
return Model.from_auth_helper(
|
1724
|
-
auth=self.auth_helper,
|
1725
|
-
model_id=self.id,
|
1726
|
-
model_version=dict(id=cache_uploading_info.get('model_version')))
|
1727
|
-
|
1728
|
-
def create_version_by_url(self,
|
1729
|
-
url: str,
|
1730
|
-
input_field_maps: dict,
|
1731
|
-
output_field_maps: dict,
|
1732
|
-
inference_parameter_configs: List[dict] = None,
|
1733
|
-
description: str = "") -> 'Model':
|
1734
|
-
"""Upload a new version of an existing model in the Clarifai platform using direct download url.
|
1735
|
-
|
1736
|
-
Args:
|
1737
|
-
url (str]): url of zip of model
|
1738
|
-
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1739
|
-
{clarifai_input_field: triton_input_filed}.
|
1740
|
-
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1741
|
-
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1742
|
-
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1743
|
-
description (str): Model description.
|
1744
|
-
|
1745
|
-
Return:
|
1746
|
-
Model: instance of Model with new created version
|
1747
|
-
"""
|
1748
|
-
|
1749
|
-
pretrained_proto = Model._make_pretrained_config_proto(
|
1750
|
-
input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url)
|
1751
|
-
inference_param_proto = Model._make_inference_params_proto(
|
1752
|
-
inference_parameter_configs) if inference_parameter_configs else None
|
1753
|
-
request = service_pb2.PostModelVersionsRequest(
|
1754
|
-
user_app_id=self.user_app_id,
|
1755
|
-
model_id=self.id,
|
1756
|
-
model_versions=[
|
1757
|
-
resources_pb2.ModelVersion(
|
1758
|
-
pretrained_model_config=pretrained_proto,
|
1759
|
-
description=description,
|
1760
|
-
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto))
|
1761
|
-
])
|
1762
|
-
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
1763
|
-
|
1764
|
-
if response.status.code != status_code_pb2.SUCCESS:
|
1765
|
-
raise Exception(f"Failed to upload model, error: {response.status}")
|
1766
|
-
self.logger.info(
|
1767
|
-
f"Success uploading model {self.id}, new version {response.model.model_version.id}")
|
1768
|
-
|
1769
|
-
return Model.from_auth_helper(
|
1770
|
-
auth=self.auth_helper,
|
1771
|
-
model_id=self.id,
|
1772
|
-
model_version=dict(id=response.model.model_version.id))
|