clarifai 11.5.2__py3-none-any.whl → 11.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/model.py +33 -1
  3. clarifai/cli/pipeline.py +137 -0
  4. clarifai/cli/pipeline_step.py +104 -0
  5. clarifai/cli/templates/__init__.py +1 -0
  6. clarifai/cli/templates/pipeline_step_templates.py +64 -0
  7. clarifai/cli/templates/pipeline_templates.py +150 -0
  8. clarifai/client/auth/helper.py +23 -0
  9. clarifai/client/auth/register.py +5 -0
  10. clarifai/client/auth/stub.py +116 -12
  11. clarifai/client/base.py +9 -0
  12. clarifai/client/model.py +111 -7
  13. clarifai/client/model_client.py +355 -6
  14. clarifai/client/user.py +81 -0
  15. clarifai/runners/models/model_builder.py +52 -9
  16. clarifai/runners/pipeline_steps/__init__.py +0 -0
  17. clarifai/runners/pipeline_steps/pipeline_step_builder.py +510 -0
  18. clarifai/runners/pipelines/__init__.py +0 -0
  19. clarifai/runners/pipelines/pipeline_builder.py +313 -0
  20. clarifai/runners/utils/code_script.py +40 -7
  21. clarifai/runners/utils/const.py +2 -2
  22. clarifai/runners/utils/model_utils.py +135 -0
  23. clarifai/runners/utils/pipeline_validation.py +153 -0
  24. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/METADATA +1 -1
  25. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/RECORD +30 -19
  26. /clarifai/cli/{model_templates.py → templates/model_templates.py} +0 -0
  27. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/WHEEL +0 -0
  28. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/entry_points.txt +0 -0
  29. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/licenses/LICENSE +0 -0
  30. {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import inspect
2
3
  import time
3
4
  from typing import Any, Dict, Iterator, List
@@ -5,6 +6,7 @@ from typing import Any, Dict, Iterator, List
5
6
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
6
7
  from clarifai_grpc.grpc.api.status import status_code_pb2
7
8
 
9
+ from clarifai.client.auth.register import V2Stub
8
10
  from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
9
11
  from clarifai.errors import UserError
10
12
  from clarifai.runners.utils import code_script, method_signatures
@@ -20,12 +22,31 @@ from clarifai.utils.logging import logger
20
22
  from clarifai.utils.misc import BackoffIterator, status_is_retryable
21
23
 
22
24
 
25
+ def is_async_context():
26
+ """Check if code is running in an async context."""
27
+ try:
28
+ asyncio.get_running_loop()
29
+ import sys
30
+
31
+ # In Jupyter, to check if we're actually in an async cell. Becaue by default jupyter considers it as async.
32
+ if 'ipykernel' in sys.modules:
33
+ return False
34
+ return True
35
+ except RuntimeError:
36
+ return False
37
+
38
+
23
39
  class ModelClient:
24
40
  '''
25
41
  Client for calling model predict, generate, and stream methods.
26
42
  '''
27
43
 
28
- def __init__(self, stub, request_template: service_pb2.PostModelOutputsRequest = None):
44
+ def __init__(
45
+ self,
46
+ stub,
47
+ async_stub: V2Stub = None,
48
+ request_template: service_pb2.PostModelOutputsRequest = None,
49
+ ):
29
50
  '''
30
51
  Initialize the model client.
31
52
 
@@ -35,6 +56,7 @@ class ModelClient:
35
56
  common fields like model_id, model_version, cluster, etc.
36
57
  '''
37
58
  self.STUB = stub
59
+ self.async_stub = async_stub
38
60
  self.request_template = request_template or service_pb2.PostModelOutputsRequest()
39
61
  self._method_signatures = None
40
62
  self._defined = False
@@ -137,16 +159,19 @@ class ModelClient:
137
159
  # define the function in this client instance
138
160
  if resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_UNARY':
139
161
  call_func = self._predict
162
+ async_call_func = self._async_predict
140
163
  elif (
141
164
  resources_pb2.RunnerMethodType.Name(method_signature.method_type)
142
165
  == 'UNARY_STREAMING'
143
166
  ):
144
167
  call_func = self._generate
168
+ async_call_func = self._async_generate
145
169
  elif (
146
170
  resources_pb2.RunnerMethodType.Name(method_signature.method_type)
147
171
  == 'STREAMING_STREAMING'
148
172
  ):
149
173
  call_func = self._stream
174
+ async_call_func = self._async_stream
150
175
  else:
151
176
  raise ValueError(f"Unknown method type {method_signature.method_type}")
152
177
 
@@ -158,8 +183,8 @@ class ModelClient:
158
183
  continue
159
184
  method_argnames.append(outer)
160
185
 
161
- def bind_f(method_name, method_argnames, call_func):
162
- def f(*args, **kwargs):
186
+ def bind_f(method_name, method_argnames, call_func, async_call_func):
187
+ def sync_f(*args, **kwargs):
163
188
  if len(args) > len(method_argnames):
164
189
  raise TypeError(
165
190
  f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
@@ -187,10 +212,45 @@ class ModelClient:
187
212
  kwargs[name] = arg
188
213
  return call_func(kwargs, method_name)
189
214
 
190
- return f
215
+ async def async_f(*args, **kwargs):
216
+ # Async version to call the async function
217
+ if len(args) > len(method_argnames):
218
+ raise TypeError(
219
+ f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
220
+ )
221
+ if len(args) + len(kwargs) > len(method_argnames):
222
+ raise TypeError(
223
+ f"{method_name}() got an unexpected keyword argument {next(iter(kwargs))}"
224
+ )
225
+ if len(args) == 1 and (not kwargs) and isinstance(args[0], list):
226
+ batch_inputs = args[0]
227
+ # Validate each input is a dictionary
228
+ is_batch_input_valid = all(
229
+ isinstance(input, dict) for input in batch_inputs
230
+ )
231
+ if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
232
+ # If the batch input is valid, call the function with the batch inputs and the method name
233
+ return async_call_func(batch_inputs, method_name)
234
+
235
+ for name, arg in zip(
236
+ method_argnames, args
237
+ ): # handle positional with zip shortest
238
+ if name in kwargs:
239
+ raise TypeError(f"Multiple values for argument {name}")
240
+ kwargs[name] = arg
241
+
242
+ return async_call_func(kwargs, method_name)
243
+
244
+ class MethodWrapper:
245
+ def __call__(self, *args, **kwargs):
246
+ if is_async_context():
247
+ return async_f(*args, **kwargs)
248
+ return sync_f(*args, **kwargs)
249
+
250
+ return MethodWrapper()
191
251
 
192
252
  # need to bind method_name to the value, not the mutating loop variable
193
- f = bind_f(method_name, method_argnames, call_func)
253
+ f = bind_f(method_name, method_argnames, call_func, async_call_func)
194
254
 
195
255
  # set names, annotations and docstrings
196
256
  f.__name__ = method_name
@@ -231,7 +291,11 @@ class ModelClient:
231
291
  self.fetch()
232
292
  return method_signatures.get_method_signature(self._method_signatures[method_name])
233
293
 
234
- def generate_client_script(self) -> str:
294
+ def generate_client_script(
295
+ self,
296
+ base_url: str = None,
297
+ use_ctx: bool = False,
298
+ ) -> str:
235
299
  """Generate a client script for this model.
236
300
 
237
301
  Returns:
@@ -247,6 +311,11 @@ class ModelClient:
247
311
  user_id=self.request_template.user_app_id.user_id,
248
312
  app_id=self.request_template.user_app_id.app_id,
249
313
  model_id=self.request_template.model_id,
314
+ base_url=base_url,
315
+ deployment_id=self.request_template.runner_selector.deployment.id,
316
+ compute_cluster_id=self.request_template.runner_selector.nodepool.compute_cluster.id,
317
+ nodepool_id=self.request_template.runner_selector.nodepool.id,
318
+ use_ctx=use_ctx,
250
319
  )
251
320
 
252
321
  def _define_compatability_functions(self):
@@ -354,6 +423,109 @@ class ModelClient:
354
423
  break
355
424
  return response
356
425
 
426
+ async def _async_predict(
427
+ self,
428
+ inputs,
429
+ method_name: str = 'predict',
430
+ ) -> Any:
431
+ """Asynchronously process inputs and make predictions.
432
+
433
+ Args:
434
+ inputs: Input data to process
435
+ method_name (str): Name of the method to call
436
+
437
+ Returns:
438
+ Processed prediction results
439
+ """
440
+ # method_name is set to 'predict' by default, this is because to replicate the input and output signature behaviour of sync to async predict.
441
+ input_signature = self._method_signatures[method_name].input_fields
442
+ output_signature = self._method_signatures[method_name].output_fields
443
+
444
+ batch_input = True
445
+ if isinstance(inputs, dict):
446
+ inputs = [inputs]
447
+ batch_input = False
448
+
449
+ proto_inputs = []
450
+ for input in inputs:
451
+ proto = resources_pb2.Input()
452
+ serialize(input, input_signature, proto.data)
453
+ proto_inputs.append(proto)
454
+ response = await self._async_predict_by_proto(proto_inputs, method_name)
455
+ outputs = []
456
+ for output in response.outputs:
457
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
458
+
459
+ return outputs if batch_input else outputs[0]
460
+
461
+ async def _async_predict_by_proto(
462
+ self,
463
+ inputs: List[resources_pb2.Input],
464
+ method_name: str = None,
465
+ inference_params: Dict = None,
466
+ output_config: Dict = None,
467
+ ) -> service_pb2.MultiOutputResponse:
468
+ """Asynchronously predicts the model based on the given inputs.
469
+
470
+ Args:
471
+ inputs (List[resources_pb2.Input]): The inputs to predict.
472
+ method_name (str): The remote method name to call.
473
+ inference_params (Dict): Inference parameters to override.
474
+ output_config (Dict): Output configuration to override.
475
+
476
+ Returns:
477
+ service_pb2.MultiOutputResponse: The prediction response(s).
478
+
479
+ Raises:
480
+ UserError: If inputs are invalid or exceed maximum limit.
481
+ Exception: If the model prediction fails.
482
+ """
483
+ if not isinstance(inputs, list):
484
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
485
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
486
+ raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
487
+
488
+ request = service_pb2.PostModelOutputsRequest()
489
+ request.CopyFrom(self.request_template)
490
+ request.inputs.extend(inputs)
491
+
492
+ if method_name:
493
+ for inp in request.inputs:
494
+ inp.data.metadata['_method_name'] = method_name
495
+ if inference_params:
496
+ request.model.model_version.output_info.params.update(inference_params)
497
+ if output_config:
498
+ request.model.model_version.output_info.output_config.MergeFrom(
499
+ resources_pb2.OutputConfig(**output_config)
500
+ )
501
+
502
+ start_time = time.time()
503
+ backoff_iterator = BackoffIterator(10)
504
+
505
+ while True:
506
+ try:
507
+ response = await self.async_stub.PostModelOutputs(request)
508
+
509
+ if (
510
+ status_is_retryable(response.status.code)
511
+ and time.time() - start_time < 60 * 10
512
+ ): # 10 minutes
513
+ logger.info("Model is still deploying, please wait...")
514
+ await asyncio.sleep(next(backoff_iterator))
515
+ continue
516
+
517
+ if response.status.code != status_code_pb2.SUCCESS:
518
+ raise Exception(f"Model predict failed with response {response!r}")
519
+
520
+ return response
521
+
522
+ except Exception as e:
523
+ if time.time() - start_time >= 10 * 1: # 10 minutes timeout
524
+ raise Exception("Model prediction timed out after 10 minutes") from e
525
+ logger.error(f"Error during prediction: {e}")
526
+ await asyncio.sleep(next(backoff_iterator))
527
+ continue
528
+
357
529
  def _generate(
358
530
  self,
359
531
  inputs, # TODO set up functions according to fetched signatures?
@@ -444,6 +616,99 @@ class ModelClient:
444
616
  raise Exception(f"Model Generate failed with response {response.status!r}")
445
617
  yield response
446
618
 
619
+ async def _async_generate(
620
+ self,
621
+ inputs,
622
+ method_name: str = 'generate',
623
+ ) -> Any:
624
+ # method_name is set to 'generate' by default, this is because to replicate the input and output signature behaviour of sync to async generate.
625
+ input_signature = self._method_signatures[method_name].input_fields
626
+ output_signature = self._method_signatures[method_name].output_fields
627
+
628
+ batch_input = True
629
+ if isinstance(inputs, dict):
630
+ inputs = [inputs]
631
+ batch_input = False
632
+
633
+ proto_inputs = []
634
+ for input in inputs:
635
+ proto = resources_pb2.Input()
636
+ serialize(input, input_signature, proto.data)
637
+ proto_inputs.append(proto)
638
+ response_stream = self._async_generate_by_proto(proto_inputs, method_name)
639
+
640
+ async for response in response_stream:
641
+ outputs = []
642
+ for output in response.outputs:
643
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
644
+ if batch_input:
645
+ yield outputs
646
+ else:
647
+ yield outputs[0]
648
+
649
+ async def _async_generate_by_proto(
650
+ self,
651
+ inputs: List[resources_pb2.Input],
652
+ method_name: str = None,
653
+ inference_params: Dict = {},
654
+ output_config: Dict = {},
655
+ ):
656
+ """Generate the async stream output on model based on the given inputs.
657
+
658
+ Args:
659
+ inputs (list[Input]): The inputs to generate, must be less than 128.
660
+ method_name (str): The remote method name to call.
661
+ inference_params (dict): The inference params to override.
662
+ output_config (dict): The output config to override.
663
+ """
664
+ if not isinstance(inputs, list):
665
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
666
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
667
+ raise UserError(
668
+ f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
669
+ ) # TODO Use Chunker for inputs len > 128
670
+
671
+ request = service_pb2.PostModelOutputsRequest()
672
+ request.CopyFrom(self.request_template)
673
+
674
+ request.inputs.extend(inputs)
675
+
676
+ if method_name:
677
+ # TODO put in new proto field?
678
+ for inp in request.inputs:
679
+ inp.data.metadata['_method_name'] = method_name
680
+ if inference_params:
681
+ request.model.model_version.output_info.params.update(inference_params)
682
+ if output_config:
683
+ request.model.model_version.output_info.output_config.MergeFromDict(output_config)
684
+
685
+ start_time = time.time()
686
+ backoff_iterator = BackoffIterator(10)
687
+ started = False
688
+ while not started:
689
+ # stream response returns gRPC async iterable - UnaryStreamCall
690
+ stream_response = self.async_stub.GenerateModelOutputs(request)
691
+ stream_resp = await stream_response # get the async iterable
692
+ iterator = stream_resp.__aiter__() # get the async iterator for the response
693
+ try:
694
+ response = await iterator.__anext__() # getting the first response
695
+ except StopAsyncIteration:
696
+ raise Exception("Model Generate failed with no response")
697
+ if status_is_retryable(response.status.code) and time.time() - start_time < 60 * 10:
698
+ logger.info("Model is still deploying, please wait...")
699
+ await asyncio.sleep(next(backoff_iterator))
700
+ continue
701
+ if response.status.code != status_code_pb2.SUCCESS:
702
+ raise Exception(f"Model Generate failed with response {response.status!r}")
703
+ started = True
704
+
705
+ yield response # yield the first response
706
+
707
+ async for response in iterator:
708
+ if response.status.code != status_code_pb2.SUCCESS:
709
+ raise Exception(f"Model Generate failed with response {response.status!r}")
710
+ yield response
711
+
447
712
  def _stream(
448
713
  self,
449
714
  inputs,
@@ -551,3 +816,87 @@ class ModelClient:
551
816
  if not generation_started:
552
817
  generation_started = True
553
818
  yield response
819
+
820
+ # TODO: Test async streaming.
821
+ async def _async_stream(
822
+ self,
823
+ inputs,
824
+ method_name: str = 'stream',
825
+ ) -> Any:
826
+ # method_name is set to 'stream' by default, this is because to replicate the input and output signature behaviour of sync to async stream.
827
+ input_signature = self._method_signatures[method_name].input_fields
828
+ output_signature = self._method_signatures[method_name].output_fields
829
+
830
+ if isinstance(inputs, list):
831
+ assert len(inputs) == 1, 'streaming methods do not support batched calls'
832
+ inputs = inputs[0]
833
+ assert isinstance(inputs, dict)
834
+ kwargs = inputs
835
+
836
+ # find the streaming vars in the input signature, and the streaming input python param
837
+ stream_sig = get_stream_from_signature(input_signature)
838
+ if stream_sig is None:
839
+ raise ValueError("Streaming method must have a Stream input")
840
+ stream_argname = stream_sig.name
841
+
842
+ # get the streaming input generator from the user-provided function arg values
843
+ user_inputs_generator = kwargs.pop(stream_argname)
844
+
845
+ async def _input_proto_stream():
846
+ # first item contains all the inputs and the first stream item
847
+ proto = resources_pb2.Input()
848
+ try:
849
+ item = await user_inputs_generator.__anext__()
850
+ except StopAsyncIteration:
851
+ return # no items to stream
852
+ kwargs[stream_argname] = item
853
+ serialize(kwargs, input_signature, proto.data)
854
+
855
+ yield proto
856
+
857
+ # subsequent items are just the stream items
858
+ async for item in user_inputs_generator:
859
+ proto = resources_pb2.Input()
860
+ serialize({stream_argname: item}, [stream_sig], proto.data)
861
+ yield proto
862
+
863
+ response_stream = await self._async_stream_by_proto(_input_proto_stream(), method_name)
864
+
865
+ async for response in response_stream:
866
+ assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
867
+ yield deserialize(response.outputs[0].data, output_signature, is_output=True)
868
+
869
+ async def _async_stream_by_proto(
870
+ self,
871
+ inputs: Iterator[List[resources_pb2.Input]],
872
+ method_name: str = None,
873
+ inference_params: Dict = {},
874
+ output_config: Dict = {},
875
+ ):
876
+ """Generate the async stream output on model based on the given stream of inputs."""
877
+ # if not isinstance(inputs, Iterator[List[Input]]):
878
+ # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
879
+
880
+ request = self._req_iterator(inputs, method_name, inference_params, output_config)
881
+
882
+ start_time = time.time()
883
+ backoff_iterator = BackoffIterator(10)
884
+ generation_started = False
885
+ while True:
886
+ if generation_started:
887
+ break
888
+ stream_response = await self.async_stub.StreamModelOutputs(request)
889
+ async for response in stream_response:
890
+ if (
891
+ status_is_retryable(response.status.code)
892
+ and time.time() - start_time < 60 * 10
893
+ ):
894
+ logger.info("Model is still deploying, please wait...")
895
+ await asyncio.sleep(next(backoff_iterator))
896
+ break
897
+ if response.status.code != status_code_pb2.SUCCESS:
898
+ raise Exception(f"Model Predict failed with response {response.status!r}")
899
+ else:
900
+ if not generation_started:
901
+ generation_started = True
902
+ yield response
clarifai/client/user.py CHANGED
@@ -456,3 +456,84 @@ class User(Lister, BaseClient):
456
456
  if hasattr(self.user_info, param)
457
457
  ]
458
458
  return f"Clarifai User Details: \n{', '.join(attribute_strings)}\n"
459
+
460
+ def list_models(
461
+ self,
462
+ user_id: str = None,
463
+ app_id: str = None,
464
+ show: bool = True,
465
+ return_clarifai_model: bool = False,
466
+ **kwargs,
467
+ ):
468
+ if user_id == "all":
469
+ params = {}
470
+ elif user_id:
471
+ user_app_id = resources_pb2.UserAppIDSet(user_id=user_id, app_id=app_id)
472
+ params = {"user_app_id": user_app_id}
473
+ elif not user_id:
474
+ user_app_id = resources_pb2.UserAppIDSet(
475
+ user_id=self.user_app_id.user_id, app_id=app_id
476
+ )
477
+ params = {"user_app_id": user_app_id}
478
+
479
+ params.update(**kwargs)
480
+ models = self.list_pages_generator(
481
+ self.STUB.ListModels, service_pb2.ListModelsRequest, request_data=params
482
+ )
483
+ all_data = []
484
+ for model in models:
485
+ url = (
486
+ f"https://clarifai.com/{model['user_id']}/{model['app_id']}/models/{model['name']}"
487
+ )
488
+ data = dict(
489
+ user_id=model["user_id"],
490
+ app_id=model["app_id"],
491
+ id=model["model_id"],
492
+ model_type=model["model_type_id"],
493
+ url=url,
494
+ )
495
+ method_types_data = dict(
496
+ supported_openai_client=False,
497
+ UNARY_UNARY={"predict"},
498
+ UNARY_STREAMING=set(),
499
+ STREAMING_STREAMING=set(),
500
+ )
501
+ for each_method in model.get("model_version", {}).get("method_signatures", []):
502
+ name = each_method["name"]
503
+ method_type = each_method["method_type"]
504
+ method_types_data[method_type].add(name)
505
+ if (
506
+ "openai_transport" in method_types_data["UNARY_UNARY"]
507
+ and "openai_stream_transport" in method_types_data["UNARY_STREAMING"]
508
+ ):
509
+ method_types_data["supported_openai_client"] = True
510
+ method_types_data["UNARY_UNARY"].remove("openai_transport")
511
+ method_types_data["UNARY_STREAMING"].remove("openai_stream_transport")
512
+ for k, v in method_types_data.items():
513
+ if k != "supported_openai_client":
514
+ if not v:
515
+ method_types_data[k] = None
516
+ else:
517
+ method_types_data[k] = list(v)
518
+
519
+ data.update(method_types_data)
520
+ all_data.append(data)
521
+
522
+ if show:
523
+ from tabulate import tabulate
524
+
525
+ print(tabulate(all_data, headers="keys"))
526
+
527
+ if return_clarifai_model:
528
+ from clarifai.client import Model
529
+
530
+ models = []
531
+ for each_data in all_data:
532
+ model = Model.from_auth_helper(
533
+ self.auth_helper,
534
+ url=each_data["url"],
535
+ )
536
+ models.append(model)
537
+ return models
538
+ else:
539
+ return all_data
@@ -10,6 +10,7 @@ import tarfile
10
10
  import time
11
11
  import webbrowser
12
12
  from string import Template
13
+ from typing import Literal
13
14
  from unittest.mock import MagicMock
14
15
 
15
16
  import yaml
@@ -22,6 +23,7 @@ from clarifai.client.user import User
22
23
  from clarifai.runners.models.model_class import ModelClass
23
24
  from clarifai.runners.utils.const import (
24
25
  AMD_PYTHON_BASE_IMAGE,
26
+ AMD_TORCH_BASE_IMAGE,
25
27
  AMD_VLLM_BASE_IMAGE,
26
28
  AVAILABLE_PYTHON_IMAGES,
27
29
  AVAILABLE_TORCH_IMAGES,
@@ -45,6 +47,7 @@ from clarifai.versions import CLIENT_VERSION
45
47
  dependencies = [
46
48
  'torch',
47
49
  'clarifai',
50
+ 'vllm',
48
51
  ]
49
52
 
50
53
 
@@ -69,6 +72,7 @@ class ModelBuilder:
69
72
  folder: str,
70
73
  validate_api_ids: bool = True,
71
74
  download_validation_only: bool = False,
75
+ app_not_found_action: Literal["auto_create", "prompt", "error"] = "error",
72
76
  ):
73
77
  """
74
78
  :param folder: The folder containing the model.py, config.yaml, requirements.txt and
@@ -77,7 +81,13 @@ class ModelBuilder:
77
81
  deprecate in favor of download_validation_only.
78
82
  :param download_validation_only: Whether to skip the API config validation. Set to True if
79
83
  just downloading a checkpoint.
84
+ :param app_not_found_action: Defines how to handle the case when the app is not found.
85
+ Options: 'auto_create' - create automatically, 'prompt' - ask user, 'error' - raise exception.
80
86
  """
87
+ assert app_not_found_action in ["auto_create", "prompt", "error"], ValueError(
88
+ f"Expected one of {['auto_create', 'prompt', 'error']}, got {app_not_found_action=}"
89
+ )
90
+ self.app_not_found_action = app_not_found_action
81
91
  self._client = None
82
92
  if not validate_api_ids: # for backwards compatibility
83
93
  download_validation_only = True
@@ -304,13 +314,46 @@ class ModelBuilder:
304
314
  f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
305
315
  )
306
316
  return False
307
- logger.error(
308
- f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
309
- )
310
- logger.error(
311
- f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
312
- )
313
- return False
317
+
318
+ user_id = self.client.user_app_id.user_id
319
+ app_id = self.client.user_app_id.app_id
320
+
321
+ if self.app_not_found_action == "error":
322
+ logger.error(
323
+ f"Error checking API {self._base_api} for user app `{user_id}/{app_id}`. Error code: {resp.status.code}"
324
+ )
325
+ logger.error(
326
+ f"App `{app_id}` not found for user `{user_id}`. Please create the app first and try again."
327
+ )
328
+ return False
329
+ else:
330
+ user = User(
331
+ user_id=user_id,
332
+ pat=self.client.pat,
333
+ token=self.client.token,
334
+ base_url=self.client.base,
335
+ )
336
+
337
+ def create_app():
338
+ logger.info(f"Creating App `{app_id}` user `{user_id}`.")
339
+ user.create_app(app_id=app_id)
340
+
341
+ logger.info(f"App {app_id} not found for user {user_id}.")
342
+
343
+ if self.app_not_found_action == "prompt":
344
+ create_app_prompt = input(f"Do you want to create App `{app_id}`? (y/n): ")
345
+ if create_app_prompt.lower() == 'y':
346
+ create_app()
347
+ return True
348
+ else:
349
+ logger.error(
350
+ f"App `{app_id}` has not been created for user `{user_id}`. Please create the app first or switch to an existing one, then try again."
351
+ )
352
+ return False
353
+
354
+ elif self.app_not_found_action == "auto_create":
355
+ create_app()
356
+ return True
314
357
 
315
358
  def _validate_config_model(self):
316
359
  assert "model" in self.config, "model section not found in the config file"
@@ -759,7 +802,7 @@ class ModelBuilder:
759
802
  )
760
803
  python_version = DEFAULT_PYTHON_VERSION
761
804
  gpu_version = DEFAULT_AMD_GPU_VERSION
762
- final_image = TORCH_BASE_IMAGE.format(
805
+ final_image = AMD_TORCH_BASE_IMAGE.format(
763
806
  torch_version=torch_version,
764
807
  python_version=python_version,
765
808
  gpu_version=gpu_version,
@@ -1186,7 +1229,7 @@ def upload_model(folder, stage, skip_dockerfile):
1186
1229
  :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
1187
1230
  :param skip_dockerfile: If True, will not create a Dockerfile.
1188
1231
  """
1189
- builder = ModelBuilder(folder)
1232
+ builder = ModelBuilder(folder, app_not_found_action="prompt")
1190
1233
  builder.download_checkpoints(stage=stage)
1191
1234
  if not skip_dockerfile:
1192
1235
  builder.create_dockerfile()
File without changes