clarifai 10.8.5__py3-none-any.whl → 10.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/user.py CHANGED
@@ -1,5 +1,7 @@
1
+ import os
1
2
  from typing import Any, Dict, Generator, List
2
3
 
4
+ import yaml
3
5
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
4
6
  from clarifai_grpc.grpc.api.status import status_code_pb2
5
7
  from google.protobuf.json_format import MessageToDict
@@ -7,9 +9,10 @@ from google.protobuf.wrappers_pb2 import BoolValue
7
9
 
8
10
  from clarifai.client.app import App
9
11
  from clarifai.client.base import BaseClient
12
+ from clarifai.client.compute_cluster import ComputeCluster
10
13
  from clarifai.client.lister import Lister
11
14
  from clarifai.errors import UserError
12
- from clarifai.utils.logging import get_logger
15
+ from clarifai.utils.logging import logger
13
16
 
14
17
 
15
18
  class User(Lister, BaseClient):
@@ -34,7 +37,7 @@ class User(Lister, BaseClient):
34
37
  """
35
38
  self.kwargs = {**kwargs, 'id': user_id}
36
39
  self.user_info = resources_pb2.User(**self.kwargs)
37
- self.logger = get_logger(logger_level="INFO", name=__name__)
40
+ self.logger = logger
38
41
  BaseClient.__init__(
39
42
  self,
40
43
  user_id=self.id,
@@ -109,6 +112,37 @@ class User(Lister, BaseClient):
109
112
  for runner_info in all_runners_info:
110
113
  yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
111
114
 
115
+ def list_compute_clusters(self, page_no: int = None,
116
+ per_page: int = None) -> Generator[dict, None, None]:
117
+ """List all compute clusters for the user
118
+
119
+ Args:
120
+ page_no (int): The page number to list.
121
+ per_page (int): The number of items per page.
122
+
123
+ Yields:
124
+ Dict: Dictionaries containing information about the compute clusters.
125
+
126
+ Example:
127
+ >>> from clarifai.client.user import User
128
+ >>> client = User(user_id="user_id")
129
+ >>> all_compute_clusters= list(client.list_compute_clusters())
130
+
131
+ Note:
132
+ Defaults to 16 per page if page_no is specified and per_page is not specified.
133
+ If both page_no and per_page are None, then lists all the resources.
134
+ """
135
+ request_data = dict(user_app_id=self.user_app_id)
136
+ all_compute_clusters_info = self.list_pages_generator(
137
+ self.STUB.ListComputeClusters,
138
+ service_pb2.ListComputeClustersRequest,
139
+ request_data,
140
+ per_page=per_page,
141
+ page_no=page_no)
142
+
143
+ for compute_cluster_info in all_compute_clusters_info:
144
+ yield ComputeCluster.from_auth_helper(self.auth_helper, **compute_cluster_info)
145
+
112
146
  def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
113
147
  """Creates an app for the user.
114
148
 
@@ -172,6 +206,59 @@ class User(Lister, BaseClient):
172
206
  description=description,
173
207
  check_runner_exists=False)
174
208
 
209
+ def _process_compute_cluster_config(self, config_filepath: str) -> Dict[str, Any]:
210
+ with open(config_filepath, "r") as file:
211
+ compute_cluster_config = yaml.safe_load(file)
212
+
213
+ assert "compute_cluster" in compute_cluster_config, "compute cluster info not found in the config file"
214
+ compute_cluster = compute_cluster_config['compute_cluster']
215
+ assert "region" in compute_cluster, "region not found in the config file"
216
+ assert "managed_by" in compute_cluster, "managed_by not found in the config file"
217
+ assert "cluster_type" in compute_cluster, "cluster_type not found in the config file"
218
+ compute_cluster['cloud_provider'] = resources_pb2.CloudProvider(
219
+ **compute_cluster['cloud_provider'])
220
+ compute_cluster['key'] = resources_pb2.Key(id=self.pat)
221
+ if "visibility" in compute_cluster:
222
+ compute_cluster["visibility"] = resources_pb2.Visibility(**compute_cluster["visibility"])
223
+ return compute_cluster
224
+
225
+ def create_compute_cluster(self, compute_cluster_id: str,
226
+ config_filepath: str) -> ComputeCluster:
227
+ """Creates a compute cluster for the user.
228
+
229
+ Args:
230
+ compute_cluster_id (str): The compute cluster ID for the compute cluster to create.
231
+ config_filepath (str): The path to the compute cluster config file.
232
+
233
+ Returns:
234
+ ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
235
+
236
+ Example:
237
+ >>> from clarifai.client.user import User
238
+ >>> client = User(user_id="user_id")
239
+ >>> compute_cluster = client.create_compute_cluster(compute_cluster_id="compute_cluster_id", config_filepath="config.yml")
240
+ """
241
+ if not os.path.exists(config_filepath):
242
+ raise UserError(f"Compute Cluster config file not found at {config_filepath}")
243
+
244
+ compute_cluster_config = self._process_compute_cluster_config(config_filepath)
245
+
246
+ if 'id' in compute_cluster_config:
247
+ compute_cluster_id = compute_cluster_config['id']
248
+ compute_cluster_config.pop('id')
249
+
250
+ request = service_pb2.PostComputeClustersRequest(
251
+ user_app_id=self.user_app_id,
252
+ compute_clusters=[
253
+ resources_pb2.ComputeCluster(id=compute_cluster_id, **compute_cluster_config)
254
+ ])
255
+ response = self._grpc_request(self.STUB.PostComputeClusters, request)
256
+ if response.status.code != status_code_pb2.SUCCESS:
257
+ raise Exception(response.status)
258
+ self.logger.info("\nCompute Cluster created\n%s", response.status)
259
+ return ComputeCluster.from_auth_helper(
260
+ auth=self.auth_helper, compute_cluster_id=compute_cluster_id)
261
+
175
262
  def app(self, app_id: str, **kwargs) -> App:
176
263
  """Returns an App object for the specified app ID.
177
264
 
@@ -223,6 +310,31 @@ class User(Lister, BaseClient):
223
310
 
224
311
  return dict(self.auth_helper, check_runner_exists=False, **kwargs)
225
312
 
313
+ def compute_cluster(self, compute_cluster_id: str) -> ComputeCluster:
314
+ """Returns an Compute Cluster object for the specified compute cluster ID.
315
+
316
+ Args:
317
+ compute_cluster_id (str): The compute cluster ID for the compute cluster to interact with.
318
+
319
+ Returns:
320
+ ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
321
+
322
+ Example:
323
+ >>> from clarifai.client.user import User
324
+ >>> compute_cluster = User("user_id").compute_cluster("compute_cluster_id")
325
+ """
326
+ request = service_pb2.GetComputeClusterRequest(
327
+ user_app_id=self.user_app_id, compute_cluster_id=compute_cluster_id)
328
+ response = self._grpc_request(self.STUB.GetComputeCluster, request)
329
+ if response.status.code != status_code_pb2.SUCCESS:
330
+ raise Exception(response.status)
331
+
332
+ dict_response = MessageToDict(response, preserving_proto_field_name=True)
333
+ kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
334
+ list(dict_response.keys())[1])
335
+
336
+ return ComputeCluster.from_auth_helper(auth=self.auth_helper, **kwargs)
337
+
226
338
  def patch_app(self, app_id: str, action: str = 'overwrite', **kwargs) -> App:
227
339
  """Patch an app for the user.
228
340
 
@@ -290,6 +402,25 @@ class User(Lister, BaseClient):
290
402
  raise Exception(response.status)
291
403
  self.logger.info("\nRunner Deleted\n%s", response.status)
292
404
 
405
+ def delete_compute_clusters(self, compute_cluster_ids: List[str]) -> None:
406
+ """Deletes a list of compute clusters for the user.
407
+
408
+ Args:
409
+ compute_cluster_ids (List[str]): The compute cluster IDs of the user to delete.
410
+
411
+ Example:
412
+ >>> from clarifai.client.user import User
413
+ >>> user = User("user_id").delete_compute_clusters(compute_cluster_ids=["compute_cluster_id1", "compute_cluster_id2"])
414
+ """
415
+ assert isinstance(compute_cluster_ids, list), "compute_cluster_ids param should be a list"
416
+
417
+ request = service_pb2.DeleteComputeClustersRequest(
418
+ user_app_id=self.user_app_id, ids=compute_cluster_ids)
419
+ response = self._grpc_request(self.STUB.DeleteComputeClusters, request)
420
+ if response.status.code != status_code_pb2.SUCCESS:
421
+ raise Exception(response.status)
422
+ self.logger.info("\nCompute Cluster Deleted\n%s", response.status)
423
+
293
424
  def __getattr__(self, name):
294
425
  return getattr(self.user_info, name)
295
426
 
@@ -12,7 +12,7 @@ from clarifai.client.lister import Lister
12
12
  from clarifai.constants.workflow import MAX_WORKFLOW_PREDICT_INPUTS
13
13
  from clarifai.errors import UserError
14
14
  from clarifai.urls.helper import ClarifaiUrlHelper
15
- from clarifai.utils.logging import get_logger
15
+ from clarifai.utils.logging import logger
16
16
  from clarifai.utils.misc import BackoffIterator
17
17
  from clarifai.workflows.export import Exporter
18
18
 
@@ -59,7 +59,7 @@ class Workflow(Lister, BaseClient):
59
59
  self.kwargs = {**kwargs, 'id': workflow_id, 'version': workflow_version}
60
60
  self.output_config = output_config
61
61
  self.workflow_info = resources_pb2.Workflow(**self.kwargs)
62
- self.logger = get_logger(logger_level="INFO", name=__name__)
62
+ self.logger = logger
63
63
  BaseClient.__init__(
64
64
  self,
65
65
  user_id=self.user_id,
@@ -0,0 +1 @@
1
+ COMPUTE_ORCHESTRATION_RESOURCES = ['Runner', 'ComputeCluster', 'Nodepool', 'Deployment']
@@ -14,9 +14,7 @@ from tqdm import tqdm
14
14
 
15
15
  from clarifai.constants.dataset import CONTENT_TYPE
16
16
  from clarifai.errors import UserError
17
- from clarifai.utils.logging import get_logger
18
-
19
- logger = get_logger("INFO", __name__)
17
+ from clarifai.utils.logging import logger
20
18
 
21
19
 
22
20
  class DatasetExportReader:
clarifai/rag/rag.py CHANGED
@@ -15,7 +15,7 @@ from clarifai.errors import UserError
15
15
  from clarifai.rag.utils import (convert_messages_to_str, format_assistant_message, load_documents,
16
16
  split_document)
17
17
  from clarifai.utils.constants import CLARIFAI_USER_ID_ENV_VAR
18
- from clarifai.utils.logging import get_logger
18
+ from clarifai.utils.logging import logger
19
19
  from clarifai.utils.misc import get_from_dict_or_env
20
20
 
21
21
  DEFAULT_RAG_PROMPT_TEMPLATE = "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
@@ -40,7 +40,7 @@ class RAG:
40
40
  **kwargs):
41
41
  """Initialize an empty or existing RAG.
42
42
  """
43
- self.logger = get_logger(logger_level="INFO", name=__name__)
43
+ self.logger = logger
44
44
  if workflow_url is not None and workflow is None:
45
45
  self.logger.info("workflow_url:%s", workflow_url)
46
46
  w = Workflow(workflow_url, base_url=base_url, pat=pat)
@@ -10,8 +10,8 @@ from google.protobuf import json_format
10
10
  from rich import print
11
11
 
12
12
  from clarifai.client import BaseClient
13
-
14
13
  from clarifai.runners.utils.loader import HuggingFaceLoarder
14
+ from clarifai.utils.logging import logger
15
15
 
16
16
 
17
17
  def _clear_line(n: int = 1) -> None:
@@ -28,12 +28,11 @@ class ModelUploader:
28
28
  ]
29
29
 
30
30
  def __init__(self, folder: str):
31
+ self._client = None
31
32
  self.folder = self._validate_folder(folder)
32
33
  self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
33
- self.initialize_client()
34
34
  self.model_proto = self._get_model_proto()
35
35
  self.model_id = self.model_proto.id
36
- self.user_app_id = self.client.user_app_id
37
36
  self.inference_compute_info = self._get_inference_compute_info()
38
37
  self.is_v3 = True # Do model build for v3
39
38
 
@@ -41,7 +40,9 @@ class ModelUploader:
41
40
  def _validate_folder(folder):
42
41
  if not folder.startswith("/"):
43
42
  folder = os.path.join(os.getcwd(), folder)
44
- print(f"Validating folder: {folder}")
43
+ logger.info(f"Validating folder: {folder}")
44
+ if not os.path.exists(folder):
45
+ raise FileNotFoundError(f"Folder {folder} not found, please provide a valid folder path")
45
46
  files = os.listdir(folder)
46
47
  assert "requirements.txt" in files, "requirements.txt not found in the folder"
47
48
  assert "config.yaml" in files, "config.yaml not found in the folder"
@@ -56,18 +57,21 @@ class ModelUploader:
56
57
  config = yaml.safe_load(file)
57
58
  return config
58
59
 
59
- def initialize_client(self):
60
- assert "model" in self.config, "model info not found in the config file"
61
- model = self.config.get('model')
62
- assert "user_id" in model, "user_id not found in the config file"
63
- assert "app_id" in model, "app_id not found in the config file"
64
- user_id = model.get('user_id')
65
- app_id = model.get('app_id')
60
+ @property
61
+ def client(self):
62
+ if self._client is None:
63
+ assert "model" in self.config, "model info not found in the config file"
64
+ model = self.config.get('model')
65
+ assert "user_id" in model, "user_id not found in the config file"
66
+ assert "app_id" in model, "app_id not found in the config file"
67
+ user_id = model.get('user_id')
68
+ app_id = model.get('app_id')
66
69
 
67
- base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
70
+ base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
68
71
 
69
- self.client = BaseClient(user_id=user_id, app_id=app_id, base=base)
70
- print(f"Client initialized for user {user_id} and app {app_id}")
72
+ self._client = BaseClient(user_id=user_id, app_id=app_id, base=base)
73
+ logger.info(f"Client initialized for user {user_id} and app {app_id}")
74
+ return self._client
71
75
 
72
76
  def _get_model_proto(self):
73
77
  assert "model" in self.config, "model info not found in the config file"
@@ -96,7 +100,7 @@ class ModelUploader:
96
100
  service_pb2.GetModelRequest(
97
101
  user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
98
102
  if resp.status.code == status_code_pb2.SUCCESS:
99
- print(
103
+ logger.info(
100
104
  f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
101
105
  f"will create a new version for it.")
102
106
  return resp
@@ -127,7 +131,15 @@ class ModelUploader:
127
131
 
128
132
  # Get the Python version from the config file
129
133
  build_info = self.config.get('build_info', {})
130
- python_version = build_info.get('python_version', self.DEFAULT_PYTHON_VERSION)
134
+ if 'python_version' in build_info:
135
+ python_version = build_info['python_version']
136
+ logger.info(
137
+ f"Using Python version {python_version} from the config file to build the Dockerfile")
138
+ else:
139
+ logger.info(
140
+ f"Python version not found in the config file, using default Python version: {self.DEFAULT_PYTHON_VERSION}"
141
+ )
142
+ python_version = self.DEFAULT_PYTHON_VERSION
131
143
 
132
144
  # Replace placeholders with actual values
133
145
  dockerfile_content = dockerfile_template.safe_substitute(
@@ -141,25 +153,33 @@ class ModelUploader:
141
153
 
142
154
  def download_checkpoints(self):
143
155
  if not self.config.get("checkpoints"):
144
- print("No checkpoints specified in the config file")
156
+ logger.info("No checkpoints specified in the config file")
145
157
  return
146
158
 
147
159
  assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
148
160
  loader_type = self.config.get("checkpoints").get("type")
149
161
  if not loader_type:
150
- print("No loader type specified in the config file for checkpoints")
162
+ logger.info("No loader type specified in the config file for checkpoints")
151
163
  assert loader_type == "huggingface", "Only huggingface loader supported for now"
152
164
  if loader_type == "huggingface":
153
165
  assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
154
166
  repo_id = self.config.get("checkpoints").get("repo_id")
155
167
 
156
- hf_token = self.config.get("checkpoints").get("hf_token", None)
168
+ # prefer env var for HF_TOKEN but if not provided then use the one from config.yaml if any.
169
+ if 'HF_TOKEN' in os.environ:
170
+ hf_token = os.environ['HF_TOKEN']
171
+ else:
172
+ hf_token = self.config.get("checkpoints").get("hf_token", None)
173
+ assert hf_token != 'hf_token', "The default 'hf_token' is not valid. Please provide a valid token or leave that field out of config.yaml if not needed."
157
174
  loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
158
175
 
159
176
  checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
160
- loader.download_checkpoints(checkpoint_path)
177
+ success = loader.download_checkpoints(checkpoint_path)
161
178
 
162
- print(f"Downloaded checkpoints for model {repo_id}")
179
+ if not success:
180
+ logger.error(f"Failed to download checkpoints for model {repo_id}")
181
+ return
182
+ logger.info(f"Downloaded checkpoints for model {repo_id}")
163
183
 
164
184
  def _concepts_protos_from_concepts(self, concepts):
165
185
  concept_protos = []
@@ -183,7 +203,7 @@ class ModelUploader:
183
203
  with open(config_file, 'w') as file:
184
204
  yaml.dump(config, file, sort_keys=False)
185
205
  concepts = config.get('concepts')
186
- print(f"Updated config.yaml with {len(concepts)} concepts.")
206
+ logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
187
207
 
188
208
  def _get_model_version_proto(self):
189
209
 
@@ -209,11 +229,11 @@ class ModelUploader:
209
229
 
210
230
  def upload_model_version(self):
211
231
  file_path = f"{self.folder}.tar.gz"
212
- print(f"Will tar it into file: {file_path}")
232
+ logger.info(f"Will tar it into file: {file_path}")
213
233
 
214
234
  # Tar the folder
215
235
  os.system(f"tar --exclude=*~ -czvf {self.folder}.tar.gz -C {self.folder} .")
216
- print("Tarring complete, about to start upload.")
236
+ logger.info("Tarring complete, about to start upload.")
217
237
 
218
238
  model_version = self._get_model_version_proto()
219
239
 
@@ -230,14 +250,15 @@ class ModelUploader:
230
250
  print(
231
251
  f"Status: {response.status.description}, "
232
252
  f"Progress: {percent_completed}% - {details} ",
253
+ f"request_id: {response.status.req_id}",
233
254
  end='\r',
234
255
  flush=True)
235
256
  print()
236
257
  if response.status.code != status_code_pb2.MODEL_BUILDING:
237
- print(f"Failed to upload model version: {response.status.description}")
258
+ logger.error(f"Failed to upload model version: {response.status.description}")
238
259
  return
239
260
  model_version_id = response.model_version_id
240
- print(f"Created Model Version ID: {model_version_id}")
261
+ logger.info(f"Created Model Version ID: {model_version_id}")
241
262
 
242
263
  self.monitor_model_build(model_version_id)
243
264
 
@@ -247,24 +268,35 @@ class ModelUploader:
247
268
  file_size = os.path.getsize(file_path)
248
269
  chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
249
270
  num_chunks = (file_size // chunk_size) + 1
250
-
271
+ logger.info("Uploading file...")
272
+ logger.info("File size: ", file_size)
273
+ logger.info("Chunk size: ", chunk_size)
274
+ logger.info("Number of chunks: ", num_chunks)
251
275
  read_so_far = 0
252
276
  for part_id in range(num_chunks):
253
- chunk = f.read(chunk_size)
254
- read_so_far += len(chunk)
255
- yield service_pb2.PostModelVersionsUploadRequest(
256
- content_part=resources_pb2.UploadContentPart(
257
- data=chunk,
258
- part_number=part_id + 1,
259
- range_start=read_so_far,
260
- ))
261
- print("\nUpload complete!, waiting for model build...")
277
+ try:
278
+ chunk_size = min(chunk_size, file_size - read_so_far)
279
+ chunk = f.read(chunk_size)
280
+ if not chunk:
281
+ break
282
+ read_so_far += len(chunk)
283
+ yield service_pb2.PostModelVersionsUploadRequest(
284
+ content_part=resources_pb2.UploadContentPart(
285
+ data=chunk,
286
+ part_number=part_id + 1,
287
+ range_start=read_so_far,
288
+ ))
289
+ except Exception as e:
290
+ logger.exception(f"\nError uploading file: {e}")
291
+ break
292
+
293
+ if read_so_far == file_size:
294
+ logger.info("\nUpload complete!, waiting for model build...")
262
295
 
263
296
  def init_upload_model_version(self, model_version, file_path):
264
297
  file_size = os.path.getsize(file_path)
265
- print(
266
- f"Uploading model version '{model_version.id}' with file '{os.path.basename(file_path)}' of size {file_size} bytes..."
267
- )
298
+ logger.info(f"Uploading model version '{model_version.id}' of model {self.model_proto.id}")
299
+ logger.info(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
268
300
  return service_pb2.PostModelVersionsUploadRequest(
269
301
  upload_config=service_pb2.PostModelVersionsUploadConfig(
270
302
  user_app_id=self.client.user_app_id,
@@ -285,22 +317,25 @@ class ModelUploader:
285
317
  ))
286
318
  status_code = resp.model_version.status.code
287
319
  if status_code == status_code_pb2.MODEL_BUILDING:
288
- print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
320
+ logger.info(
321
+ f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
289
322
  time.sleep(1)
290
323
  elif status_code == status_code_pb2.MODEL_TRAINED:
291
- print("\nModel build complete!")
292
- print(
293
- f"Check out the model at https://clarifai.com/{self.user_app_id.user_id}/apps/{self.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
324
+ logger.info("\nModel build complete!")
325
+ logger.info(
326
+ f"Check out the model at https://clarifai.com/{self.client.user_app_id.user_id}/apps/{self.client.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
294
327
  )
295
328
  break
296
329
  else:
297
- print(f"\nModel build failed with status: {resp.model_version.status}")
330
+ logger.info(
331
+ f"\nModel build failed with status: {resp.model_version.status} and response {resp}")
298
332
  break
299
333
 
300
334
 
301
- def main(folder):
335
+ def main(folder, download_checkpoints):
302
336
  uploader = ModelUploader(folder)
303
- uploader.download_checkpoints()
337
+ if download_checkpoints:
338
+ uploader.download_checkpoints()
304
339
  uploader.create_dockerfile()
305
340
  input("Press Enter to continue...")
306
341
  uploader.upload_model_version()
@@ -310,6 +345,13 @@ if __name__ == "__main__":
310
345
  parser = argparse.ArgumentParser()
311
346
  parser.add_argument(
312
347
  '--model_path', type=str, help='Path of the model folder to upload', required=True)
348
+ # flag to default to not download checkpoints
349
+ parser.add_argument(
350
+ '--download_checkpoints',
351
+ action='store_true',
352
+ help=
353
+ 'Flag to download checkpoints before uploading and including them in the tar file that is uploaded. Defaults to False, which will attempt to download them at docker build time.',
354
+ )
313
355
  args = parser.parse_args()
314
356
 
315
- main(args.model_path)
357
+ main(args.model_path, args.download_checkpoints)
@@ -15,7 +15,7 @@ from clarifai_protocol import BaseRunner
15
15
  from clarifai_protocol.utils.grpc_server import GRPCServer
16
16
 
17
17
  from clarifai.runners.models.model_servicer import ModelServicer
18
- from clarifai.runners.utils.logging import logger
18
+ from clarifai.utils.logging import logger
19
19
 
20
20
 
21
21
  def main():
@@ -3,6 +3,8 @@ import json
3
3
  import os
4
4
  import subprocess
5
5
 
6
+ from clarifai.utils.logging import logger
7
+
6
8
 
7
9
  class HuggingFaceLoarder:
8
10
 
@@ -29,22 +31,24 @@ class HuggingFaceLoarder:
29
31
  "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
30
32
  )
31
33
  if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
32
- print("Checkpoints already exist")
34
+ logger.info("Checkpoints already exist")
35
+ return True
33
36
  else:
34
37
  os.makedirs(checkpoint_path, exist_ok=True)
35
38
  try:
36
39
  is_hf_model_exists = self.validate_hf_model()
37
40
  if not is_hf_model_exists:
38
- print("Model not found on Hugging Face")
41
+ logger.error("Model %s not found on Hugging Face" % (self.repo_id))
39
42
  return False
40
- snapshot_download(repo_id=self.repo_id, local_dir=checkpoint_path)
43
+ snapshot_download(
44
+ repo_id=self.repo_id, local_dir=checkpoint_path, local_dir_use_symlinks=False)
41
45
  except Exception as e:
42
- print("Error downloading model checkpoints ", e)
46
+ logger.exception(f"Error downloading model checkpoints {e}")
43
47
  return False
44
48
  finally:
45
49
  is_downloaded = self.validate_download(checkpoint_path)
46
50
  if not is_downloaded:
47
- print("Error downloading model checkpoints")
51
+ logger.error("Error validating downloaded model checkpoints")
48
52
  return False
49
53
  return True
50
54
 
@@ -57,8 +61,10 @@ class HuggingFaceLoarder:
57
61
  def validate_download(self, checkpoint_path: str):
58
62
  # check if model exists on HF
59
63
  from huggingface_hub import list_repo_files
60
-
61
- return (len(os.listdir(checkpoint_path)) >= len(list_repo_files(self.repo_id))) and len(
64
+ checkpoint_dir_files = [
65
+ f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn
66
+ ]
67
+ return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
62
68
  list_repo_files(self.repo_id)) > 0
63
69
 
64
70
  def fetch_labels(self, checkpoint_path: str):
@@ -2,7 +2,7 @@ import concurrent.futures
2
2
 
3
3
  import fsspec
4
4
 
5
- from .logging import logger
5
+ from clarifai.utils.logging import logger
6
6
 
7
7
 
8
8
  def download_input(input):
@@ -20,8 +20,7 @@ except ImportError:
20
20
  try:
21
21
  from loguru import logger
22
22
  except ImportError:
23
- from ..logging import get_logger
24
- logger = get_logger(logger_level="INFO", name=__name__)
23
+ from ..logging import logger
25
24
 
26
25
  MACRO_AVG = "macro_avg"
27
26
 
@@ -26,8 +26,7 @@ except ImportError:
26
26
  try:
27
27
  from loguru import logger
28
28
  except ImportError:
29
- from ..logging import get_logger
30
- logger = get_logger(logger_level="INFO", name=__name__)
29
+ from ..logging import logger
31
30
 
32
31
  __all__ = ['EvalResultCompare']
33
32
 
clarifai/utils/logging.py CHANGED
@@ -19,6 +19,8 @@ from rich.tree import Tree
19
19
 
20
20
  install()
21
21
 
22
+ # The default logger to use throughout the SDK is defined at bottom of this file.
23
+
22
24
  # For the json logger.
23
25
  JSON_LOGGER_NAME = "clarifai-json"
24
26
  JSON_LOG_KEY = 'msg'
@@ -357,3 +359,7 @@ class JsonFormatter(logging.Formatter):
357
359
  default=self.json_default,
358
360
  cls=self.json_cls,
359
361
  )
362
+
363
+
364
+ # the default logger for the SDK.
365
+ logger = get_logger(logger_level=os.environ.get("LOG_LEVEL", "INFO"), name="clarifai")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 10.8.5
3
+ Version: 10.8.7
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai