clarifai 11.5.2__py3-none-any.whl → 11.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/model.py +33 -1
- clarifai/cli/pipeline.py +137 -0
- clarifai/cli/pipeline_step.py +104 -0
- clarifai/cli/templates/__init__.py +1 -0
- clarifai/cli/templates/pipeline_step_templates.py +64 -0
- clarifai/cli/templates/pipeline_templates.py +150 -0
- clarifai/client/auth/helper.py +23 -0
- clarifai/client/auth/register.py +5 -0
- clarifai/client/auth/stub.py +116 -12
- clarifai/client/base.py +9 -0
- clarifai/client/model.py +111 -7
- clarifai/client/model_client.py +355 -6
- clarifai/client/user.py +81 -0
- clarifai/runners/models/model_builder.py +52 -9
- clarifai/runners/pipeline_steps/__init__.py +0 -0
- clarifai/runners/pipeline_steps/pipeline_step_builder.py +510 -0
- clarifai/runners/pipelines/__init__.py +0 -0
- clarifai/runners/pipelines/pipeline_builder.py +313 -0
- clarifai/runners/utils/code_script.py +40 -7
- clarifai/runners/utils/const.py +2 -2
- clarifai/runners/utils/model_utils.py +135 -0
- clarifai/runners/utils/pipeline_validation.py +153 -0
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/METADATA +1 -1
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/RECORD +30 -19
- /clarifai/cli/{model_templates.py → templates/model_templates.py} +0 -0
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/WHEEL +0 -0
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.5.2.dist-info → clarifai-11.5.3.dist-info}/top_level.txt +0 -0
clarifai/client/model_client.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import inspect
|
2
3
|
import time
|
3
4
|
from typing import Any, Dict, Iterator, List
|
@@ -5,6 +6,7 @@ from typing import Any, Dict, Iterator, List
|
|
5
6
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
6
7
|
from clarifai_grpc.grpc.api.status import status_code_pb2
|
7
8
|
|
9
|
+
from clarifai.client.auth.register import V2Stub
|
8
10
|
from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
|
9
11
|
from clarifai.errors import UserError
|
10
12
|
from clarifai.runners.utils import code_script, method_signatures
|
@@ -20,12 +22,31 @@ from clarifai.utils.logging import logger
|
|
20
22
|
from clarifai.utils.misc import BackoffIterator, status_is_retryable
|
21
23
|
|
22
24
|
|
25
|
+
def is_async_context():
|
26
|
+
"""Check if code is running in an async context."""
|
27
|
+
try:
|
28
|
+
asyncio.get_running_loop()
|
29
|
+
import sys
|
30
|
+
|
31
|
+
# In Jupyter, to check if we're actually in an async cell. Becaue by default jupyter considers it as async.
|
32
|
+
if 'ipykernel' in sys.modules:
|
33
|
+
return False
|
34
|
+
return True
|
35
|
+
except RuntimeError:
|
36
|
+
return False
|
37
|
+
|
38
|
+
|
23
39
|
class ModelClient:
|
24
40
|
'''
|
25
41
|
Client for calling model predict, generate, and stream methods.
|
26
42
|
'''
|
27
43
|
|
28
|
-
def __init__(
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
stub,
|
47
|
+
async_stub: V2Stub = None,
|
48
|
+
request_template: service_pb2.PostModelOutputsRequest = None,
|
49
|
+
):
|
29
50
|
'''
|
30
51
|
Initialize the model client.
|
31
52
|
|
@@ -35,6 +56,7 @@ class ModelClient:
|
|
35
56
|
common fields like model_id, model_version, cluster, etc.
|
36
57
|
'''
|
37
58
|
self.STUB = stub
|
59
|
+
self.async_stub = async_stub
|
38
60
|
self.request_template = request_template or service_pb2.PostModelOutputsRequest()
|
39
61
|
self._method_signatures = None
|
40
62
|
self._defined = False
|
@@ -137,16 +159,19 @@ class ModelClient:
|
|
137
159
|
# define the function in this client instance
|
138
160
|
if resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_UNARY':
|
139
161
|
call_func = self._predict
|
162
|
+
async_call_func = self._async_predict
|
140
163
|
elif (
|
141
164
|
resources_pb2.RunnerMethodType.Name(method_signature.method_type)
|
142
165
|
== 'UNARY_STREAMING'
|
143
166
|
):
|
144
167
|
call_func = self._generate
|
168
|
+
async_call_func = self._async_generate
|
145
169
|
elif (
|
146
170
|
resources_pb2.RunnerMethodType.Name(method_signature.method_type)
|
147
171
|
== 'STREAMING_STREAMING'
|
148
172
|
):
|
149
173
|
call_func = self._stream
|
174
|
+
async_call_func = self._async_stream
|
150
175
|
else:
|
151
176
|
raise ValueError(f"Unknown method type {method_signature.method_type}")
|
152
177
|
|
@@ -158,8 +183,8 @@ class ModelClient:
|
|
158
183
|
continue
|
159
184
|
method_argnames.append(outer)
|
160
185
|
|
161
|
-
def bind_f(method_name, method_argnames, call_func):
|
162
|
-
def
|
186
|
+
def bind_f(method_name, method_argnames, call_func, async_call_func):
|
187
|
+
def sync_f(*args, **kwargs):
|
163
188
|
if len(args) > len(method_argnames):
|
164
189
|
raise TypeError(
|
165
190
|
f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
|
@@ -187,10 +212,45 @@ class ModelClient:
|
|
187
212
|
kwargs[name] = arg
|
188
213
|
return call_func(kwargs, method_name)
|
189
214
|
|
190
|
-
|
215
|
+
async def async_f(*args, **kwargs):
|
216
|
+
# Async version to call the async function
|
217
|
+
if len(args) > len(method_argnames):
|
218
|
+
raise TypeError(
|
219
|
+
f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
|
220
|
+
)
|
221
|
+
if len(args) + len(kwargs) > len(method_argnames):
|
222
|
+
raise TypeError(
|
223
|
+
f"{method_name}() got an unexpected keyword argument {next(iter(kwargs))}"
|
224
|
+
)
|
225
|
+
if len(args) == 1 and (not kwargs) and isinstance(args[0], list):
|
226
|
+
batch_inputs = args[0]
|
227
|
+
# Validate each input is a dictionary
|
228
|
+
is_batch_input_valid = all(
|
229
|
+
isinstance(input, dict) for input in batch_inputs
|
230
|
+
)
|
231
|
+
if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
|
232
|
+
# If the batch input is valid, call the function with the batch inputs and the method name
|
233
|
+
return async_call_func(batch_inputs, method_name)
|
234
|
+
|
235
|
+
for name, arg in zip(
|
236
|
+
method_argnames, args
|
237
|
+
): # handle positional with zip shortest
|
238
|
+
if name in kwargs:
|
239
|
+
raise TypeError(f"Multiple values for argument {name}")
|
240
|
+
kwargs[name] = arg
|
241
|
+
|
242
|
+
return async_call_func(kwargs, method_name)
|
243
|
+
|
244
|
+
class MethodWrapper:
|
245
|
+
def __call__(self, *args, **kwargs):
|
246
|
+
if is_async_context():
|
247
|
+
return async_f(*args, **kwargs)
|
248
|
+
return sync_f(*args, **kwargs)
|
249
|
+
|
250
|
+
return MethodWrapper()
|
191
251
|
|
192
252
|
# need to bind method_name to the value, not the mutating loop variable
|
193
|
-
f = bind_f(method_name, method_argnames, call_func)
|
253
|
+
f = bind_f(method_name, method_argnames, call_func, async_call_func)
|
194
254
|
|
195
255
|
# set names, annotations and docstrings
|
196
256
|
f.__name__ = method_name
|
@@ -231,7 +291,11 @@ class ModelClient:
|
|
231
291
|
self.fetch()
|
232
292
|
return method_signatures.get_method_signature(self._method_signatures[method_name])
|
233
293
|
|
234
|
-
def generate_client_script(
|
294
|
+
def generate_client_script(
|
295
|
+
self,
|
296
|
+
base_url: str = None,
|
297
|
+
use_ctx: bool = False,
|
298
|
+
) -> str:
|
235
299
|
"""Generate a client script for this model.
|
236
300
|
|
237
301
|
Returns:
|
@@ -247,6 +311,11 @@ class ModelClient:
|
|
247
311
|
user_id=self.request_template.user_app_id.user_id,
|
248
312
|
app_id=self.request_template.user_app_id.app_id,
|
249
313
|
model_id=self.request_template.model_id,
|
314
|
+
base_url=base_url,
|
315
|
+
deployment_id=self.request_template.runner_selector.deployment.id,
|
316
|
+
compute_cluster_id=self.request_template.runner_selector.nodepool.compute_cluster.id,
|
317
|
+
nodepool_id=self.request_template.runner_selector.nodepool.id,
|
318
|
+
use_ctx=use_ctx,
|
250
319
|
)
|
251
320
|
|
252
321
|
def _define_compatability_functions(self):
|
@@ -354,6 +423,109 @@ class ModelClient:
|
|
354
423
|
break
|
355
424
|
return response
|
356
425
|
|
426
|
+
async def _async_predict(
|
427
|
+
self,
|
428
|
+
inputs,
|
429
|
+
method_name: str = 'predict',
|
430
|
+
) -> Any:
|
431
|
+
"""Asynchronously process inputs and make predictions.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
inputs: Input data to process
|
435
|
+
method_name (str): Name of the method to call
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
Processed prediction results
|
439
|
+
"""
|
440
|
+
# method_name is set to 'predict' by default, this is because to replicate the input and output signature behaviour of sync to async predict.
|
441
|
+
input_signature = self._method_signatures[method_name].input_fields
|
442
|
+
output_signature = self._method_signatures[method_name].output_fields
|
443
|
+
|
444
|
+
batch_input = True
|
445
|
+
if isinstance(inputs, dict):
|
446
|
+
inputs = [inputs]
|
447
|
+
batch_input = False
|
448
|
+
|
449
|
+
proto_inputs = []
|
450
|
+
for input in inputs:
|
451
|
+
proto = resources_pb2.Input()
|
452
|
+
serialize(input, input_signature, proto.data)
|
453
|
+
proto_inputs.append(proto)
|
454
|
+
response = await self._async_predict_by_proto(proto_inputs, method_name)
|
455
|
+
outputs = []
|
456
|
+
for output in response.outputs:
|
457
|
+
outputs.append(deserialize(output.data, output_signature, is_output=True))
|
458
|
+
|
459
|
+
return outputs if batch_input else outputs[0]
|
460
|
+
|
461
|
+
async def _async_predict_by_proto(
|
462
|
+
self,
|
463
|
+
inputs: List[resources_pb2.Input],
|
464
|
+
method_name: str = None,
|
465
|
+
inference_params: Dict = None,
|
466
|
+
output_config: Dict = None,
|
467
|
+
) -> service_pb2.MultiOutputResponse:
|
468
|
+
"""Asynchronously predicts the model based on the given inputs.
|
469
|
+
|
470
|
+
Args:
|
471
|
+
inputs (List[resources_pb2.Input]): The inputs to predict.
|
472
|
+
method_name (str): The remote method name to call.
|
473
|
+
inference_params (Dict): Inference parameters to override.
|
474
|
+
output_config (Dict): Output configuration to override.
|
475
|
+
|
476
|
+
Returns:
|
477
|
+
service_pb2.MultiOutputResponse: The prediction response(s).
|
478
|
+
|
479
|
+
Raises:
|
480
|
+
UserError: If inputs are invalid or exceed maximum limit.
|
481
|
+
Exception: If the model prediction fails.
|
482
|
+
"""
|
483
|
+
if not isinstance(inputs, list):
|
484
|
+
raise UserError('Invalid inputs, inputs must be a list of Input objects.')
|
485
|
+
if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
|
486
|
+
raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
|
487
|
+
|
488
|
+
request = service_pb2.PostModelOutputsRequest()
|
489
|
+
request.CopyFrom(self.request_template)
|
490
|
+
request.inputs.extend(inputs)
|
491
|
+
|
492
|
+
if method_name:
|
493
|
+
for inp in request.inputs:
|
494
|
+
inp.data.metadata['_method_name'] = method_name
|
495
|
+
if inference_params:
|
496
|
+
request.model.model_version.output_info.params.update(inference_params)
|
497
|
+
if output_config:
|
498
|
+
request.model.model_version.output_info.output_config.MergeFrom(
|
499
|
+
resources_pb2.OutputConfig(**output_config)
|
500
|
+
)
|
501
|
+
|
502
|
+
start_time = time.time()
|
503
|
+
backoff_iterator = BackoffIterator(10)
|
504
|
+
|
505
|
+
while True:
|
506
|
+
try:
|
507
|
+
response = await self.async_stub.PostModelOutputs(request)
|
508
|
+
|
509
|
+
if (
|
510
|
+
status_is_retryable(response.status.code)
|
511
|
+
and time.time() - start_time < 60 * 10
|
512
|
+
): # 10 minutes
|
513
|
+
logger.info("Model is still deploying, please wait...")
|
514
|
+
await asyncio.sleep(next(backoff_iterator))
|
515
|
+
continue
|
516
|
+
|
517
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
518
|
+
raise Exception(f"Model predict failed with response {response!r}")
|
519
|
+
|
520
|
+
return response
|
521
|
+
|
522
|
+
except Exception as e:
|
523
|
+
if time.time() - start_time >= 10 * 1: # 10 minutes timeout
|
524
|
+
raise Exception("Model prediction timed out after 10 minutes") from e
|
525
|
+
logger.error(f"Error during prediction: {e}")
|
526
|
+
await asyncio.sleep(next(backoff_iterator))
|
527
|
+
continue
|
528
|
+
|
357
529
|
def _generate(
|
358
530
|
self,
|
359
531
|
inputs, # TODO set up functions according to fetched signatures?
|
@@ -444,6 +616,99 @@ class ModelClient:
|
|
444
616
|
raise Exception(f"Model Generate failed with response {response.status!r}")
|
445
617
|
yield response
|
446
618
|
|
619
|
+
async def _async_generate(
|
620
|
+
self,
|
621
|
+
inputs,
|
622
|
+
method_name: str = 'generate',
|
623
|
+
) -> Any:
|
624
|
+
# method_name is set to 'generate' by default, this is because to replicate the input and output signature behaviour of sync to async generate.
|
625
|
+
input_signature = self._method_signatures[method_name].input_fields
|
626
|
+
output_signature = self._method_signatures[method_name].output_fields
|
627
|
+
|
628
|
+
batch_input = True
|
629
|
+
if isinstance(inputs, dict):
|
630
|
+
inputs = [inputs]
|
631
|
+
batch_input = False
|
632
|
+
|
633
|
+
proto_inputs = []
|
634
|
+
for input in inputs:
|
635
|
+
proto = resources_pb2.Input()
|
636
|
+
serialize(input, input_signature, proto.data)
|
637
|
+
proto_inputs.append(proto)
|
638
|
+
response_stream = self._async_generate_by_proto(proto_inputs, method_name)
|
639
|
+
|
640
|
+
async for response in response_stream:
|
641
|
+
outputs = []
|
642
|
+
for output in response.outputs:
|
643
|
+
outputs.append(deserialize(output.data, output_signature, is_output=True))
|
644
|
+
if batch_input:
|
645
|
+
yield outputs
|
646
|
+
else:
|
647
|
+
yield outputs[0]
|
648
|
+
|
649
|
+
async def _async_generate_by_proto(
|
650
|
+
self,
|
651
|
+
inputs: List[resources_pb2.Input],
|
652
|
+
method_name: str = None,
|
653
|
+
inference_params: Dict = {},
|
654
|
+
output_config: Dict = {},
|
655
|
+
):
|
656
|
+
"""Generate the async stream output on model based on the given inputs.
|
657
|
+
|
658
|
+
Args:
|
659
|
+
inputs (list[Input]): The inputs to generate, must be less than 128.
|
660
|
+
method_name (str): The remote method name to call.
|
661
|
+
inference_params (dict): The inference params to override.
|
662
|
+
output_config (dict): The output config to override.
|
663
|
+
"""
|
664
|
+
if not isinstance(inputs, list):
|
665
|
+
raise UserError('Invalid inputs, inputs must be a list of Input objects.')
|
666
|
+
if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
|
667
|
+
raise UserError(
|
668
|
+
f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
|
669
|
+
) # TODO Use Chunker for inputs len > 128
|
670
|
+
|
671
|
+
request = service_pb2.PostModelOutputsRequest()
|
672
|
+
request.CopyFrom(self.request_template)
|
673
|
+
|
674
|
+
request.inputs.extend(inputs)
|
675
|
+
|
676
|
+
if method_name:
|
677
|
+
# TODO put in new proto field?
|
678
|
+
for inp in request.inputs:
|
679
|
+
inp.data.metadata['_method_name'] = method_name
|
680
|
+
if inference_params:
|
681
|
+
request.model.model_version.output_info.params.update(inference_params)
|
682
|
+
if output_config:
|
683
|
+
request.model.model_version.output_info.output_config.MergeFromDict(output_config)
|
684
|
+
|
685
|
+
start_time = time.time()
|
686
|
+
backoff_iterator = BackoffIterator(10)
|
687
|
+
started = False
|
688
|
+
while not started:
|
689
|
+
# stream response returns gRPC async iterable - UnaryStreamCall
|
690
|
+
stream_response = self.async_stub.GenerateModelOutputs(request)
|
691
|
+
stream_resp = await stream_response # get the async iterable
|
692
|
+
iterator = stream_resp.__aiter__() # get the async iterator for the response
|
693
|
+
try:
|
694
|
+
response = await iterator.__anext__() # getting the first response
|
695
|
+
except StopAsyncIteration:
|
696
|
+
raise Exception("Model Generate failed with no response")
|
697
|
+
if status_is_retryable(response.status.code) and time.time() - start_time < 60 * 10:
|
698
|
+
logger.info("Model is still deploying, please wait...")
|
699
|
+
await asyncio.sleep(next(backoff_iterator))
|
700
|
+
continue
|
701
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
702
|
+
raise Exception(f"Model Generate failed with response {response.status!r}")
|
703
|
+
started = True
|
704
|
+
|
705
|
+
yield response # yield the first response
|
706
|
+
|
707
|
+
async for response in iterator:
|
708
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
709
|
+
raise Exception(f"Model Generate failed with response {response.status!r}")
|
710
|
+
yield response
|
711
|
+
|
447
712
|
def _stream(
|
448
713
|
self,
|
449
714
|
inputs,
|
@@ -551,3 +816,87 @@ class ModelClient:
|
|
551
816
|
if not generation_started:
|
552
817
|
generation_started = True
|
553
818
|
yield response
|
819
|
+
|
820
|
+
# TODO: Test async streaming.
|
821
|
+
async def _async_stream(
|
822
|
+
self,
|
823
|
+
inputs,
|
824
|
+
method_name: str = 'stream',
|
825
|
+
) -> Any:
|
826
|
+
# method_name is set to 'stream' by default, this is because to replicate the input and output signature behaviour of sync to async stream.
|
827
|
+
input_signature = self._method_signatures[method_name].input_fields
|
828
|
+
output_signature = self._method_signatures[method_name].output_fields
|
829
|
+
|
830
|
+
if isinstance(inputs, list):
|
831
|
+
assert len(inputs) == 1, 'streaming methods do not support batched calls'
|
832
|
+
inputs = inputs[0]
|
833
|
+
assert isinstance(inputs, dict)
|
834
|
+
kwargs = inputs
|
835
|
+
|
836
|
+
# find the streaming vars in the input signature, and the streaming input python param
|
837
|
+
stream_sig = get_stream_from_signature(input_signature)
|
838
|
+
if stream_sig is None:
|
839
|
+
raise ValueError("Streaming method must have a Stream input")
|
840
|
+
stream_argname = stream_sig.name
|
841
|
+
|
842
|
+
# get the streaming input generator from the user-provided function arg values
|
843
|
+
user_inputs_generator = kwargs.pop(stream_argname)
|
844
|
+
|
845
|
+
async def _input_proto_stream():
|
846
|
+
# first item contains all the inputs and the first stream item
|
847
|
+
proto = resources_pb2.Input()
|
848
|
+
try:
|
849
|
+
item = await user_inputs_generator.__anext__()
|
850
|
+
except StopAsyncIteration:
|
851
|
+
return # no items to stream
|
852
|
+
kwargs[stream_argname] = item
|
853
|
+
serialize(kwargs, input_signature, proto.data)
|
854
|
+
|
855
|
+
yield proto
|
856
|
+
|
857
|
+
# subsequent items are just the stream items
|
858
|
+
async for item in user_inputs_generator:
|
859
|
+
proto = resources_pb2.Input()
|
860
|
+
serialize({stream_argname: item}, [stream_sig], proto.data)
|
861
|
+
yield proto
|
862
|
+
|
863
|
+
response_stream = await self._async_stream_by_proto(_input_proto_stream(), method_name)
|
864
|
+
|
865
|
+
async for response in response_stream:
|
866
|
+
assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
|
867
|
+
yield deserialize(response.outputs[0].data, output_signature, is_output=True)
|
868
|
+
|
869
|
+
async def _async_stream_by_proto(
|
870
|
+
self,
|
871
|
+
inputs: Iterator[List[resources_pb2.Input]],
|
872
|
+
method_name: str = None,
|
873
|
+
inference_params: Dict = {},
|
874
|
+
output_config: Dict = {},
|
875
|
+
):
|
876
|
+
"""Generate the async stream output on model based on the given stream of inputs."""
|
877
|
+
# if not isinstance(inputs, Iterator[List[Input]]):
|
878
|
+
# raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
|
879
|
+
|
880
|
+
request = self._req_iterator(inputs, method_name, inference_params, output_config)
|
881
|
+
|
882
|
+
start_time = time.time()
|
883
|
+
backoff_iterator = BackoffIterator(10)
|
884
|
+
generation_started = False
|
885
|
+
while True:
|
886
|
+
if generation_started:
|
887
|
+
break
|
888
|
+
stream_response = await self.async_stub.StreamModelOutputs(request)
|
889
|
+
async for response in stream_response:
|
890
|
+
if (
|
891
|
+
status_is_retryable(response.status.code)
|
892
|
+
and time.time() - start_time < 60 * 10
|
893
|
+
):
|
894
|
+
logger.info("Model is still deploying, please wait...")
|
895
|
+
await asyncio.sleep(next(backoff_iterator))
|
896
|
+
break
|
897
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
898
|
+
raise Exception(f"Model Predict failed with response {response.status!r}")
|
899
|
+
else:
|
900
|
+
if not generation_started:
|
901
|
+
generation_started = True
|
902
|
+
yield response
|
clarifai/client/user.py
CHANGED
@@ -456,3 +456,84 @@ class User(Lister, BaseClient):
|
|
456
456
|
if hasattr(self.user_info, param)
|
457
457
|
]
|
458
458
|
return f"Clarifai User Details: \n{', '.join(attribute_strings)}\n"
|
459
|
+
|
460
|
+
def list_models(
|
461
|
+
self,
|
462
|
+
user_id: str = None,
|
463
|
+
app_id: str = None,
|
464
|
+
show: bool = True,
|
465
|
+
return_clarifai_model: bool = False,
|
466
|
+
**kwargs,
|
467
|
+
):
|
468
|
+
if user_id == "all":
|
469
|
+
params = {}
|
470
|
+
elif user_id:
|
471
|
+
user_app_id = resources_pb2.UserAppIDSet(user_id=user_id, app_id=app_id)
|
472
|
+
params = {"user_app_id": user_app_id}
|
473
|
+
elif not user_id:
|
474
|
+
user_app_id = resources_pb2.UserAppIDSet(
|
475
|
+
user_id=self.user_app_id.user_id, app_id=app_id
|
476
|
+
)
|
477
|
+
params = {"user_app_id": user_app_id}
|
478
|
+
|
479
|
+
params.update(**kwargs)
|
480
|
+
models = self.list_pages_generator(
|
481
|
+
self.STUB.ListModels, service_pb2.ListModelsRequest, request_data=params
|
482
|
+
)
|
483
|
+
all_data = []
|
484
|
+
for model in models:
|
485
|
+
url = (
|
486
|
+
f"https://clarifai.com/{model['user_id']}/{model['app_id']}/models/{model['name']}"
|
487
|
+
)
|
488
|
+
data = dict(
|
489
|
+
user_id=model["user_id"],
|
490
|
+
app_id=model["app_id"],
|
491
|
+
id=model["model_id"],
|
492
|
+
model_type=model["model_type_id"],
|
493
|
+
url=url,
|
494
|
+
)
|
495
|
+
method_types_data = dict(
|
496
|
+
supported_openai_client=False,
|
497
|
+
UNARY_UNARY={"predict"},
|
498
|
+
UNARY_STREAMING=set(),
|
499
|
+
STREAMING_STREAMING=set(),
|
500
|
+
)
|
501
|
+
for each_method in model.get("model_version", {}).get("method_signatures", []):
|
502
|
+
name = each_method["name"]
|
503
|
+
method_type = each_method["method_type"]
|
504
|
+
method_types_data[method_type].add(name)
|
505
|
+
if (
|
506
|
+
"openai_transport" in method_types_data["UNARY_UNARY"]
|
507
|
+
and "openai_stream_transport" in method_types_data["UNARY_STREAMING"]
|
508
|
+
):
|
509
|
+
method_types_data["supported_openai_client"] = True
|
510
|
+
method_types_data["UNARY_UNARY"].remove("openai_transport")
|
511
|
+
method_types_data["UNARY_STREAMING"].remove("openai_stream_transport")
|
512
|
+
for k, v in method_types_data.items():
|
513
|
+
if k != "supported_openai_client":
|
514
|
+
if not v:
|
515
|
+
method_types_data[k] = None
|
516
|
+
else:
|
517
|
+
method_types_data[k] = list(v)
|
518
|
+
|
519
|
+
data.update(method_types_data)
|
520
|
+
all_data.append(data)
|
521
|
+
|
522
|
+
if show:
|
523
|
+
from tabulate import tabulate
|
524
|
+
|
525
|
+
print(tabulate(all_data, headers="keys"))
|
526
|
+
|
527
|
+
if return_clarifai_model:
|
528
|
+
from clarifai.client import Model
|
529
|
+
|
530
|
+
models = []
|
531
|
+
for each_data in all_data:
|
532
|
+
model = Model.from_auth_helper(
|
533
|
+
self.auth_helper,
|
534
|
+
url=each_data["url"],
|
535
|
+
)
|
536
|
+
models.append(model)
|
537
|
+
return models
|
538
|
+
else:
|
539
|
+
return all_data
|
@@ -10,6 +10,7 @@ import tarfile
|
|
10
10
|
import time
|
11
11
|
import webbrowser
|
12
12
|
from string import Template
|
13
|
+
from typing import Literal
|
13
14
|
from unittest.mock import MagicMock
|
14
15
|
|
15
16
|
import yaml
|
@@ -22,6 +23,7 @@ from clarifai.client.user import User
|
|
22
23
|
from clarifai.runners.models.model_class import ModelClass
|
23
24
|
from clarifai.runners.utils.const import (
|
24
25
|
AMD_PYTHON_BASE_IMAGE,
|
26
|
+
AMD_TORCH_BASE_IMAGE,
|
25
27
|
AMD_VLLM_BASE_IMAGE,
|
26
28
|
AVAILABLE_PYTHON_IMAGES,
|
27
29
|
AVAILABLE_TORCH_IMAGES,
|
@@ -45,6 +47,7 @@ from clarifai.versions import CLIENT_VERSION
|
|
45
47
|
dependencies = [
|
46
48
|
'torch',
|
47
49
|
'clarifai',
|
50
|
+
'vllm',
|
48
51
|
]
|
49
52
|
|
50
53
|
|
@@ -69,6 +72,7 @@ class ModelBuilder:
|
|
69
72
|
folder: str,
|
70
73
|
validate_api_ids: bool = True,
|
71
74
|
download_validation_only: bool = False,
|
75
|
+
app_not_found_action: Literal["auto_create", "prompt", "error"] = "error",
|
72
76
|
):
|
73
77
|
"""
|
74
78
|
:param folder: The folder containing the model.py, config.yaml, requirements.txt and
|
@@ -77,7 +81,13 @@ class ModelBuilder:
|
|
77
81
|
deprecate in favor of download_validation_only.
|
78
82
|
:param download_validation_only: Whether to skip the API config validation. Set to True if
|
79
83
|
just downloading a checkpoint.
|
84
|
+
:param app_not_found_action: Defines how to handle the case when the app is not found.
|
85
|
+
Options: 'auto_create' - create automatically, 'prompt' - ask user, 'error' - raise exception.
|
80
86
|
"""
|
87
|
+
assert app_not_found_action in ["auto_create", "prompt", "error"], ValueError(
|
88
|
+
f"Expected one of {['auto_create', 'prompt', 'error']}, got {app_not_found_action=}"
|
89
|
+
)
|
90
|
+
self.app_not_found_action = app_not_found_action
|
81
91
|
self._client = None
|
82
92
|
if not validate_api_ids: # for backwards compatibility
|
83
93
|
download_validation_only = True
|
@@ -304,13 +314,46 @@ class ModelBuilder:
|
|
304
314
|
f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
|
305
315
|
)
|
306
316
|
return False
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
317
|
+
|
318
|
+
user_id = self.client.user_app_id.user_id
|
319
|
+
app_id = self.client.user_app_id.app_id
|
320
|
+
|
321
|
+
if self.app_not_found_action == "error":
|
322
|
+
logger.error(
|
323
|
+
f"Error checking API {self._base_api} for user app `{user_id}/{app_id}`. Error code: {resp.status.code}"
|
324
|
+
)
|
325
|
+
logger.error(
|
326
|
+
f"App `{app_id}` not found for user `{user_id}`. Please create the app first and try again."
|
327
|
+
)
|
328
|
+
return False
|
329
|
+
else:
|
330
|
+
user = User(
|
331
|
+
user_id=user_id,
|
332
|
+
pat=self.client.pat,
|
333
|
+
token=self.client.token,
|
334
|
+
base_url=self.client.base,
|
335
|
+
)
|
336
|
+
|
337
|
+
def create_app():
|
338
|
+
logger.info(f"Creating App `{app_id}` user `{user_id}`.")
|
339
|
+
user.create_app(app_id=app_id)
|
340
|
+
|
341
|
+
logger.info(f"App {app_id} not found for user {user_id}.")
|
342
|
+
|
343
|
+
if self.app_not_found_action == "prompt":
|
344
|
+
create_app_prompt = input(f"Do you want to create App `{app_id}`? (y/n): ")
|
345
|
+
if create_app_prompt.lower() == 'y':
|
346
|
+
create_app()
|
347
|
+
return True
|
348
|
+
else:
|
349
|
+
logger.error(
|
350
|
+
f"App `{app_id}` has not been created for user `{user_id}`. Please create the app first or switch to an existing one, then try again."
|
351
|
+
)
|
352
|
+
return False
|
353
|
+
|
354
|
+
elif self.app_not_found_action == "auto_create":
|
355
|
+
create_app()
|
356
|
+
return True
|
314
357
|
|
315
358
|
def _validate_config_model(self):
|
316
359
|
assert "model" in self.config, "model section not found in the config file"
|
@@ -759,7 +802,7 @@ class ModelBuilder:
|
|
759
802
|
)
|
760
803
|
python_version = DEFAULT_PYTHON_VERSION
|
761
804
|
gpu_version = DEFAULT_AMD_GPU_VERSION
|
762
|
-
final_image =
|
805
|
+
final_image = AMD_TORCH_BASE_IMAGE.format(
|
763
806
|
torch_version=torch_version,
|
764
807
|
python_version=python_version,
|
765
808
|
gpu_version=gpu_version,
|
@@ -1186,7 +1229,7 @@ def upload_model(folder, stage, skip_dockerfile):
|
|
1186
1229
|
:param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
|
1187
1230
|
:param skip_dockerfile: If True, will not create a Dockerfile.
|
1188
1231
|
"""
|
1189
|
-
builder = ModelBuilder(folder)
|
1232
|
+
builder = ModelBuilder(folder, app_not_found_action="prompt")
|
1190
1233
|
builder.download_checkpoints(stage=stage)
|
1191
1234
|
if not skip_dockerfile:
|
1192
1235
|
builder.create_dockerfile()
|
File without changes
|