clarifai 11.2.1__py3-none-any.whl → 11.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +225 -89
- clarifai/cli/compute_cluster.py +24 -21
- clarifai/cli/deployment.py +66 -42
- clarifai/cli/model.py +1 -1
- clarifai/cli/nodepool.py +59 -41
- clarifai/client/app.py +1 -1
- clarifai/client/auth/stub.py +4 -5
- clarifai/client/dataset.py +3 -4
- clarifai/runners/models/model_builder.py +38 -15
- clarifai/runners/models/model_run_locally.py +2 -1
- clarifai/runners/utils/const.py +9 -8
- clarifai/runners/utils/loader.py +36 -6
- clarifai/utils/cli.py +125 -36
- clarifai/utils/config.py +105 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +64 -21
- clarifai/utils/misc.py +2 -0
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/METADATA +2 -2
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/RECORD +24 -23
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/WHEEL +1 -1
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.2.1.dist-info → clarifai-11.2.3.dist-info}/top_level.txt +0 -0
clarifai/cli/nodepool.py
CHANGED
@@ -1,31 +1,27 @@
|
|
1
1
|
import click
|
2
|
+
|
2
3
|
from clarifai.cli.base import cli
|
3
|
-
from clarifai.
|
4
|
-
|
4
|
+
from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
|
5
|
+
validate_context)
|
5
6
|
|
6
7
|
|
7
|
-
@cli.group(['nodepool', 'np'])
|
8
|
+
@cli.group(['nodepool', 'np'], cls=AliasedGroup)
|
8
9
|
def nodepool():
|
9
10
|
"""Manage Nodepools: create, delete, list"""
|
10
|
-
pass
|
11
11
|
|
12
12
|
|
13
|
-
@nodepool.command()
|
14
|
-
@click.
|
15
|
-
|
16
|
-
'--compute_cluster_id',
|
17
|
-
required=False,
|
18
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
13
|
+
@nodepool.command(['c'])
|
14
|
+
@click.argument('compute_cluster_id')
|
15
|
+
@click.argument('nodepool_id')
|
19
16
|
@click.option(
|
20
17
|
'--config',
|
21
18
|
type=click.Path(exists=True),
|
22
19
|
required=True,
|
23
20
|
help='Path to the nodepool config file.')
|
24
|
-
@click.option(
|
25
|
-
'-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
|
26
21
|
@click.pass_context
|
27
|
-
def create(ctx, compute_cluster_id,
|
22
|
+
def create(ctx, compute_cluster_id, nodepool_id, config):
|
28
23
|
"""Create a new Nodepool with the given config file."""
|
24
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
29
25
|
|
30
26
|
validate_context(ctx)
|
31
27
|
nodepool_config = from_yaml(config)
|
@@ -43,52 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
|
|
43
39
|
|
44
40
|
compute_cluster = ComputeCluster(
|
45
41
|
compute_cluster_id=compute_cluster_id,
|
46
|
-
user_id=ctx.obj
|
47
|
-
pat=ctx.obj
|
48
|
-
base_url=ctx.obj
|
42
|
+
user_id=ctx.obj.current.user_id,
|
43
|
+
pat=ctx.obj.current.pat,
|
44
|
+
base_url=ctx.obj.current.api_base)
|
49
45
|
if nodepool_id:
|
50
46
|
compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
|
51
47
|
else:
|
52
48
|
compute_cluster.create_nodepool(config)
|
53
49
|
|
54
50
|
|
55
|
-
@nodepool.command()
|
56
|
-
@click.
|
57
|
-
'-cc_id',
|
58
|
-
'--compute_cluster_id',
|
59
|
-
required=True,
|
60
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
51
|
+
@nodepool.command(['ls'])
|
52
|
+
@click.argument('compute_cluster_id', default="")
|
61
53
|
@click.option('--page_no', required=False, help='Page number to list.', default=1)
|
62
|
-
@click.option('--per_page', required=False, help='Number of items per page.', default=
|
54
|
+
@click.option('--per_page', required=False, help='Number of items per page.', default=128)
|
63
55
|
@click.pass_context
|
64
56
|
def list(ctx, compute_cluster_id, page_no, per_page):
|
65
|
-
"""List all nodepools for the user.
|
57
|
+
"""List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
|
58
|
+
it will list only within that compute cluster. """
|
59
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
60
|
+
from clarifai.client.user import User
|
66
61
|
|
67
62
|
validate_context(ctx)
|
68
|
-
compute_cluster = ComputeCluster(
|
69
|
-
compute_cluster_id=compute_cluster_id,
|
70
|
-
user_id=ctx.obj['user_id'],
|
71
|
-
pat=ctx.obj['pat'],
|
72
|
-
base_url=ctx.obj['base_url'])
|
73
|
-
response = compute_cluster.list_nodepools(page_no, per_page)
|
74
|
-
display_co_resources(response, "Nodepool")
|
75
63
|
|
64
|
+
cc_id = compute_cluster_id
|
76
65
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
66
|
+
if cc_id:
|
67
|
+
compute_cluster = ComputeCluster(
|
68
|
+
compute_cluster_id=cc_id,
|
69
|
+
user_id=ctx.obj.current.user_id,
|
70
|
+
pat=ctx.obj.current.pat,
|
71
|
+
base_url=ctx.obj.current.api_base)
|
72
|
+
response = compute_cluster.list_nodepools(page_no, per_page)
|
73
|
+
else:
|
74
|
+
user = User(
|
75
|
+
user_id=ctx.obj.current.user_id,
|
76
|
+
pat=ctx.obj.current.pat,
|
77
|
+
base_url=ctx.obj.current.api_base)
|
78
|
+
ccs = user.list_compute_clusters(page_no, per_page)
|
79
|
+
response = []
|
80
|
+
for cc in ccs:
|
81
|
+
compute_cluster = ComputeCluster(
|
82
|
+
compute_cluster_id=cc.id,
|
83
|
+
user_id=ctx.obj.current.user_id,
|
84
|
+
pat=ctx.obj.current.pat,
|
85
|
+
base_url=ctx.obj.current.api_base)
|
86
|
+
response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
|
87
|
+
|
88
|
+
display_co_resources(
|
89
|
+
response,
|
90
|
+
custom_columns={
|
91
|
+
'ID': lambda c: c.id,
|
92
|
+
'USER_ID': lambda c: c.compute_cluster.user_id,
|
93
|
+
'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
|
94
|
+
'DESCRIPTION': lambda c: c.description,
|
95
|
+
})
|
96
|
+
|
97
|
+
|
98
|
+
@nodepool.command(['rm'])
|
99
|
+
@click.argument('compute_cluster_id')
|
100
|
+
@click.argument('nodepool_id')
|
84
101
|
@click.pass_context
|
85
102
|
def delete(ctx, compute_cluster_id, nodepool_id):
|
86
103
|
"""Deletes a nodepool for the user."""
|
104
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
87
105
|
|
88
106
|
validate_context(ctx)
|
89
107
|
compute_cluster = ComputeCluster(
|
90
108
|
compute_cluster_id=compute_cluster_id,
|
91
|
-
user_id=ctx.obj
|
92
|
-
pat=ctx.obj
|
93
|
-
base_url=ctx.obj
|
109
|
+
user_id=ctx.obj.current.user_id,
|
110
|
+
pat=ctx.obj.current.pat,
|
111
|
+
base_url=ctx.obj.current.api_base)
|
94
112
|
compute_cluster.delete_nodepools([nodepool_id])
|
clarifai/client/app.py
CHANGED
@@ -629,7 +629,7 @@ class App(Lister, BaseClient):
|
|
629
629
|
|
630
630
|
Args:
|
631
631
|
model_id (str): The model ID for the model to interact with.
|
632
|
-
|
632
|
+
model_version (Dict): The model version ID for the model version to interact with.
|
633
633
|
|
634
634
|
Returns:
|
635
635
|
Model: A Model object for the existing model ID.
|
clarifai/client/auth/stub.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import itertools
|
2
|
-
import logging
|
3
2
|
import time
|
4
3
|
from concurrent.futures import ThreadPoolExecutor
|
5
4
|
|
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
|
|
8
7
|
|
9
8
|
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
10
9
|
from clarifai.client.auth.register import RpcCallable, V2Stub
|
11
|
-
|
10
|
+
from clarifai.utils.logging import logger
|
12
11
|
throttle_status_codes = {
|
13
12
|
status_code_pb2.CONN_THROTTLED,
|
14
13
|
status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
|
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
|
|
26
25
|
def handle_simple_response(response):
|
27
26
|
if hasattr(response, 'status') and hasattr(response.status, 'code'):
|
28
27
|
if (response.status.code in throttle_status_codes) and attempt < max_attempts:
|
29
|
-
|
28
|
+
logger.debug('Retrying with status %s' % str(response.status))
|
30
29
|
return None # Indicates a retry is needed
|
31
30
|
else:
|
32
31
|
return response
|
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
|
|
42
41
|
return itertools.chain([validated_response], response)
|
43
42
|
return None # Indicates a retry is needed
|
44
43
|
except grpc.RpcError as e:
|
45
|
-
|
44
|
+
logger.error('Error processing streaming response: %s' % str(e))
|
46
45
|
return None # Indicates an error
|
47
46
|
else:
|
48
47
|
# Handle simple response validation
|
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
|
|
143
142
|
return v
|
144
143
|
except grpc.RpcError as e:
|
145
144
|
if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
|
146
|
-
|
145
|
+
logger.debug('Retrying with status %s' % e.code())
|
147
146
|
else:
|
148
147
|
raise
|
149
148
|
|
clarifai/client/dataset.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import logging
|
2
1
|
import os
|
3
2
|
import time
|
4
3
|
import uuid
|
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
|
|
354
353
|
break
|
355
354
|
if failed_input_ids:
|
356
355
|
retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
|
357
|
-
|
356
|
+
logger.warning(
|
358
357
|
f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
|
359
358
|
)
|
360
359
|
failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
|
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
|
|
494
493
|
add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
|
495
494
|
|
496
495
|
if retry_duplicates and duplicate_input_ids:
|
497
|
-
|
496
|
+
logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
|
498
497
|
duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
|
499
498
|
self.upload_dataset(
|
500
499
|
dataloader=dataloader,
|
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
|
|
505
504
|
|
506
505
|
if failed_input_ids:
|
507
506
|
#failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
|
508
|
-
|
507
|
+
logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
|
509
508
|
failed_input_indexes = [input["Index"] for input in failed_input_ids]
|
510
509
|
self.upload_dataset(
|
511
510
|
dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
|
@@ -142,17 +142,21 @@ class ModelBuilder:
|
|
142
142
|
def _validate_config_checkpoints(self):
|
143
143
|
"""
|
144
144
|
Validates the checkpoints section in the config file.
|
145
|
+
return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
|
145
146
|
:return: loader_type the type of loader or None if no checkpoints.
|
146
147
|
:return: repo_id location of checkpoint.
|
147
148
|
:return: hf_token token to access checkpoint.
|
149
|
+
:return: when one of ['upload', 'build', 'runtime'] to download checkpoint
|
150
|
+
:return: allowed_file_patterns patterns to allow in downloaded checkpoint
|
151
|
+
:return: ignore_file_patterns patterns to ignore in downloaded checkpoint
|
148
152
|
"""
|
149
153
|
if "checkpoints" not in self.config:
|
150
|
-
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
|
154
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
151
155
|
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
|
152
156
|
loader_type = self.config.get("checkpoints").get("type")
|
153
157
|
if not loader_type:
|
154
158
|
logger.info("No loader type specified in the config file for checkpoints")
|
155
|
-
return None, None, None
|
159
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
156
160
|
checkpoints = self.config.get("checkpoints")
|
157
161
|
if 'when' not in checkpoints:
|
158
162
|
logger.warn(
|
@@ -171,15 +175,30 @@ class ModelBuilder:
|
|
171
175
|
|
172
176
|
# get from config.yaml otherwise fall back to HF_TOKEN env var.
|
173
177
|
hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
|
174
|
-
|
178
|
+
|
179
|
+
allowed_file_patterns = self.config.get("checkpoints").get('allowed_file_patterns', None)
|
180
|
+
if isinstance(allowed_file_patterns, str):
|
181
|
+
allowed_file_patterns = [allowed_file_patterns]
|
182
|
+
ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
|
183
|
+
if isinstance(ignore_file_patterns, str):
|
184
|
+
ignore_file_patterns = [ignore_file_patterns]
|
185
|
+
return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
|
175
186
|
|
176
187
|
def _check_app_exists(self):
|
177
188
|
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
|
178
189
|
if resp.status.code == status_code_pb2.SUCCESS:
|
179
190
|
return True
|
191
|
+
if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
|
192
|
+
logger.error(
|
193
|
+
f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
|
194
|
+
)
|
195
|
+
return False
|
180
196
|
logger.error(
|
181
197
|
f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
|
182
198
|
)
|
199
|
+
logger.error(
|
200
|
+
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
|
201
|
+
)
|
183
202
|
return False
|
184
203
|
|
185
204
|
def _validate_config_model(self):
|
@@ -200,9 +219,6 @@ class ModelBuilder:
|
|
200
219
|
assert model.get('id') != "", "model_id cannot be empty in the config file"
|
201
220
|
|
202
221
|
if not self._check_app_exists():
|
203
|
-
logger.error(
|
204
|
-
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
|
205
|
-
)
|
206
222
|
sys.exit(1)
|
207
223
|
|
208
224
|
def _validate_config(self):
|
@@ -216,7 +232,7 @@ class ModelBuilder:
|
|
216
232
|
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
|
217
233
|
|
218
234
|
if self.config.get("checkpoints"):
|
219
|
-
loader_type, _, hf_token, _ = self._validate_config_checkpoints()
|
235
|
+
loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
|
220
236
|
|
221
237
|
if loader_type == "huggingface" and hf_token:
|
222
238
|
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
|
@@ -399,11 +415,12 @@ class ModelBuilder:
|
|
399
415
|
# Sort in reverse so that newer cuda versions come first and are preferred.
|
400
416
|
for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
|
401
417
|
if torch_version in image and f'py{python_version}' in image:
|
402
|
-
|
418
|
+
# like cu124, rocm6.3, etc.
|
419
|
+
gpu_version = image.split('-')[-1]
|
403
420
|
final_image = TORCH_BASE_IMAGE.format(
|
404
421
|
torch_version=torch_version,
|
405
422
|
python_version=python_version,
|
406
|
-
|
423
|
+
gpu_version=gpu_version,
|
407
424
|
)
|
408
425
|
logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
|
409
426
|
break
|
@@ -480,8 +497,10 @@ class ModelBuilder:
|
|
480
497
|
if not self.config.get("checkpoints"):
|
481
498
|
logger.info("No checkpoints specified in the config file")
|
482
499
|
return path
|
500
|
+
clarifai_model_type_id = self.config.get('model').get('model_type_id')
|
483
501
|
|
484
|
-
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints(
|
502
|
+
loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = self._validate_config_checkpoints(
|
503
|
+
)
|
485
504
|
if stage not in ["build", "upload", "runtime"]:
|
486
505
|
raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
|
487
506
|
if when != stage:
|
@@ -490,14 +509,18 @@ class ModelBuilder:
|
|
490
509
|
)
|
491
510
|
return path
|
492
511
|
|
493
|
-
success =
|
512
|
+
success = False
|
494
513
|
if loader_type == "huggingface":
|
495
|
-
loader = HuggingFaceLoader(
|
514
|
+
loader = HuggingFaceLoader(
|
515
|
+
repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id)
|
496
516
|
# for runtime default to /tmp path
|
497
517
|
if stage == "runtime" and checkpoint_path_override is None:
|
498
518
|
checkpoint_path_override = self.default_runtime_checkpoint_path()
|
499
519
|
path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
|
500
|
-
success = loader.download_checkpoints(
|
520
|
+
success = loader.download_checkpoints(
|
521
|
+
path,
|
522
|
+
allowed_file_patterns=allowed_file_patterns,
|
523
|
+
ignore_file_patterns=ignore_file_patterns)
|
501
524
|
|
502
525
|
if loader_type:
|
503
526
|
if not success:
|
@@ -569,7 +592,7 @@ class ModelBuilder:
|
|
569
592
|
logger.debug(f"Will tar it into file: {file_path}")
|
570
593
|
|
571
594
|
model_type_id = self.config.get('model').get('model_type_id')
|
572
|
-
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
|
595
|
+
loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
|
573
596
|
|
574
597
|
if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
|
575
598
|
logger.info(
|
@@ -617,7 +640,7 @@ class ModelBuilder:
|
|
617
640
|
# First check for the env variable, then try querying huggingface. If all else fails, use the default.
|
618
641
|
checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
|
619
642
|
if not checkpoint_size:
|
620
|
-
_, repo_id, _, _ = self._validate_config_checkpoints()
|
643
|
+
_, repo_id, _, _, _, _ = self._validate_config_checkpoints()
|
621
644
|
checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
|
622
645
|
if not checkpoint_size:
|
623
646
|
checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
|
@@ -108,6 +108,7 @@ class ModelRunLocally:
|
|
108
108
|
|
109
109
|
def _build_stream_request(self):
|
110
110
|
request = self._build_request()
|
111
|
+
ensure_urls_downloaded(request)
|
111
112
|
for i in range(1):
|
112
113
|
yield request
|
113
114
|
|
@@ -479,7 +480,7 @@ def main(model_path,
|
|
479
480
|
manager = ModelRunLocally(model_path)
|
480
481
|
# get whatever stage is in config.yaml to force download now
|
481
482
|
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
|
482
|
-
_, _, _, when = manager.builder._validate_config_checkpoints()
|
483
|
+
_, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
|
483
484
|
manager.builder.download_checkpoints(
|
484
485
|
stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
|
485
486
|
if inside_container:
|
clarifai/runners/utils/const.py
CHANGED
@@ -2,10 +2,10 @@ import os
|
|
2
2
|
|
3
3
|
registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
|
4
4
|
|
5
|
-
GIT_SHA = "
|
5
|
+
GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
|
6
6
|
|
7
7
|
PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
|
8
|
-
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-
|
8
|
+
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
|
9
9
|
|
10
10
|
# List of available python base images
|
11
11
|
AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
|
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
|
|
21
21
|
# List of available torch images
|
22
22
|
# Keep sorted by most recent cuda version.
|
23
23
|
AVAILABLE_TORCH_IMAGES = [
|
24
|
-
'2.4.1-py3.11-
|
25
|
-
'2.5.1-py3.11-
|
26
|
-
'2.4.1-py3.12-
|
27
|
-
'2.5.1-py3.12-
|
28
|
-
|
29
|
-
|
24
|
+
'2.4.1-py3.11-cu124',
|
25
|
+
'2.5.1-py3.11-cu124',
|
26
|
+
'2.4.1-py3.12-cu124',
|
27
|
+
'2.5.1-py3.12-cu124',
|
28
|
+
'2.6.0-py3.12-cu126',
|
29
|
+
'2.7.0-py3.12-cu128',
|
30
|
+
'2.7.0-py3.12-rocm6.3',
|
30
31
|
]
|
31
32
|
CONCEPTS_REQUIRED_MODEL_TYPE = [
|
32
33
|
'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
|
clarifai/runners/utils/loader.py
CHANGED
@@ -6,6 +6,7 @@ import shutil
|
|
6
6
|
|
7
7
|
import requests
|
8
8
|
|
9
|
+
from clarifai.runners.utils.const import CONCEPTS_REQUIRED_MODEL_TYPE
|
9
10
|
from clarifai.utils.logging import logger
|
10
11
|
|
11
12
|
|
@@ -13,9 +14,10 @@ class HuggingFaceLoader:
|
|
13
14
|
|
14
15
|
HF_DOWNLOAD_TEXT = "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
|
15
16
|
|
16
|
-
def __init__(self, repo_id=None, token=None):
|
17
|
+
def __init__(self, repo_id=None, token=None, model_type_id=None):
|
17
18
|
self.repo_id = repo_id
|
18
19
|
self.token = token
|
20
|
+
self.clarifai_model_type_id = model_type_id
|
19
21
|
if token:
|
20
22
|
if self.validate_hftoken(token):
|
21
23
|
try:
|
@@ -43,13 +45,17 @@ class HuggingFaceLoader:
|
|
43
45
|
f"Error setting up Hugging Face token, please make sure you have the correct token: {e}")
|
44
46
|
return False
|
45
47
|
|
46
|
-
def download_checkpoints(self,
|
48
|
+
def download_checkpoints(self,
|
49
|
+
checkpoint_path: str,
|
50
|
+
allowed_file_patterns=None,
|
51
|
+
ignore_file_patterns=None):
|
47
52
|
# throw error if huggingface_hub wasn't installed
|
48
53
|
try:
|
49
54
|
from huggingface_hub import snapshot_download
|
50
55
|
except ImportError:
|
51
56
|
raise ImportError(self.HF_DOWNLOAD_TEXT)
|
52
|
-
if os.path.exists(checkpoint_path) and self.validate_download(
|
57
|
+
if os.path.exists(checkpoint_path) and self.validate_download(
|
58
|
+
checkpoint_path, allowed_file_patterns, ignore_file_patterns):
|
53
59
|
logger.info("Checkpoints already exist")
|
54
60
|
return True
|
55
61
|
else:
|
@@ -61,10 +67,16 @@ class HuggingFaceLoader:
|
|
61
67
|
return False
|
62
68
|
|
63
69
|
self.ignore_patterns = self._get_ignore_patterns()
|
70
|
+
if ignore_file_patterns:
|
71
|
+
if self.ignore_patterns:
|
72
|
+
self.ignore_patterns.extend(ignore_file_patterns)
|
73
|
+
else:
|
74
|
+
self.ignore_patterns = ignore_file_patterns
|
64
75
|
snapshot_download(
|
65
76
|
repo_id=self.repo_id,
|
66
77
|
local_dir=checkpoint_path,
|
67
78
|
local_dir_use_symlinks=False,
|
79
|
+
allow_patterns=allowed_file_patterns,
|
68
80
|
ignore_patterns=self.ignore_patterns)
|
69
81
|
# Remove the `.cache` folder if it exists
|
70
82
|
cache_path = os.path.join(checkpoint_path, ".cache")
|
@@ -75,7 +87,8 @@ class HuggingFaceLoader:
|
|
75
87
|
logger.error(f"Error downloading model checkpoints {e}")
|
76
88
|
return False
|
77
89
|
finally:
|
78
|
-
is_downloaded = self.validate_download(checkpoint_path
|
90
|
+
is_downloaded = self.validate_download(checkpoint_path, allowed_file_patterns,
|
91
|
+
ignore_file_patterns)
|
79
92
|
if not is_downloaded:
|
80
93
|
logger.error("Error validating downloaded model checkpoints")
|
81
94
|
return False
|
@@ -109,9 +122,13 @@ class HuggingFaceLoader:
|
|
109
122
|
from huggingface_hub import file_exists, repo_exists
|
110
123
|
except ImportError:
|
111
124
|
raise ImportError(self.HF_DOWNLOAD_TEXT)
|
112
|
-
|
125
|
+
if self.clarifai_model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
|
126
|
+
return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
|
127
|
+
else:
|
128
|
+
return repo_exists(self.repo_id)
|
113
129
|
|
114
|
-
def validate_download(self, checkpoint_path: str
|
130
|
+
def validate_download(self, checkpoint_path: str, allowed_file_patterns: list,
|
131
|
+
ignore_file_patterns: list):
|
115
132
|
# check if model exists on HF
|
116
133
|
try:
|
117
134
|
from huggingface_hub import list_repo_files
|
@@ -120,7 +137,20 @@ class HuggingFaceLoader:
|
|
120
137
|
# Get the list of files on the repo
|
121
138
|
repo_files = list_repo_files(self.repo_id, token=self.token)
|
122
139
|
|
140
|
+
# Get the list of files on the repo that are allowed
|
141
|
+
if allowed_file_patterns:
|
142
|
+
|
143
|
+
def should_allow(file_path):
|
144
|
+
return any(fnmatch.fnmatch(file_path, pattern) for pattern in allowed_file_patterns)
|
145
|
+
|
146
|
+
repo_files = [f for f in repo_files if should_allow(f)]
|
147
|
+
|
123
148
|
self.ignore_patterns = self._get_ignore_patterns()
|
149
|
+
if ignore_file_patterns:
|
150
|
+
if self.ignore_patterns:
|
151
|
+
self.ignore_patterns.extend(ignore_file_patterns)
|
152
|
+
else:
|
153
|
+
self.ignore_patterns = ignore_file_patterns
|
124
154
|
# Get the list of files on the repo that are not ignored
|
125
155
|
if getattr(self, "ignore_patterns", None):
|
126
156
|
patterns = self.ignore_patterns
|