clarifai 10.8.6__py3-none-any.whl → 10.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/user.py CHANGED
@@ -1,5 +1,7 @@
1
+ import os
1
2
  from typing import Any, Dict, Generator, List
2
3
 
4
+ import yaml
3
5
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
4
6
  from clarifai_grpc.grpc.api.status import status_code_pb2
5
7
  from google.protobuf.json_format import MessageToDict
@@ -7,9 +9,10 @@ from google.protobuf.wrappers_pb2 import BoolValue
7
9
 
8
10
  from clarifai.client.app import App
9
11
  from clarifai.client.base import BaseClient
12
+ from clarifai.client.compute_cluster import ComputeCluster
10
13
  from clarifai.client.lister import Lister
11
14
  from clarifai.errors import UserError
12
- from clarifai.utils.logging import get_logger
15
+ from clarifai.utils.logging import logger
13
16
 
14
17
 
15
18
  class User(Lister, BaseClient):
@@ -34,7 +37,7 @@ class User(Lister, BaseClient):
34
37
  """
35
38
  self.kwargs = {**kwargs, 'id': user_id}
36
39
  self.user_info = resources_pb2.User(**self.kwargs)
37
- self.logger = get_logger(logger_level="INFO", name=__name__)
40
+ self.logger = logger
38
41
  BaseClient.__init__(
39
42
  self,
40
43
  user_id=self.id,
@@ -109,6 +112,37 @@ class User(Lister, BaseClient):
109
112
  for runner_info in all_runners_info:
110
113
  yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
111
114
 
115
+ def list_compute_clusters(self, page_no: int = None,
116
+ per_page: int = None) -> Generator[dict, None, None]:
117
+ """List all compute clusters for the user
118
+
119
+ Args:
120
+ page_no (int): The page number to list.
121
+ per_page (int): The number of items per page.
122
+
123
+ Yields:
124
+ Dict: Dictionaries containing information about the compute clusters.
125
+
126
+ Example:
127
+ >>> from clarifai.client.user import User
128
+ >>> client = User(user_id="user_id")
129
+ >>> all_compute_clusters= list(client.list_compute_clusters())
130
+
131
+ Note:
132
+ Defaults to 16 per page if page_no is specified and per_page is not specified.
133
+ If both page_no and per_page are None, then lists all the resources.
134
+ """
135
+ request_data = dict(user_app_id=self.user_app_id)
136
+ all_compute_clusters_info = self.list_pages_generator(
137
+ self.STUB.ListComputeClusters,
138
+ service_pb2.ListComputeClustersRequest,
139
+ request_data,
140
+ per_page=per_page,
141
+ page_no=page_no)
142
+
143
+ for compute_cluster_info in all_compute_clusters_info:
144
+ yield ComputeCluster.from_auth_helper(self.auth_helper, **compute_cluster_info)
145
+
112
146
  def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
113
147
  """Creates an app for the user.
114
148
 
@@ -172,6 +206,59 @@ class User(Lister, BaseClient):
172
206
  description=description,
173
207
  check_runner_exists=False)
174
208
 
209
+ def _process_compute_cluster_config(self, config_filepath: str) -> Dict[str, Any]:
210
+ with open(config_filepath, "r") as file:
211
+ compute_cluster_config = yaml.safe_load(file)
212
+
213
+ assert "compute_cluster" in compute_cluster_config, "compute cluster info not found in the config file"
214
+ compute_cluster = compute_cluster_config['compute_cluster']
215
+ assert "region" in compute_cluster, "region not found in the config file"
216
+ assert "managed_by" in compute_cluster, "managed_by not found in the config file"
217
+ assert "cluster_type" in compute_cluster, "cluster_type not found in the config file"
218
+ compute_cluster['cloud_provider'] = resources_pb2.CloudProvider(
219
+ **compute_cluster['cloud_provider'])
220
+ compute_cluster['key'] = resources_pb2.Key(id=self.pat)
221
+ if "visibility" in compute_cluster:
222
+ compute_cluster["visibility"] = resources_pb2.Visibility(**compute_cluster["visibility"])
223
+ return compute_cluster
224
+
225
+ def create_compute_cluster(self, compute_cluster_id: str,
226
+ config_filepath: str) -> ComputeCluster:
227
+ """Creates a compute cluster for the user.
228
+
229
+ Args:
230
+ compute_cluster_id (str): The compute cluster ID for the compute cluster to create.
231
+ config_filepath (str): The path to the compute cluster config file.
232
+
233
+ Returns:
234
+ ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
235
+
236
+ Example:
237
+ >>> from clarifai.client.user import User
238
+ >>> client = User(user_id="user_id")
239
+ >>> compute_cluster = client.create_compute_cluster(compute_cluster_id="compute_cluster_id", config_filepath="config.yml")
240
+ """
241
+ if not os.path.exists(config_filepath):
242
+ raise UserError(f"Compute Cluster config file not found at {config_filepath}")
243
+
244
+ compute_cluster_config = self._process_compute_cluster_config(config_filepath)
245
+
246
+ if 'id' in compute_cluster_config:
247
+ compute_cluster_id = compute_cluster_config['id']
248
+ compute_cluster_config.pop('id')
249
+
250
+ request = service_pb2.PostComputeClustersRequest(
251
+ user_app_id=self.user_app_id,
252
+ compute_clusters=[
253
+ resources_pb2.ComputeCluster(id=compute_cluster_id, **compute_cluster_config)
254
+ ])
255
+ response = self._grpc_request(self.STUB.PostComputeClusters, request)
256
+ if response.status.code != status_code_pb2.SUCCESS:
257
+ raise Exception(response.status)
258
+ self.logger.info("\nCompute Cluster created\n%s", response.status)
259
+ return ComputeCluster.from_auth_helper(
260
+ auth=self.auth_helper, compute_cluster_id=compute_cluster_id)
261
+
175
262
  def app(self, app_id: str, **kwargs) -> App:
176
263
  """Returns an App object for the specified app ID.
177
264
 
@@ -223,6 +310,31 @@ class User(Lister, BaseClient):
223
310
 
224
311
  return dict(self.auth_helper, check_runner_exists=False, **kwargs)
225
312
 
313
+ def compute_cluster(self, compute_cluster_id: str) -> ComputeCluster:
314
+ """Returns an Compute Cluster object for the specified compute cluster ID.
315
+
316
+ Args:
317
+ compute_cluster_id (str): The compute cluster ID for the compute cluster to interact with.
318
+
319
+ Returns:
320
+ ComputeCluster: A Compute Cluster object for the specified compute cluster ID.
321
+
322
+ Example:
323
+ >>> from clarifai.client.user import User
324
+ >>> compute_cluster = User("user_id").compute_cluster("compute_cluster_id")
325
+ """
326
+ request = service_pb2.GetComputeClusterRequest(
327
+ user_app_id=self.user_app_id, compute_cluster_id=compute_cluster_id)
328
+ response = self._grpc_request(self.STUB.GetComputeCluster, request)
329
+ if response.status.code != status_code_pb2.SUCCESS:
330
+ raise Exception(response.status)
331
+
332
+ dict_response = MessageToDict(response, preserving_proto_field_name=True)
333
+ kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
334
+ list(dict_response.keys())[1])
335
+
336
+ return ComputeCluster.from_auth_helper(auth=self.auth_helper, **kwargs)
337
+
226
338
  def patch_app(self, app_id: str, action: str = 'overwrite', **kwargs) -> App:
227
339
  """Patch an app for the user.
228
340
 
@@ -290,6 +402,25 @@ class User(Lister, BaseClient):
290
402
  raise Exception(response.status)
291
403
  self.logger.info("\nRunner Deleted\n%s", response.status)
292
404
 
405
+ def delete_compute_clusters(self, compute_cluster_ids: List[str]) -> None:
406
+ """Deletes a list of compute clusters for the user.
407
+
408
+ Args:
409
+ compute_cluster_ids (List[str]): The compute cluster IDs of the user to delete.
410
+
411
+ Example:
412
+ >>> from clarifai.client.user import User
413
+ >>> user = User("user_id").delete_compute_clusters(compute_cluster_ids=["compute_cluster_id1", "compute_cluster_id2"])
414
+ """
415
+ assert isinstance(compute_cluster_ids, list), "compute_cluster_ids param should be a list"
416
+
417
+ request = service_pb2.DeleteComputeClustersRequest(
418
+ user_app_id=self.user_app_id, ids=compute_cluster_ids)
419
+ response = self._grpc_request(self.STUB.DeleteComputeClusters, request)
420
+ if response.status.code != status_code_pb2.SUCCESS:
421
+ raise Exception(response.status)
422
+ self.logger.info("\nCompute Cluster Deleted\n%s", response.status)
423
+
293
424
  def __getattr__(self, name):
294
425
  return getattr(self.user_info, name)
295
426
 
@@ -12,7 +12,7 @@ from clarifai.client.lister import Lister
12
12
  from clarifai.constants.workflow import MAX_WORKFLOW_PREDICT_INPUTS
13
13
  from clarifai.errors import UserError
14
14
  from clarifai.urls.helper import ClarifaiUrlHelper
15
- from clarifai.utils.logging import get_logger
15
+ from clarifai.utils.logging import logger
16
16
  from clarifai.utils.misc import BackoffIterator
17
17
  from clarifai.workflows.export import Exporter
18
18
 
@@ -59,7 +59,7 @@ class Workflow(Lister, BaseClient):
59
59
  self.kwargs = {**kwargs, 'id': workflow_id, 'version': workflow_version}
60
60
  self.output_config = output_config
61
61
  self.workflow_info = resources_pb2.Workflow(**self.kwargs)
62
- self.logger = get_logger(logger_level="INFO", name=__name__)
62
+ self.logger = logger
63
63
  BaseClient.__init__(
64
64
  self,
65
65
  user_id=self.user_id,
@@ -0,0 +1 @@
1
+ COMPUTE_ORCHESTRATION_RESOURCES = ['Runner', 'ComputeCluster', 'Nodepool', 'Deployment']
@@ -14,9 +14,7 @@ from tqdm import tqdm
14
14
 
15
15
  from clarifai.constants.dataset import CONTENT_TYPE
16
16
  from clarifai.errors import UserError
17
- from clarifai.utils.logging import get_logger
18
-
19
- logger = get_logger("INFO", __name__)
17
+ from clarifai.utils.logging import logger
20
18
 
21
19
 
22
20
  class DatasetExportReader:
clarifai/rag/rag.py CHANGED
@@ -15,7 +15,7 @@ from clarifai.errors import UserError
15
15
  from clarifai.rag.utils import (convert_messages_to_str, format_assistant_message, load_documents,
16
16
  split_document)
17
17
  from clarifai.utils.constants import CLARIFAI_USER_ID_ENV_VAR
18
- from clarifai.utils.logging import get_logger
18
+ from clarifai.utils.logging import logger
19
19
  from clarifai.utils.misc import get_from_dict_or_env
20
20
 
21
21
  DEFAULT_RAG_PROMPT_TEMPLATE = "Context information is below:\n{data.hits}\nGiven the context information and not prior knowledge, answer the query.\nQuery: {data.text.raw}\nAnswer: "
@@ -40,7 +40,7 @@ class RAG:
40
40
  **kwargs):
41
41
  """Initialize an empty or existing RAG.
42
42
  """
43
- self.logger = get_logger(logger_level="INFO", name=__name__)
43
+ self.logger = logger
44
44
  if workflow_url is not None and workflow is None:
45
45
  self.logger.info("workflow_url:%s", workflow_url)
46
46
  w = Workflow(workflow_url, base_url=base_url, pat=pat)
@@ -13,6 +13,11 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
13
13
  """
14
14
 
15
15
  def __init__(self, model_class):
16
+ """
17
+ Args:
18
+ model_class: The class that will handle the model logic. Must implement predict(),
19
+ generate(), stream().
20
+ """
16
21
  self.model_class = model_class
17
22
 
18
23
  def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
@@ -10,8 +10,8 @@ from google.protobuf import json_format
10
10
  from rich import print
11
11
 
12
12
  from clarifai.client import BaseClient
13
-
14
13
  from clarifai.runners.utils.loader import HuggingFaceLoarder
14
+ from clarifai.utils.logging import logger
15
15
 
16
16
 
17
17
  def _clear_line(n: int = 1) -> None:
@@ -40,7 +40,7 @@ class ModelUploader:
40
40
  def _validate_folder(folder):
41
41
  if not folder.startswith("/"):
42
42
  folder = os.path.join(os.getcwd(), folder)
43
- print(f"Validating folder: {folder}")
43
+ logger.info(f"Validating folder: {folder}")
44
44
  if not os.path.exists(folder):
45
45
  raise FileNotFoundError(f"Folder {folder} not found, please provide a valid folder path")
46
46
  files = os.listdir(folder)
@@ -70,7 +70,7 @@ class ModelUploader:
70
70
  base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
71
71
 
72
72
  self._client = BaseClient(user_id=user_id, app_id=app_id, base=base)
73
- print(f"Client initialized for user {user_id} and app {app_id}")
73
+ logger.info(f"Client initialized for user {user_id} and app {app_id}")
74
74
  return self._client
75
75
 
76
76
  def _get_model_proto(self):
@@ -100,7 +100,7 @@ class ModelUploader:
100
100
  service_pb2.GetModelRequest(
101
101
  user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
102
102
  if resp.status.code == status_code_pb2.SUCCESS:
103
- print(
103
+ logger.info(
104
104
  f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
105
105
  f"will create a new version for it.")
106
106
  return resp
@@ -133,9 +133,10 @@ class ModelUploader:
133
133
  build_info = self.config.get('build_info', {})
134
134
  if 'python_version' in build_info:
135
135
  python_version = build_info['python_version']
136
- print(f"Using Python version {python_version} from the config file to build the Dockerfile")
136
+ logger.info(
137
+ f"Using Python version {python_version} from the config file to build the Dockerfile")
137
138
  else:
138
- print(
139
+ logger.info(
139
140
  f"Python version not found in the config file, using default Python version: {self.DEFAULT_PYTHON_VERSION}"
140
141
  )
141
142
  python_version = self.DEFAULT_PYTHON_VERSION
@@ -152,13 +153,13 @@ class ModelUploader:
152
153
 
153
154
  def download_checkpoints(self):
154
155
  if not self.config.get("checkpoints"):
155
- print("No checkpoints specified in the config file")
156
+ logger.info("No checkpoints specified in the config file")
156
157
  return
157
158
 
158
159
  assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
159
160
  loader_type = self.config.get("checkpoints").get("type")
160
161
  if not loader_type:
161
- print("No loader type specified in the config file for checkpoints")
162
+ logger.info("No loader type specified in the config file for checkpoints")
162
163
  assert loader_type == "huggingface", "Only huggingface loader supported for now"
163
164
  if loader_type == "huggingface":
164
165
  assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
@@ -173,9 +174,12 @@ class ModelUploader:
173
174
  loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
174
175
 
175
176
  checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
176
- loader.download_checkpoints(checkpoint_path)
177
+ success = loader.download_checkpoints(checkpoint_path)
177
178
 
178
- print(f"Downloaded checkpoints for model {repo_id}")
179
+ if not success:
180
+ logger.error(f"Failed to download checkpoints for model {repo_id}")
181
+ return
182
+ logger.info(f"Downloaded checkpoints for model {repo_id}")
179
183
 
180
184
  def _concepts_protos_from_concepts(self, concepts):
181
185
  concept_protos = []
@@ -199,7 +203,7 @@ class ModelUploader:
199
203
  with open(config_file, 'w') as file:
200
204
  yaml.dump(config, file, sort_keys=False)
201
205
  concepts = config.get('concepts')
202
- print(f"Updated config.yaml with {len(concepts)} concepts.")
206
+ logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
203
207
 
204
208
  def _get_model_version_proto(self):
205
209
 
@@ -225,11 +229,11 @@ class ModelUploader:
225
229
 
226
230
  def upload_model_version(self):
227
231
  file_path = f"{self.folder}.tar.gz"
228
- print(f"Will tar it into file: {file_path}")
232
+ logger.info(f"Will tar it into file: {file_path}")
229
233
 
230
234
  # Tar the folder
231
235
  os.system(f"tar --exclude=*~ -czvf {self.folder}.tar.gz -C {self.folder} .")
232
- print("Tarring complete, about to start upload.")
236
+ logger.info("Tarring complete, about to start upload.")
233
237
 
234
238
  model_version = self._get_model_version_proto()
235
239
 
@@ -251,10 +255,10 @@ class ModelUploader:
251
255
  flush=True)
252
256
  print()
253
257
  if response.status.code != status_code_pb2.MODEL_BUILDING:
254
- print(f"Failed to upload model version: {response.status.description}")
258
+ logger.error(f"Failed to upload model version: {response}")
255
259
  return
256
260
  model_version_id = response.model_version_id
257
- print(f"Created Model Version ID: {model_version_id}")
261
+ logger.info(f"Created Model Version ID: {model_version_id}")
258
262
 
259
263
  self.monitor_model_build(model_version_id)
260
264
 
@@ -264,10 +268,10 @@ class ModelUploader:
264
268
  file_size = os.path.getsize(file_path)
265
269
  chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
266
270
  num_chunks = (file_size // chunk_size) + 1
267
- print("Uploading file...")
268
- print("File size: ", file_size)
269
- print("Chunk size: ", chunk_size)
270
- print("Number of chunks: ", num_chunks)
271
+ logger.info("Uploading file...")
272
+ logger.info(f"File size: {file_size}")
273
+ logger.info(f"Chunk size: {chunk_size}")
274
+ logger.info(f"Number of chunks: {num_chunks}")
271
275
  read_so_far = 0
272
276
  for part_id in range(num_chunks):
273
277
  try:
@@ -283,16 +287,16 @@ class ModelUploader:
283
287
  range_start=read_so_far,
284
288
  ))
285
289
  except Exception as e:
286
- print(f"\nError uploading file: {e}")
290
+ logger.exception(f"\nError uploading file: {e}")
287
291
  break
288
292
 
289
293
  if read_so_far == file_size:
290
- print("\nUpload complete!, waiting for model build...")
294
+ logger.info("\nUpload complete!, waiting for model build...")
291
295
 
292
296
  def init_upload_model_version(self, model_version, file_path):
293
297
  file_size = os.path.getsize(file_path)
294
- print(f"Uploading model version '{model_version.id}' of model {self.model_proto.id}")
295
- print(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
298
+ logger.info(f"Uploading model version '{model_version.id}' of model {self.model_proto.id}")
299
+ logger.info(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
296
300
  return service_pb2.PostModelVersionsUploadRequest(
297
301
  upload_config=service_pb2.PostModelVersionsUploadConfig(
298
302
  user_app_id=self.client.user_app_id,
@@ -316,13 +320,14 @@ class ModelUploader:
316
320
  print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
317
321
  time.sleep(1)
318
322
  elif status_code == status_code_pb2.MODEL_TRAINED:
319
- print("\nModel build complete!")
320
- print(
323
+ logger.info("\nModel build complete!")
324
+ logger.info(
321
325
  f"Check out the model at https://clarifai.com/{self.client.user_app_id.user_id}/apps/{self.client.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
322
326
  )
323
327
  break
324
328
  else:
325
- print(f"\nModel build failed with status: {resp.model_version.status}")
329
+ logger.info(
330
+ f"\nModel build failed with status: {resp.model_version.status} and response {resp}")
326
331
  break
327
332
 
328
333
 
@@ -15,7 +15,8 @@ from clarifai_protocol import BaseRunner
15
15
  from clarifai_protocol.utils.grpc_server import GRPCServer
16
16
 
17
17
  from clarifai.runners.models.model_servicer import ModelServicer
18
- from clarifai.runners.utils.logging import logger
18
+ from clarifai.runners.models.model_upload import ModelUploader
19
+ from clarifai.utils.logging import logger
19
20
 
20
21
 
21
22
  def main():
@@ -93,21 +94,26 @@ def main():
93
94
 
94
95
  MyRunner = classes[0]
95
96
 
96
- # initialize the Runner class. This is what the user implements.
97
- # (Note) do we want to set runner_id, nodepool_id, compute_cluster_id, base_url, num_parallel_polls as env vars? or as args?
98
- runner = MyRunner(
99
- runner_id=os.environ["CLARIFAI_RUNNER_ID"],
100
- nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
101
- compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
102
- base_url=os.environ["CLARIFAI_API_BASE"],
103
- num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
104
- )
105
-
106
- # initialize the servicer
107
- servicer = ModelServicer(runner)
108
-
109
97
  # Setup the grpc server for local development.
110
98
  if parsed_args.start_dev_server:
99
+
100
+ # We validate that we have checkpoints downloaded before constructing MyRunner which
101
+ # will call load_model()
102
+ uploader = ModelUploader(parsed_args.model_path)
103
+ uploader.download_checkpoints()
104
+
105
+ # initialize the Runner class. This is what the user implements.
106
+ # we aren't going to call runner.start() to engage with the API so IDs are not necessary.
107
+ runner = MyRunner(
108
+ runner_id="n/a",
109
+ nodepool_id="n/a",
110
+ compute_cluster_id="n/a",
111
+ health_check_port=None, # not needed when running local server
112
+ )
113
+
114
+ # initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
115
+ servicer = ModelServicer(runner)
116
+
111
117
  server = GRPCServer(
112
118
  futures.ThreadPoolExecutor(
113
119
  max_workers=parsed_args.pool_size,
@@ -121,9 +127,18 @@ def main():
121
127
  service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
122
128
  server.start()
123
129
  logger.info("Started server on port %s", parsed_args.port)
124
- # server.wait_for_termination() # won't get here currently.
125
-
126
- runner.start() # start the runner loop to fetch work from the API.
130
+ server.wait_for_termination()
131
+ else: # start the runner with the proper env variables and as a runner protocol.
132
+
133
+ # initialize the Runner class. This is what the user implements.
134
+ runner = MyRunner(
135
+ runner_id=os.environ["CLARIFAI_RUNNER_ID"],
136
+ nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
137
+ compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
138
+ base_url=os.environ["CLARIFAI_API_BASE"],
139
+ num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
140
+ )
141
+ runner.start() # start the runner to fetch work from the API.
127
142
 
128
143
 
129
144
  if __name__ == '__main__':
@@ -3,6 +3,8 @@ import json
3
3
  import os
4
4
  import subprocess
5
5
 
6
+ from clarifai.utils.logging import logger
7
+
6
8
 
7
9
  class HuggingFaceLoarder:
8
10
 
@@ -29,22 +31,24 @@ class HuggingFaceLoarder:
29
31
  "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
30
32
  )
31
33
  if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
32
- print("Checkpoints already exist")
34
+ logger.info("Checkpoints already exist")
35
+ return True
33
36
  else:
34
37
  os.makedirs(checkpoint_path, exist_ok=True)
35
38
  try:
36
39
  is_hf_model_exists = self.validate_hf_model()
37
40
  if not is_hf_model_exists:
38
- print("Model not found on Hugging Face")
41
+ logger.error("Model %s not found on Hugging Face" % (self.repo_id))
39
42
  return False
40
- snapshot_download(repo_id=self.repo_id, local_dir=checkpoint_path)
43
+ snapshot_download(
44
+ repo_id=self.repo_id, local_dir=checkpoint_path, local_dir_use_symlinks=False)
41
45
  except Exception as e:
42
- print("Error downloading model checkpoints ", e)
46
+ logger.exception(f"Error downloading model checkpoints {e}")
43
47
  return False
44
48
  finally:
45
49
  is_downloaded = self.validate_download(checkpoint_path)
46
50
  if not is_downloaded:
47
- print("Error downloading model checkpoints")
51
+ logger.error("Error validating downloaded model checkpoints")
48
52
  return False
49
53
  return True
50
54
 
@@ -57,7 +61,10 @@ class HuggingFaceLoarder:
57
61
  def validate_download(self, checkpoint_path: str):
58
62
  # check if model exists on HF
59
63
  from huggingface_hub import list_repo_files
60
- return (len(os.listdir(checkpoint_path)) >= len(list_repo_files(self.repo_id))) and len(
64
+ checkpoint_dir_files = [
65
+ f for dp, dn, fn in os.walk(os.path.expanduser(checkpoint_path)) for f in fn
66
+ ]
67
+ return (len(checkpoint_dir_files) >= len(list_repo_files(self.repo_id))) and len(
61
68
  list_repo_files(self.repo_id)) > 0
62
69
 
63
70
  def fetch_labels(self, checkpoint_path: str):
@@ -2,7 +2,7 @@ import concurrent.futures
2
2
 
3
3
  import fsspec
4
4
 
5
- from .logging import logger
5
+ from clarifai.utils.logging import logger
6
6
 
7
7
 
8
8
  def download_input(input):
@@ -20,8 +20,7 @@ except ImportError:
20
20
  try:
21
21
  from loguru import logger
22
22
  except ImportError:
23
- from ..logging import get_logger
24
- logger = get_logger(logger_level="INFO", name=__name__)
23
+ from ..logging import logger
25
24
 
26
25
  MACRO_AVG = "macro_avg"
27
26
 
@@ -26,8 +26,7 @@ except ImportError:
26
26
  try:
27
27
  from loguru import logger
28
28
  except ImportError:
29
- from ..logging import get_logger
30
- logger = get_logger(logger_level="INFO", name=__name__)
29
+ from ..logging import logger
31
30
 
32
31
  __all__ = ['EvalResultCompare']
33
32
 
clarifai/utils/logging.py CHANGED
@@ -19,6 +19,8 @@ from rich.tree import Tree
19
19
 
20
20
  install()
21
21
 
22
+ # The default logger to use throughout the SDK is defined at bottom of this file.
23
+
22
24
  # For the json logger.
23
25
  JSON_LOGGER_NAME = "clarifai-json"
24
26
  JSON_LOG_KEY = 'msg'
@@ -357,3 +359,7 @@ class JsonFormatter(logging.Formatter):
357
359
  default=self.json_default,
358
360
  cls=self.json_cls,
359
361
  )
362
+
363
+
364
+ # the default logger for the SDK.
365
+ logger = get_logger(logger_level=os.environ.get("LOG_LEVEL", "INFO"), name="clarifai")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 10.8.6
3
+ Version: 10.8.8
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -21,7 +21,7 @@ Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
23
  Requires-Dist: clarifai-grpc >=10.8.7
24
- Requires-Dist: clarifai-protocol >=0.0.4
24
+ Requires-Dist: clarifai-protocol >=0.0.6
25
25
  Requires-Dist: numpy >=1.22.0
26
26
  Requires-Dist: tqdm >=4.65.0
27
27
  Requires-Dist: tritonclient >=2.34.0