clarifai 11.2.1__py3-none-any.whl → 11.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/cli/nodepool.py CHANGED
@@ -1,31 +1,27 @@
1
1
  import click
2
+
2
3
  from clarifai.cli.base import cli
3
- from clarifai.client.compute_cluster import ComputeCluster
4
- from clarifai.utils.cli import display_co_resources, dump_yaml, from_yaml, validate_context
4
+ from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
5
+ validate_context)
5
6
 
6
7
 
7
- @cli.group(['nodepool', 'np'])
8
+ @cli.group(['nodepool', 'np'], cls=AliasedGroup)
8
9
  def nodepool():
9
10
  """Manage Nodepools: create, delete, list"""
10
- pass
11
11
 
12
12
 
13
- @nodepool.command()
14
- @click.option(
15
- '-cc_id',
16
- '--compute_cluster_id',
17
- required=False,
18
- help='Compute Cluster ID for the compute cluster to interact with.')
13
+ @nodepool.command(['c'])
14
+ @click.argument('compute_cluster_id')
15
+ @click.argument('nodepool_id')
19
16
  @click.option(
20
17
  '--config',
21
18
  type=click.Path(exists=True),
22
19
  required=True,
23
20
  help='Path to the nodepool config file.')
24
- @click.option(
25
- '-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
26
21
  @click.pass_context
27
- def create(ctx, compute_cluster_id, config, nodepool_id):
22
+ def create(ctx, compute_cluster_id, nodepool_id, config):
28
23
  """Create a new Nodepool with the given config file."""
24
+ from clarifai.client.compute_cluster import ComputeCluster
29
25
 
30
26
  validate_context(ctx)
31
27
  nodepool_config = from_yaml(config)
@@ -43,52 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
43
39
 
44
40
  compute_cluster = ComputeCluster(
45
41
  compute_cluster_id=compute_cluster_id,
46
- user_id=ctx.obj['user_id'],
47
- pat=ctx.obj['pat'],
48
- base_url=ctx.obj['base_url'])
42
+ user_id=ctx.obj.current.user_id,
43
+ pat=ctx.obj.current.pat,
44
+ base_url=ctx.obj.current.api_base)
49
45
  if nodepool_id:
50
46
  compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
51
47
  else:
52
48
  compute_cluster.create_nodepool(config)
53
49
 
54
50
 
55
- @nodepool.command()
56
- @click.option(
57
- '-cc_id',
58
- '--compute_cluster_id',
59
- required=True,
60
- help='Compute Cluster ID for the compute cluster to interact with.')
51
+ @nodepool.command(['ls'])
52
+ @click.argument('compute_cluster_id', default="")
61
53
  @click.option('--page_no', required=False, help='Page number to list.', default=1)
62
- @click.option('--per_page', required=False, help='Number of items per page.', default=16)
54
+ @click.option('--per_page', required=False, help='Number of items per page.', default=128)
63
55
  @click.pass_context
64
56
  def list(ctx, compute_cluster_id, page_no, per_page):
65
- """List all nodepools for the user."""
57
+ """List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
58
+ it will list only within that compute cluster. """
59
+ from clarifai.client.compute_cluster import ComputeCluster
60
+ from clarifai.client.user import User
66
61
 
67
62
  validate_context(ctx)
68
- compute_cluster = ComputeCluster(
69
- compute_cluster_id=compute_cluster_id,
70
- user_id=ctx.obj['user_id'],
71
- pat=ctx.obj['pat'],
72
- base_url=ctx.obj['base_url'])
73
- response = compute_cluster.list_nodepools(page_no, per_page)
74
- display_co_resources(response, "Nodepool")
75
63
 
64
+ cc_id = compute_cluster_id
76
65
 
77
- @nodepool.command()
78
- @click.option(
79
- '-cc_id',
80
- '--compute_cluster_id',
81
- required=True,
82
- help='Compute Cluster ID for the compute cluster to interact with.')
83
- @click.option('-np_id', '--nodepool_id', help='Nodepool ID of the user to delete.')
66
+ if cc_id:
67
+ compute_cluster = ComputeCluster(
68
+ compute_cluster_id=cc_id,
69
+ user_id=ctx.obj.current.user_id,
70
+ pat=ctx.obj.current.pat,
71
+ base_url=ctx.obj.current.api_base)
72
+ response = compute_cluster.list_nodepools(page_no, per_page)
73
+ else:
74
+ user = User(
75
+ user_id=ctx.obj.current.user_id,
76
+ pat=ctx.obj.current.pat,
77
+ base_url=ctx.obj.current.api_base)
78
+ ccs = user.list_compute_clusters(page_no, per_page)
79
+ response = []
80
+ for cc in ccs:
81
+ compute_cluster = ComputeCluster(
82
+ compute_cluster_id=cc.id,
83
+ user_id=ctx.obj.current.user_id,
84
+ pat=ctx.obj.current.pat,
85
+ base_url=ctx.obj.current.api_base)
86
+ response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
87
+
88
+ display_co_resources(
89
+ response,
90
+ custom_columns={
91
+ 'ID': lambda c: c.id,
92
+ 'USER_ID': lambda c: c.compute_cluster.user_id,
93
+ 'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
94
+ 'DESCRIPTION': lambda c: c.description,
95
+ })
96
+
97
+
98
+ @nodepool.command(['rm'])
99
+ @click.argument('compute_cluster_id')
100
+ @click.argument('nodepool_id')
84
101
  @click.pass_context
85
102
  def delete(ctx, compute_cluster_id, nodepool_id):
86
103
  """Deletes a nodepool for the user."""
104
+ from clarifai.client.compute_cluster import ComputeCluster
87
105
 
88
106
  validate_context(ctx)
89
107
  compute_cluster = ComputeCluster(
90
108
  compute_cluster_id=compute_cluster_id,
91
- user_id=ctx.obj['user_id'],
92
- pat=ctx.obj['pat'],
93
- base_url=ctx.obj['base_url'])
109
+ user_id=ctx.obj.current.user_id,
110
+ pat=ctx.obj.current.pat,
111
+ base_url=ctx.obj.current.api_base)
94
112
  compute_cluster.delete_nodepools([nodepool_id])
clarifai/client/app.py CHANGED
@@ -629,7 +629,7 @@ class App(Lister, BaseClient):
629
629
 
630
630
  Args:
631
631
  model_id (str): The model ID for the model to interact with.
632
- model_version_id (str): The model version ID for the model version to interact with.
632
+ model_version (Dict): The model version ID for the model version to interact with.
633
633
 
634
634
  Returns:
635
635
  Model: A Model object for the existing model ID.
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import logging
3
2
  import time
4
3
  from concurrent.futures import ThreadPoolExecutor
5
4
 
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
8
7
 
9
8
  from clarifai.client.auth.helper import ClarifaiAuthHelper
10
9
  from clarifai.client.auth.register import RpcCallable, V2Stub
11
-
10
+ from clarifai.utils.logging import logger
12
11
  throttle_status_codes = {
13
12
  status_code_pb2.CONN_THROTTLED,
14
13
  status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
26
25
  def handle_simple_response(response):
27
26
  if hasattr(response, 'status') and hasattr(response.status, 'code'):
28
27
  if (response.status.code in throttle_status_codes) and attempt < max_attempts:
29
- logging.debug('Retrying with status %s' % str(response.status))
28
+ logger.debug('Retrying with status %s' % str(response.status))
30
29
  return None # Indicates a retry is needed
31
30
  else:
32
31
  return response
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
42
41
  return itertools.chain([validated_response], response)
43
42
  return None # Indicates a retry is needed
44
43
  except grpc.RpcError as e:
45
- logging.error('Error processing streaming response: %s' % str(e))
44
+ logger.error('Error processing streaming response: %s' % str(e))
46
45
  return None # Indicates an error
47
46
  else:
48
47
  # Handle simple response validation
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
143
142
  return v
144
143
  except grpc.RpcError as e:
145
144
  if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
146
- logging.debug('Retrying with status %s' % e.code())
145
+ logger.debug('Retrying with status %s' % e.code())
147
146
  else:
148
147
  raise
149
148
 
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import os
3
2
  import time
4
3
  import uuid
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
354
353
  break
355
354
  if failed_input_ids:
356
355
  retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
357
- logging.warning(
356
+ logger.warning(
358
357
  f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
359
358
  )
360
359
  failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
494
493
  add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
495
494
 
496
495
  if retry_duplicates and duplicate_input_ids:
497
- logging.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
496
+ logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
498
497
  duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
499
498
  self.upload_dataset(
500
499
  dataloader=dataloader,
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
505
504
 
506
505
  if failed_input_ids:
507
506
  #failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
508
- logging.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
507
+ logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
509
508
  failed_input_indexes = [input["Index"] for input in failed_input_ids]
510
509
  self.upload_dataset(
511
510
  dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
@@ -142,17 +142,21 @@ class ModelBuilder:
142
142
  def _validate_config_checkpoints(self):
143
143
  """
144
144
  Validates the checkpoints section in the config file.
145
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
145
146
  :return: loader_type the type of loader or None if no checkpoints.
146
147
  :return: repo_id location of checkpoint.
147
148
  :return: hf_token token to access checkpoint.
149
+ :return: when one of ['upload', 'build', 'runtime'] to download checkpoint
150
+ :return: allowed_file_patterns patterns to allow in downloaded checkpoint
151
+ :return: ignore_file_patterns patterns to ignore in downloaded checkpoint
148
152
  """
149
153
  if "checkpoints" not in self.config:
150
- return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
154
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
151
155
  assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
152
156
  loader_type = self.config.get("checkpoints").get("type")
153
157
  if not loader_type:
154
158
  logger.info("No loader type specified in the config file for checkpoints")
155
- return None, None, None
159
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
156
160
  checkpoints = self.config.get("checkpoints")
157
161
  if 'when' not in checkpoints:
158
162
  logger.warn(
@@ -171,15 +175,30 @@ class ModelBuilder:
171
175
 
172
176
  # get from config.yaml otherwise fall back to HF_TOKEN env var.
173
177
  hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
174
- return loader_type, repo_id, hf_token, when
178
+
179
+ allowed_file_patterns = self.config.get("checkpoints").get('allowed_file_patterns', None)
180
+ if isinstance(allowed_file_patterns, str):
181
+ allowed_file_patterns = [allowed_file_patterns]
182
+ ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
183
+ if isinstance(ignore_file_patterns, str):
184
+ ignore_file_patterns = [ignore_file_patterns]
185
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
175
186
 
176
187
  def _check_app_exists(self):
177
188
  resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
178
189
  if resp.status.code == status_code_pb2.SUCCESS:
179
190
  return True
191
+ if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
192
+ logger.error(
193
+ f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
194
+ )
195
+ return False
180
196
  logger.error(
181
197
  f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
182
198
  )
199
+ logger.error(
200
+ f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
201
+ )
183
202
  return False
184
203
 
185
204
  def _validate_config_model(self):
@@ -200,9 +219,6 @@ class ModelBuilder:
200
219
  assert model.get('id') != "", "model_id cannot be empty in the config file"
201
220
 
202
221
  if not self._check_app_exists():
203
- logger.error(
204
- f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
205
- )
206
222
  sys.exit(1)
207
223
 
208
224
  def _validate_config(self):
@@ -216,7 +232,7 @@ class ModelBuilder:
216
232
  assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
217
233
 
218
234
  if self.config.get("checkpoints"):
219
- loader_type, _, hf_token, _ = self._validate_config_checkpoints()
235
+ loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
220
236
 
221
237
  if loader_type == "huggingface" and hf_token:
222
238
  is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
@@ -399,11 +415,12 @@ class ModelBuilder:
399
415
  # Sort in reverse so that newer cuda versions come first and are preferred.
400
416
  for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
401
417
  if torch_version in image and f'py{python_version}' in image:
402
- cuda_version = image.split('-')[-1].replace('cuda', '')
418
+ # like cu124, rocm6.3, etc.
419
+ gpu_version = image.split('-')[-1]
403
420
  final_image = TORCH_BASE_IMAGE.format(
404
421
  torch_version=torch_version,
405
422
  python_version=python_version,
406
- cuda_version=cuda_version,
423
+ gpu_version=gpu_version,
407
424
  )
408
425
  logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
409
426
  break
@@ -480,8 +497,10 @@ class ModelBuilder:
480
497
  if not self.config.get("checkpoints"):
481
498
  logger.info("No checkpoints specified in the config file")
482
499
  return path
500
+ clarifai_model_type_id = self.config.get('model').get('model_type_id')
483
501
 
484
- loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
502
+ loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = self._validate_config_checkpoints(
503
+ )
485
504
  if stage not in ["build", "upload", "runtime"]:
486
505
  raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
487
506
  if when != stage:
@@ -490,14 +509,18 @@ class ModelBuilder:
490
509
  )
491
510
  return path
492
511
 
493
- success = True
512
+ success = False
494
513
  if loader_type == "huggingface":
495
- loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
514
+ loader = HuggingFaceLoader(
515
+ repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id)
496
516
  # for runtime default to /tmp path
497
517
  if stage == "runtime" and checkpoint_path_override is None:
498
518
  checkpoint_path_override = self.default_runtime_checkpoint_path()
499
519
  path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
500
- success = loader.download_checkpoints(path)
520
+ success = loader.download_checkpoints(
521
+ path,
522
+ allowed_file_patterns=allowed_file_patterns,
523
+ ignore_file_patterns=ignore_file_patterns)
501
524
 
502
525
  if loader_type:
503
526
  if not success:
@@ -569,7 +592,7 @@ class ModelBuilder:
569
592
  logger.debug(f"Will tar it into file: {file_path}")
570
593
 
571
594
  model_type_id = self.config.get('model').get('model_type_id')
572
- loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
595
+ loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
573
596
 
574
597
  if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
575
598
  logger.info(
@@ -617,7 +640,7 @@ class ModelBuilder:
617
640
  # First check for the env variable, then try querying huggingface. If all else fails, use the default.
618
641
  checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
619
642
  if not checkpoint_size:
620
- _, repo_id, _, _ = self._validate_config_checkpoints()
643
+ _, repo_id, _, _, _, _ = self._validate_config_checkpoints()
621
644
  checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
622
645
  if not checkpoint_size:
623
646
  checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
@@ -108,6 +108,7 @@ class ModelRunLocally:
108
108
 
109
109
  def _build_stream_request(self):
110
110
  request = self._build_request()
111
+ ensure_urls_downloaded(request)
111
112
  for i in range(1):
112
113
  yield request
113
114
 
@@ -479,7 +480,7 @@ def main(model_path,
479
480
  manager = ModelRunLocally(model_path)
480
481
  # get whatever stage is in config.yaml to force download now
481
482
  # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
482
- _, _, _, when = manager.builder._validate_config_checkpoints()
483
+ _, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
483
484
  manager.builder.download_checkpoints(
484
485
  stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
485
486
  if inside_container:
@@ -2,10 +2,10 @@ import os
2
2
 
3
3
  registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
4
4
 
5
- GIT_SHA = "df565436eea93efb3e8d1eb558a0a46df29523ec"
5
+ GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
6
6
 
7
7
  PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
8
- TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-cuda{cuda_version}-' + GIT_SHA
8
+ TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
9
9
 
10
10
  # List of available python base images
11
11
  AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
21
21
  # List of available torch images
22
22
  # Keep sorted by most recent cuda version.
23
23
  AVAILABLE_TORCH_IMAGES = [
24
- '2.4.1-py3.11-cuda124',
25
- '2.5.1-py3.11-cuda124',
26
- '2.4.1-py3.12-cuda124',
27
- '2.5.1-py3.12-cuda124',
28
- # '2.4.1-py3.13-cuda124',
29
- # '2.5.1-py3.13-cuda124',
24
+ '2.4.1-py3.11-cu124',
25
+ '2.5.1-py3.11-cu124',
26
+ '2.4.1-py3.12-cu124',
27
+ '2.5.1-py3.12-cu124',
28
+ '2.6.0-py3.12-cu126',
29
+ '2.7.0-py3.12-cu128',
30
+ '2.7.0-py3.12-rocm6.3',
30
31
  ]
31
32
  CONCEPTS_REQUIRED_MODEL_TYPE = [
32
33
  'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
@@ -6,6 +6,7 @@ import shutil
6
6
 
7
7
  import requests
8
8
 
9
+ from clarifai.runners.utils.const import CONCEPTS_REQUIRED_MODEL_TYPE
9
10
  from clarifai.utils.logging import logger
10
11
 
11
12
 
@@ -13,9 +14,10 @@ class HuggingFaceLoader:
13
14
 
14
15
  HF_DOWNLOAD_TEXT = "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
15
16
 
16
- def __init__(self, repo_id=None, token=None):
17
+ def __init__(self, repo_id=None, token=None, model_type_id=None):
17
18
  self.repo_id = repo_id
18
19
  self.token = token
20
+ self.clarifai_model_type_id = model_type_id
19
21
  if token:
20
22
  if self.validate_hftoken(token):
21
23
  try:
@@ -43,13 +45,17 @@ class HuggingFaceLoader:
43
45
  f"Error setting up Hugging Face token, please make sure you have the correct token: {e}")
44
46
  return False
45
47
 
46
- def download_checkpoints(self, checkpoint_path: str):
48
+ def download_checkpoints(self,
49
+ checkpoint_path: str,
50
+ allowed_file_patterns=None,
51
+ ignore_file_patterns=None):
47
52
  # throw error if huggingface_hub wasn't installed
48
53
  try:
49
54
  from huggingface_hub import snapshot_download
50
55
  except ImportError:
51
56
  raise ImportError(self.HF_DOWNLOAD_TEXT)
52
- if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
57
+ if os.path.exists(checkpoint_path) and self.validate_download(
58
+ checkpoint_path, allowed_file_patterns, ignore_file_patterns):
53
59
  logger.info("Checkpoints already exist")
54
60
  return True
55
61
  else:
@@ -61,10 +67,16 @@ class HuggingFaceLoader:
61
67
  return False
62
68
 
63
69
  self.ignore_patterns = self._get_ignore_patterns()
70
+ if ignore_file_patterns:
71
+ if self.ignore_patterns:
72
+ self.ignore_patterns.extend(ignore_file_patterns)
73
+ else:
74
+ self.ignore_patterns = ignore_file_patterns
64
75
  snapshot_download(
65
76
  repo_id=self.repo_id,
66
77
  local_dir=checkpoint_path,
67
78
  local_dir_use_symlinks=False,
79
+ allow_patterns=allowed_file_patterns,
68
80
  ignore_patterns=self.ignore_patterns)
69
81
  # Remove the `.cache` folder if it exists
70
82
  cache_path = os.path.join(checkpoint_path, ".cache")
@@ -75,7 +87,8 @@ class HuggingFaceLoader:
75
87
  logger.error(f"Error downloading model checkpoints {e}")
76
88
  return False
77
89
  finally:
78
- is_downloaded = self.validate_download(checkpoint_path)
90
+ is_downloaded = self.validate_download(checkpoint_path, allowed_file_patterns,
91
+ ignore_file_patterns)
79
92
  if not is_downloaded:
80
93
  logger.error("Error validating downloaded model checkpoints")
81
94
  return False
@@ -109,9 +122,13 @@ class HuggingFaceLoader:
109
122
  from huggingface_hub import file_exists, repo_exists
110
123
  except ImportError:
111
124
  raise ImportError(self.HF_DOWNLOAD_TEXT)
112
- return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
125
+ if self.clarifai_model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
126
+ return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
127
+ else:
128
+ return repo_exists(self.repo_id)
113
129
 
114
- def validate_download(self, checkpoint_path: str):
130
+ def validate_download(self, checkpoint_path: str, allowed_file_patterns: list,
131
+ ignore_file_patterns: list):
115
132
  # check if model exists on HF
116
133
  try:
117
134
  from huggingface_hub import list_repo_files
@@ -120,7 +137,20 @@ class HuggingFaceLoader:
120
137
  # Get the list of files on the repo
121
138
  repo_files = list_repo_files(self.repo_id, token=self.token)
122
139
 
140
+ # Get the list of files on the repo that are allowed
141
+ if allowed_file_patterns:
142
+
143
+ def should_allow(file_path):
144
+ return any(fnmatch.fnmatch(file_path, pattern) for pattern in allowed_file_patterns)
145
+
146
+ repo_files = [f for f in repo_files if should_allow(f)]
147
+
123
148
  self.ignore_patterns = self._get_ignore_patterns()
149
+ if ignore_file_patterns:
150
+ if self.ignore_patterns:
151
+ self.ignore_patterns.extend(ignore_file_patterns)
152
+ else:
153
+ self.ignore_patterns = ignore_file_patterns
124
154
  # Get the list of files on the repo that are not ignored
125
155
  if getattr(self, "ignore_patterns", None):
126
156
  patterns = self.ignore_patterns