clarifai 10.8.5__py3-none-any.whl → 10.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/client/app.py +2 -3
- clarifai/client/auth/helper.py +9 -11
- clarifai/client/base.py +4 -3
- clarifai/client/compute_cluster.py +196 -0
- clarifai/client/dataset.py +2 -2
- clarifai/client/deployment.py +51 -0
- clarifai/client/input.py +2 -2
- clarifai/client/model.py +2 -2
- clarifai/client/module.py +2 -2
- clarifai/client/nodepool.py +207 -0
- clarifai/client/user.py +133 -2
- clarifai/client/workflow.py +2 -2
- clarifai/constants/base.py +1 -0
- clarifai/datasets/export/inputs_annotations.py +1 -3
- clarifai/rag/rag.py +2 -2
- clarifai/runners/models/model_upload.py +89 -47
- clarifai/runners/server.py +1 -1
- clarifai/runners/utils/loader.py +13 -7
- clarifai/runners/utils/url_fetcher.py +1 -1
- clarifai/utils/evaluation/helpers.py +1 -2
- clarifai/utils/evaluation/main.py +1 -2
- clarifai/utils/logging.py +6 -0
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/METADATA +1 -1
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/RECORD +29 -26
- clarifai/runners/utils/logging.py +0 -6
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/LICENSE +0 -0
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/WHEEL +0 -0
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/entry_points.txt +0 -0
- {clarifai-10.8.5.dist-info → clarifai-10.8.7.dist-info}/top_level.txt +0 -0
clarifai/client/user.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
import os
|
1
2
|
from typing import Any, Dict, Generator, List
|
2
3
|
|
4
|
+
import yaml
|
3
5
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
4
6
|
from clarifai_grpc.grpc.api.status import status_code_pb2
|
5
7
|
from google.protobuf.json_format import MessageToDict
|
@@ -7,9 +9,10 @@ from google.protobuf.wrappers_pb2 import BoolValue
|
|
7
9
|
|
8
10
|
from clarifai.client.app import App
|
9
11
|
from clarifai.client.base import BaseClient
|
12
|
+
from clarifai.client.compute_cluster import ComputeCluster
|
10
13
|
from clarifai.client.lister import Lister
|
11
14
|
from clarifai.errors import UserError
|
12
|
-
from clarifai.utils.logging import
|
15
|
+
from clarifai.utils.logging import logger
|
13
16
|
|
14
17
|
|
15
18
|
class User(Lister, BaseClient):
|
@@ -34,7 +37,7 @@ class User(Lister, BaseClient):
|
|
34
37
|
"""
|
35
38
|
self.kwargs = {**kwargs, 'id': user_id}
|
36
39
|
self.user_info = resources_pb2.User(**self.kwargs)
|
37
|
-
self.logger =
|
40
|
+
self.logger = logger
|
38
41
|
BaseClient.__init__(
|
39
42
|
self,
|
40
43
|
user_id=self.id,
|
@@ -109,6 +112,37 @@ class User(Lister, BaseClient):
|
|
109
112
|
for runner_info in all_runners_info:
|
110
113
|
yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
111
114
|
|
115
|
+
def list_compute_clusters(self, page_no: int = None,
|
116
|
+
per_page: int = None) -> Generator[dict, None, None]:
|
117
|
+
"""List all compute clusters for the user
|
118
|
+
|
119
|
+
Args:
|
120
|
+
page_no (int): The page number to list.
|
121
|
+
per_page (int): The number of items per page.
|
122
|
+
|
123
|
+
Yields:
|
124
|
+
Dict: Dictionaries containing information about the compute clusters.
|
125
|
+
|
126
|
+
Example:
|
127
|
+
>>> from clarifai.client.user import User
|
128
|
+
>>> client = User(user_id="user_id")
|
129
|
+
>>> all_compute_clusters= list(client.list_compute_clusters())
|
130
|
+
|
131
|
+
Note:
|
132
|
+
Defaults to 16 per page if page_no is specified and per_page is not specified.
|
133
|
+
If both page_no and per_page are None, then lists all the resources.
|
134
|
+
"""
|
135
|
+
request_data = dict(user_app_id=self.user_app_id)
|
136
|
+
all_compute_clusters_info = self.list_pages_generator(
|
137
|
+
self.STUB.ListComputeClusters,
|
138
|
+
service_pb2.ListComputeClustersRequest,
|
139
|
+
request_data,
|
140
|
+
per_page=per_page,
|
141
|
+
page_no=page_no)
|
142
|
+
|
143
|
+
for compute_cluster_info in all_compute_clusters_info:
|
144
|
+
yield ComputeCluster.from_auth_helper(self.auth_helper, **compute_cluster_info)
|
145
|
+
|
112
146
|
def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
|
113
147
|
"""Creates an app for the user.
|
114
148
|
|
@@ -172,6 +206,59 @@ class User(Lister, BaseClient):
|
|
172
206
|
description=description,
|
173
207
|
check_runner_exists=False)
|
174
208
|
|
209
|
+
def _process_compute_cluster_config(self, config_filepath: str) -> Dict[str, Any]:
|
210
|
+
with open(config_filepath, "r") as file:
|
211
|
+
compute_cluster_config = yaml.safe_load(file)
|
212
|
+
|
213
|
+
assert "compute_cluster" in compute_cluster_config, "compute cluster info not found in the config file"
|
214
|
+
compute_cluster = compute_cluster_config['compute_cluster']
|
215
|
+
assert "region" in compute_cluster, "region not found in the config file"
|
216
|
+
assert "managed_by" in compute_cluster, "managed_by not found in the config file"
|
217
|
+
assert "cluster_type" in compute_cluster, "cluster_type not found in the config file"
|
218
|
+
compute_cluster['cloud_provider'] = resources_pb2.CloudProvider(
|
219
|
+
**compute_cluster['cloud_provider'])
|
220
|
+
compute_cluster['key'] = resources_pb2.Key(id=self.pat)
|
221
|
+
if "visibility" in compute_cluster:
|
222
|
+
compute_cluster["visibility"] = resources_pb2.Visibility(**compute_cluster["visibility"])
|
223
|
+
return compute_cluster
|
224
|
+
|
225
|
+
def create_compute_cluster(self, compute_cluster_id: str,
|
226
|
+
config_filepath: str) -> ComputeCluster:
|
227
|
+
"""Creates a compute cluster for the user.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
compute_cluster_id (str): The compute cluster ID for the compute cluster to create.
|
231
|
+
config_filepath (str): The path to the compute cluster config file.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
|
235
|
+
|
236
|
+
Example:
|
237
|
+
>>> from clarifai.client.user import User
|
238
|
+
>>> client = User(user_id="user_id")
|
239
|
+
>>> compute_cluster = client.create_compute_cluster(compute_cluster_id="compute_cluster_id", config_filepath="config.yml")
|
240
|
+
"""
|
241
|
+
if not os.path.exists(config_filepath):
|
242
|
+
raise UserError(f"Compute Cluster config file not found at {config_filepath}")
|
243
|
+
|
244
|
+
compute_cluster_config = self._process_compute_cluster_config(config_filepath)
|
245
|
+
|
246
|
+
if 'id' in compute_cluster_config:
|
247
|
+
compute_cluster_id = compute_cluster_config['id']
|
248
|
+
compute_cluster_config.pop('id')
|
249
|
+
|
250
|
+
request = service_pb2.PostComputeClustersRequest(
|
251
|
+
user_app_id=self.user_app_id,
|
252
|
+
compute_clusters=[
|
253
|
+
resources_pb2.ComputeCluster(id=compute_cluster_id, **compute_cluster_config)
|
254
|
+
])
|
255
|
+
response = self._grpc_request(self.STUB.PostComputeClusters, request)
|
256
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
257
|
+
raise Exception(response.status)
|
258
|
+
self.logger.info("\nCompute Cluster created\n%s", response.status)
|
259
|
+
return ComputeCluster.from_auth_helper(
|
260
|
+
auth=self.auth_helper, compute_cluster_id=compute_cluster_id)
|
261
|
+
|
175
262
|
def app(self, app_id: str, **kwargs) -> App:
|
176
263
|
"""Returns an App object for the specified app ID.
|
177
264
|
|
@@ -223,6 +310,31 @@ class User(Lister, BaseClient):
|
|
223
310
|
|
224
311
|
return dict(self.auth_helper, check_runner_exists=False, **kwargs)
|
225
312
|
|
313
|
+
def compute_cluster(self, compute_cluster_id: str) -> ComputeCluster:
|
314
|
+
"""Returns an Compute Cluster object for the specified compute cluster ID.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
compute_cluster_id (str): The compute cluster ID for the compute cluster to interact with.
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
|
321
|
+
|
322
|
+
Example:
|
323
|
+
>>> from clarifai.client.user import User
|
324
|
+
>>> compute_cluster = User("user_id").compute_cluster("compute_cluster_id")
|
325
|
+
"""
|
326
|
+
request = service_pb2.GetComputeClusterRequest(
|
327
|
+
user_app_id=self.user_app_id, compute_cluster_id=compute_cluster_id)
|
328
|
+
response = self._grpc_request(self.STUB.GetComputeCluster, request)
|
329
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
330
|
+
raise Exception(response.status)
|
331
|
+
|
332
|
+
dict_response = MessageToDict(response, preserving_proto_field_name=True)
|
333
|
+
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
334
|
+
list(dict_response.keys())[1])
|
335
|
+
|
336
|
+
return ComputeCluster.from_auth_helper(auth=self.auth_helper, **kwargs)
|
337
|
+
|
226
338
|
def patch_app(self, app_id: str, action: str = 'overwrite', **kwargs) -> App:
|
227
339
|
"""Patch an app for the user.
|
228
340
|
|
@@ -290,6 +402,25 @@ class User(Lister, BaseClient):
|
|
290
402
|
raise Exception(response.status)
|
291
403
|
self.logger.info("\nRunner Deleted\n%s", response.status)
|
292
404
|
|
405
|
+
def delete_compute_clusters(self, compute_cluster_ids: List[str]) -> None:
|
406
|
+
"""Deletes a list of compute clusters for the user.
|
407
|
+
|
408
|
+
Args:
|
409
|
+
compute_cluster_ids (List[str]): The compute cluster IDs of the user to delete.
|
410
|
+
|
411
|
+
Example:
|
412
|
+
>>> from clarifai.client.user import User
|
413
|
+
>>> user = User("user_id").delete_compute_clusters(compute_cluster_ids=["compute_cluster_id1", "compute_cluster_id2"])
|
414
|
+
"""
|
415
|
+
assert isinstance(compute_cluster_ids, list), "compute_cluster_ids param should be a list"
|
416
|
+
|
417
|
+
request = service_pb2.DeleteComputeClustersRequest(
|
418
|
+
user_app_id=self.user_app_id, ids=compute_cluster_ids)
|
419
|
+
response = self._grpc_request(self.STUB.DeleteComputeClusters, request)
|
420
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
421
|
+
raise Exception(response.status)
|
422
|
+
self.logger.info("\nCompute Cluster Deleted\n%s", response.status)
|
423
|
+
|
293
424
|
def __getattr__(self, name):
|
294
425
|
return getattr(self.user_info, name)
|
295
426
|
|
clarifai/client/workflow.py
CHANGED
@@ -12,7 +12,7 @@ from clarifai.client.lister import Lister
|
|
12
12
|
from clarifai.constants.workflow import MAX_WORKFLOW_PREDICT_INPUTS
|
13
13
|
from clarifai.errors import UserError
|
14
14
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
15
|
-
from clarifai.utils.logging import
|
15
|
+
from clarifai.utils.logging import logger
|
16
16
|
from clarifai.utils.misc import BackoffIterator
|
17
17
|
from clarifai.workflows.export import Exporter
|
18
18
|
|
@@ -59,7 +59,7 @@ class Workflow(Lister, BaseClient):
|
|
59
59
|
self.kwargs = {**kwargs, 'id': workflow_id, 'version': workflow_version}
|
60
60
|
self.output_config = output_config
|
61
61
|
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
|
62
|
-
self.logger =
|
62
|
+
self.logger = logger
|
63
63
|
BaseClient.__init__(
|
64
64
|
self,
|
65
65
|
user_id=self.user_id,
|
@@ -0,0 +1 @@
|
|
1
|
+
COMPUTE_ORCHESTRATION_RESOURCES = ['Runner', 'ComputeCluster', 'Nodepool', 'Deployment']
|
@@ -14,9 +14,7 @@ from tqdm import tqdm
|
|
14
14
|
|
15
15
|
from clarifai.constants.dataset import CONTENT_TYPE
|
16
16
|
from clarifai.errors import UserError
|
17
|
-
from clarifai.utils.logging import
|
18
|
-
|
19
|
-
logger = get_logger("INFO", __name__)
|
17
|
+
from clarifai.utils.logging import logger
|
20
18
|
|
21
19
|
|
22
20
|
class DatasetExportReader:
|
clarifai/rag/rag.py
CHANGED
@@ -15,7 +15,7 @@ from clarifai.errors import UserError
|
|
15
15
|
from clarifai.rag.utils import (convert_messages_to_str, format_assistant_message, load_documents,
|
16
16
|
split_document)
|
17
17
|
from clarifai.utils.constants import CLARIFAI_USER_ID_ENV_VAR
|
18
|
-
from clarifai.utils.logging import
|
18
|
+
from clarifai.utils.logging import logger
|
19
19
|
from clarifai.utils.misc import get_from_dict_or_env
|
20
20
|
|
21
21
|
DEFAULT_RAG_PROMPT_TEMPLATE = "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
|
@@ -40,7 +40,7 @@ class RAG:
|
|
40
40
|
**kwargs):
|
41
41
|
"""Initialize an empty or existing RAG.
|
42
42
|
"""
|
43
|
-
self.logger =
|
43
|
+
self.logger = logger
|
44
44
|
if workflow_url is not None and workflow is None:
|
45
45
|
self.logger.info("workflow_url:%s", workflow_url)
|
46
46
|
w = Workflow(workflow_url, base_url=base_url, pat=pat)
|
@@ -10,8 +10,8 @@ from google.protobuf import json_format
|
|
10
10
|
from rich import print
|
11
11
|
|
12
12
|
from clarifai.client import BaseClient
|
13
|
-
|
14
13
|
from clarifai.runners.utils.loader import HuggingFaceLoarder
|
14
|
+
from clarifai.utils.logging import logger
|
15
15
|
|
16
16
|
|
17
17
|
def _clear_line(n: int = 1) -> None:
|
@@ -28,12 +28,11 @@ class ModelUploader:
|
|
28
28
|
]
|
29
29
|
|
30
30
|
def __init__(self, folder: str):
|
31
|
+
self._client = None
|
31
32
|
self.folder = self._validate_folder(folder)
|
32
33
|
self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
|
33
|
-
self.initialize_client()
|
34
34
|
self.model_proto = self._get_model_proto()
|
35
35
|
self.model_id = self.model_proto.id
|
36
|
-
self.user_app_id = self.client.user_app_id
|
37
36
|
self.inference_compute_info = self._get_inference_compute_info()
|
38
37
|
self.is_v3 = True # Do model build for v3
|
39
38
|
|
@@ -41,7 +40,9 @@ class ModelUploader:
|
|
41
40
|
def _validate_folder(folder):
|
42
41
|
if not folder.startswith("/"):
|
43
42
|
folder = os.path.join(os.getcwd(), folder)
|
44
|
-
|
43
|
+
logger.info(f"Validating folder: {folder}")
|
44
|
+
if not os.path.exists(folder):
|
45
|
+
raise FileNotFoundError(f"Folder {folder} not found, please provide a valid folder path")
|
45
46
|
files = os.listdir(folder)
|
46
47
|
assert "requirements.txt" in files, "requirements.txt not found in the folder"
|
47
48
|
assert "config.yaml" in files, "config.yaml not found in the folder"
|
@@ -56,18 +57,21 @@ class ModelUploader:
|
|
56
57
|
config = yaml.safe_load(file)
|
57
58
|
return config
|
58
59
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
60
|
+
@property
|
61
|
+
def client(self):
|
62
|
+
if self._client is None:
|
63
|
+
assert "model" in self.config, "model info not found in the config file"
|
64
|
+
model = self.config.get('model')
|
65
|
+
assert "user_id" in model, "user_id not found in the config file"
|
66
|
+
assert "app_id" in model, "app_id not found in the config file"
|
67
|
+
user_id = model.get('user_id')
|
68
|
+
app_id = model.get('app_id')
|
66
69
|
|
67
|
-
|
70
|
+
base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
|
68
71
|
|
69
|
-
|
70
|
-
|
72
|
+
self._client = BaseClient(user_id=user_id, app_id=app_id, base=base)
|
73
|
+
logger.info(f"Client initialized for user {user_id} and app {app_id}")
|
74
|
+
return self._client
|
71
75
|
|
72
76
|
def _get_model_proto(self):
|
73
77
|
assert "model" in self.config, "model info not found in the config file"
|
@@ -96,7 +100,7 @@ class ModelUploader:
|
|
96
100
|
service_pb2.GetModelRequest(
|
97
101
|
user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
|
98
102
|
if resp.status.code == status_code_pb2.SUCCESS:
|
99
|
-
|
103
|
+
logger.info(
|
100
104
|
f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
|
101
105
|
f"will create a new version for it.")
|
102
106
|
return resp
|
@@ -127,7 +131,15 @@ class ModelUploader:
|
|
127
131
|
|
128
132
|
# Get the Python version from the config file
|
129
133
|
build_info = self.config.get('build_info', {})
|
130
|
-
|
134
|
+
if 'python_version' in build_info:
|
135
|
+
python_version = build_info['python_version']
|
136
|
+
logger.info(
|
137
|
+
f"Using Python version {python_version} from the config file to build the Dockerfile")
|
138
|
+
else:
|
139
|
+
logger.info(
|
140
|
+
f"Python version not found in the config file, using default Python version: {self.DEFAULT_PYTHON_VERSION}"
|
141
|
+
)
|
142
|
+
python_version = self.DEFAULT_PYTHON_VERSION
|
131
143
|
|
132
144
|
# Replace placeholders with actual values
|
133
145
|
dockerfile_content = dockerfile_template.safe_substitute(
|
@@ -141,25 +153,33 @@ class ModelUploader:
|
|
141
153
|
|
142
154
|
def download_checkpoints(self):
|
143
155
|
if not self.config.get("checkpoints"):
|
144
|
-
|
156
|
+
logger.info("No checkpoints specified in the config file")
|
145
157
|
return
|
146
158
|
|
147
159
|
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
|
148
160
|
loader_type = self.config.get("checkpoints").get("type")
|
149
161
|
if not loader_type:
|
150
|
-
|
162
|
+
logger.info("No loader type specified in the config file for checkpoints")
|
151
163
|
assert loader_type == "huggingface", "Only huggingface loader supported for now"
|
152
164
|
if loader_type == "huggingface":
|
153
165
|
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
|
154
166
|
repo_id = self.config.get("checkpoints").get("repo_id")
|
155
167
|
|
156
|
-
|
168
|
+
# prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any.
|
169
|
+
if 'HF_TOKEN' in os.environ:
|
170
|
+
hf_token = os.environ['HF_TOKEN']
|
171
|
+
else:
|
172
|
+
hf_token = self.config.get("checkpoints").get("hf_token", None)
|
173
|
+
assert hf_token != 'hf_token', "The default 'hf_token' is not valid. Please provide a valid token or leave that field out of config.yaml if not needed."
|
157
174
|
loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
|
158
175
|
|
159
176
|
checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
|
160
|
-
loader.download_checkpoints(checkpoint_path)
|
177
|
+
success = loader.download_checkpoints(checkpoint_path)
|
161
178
|
|
162
|
-
|
179
|
+
if not success:
|
180
|
+
logger.error(f"Failed to download checkpoints for model {repo_id}")
|
181
|
+
return
|
182
|
+
logger.info(f"Downloaded checkpoints for model {repo_id}")
|
163
183
|
|
164
184
|
def _concepts_protos_from_concepts(self, concepts):
|
165
185
|
concept_protos = []
|
@@ -183,7 +203,7 @@ class ModelUploader:
|
|
183
203
|
with open(config_file, 'w') as file:
|
184
204
|
yaml.dump(config, file, sort_keys=False)
|
185
205
|
concepts = config.get('concepts')
|
186
|
-
|
206
|
+
logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
|
187
207
|
|
188
208
|
def _get_model_version_proto(self):
|
189
209
|
|
@@ -209,11 +229,11 @@ class ModelUploader:
|
|
209
229
|
|
210
230
|
def upload_model_version(self):
|
211
231
|
file_path = f"{self.folder}.tar.gz"
|
212
|
-
|
232
|
+
logger.info(f"Will tar it into file: {file_path}")
|
213
233
|
|
214
234
|
# Tar the folder
|
215
235
|
os.system(f"tar --exclude=*~ -czvf {self.folder}.tar.gz -C {self.folder} .")
|
216
|
-
|
236
|
+
logger.info("Tarring complete, about to start upload.")
|
217
237
|
|
218
238
|
model_version = self._get_model_version_proto()
|
219
239
|
|
@@ -230,14 +250,15 @@ class ModelUploader:
|
|
230
250
|
print(
|
231
251
|
f"Status: {response.status.description}, "
|
232
252
|
f"Progress: {percent_completed}% - {details} ",
|
253
|
+
f"request_id: {response.status.req_id}",
|
233
254
|
end='\r',
|
234
255
|
flush=True)
|
235
256
|
print()
|
236
257
|
if response.status.code != status_code_pb2.MODEL_BUILDING:
|
237
|
-
|
258
|
+
logger.error(f"Failed to upload model version: {response.status.description}")
|
238
259
|
return
|
239
260
|
model_version_id = response.model_version_id
|
240
|
-
|
261
|
+
logger.info(f"Created Model Version ID: {model_version_id}")
|
241
262
|
|
242
263
|
self.monitor_model_build(model_version_id)
|
243
264
|
|
@@ -247,24 +268,35 @@ class ModelUploader:
|
|
247
268
|
file_size = os.path.getsize(file_path)
|
248
269
|
chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
|
249
270
|
num_chunks = (file_size // chunk_size) + 1
|
250
|
-
|
271
|
+
logger.info("Uploading file...")
|
272
|
+
logger.info("File size: ", file_size)
|
273
|
+
logger.info("Chunk size: ", chunk_size)
|
274
|
+
logger.info("Number of chunks: ", num_chunks)
|
251
275
|
read_so_far = 0
|
252
276
|
for part_id in range(num_chunks):
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
277
|
+
try:
|
278
|
+
chunk_size = min(chunk_size, file_size - read_so_far)
|
279
|
+
chunk = f.read(chunk_size)
|
280
|
+
if not chunk:
|
281
|
+
break
|
282
|
+
read_so_far += len(chunk)
|
283
|
+
yield service_pb2.PostModelVersionsUploadRequest(
|
284
|
+
content_part=resources_pb2.UploadContentPart(
|
285
|
+
data=chunk,
|
286
|
+
part_number=part_id + 1,
|
287
|
+
range_start=read_so_far,
|
288
|
+
))
|
289
|
+
except Exception as e:
|
290
|
+
logger.exception(f"\nError uploading file: {e}")
|
291
|
+
break
|
292
|
+
|
293
|
+
if read_so_far == file_size:
|
294
|
+
logger.info("\nUpload complete!, waiting for model build...")
|
262
295
|
|
263
296
|
def init_upload_model_version(self, model_version, file_path):
|
264
297
|
file_size = os.path.getsize(file_path)
|
265
|
-
|
266
|
-
|
267
|
-
)
|
298
|
+
logger.info(f"Uploading model version '{model_version.id}' of model {self.model_proto.id}")
|
299
|
+
logger.info(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
|
268
300
|
return service_pb2.PostModelVersionsUploadRequest(
|
269
301
|
upload_config=service_pb2.PostModelVersionsUploadConfig(
|
270
302
|
user_app_id=self.client.user_app_id,
|
@@ -285,22 +317,25 @@ class ModelUploader:
|
|
285
317
|
))
|
286
318
|
status_code = resp.model_version.status.code
|
287
319
|
if status_code == status_code_pb2.MODEL_BUILDING:
|
288
|
-
|
320
|
+
logger.info(
|
321
|
+
f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
|
289
322
|
time.sleep(1)
|
290
323
|
elif status_code == status_code_pb2.MODEL_TRAINED:
|
291
|
-
|
292
|
-
|
293
|
-
f"Check out the model at https://clarifai.com/{self.user_app_id.user_id}/apps/{self.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
|
324
|
+
logger.info("\nModel build complete!")
|
325
|
+
logger.info(
|
326
|
+
f"Check out the model at https://clarifai.com/{self.client.user_app_id.user_id}/apps/{self.client.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
|
294
327
|
)
|
295
328
|
break
|
296
329
|
else:
|
297
|
-
|
330
|
+
logger.info(
|
331
|
+
f"\nModel build failed with status: {resp.model_version.status} and response {resp}")
|
298
332
|
break
|
299
333
|
|
300
334
|
|
301
|
-
def main(folder):
|
335
|
+
def main(folder, download_checkpoints):
|
302
336
|
uploader = ModelUploader(folder)
|
303
|
-
|
337
|
+
if download_checkpoints:
|
338
|
+
uploader.download_checkpoints()
|
304
339
|
uploader.create_dockerfile()
|
305
340
|
input("Press Enter to continue...")
|
306
341
|
uploader.upload_model_version()
|
@@ -310,6 +345,13 @@ if __name__ == "__main__":
|
|
310
345
|
parser = argparse.ArgumentParser()
|
311
346
|
parser.add_argument(
|
312
347
|
'--model_path', type=str, help='Path of the model folder to upload', required=True)
|
348
|
+
# flag to default to not download checkpoints
|
349
|
+
parser.add_argument(
|
350
|
+
'--download_checkpoints',
|
351
|
+
action='store_true',
|
352
|
+
help=
|
353
|
+
'Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.',
|
354
|
+
)
|
313
355
|
args = parser.parse_args()
|
314
356
|
|
315
|
-
main(args.model_path)
|
357
|
+
main(args.model_path, args.download_checkpoints)
|
clarifai/runners/server.py
CHANGED
@@ -15,7 +15,7 @@ from clarifai_protocol import BaseRunner
|
|
15
15
|
from clarifai_protocol.utils.grpc_server import GRPCServer
|
16
16
|
|
17
17
|
from clarifai.runners.models.model_servicer import ModelServicer
|
18
|
-
from clarifai.
|
18
|
+
from clarifai.utils.logging import logger
|
19
19
|
|
20
20
|
|
21
21
|
def main():
|
clarifai/runners/utils/loader.py
CHANGED
@@ -3,6 +3,8 @@ import json
|
|
3
3
|
import os
|
4
4
|
import subprocess
|
5
5
|
|
6
|
+
from clarifai.utils.logging import logger
|
7
|
+
|
6
8
|
|
7
9
|
class HuggingFaceLoarder:
|
8
10
|
|
@@ -29,22 +31,24 @@ class HuggingFaceLoarder:
|
|
29
31
|
"The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
|
30
32
|
)
|
31
33
|
if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
|
32
|
-
|
34
|
+
logger.info("Checkpoints already exist")
|
35
|
+
return True
|
33
36
|
else:
|
34
37
|
os.makedirs(checkpoint_path, exist_ok=True)
|
35
38
|
try:
|
36
39
|
is_hf_model_exists = self.validate_hf_model()
|
37
40
|
if not is_hf_model_exists:
|
38
|
-
|
41
|
+
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
|
39
42
|
return False
|
40
|
-
snapshot_download(
|
43
|
+
snapshot_download(
|
44
|
+
repo_id=self.repo_id, local_dir=checkpoint_path, local_dir_use_symlinks=False)
|
41
45
|
except Exception as e:
|
42
|
-
|
46
|
+
logger.exception(f"Error downloading model checkpoints {e}")
|
43
47
|
return False
|
44
48
|
finally:
|
45
49
|
is_downloaded = self.validate_download(checkpoint_path)
|
46
50
|
if not is_downloaded:
|
47
|
-
|
51
|
+
logger.error("Error validating downloaded model checkpoints")
|
48
52
|
return False
|
49
53
|
return True
|
50
54
|
|
@@ -57,8 +61,10 @@ class HuggingFaceLoarder:
|
|
57
61
|
def validate_download(self, checkpoint_path: str):
|
58
62
|
# check if model exists on HF
|
59
63
|
from huggingface_hub import list_repo_files
|
60
|
-
|
61
|
-
|
64
|
+
checkpoint_dir_files = [
|
65
|
+
f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn
|
66
|
+
]
|
67
|
+
return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
|
62
68
|
list_repo_files(self.repo_id)) > 0
|
63
69
|
|
64
70
|
def fetch_labels(self, checkpoint_path: str):
|
clarifai/utils/logging.py
CHANGED
@@ -19,6 +19,8 @@ from rich.tree import Tree
|
|
19
19
|
|
20
20
|
install()
|
21
21
|
|
22
|
+
# The default logger to use throughout the SDK is defined at bottom of this file.
|
23
|
+
|
22
24
|
# For the json logger.
|
23
25
|
JSON_LOGGER_NAME = "clarifai-json"
|
24
26
|
JSON_LOG_KEY = 'msg'
|
@@ -357,3 +359,7 @@ class JsonFormatter(logging.Formatter):
|
|
357
359
|
default=self.json_default,
|
358
360
|
cls=self.json_cls,
|
359
361
|
)
|
362
|
+
|
363
|
+
|
364
|
+
# the default logger for the SDK.
|
365
|
+
logger = get_logger(logger_level=os.environ.get("LOG_LEVEL", "INFO"), name="clarifai")
|