clarifai 11.2.3rc7__py3-none-any.whl → 11.2.3rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +228 -81
- clarifai/cli/compute_cluster.py +28 -18
- clarifai/cli/deployment.py +70 -42
- clarifai/cli/model.py +26 -14
- clarifai/cli/nodepool.py +62 -41
- clarifai/client/app.py +1 -1
- clarifai/client/auth/stub.py +4 -5
- clarifai/client/dataset.py +3 -4
- clarifai/client/model_client.py +35 -6
- clarifai/runners/models/model_builder.py +42 -59
- clarifai/runners/models/model_class.py +6 -5
- clarifai/runners/models/model_run_locally.py +11 -23
- clarifai/runners/utils/const.py +9 -8
- clarifai/runners/utils/data_types.py +0 -4
- clarifai/runners/utils/data_utils.py +1 -1
- clarifai/runners/utils/loader.py +36 -6
- clarifai/runners/utils/method_signatures.py +2 -1
- clarifai/runners/utils/openai_format.py +87 -0
- clarifai/utils/cli.py +132 -34
- clarifai/utils/config.py +105 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +64 -21
- clarifai/utils/misc.py +2 -0
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/METADATA +1 -1
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/RECORD +30 -28
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/LICENSE +0 -0
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/WHEEL +0 -0
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/entry_points.txt +0 -0
- {clarifai-11.2.3rc7.dist-info → clarifai-11.2.3rc8.dist-info}/top_level.txt +0 -0
clarifai/runners/utils/loader.py
CHANGED
@@ -6,6 +6,7 @@ import shutil
|
|
6
6
|
|
7
7
|
import requests
|
8
8
|
|
9
|
+
from clarifai.runners.utils.const import CONCEPTS_REQUIRED_MODEL_TYPE
|
9
10
|
from clarifai.utils.logging import logger
|
10
11
|
|
11
12
|
|
@@ -13,9 +14,10 @@ class HuggingFaceLoader:
|
|
13
14
|
|
14
15
|
HF_DOWNLOAD_TEXT = "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
|
15
16
|
|
16
|
-
def __init__(self, repo_id=None, token=None):
|
17
|
+
def __init__(self, repo_id=None, token=None, model_type_id=None):
|
17
18
|
self.repo_id = repo_id
|
18
19
|
self.token = token
|
20
|
+
self.clarifai_model_type_id = model_type_id
|
19
21
|
if token:
|
20
22
|
if self.validate_hftoken(token):
|
21
23
|
try:
|
@@ -43,13 +45,17 @@ class HuggingFaceLoader:
|
|
43
45
|
f"Error setting up Hugging Face token, please make sure you have the correct token: {e}")
|
44
46
|
return False
|
45
47
|
|
46
|
-
def download_checkpoints(self,
|
48
|
+
def download_checkpoints(self,
|
49
|
+
checkpoint_path: str,
|
50
|
+
allowed_file_patterns=None,
|
51
|
+
ignore_file_patterns=None):
|
47
52
|
# throw error if huggingface_hub wasn't installed
|
48
53
|
try:
|
49
54
|
from huggingface_hub import snapshot_download
|
50
55
|
except ImportError:
|
51
56
|
raise ImportError(self.HF_DOWNLOAD_TEXT)
|
52
|
-
if os.path.exists(checkpoint_path) and self.validate_download(
|
57
|
+
if os.path.exists(checkpoint_path) and self.validate_download(
|
58
|
+
checkpoint_path, allowed_file_patterns, ignore_file_patterns):
|
53
59
|
logger.info("Checkpoints already exist")
|
54
60
|
return True
|
55
61
|
else:
|
@@ -61,10 +67,16 @@ class HuggingFaceLoader:
|
|
61
67
|
return False
|
62
68
|
|
63
69
|
self.ignore_patterns = self._get_ignore_patterns()
|
70
|
+
if ignore_file_patterns:
|
71
|
+
if self.ignore_patterns:
|
72
|
+
self.ignore_patterns.extend(ignore_file_patterns)
|
73
|
+
else:
|
74
|
+
self.ignore_patterns = ignore_file_patterns
|
64
75
|
snapshot_download(
|
65
76
|
repo_id=self.repo_id,
|
66
77
|
local_dir=checkpoint_path,
|
67
78
|
local_dir_use_symlinks=False,
|
79
|
+
allow_patterns=allowed_file_patterns,
|
68
80
|
ignore_patterns=self.ignore_patterns)
|
69
81
|
# Remove the `.cache` folder if it exists
|
70
82
|
cache_path = os.path.join(checkpoint_path, ".cache")
|
@@ -75,7 +87,8 @@ class HuggingFaceLoader:
|
|
75
87
|
logger.error(f"Error downloading model checkpoints {e}")
|
76
88
|
return False
|
77
89
|
finally:
|
78
|
-
is_downloaded = self.validate_download(checkpoint_path
|
90
|
+
is_downloaded = self.validate_download(checkpoint_path, allowed_file_patterns,
|
91
|
+
ignore_file_patterns)
|
79
92
|
if not is_downloaded:
|
80
93
|
logger.error("Error validating downloaded model checkpoints")
|
81
94
|
return False
|
@@ -109,9 +122,13 @@ class HuggingFaceLoader:
|
|
109
122
|
from huggingface_hub import file_exists, repo_exists
|
110
123
|
except ImportError:
|
111
124
|
raise ImportError(self.HF_DOWNLOAD_TEXT)
|
112
|
-
|
125
|
+
if self.clarifai_model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
|
126
|
+
return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
|
127
|
+
else:
|
128
|
+
return repo_exists(self.repo_id)
|
113
129
|
|
114
|
-
def validate_download(self, checkpoint_path: str
|
130
|
+
def validate_download(self, checkpoint_path: str, allowed_file_patterns: list,
|
131
|
+
ignore_file_patterns: list):
|
115
132
|
# check if model exists on HF
|
116
133
|
try:
|
117
134
|
from huggingface_hub import list_repo_files
|
@@ -120,7 +137,20 @@ class HuggingFaceLoader:
|
|
120
137
|
# Get the list of files on the repo
|
121
138
|
repo_files = list_repo_files(self.repo_id, token=self.token)
|
122
139
|
|
140
|
+
# Get the list of files on the repo that are allowed
|
141
|
+
if allowed_file_patterns:
|
142
|
+
|
143
|
+
def should_allow(file_path):
|
144
|
+
return any(fnmatch.fnmatch(file_path, pattern) for pattern in allowed_file_patterns)
|
145
|
+
|
146
|
+
repo_files = [f for f in repo_files if should_allow(f)]
|
147
|
+
|
123
148
|
self.ignore_patterns = self._get_ignore_patterns()
|
149
|
+
if ignore_file_patterns:
|
150
|
+
if self.ignore_patterns:
|
151
|
+
self.ignore_patterns.extend(ignore_file_patterns)
|
152
|
+
else:
|
153
|
+
self.ignore_patterns = ignore_file_patterns
|
124
154
|
# Get the list of files on the repo that are not ignored
|
125
155
|
if getattr(self, "ignore_patterns", None):
|
126
156
|
patterns = self.ignore_patterns
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import collections.abc as abc
|
1
2
|
import inspect
|
2
3
|
import json
|
3
4
|
from collections import namedtuple
|
@@ -354,7 +355,7 @@ def _normalize_type(tp):
|
|
354
355
|
'''
|
355
356
|
# stream type indicates streaming, not part of the data itself
|
356
357
|
# it can only be used at the top-level of the var type
|
357
|
-
streaming = (get_origin(tp)
|
358
|
+
streaming = (get_origin(tp) in [abc.Iterator, abc.Generator, abc.Iterable])
|
358
359
|
if streaming:
|
359
360
|
tp = get_args(tp)[0]
|
360
361
|
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import time
|
2
|
+
import uuid
|
3
|
+
|
4
|
+
def generate_id():
|
5
|
+
return f"chatcmpl-{uuid.uuid4().hex}"
|
6
|
+
|
7
|
+
def format_non_streaming_response(
|
8
|
+
generated_text,
|
9
|
+
model="custom-model",
|
10
|
+
id=None,
|
11
|
+
created=None,
|
12
|
+
usage=None,
|
13
|
+
finish_reason="stop",
|
14
|
+
):
|
15
|
+
if id is None:
|
16
|
+
id = generate_id()
|
17
|
+
if created is None:
|
18
|
+
created = int(time.time())
|
19
|
+
|
20
|
+
response = {
|
21
|
+
"id": id,
|
22
|
+
"object": "chat.completion",
|
23
|
+
"created": created,
|
24
|
+
"model": model,
|
25
|
+
"choices": [
|
26
|
+
{
|
27
|
+
"index": 0,
|
28
|
+
"message": {
|
29
|
+
"role": "assistant",
|
30
|
+
"content": generated_text,
|
31
|
+
},
|
32
|
+
"finish_reason": finish_reason,
|
33
|
+
"logprobs": None,
|
34
|
+
}
|
35
|
+
],
|
36
|
+
}
|
37
|
+
|
38
|
+
if usage is not None:
|
39
|
+
response["usage"] = usage
|
40
|
+
|
41
|
+
return response
|
42
|
+
|
43
|
+
def format_streaming_response(
|
44
|
+
generated_chunks,
|
45
|
+
model="custom-model",
|
46
|
+
id=None,
|
47
|
+
created=None,
|
48
|
+
finish_reason="stop",
|
49
|
+
):
|
50
|
+
if id is None:
|
51
|
+
id = generate_id()
|
52
|
+
if created is None:
|
53
|
+
created = int(time.time())
|
54
|
+
|
55
|
+
for chunk in generated_chunks:
|
56
|
+
yield {
|
57
|
+
"id": id,
|
58
|
+
"object": "chat.completion.chunk",
|
59
|
+
"created": created,
|
60
|
+
"model": model,
|
61
|
+
"choices": [
|
62
|
+
{
|
63
|
+
"index": 0,
|
64
|
+
"delta": {
|
65
|
+
"content": chunk,
|
66
|
+
},
|
67
|
+
"finish_reason": None,
|
68
|
+
"logprobs": None,
|
69
|
+
}
|
70
|
+
],
|
71
|
+
}
|
72
|
+
|
73
|
+
# Final chunk indicating completion
|
74
|
+
yield {
|
75
|
+
"id": id,
|
76
|
+
"object": "chat.completion.chunk",
|
77
|
+
"created": created,
|
78
|
+
"model": model,
|
79
|
+
"choices": [
|
80
|
+
{
|
81
|
+
"index": 0,
|
82
|
+
"delta": {},
|
83
|
+
"finish_reason": finish_reason,
|
84
|
+
"logprobs": None,
|
85
|
+
}
|
86
|
+
],
|
87
|
+
}
|
clarifai/utils/cli.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
import importlib
|
2
2
|
import os
|
3
3
|
import pkgutil
|
4
|
+
import sys
|
5
|
+
import typing as t
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import OrderedDict
|
4
8
|
|
5
9
|
import click
|
6
10
|
import yaml
|
7
|
-
|
8
|
-
from rich.console import Console
|
9
|
-
from rich.panel import Panel
|
10
|
-
from rich.style import Style
|
11
|
-
from rich.text import Text
|
11
|
+
from tabulate import tabulate
|
12
12
|
|
13
13
|
|
14
14
|
def from_yaml(filename: str):
|
@@ -28,19 +28,6 @@ def dump_yaml(data, filename: str):
|
|
28
28
|
click.echo(f"Error writing YAML file: {e}", err=True)
|
29
29
|
|
30
30
|
|
31
|
-
def set_base_url(env):
|
32
|
-
environments = {
|
33
|
-
'prod': 'https://api.clarifai.com',
|
34
|
-
'staging': 'https://api-staging.clarifai.com',
|
35
|
-
'dev': 'https://api-dev.clarifai.com'
|
36
|
-
}
|
37
|
-
|
38
|
-
if env in environments:
|
39
|
-
return environments[env]
|
40
|
-
else:
|
41
|
-
raise ValueError("Invalid environment. Please choose from 'prod', 'staging', 'dev'.")
|
42
|
-
|
43
|
-
|
44
31
|
# Dynamically find and import all command modules from the cli directory
|
45
32
|
def load_command_modules():
|
46
33
|
package_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'cli')
|
@@ -50,21 +37,132 @@ def load_command_modules():
|
|
50
37
|
importlib.import_module(f'clarifai.cli.{module_name}')
|
51
38
|
|
52
39
|
|
53
|
-
def display_co_resources(response,
|
40
|
+
def display_co_resources(response,
|
41
|
+
custom_columns={
|
42
|
+
'ID': lambda c: c.id,
|
43
|
+
'USER_ID': lambda c: c.user_id,
|
44
|
+
'DESCRIPTION': lambda c: c.description,
|
45
|
+
}):
|
54
46
|
"""Display compute orchestration resources listing results using rich."""
|
55
47
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
48
|
+
formatter = TableFormatter(custom_columns)
|
49
|
+
print(formatter.format(list(response), fmt="plain"))
|
50
|
+
|
51
|
+
|
52
|
+
class TableFormatter:
|
53
|
+
|
54
|
+
def __init__(self, custom_columns: OrderedDict):
|
55
|
+
"""
|
56
|
+
Initializes the TableFormatter with column headers and custom column mappings.
|
57
|
+
|
58
|
+
:param headers: List of column headers for the table.
|
59
|
+
"""
|
60
|
+
self.custom_columns = custom_columns
|
61
|
+
|
62
|
+
def format(self, objects, fmt='plain'):
|
63
|
+
"""
|
64
|
+
Formats a list of objects into a table with custom columns.
|
65
|
+
|
66
|
+
:param objects: List of objects to format into a table.
|
67
|
+
:return: A string representing the table.
|
68
|
+
"""
|
69
|
+
# Prepare the rows by applying the custom column functions to each object
|
70
|
+
rows = []
|
71
|
+
for obj in objects:
|
72
|
+
# row = [self.custom_columns[header](obj) for header in self.headers]
|
73
|
+
row = [f(obj) for f in self.custom_columns.values()]
|
74
|
+
rows.append(row)
|
75
|
+
|
76
|
+
# Create the table
|
77
|
+
table = tabulate(rows, headers=self.custom_columns.keys(), tablefmt=fmt)
|
78
|
+
return table
|
79
|
+
|
80
|
+
|
81
|
+
class AliasedGroup(click.Group):
|
82
|
+
|
83
|
+
def __init__(self,
|
84
|
+
name: t.Optional[str] = None,
|
85
|
+
commands: t.Optional[t.Union[t.MutableMapping[str, click.Command], t.Sequence[
|
86
|
+
click.Command]]] = None,
|
87
|
+
**attrs: t.Any) -> None:
|
88
|
+
super().__init__(name, commands, **attrs)
|
89
|
+
self.alias_map = {}
|
90
|
+
self.command_to_aliases = defaultdict(list)
|
91
|
+
|
92
|
+
def add_alias(self, cmd: click.Command, alias: str) -> None:
|
93
|
+
self.alias_map[alias] = cmd
|
94
|
+
if alias != cmd.name:
|
95
|
+
self.command_to_aliases[cmd].append(alias)
|
96
|
+
|
97
|
+
def command(self, aliases=None, *args,
|
98
|
+
**kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Command]:
|
99
|
+
cmd_decorator = super().command(*args, **kwargs)
|
100
|
+
if aliases is None:
|
101
|
+
aliases = []
|
102
|
+
|
103
|
+
def aliased_decorator(f):
|
104
|
+
cmd = cmd_decorator(f)
|
105
|
+
if cmd.name:
|
106
|
+
self.add_alias(cmd, cmd.name)
|
107
|
+
for alias in aliases:
|
108
|
+
self.add_alias(cmd, alias)
|
109
|
+
return cmd
|
110
|
+
|
111
|
+
f = None
|
112
|
+
if args and callable(args[0]):
|
113
|
+
(f,) = args
|
114
|
+
if f is not None:
|
115
|
+
return aliased_decorator(f)
|
116
|
+
return aliased_decorator
|
117
|
+
|
118
|
+
def group(self, aliases=None, *args,
|
119
|
+
**kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Group]:
|
120
|
+
cmd_decorator = super().group(*args, **kwargs)
|
121
|
+
if aliases is None:
|
122
|
+
aliases = []
|
123
|
+
|
124
|
+
def aliased_decorator(f):
|
125
|
+
cmd = cmd_decorator(f)
|
126
|
+
if cmd.name:
|
127
|
+
self.add_alias(cmd, cmd.name)
|
128
|
+
for alias in aliases:
|
129
|
+
self.add_alias(cmd, alias)
|
130
|
+
return cmd
|
131
|
+
|
132
|
+
f = None
|
133
|
+
if args and callable(args[0]):
|
134
|
+
(f,) = args
|
135
|
+
if f is not None:
|
136
|
+
return aliased_decorator(f)
|
137
|
+
return aliased_decorator
|
138
|
+
|
139
|
+
def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]:
|
140
|
+
rv = click.Group.get_command(self, ctx, cmd_name)
|
141
|
+
if rv is not None:
|
142
|
+
return rv
|
143
|
+
return self.alias_map.get(cmd_name)
|
144
|
+
|
145
|
+
def format_commands(self, ctx, formatter):
|
146
|
+
sub_commands = self.list_commands(ctx)
|
147
|
+
|
148
|
+
rows = []
|
149
|
+
for sub_command in sub_commands:
|
150
|
+
cmd = self.get_command(ctx, sub_command)
|
151
|
+
if cmd is None or getattr(cmd, 'hidden', False):
|
152
|
+
continue
|
153
|
+
if cmd in self.command_to_aliases:
|
154
|
+
aliases = ', '.join(self.command_to_aliases[cmd])
|
155
|
+
sub_command = f'{sub_command} ({aliases})'
|
156
|
+
cmd_help = cmd.help
|
157
|
+
rows.append((sub_command, cmd_help))
|
158
|
+
|
159
|
+
if rows:
|
160
|
+
with formatter.section("Commands"):
|
161
|
+
formatter.write_dl(rows)
|
162
|
+
|
163
|
+
|
164
|
+
def validate_context(ctx):
|
165
|
+
from clarifai.utils.logging import logger
|
166
|
+
if ctx.obj == {}:
|
167
|
+
logger.error("CLI config file missing. Run `clarifai login` to set up the CLI config.")
|
168
|
+
sys.exit(1)
|
clarifai/utils/config.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
import os
|
2
|
+
from collections import OrderedDict
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
|
5
|
+
import yaml
|
6
|
+
|
7
|
+
from clarifai.utils.constants import DEFAULT_CONFIG
|
8
|
+
|
9
|
+
|
10
|
+
class Context(OrderedDict):
|
11
|
+
"""
|
12
|
+
A context which has a name and a set of key-values as a dict under env.
|
13
|
+
|
14
|
+
You can access the keys directly.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, name, **kwargs):
|
18
|
+
self['name'] = name
|
19
|
+
# when loading from config we may have the env: section in yaml already so we get it here.
|
20
|
+
if 'env' in kwargs:
|
21
|
+
self['env'] = kwargs['env']
|
22
|
+
else: # when consructing as Context(name, key=value) we set it here.
|
23
|
+
self['env'] = kwargs
|
24
|
+
|
25
|
+
def __getattr__(self, key):
|
26
|
+
try:
|
27
|
+
if key == 'name':
|
28
|
+
return self[key]
|
29
|
+
if key == 'env':
|
30
|
+
raise AttributeError("Don't access .env directly")
|
31
|
+
|
32
|
+
# Allow accessing CLARIFAI_PAT type env var names from config as .pat
|
33
|
+
envvar_name = 'CLARIFAI_' + key.upper()
|
34
|
+
env = self['env']
|
35
|
+
if envvar_name in env:
|
36
|
+
value = env[envvar_name]
|
37
|
+
if value == "ENVVAR":
|
38
|
+
return os.environ[envvar_name]
|
39
|
+
else:
|
40
|
+
value = env[key]
|
41
|
+
|
42
|
+
if isinstance(value, dict):
|
43
|
+
return Context(value)
|
44
|
+
|
45
|
+
return value
|
46
|
+
except KeyError as e:
|
47
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
|
48
|
+
|
49
|
+
def __setattr__(self, key, value):
|
50
|
+
if key == "name":
|
51
|
+
self['name'] = value
|
52
|
+
else:
|
53
|
+
self['env'][key] = value
|
54
|
+
|
55
|
+
def __delattr__(self, key):
|
56
|
+
try:
|
57
|
+
del self['env'][key]
|
58
|
+
except KeyError as e:
|
59
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
|
60
|
+
|
61
|
+
def to_serializable_dict(self):
|
62
|
+
return dict(self['env'])
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class Config():
|
67
|
+
current_context: str
|
68
|
+
filename: str
|
69
|
+
contexts: OrderedDict[str, Context] = field(default_factory=OrderedDict)
|
70
|
+
|
71
|
+
def __post_init__(self):
|
72
|
+
for k, v in self.contexts.items():
|
73
|
+
if 'name' not in v:
|
74
|
+
v['name'] = k
|
75
|
+
self.contexts = {k: Context(**v) for k, v in self.contexts.items()}
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def from_yaml(cls, filename: str = DEFAULT_CONFIG):
|
79
|
+
with open(filename, 'r') as f:
|
80
|
+
cfg = yaml.safe_load(f)
|
81
|
+
return cls(**cfg, filename=filename)
|
82
|
+
|
83
|
+
def to_dict(self):
|
84
|
+
return {
|
85
|
+
'current_context': self.current_context,
|
86
|
+
'contexts': {k: v.to_serializable_dict()
|
87
|
+
for k, v in self.contexts.items()}
|
88
|
+
}
|
89
|
+
|
90
|
+
def to_yaml(self, filename: str = None):
|
91
|
+
if filename is None:
|
92
|
+
filename = self.filename
|
93
|
+
dir = os.path.dirname(filename)
|
94
|
+
if len(dir):
|
95
|
+
os.makedirs(dir, exist_ok=True)
|
96
|
+
_dict = self.to_dict()
|
97
|
+
for k, v in _dict['contexts'].items():
|
98
|
+
v.pop('name', None)
|
99
|
+
with open(filename, 'w') as f:
|
100
|
+
yaml.safe_dump(_dict, f)
|
101
|
+
|
102
|
+
@property
|
103
|
+
def current(self) -> Context:
|
104
|
+
""" get the current Context """
|
105
|
+
return self.contexts[self.current_context]
|
clarifai/utils/constants.py
CHANGED
clarifai/utils/logging.py
CHANGED
@@ -10,15 +10,6 @@ import traceback
|
|
10
10
|
from collections import defaultdict
|
11
11
|
from typing import Any, Dict, List, Optional, Union
|
12
12
|
|
13
|
-
from rich import print as rprint
|
14
|
-
from rich.console import Console
|
15
|
-
from rich.logging import RichHandler
|
16
|
-
from rich.table import Table
|
17
|
-
from rich.traceback import install
|
18
|
-
from rich.tree import Tree
|
19
|
-
|
20
|
-
install()
|
21
|
-
|
22
13
|
# The default logger to use throughout the SDK is defined at bottom of this file.
|
23
14
|
|
24
15
|
# For the json logger.
|
@@ -29,6 +20,20 @@ FIELD_BLACKLIST = [
|
|
29
20
|
'msg', 'message', 'account', 'levelno', 'created', 'threadName', 'name', 'processName',
|
30
21
|
'module', 'funcName', 'msecs', 'relativeCreated', 'pathname', 'args', 'thread', 'process'
|
31
22
|
]
|
23
|
+
COLORS = {
|
24
|
+
'ARGUMENTS': '\033[90m', # Gray
|
25
|
+
'DEBUG': '\033[90m', # Gray
|
26
|
+
'INFO': '\033[32m', # Green
|
27
|
+
'WARNING': '\033[33m', # Yellow
|
28
|
+
'ERROR': '\033[31m', # Red
|
29
|
+
'CRITICAL': '\033[31m', # Red
|
30
|
+
'TIME': '\033[34m',
|
31
|
+
'RESET': '\033[0m'
|
32
|
+
}
|
33
|
+
LOG_FORMAT = f"[%(levelname)s] {COLORS.get('TIME')}%(asctime)s{COLORS.get('RESET')} %(message)s |" \
|
34
|
+
f"{COLORS.get('ARGUMENTS')} " \
|
35
|
+
f"%(optional_args)s " \
|
36
|
+
f"thread=%(thread)d {COLORS.get('RESET')}"
|
32
37
|
|
33
38
|
# Create thread local storage that the format() call below uses.
|
34
39
|
# This is only used by the json_logger in the appropriate CLARIFAI_DEPLOY levels.
|
@@ -59,6 +64,9 @@ def get_req_id_from_context():
|
|
59
64
|
|
60
65
|
def display_workflow_tree(nodes_data: List[Dict]) -> None:
|
61
66
|
"""Displays a tree of the workflow nodes."""
|
67
|
+
from rich import print as rprint
|
68
|
+
from rich.tree import Tree
|
69
|
+
|
62
70
|
# Create a mapping of node_id to the list of node_ids that are connected to it.
|
63
71
|
node_adj_mapping = defaultdict(list)
|
64
72
|
# Create a mapping of node_id to the node data info.
|
@@ -104,8 +112,10 @@ def display_workflow_tree(nodes_data: List[Dict]) -> None:
|
|
104
112
|
rprint(tree)
|
105
113
|
|
106
114
|
|
107
|
-
def table_from_dict(data: List[Dict], column_names: List[str],
|
115
|
+
def table_from_dict(data: List[Dict], column_names: List[str],
|
116
|
+
title: str = "") -> 'rich.Table': #noqa F821
|
108
117
|
"""Use this function for printing tables from a list of dicts."""
|
118
|
+
from rich.table import Table
|
109
119
|
table = Table(title=title, show_lines=False, show_header=True, header_style="blue")
|
110
120
|
for column_name in column_names:
|
111
121
|
table.add_column(column_name)
|
@@ -134,23 +144,18 @@ def _configure_logger(name: str, logger_level: Union[int, str] = logging.NOTSET)
|
|
134
144
|
# If ENABLE_JSON_LOGGER is not set, then use json logger if in k8s.
|
135
145
|
enabled_json = os.getenv('ENABLE_JSON_LOGGER', None)
|
136
146
|
in_k8s = 'KUBERNETES_SERVICE_HOST' in os.environ
|
147
|
+
handler = logging.StreamHandler()
|
148
|
+
handler.setLevel(logger_level)
|
137
149
|
if enabled_json == 'true' or (in_k8s and enabled_json != 'false'):
|
138
150
|
# Add the json handler and formatter
|
139
|
-
handler = logging.StreamHandler()
|
140
151
|
formatter = JsonFormatter()
|
141
152
|
handler.setFormatter(formatter)
|
142
|
-
logger.addHandler(handler)
|
143
153
|
else:
|
144
|
-
#
|
145
|
-
|
146
|
-
width, _ = os.get_terminal_size()
|
147
|
-
except OSError:
|
148
|
-
width = 255
|
149
|
-
handler = RichHandler(
|
150
|
-
rich_tracebacks=True, log_time_format="%Y-%m-%d %H:%M:%S.%f", console=Console(width=width))
|
151
|
-
formatter = logging.Formatter('%(message)s')
|
154
|
+
# create formatter and add it to the handlers
|
155
|
+
formatter = TerminalFormatter(LOG_FORMAT)
|
152
156
|
handler.setFormatter(formatter)
|
153
|
-
|
157
|
+
# add the handlers to the logger
|
158
|
+
logger.addHandler(handler)
|
154
159
|
|
155
160
|
|
156
161
|
def get_logger(logger_level: Union[int, str] = logging.NOTSET,
|
@@ -207,6 +212,8 @@ def display_concept_relations_tree(relations_dict: Dict[str, Any]) -> None:
|
|
207
212
|
Args:
|
208
213
|
relations_dict (dict): A dict of concept relations info.
|
209
214
|
"""
|
215
|
+
from rich import print as rprint
|
216
|
+
from rich.tree import Tree
|
210
217
|
for parent, children in relations_dict.items():
|
211
218
|
tree = Tree(parent)
|
212
219
|
for child in children:
|
@@ -372,5 +379,41 @@ class JsonFormatter(logging.Formatter):
|
|
372
379
|
)
|
373
380
|
|
374
381
|
|
382
|
+
class TerminalFormatter(logging.Formatter):
|
383
|
+
""" If you have fields in your Formatter (see setup_logger where we setup the format strings) then
|
384
|
+
you can set them on the record using a filter. We do that for req_id here which is a request
|
385
|
+
specific field. This allows us to find requests easily between services.
|
386
|
+
"""
|
387
|
+
|
388
|
+
def format(self, record):
|
389
|
+
record.optional_args = []
|
390
|
+
|
391
|
+
user_id = getattr(thread_log_info, 'user_id', None)
|
392
|
+
if user_id is not None:
|
393
|
+
record.optional_args.append("user_id=" + user_id)
|
394
|
+
|
395
|
+
app_id = getattr(thread_log_info, 'app_id', None)
|
396
|
+
if app_id is not None:
|
397
|
+
record.optional_args.append("app_id=" + app_id)
|
398
|
+
|
399
|
+
req_id = getattr(thread_log_info, 'req_id', None)
|
400
|
+
if req_id is not None:
|
401
|
+
record.optional_args.append("req_id=" + req_id)
|
402
|
+
|
403
|
+
record.optional_args = " ".join(record.optional_args)
|
404
|
+
|
405
|
+
color_code = COLORS.get(record.levelname, '')
|
406
|
+
|
407
|
+
record.levelname = f"{color_code}{record.levelname}{COLORS.get('RESET')}"
|
408
|
+
record.msg = f"{color_code}{str(record.msg)}{COLORS.get('RESET')}"
|
409
|
+
|
410
|
+
return super(TerminalFormatter, self).format(record)
|
411
|
+
|
412
|
+
def formatTime(self, record, datefmt=None):
|
413
|
+
# Note we didn't go with UTC here as it's easier to understand time in your time zone.
|
414
|
+
# The json logger leverages UTC though.
|
415
|
+
return datetime.datetime.fromtimestamp(record.created).strftime('%H:%M:%S.%f')
|
416
|
+
|
417
|
+
|
375
418
|
# the default logger for the SDK.
|
376
419
|
logger = get_logger(logger_level=os.environ.get("LOG_LEVEL", "INFO"), name="clarifai")
|
clarifai/utils/misc.py
CHANGED