clarifai 11.2.2__py3-none-any.whl → 11.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +225 -89
- clarifai/cli/compute_cluster.py +24 -21
- clarifai/cli/deployment.py +66 -42
- clarifai/cli/model.py +1 -1
- clarifai/cli/nodepool.py +59 -41
- clarifai/client/auth/stub.py +4 -5
- clarifai/client/dataset.py +3 -4
- clarifai/runners/models/model_builder.py +17 -7
- clarifai/runners/models/model_run_locally.py +1 -0
- clarifai/runners/utils/const.py +9 -8
- clarifai/utils/cli.py +125 -36
- clarifai/utils/config.py +105 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +64 -21
- clarifai/utils/misc.py +2 -0
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/METADATA +2 -2
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/RECORD +22 -21
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/WHEEL +0 -0
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.2.2.dist-info → clarifai-11.2.3.dist-info}/top_level.txt +0 -0
clarifai/cli/nodepool.py
CHANGED
@@ -1,31 +1,27 @@
|
|
1
1
|
import click
|
2
|
+
|
2
3
|
from clarifai.cli.base import cli
|
3
|
-
from clarifai.
|
4
|
-
|
4
|
+
from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
|
5
|
+
validate_context)
|
5
6
|
|
6
7
|
|
7
|
-
@cli.group(['nodepool', 'np'])
|
8
|
+
@cli.group(['nodepool', 'np'], cls=AliasedGroup)
|
8
9
|
def nodepool():
|
9
10
|
"""Manage Nodepools: create, delete, list"""
|
10
|
-
pass
|
11
11
|
|
12
12
|
|
13
|
-
@nodepool.command()
|
14
|
-
@click.
|
15
|
-
|
16
|
-
'--compute_cluster_id',
|
17
|
-
required=False,
|
18
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
13
|
+
@nodepool.command(['c'])
|
14
|
+
@click.argument('compute_cluster_id')
|
15
|
+
@click.argument('nodepool_id')
|
19
16
|
@click.option(
|
20
17
|
'--config',
|
21
18
|
type=click.Path(exists=True),
|
22
19
|
required=True,
|
23
20
|
help='Path to the nodepool config file.')
|
24
|
-
@click.option(
|
25
|
-
'-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
|
26
21
|
@click.pass_context
|
27
|
-
def create(ctx, compute_cluster_id,
|
22
|
+
def create(ctx, compute_cluster_id, nodepool_id, config):
|
28
23
|
"""Create a new Nodepool with the given config file."""
|
24
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
29
25
|
|
30
26
|
validate_context(ctx)
|
31
27
|
nodepool_config = from_yaml(config)
|
@@ -43,52 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
|
|
43
39
|
|
44
40
|
compute_cluster = ComputeCluster(
|
45
41
|
compute_cluster_id=compute_cluster_id,
|
46
|
-
user_id=ctx.obj
|
47
|
-
pat=ctx.obj
|
48
|
-
base_url=ctx.obj
|
42
|
+
user_id=ctx.obj.current.user_id,
|
43
|
+
pat=ctx.obj.current.pat,
|
44
|
+
base_url=ctx.obj.current.api_base)
|
49
45
|
if nodepool_id:
|
50
46
|
compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
|
51
47
|
else:
|
52
48
|
compute_cluster.create_nodepool(config)
|
53
49
|
|
54
50
|
|
55
|
-
@nodepool.command()
|
56
|
-
@click.
|
57
|
-
'-cc_id',
|
58
|
-
'--compute_cluster_id',
|
59
|
-
required=True,
|
60
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
51
|
+
@nodepool.command(['ls'])
|
52
|
+
@click.argument('compute_cluster_id', default="")
|
61
53
|
@click.option('--page_no', required=False, help='Page number to list.', default=1)
|
62
|
-
@click.option('--per_page', required=False, help='Number of items per page.', default=
|
54
|
+
@click.option('--per_page', required=False, help='Number of items per page.', default=128)
|
63
55
|
@click.pass_context
|
64
56
|
def list(ctx, compute_cluster_id, page_no, per_page):
|
65
|
-
"""List all nodepools for the user.
|
57
|
+
"""List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
|
58
|
+
it will list only within that compute cluster. """
|
59
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
60
|
+
from clarifai.client.user import User
|
66
61
|
|
67
62
|
validate_context(ctx)
|
68
|
-
compute_cluster = ComputeCluster(
|
69
|
-
compute_cluster_id=compute_cluster_id,
|
70
|
-
user_id=ctx.obj['user_id'],
|
71
|
-
pat=ctx.obj['pat'],
|
72
|
-
base_url=ctx.obj['base_url'])
|
73
|
-
response = compute_cluster.list_nodepools(page_no, per_page)
|
74
|
-
display_co_resources(response, "Nodepool")
|
75
63
|
|
64
|
+
cc_id = compute_cluster_id
|
76
65
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
66
|
+
if cc_id:
|
67
|
+
compute_cluster = ComputeCluster(
|
68
|
+
compute_cluster_id=cc_id,
|
69
|
+
user_id=ctx.obj.current.user_id,
|
70
|
+
pat=ctx.obj.current.pat,
|
71
|
+
base_url=ctx.obj.current.api_base)
|
72
|
+
response = compute_cluster.list_nodepools(page_no, per_page)
|
73
|
+
else:
|
74
|
+
user = User(
|
75
|
+
user_id=ctx.obj.current.user_id,
|
76
|
+
pat=ctx.obj.current.pat,
|
77
|
+
base_url=ctx.obj.current.api_base)
|
78
|
+
ccs = user.list_compute_clusters(page_no, per_page)
|
79
|
+
response = []
|
80
|
+
for cc in ccs:
|
81
|
+
compute_cluster = ComputeCluster(
|
82
|
+
compute_cluster_id=cc.id,
|
83
|
+
user_id=ctx.obj.current.user_id,
|
84
|
+
pat=ctx.obj.current.pat,
|
85
|
+
base_url=ctx.obj.current.api_base)
|
86
|
+
response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
|
87
|
+
|
88
|
+
display_co_resources(
|
89
|
+
response,
|
90
|
+
custom_columns={
|
91
|
+
'ID': lambda c: c.id,
|
92
|
+
'USER_ID': lambda c: c.compute_cluster.user_id,
|
93
|
+
'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
|
94
|
+
'DESCRIPTION': lambda c: c.description,
|
95
|
+
})
|
96
|
+
|
97
|
+
|
98
|
+
@nodepool.command(['rm'])
|
99
|
+
@click.argument('compute_cluster_id')
|
100
|
+
@click.argument('nodepool_id')
|
84
101
|
@click.pass_context
|
85
102
|
def delete(ctx, compute_cluster_id, nodepool_id):
|
86
103
|
"""Deletes a nodepool for the user."""
|
104
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
87
105
|
|
88
106
|
validate_context(ctx)
|
89
107
|
compute_cluster = ComputeCluster(
|
90
108
|
compute_cluster_id=compute_cluster_id,
|
91
|
-
user_id=ctx.obj
|
92
|
-
pat=ctx.obj
|
93
|
-
base_url=ctx.obj
|
109
|
+
user_id=ctx.obj.current.user_id,
|
110
|
+
pat=ctx.obj.current.pat,
|
111
|
+
base_url=ctx.obj.current.api_base)
|
94
112
|
compute_cluster.delete_nodepools([nodepool_id])
|
clarifai/client/auth/stub.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import itertools
|
2
|
-
import logging
|
3
2
|
import time
|
4
3
|
from concurrent.futures import ThreadPoolExecutor
|
5
4
|
|
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
|
|
8
7
|
|
9
8
|
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
10
9
|
from clarifai.client.auth.register import RpcCallable, V2Stub
|
11
|
-
|
10
|
+
from clarifai.utils.logging import logger
|
12
11
|
throttle_status_codes = {
|
13
12
|
status_code_pb2.CONN_THROTTLED,
|
14
13
|
status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
|
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
|
|
26
25
|
def handle_simple_response(response):
|
27
26
|
if hasattr(response, 'status') and hasattr(response.status, 'code'):
|
28
27
|
if (response.status.code in throttle_status_codes) and attempt < max_attempts:
|
29
|
-
|
28
|
+
logger.debug('Retrying with status %s' % str(response.status))
|
30
29
|
return None # Indicates a retry is needed
|
31
30
|
else:
|
32
31
|
return response
|
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
|
|
42
41
|
return itertools.chain([validated_response], response)
|
43
42
|
return None # Indicates a retry is needed
|
44
43
|
except grpc.RpcError as e:
|
45
|
-
|
44
|
+
logger.error('Error processing streaming response: %s' % str(e))
|
46
45
|
return None # Indicates an error
|
47
46
|
else:
|
48
47
|
# Handle simple response validation
|
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
|
|
143
142
|
return v
|
144
143
|
except grpc.RpcError as e:
|
145
144
|
if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
|
146
|
-
|
145
|
+
logger.debug('Retrying with status %s' % e.code())
|
147
146
|
else:
|
148
147
|
raise
|
149
148
|
|
clarifai/client/dataset.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import logging
|
2
1
|
import os
|
3
2
|
import time
|
4
3
|
import uuid
|
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
|
|
354
353
|
break
|
355
354
|
if failed_input_ids:
|
356
355
|
retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
|
357
|
-
|
356
|
+
logger.warning(
|
358
357
|
f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
|
359
358
|
)
|
360
359
|
failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
|
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
|
|
494
493
|
add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
|
495
494
|
|
496
495
|
if retry_duplicates and duplicate_input_ids:
|
497
|
-
|
496
|
+
logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
|
498
497
|
duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
|
499
498
|
self.upload_dataset(
|
500
499
|
dataloader=dataloader,
|
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
|
|
505
504
|
|
506
505
|
if failed_input_ids:
|
507
506
|
#failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
|
508
|
-
|
507
|
+
logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
|
509
508
|
failed_input_indexes = [input["Index"] for input in failed_input_ids]
|
510
509
|
self.upload_dataset(
|
511
510
|
dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
|
@@ -142,17 +142,21 @@ class ModelBuilder:
|
|
142
142
|
def _validate_config_checkpoints(self):
|
143
143
|
"""
|
144
144
|
Validates the checkpoints section in the config file.
|
145
|
+
return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
|
145
146
|
:return: loader_type the type of loader or None if no checkpoints.
|
146
147
|
:return: repo_id location of checkpoint.
|
147
148
|
:return: hf_token token to access checkpoint.
|
149
|
+
:return: when one of ['upload', 'build', 'runtime'] to download checkpoint
|
150
|
+
:return: allowed_file_patterns patterns to allow in downloaded checkpoint
|
151
|
+
:return: ignore_file_patterns patterns to ignore in downloaded checkpoint
|
148
152
|
"""
|
149
153
|
if "checkpoints" not in self.config:
|
150
|
-
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
|
154
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
151
155
|
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
|
152
156
|
loader_type = self.config.get("checkpoints").get("type")
|
153
157
|
if not loader_type:
|
154
158
|
logger.info("No loader type specified in the config file for checkpoints")
|
155
|
-
return None, None, None
|
159
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
156
160
|
checkpoints = self.config.get("checkpoints")
|
157
161
|
if 'when' not in checkpoints:
|
158
162
|
logger.warn(
|
@@ -184,9 +188,17 @@ class ModelBuilder:
|
|
184
188
|
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
|
185
189
|
if resp.status.code == status_code_pb2.SUCCESS:
|
186
190
|
return True
|
191
|
+
if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
|
192
|
+
logger.error(
|
193
|
+
f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
|
194
|
+
)
|
195
|
+
return False
|
187
196
|
logger.error(
|
188
197
|
f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
|
189
198
|
)
|
199
|
+
logger.error(
|
200
|
+
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
|
201
|
+
)
|
190
202
|
return False
|
191
203
|
|
192
204
|
def _validate_config_model(self):
|
@@ -207,9 +219,6 @@ class ModelBuilder:
|
|
207
219
|
assert model.get('id') != "", "model_id cannot be empty in the config file"
|
208
220
|
|
209
221
|
if not self._check_app_exists():
|
210
|
-
logger.error(
|
211
|
-
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
|
212
|
-
)
|
213
222
|
sys.exit(1)
|
214
223
|
|
215
224
|
def _validate_config(self):
|
@@ -406,11 +415,12 @@ class ModelBuilder:
|
|
406
415
|
# Sort in reverse so that newer cuda versions come first and are preferred.
|
407
416
|
for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
|
408
417
|
if torch_version in image and f'py{python_version}' in image:
|
409
|
-
|
418
|
+
# like cu124, rocm6.3, etc.
|
419
|
+
gpu_version = image.split('-')[-1]
|
410
420
|
final_image = TORCH_BASE_IMAGE.format(
|
411
421
|
torch_version=torch_version,
|
412
422
|
python_version=python_version,
|
413
|
-
|
423
|
+
gpu_version=gpu_version,
|
414
424
|
)
|
415
425
|
logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
|
416
426
|
break
|
clarifai/runners/utils/const.py
CHANGED
@@ -2,10 +2,10 @@ import os
|
|
2
2
|
|
3
3
|
registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
|
4
4
|
|
5
|
-
GIT_SHA = "
|
5
|
+
GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
|
6
6
|
|
7
7
|
PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
|
8
|
-
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-
|
8
|
+
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
|
9
9
|
|
10
10
|
# List of available python base images
|
11
11
|
AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
|
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
|
|
21
21
|
# List of available torch images
|
22
22
|
# Keep sorted by most recent cuda version.
|
23
23
|
AVAILABLE_TORCH_IMAGES = [
|
24
|
-
'2.4.1-py3.11-
|
25
|
-
'2.5.1-py3.11-
|
26
|
-
'2.4.1-py3.12-
|
27
|
-
'2.5.1-py3.12-
|
28
|
-
|
29
|
-
|
24
|
+
'2.4.1-py3.11-cu124',
|
25
|
+
'2.5.1-py3.11-cu124',
|
26
|
+
'2.4.1-py3.12-cu124',
|
27
|
+
'2.5.1-py3.12-cu124',
|
28
|
+
'2.6.0-py3.12-cu126',
|
29
|
+
'2.7.0-py3.12-cu128',
|
30
|
+
'2.7.0-py3.12-rocm6.3',
|
30
31
|
]
|
31
32
|
CONCEPTS_REQUIRED_MODEL_TYPE = [
|
32
33
|
'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
|
clarifai/utils/cli.py
CHANGED
@@ -2,16 +2,13 @@ import importlib
|
|
2
2
|
import os
|
3
3
|
import pkgutil
|
4
4
|
import sys
|
5
|
+
import typing as t
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import OrderedDict
|
5
8
|
|
6
9
|
import click
|
7
10
|
import yaml
|
8
|
-
|
9
|
-
from rich.console import Console
|
10
|
-
from rich.panel import Panel
|
11
|
-
from rich.style import Style
|
12
|
-
from rich.text import Text
|
13
|
-
|
14
|
-
from clarifai.utils.logging import logger
|
11
|
+
from tabulate import tabulate
|
15
12
|
|
16
13
|
|
17
14
|
def from_yaml(filename: str):
|
@@ -31,19 +28,6 @@ def dump_yaml(data, filename: str):
|
|
31
28
|
click.echo(f"Error writing YAML file: {e}", err=True)
|
32
29
|
|
33
30
|
|
34
|
-
def set_base_url(env):
|
35
|
-
environments = {
|
36
|
-
'prod': 'https://api.clarifai.com',
|
37
|
-
'staging': 'https://api-staging.clarifai.com',
|
38
|
-
'dev': 'https://api-dev.clarifai.com'
|
39
|
-
}
|
40
|
-
|
41
|
-
if env in environments:
|
42
|
-
return environments[env]
|
43
|
-
else:
|
44
|
-
raise ValueError("Invalid environment. Please choose from 'prod', 'staging', 'dev'.")
|
45
|
-
|
46
|
-
|
47
31
|
# Dynamically find and import all command modules from the cli directory
|
48
32
|
def load_command_modules():
|
49
33
|
package_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'cli')
|
@@ -53,27 +37,132 @@ def load_command_modules():
|
|
53
37
|
importlib.import_module(f'clarifai.cli.{module_name}')
|
54
38
|
|
55
39
|
|
56
|
-
def display_co_resources(response,
|
40
|
+
def display_co_resources(response,
|
41
|
+
custom_columns={
|
42
|
+
'ID': lambda c: c.id,
|
43
|
+
'USER_ID': lambda c: c.user_id,
|
44
|
+
'DESCRIPTION': lambda c: c.description,
|
45
|
+
}):
|
57
46
|
"""Display compute orchestration resources listing results using rich."""
|
58
47
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
48
|
+
formatter = TableFormatter(custom_columns)
|
49
|
+
print(formatter.format(list(response), fmt="plain"))
|
50
|
+
|
51
|
+
|
52
|
+
class TableFormatter:
|
53
|
+
|
54
|
+
def __init__(self, custom_columns: OrderedDict):
|
55
|
+
"""
|
56
|
+
Initializes the TableFormatter with column headers and custom column mappings.
|
57
|
+
|
58
|
+
:param headers: List of column headers for the table.
|
59
|
+
"""
|
60
|
+
self.custom_columns = custom_columns
|
61
|
+
|
62
|
+
def format(self, objects, fmt='plain'):
|
63
|
+
"""
|
64
|
+
Formats a list of objects into a table with custom columns.
|
65
|
+
|
66
|
+
:param objects: List of objects to format into a table.
|
67
|
+
:return: A string representing the table.
|
68
|
+
"""
|
69
|
+
# Prepare the rows by applying the custom column functions to each object
|
70
|
+
rows = []
|
71
|
+
for obj in objects:
|
72
|
+
# row = [self.custom_columns[header](obj) for header in self.headers]
|
73
|
+
row = [f(obj) for f in self.custom_columns.values()]
|
74
|
+
rows.append(row)
|
75
|
+
|
76
|
+
# Create the table
|
77
|
+
table = tabulate(rows, headers=self.custom_columns.keys(), tablefmt=fmt)
|
78
|
+
return table
|
79
|
+
|
80
|
+
|
81
|
+
class AliasedGroup(click.Group):
|
82
|
+
|
83
|
+
def __init__(self,
|
84
|
+
name: t.Optional[str] = None,
|
85
|
+
commands: t.Optional[t.Union[t.MutableMapping[str, click.Command], t.Sequence[
|
86
|
+
click.Command]]] = None,
|
87
|
+
**attrs: t.Any) -> None:
|
88
|
+
super().__init__(name, commands, **attrs)
|
89
|
+
self.alias_map = {}
|
90
|
+
self.command_to_aliases = defaultdict(list)
|
91
|
+
|
92
|
+
def add_alias(self, cmd: click.Command, alias: str) -> None:
|
93
|
+
self.alias_map[alias] = cmd
|
94
|
+
if alias != cmd.name:
|
95
|
+
self.command_to_aliases[cmd].append(alias)
|
96
|
+
|
97
|
+
def command(self, aliases=None, *args,
|
98
|
+
**kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Command]:
|
99
|
+
cmd_decorator = super().command(*args, **kwargs)
|
100
|
+
if aliases is None:
|
101
|
+
aliases = []
|
102
|
+
|
103
|
+
def aliased_decorator(f):
|
104
|
+
cmd = cmd_decorator(f)
|
105
|
+
if cmd.name:
|
106
|
+
self.add_alias(cmd, cmd.name)
|
107
|
+
for alias in aliases:
|
108
|
+
self.add_alias(cmd, alias)
|
109
|
+
return cmd
|
110
|
+
|
111
|
+
f = None
|
112
|
+
if args and callable(args[0]):
|
113
|
+
(f,) = args
|
114
|
+
if f is not None:
|
115
|
+
return aliased_decorator(f)
|
116
|
+
return aliased_decorator
|
117
|
+
|
118
|
+
def group(self, aliases=None, *args,
|
119
|
+
**kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Group]:
|
120
|
+
cmd_decorator = super().group(*args, **kwargs)
|
121
|
+
if aliases is None:
|
122
|
+
aliases = []
|
123
|
+
|
124
|
+
def aliased_decorator(f):
|
125
|
+
cmd = cmd_decorator(f)
|
126
|
+
if cmd.name:
|
127
|
+
self.add_alias(cmd, cmd.name)
|
128
|
+
for alias in aliases:
|
129
|
+
self.add_alias(cmd, alias)
|
130
|
+
return cmd
|
131
|
+
|
132
|
+
f = None
|
133
|
+
if args and callable(args[0]):
|
134
|
+
(f,) = args
|
135
|
+
if f is not None:
|
136
|
+
return aliased_decorator(f)
|
137
|
+
return aliased_decorator
|
138
|
+
|
139
|
+
def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]:
|
140
|
+
rv = click.Group.get_command(self, ctx, cmd_name)
|
141
|
+
if rv is not None:
|
142
|
+
return rv
|
143
|
+
return self.alias_map.get(cmd_name)
|
144
|
+
|
145
|
+
def format_commands(self, ctx, formatter):
|
146
|
+
sub_commands = self.list_commands(ctx)
|
147
|
+
|
148
|
+
rows = []
|
149
|
+
for sub_command in sub_commands:
|
150
|
+
cmd = self.get_command(ctx, sub_command)
|
151
|
+
if cmd is None or getattr(cmd, 'hidden', False):
|
152
|
+
continue
|
153
|
+
if cmd in self.command_to_aliases:
|
154
|
+
aliases = ', '.join(self.command_to_aliases[cmd])
|
155
|
+
sub_command = f'{sub_command} ({aliases})'
|
156
|
+
cmd_help = cmd.help
|
157
|
+
rows.append((sub_command, cmd_help))
|
158
|
+
|
159
|
+
if rows:
|
160
|
+
with formatter.section("Commands"):
|
161
|
+
formatter.write_dl(rows)
|
74
162
|
|
75
163
|
|
76
164
|
def validate_context(ctx):
|
165
|
+
from clarifai.utils.logging import logger
|
77
166
|
if ctx.obj == {}:
|
78
167
|
logger.error("CLI config file missing. Run `clarifai login` to set up the CLI config.")
|
79
168
|
sys.exit(1)
|
clarifai/utils/config.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
import os
|
2
|
+
from collections import OrderedDict
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
from clarifai.utils.constants import DEFAULT_CONFIG
|
8
|
+
|
9
|
+
|
10
|
+
class Context(OrderedDict):
|
11
|
+
"""
|
12
|
+
A context which has a name and a set of key-values as a dict under env.
|
13
|
+
|
14
|
+
You can access the keys directly.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, name, **kwargs):
|
18
|
+
self['name'] = name
|
19
|
+
# when loading from config we may have the env: section in yaml already so we get it here.
|
20
|
+
if 'env' in kwargs:
|
21
|
+
self['env'] = kwargs['env']
|
22
|
+
else: # when consructing as Context(name, key=value) we set it here.
|
23
|
+
self['env'] = kwargs
|
24
|
+
|
25
|
+
def __getattr__(self, key):
|
26
|
+
try:
|
27
|
+
if key == 'name':
|
28
|
+
return self[key]
|
29
|
+
if key == 'env':
|
30
|
+
raise AttributeError("Don't access .env directly")
|
31
|
+
|
32
|
+
# Allow accessing CLARIFAI_PAT type env var names from config as .pat
|
33
|
+
envvar_name = 'CLARIFAI_' + key.upper()
|
34
|
+
env = self['env']
|
35
|
+
if envvar_name in env:
|
36
|
+
value = env[envvar_name]
|
37
|
+
if value == "ENVVAR":
|
38
|
+
return os.environ[envvar_name]
|
39
|
+
else:
|
40
|
+
value = env[key]
|
41
|
+
|
42
|
+
if isinstance(value, dict):
|
43
|
+
return Context(value)
|
44
|
+
|
45
|
+
return value
|
46
|
+
except KeyError as e:
|
47
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
|
48
|
+
|
49
|
+
def __setattr__(self, key, value):
|
50
|
+
if key == "name":
|
51
|
+
self['name'] = value
|
52
|
+
else:
|
53
|
+
self['env'][key] = value
|
54
|
+
|
55
|
+
def __delattr__(self, key):
|
56
|
+
try:
|
57
|
+
del self['env'][key]
|
58
|
+
except KeyError as e:
|
59
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
|
60
|
+
|
61
|
+
def to_serializable_dict(self):
|
62
|
+
return dict(self['env'])
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class Config():
|
67
|
+
current_context: str
|
68
|
+
filename: str
|
69
|
+
contexts: OrderedDict[str, Context] = field(default_factory=OrderedDict)
|
70
|
+
|
71
|
+
def __post_init__(self):
|
72
|
+
for k, v in self.contexts.items():
|
73
|
+
if 'name' not in v:
|
74
|
+
v['name'] = k
|
75
|
+
self.contexts = {k: Context(**v) for k, v in self.contexts.items()}
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def from_yaml(cls, filename: str = DEFAULT_CONFIG):
|
79
|
+
with open(filename, 'r') as f:
|
80
|
+
cfg = yaml.safe_load(f)
|
81
|
+
return cls(**cfg, filename=filename)
|
82
|
+
|
83
|
+
def to_dict(self):
|
84
|
+
return {
|
85
|
+
'current_context': self.current_context,
|
86
|
+
'contexts': {k: v.to_serializable_dict()
|
87
|
+
for k, v in self.contexts.items()}
|
88
|
+
}
|
89
|
+
|
90
|
+
def to_yaml(self, filename: str = None):
|
91
|
+
if filename is None:
|
92
|
+
filename = self.filename
|
93
|
+
dir = os.path.dirname(filename)
|
94
|
+
if len(dir):
|
95
|
+
os.makedirs(dir, exist_ok=True)
|
96
|
+
_dict = self.to_dict()
|
97
|
+
for k, v in _dict['contexts'].items():
|
98
|
+
v.pop('name', None)
|
99
|
+
with open(filename, 'w') as f:
|
100
|
+
yaml.safe_dump(_dict, f)
|
101
|
+
|
102
|
+
@property
|
103
|
+
def current(self) -> Context:
|
104
|
+
""" get the current Context """
|
105
|
+
return self.contexts[self.current_context]
|