clarifai 11.8.1__py3-none-any.whl → 11.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "11.8.1"
1
+ __version__ = "11.8.3"
clarifai/cli/model.py CHANGED
@@ -7,13 +7,18 @@ import click
7
7
 
8
8
  from clarifai.cli.base import cli, pat_display
9
9
  from clarifai.utils.cli import (
10
+ check_lmstudio_installed,
10
11
  check_ollama_installed,
11
12
  check_requirements_installed,
13
+ customize_huggingface_model,
14
+ customize_lmstudio_model,
12
15
  customize_ollama_model,
13
16
  parse_requirements,
14
17
  validate_context,
15
18
  )
16
19
  from clarifai.utils.constants import (
20
+ DEFAULT_HF_MODEL_REPO_BRANCH,
21
+ DEFAULT_LMSTUDIO_MODEL_REPO_BRANCH,
17
22
  DEFAULT_LOCAL_RUNNER_APP_ID,
18
23
  DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_CONFIG,
19
24
  DEFAULT_LOCAL_RUNNER_COMPUTE_CLUSTER_ID,
@@ -22,8 +27,9 @@ from clarifai.utils.constants import (
22
27
  DEFAULT_LOCAL_RUNNER_MODEL_TYPE,
23
28
  DEFAULT_LOCAL_RUNNER_NODEPOOL_CONFIG,
24
29
  DEFAULT_LOCAL_RUNNER_NODEPOOL_ID,
25
- DEFAULT_OLLAMA_MODEL_REPO,
26
30
  DEFAULT_OLLAMA_MODEL_REPO_BRANCH,
31
+ DEFAULT_TOOLKIT_MODEL_REPO,
32
+ DEFAULT_VLLM_MODEL_REPO_BRANCH,
27
33
  )
28
34
  from clarifai.utils.logging import logger
29
35
  from clarifai.utils.misc import (
@@ -68,14 +74,14 @@ def model():
68
74
  )
69
75
  @click.option(
70
76
  '--toolkit',
71
- type=click.Choice(['ollama'], case_sensitive=False),
77
+ type=click.Choice(['ollama', 'huggingface', 'lmstudio', 'vllm'], case_sensitive=False),
72
78
  required=False,
73
- help='Toolkit to use for model initialization. Currently supports "ollama".',
79
+ help='Toolkit to use for model initialization. Currently supports "ollama", "huggingface", "lmstudio" and "vllm".',
74
80
  )
75
81
  @click.option(
76
82
  '--model-name',
77
83
  required=False,
78
- help='Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.).',
84
+ help='Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.). For vllm & huggingface toolkit, this sets the Hugging Face model repo_id (e.g., "unsloth/Llama-3.2-1B-Instruct").\n For lmstudio toolkit, this sets the LM Studio model name (e.g., "qwen/qwen3-4b-thinking-2507").\n',
79
85
  )
80
86
  @click.option(
81
87
  '--port',
@@ -112,14 +118,16 @@ def init(
112
118
  when cloning private repositories. The --branch option can be used to specify a specific
113
119
  branch to clone from.
114
120
 
115
- MODEL_PATH: Path where to create the model directory structure. If not specified, the current directory is used by default.
116
- MODEL_TYPE_ID: Type of model to create. If not specified, defaults to "text-to-text" for text models.
117
- GITHUB_PAT: GitHub Personal Access Token for authentication when cloning private repositories.
118
- GITHUB_URL: GitHub repository URL or "repo" format to clone a repository from. If provided, the entire repository contents will be copied to the target directory instead of using default templates.
119
- TOOLKIT: Toolkit to use for model initialization. Currently supports "ollama".
120
- MODEL_NAME: Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.).
121
- PORT: Port to run the Ollama server on. Defaults to 23333.
122
- CONTEXT_LENGTH: Context length for the Ollama model. Defaults to 8192.
121
+ MODEL_PATH: Path where to create the model directory structure. If not specified, the current directory is used by default.\n
122
+
123
+ OPTIONS:\n
124
+ MODEL_TYPE_ID: Type of model to create. If not specified, defaults to "text-to-text" for text models.\n
125
+ GITHUB_PAT: GitHub Personal Access Token for authentication when cloning private repositories.\n
126
+ GITHUB_URL: GitHub repository URL or "repo" format to clone a repository from. If provided, the entire repository contents will be copied to the target directory instead of using default templates.\n
127
+ TOOLKIT: Toolkit to use for model initialization. Currently supports "ollama", "huggingface", "lmstudio" and "vllm".\n
128
+ MODEL_NAME: Model name to configure when using --toolkit. For ollama toolkit, this sets the Ollama model to use (e.g., "llama3.1", "mistral", etc.). For vllm & huggingface toolkit, this sets the Hugging Face model repo_id (e.g., "Qwen/Qwen3-4B-Instruct-2507"). For lmstudio toolkit, this sets the LM Studio model name (e.g., "qwen/qwen3-4b-thinking-2507").\n
129
+ PORT: Port to run the (Ollama/lmstudio) server on. Defaults to 23333.\n
130
+ CONTEXT_LENGTH: Context length for the (Ollama/lmstudio) model. Defaults to 8192.\n
123
131
  """
124
132
  # Resolve the absolute path
125
133
  model_path = os.path.abspath(model_path)
@@ -152,8 +160,22 @@ def init(
152
160
  "Ollama is not installed. Please install it from `https://ollama.com/` to use the Ollama toolkit."
153
161
  )
154
162
  raise click.Abort()
155
- github_url = DEFAULT_OLLAMA_MODEL_REPO
163
+ github_url = DEFAULT_TOOLKIT_MODEL_REPO
156
164
  branch = DEFAULT_OLLAMA_MODEL_REPO_BRANCH
165
+ elif toolkit == 'huggingface':
166
+ github_url = DEFAULT_TOOLKIT_MODEL_REPO
167
+ branch = DEFAULT_HF_MODEL_REPO_BRANCH
168
+ elif toolkit == 'lmstudio':
169
+ if not check_lmstudio_installed():
170
+ logger.error(
171
+ "LM Studio is not installed. Please install it from `https://lmstudio.com/` to use the LM Studio toolkit."
172
+ )
173
+ raise click.Abort()
174
+ github_url = DEFAULT_TOOLKIT_MODEL_REPO
175
+ branch = DEFAULT_LMSTUDIO_MODEL_REPO_BRANCH
176
+ elif toolkit == 'vllm':
177
+ github_url = DEFAULT_TOOLKIT_MODEL_REPO
178
+ branch = DEFAULT_VLLM_MODEL_REPO_BRANCH
157
179
 
158
180
  if github_url:
159
181
  downloader = GitHubDownloader(
@@ -209,6 +231,44 @@ def init(
209
231
  repo_url = format_github_repo_url(github_url)
210
232
  repo_url = f"https://github.com/{owner}/{repo}"
211
233
 
234
+ try:
235
+ # Create a temporary directory for cloning
236
+ with tempfile.TemporaryDirectory(prefix="clarifai_model_") as clone_dir:
237
+ # Clone the repository with explicit branch parameter
238
+ if not clone_github_repo(repo_url, clone_dir, github_pat, branch):
239
+ logger.error(f"Failed to clone repository from {repo_url}")
240
+ github_url = None # Fall back to template mode
241
+
242
+ else:
243
+ # Copy the entire repository content to target directory (excluding .git)
244
+ for item in os.listdir(clone_dir):
245
+ if item == '.git':
246
+ continue
247
+
248
+ source_path = os.path.join(clone_dir, item)
249
+ target_path = os.path.join(model_path, item)
250
+
251
+ if os.path.isdir(source_path):
252
+ shutil.copytree(source_path, target_path, dirs_exist_ok=True)
253
+ else:
254
+ shutil.copy2(source_path, target_path)
255
+
256
+ logger.info(f"Successfully cloned repository to {model_path}")
257
+ logger.info(
258
+ "Model initialization complete with GitHub repository clone"
259
+ )
260
+ logger.info("Next steps:")
261
+ logger.info("1. Review the model configuration")
262
+ logger.info("2. Install any required dependencies manually")
263
+ logger.info(
264
+ "3. Test the model locally using 'clarifai model local-test'"
265
+ )
266
+ return
267
+
268
+ except Exception as e:
269
+ logger.error(f"Failed to clone GitHub repository: {e}")
270
+ github_url = None # Fall back to template mode
271
+
212
272
  if toolkit:
213
273
  logger.info(f"Initializing model from GitHub repository: {github_url}")
214
274
 
@@ -218,35 +278,42 @@ def init(
218
278
  else:
219
279
  repo_url = format_github_repo_url(github_url)
220
280
 
221
- try:
222
- # Create a temporary directory for cloning
223
- with tempfile.TemporaryDirectory(prefix="clarifai_model_") as clone_dir:
224
- # Clone the repository with explicit branch parameter
225
- if not clone_github_repo(repo_url, clone_dir, github_pat, branch):
226
- logger.error(f"Failed to clone repository from {repo_url}")
227
- github_url = None # Fall back to template mode
228
-
229
- else:
230
- # Copy the entire repository content to target directory (excluding .git)
231
- for item in os.listdir(clone_dir):
232
- if item == '.git':
233
- continue
234
-
235
- source_path = os.path.join(clone_dir, item)
236
- target_path = os.path.join(model_path, item)
237
-
238
- if os.path.isdir(source_path):
239
- shutil.copytree(source_path, target_path, dirs_exist_ok=True)
240
- else:
241
- shutil.copy2(source_path, target_path)
281
+ try:
282
+ # Create a temporary directory for cloning
283
+ with tempfile.TemporaryDirectory(prefix="clarifai_model_") as clone_dir:
284
+ # Clone the repository with explicit branch parameter
285
+ if not clone_github_repo(repo_url, clone_dir, github_pat, branch):
286
+ logger.error(f"Failed to clone repository from {repo_url}")
287
+ github_url = None # Fall back to template mode
242
288
 
243
- except Exception as e:
244
- logger.error(f"Failed to clone GitHub repository: {e}")
245
- github_url = None
289
+ else:
290
+ # Copy the entire repository content to target directory (excluding .git)
291
+ for item in os.listdir(clone_dir):
292
+ if item == '.git':
293
+ continue
294
+
295
+ source_path = os.path.join(clone_dir, item)
296
+ target_path = os.path.join(model_path, item)
297
+
298
+ if os.path.isdir(source_path):
299
+ shutil.copytree(source_path, target_path, dirs_exist_ok=True)
300
+ else:
301
+ shutil.copy2(source_path, target_path)
302
+
303
+ except Exception as e:
304
+ logger.error(f"Failed to clone GitHub repository: {e}")
305
+ github_url = None
246
306
 
247
307
  if (model_name or port or context_length) and (toolkit == 'ollama'):
248
308
  customize_ollama_model(model_path, model_name, port, context_length)
249
309
 
310
+ if (model_name or port or context_length) and (toolkit == 'lmstudio'):
311
+ customize_lmstudio_model(model_path, model_name, port, context_length)
312
+
313
+ if model_name and (toolkit == 'huggingface' or toolkit == 'vllm'):
314
+ # Update the config.yaml file with the provided model name
315
+ customize_huggingface_model(model_path, model_name)
316
+
250
317
  if github_url:
251
318
  logger.info("Model initialization complete with GitHub repository")
252
319
  logger.info("Next steps:")
@@ -294,7 +361,7 @@ def init(
294
361
  if os.path.exists(config_path):
295
362
  logger.warning(f"File {config_path} already exists, skipping...")
296
363
  else:
297
- config_model_type_id = "text-to-text" # default
364
+ config_model_type_id = DEFAULT_LOCAL_RUNNER_MODEL_TYPE # default
298
365
 
299
366
  config_template = get_config_template(config_model_type_id)
300
367
  with open(config_path, 'w') as f:
clarifai/cli/pipeline.py CHANGED
@@ -26,14 +26,19 @@ def pipeline():
26
26
 
27
27
  @pipeline.command()
28
28
  @click.argument("path", type=click.Path(exists=True), required=False, default=".")
29
- def upload(path):
29
+ @click.option(
30
+ '--no-lockfile',
31
+ is_flag=True,
32
+ help='Skip creating config-lock.yaml file.',
33
+ )
34
+ def upload(path, no_lockfile):
30
35
  """Upload a pipeline with associated pipeline steps to Clarifai.
31
36
 
32
37
  PATH: Path to the pipeline configuration file or directory containing config.yaml. If not specified, the current directory is used by default.
33
38
  """
34
39
  from clarifai.runners.pipelines.pipeline_builder import upload_pipeline
35
40
 
36
- upload_pipeline(path)
41
+ upload_pipeline(path, no_lockfile=no_lockfile)
37
42
 
38
43
 
39
44
  @pipeline.command()
@@ -106,15 +111,32 @@ def run(
106
111
 
107
112
  validate_context(ctx)
108
113
 
114
+ # Try to load from config-lock.yaml first if no config is specified
115
+ lockfile_path = os.path.join(os.getcwd(), "config-lock.yaml")
116
+ if not config and os.path.exists(lockfile_path):
117
+ logger.info("Found config-lock.yaml, using it as default config source")
118
+ config = lockfile_path
119
+
109
120
  if config:
110
121
  config_data = from_yaml(config)
111
- pipeline_id = config_data.get('pipeline_id', pipeline_id)
112
- pipeline_version_id = config_data.get('pipeline_version_id', pipeline_version_id)
122
+
123
+ # Handle both regular config format and lockfile format
124
+ if 'pipeline' in config_data and isinstance(config_data['pipeline'], dict):
125
+ pipeline_config = config_data['pipeline']
126
+ pipeline_id = pipeline_config.get('id', pipeline_id)
127
+ pipeline_version_id = pipeline_config.get('version_id', pipeline_version_id)
128
+ user_id = pipeline_config.get('user_id', user_id)
129
+ app_id = pipeline_config.get('app_id', app_id)
130
+ else:
131
+ # Fallback to flat config structure
132
+ pipeline_id = config_data.get('pipeline_id', pipeline_id)
133
+ pipeline_version_id = config_data.get('pipeline_version_id', pipeline_version_id)
134
+ user_id = config_data.get('user_id', user_id)
135
+ app_id = config_data.get('app_id', app_id)
136
+
113
137
  pipeline_version_run_id = config_data.get(
114
138
  'pipeline_version_run_id', pipeline_version_run_id
115
139
  )
116
- user_id = config_data.get('user_id', user_id)
117
- app_id = config_data.get('app_id', app_id)
118
140
  nodepool_id = config_data.get('nodepool_id', nodepool_id)
119
141
  compute_cluster_id = config_data.get('compute_cluster_id', compute_cluster_id)
120
142
  pipeline_url = config_data.get('pipeline_url', pipeline_url)
@@ -319,6 +341,62 @@ def init(pipeline_path):
319
341
  logger.info("3. Run 'clarifai pipeline upload config.yaml' to upload your pipeline")
320
342
 
321
343
 
344
+ @pipeline.command()
345
+ @click.argument(
346
+ "lockfile_path", type=click.Path(exists=True), required=False, default="config-lock.yaml"
347
+ )
348
+ def validate_lock(lockfile_path):
349
+ """Validate a config-lock.yaml file for schema and reference consistency.
350
+
351
+ LOCKFILE_PATH: Path to the config-lock.yaml file. If not specified, looks for config-lock.yaml in current directory.
352
+ """
353
+ from clarifai.runners.utils.pipeline_validation import PipelineConfigValidator
354
+ from clarifai.utils.cli import from_yaml
355
+
356
+ try:
357
+ # Load the lockfile
358
+ lockfile_data = from_yaml(lockfile_path)
359
+
360
+ # Validate required fields
361
+ if "pipeline" not in lockfile_data:
362
+ raise ValueError("'pipeline' section not found in lockfile")
363
+
364
+ pipeline = lockfile_data["pipeline"]
365
+ required_fields = ["id", "user_id", "app_id", "version_id"]
366
+
367
+ for field in required_fields:
368
+ if field not in pipeline:
369
+ raise ValueError(f"Required field '{field}' not found in pipeline section")
370
+ if not pipeline[field]:
371
+ raise ValueError(f"Required field '{field}' cannot be empty")
372
+
373
+ # Validate orchestration spec if present
374
+ if "orchestration_spec" in pipeline:
375
+ # Create a temporary config structure for validation
376
+ temp_config = {
377
+ "pipeline": {
378
+ "id": pipeline["id"],
379
+ "user_id": pipeline["user_id"],
380
+ "app_id": pipeline["app_id"],
381
+ "orchestration_spec": pipeline["orchestration_spec"],
382
+ }
383
+ }
384
+
385
+ # Use existing validator to check orchestration spec
386
+ validator = PipelineConfigValidator()
387
+ validator._validate_orchestration_spec(temp_config)
388
+
389
+ logger.info(f"✅ Lockfile {lockfile_path} is valid")
390
+ logger.info(f"Pipeline: {pipeline['id']}")
391
+ logger.info(f"User: {pipeline['user_id']}")
392
+ logger.info(f"App: {pipeline['app_id']}")
393
+ logger.info(f"Version: {pipeline['version_id']}")
394
+
395
+ except Exception as e:
396
+ logger.error(f"❌ Lockfile validation failed: {e}")
397
+ raise click.Abort()
398
+
399
+
322
400
  @pipeline.command(['ls'])
323
401
  @click.option('--page_no', required=False, help='Page number to list.', default=1)
324
402
  @click.option('--per_page', required=False, help='Number of items per page.', default=16)
@@ -178,7 +178,7 @@ class MyModel(OpenAIModelClass):
178
178
  '''
179
179
 
180
180
 
181
- def get_config_template(model_type_id: str = "text-to-text") -> str:
181
+ def get_config_template(model_type_id: str = "any-to-any") -> str:
182
182
  """Return the template for config.yaml."""
183
183
  return f'''# Configuration file for your Clarifai model
184
184
 
clarifai/client/base.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import Any, Callable
2
+ from typing import Any, Callable, Dict, Optional
3
3
 
4
4
  from clarifai_grpc.grpc.api import resources_pb2
5
5
  from google.protobuf import struct_pb2
@@ -88,20 +88,46 @@ class BaseClient:
88
88
  self.root_certificates_path = self.auth_helper._root_certificates_path
89
89
 
90
90
  @property
91
- def async_stub(self):
91
+ def async_stub(self) -> Any:
92
92
  """Returns the asynchronous gRPC stub for the API interaction.
93
- Lazy initialization of async stub"""
93
+
94
+ Returns:
95
+ Any: The async gRPC stub object for making asynchronous API calls.
96
+
97
+ Note:
98
+ Uses lazy initialization - stub is created on first access.
99
+ """
94
100
  if self._async_stub is None:
95
101
  self._async_stub = create_stub(self.auth_helper, is_async=True)
96
102
  return self._async_stub
97
103
 
98
104
  @classmethod
99
- def from_env(cls, validate: bool = False):
105
+ def from_env(cls, validate: bool = False) -> 'BaseClient':
106
+ """Creates a BaseClient instance from environment variables.
107
+
108
+ Args:
109
+ validate (bool): Whether to validate the authentication credentials.
110
+ Defaults to False.
111
+
112
+ Returns:
113
+ BaseClient: A new BaseClient instance configured from environment variables.
114
+ """
100
115
  auth = ClarifaiAuthHelper.from_env(validate=validate)
101
116
  return cls.from_auth_helper(auth)
102
117
 
103
118
  @classmethod
104
- def from_auth_helper(cls, auth: ClarifaiAuthHelper, **kwargs):
119
+ def from_auth_helper(cls, auth: ClarifaiAuthHelper, **kwargs) -> 'BaseClient':
120
+ """Creates a BaseClient instance from a ClarifaiAuthHelper.
121
+
122
+ Args:
123
+ auth (ClarifaiAuthHelper): The authentication helper containing credentials.
124
+ **kwargs: Additional keyword arguments to override auth helper values.
125
+ Supported keys: user_id, app_id, pat, token, root_certificates_path,
126
+ base, ui, url.
127
+
128
+ Returns:
129
+ BaseClient: A new BaseClient instance configured from the auth helper.
130
+ """
105
131
  default_kwargs = {
106
132
  "user_id": kwargs.get("user_id", None) or auth.user_id,
107
133
  "app_id": kwargs.get("app_id", None) or auth.app_id,
@@ -132,15 +158,18 @@ class BaseClient:
132
158
 
133
159
  return cls(**kwargs)
134
160
 
135
- def _grpc_request(self, method: Callable, argument: Any):
161
+ def _grpc_request(self, method: Callable[..., Any], argument: Any) -> Any:
136
162
  """Makes a gRPC request to the API.
137
163
 
138
164
  Args:
139
- method (Callable): The gRPC method to call.
140
- argument (Any): The argument to pass to the gRPC method.
165
+ method (Callable[..., Any]): The gRPC stub method to call.
166
+ argument (Any): The protobuf request object to pass to the gRPC method.
141
167
 
142
168
  Returns:
143
- res (Any): The result of the gRPC method call.
169
+ Any: The protobuf response object from the gRPC method call.
170
+
171
+ Raises:
172
+ Exception: If the API request fails.
144
173
  """
145
174
 
146
175
  try:
@@ -150,14 +179,16 @@ class BaseClient:
150
179
  except ApiError:
151
180
  raise Exception("ApiError")
152
181
 
153
- def convert_string_to_timestamp(self, date_str) -> Timestamp:
182
+ def convert_string_to_timestamp(self, date_str: str) -> Timestamp:
154
183
  """Converts a string to a Timestamp object.
155
184
 
156
185
  Args:
157
- date_str (str): The string to convert.
186
+ date_str (str): The date string to convert. Accepts formats like
187
+ '%Y-%m-%dT%H:%M:%S.%fZ' or '%Y-%m-%dT%H:%M:%SZ'.
158
188
 
159
189
  Returns:
160
- Timestamp: The converted Timestamp object.
190
+ Timestamp: The converted protobuf Timestamp object. Returns empty Timestamp
191
+ if the date string format is invalid.
161
192
  """
162
193
  # Parse the string into a Python datetime object
163
194
  try:
@@ -174,15 +205,22 @@ class BaseClient:
174
205
 
175
206
  return timestamp_obj
176
207
 
177
- def process_response_keys(self, old_dict, listing_resource=None):
208
+ def process_response_keys(
209
+ self, old_dict: Dict[str, Any], listing_resource: Optional[str] = None
210
+ ) -> Dict[str, Any]:
178
211
  """Converts keys in a response dictionary to resource proto format.
179
212
 
213
+ This method processes dictionary keys to match protobuf resource naming conventions,
214
+ particularly for converting 'id' to '{resource}_id' format.
215
+
180
216
  Args:
181
- old_dict (dict): The dictionary to convert.
182
- listing_resource (str, optional): The resource type for which the keys are being processed.
217
+ old_dict (Dict[str, Any]): The dictionary to convert with keys to process.
218
+ listing_resource (Optional[str]): The resource type for which the keys are being
219
+ processed (e.g., 'model', 'workflow'). If provided,
220
+ renames 'id' to '{listing_resource}_id'.
183
221
 
184
222
  Returns:
185
- new_dict (dict): The dictionary with processed keys.
223
+ Dict[str, Any]: The dictionary with processed keys and converted values.
186
224
  """
187
225
  if listing_resource:
188
226
  old_dict[f'{listing_resource}_id'] = old_dict['id']
@@ -4,7 +4,16 @@ import uuid
4
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
5
  from datetime import datetime
6
6
  from multiprocessing import cpu_count
7
- from typing import Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
7
+ from typing import (
8
+ Dict,
9
+ Generator,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ )
8
17
 
9
18
  import requests
10
19
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
@@ -122,10 +131,13 @@ class Dataset(Lister, BaseClient):
122
131
  Args:
123
132
  **kwargs: Additional keyword arguments to be passed to Dataset Version.
124
133
  - description (str): The description of the dataset version.
125
- - metadata (dict): The metadata of the dataset version.
134
+ - metadata (Dict[str, Any]): The metadata dictionary for the dataset version.
126
135
 
127
136
  Returns:
128
- Dataset: A Dataset object for the specified dataset ID.
137
+ Dataset: A Dataset object for the newly created dataset version.
138
+
139
+ Raises:
140
+ Exception: If the dataset version creation fails.
129
141
 
130
142
  Example:
131
143
  >>> from clarifai.client.dataset import Dataset
@@ -172,13 +184,13 @@ class Dataset(Lister, BaseClient):
172
184
  self.logger.info("\nDataset Version Deleted\n%s", response.status)
173
185
 
174
186
  def list_versions(
175
- self, page_no: int = None, per_page: int = None
187
+ self, page_no: Optional[int] = None, per_page: Optional[int] = None
176
188
  ) -> Generator['Dataset', None, None]:
177
189
  """Lists all the versions for the dataset.
178
190
 
179
191
  Args:
180
- page_no (int): The page number to list.
181
- per_page (int): The number of items per page.
192
+ page_no (Optional[int]): The page number to list. If None, lists all pages.
193
+ per_page (Optional[int]): The number of items per page. If None, uses default.
182
194
 
183
195
  Yields:
184
196
  Dataset: Dataset objects for the versions of the dataset.
clarifai/client/model.py CHANGED
@@ -2,7 +2,7 @@ import itertools
2
2
  import json
3
3
  import os
4
4
  import time
5
- from typing import Any, Dict, Generator, Iterable, Iterator, List, Tuple, Union
5
+ from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
6
6
 
7
7
  import numpy as np
8
8
  import requests
@@ -178,20 +178,26 @@ class Model(Lister, BaseClient):
178
178
 
179
179
  return templates
180
180
 
181
- def get_params(self, template: str = None, save_to: str = 'params.yaml') -> Dict[str, Any]:
182
- """Returns the model params for the model type and yaml file.
181
+ def get_params(
182
+ self, template: Optional[str] = None, save_to: str = 'params.yaml'
183
+ ) -> Dict[str, Any]:
184
+ """Returns the model params for the model type and saves them to a yaml file.
183
185
 
184
186
  Args:
185
- template (str): The template to use for the model type.
186
- yaml_file (str): The yaml file to save the model params.
187
+ template (Optional[str]): The template to use for the model type. Required for most
188
+ model types except 'clusterer' and 'embedding-classifier'.
189
+ save_to (str): The yaml file path to save the model params. Defaults to 'params.yaml'.
187
190
 
188
191
  Returns:
189
- params (Dict): Dictionary of model params for the model type.
192
+ Dict[str, Any]: Dictionary of model params for the model type.
193
+
194
+ Raises:
195
+ UserError: If the model type is not trainable, or if template is required but not provided.
190
196
 
191
197
  Example:
192
198
  >>> from clarifai.client.model import Model
193
199
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
194
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
200
+ >>> model_params = model.get_params(template='template', save_to='model_params.yaml')
195
201
  """
196
202
  if not self.model_info.model_type_id:
197
203
  self.load_info()
@@ -260,19 +266,22 @@ class Model(Lister, BaseClient):
260
266
  find_and_replace_key(self.training_params, key, value)
261
267
 
262
268
  def get_param_info(self, param: str) -> Dict[str, Any]:
263
- """Returns the param info for the param.
269
+ """Returns the parameter info for the specified parameter.
264
270
 
265
271
  Args:
266
- param (str): The param to get the info for.
272
+ param (str): The parameter name to get information for.
267
273
 
268
274
  Returns:
269
- param_info (Dict): Dictionary of model param info for the param.
275
+ Dict[str, Any]: Dictionary containing model parameter info for the specified param.
276
+
277
+ Raises:
278
+ UserError: If the model type is not trainable or if training params are not loaded.
270
279
 
271
280
  Example:
272
281
  >>> from clarifai.client.model import Model
273
282
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
274
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
275
- >>> model.get_param_info('param')
283
+ >>> model_params = model.get_params(template='template', save_to='model_params.yaml')
284
+ >>> param_info = model.get_param_info('learning_rate')
276
285
  """
277
286
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
278
287
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
@@ -513,8 +522,9 @@ class Model(Lister, BaseClient):
513
522
  model=self.model_info,
514
523
  runner_selector=self._runner_selector,
515
524
  )
525
+ # Pass in None for async stub will create it.
516
526
  self._client = ModelClient(
517
- stub=self.STUB, async_stub=self.async_stub, request_template=request_template
527
+ stub=self.STUB, async_stub=None, request_template=request_template
518
528
  )
519
529
  return self._client
520
530
 
@@ -314,6 +314,7 @@ class ModelClient:
314
314
  self,
315
315
  base_url: str = None,
316
316
  use_ctx: bool = False,
317
+ colorize: bool = False,
317
318
  ) -> str:
318
319
  """Generate a client script for this model.
319
320
 
@@ -335,6 +336,7 @@ class ModelClient:
335
336
  compute_cluster_id=self.request_template.runner_selector.nodepool.compute_cluster.id,
336
337
  nodepool_id=self.request_template.runner_selector.nodepool.id,
337
338
  use_ctx=use_ctx,
339
+ colorize=colorize,
338
340
  )
339
341
 
340
342
  def _define_compatability_functions(self):