clarifai 11.2.2__py3-none-any.whl → 11.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/cli/nodepool.py CHANGED
@@ -1,31 +1,27 @@
1
1
  import click
2
+
2
3
  from clarifai.cli.base import cli
3
- from clarifai.client.compute_cluster import ComputeCluster
4
- from clarifai.utils.cli import display_co_resources, dump_yaml, from_yaml, validate_context
4
+ from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
5
+ validate_context)
5
6
 
6
7
 
7
- @cli.group(['nodepool', 'np'])
8
+ @cli.group(['nodepool', 'np'], cls=AliasedGroup)
8
9
  def nodepool():
9
10
  """Manage Nodepools: create, delete, list"""
10
- pass
11
11
 
12
12
 
13
- @nodepool.command()
14
- @click.option(
15
- '-cc_id',
16
- '--compute_cluster_id',
17
- required=False,
18
- help='Compute Cluster ID for the compute cluster to interact with.')
13
+ @nodepool.command(['c'])
14
+ @click.argument('compute_cluster_id')
15
+ @click.argument('nodepool_id')
19
16
  @click.option(
20
17
  '--config',
21
18
  type=click.Path(exists=True),
22
19
  required=True,
23
20
  help='Path to the nodepool config file.')
24
- @click.option(
25
- '-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
26
21
  @click.pass_context
27
- def create(ctx, compute_cluster_id, config, nodepool_id):
22
+ def create(ctx, compute_cluster_id, nodepool_id, config):
28
23
  """Create a new Nodepool with the given config file."""
24
+ from clarifai.client.compute_cluster import ComputeCluster
29
25
 
30
26
  validate_context(ctx)
31
27
  nodepool_config = from_yaml(config)
@@ -43,52 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
43
39
 
44
40
  compute_cluster = ComputeCluster(
45
41
  compute_cluster_id=compute_cluster_id,
46
- user_id=ctx.obj['user_id'],
47
- pat=ctx.obj['pat'],
48
- base_url=ctx.obj['base_url'])
42
+ user_id=ctx.obj.current.user_id,
43
+ pat=ctx.obj.current.pat,
44
+ base_url=ctx.obj.current.api_base)
49
45
  if nodepool_id:
50
46
  compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
51
47
  else:
52
48
  compute_cluster.create_nodepool(config)
53
49
 
54
50
 
55
- @nodepool.command()
56
- @click.option(
57
- '-cc_id',
58
- '--compute_cluster_id',
59
- required=True,
60
- help='Compute Cluster ID for the compute cluster to interact with.')
51
+ @nodepool.command(['ls'])
52
+ @click.argument('compute_cluster_id', default="")
61
53
  @click.option('--page_no', required=False, help='Page number to list.', default=1)
62
- @click.option('--per_page', required=False, help='Number of items per page.', default=16)
54
+ @click.option('--per_page', required=False, help='Number of items per page.', default=128)
63
55
  @click.pass_context
64
56
  def list(ctx, compute_cluster_id, page_no, per_page):
65
- """List all nodepools for the user."""
57
+ """List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
58
+ it will list only within that compute cluster. """
59
+ from clarifai.client.compute_cluster import ComputeCluster
60
+ from clarifai.client.user import User
66
61
 
67
62
  validate_context(ctx)
68
- compute_cluster = ComputeCluster(
69
- compute_cluster_id=compute_cluster_id,
70
- user_id=ctx.obj['user_id'],
71
- pat=ctx.obj['pat'],
72
- base_url=ctx.obj['base_url'])
73
- response = compute_cluster.list_nodepools(page_no, per_page)
74
- display_co_resources(response, "Nodepool")
75
63
 
64
+ cc_id = compute_cluster_id
76
65
 
77
- @nodepool.command()
78
- @click.option(
79
- '-cc_id',
80
- '--compute_cluster_id',
81
- required=True,
82
- help='Compute Cluster ID for the compute cluster to interact with.')
83
- @click.option('-np_id', '--nodepool_id', help='Nodepool ID of the user to delete.')
66
+ if cc_id:
67
+ compute_cluster = ComputeCluster(
68
+ compute_cluster_id=cc_id,
69
+ user_id=ctx.obj.current.user_id,
70
+ pat=ctx.obj.current.pat,
71
+ base_url=ctx.obj.current.api_base)
72
+ response = compute_cluster.list_nodepools(page_no, per_page)
73
+ else:
74
+ user = User(
75
+ user_id=ctx.obj.current.user_id,
76
+ pat=ctx.obj.current.pat,
77
+ base_url=ctx.obj.current.api_base)
78
+ ccs = user.list_compute_clusters(page_no, per_page)
79
+ response = []
80
+ for cc in ccs:
81
+ compute_cluster = ComputeCluster(
82
+ compute_cluster_id=cc.id,
83
+ user_id=ctx.obj.current.user_id,
84
+ pat=ctx.obj.current.pat,
85
+ base_url=ctx.obj.current.api_base)
86
+ response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
87
+
88
+ display_co_resources(
89
+ response,
90
+ custom_columns={
91
+ 'ID': lambda c: c.id,
92
+ 'USER_ID': lambda c: c.compute_cluster.user_id,
93
+ 'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
94
+ 'DESCRIPTION': lambda c: c.description,
95
+ })
96
+
97
+
98
+ @nodepool.command(['rm'])
99
+ @click.argument('compute_cluster_id')
100
+ @click.argument('nodepool_id')
84
101
  @click.pass_context
85
102
  def delete(ctx, compute_cluster_id, nodepool_id):
86
103
  """Deletes a nodepool for the user."""
104
+ from clarifai.client.compute_cluster import ComputeCluster
87
105
 
88
106
  validate_context(ctx)
89
107
  compute_cluster = ComputeCluster(
90
108
  compute_cluster_id=compute_cluster_id,
91
- user_id=ctx.obj['user_id'],
92
- pat=ctx.obj['pat'],
93
- base_url=ctx.obj['base_url'])
109
+ user_id=ctx.obj.current.user_id,
110
+ pat=ctx.obj.current.pat,
111
+ base_url=ctx.obj.current.api_base)
94
112
  compute_cluster.delete_nodepools([nodepool_id])
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import logging
3
2
  import time
4
3
  from concurrent.futures import ThreadPoolExecutor
5
4
 
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
8
7
 
9
8
  from clarifai.client.auth.helper import ClarifaiAuthHelper
10
9
  from clarifai.client.auth.register import RpcCallable, V2Stub
11
-
10
+ from clarifai.utils.logging import logger
12
11
  throttle_status_codes = {
13
12
  status_code_pb2.CONN_THROTTLED,
14
13
  status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
26
25
  def handle_simple_response(response):
27
26
  if hasattr(response, 'status') and hasattr(response.status, 'code'):
28
27
  if (response.status.code in throttle_status_codes) and attempt < max_attempts:
29
- logging.debug('Retrying with status %s' % str(response.status))
28
+ logger.debug('Retrying with status %s' % str(response.status))
30
29
  return None # Indicates a retry is needed
31
30
  else:
32
31
  return response
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
42
41
  return itertools.chain([validated_response], response)
43
42
  return None # Indicates a retry is needed
44
43
  except grpc.RpcError as e:
45
- logging.error('Error processing streaming response: %s' % str(e))
44
+ logger.error('Error processing streaming response: %s' % str(e))
46
45
  return None # Indicates an error
47
46
  else:
48
47
  # Handle simple response validation
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
143
142
  return v
144
143
  except grpc.RpcError as e:
145
144
  if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
146
- logging.debug('Retrying with status %s' % e.code())
145
+ logger.debug('Retrying with status %s' % e.code())
147
146
  else:
148
147
  raise
149
148
 
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import os
3
2
  import time
4
3
  import uuid
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
354
353
  break
355
354
  if failed_input_ids:
356
355
  retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
357
- logging.warning(
356
+ logger.warning(
358
357
  f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
359
358
  )
360
359
  failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
494
493
  add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
495
494
 
496
495
  if retry_duplicates and duplicate_input_ids:
497
- logging.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
496
+ logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
498
497
  duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
499
498
  self.upload_dataset(
500
499
  dataloader=dataloader,
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
505
504
 
506
505
  if failed_input_ids:
507
506
  #failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
508
- logging.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
507
+ logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
509
508
  failed_input_indexes = [input["Index"] for input in failed_input_ids]
510
509
  self.upload_dataset(
511
510
  dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
@@ -142,17 +142,21 @@ class ModelBuilder:
142
142
  def _validate_config_checkpoints(self):
143
143
  """
144
144
  Validates the checkpoints section in the config file.
145
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
145
146
  :return: loader_type the type of loader or None if no checkpoints.
146
147
  :return: repo_id location of checkpoint.
147
148
  :return: hf_token token to access checkpoint.
149
+ :return: when one of ['upload', 'build', 'runtime'] to download checkpoint
150
+ :return: allowed_file_patterns patterns to allow in downloaded checkpoint
151
+ :return: ignore_file_patterns patterns to ignore in downloaded checkpoint
148
152
  """
149
153
  if "checkpoints" not in self.config:
150
- return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
154
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
151
155
  assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
152
156
  loader_type = self.config.get("checkpoints").get("type")
153
157
  if not loader_type:
154
158
  logger.info("No loader type specified in the config file for checkpoints")
155
- return None, None, None
159
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
156
160
  checkpoints = self.config.get("checkpoints")
157
161
  if 'when' not in checkpoints:
158
162
  logger.warn(
@@ -184,9 +188,17 @@ class ModelBuilder:
184
188
  resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
185
189
  if resp.status.code == status_code_pb2.SUCCESS:
186
190
  return True
191
+ if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
192
+ logger.error(
193
+ f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
194
+ )
195
+ return False
187
196
  logger.error(
188
197
  f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
189
198
  )
199
+ logger.error(
200
+ f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
201
+ )
190
202
  return False
191
203
 
192
204
  def _validate_config_model(self):
@@ -207,9 +219,6 @@ class ModelBuilder:
207
219
  assert model.get('id') != "", "model_id cannot be empty in the config file"
208
220
 
209
221
  if not self._check_app_exists():
210
- logger.error(
211
- f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
212
- )
213
222
  sys.exit(1)
214
223
 
215
224
  def _validate_config(self):
@@ -406,11 +415,12 @@ class ModelBuilder:
406
415
  # Sort in reverse so that newer cuda versions come first and are preferred.
407
416
  for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
408
417
  if torch_version in image and f'py{python_version}' in image:
409
- cuda_version = image.split('-')[-1].replace('cuda', '')
418
+ # like cu124, rocm6.3, etc.
419
+ gpu_version = image.split('-')[-1]
410
420
  final_image = TORCH_BASE_IMAGE.format(
411
421
  torch_version=torch_version,
412
422
  python_version=python_version,
413
- cuda_version=cuda_version,
423
+ gpu_version=gpu_version,
414
424
  )
415
425
  logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
416
426
  break
@@ -108,6 +108,7 @@ class ModelRunLocally:
108
108
 
109
109
  def _build_stream_request(self):
110
110
  request = self._build_request()
111
+ ensure_urls_downloaded(request)
111
112
  for i in range(1):
112
113
  yield request
113
114
 
@@ -2,10 +2,10 @@ import os
2
2
 
3
3
  registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
4
4
 
5
- GIT_SHA = "df565436eea93efb3e8d1eb558a0a46df29523ec"
5
+ GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
6
6
 
7
7
  PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
8
- TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-cuda{cuda_version}-' + GIT_SHA
8
+ TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
9
9
 
10
10
  # List of available python base images
11
11
  AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
21
21
  # List of available torch images
22
22
  # Keep sorted by most recent cuda version.
23
23
  AVAILABLE_TORCH_IMAGES = [
24
- '2.4.1-py3.11-cuda124',
25
- '2.5.1-py3.11-cuda124',
26
- '2.4.1-py3.12-cuda124',
27
- '2.5.1-py3.12-cuda124',
28
- # '2.4.1-py3.13-cuda124',
29
- # '2.5.1-py3.13-cuda124',
24
+ '2.4.1-py3.11-cu124',
25
+ '2.5.1-py3.11-cu124',
26
+ '2.4.1-py3.12-cu124',
27
+ '2.5.1-py3.12-cu124',
28
+ '2.6.0-py3.12-cu126',
29
+ '2.7.0-py3.12-cu128',
30
+ '2.7.0-py3.12-rocm6.3',
30
31
  ]
31
32
  CONCEPTS_REQUIRED_MODEL_TYPE = [
32
33
  'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
clarifai/utils/cli.py CHANGED
@@ -2,16 +2,13 @@ import importlib
2
2
  import os
3
3
  import pkgutil
4
4
  import sys
5
+ import typing as t
6
+ from collections import defaultdict
7
+ from typing import OrderedDict
5
8
 
6
9
  import click
7
10
  import yaml
8
-
9
- from rich.console import Console
10
- from rich.panel import Panel
11
- from rich.style import Style
12
- from rich.text import Text
13
-
14
- from clarifai.utils.logging import logger
11
+ from tabulate import tabulate
15
12
 
16
13
 
17
14
  def from_yaml(filename: str):
@@ -31,19 +28,6 @@ def dump_yaml(data, filename: str):
31
28
  click.echo(f"Error writing YAML file: {e}", err=True)
32
29
 
33
30
 
34
- def set_base_url(env):
35
- environments = {
36
- 'prod': 'https://api.clarifai.com',
37
- 'staging': 'https://api-staging.clarifai.com',
38
- 'dev': 'https://api-dev.clarifai.com'
39
- }
40
-
41
- if env in environments:
42
- return environments[env]
43
- else:
44
- raise ValueError("Invalid environment. Please choose from 'prod', 'staging', 'dev'.")
45
-
46
-
47
31
  # Dynamically find and import all command modules from the cli directory
48
32
  def load_command_modules():
49
33
  package_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'cli')
@@ -53,27 +37,132 @@ def load_command_modules():
53
37
  importlib.import_module(f'clarifai.cli.{module_name}')
54
38
 
55
39
 
56
- def display_co_resources(response, resource_type):
40
+ def display_co_resources(response,
41
+ custom_columns={
42
+ 'ID': lambda c: c.id,
43
+ 'USER_ID': lambda c: c.user_id,
44
+ 'DESCRIPTION': lambda c: c.description,
45
+ }):
57
46
  """Display compute orchestration resources listing results using rich."""
58
47
 
59
- console = Console()
60
- panel = Panel(
61
- Text(f"List of {resource_type}s", justify="center"),
62
- title="",
63
- style=Style(color="blue", bold=True),
64
- border_style="green",
65
- width=60)
66
- console.print(panel)
67
- for indx, item in enumerate(list(response)):
68
- panel = Panel(
69
- "\n".join([f"{'ID'}: {item.id}", f"{'Description'}: {item.description}"]),
70
- title=f"{resource_type} {(indx + 1)}",
71
- border_style="green",
72
- width=60)
73
- console.print(panel)
48
+ formatter = TableFormatter(custom_columns)
49
+ print(formatter.format(list(response), fmt="plain"))
50
+
51
+
52
+ class TableFormatter:
53
+
54
+ def __init__(self, custom_columns: OrderedDict):
55
+ """
56
+ Initializes the TableFormatter with column headers and custom column mappings.
57
+
58
+ :param headers: List of column headers for the table.
59
+ """
60
+ self.custom_columns = custom_columns
61
+
62
+ def format(self, objects, fmt='plain'):
63
+ """
64
+ Formats a list of objects into a table with custom columns.
65
+
66
+ :param objects: List of objects to format into a table.
67
+ :return: A string representing the table.
68
+ """
69
+ # Prepare the rows by applying the custom column functions to each object
70
+ rows = []
71
+ for obj in objects:
72
+ # row = [self.custom_columns[header](obj) for header in self.headers]
73
+ row = [f(obj) for f in self.custom_columns.values()]
74
+ rows.append(row)
75
+
76
+ # Create the table
77
+ table = tabulate(rows, headers=self.custom_columns.keys(), tablefmt=fmt)
78
+ return table
79
+
80
+
81
+ class AliasedGroup(click.Group):
82
+
83
+ def __init__(self,
84
+ name: t.Optional[str] = None,
85
+ commands: t.Optional[t.Union[t.MutableMapping[str, click.Command], t.Sequence[
86
+ click.Command]]] = None,
87
+ **attrs: t.Any) -> None:
88
+ super().__init__(name, commands, **attrs)
89
+ self.alias_map = {}
90
+ self.command_to_aliases = defaultdict(list)
91
+
92
+ def add_alias(self, cmd: click.Command, alias: str) -> None:
93
+ self.alias_map[alias] = cmd
94
+ if alias != cmd.name:
95
+ self.command_to_aliases[cmd].append(alias)
96
+
97
+ def command(self, aliases=None, *args,
98
+ **kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Command]:
99
+ cmd_decorator = super().command(*args, **kwargs)
100
+ if aliases is None:
101
+ aliases = []
102
+
103
+ def aliased_decorator(f):
104
+ cmd = cmd_decorator(f)
105
+ if cmd.name:
106
+ self.add_alias(cmd, cmd.name)
107
+ for alias in aliases:
108
+ self.add_alias(cmd, alias)
109
+ return cmd
110
+
111
+ f = None
112
+ if args and callable(args[0]):
113
+ (f,) = args
114
+ if f is not None:
115
+ return aliased_decorator(f)
116
+ return aliased_decorator
117
+
118
+ def group(self, aliases=None, *args,
119
+ **kwargs) -> t.Callable[[t.Callable[..., t.Any]], click.Group]:
120
+ cmd_decorator = super().group(*args, **kwargs)
121
+ if aliases is None:
122
+ aliases = []
123
+
124
+ def aliased_decorator(f):
125
+ cmd = cmd_decorator(f)
126
+ if cmd.name:
127
+ self.add_alias(cmd, cmd.name)
128
+ for alias in aliases:
129
+ self.add_alias(cmd, alias)
130
+ return cmd
131
+
132
+ f = None
133
+ if args and callable(args[0]):
134
+ (f,) = args
135
+ if f is not None:
136
+ return aliased_decorator(f)
137
+ return aliased_decorator
138
+
139
+ def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]:
140
+ rv = click.Group.get_command(self, ctx, cmd_name)
141
+ if rv is not None:
142
+ return rv
143
+ return self.alias_map.get(cmd_name)
144
+
145
+ def format_commands(self, ctx, formatter):
146
+ sub_commands = self.list_commands(ctx)
147
+
148
+ rows = []
149
+ for sub_command in sub_commands:
150
+ cmd = self.get_command(ctx, sub_command)
151
+ if cmd is None or getattr(cmd, 'hidden', False):
152
+ continue
153
+ if cmd in self.command_to_aliases:
154
+ aliases = ', '.join(self.command_to_aliases[cmd])
155
+ sub_command = f'{sub_command} ({aliases})'
156
+ cmd_help = cmd.help
157
+ rows.append((sub_command, cmd_help))
158
+
159
+ if rows:
160
+ with formatter.section("Commands"):
161
+ formatter.write_dl(rows)
74
162
 
75
163
 
76
164
  def validate_context(ctx):
165
+ from clarifai.utils.logging import logger
77
166
  if ctx.obj == {}:
78
167
  logger.error("CLI config file missing. Run `clarifai login` to set up the CLI config.")
79
168
  sys.exit(1)
@@ -0,0 +1,105 @@
1
+ import os
2
+ from collections import OrderedDict
3
+ from dataclasses import dataclass, field
4
+
5
+ import yaml
6
+
7
+ from clarifai.utils.constants import DEFAULT_CONFIG
8
+
9
+
10
+ class Context(OrderedDict):
11
+ """
12
+ A context which has a name and a set of key-values as a dict under env.
13
+
14
+ You can access the keys directly.
15
+ """
16
+
17
+ def __init__(self, name, **kwargs):
18
+ self['name'] = name
19
+ # when loading from config we may have the env: section in yaml already so we get it here.
20
+ if 'env' in kwargs:
21
+ self['env'] = kwargs['env']
22
+ else: # when consructing as Context(name, key=value) we set it here.
23
+ self['env'] = kwargs
24
+
25
+ def __getattr__(self, key):
26
+ try:
27
+ if key == 'name':
28
+ return self[key]
29
+ if key == 'env':
30
+ raise AttributeError("Don't access .env directly")
31
+
32
+ # Allow accessing CLARIFAI_PAT type env var names from config as .pat
33
+ envvar_name = 'CLARIFAI_' + key.upper()
34
+ env = self['env']
35
+ if envvar_name in env:
36
+ value = env[envvar_name]
37
+ if value == "ENVVAR":
38
+ return os.environ[envvar_name]
39
+ else:
40
+ value = env[key]
41
+
42
+ if isinstance(value, dict):
43
+ return Context(value)
44
+
45
+ return value
46
+ except KeyError as e:
47
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
48
+
49
+ def __setattr__(self, key, value):
50
+ if key == "name":
51
+ self['name'] = value
52
+ else:
53
+ self['env'][key] = value
54
+
55
+ def __delattr__(self, key):
56
+ try:
57
+ del self['env'][key]
58
+ except KeyError as e:
59
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e
60
+
61
+ def to_serializable_dict(self):
62
+ return dict(self['env'])
63
+
64
+
65
+ @dataclass
66
+ class Config():
67
+ current_context: str
68
+ filename: str
69
+ contexts: OrderedDict[str, Context] = field(default_factory=OrderedDict)
70
+
71
+ def __post_init__(self):
72
+ for k, v in self.contexts.items():
73
+ if 'name' not in v:
74
+ v['name'] = k
75
+ self.contexts = {k: Context(**v) for k, v in self.contexts.items()}
76
+
77
+ @classmethod
78
+ def from_yaml(cls, filename: str = DEFAULT_CONFIG):
79
+ with open(filename, 'r') as f:
80
+ cfg = yaml.safe_load(f)
81
+ return cls(**cfg, filename=filename)
82
+
83
+ def to_dict(self):
84
+ return {
85
+ 'current_context': self.current_context,
86
+ 'contexts': {k: v.to_serializable_dict()
87
+ for k, v in self.contexts.items()}
88
+ }
89
+
90
+ def to_yaml(self, filename: str = None):
91
+ if filename is None:
92
+ filename = self.filename
93
+ dir = os.path.dirname(filename)
94
+ if len(dir):
95
+ os.makedirs(dir, exist_ok=True)
96
+ _dict = self.to_dict()
97
+ for k, v in _dict['contexts'].items():
98
+ v.pop('name', None)
99
+ with open(filename, 'w') as f:
100
+ yaml.safe_dump(_dict, f)
101
+
102
+ @property
103
+ def current(self) -> Context:
104
+ """ get the current Context """
105
+ return self.contexts[self.current_context]
@@ -1,3 +1,7 @@
1
+ import os
2
+
1
3
  CLARIFAI_PAT_ENV_VAR = "CLARIFAI_PAT"
2
4
  CLARIFAI_SESSION_TOKEN_ENV_VAR = "CLARIFAI_SESSION_TOKEN"
3
5
  CLARIFAI_USER_ID_ENV_VAR = "CLARIFAI_USER_ID"
6
+
7
+ DEFAULT_CONFIG = f'{os.environ["HOME"]}/.config/clarifai/config'