clarifai 11.2.3rc6__py3-none-any.whl → 11.2.3rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +228 -81
- clarifai/cli/compute_cluster.py +28 -18
- clarifai/cli/deployment.py +70 -42
- clarifai/cli/model.py +26 -14
- clarifai/cli/nodepool.py +62 -41
- clarifai/client/app.py +1 -1
- clarifai/client/auth/stub.py +4 -5
- clarifai/client/dataset.py +3 -4
- clarifai/client/model_client.py +35 -6
- clarifai/runners/models/model_builder.py +42 -59
- clarifai/runners/models/model_class.py +12 -8
- clarifai/runners/models/model_run_locally.py +11 -23
- clarifai/runners/utils/const.py +9 -8
- clarifai/runners/utils/data_types.py +0 -4
- clarifai/runners/utils/data_utils.py +10 -10
- clarifai/runners/utils/loader.py +36 -6
- clarifai/runners/utils/method_signatures.py +2 -1
- clarifai/utils/cli.py +132 -34
- clarifai/utils/config.py +105 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +64 -21
- clarifai/utils/misc.py +2 -0
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/METADATA +2 -13
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/RECORD +29 -28
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/WHEEL +1 -1
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/LICENSE +0 -0
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/entry_points.txt +0 -0
- {clarifai-11.2.3rc6.dist-info → clarifai-11.2.3rc9.dist-info}/top_level.txt +0 -0
clarifai/cli/nodepool.py
CHANGED
@@ -1,32 +1,29 @@
|
|
1
1
|
import click
|
2
|
+
|
2
3
|
from clarifai.cli.base import cli
|
3
|
-
from clarifai.
|
4
|
-
|
4
|
+
from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
|
5
|
+
validate_context)
|
5
6
|
|
6
7
|
|
7
|
-
@cli.group(['nodepool', 'np'])
|
8
|
+
@cli.group(['nodepool', 'np'], cls=AliasedGroup)
|
8
9
|
def nodepool():
|
9
10
|
"""Manage Nodepools: create, delete, list"""
|
10
|
-
pass
|
11
11
|
|
12
12
|
|
13
|
-
@nodepool.command()
|
14
|
-
@click.
|
15
|
-
|
16
|
-
'--compute_cluster_id',
|
17
|
-
required=False,
|
18
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
13
|
+
@nodepool.command(['c'])
|
14
|
+
@click.argument('compute_cluster_id')
|
15
|
+
@click.argument('nodepool_id')
|
19
16
|
@click.option(
|
20
17
|
'--config',
|
21
18
|
type=click.Path(exists=True),
|
22
19
|
required=True,
|
23
20
|
help='Path to the nodepool config file.')
|
24
|
-
@click.option(
|
25
|
-
'-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
|
26
21
|
@click.pass_context
|
27
|
-
def create(ctx, compute_cluster_id,
|
22
|
+
def create(ctx, compute_cluster_id, nodepool_id, config):
|
28
23
|
"""Create a new Nodepool with the given config file."""
|
24
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
29
25
|
|
26
|
+
validate_context(ctx)
|
30
27
|
nodepool_config = from_yaml(config)
|
31
28
|
if not compute_cluster_id:
|
32
29
|
if 'compute_cluster' not in nodepool_config['nodepool']:
|
@@ -42,50 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
|
|
42
39
|
|
43
40
|
compute_cluster = ComputeCluster(
|
44
41
|
compute_cluster_id=compute_cluster_id,
|
45
|
-
user_id=ctx.obj
|
46
|
-
pat=ctx.obj
|
47
|
-
base_url=ctx.obj
|
42
|
+
user_id=ctx.obj.current.user_id,
|
43
|
+
pat=ctx.obj.current.pat,
|
44
|
+
base_url=ctx.obj.current.api_base)
|
48
45
|
if nodepool_id:
|
49
46
|
compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
|
50
47
|
else:
|
51
48
|
compute_cluster.create_nodepool(config)
|
52
49
|
|
53
50
|
|
54
|
-
@nodepool.command()
|
55
|
-
@click.
|
56
|
-
'-cc_id',
|
57
|
-
'--compute_cluster_id',
|
58
|
-
required=True,
|
59
|
-
help='Compute Cluster ID for the compute cluster to interact with.')
|
51
|
+
@nodepool.command(['ls'])
|
52
|
+
@click.argument('compute_cluster_id', default="")
|
60
53
|
@click.option('--page_no', required=False, help='Page number to list.', default=1)
|
61
|
-
@click.option('--per_page', required=False, help='Number of items per page.', default=
|
54
|
+
@click.option('--per_page', required=False, help='Number of items per page.', default=128)
|
62
55
|
@click.pass_context
|
63
56
|
def list(ctx, compute_cluster_id, page_no, per_page):
|
64
|
-
"""List all nodepools for the user.
|
57
|
+
"""List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
|
58
|
+
it will list only within that compute cluster. """
|
59
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
60
|
+
from clarifai.client.user import User
|
65
61
|
|
66
|
-
|
67
|
-
compute_cluster_id=compute_cluster_id,
|
68
|
-
user_id=ctx.obj['user_id'],
|
69
|
-
pat=ctx.obj['pat'],
|
70
|
-
base_url=ctx.obj['base_url'])
|
71
|
-
response = compute_cluster.list_nodepools(page_no, per_page)
|
72
|
-
display_co_resources(response, "Nodepool")
|
62
|
+
validate_context(ctx)
|
73
63
|
|
64
|
+
cc_id = compute_cluster_id
|
74
65
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
66
|
+
if cc_id:
|
67
|
+
compute_cluster = ComputeCluster(
|
68
|
+
compute_cluster_id=cc_id,
|
69
|
+
user_id=ctx.obj.current.user_id,
|
70
|
+
pat=ctx.obj.current.pat,
|
71
|
+
base_url=ctx.obj.current.api_base)
|
72
|
+
response = compute_cluster.list_nodepools(page_no, per_page)
|
73
|
+
else:
|
74
|
+
user = User(
|
75
|
+
user_id=ctx.obj.current.user_id,
|
76
|
+
pat=ctx.obj.current.pat,
|
77
|
+
base_url=ctx.obj.current.api_base)
|
78
|
+
ccs = user.list_compute_clusters(page_no, per_page)
|
79
|
+
response = []
|
80
|
+
for cc in ccs:
|
81
|
+
compute_cluster = ComputeCluster(
|
82
|
+
compute_cluster_id=cc.id,
|
83
|
+
user_id=ctx.obj.current.user_id,
|
84
|
+
pat=ctx.obj.current.pat,
|
85
|
+
base_url=ctx.obj.current.api_base)
|
86
|
+
response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
|
87
|
+
|
88
|
+
display_co_resources(
|
89
|
+
response,
|
90
|
+
custom_columns={
|
91
|
+
'ID': lambda c: c.id,
|
92
|
+
'USER_ID': lambda c: c.compute_cluster.user_id,
|
93
|
+
'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
|
94
|
+
'DESCRIPTION': lambda c: c.description,
|
95
|
+
})
|
96
|
+
|
97
|
+
|
98
|
+
@nodepool.command(['rm'])
|
99
|
+
@click.argument('compute_cluster_id')
|
100
|
+
@click.argument('nodepool_id')
|
82
101
|
@click.pass_context
|
83
102
|
def delete(ctx, compute_cluster_id, nodepool_id):
|
84
103
|
"""Deletes a nodepool for the user."""
|
104
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
85
105
|
|
106
|
+
validate_context(ctx)
|
86
107
|
compute_cluster = ComputeCluster(
|
87
108
|
compute_cluster_id=compute_cluster_id,
|
88
|
-
user_id=ctx.obj
|
89
|
-
pat=ctx.obj
|
90
|
-
base_url=ctx.obj
|
109
|
+
user_id=ctx.obj.current.user_id,
|
110
|
+
pat=ctx.obj.current.pat,
|
111
|
+
base_url=ctx.obj.current.api_base)
|
91
112
|
compute_cluster.delete_nodepools([nodepool_id])
|
clarifai/client/app.py
CHANGED
@@ -629,7 +629,7 @@ class App(Lister, BaseClient):
|
|
629
629
|
|
630
630
|
Args:
|
631
631
|
model_id (str): The model ID for the model to interact with.
|
632
|
-
|
632
|
+
model_version (Dict): The model version ID for the model version to interact with.
|
633
633
|
|
634
634
|
Returns:
|
635
635
|
Model: A Model object for the existing model ID.
|
clarifai/client/auth/stub.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
import itertools
|
2
|
-
import logging
|
3
2
|
import time
|
4
3
|
from concurrent.futures import ThreadPoolExecutor
|
5
4
|
|
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
|
|
8
7
|
|
9
8
|
from clarifai.client.auth.helper import ClarifaiAuthHelper
|
10
9
|
from clarifai.client.auth.register import RpcCallable, V2Stub
|
11
|
-
|
10
|
+
from clarifai.utils.logging import logger
|
12
11
|
throttle_status_codes = {
|
13
12
|
status_code_pb2.CONN_THROTTLED,
|
14
13
|
status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
|
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
|
|
26
25
|
def handle_simple_response(response):
|
27
26
|
if hasattr(response, 'status') and hasattr(response.status, 'code'):
|
28
27
|
if (response.status.code in throttle_status_codes) and attempt < max_attempts:
|
29
|
-
|
28
|
+
logger.debug('Retrying with status %s' % str(response.status))
|
30
29
|
return None # Indicates a retry is needed
|
31
30
|
else:
|
32
31
|
return response
|
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
|
|
42
41
|
return itertools.chain([validated_response], response)
|
43
42
|
return None # Indicates a retry is needed
|
44
43
|
except grpc.RpcError as e:
|
45
|
-
|
44
|
+
logger.error('Error processing streaming response: %s' % str(e))
|
46
45
|
return None # Indicates an error
|
47
46
|
else:
|
48
47
|
# Handle simple response validation
|
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
|
|
143
142
|
return v
|
144
143
|
except grpc.RpcError as e:
|
145
144
|
if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
|
146
|
-
|
145
|
+
logger.debug('Retrying with status %s' % e.code())
|
147
146
|
else:
|
148
147
|
raise
|
149
148
|
|
clarifai/client/dataset.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import logging
|
2
1
|
import os
|
3
2
|
import time
|
4
3
|
import uuid
|
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
|
|
354
353
|
break
|
355
354
|
if failed_input_ids:
|
356
355
|
retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
|
357
|
-
|
356
|
+
logger.warning(
|
358
357
|
f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
|
359
358
|
)
|
360
359
|
failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
|
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
|
|
494
493
|
add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
|
495
494
|
|
496
495
|
if retry_duplicates and duplicate_input_ids:
|
497
|
-
|
496
|
+
logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
|
498
497
|
duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
|
499
498
|
self.upload_dataset(
|
500
499
|
dataloader=dataloader,
|
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
|
|
505
504
|
|
506
505
|
if failed_input_ids:
|
507
506
|
#failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
|
508
|
-
|
507
|
+
logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
|
509
508
|
failed_input_indexes = [input["Index"] for input in failed_input_ids]
|
510
509
|
self.upload_dataset(
|
511
510
|
dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
|
clarifai/client/model_client.py
CHANGED
@@ -59,12 +59,41 @@ class ModelClient:
|
|
59
59
|
Returns:
|
60
60
|
Dict: The method signatures.
|
61
61
|
'''
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
62
|
+
try:
|
63
|
+
response = self.client.STUB.GetModelVersion(
|
64
|
+
service_pb2.GetModelVersionRequest(
|
65
|
+
user_app_id=self.request_template.user_app_id,
|
66
|
+
model_id= self.request_template.model_id,
|
67
|
+
version_id=self.request_template.model_version.id,
|
68
|
+
))
|
69
|
+
method_signatures = None
|
70
|
+
if response.status.code == status_code_pb2.SUCCESS:
|
71
|
+
method_signatures = response.model_version.method_signatures
|
72
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
73
|
+
raise Exception(f"Model failed with response {response!r}")
|
74
|
+
self._method_signatures= {}
|
75
|
+
for method_signature in method_signatures:
|
76
|
+
method_name = method_signature.name
|
77
|
+
# check for duplicate method names
|
78
|
+
if method_name in self._method_signatures:
|
79
|
+
raise ValueError(f"Duplicate method name {method_name}")
|
80
|
+
self._method_signatures[method_name] = method_signature
|
81
|
+
if not self._method_signatures: # if no method signatures, try to fetch from the model
|
82
|
+
self._fetch_signatures_backup()
|
83
|
+
except Exception as e:
|
84
|
+
logger.info(f"Failed to fetch method signatures from model: {e}")
|
85
|
+
# try to fetch from the model
|
86
|
+
self._fetch_signatures_backup()
|
87
|
+
if not self._method_signatures:
|
88
|
+
raise ValueError("Failed to fetch method signatures from model and backup method")
|
89
|
+
|
90
|
+
def _fetch_signatures_backup(self):
|
91
|
+
'''
|
92
|
+
This is a temporary method of fetching the method signatures from the model.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Dict: The method signatures.
|
96
|
+
'''
|
68
97
|
|
69
98
|
request = service_pb2.PostModelOutputsRequest()
|
70
99
|
request.CopyFrom(self.request_template)
|
@@ -192,17 +192,21 @@ class ModelBuilder:
|
|
192
192
|
def _validate_config_checkpoints(self):
|
193
193
|
"""
|
194
194
|
Validates the checkpoints section in the config file.
|
195
|
+
return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
|
195
196
|
:return: loader_type the type of loader or None if no checkpoints.
|
196
197
|
:return: repo_id location of checkpoint.
|
197
198
|
:return: hf_token token to access checkpoint.
|
199
|
+
:return: when one of ['upload', 'build', 'runtime'] to download checkpoint
|
200
|
+
:return: allowed_file_patterns patterns to allow in downloaded checkpoint
|
201
|
+
:return: ignore_file_patterns patterns to ignore in downloaded checkpoint
|
198
202
|
"""
|
199
203
|
if "checkpoints" not in self.config:
|
200
|
-
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
|
204
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
201
205
|
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
|
202
206
|
loader_type = self.config.get("checkpoints").get("type")
|
203
207
|
if not loader_type:
|
204
208
|
logger.info("No loader type specified in the config file for checkpoints")
|
205
|
-
return None, None, None
|
209
|
+
return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
|
206
210
|
checkpoints = self.config.get("checkpoints")
|
207
211
|
if 'when' not in checkpoints:
|
208
212
|
logger.warn(
|
@@ -221,15 +225,30 @@ class ModelBuilder:
|
|
221
225
|
|
222
226
|
# get from config.yaml otherwise fall back to HF_TOKEN env var.
|
223
227
|
hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
|
224
|
-
|
228
|
+
|
229
|
+
allowed_file_patterns = self.config.get("checkpoints").get('allowed_file_patterns', None)
|
230
|
+
if isinstance(allowed_file_patterns, str):
|
231
|
+
allowed_file_patterns = [allowed_file_patterns]
|
232
|
+
ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
|
233
|
+
if isinstance(ignore_file_patterns, str):
|
234
|
+
ignore_file_patterns = [ignore_file_patterns]
|
235
|
+
return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
|
225
236
|
|
226
237
|
def _check_app_exists(self):
|
227
238
|
resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
|
228
239
|
if resp.status.code == status_code_pb2.SUCCESS:
|
229
240
|
return True
|
241
|
+
if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
|
242
|
+
logger.error(
|
243
|
+
f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
|
244
|
+
)
|
245
|
+
return False
|
230
246
|
logger.error(
|
231
247
|
f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
|
232
248
|
)
|
249
|
+
logger.error(
|
250
|
+
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
|
251
|
+
)
|
233
252
|
return False
|
234
253
|
|
235
254
|
def _validate_config_model(self):
|
@@ -250,9 +269,6 @@ class ModelBuilder:
|
|
250
269
|
assert model.get('id') != "", "model_id cannot be empty in the config file"
|
251
270
|
|
252
271
|
if not self._check_app_exists():
|
253
|
-
logger.error(
|
254
|
-
f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
|
255
|
-
)
|
256
272
|
sys.exit(1)
|
257
273
|
|
258
274
|
def _validate_config(self):
|
@@ -266,7 +282,7 @@ class ModelBuilder:
|
|
266
282
|
assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
|
267
283
|
|
268
284
|
if self.config.get("checkpoints"):
|
269
|
-
loader_type, _, hf_token, _ = self._validate_config_checkpoints()
|
285
|
+
loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
|
270
286
|
|
271
287
|
if loader_type == "huggingface" and hf_token:
|
272
288
|
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
|
@@ -282,7 +298,7 @@ class ModelBuilder:
|
|
282
298
|
f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
|
283
299
|
)
|
284
300
|
else:
|
285
|
-
num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS",
|
301
|
+
num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 16))
|
286
302
|
self.config["num_threads"] = num_threads
|
287
303
|
|
288
304
|
@staticmethod
|
@@ -354,8 +370,9 @@ class ModelBuilder:
|
|
354
370
|
|
355
371
|
assert "model_type_id" in model, "model_type_id not found in the config file"
|
356
372
|
assert "id" in model, "model_id not found in the config file"
|
357
|
-
|
358
|
-
|
373
|
+
if not self.download_validation_only:
|
374
|
+
assert "user_id" in model, "user_id not found in the config file"
|
375
|
+
assert "app_id" in model, "app_id not found in the config file"
|
359
376
|
|
360
377
|
model_proto = json_format.ParseDict(model, resources_pb2.Model())
|
361
378
|
|
@@ -466,11 +483,12 @@ class ModelBuilder:
|
|
466
483
|
# Sort in reverse so that newer cuda versions come first and are preferred.
|
467
484
|
for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
|
468
485
|
if torch_version in image and f'py{python_version}' in image:
|
469
|
-
|
486
|
+
# like cu124, rocm6.3, etc.
|
487
|
+
gpu_version = image.split('-')[-1]
|
470
488
|
final_image = TORCH_BASE_IMAGE.format(
|
471
489
|
torch_version=torch_version,
|
472
490
|
python_version=python_version,
|
473
|
-
|
491
|
+
gpu_version=gpu_version,
|
474
492
|
)
|
475
493
|
logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
|
476
494
|
break
|
@@ -547,8 +565,10 @@ class ModelBuilder:
|
|
547
565
|
if not self.config.get("checkpoints"):
|
548
566
|
logger.info("No checkpoints specified in the config file")
|
549
567
|
return path
|
568
|
+
clarifai_model_type_id = self.config.get('model').get('model_type_id')
|
550
569
|
|
551
|
-
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints(
|
570
|
+
loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = self._validate_config_checkpoints(
|
571
|
+
)
|
552
572
|
if stage not in ["build", "upload", "runtime"]:
|
553
573
|
raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
|
554
574
|
if when != stage:
|
@@ -557,14 +577,18 @@ class ModelBuilder:
|
|
557
577
|
)
|
558
578
|
return path
|
559
579
|
|
560
|
-
success =
|
580
|
+
success = False
|
561
581
|
if loader_type == "huggingface":
|
562
|
-
loader = HuggingFaceLoader(
|
582
|
+
loader = HuggingFaceLoader(
|
583
|
+
repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id)
|
563
584
|
# for runtime default to /tmp path
|
564
585
|
if stage == "runtime" and checkpoint_path_override is None:
|
565
586
|
checkpoint_path_override = self.default_runtime_checkpoint_path()
|
566
587
|
path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
|
567
|
-
success = loader.download_checkpoints(
|
588
|
+
success = loader.download_checkpoints(
|
589
|
+
path,
|
590
|
+
allowed_file_patterns=allowed_file_patterns,
|
591
|
+
ignore_file_patterns=ignore_file_patterns)
|
568
592
|
|
569
593
|
if loader_type:
|
570
594
|
if not success:
|
@@ -598,53 +622,12 @@ class ModelBuilder:
|
|
598
622
|
concepts = config.get('concepts')
|
599
623
|
logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
|
600
624
|
|
601
|
-
def filled_params_specs_with_inference_params(self, method_signatures: list[resources_pb2.MethodSignature]) -> list[resources_pb2.ModelTypeField]:
|
602
|
-
"""
|
603
|
-
Fills the params_specs with the inference params.
|
604
|
-
"""
|
605
|
-
inference_params = set()
|
606
|
-
for i, signature in enumerate(method_signatures):
|
607
|
-
for field in signature.input_fields:
|
608
|
-
if field.is_param:
|
609
|
-
if i==0:
|
610
|
-
inference_params.add(field.name)
|
611
|
-
else:
|
612
|
-
# if field.name not in inference_params then remove from inference_params
|
613
|
-
if field.name not in inference_params:
|
614
|
-
inference_params.remove(field.name)
|
615
|
-
output=[]
|
616
|
-
for signature in method_signatures:
|
617
|
-
for field in signature.input_fields:
|
618
|
-
if field.is_param and field.name in inference_params:
|
619
|
-
field.path = field.name
|
620
|
-
if field.type == resources_pb2.ModelTypeField.DataType.STR:
|
621
|
-
field.default_value= str(field.default)
|
622
|
-
field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.STRING
|
623
|
-
elif field.type == resources_pb2.ModelTypeField.DataType.INT:
|
624
|
-
field.default_value= int(field.default)
|
625
|
-
field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.NUMBER
|
626
|
-
elif field.type == resources_pb2.ModelTypeField.DataType.FLOAT:
|
627
|
-
field.default_value= float(field.default)
|
628
|
-
field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.NUMBER
|
629
|
-
elif field.type == resources_pb2.ModelTypeField.DataType.BOOL:
|
630
|
-
field.default_value= bool(field.default)
|
631
|
-
field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.BOOLEAN
|
632
|
-
else:
|
633
|
-
field.default_value= field.default
|
634
|
-
field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.STRING
|
635
|
-
output.append(field)
|
636
|
-
return output
|
637
|
-
|
638
|
-
|
639
625
|
def get_model_version_proto(self):
|
640
626
|
signatures = self.get_method_signatures()
|
641
627
|
model_version_proto = resources_pb2.ModelVersion(
|
642
628
|
pretrained_model_config=resources_pb2.PretrainedModelConfig(),
|
643
629
|
inference_compute_info=self.inference_compute_info,
|
644
630
|
method_signatures=signatures,
|
645
|
-
# output_info= resources_pb2.OutputInfo(
|
646
|
-
# params_specs=self.filled_params_specs_with_inference_params(signatures),
|
647
|
-
# )
|
648
631
|
)
|
649
632
|
|
650
633
|
model_type_id = self.config.get('model').get('model_type_id')
|
@@ -678,7 +661,7 @@ class ModelBuilder:
|
|
678
661
|
logger.debug(f"Will tar it into file: {file_path}")
|
679
662
|
|
680
663
|
model_type_id = self.config.get('model').get('model_type_id')
|
681
|
-
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
|
664
|
+
loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
|
682
665
|
|
683
666
|
if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
|
684
667
|
logger.info(
|
@@ -726,7 +709,7 @@ class ModelBuilder:
|
|
726
709
|
# First check for the env variable, then try querying huggingface. If all else fails, use the default.
|
727
710
|
checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
|
728
711
|
if not checkpoint_size:
|
729
|
-
_, repo_id, _, _ = self._validate_config_checkpoints()
|
712
|
+
_, repo_id, _, _, _, _ = self._validate_config_checkpoints()
|
730
713
|
checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
|
731
714
|
if not checkpoint_size:
|
732
715
|
checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
|
@@ -30,7 +30,8 @@ class ModelClass(ABC):
|
|
30
30
|
Example:
|
31
31
|
|
32
32
|
from clarifai.runners.model_class import ModelClass
|
33
|
-
from clarifai.runners.utils.data_types import NamedFields
|
33
|
+
from clarifai.runners.utils.data_types import NamedFields
|
34
|
+
from typing import List, Iterator
|
34
35
|
|
35
36
|
class MyModel(ModelClass):
|
36
37
|
|
@@ -39,12 +40,12 @@ class ModelClass(ABC):
|
|
39
40
|
return [x] * y
|
40
41
|
|
41
42
|
@ModelClass.method
|
42
|
-
def generate(self, x: str, y: int) ->
|
43
|
+
def generate(self, x: str, y: int) -> Iterator[str]:
|
43
44
|
for i in range(y):
|
44
45
|
yield x + str(i)
|
45
46
|
|
46
47
|
@ModelClass.method
|
47
|
-
def stream(self, input_stream:
|
48
|
+
def stream(self, input_stream: Iterator[NamedFields(x=str, y=int)]) -> Iterator[str]:
|
48
49
|
for item in input_stream:
|
49
50
|
yield item.x + ' ' + str(item.y)
|
50
51
|
'''
|
@@ -107,7 +108,8 @@ class ModelClass(ABC):
|
|
107
108
|
is_convert = DataConverter.is_old_format(input.data)
|
108
109
|
if is_convert:
|
109
110
|
# convert to new format
|
110
|
-
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
111
|
+
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
112
|
+
signature.input_fields)
|
111
113
|
input.data.CopyFrom(new_data)
|
112
114
|
# convert inputs to python types
|
113
115
|
inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
|
@@ -153,7 +155,8 @@ class ModelClass(ABC):
|
|
153
155
|
is_convert = DataConverter.is_old_format(input.data)
|
154
156
|
if is_convert:
|
155
157
|
# convert to new format
|
156
|
-
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
158
|
+
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
159
|
+
signature.input_fields)
|
157
160
|
input.data.CopyFrom(new_data)
|
158
161
|
inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
|
159
162
|
signature.input_fields, python_param_types)
|
@@ -214,7 +217,8 @@ class ModelClass(ABC):
|
|
214
217
|
is_convert = DataConverter.is_old_format(input.data)
|
215
218
|
if is_convert:
|
216
219
|
# convert to new format
|
217
|
-
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
220
|
+
new_data = DataConverter.convert_input_data_to_new_format(input.data,
|
221
|
+
signature.input_fields)
|
218
222
|
input.data.CopyFrom(new_data)
|
219
223
|
# convert all inputs for the first request, including the first stream value
|
220
224
|
inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
|
@@ -267,8 +271,8 @@ class ModelClass(ABC):
|
|
267
271
|
if k not in python_param_types:
|
268
272
|
continue
|
269
273
|
|
270
|
-
if hasattr(python_param_types[k], "__args__") and getattr(
|
271
|
-
|
274
|
+
if hasattr(python_param_types[k], "__args__") and getattr(python_param_types[k],
|
275
|
+
"__origin__", None) == Iterator:
|
272
276
|
# get the type of the items in the stream
|
273
277
|
stream_type = python_param_types[k].__args__[0]
|
274
278
|
|
@@ -11,7 +11,6 @@ import traceback
|
|
11
11
|
import venv
|
12
12
|
|
13
13
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
14
|
-
|
15
14
|
from clarifai.runners.models.model_builder import ModelBuilder
|
16
15
|
from clarifai.utils.logging import logger
|
17
16
|
|
@@ -104,11 +103,6 @@ class ModelRunLocally:
|
|
104
103
|
],
|
105
104
|
)
|
106
105
|
|
107
|
-
def _build_stream_request(self):
|
108
|
-
request = self._build_request()
|
109
|
-
for i in range(1):
|
110
|
-
yield request
|
111
|
-
|
112
106
|
def _run_test(self):
|
113
107
|
"""Test the model locally by making a prediction."""
|
114
108
|
# Create the model
|
@@ -404,39 +398,33 @@ def main(model_path,
|
|
404
398
|
inside_container=False,
|
405
399
|
port=8080,
|
406
400
|
keep_env=False,
|
407
|
-
keep_image=False
|
401
|
+
keep_image=False,
|
402
|
+
skip_dockerfile: bool = False):
|
408
403
|
|
409
|
-
if not os.environ.get("CLARIFAI_PAT", None):
|
410
|
-
logger.error(
|
411
|
-
"CLARIFAI_PAT environment variable is not set! Please set your PAT in the 'CLARIFAI_PAT' environment variable."
|
412
|
-
)
|
413
|
-
sys.exit(1)
|
414
404
|
manager = ModelRunLocally(model_path)
|
415
405
|
# get whatever stage is in config.yaml to force download now
|
416
406
|
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
|
417
|
-
_, _, _, when = manager.builder._validate_config_checkpoints()
|
407
|
+
_, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
|
418
408
|
manager.builder.download_checkpoints(
|
419
409
|
stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
|
420
410
|
if inside_container:
|
421
411
|
if not manager.is_docker_installed():
|
422
412
|
sys.exit(1)
|
423
|
-
|
413
|
+
if not skip_dockerfile:
|
414
|
+
manager.builder.create_dockerfile()
|
424
415
|
image_tag = manager._docker_hash()
|
425
|
-
|
426
|
-
|
416
|
+
model_id = manager.config['model']['id'].lower()
|
417
|
+
# must be in lowercase
|
418
|
+
image_name = f"{model_id}:{image_tag}"
|
419
|
+
container_name = model_id
|
427
420
|
if not manager.docker_image_exists(image_name):
|
428
421
|
manager.build_docker_image(image_name=image_name)
|
429
422
|
try:
|
430
|
-
envs = {
|
431
|
-
'CLARIFAI_PAT': os.environ['CLARIFAI_PAT'],
|
432
|
-
'CLARIFAI_API_BASE': os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')
|
433
|
-
}
|
434
423
|
if run_model_server:
|
435
424
|
manager.run_docker_container(
|
436
|
-
image_name=image_name, container_name=container_name, port=port
|
425
|
+
image_name=image_name, container_name=container_name, port=port)
|
437
426
|
else:
|
438
|
-
manager.test_model_container(
|
439
|
-
image_name=image_name, container_name=container_name, env_vars=envs)
|
427
|
+
manager.test_model_container(image_name=image_name, container_name=container_name)
|
440
428
|
finally:
|
441
429
|
if manager.container_exists(container_name):
|
442
430
|
manager.stop_docker_container(container_name)
|
clarifai/runners/utils/const.py
CHANGED
@@ -2,10 +2,10 @@ import os
|
|
2
2
|
|
3
3
|
registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
|
4
4
|
|
5
|
-
GIT_SHA = "
|
5
|
+
GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
|
6
6
|
|
7
7
|
PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
|
8
|
-
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-
|
8
|
+
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
|
9
9
|
|
10
10
|
# List of available python base images
|
11
11
|
AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
|
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
|
|
21
21
|
# List of available torch images
|
22
22
|
# Keep sorted by most recent cuda version.
|
23
23
|
AVAILABLE_TORCH_IMAGES = [
|
24
|
-
'2.4.1-py3.11-
|
25
|
-
'2.5.1-py3.11-
|
26
|
-
'2.4.1-py3.12-
|
27
|
-
'2.5.1-py3.12-
|
28
|
-
|
29
|
-
|
24
|
+
'2.4.1-py3.11-cu124',
|
25
|
+
'2.5.1-py3.11-cu124',
|
26
|
+
'2.4.1-py3.12-cu124',
|
27
|
+
'2.5.1-py3.12-cu124',
|
28
|
+
'2.6.0-py3.12-cu126',
|
29
|
+
'2.7.0-py3.12-cu128',
|
30
|
+
'2.7.0-py3.12-rocm6.3',
|
30
31
|
]
|
31
32
|
CONCEPTS_REQUIRED_MODEL_TYPE = [
|
32
33
|
'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
|