clarifai 11.2.3rc7__py3-none-any.whl → 11.2.3rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/cli/nodepool.py CHANGED
@@ -1,32 +1,29 @@
1
1
  import click
2
+
2
3
  from clarifai.cli.base import cli
3
- from clarifai.client.compute_cluster import ComputeCluster
4
- from clarifai.utils.cli import display_co_resources, dump_yaml, from_yaml
4
+ from clarifai.utils.cli import (AliasedGroup, display_co_resources, dump_yaml, from_yaml,
5
+ validate_context)
5
6
 
6
7
 
7
- @cli.group(['nodepool', 'np'])
8
+ @cli.group(['nodepool', 'np'], cls=AliasedGroup)
8
9
  def nodepool():
9
10
  """Manage Nodepools: create, delete, list"""
10
- pass
11
11
 
12
12
 
13
- @nodepool.command()
14
- @click.option(
15
- '-cc_id',
16
- '--compute_cluster_id',
17
- required=False,
18
- help='Compute Cluster ID for the compute cluster to interact with.')
13
+ @nodepool.command(['c'])
14
+ @click.argument('compute_cluster_id')
15
+ @click.argument('nodepool_id')
19
16
  @click.option(
20
17
  '--config',
21
18
  type=click.Path(exists=True),
22
19
  required=True,
23
20
  help='Path to the nodepool config file.')
24
- @click.option(
25
- '-np_id', '--nodepool_id', required=False, help='New Nodepool ID for the nodepool to create.')
26
21
  @click.pass_context
27
- def create(ctx, compute_cluster_id, config, nodepool_id):
22
+ def create(ctx, compute_cluster_id, nodepool_id, config):
28
23
  """Create a new Nodepool with the given config file."""
24
+ from clarifai.client.compute_cluster import ComputeCluster
29
25
 
26
+ validate_context(ctx)
30
27
  nodepool_config = from_yaml(config)
31
28
  if not compute_cluster_id:
32
29
  if 'compute_cluster' not in nodepool_config['nodepool']:
@@ -42,50 +39,74 @@ def create(ctx, compute_cluster_id, config, nodepool_id):
42
39
 
43
40
  compute_cluster = ComputeCluster(
44
41
  compute_cluster_id=compute_cluster_id,
45
- user_id=ctx.obj['user_id'],
46
- pat=ctx.obj['pat'],
47
- base_url=ctx.obj['base_url'])
42
+ user_id=ctx.obj.current.user_id,
43
+ pat=ctx.obj.current.pat,
44
+ base_url=ctx.obj.current.api_base)
48
45
  if nodepool_id:
49
46
  compute_cluster.create_nodepool(config, nodepool_id=nodepool_id)
50
47
  else:
51
48
  compute_cluster.create_nodepool(config)
52
49
 
53
50
 
54
- @nodepool.command()
55
- @click.option(
56
- '-cc_id',
57
- '--compute_cluster_id',
58
- required=True,
59
- help='Compute Cluster ID for the compute cluster to interact with.')
51
+ @nodepool.command(['ls'])
52
+ @click.argument('compute_cluster_id', default="")
60
53
  @click.option('--page_no', required=False, help='Page number to list.', default=1)
61
- @click.option('--per_page', required=False, help='Number of items per page.', default=16)
54
+ @click.option('--per_page', required=False, help='Number of items per page.', default=128)
62
55
  @click.pass_context
63
56
  def list(ctx, compute_cluster_id, page_no, per_page):
64
- """List all nodepools for the user."""
57
+ """List all nodepools for the user across all compute clusters. If compute_cluster_id is provided
58
+ it will list only within that compute cluster. """
59
+ from clarifai.client.compute_cluster import ComputeCluster
60
+ from clarifai.client.user import User
65
61
 
66
- compute_cluster = ComputeCluster(
67
- compute_cluster_id=compute_cluster_id,
68
- user_id=ctx.obj['user_id'],
69
- pat=ctx.obj['pat'],
70
- base_url=ctx.obj['base_url'])
71
- response = compute_cluster.list_nodepools(page_no, per_page)
72
- display_co_resources(response, "Nodepool")
62
+ validate_context(ctx)
73
63
 
64
+ cc_id = compute_cluster_id
74
65
 
75
- @nodepool.command()
76
- @click.option(
77
- '-cc_id',
78
- '--compute_cluster_id',
79
- required=True,
80
- help='Compute Cluster ID for the compute cluster to interact with.')
81
- @click.option('-np_id', '--nodepool_id', help='Nodepool ID of the user to delete.')
66
+ if cc_id:
67
+ compute_cluster = ComputeCluster(
68
+ compute_cluster_id=cc_id,
69
+ user_id=ctx.obj.current.user_id,
70
+ pat=ctx.obj.current.pat,
71
+ base_url=ctx.obj.current.api_base)
72
+ response = compute_cluster.list_nodepools(page_no, per_page)
73
+ else:
74
+ user = User(
75
+ user_id=ctx.obj.current.user_id,
76
+ pat=ctx.obj.current.pat,
77
+ base_url=ctx.obj.current.api_base)
78
+ ccs = user.list_compute_clusters(page_no, per_page)
79
+ response = []
80
+ for cc in ccs:
81
+ compute_cluster = ComputeCluster(
82
+ compute_cluster_id=cc.id,
83
+ user_id=ctx.obj.current.user_id,
84
+ pat=ctx.obj.current.pat,
85
+ base_url=ctx.obj.current.api_base)
86
+ response.extend([i for i in compute_cluster.list_nodepools(page_no, per_page)])
87
+
88
+ display_co_resources(
89
+ response,
90
+ custom_columns={
91
+ 'ID': lambda c: c.id,
92
+ 'USER_ID': lambda c: c.compute_cluster.user_id,
93
+ 'COMPUTE_CLUSTER_ID': lambda c: c.compute_cluster.id,
94
+ 'DESCRIPTION': lambda c: c.description,
95
+ })
96
+
97
+
98
+ @nodepool.command(['rm'])
99
+ @click.argument('compute_cluster_id')
100
+ @click.argument('nodepool_id')
82
101
  @click.pass_context
83
102
  def delete(ctx, compute_cluster_id, nodepool_id):
84
103
  """Deletes a nodepool for the user."""
104
+ from clarifai.client.compute_cluster import ComputeCluster
85
105
 
106
+ validate_context(ctx)
86
107
  compute_cluster = ComputeCluster(
87
108
  compute_cluster_id=compute_cluster_id,
88
- user_id=ctx.obj['user_id'],
89
- pat=ctx.obj['pat'],
90
- base_url=ctx.obj['base_url'])
109
+ user_id=ctx.obj.current.user_id,
110
+ pat=ctx.obj.current.pat,
111
+ base_url=ctx.obj.current.api_base)
91
112
  compute_cluster.delete_nodepools([nodepool_id])
clarifai/client/app.py CHANGED
@@ -629,7 +629,7 @@ class App(Lister, BaseClient):
629
629
 
630
630
  Args:
631
631
  model_id (str): The model ID for the model to interact with.
632
- model_version_id (str): The model version ID for the model version to interact with.
632
+ model_version (Dict): The model version ID for the model version to interact with.
633
633
 
634
634
  Returns:
635
635
  Model: A Model object for the existing model ID.
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import logging
3
2
  import time
4
3
  from concurrent.futures import ThreadPoolExecutor
5
4
 
@@ -8,7 +7,7 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
8
7
 
9
8
  from clarifai.client.auth.helper import ClarifaiAuthHelper
10
9
  from clarifai.client.auth.register import RpcCallable, V2Stub
11
-
10
+ from clarifai.utils.logging import logger
12
11
  throttle_status_codes = {
13
12
  status_code_pb2.CONN_THROTTLED,
14
13
  status_code_pb2.CONN_EXCEED_HOURLY_LIMIT,
@@ -26,7 +25,7 @@ def validate_response(response, attempt, max_attempts):
26
25
  def handle_simple_response(response):
27
26
  if hasattr(response, 'status') and hasattr(response.status, 'code'):
28
27
  if (response.status.code in throttle_status_codes) and attempt < max_attempts:
29
- logging.debug('Retrying with status %s' % str(response.status))
28
+ logger.debug('Retrying with status %s' % str(response.status))
30
29
  return None # Indicates a retry is needed
31
30
  else:
32
31
  return response
@@ -42,7 +41,7 @@ def validate_response(response, attempt, max_attempts):
42
41
  return itertools.chain([validated_response], response)
43
42
  return None # Indicates a retry is needed
44
43
  except grpc.RpcError as e:
45
- logging.error('Error processing streaming response: %s' % str(e))
44
+ logger.error('Error processing streaming response: %s' % str(e))
46
45
  return None # Indicates an error
47
46
  else:
48
47
  # Handle simple response validation
@@ -143,7 +142,7 @@ class _RetryRpcCallable(RpcCallable):
143
142
  return v
144
143
  except grpc.RpcError as e:
145
144
  if (e.code() in retry_codes_grpc) and attempt < self.max_attempts:
146
- logging.debug('Retrying with status %s' % e.code())
145
+ logger.debug('Retrying with status %s' % e.code())
147
146
  else:
148
147
  raise
149
148
 
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import os
3
2
  import time
4
3
  import uuid
@@ -354,7 +353,7 @@ class Dataset(Lister, BaseClient):
354
353
  break
355
354
  if failed_input_ids:
356
355
  retry_input_ids = [dataset_obj.all_input_ids[id] for id in failed_input_ids]
357
- logging.warning(
356
+ logger.warning(
358
357
  f"Retrying upload for {len(failed_input_ids)} inputs in current batch: {retry_input_ids}\n"
359
358
  )
360
359
  failed_retrying_inputs, _, retry_response = self._upload_inputs_annotations(
@@ -494,7 +493,7 @@ class Dataset(Lister, BaseClient):
494
493
  add_file_handler(self.logger, f"Dataset_Upload{str(int(datetime.now().timestamp()))}.log")
495
494
 
496
495
  if retry_duplicates and duplicate_input_ids:
497
- logging.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
496
+ logger.warning(f"Retrying upload for {len(duplicate_input_ids)} duplicate inputs...\n")
498
497
  duplicate_inputs_indexes = [input["Index"] for input in duplicate_input_ids]
499
498
  self.upload_dataset(
500
499
  dataloader=dataloader,
@@ -505,7 +504,7 @@ class Dataset(Lister, BaseClient):
505
504
 
506
505
  if failed_input_ids:
507
506
  #failed_inputs= ([input["Input_ID"] for input in failed_input_ids])
508
- logging.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
507
+ logger.warning(f"Retrying upload for {len(failed_input_ids)} failed inputs...\n")
509
508
  failed_input_indexes = [input["Index"] for input in failed_input_ids]
510
509
  self.upload_dataset(
511
510
  dataloader=dataloader, log_retry_ids=failed_input_indexes, is_log_retry=True, **kwargs)
@@ -59,12 +59,41 @@ class ModelClient:
59
59
  Returns:
60
60
  Dict: The method signatures.
61
61
  '''
62
- #request = resources_pb2.GetModelSignaturesRequest()
63
- #response = self.stub.GetModelSignatures(request)
64
- #self._method_signatures = json.loads(response.signatures) # or define protos
65
- # TODO this could use a new endpoint to get the signatures
66
- # for local grpc models, we'll also have to add the endpoint to the model servicer
67
- # for now we'll just use the predict endpoint with a special method name
62
+ try:
63
+ response = self.client.STUB.GetModelVersion(
64
+ service_pb2.GetModelVersionRequest(
65
+ user_app_id=self.request_template.user_app_id,
66
+ model_id= self.request_template.model_id,
67
+ version_id=self.request_template.model_version.id,
68
+ ))
69
+ method_signatures = None
70
+ if response.status.code == status_code_pb2.SUCCESS:
71
+ method_signatures = response.model_version.method_signatures
72
+ if response.status.code != status_code_pb2.SUCCESS:
73
+ raise Exception(f"Model failed with response {response!r}")
74
+ self._method_signatures= {}
75
+ for method_signature in method_signatures:
76
+ method_name = method_signature.name
77
+ # check for duplicate method names
78
+ if method_name in self._method_signatures:
79
+ raise ValueError(f"Duplicate method name {method_name}")
80
+ self._method_signatures[method_name] = method_signature
81
+ if not self._method_signatures: # if no method signatures, try to fetch from the model
82
+ self._fetch_signatures_backup()
83
+ except Exception as e:
84
+ logger.info(f"Failed to fetch method signatures from model: {e}")
85
+ # try to fetch from the model
86
+ self._fetch_signatures_backup()
87
+ if not self._method_signatures:
88
+ raise ValueError("Failed to fetch method signatures from model and backup method")
89
+
90
+ def _fetch_signatures_backup(self):
91
+ '''
92
+ This is a temporary method of fetching the method signatures from the model.
93
+
94
+ Returns:
95
+ Dict: The method signatures.
96
+ '''
68
97
 
69
98
  request = service_pb2.PostModelOutputsRequest()
70
99
  request.CopyFrom(self.request_template)
@@ -192,17 +192,21 @@ class ModelBuilder:
192
192
  def _validate_config_checkpoints(self):
193
193
  """
194
194
  Validates the checkpoints section in the config file.
195
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
195
196
  :return: loader_type the type of loader or None if no checkpoints.
196
197
  :return: repo_id location of checkpoint.
197
198
  :return: hf_token token to access checkpoint.
199
+ :return: when one of ['upload', 'build', 'runtime'] to download checkpoint
200
+ :return: allowed_file_patterns patterns to allow in downloaded checkpoint
201
+ :return: ignore_file_patterns patterns to ignore in downloaded checkpoint
198
202
  """
199
203
  if "checkpoints" not in self.config:
200
- return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN
204
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
201
205
  assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
202
206
  loader_type = self.config.get("checkpoints").get("type")
203
207
  if not loader_type:
204
208
  logger.info("No loader type specified in the config file for checkpoints")
205
- return None, None, None
209
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
206
210
  checkpoints = self.config.get("checkpoints")
207
211
  if 'when' not in checkpoints:
208
212
  logger.warn(
@@ -221,15 +225,30 @@ class ModelBuilder:
221
225
 
222
226
  # get from config.yaml otherwise fall back to HF_TOKEN env var.
223
227
  hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
224
- return loader_type, repo_id, hf_token, when
228
+
229
+ allowed_file_patterns = self.config.get("checkpoints").get('allowed_file_patterns', None)
230
+ if isinstance(allowed_file_patterns, str):
231
+ allowed_file_patterns = [allowed_file_patterns]
232
+ ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
233
+ if isinstance(ignore_file_patterns, str):
234
+ ignore_file_patterns = [ignore_file_patterns]
235
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
225
236
 
226
237
  def _check_app_exists(self):
227
238
  resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
228
239
  if resp.status.code == status_code_pb2.SUCCESS:
229
240
  return True
241
+ if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
242
+ logger.error(
243
+ f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
244
+ )
245
+ return False
230
246
  logger.error(
231
247
  f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
232
248
  )
249
+ logger.error(
250
+ f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
251
+ )
233
252
  return False
234
253
 
235
254
  def _validate_config_model(self):
@@ -250,9 +269,6 @@ class ModelBuilder:
250
269
  assert model.get('id') != "", "model_id cannot be empty in the config file"
251
270
 
252
271
  if not self._check_app_exists():
253
- logger.error(
254
- f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}"
255
- )
256
272
  sys.exit(1)
257
273
 
258
274
  def _validate_config(self):
@@ -266,7 +282,7 @@ class ModelBuilder:
266
282
  assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
267
283
 
268
284
  if self.config.get("checkpoints"):
269
- loader_type, _, hf_token, _ = self._validate_config_checkpoints()
285
+ loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
270
286
 
271
287
  if loader_type == "huggingface" and hf_token:
272
288
  is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
@@ -282,7 +298,7 @@ class ModelBuilder:
282
298
  f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
283
299
  )
284
300
  else:
285
- num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 1))
301
+ num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 16))
286
302
  self.config["num_threads"] = num_threads
287
303
 
288
304
  @staticmethod
@@ -354,8 +370,9 @@ class ModelBuilder:
354
370
 
355
371
  assert "model_type_id" in model, "model_type_id not found in the config file"
356
372
  assert "id" in model, "model_id not found in the config file"
357
- assert "user_id" in model, "user_id not found in the config file"
358
- assert "app_id" in model, "app_id not found in the config file"
373
+ if not self.download_validation_only:
374
+ assert "user_id" in model, "user_id not found in the config file"
375
+ assert "app_id" in model, "app_id not found in the config file"
359
376
 
360
377
  model_proto = json_format.ParseDict(model, resources_pb2.Model())
361
378
 
@@ -466,11 +483,12 @@ class ModelBuilder:
466
483
  # Sort in reverse so that newer cuda versions come first and are preferred.
467
484
  for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
468
485
  if torch_version in image and f'py{python_version}' in image:
469
- cuda_version = image.split('-')[-1].replace('cuda', '')
486
+ # like cu124, rocm6.3, etc.
487
+ gpu_version = image.split('-')[-1]
470
488
  final_image = TORCH_BASE_IMAGE.format(
471
489
  torch_version=torch_version,
472
490
  python_version=python_version,
473
- cuda_version=cuda_version,
491
+ gpu_version=gpu_version,
474
492
  )
475
493
  logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
476
494
  break
@@ -547,8 +565,10 @@ class ModelBuilder:
547
565
  if not self.config.get("checkpoints"):
548
566
  logger.info("No checkpoints specified in the config file")
549
567
  return path
568
+ clarifai_model_type_id = self.config.get('model').get('model_type_id')
550
569
 
551
- loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
570
+ loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = self._validate_config_checkpoints(
571
+ )
552
572
  if stage not in ["build", "upload", "runtime"]:
553
573
  raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
554
574
  if when != stage:
@@ -557,14 +577,18 @@ class ModelBuilder:
557
577
  )
558
578
  return path
559
579
 
560
- success = True
580
+ success = False
561
581
  if loader_type == "huggingface":
562
- loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
582
+ loader = HuggingFaceLoader(
583
+ repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id)
563
584
  # for runtime default to /tmp path
564
585
  if stage == "runtime" and checkpoint_path_override is None:
565
586
  checkpoint_path_override = self.default_runtime_checkpoint_path()
566
587
  path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
567
- success = loader.download_checkpoints(path)
588
+ success = loader.download_checkpoints(
589
+ path,
590
+ allowed_file_patterns=allowed_file_patterns,
591
+ ignore_file_patterns=ignore_file_patterns)
568
592
 
569
593
  if loader_type:
570
594
  if not success:
@@ -598,53 +622,12 @@ class ModelBuilder:
598
622
  concepts = config.get('concepts')
599
623
  logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
600
624
 
601
- def filled_params_specs_with_inference_params(self, method_signatures: list[resources_pb2.MethodSignature]) -> list[resources_pb2.ModelTypeField]:
602
- """
603
- Fills the params_specs with the inference params.
604
- """
605
- inference_params = set()
606
- for i, signature in enumerate(method_signatures):
607
- for field in signature.input_fields:
608
- if field.is_param:
609
- if i==0:
610
- inference_params.add(field.name)
611
- else:
612
- # if field.name not in inference_params then remove from inference_params
613
- if field.name not in inference_params:
614
- inference_params.remove(field.name)
615
- output=[]
616
- for signature in method_signatures:
617
- for field in signature.input_fields:
618
- if field.is_param and field.name in inference_params:
619
- field.path = field.name
620
- if field.type == resources_pb2.ModelTypeField.DataType.STR:
621
- field.default_value= str(field.default)
622
- field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.STRING
623
- elif field.type == resources_pb2.ModelTypeField.DataType.INT:
624
- field.default_value= int(field.default)
625
- field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.NUMBER
626
- elif field.type == resources_pb2.ModelTypeField.DataType.FLOAT:
627
- field.default_value= float(field.default)
628
- field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.NUMBER
629
- elif field.type == resources_pb2.ModelTypeField.DataType.BOOL:
630
- field.default_value= bool(field.default)
631
- field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.BOOLEAN
632
- else:
633
- field.default_value= field.default
634
- field.field_type = resources_pb2.ModelTypeField.ModelTypeFieldType.STRING
635
- output.append(field)
636
- return output
637
-
638
-
639
625
  def get_model_version_proto(self):
640
626
  signatures = self.get_method_signatures()
641
627
  model_version_proto = resources_pb2.ModelVersion(
642
628
  pretrained_model_config=resources_pb2.PretrainedModelConfig(),
643
629
  inference_compute_info=self.inference_compute_info,
644
630
  method_signatures=signatures,
645
- # output_info= resources_pb2.OutputInfo(
646
- # params_specs=self.filled_params_specs_with_inference_params(signatures),
647
- # )
648
631
  )
649
632
 
650
633
  model_type_id = self.config.get('model').get('model_type_id')
@@ -678,7 +661,7 @@ class ModelBuilder:
678
661
  logger.debug(f"Will tar it into file: {file_path}")
679
662
 
680
663
  model_type_id = self.config.get('model').get('model_type_id')
681
- loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
664
+ loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
682
665
 
683
666
  if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
684
667
  logger.info(
@@ -726,7 +709,7 @@ class ModelBuilder:
726
709
  # First check for the env variable, then try querying huggingface. If all else fails, use the default.
727
710
  checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
728
711
  if not checkpoint_size:
729
- _, repo_id, _, _ = self._validate_config_checkpoints()
712
+ _, repo_id, _, _, _, _ = self._validate_config_checkpoints()
730
713
  checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
731
714
  if not checkpoint_size:
732
715
  checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
@@ -30,7 +30,8 @@ class ModelClass(ABC):
30
30
  Example:
31
31
 
32
32
  from clarifai.runners.model_class import ModelClass
33
- from clarifai.runners.utils.data_types import NamedFields, Stream
33
+ from clarifai.runners.utils.data_types import NamedFields
34
+ from typing import List, Iterator
34
35
 
35
36
  class MyModel(ModelClass):
36
37
 
@@ -39,12 +40,12 @@ class ModelClass(ABC):
39
40
  return [x] * y
40
41
 
41
42
  @ModelClass.method
42
- def generate(self, x: str, y: int) -> Stream[str]:
43
+ def generate(self, x: str, y: int) -> Iterator[str]:
43
44
  for i in range(y):
44
45
  yield x + str(i)
45
46
 
46
47
  @ModelClass.method
47
- def stream(self, input_stream: Stream[NamedFields(x=str, y=int)]) -> Stream[str]:
48
+ def stream(self, input_stream: Iterator[NamedFields(x=str, y=int)]) -> Iterator[str]:
48
49
  for item in input_stream:
49
50
  yield item.x + ' ' + str(item.y)
50
51
  '''
@@ -270,8 +271,8 @@ class ModelClass(ABC):
270
271
  if k not in python_param_types:
271
272
  continue
272
273
 
273
- if hasattr(python_param_types[k], "__args__") and getattr(
274
- python_param_types[k], "__origin__", None) == data_types.Stream:
274
+ if hasattr(python_param_types[k], "__args__") and getattr(python_param_types[k],
275
+ "__origin__", None) == Iterator:
275
276
  # get the type of the items in the stream
276
277
  stream_type = python_param_types[k].__args__[0]
277
278
 
@@ -11,7 +11,6 @@ import traceback
11
11
  import venv
12
12
 
13
13
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
14
-
15
14
  from clarifai.runners.models.model_builder import ModelBuilder
16
15
  from clarifai.utils.logging import logger
17
16
 
@@ -104,11 +103,6 @@ class ModelRunLocally:
104
103
  ],
105
104
  )
106
105
 
107
- def _build_stream_request(self):
108
- request = self._build_request()
109
- for i in range(1):
110
- yield request
111
-
112
106
  def _run_test(self):
113
107
  """Test the model locally by making a prediction."""
114
108
  # Create the model
@@ -404,39 +398,33 @@ def main(model_path,
404
398
  inside_container=False,
405
399
  port=8080,
406
400
  keep_env=False,
407
- keep_image=False):
401
+ keep_image=False,
402
+ skip_dockerfile: bool = False):
408
403
 
409
- if not os.environ.get("CLARIFAI_PAT", None):
410
- logger.error(
411
- "CLARIFAI_PAT environment variable is not set! Please set your PAT in the 'CLARIFAI_PAT' environment variable."
412
- )
413
- sys.exit(1)
414
404
  manager = ModelRunLocally(model_path)
415
405
  # get whatever stage is in config.yaml to force download now
416
406
  # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
417
- _, _, _, when = manager.builder._validate_config_checkpoints()
407
+ _, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
418
408
  manager.builder.download_checkpoints(
419
409
  stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
420
410
  if inside_container:
421
411
  if not manager.is_docker_installed():
422
412
  sys.exit(1)
423
- manager.builder.create_dockerfile()
413
+ if not skip_dockerfile:
414
+ manager.builder.create_dockerfile()
424
415
  image_tag = manager._docker_hash()
425
- image_name = f"{manager.config['model']['id']}:{image_tag}"
426
- container_name = manager.config['model']['id']
416
+ model_id = manager.config['model']['id'].lower()
417
+ # must be in lowercase
418
+ image_name = f"{model_id}:{image_tag}"
419
+ container_name = model_id
427
420
  if not manager.docker_image_exists(image_name):
428
421
  manager.build_docker_image(image_name=image_name)
429
422
  try:
430
- envs = {
431
- 'CLARIFAI_PAT': os.environ['CLARIFAI_PAT'],
432
- 'CLARIFAI_API_BASE': os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')
433
- }
434
423
  if run_model_server:
435
424
  manager.run_docker_container(
436
- image_name=image_name, container_name=container_name, port=port, env_vars=envs)
425
+ image_name=image_name, container_name=container_name, port=port)
437
426
  else:
438
- manager.test_model_container(
439
- image_name=image_name, container_name=container_name, env_vars=envs)
427
+ manager.test_model_container(image_name=image_name, container_name=container_name)
440
428
  finally:
441
429
  if manager.container_exists(container_name):
442
430
  manager.stop_docker_container(container_name)
@@ -2,10 +2,10 @@ import os
2
2
 
3
3
  registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarifai-models')
4
4
 
5
- GIT_SHA = "df565436eea93efb3e8d1eb558a0a46df29523ec"
5
+ GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
6
6
 
7
7
  PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
8
- TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-cuda{cuda_version}-' + GIT_SHA
8
+ TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
9
9
 
10
10
  # List of available python base images
11
11
  AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
@@ -21,12 +21,13 @@ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
21
21
  # List of available torch images
22
22
  # Keep sorted by most recent cuda version.
23
23
  AVAILABLE_TORCH_IMAGES = [
24
- '2.4.1-py3.11-cuda124',
25
- '2.5.1-py3.11-cuda124',
26
- '2.4.1-py3.12-cuda124',
27
- '2.5.1-py3.12-cuda124',
28
- # '2.4.1-py3.13-cuda124',
29
- # '2.5.1-py3.13-cuda124',
24
+ '2.4.1-py3.11-cu124',
25
+ '2.5.1-py3.11-cu124',
26
+ '2.4.1-py3.12-cu124',
27
+ '2.5.1-py3.12-cu124',
28
+ '2.6.0-py3.12-cu126',
29
+ '2.7.0-py3.12-cu128',
30
+ '2.7.0-py3.12-rocm6.3',
30
31
  ]
31
32
  CONCEPTS_REQUIRED_MODEL_TYPE = [
32
33
  'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
@@ -83,10 +83,6 @@ class NamedFields(metaclass=NamedFieldsMeta):
83
83
  return list(self.keys())
84
84
 
85
85
 
86
- class Stream(Iterable):
87
- pass
88
-
89
-
90
86
  class JSON:
91
87
 
92
88
  def __init__(self, value):
@@ -221,7 +221,7 @@ class DataConverter:
221
221
  new_data.video.CopyFrom(old_data.video)
222
222
  return new_data
223
223
  elif data_type == resources_pb2.ModelTypeField.DataType.BOOL:
224
- if old_data.bool_value != False:
224
+ if old_data.bool_value is not False:
225
225
  new_data.bool_value = old_data.bool_value
226
226
  return new_data
227
227
  elif data_type == resources_pb2.ModelTypeField.DataType.INT: