clarifai 11.5.1__py3-none-any.whl → 11.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/model.py +42 -1
- clarifai/cli/pipeline.py +137 -0
- clarifai/cli/pipeline_step.py +104 -0
- clarifai/cli/templates/__init__.py +1 -0
- clarifai/cli/templates/pipeline_step_templates.py +64 -0
- clarifai/cli/templates/pipeline_templates.py +150 -0
- clarifai/client/auth/helper.py +46 -21
- clarifai/client/auth/register.py +5 -0
- clarifai/client/auth/stub.py +116 -12
- clarifai/client/base.py +9 -0
- clarifai/client/model.py +111 -7
- clarifai/client/model_client.py +355 -6
- clarifai/client/user.py +81 -0
- clarifai/runners/models/model_builder.py +52 -9
- clarifai/runners/pipeline_steps/__init__.py +0 -0
- clarifai/runners/pipeline_steps/pipeline_step_builder.py +510 -0
- clarifai/runners/pipelines/__init__.py +0 -0
- clarifai/runners/pipelines/pipeline_builder.py +313 -0
- clarifai/runners/utils/code_script.py +40 -7
- clarifai/runners/utils/const.py +2 -2
- clarifai/runners/utils/model_utils.py +135 -0
- clarifai/runners/utils/pipeline_validation.py +153 -0
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/METADATA +1 -1
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/RECORD +30 -19
- /clarifai/cli/{model_templates.py → templates/model_templates.py} +0 -0
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/WHEEL +0 -0
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/top_level.txt +0 -0
clarifai/client/auth/stub.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import itertools
|
2
3
|
import time
|
3
4
|
from concurrent.futures import ThreadPoolExecutor
|
@@ -31,6 +32,10 @@ def validate_response(response, attempt, max_attempts):
|
|
31
32
|
else:
|
32
33
|
return response
|
33
34
|
|
35
|
+
# Check if response is an async iterator
|
36
|
+
if hasattr(response, '__aiter__'):
|
37
|
+
return response # Return async iterator directly for handling in _async_call__
|
38
|
+
|
34
39
|
# Check if the response is an instance of a gRPC streaming call
|
35
40
|
if isinstance(response, grpc._channel._MultiThreadedRendezvous):
|
36
41
|
try:
|
@@ -49,7 +54,9 @@ def validate_response(response, attempt, max_attempts):
|
|
49
54
|
return handle_simple_response(response)
|
50
55
|
|
51
56
|
|
52
|
-
def create_stub(
|
57
|
+
def create_stub(
|
58
|
+
auth_helper: ClarifaiAuthHelper = None, max_retry_attempts: int = 10, is_async: bool = False
|
59
|
+
) -> V2Stub:
|
53
60
|
"""
|
54
61
|
Create client stub that handles authorization and basic retries for
|
55
62
|
unavailable or throttled connections.
|
@@ -58,41 +65,48 @@ def create_stub(auth_helper: ClarifaiAuthHelper = None, max_retry_attempts: int
|
|
58
65
|
auth_helper: ClarifaiAuthHelper to use for auth metadata (default: from env)
|
59
66
|
max_retry_attempts: max attempts to retry rpcs with retryable failures
|
60
67
|
"""
|
61
|
-
stub = AuthorizedStub(auth_helper)
|
68
|
+
stub = AuthorizedStub(auth_helper, is_async=is_async)
|
62
69
|
if max_retry_attempts > 0:
|
63
|
-
return RetryStub(stub, max_retry_attempts)
|
70
|
+
return RetryStub(stub, max_retry_attempts, is_async=is_async)
|
64
71
|
return stub
|
65
72
|
|
66
73
|
|
67
74
|
class AuthorizedStub(V2Stub):
|
68
75
|
"""V2Stub proxy that inserts metadata authorization in rpc calls."""
|
69
76
|
|
70
|
-
def __init__(self, auth_helper: ClarifaiAuthHelper = None):
|
77
|
+
def __init__(self, auth_helper: ClarifaiAuthHelper = None, is_async: bool = False):
|
71
78
|
if auth_helper is None:
|
72
79
|
auth_helper = ClarifaiAuthHelper.from_env()
|
73
|
-
self.
|
80
|
+
self.is_async = is_async
|
81
|
+
self.stub = auth_helper.get_async_stub() if is_async else auth_helper.get_stub()
|
74
82
|
self.metadata = auth_helper.metadata
|
75
83
|
|
76
84
|
def __getattr__(self, name):
|
77
85
|
value = getattr(self.stub, name)
|
78
86
|
if isinstance(value, RpcCallable):
|
79
|
-
value = _AuthorizedRpcCallable(value, self.metadata)
|
87
|
+
value = _AuthorizedRpcCallable(value, self.metadata, self.is_async)
|
80
88
|
return value
|
81
89
|
|
82
90
|
|
83
91
|
class _AuthorizedRpcCallable(RpcCallable):
|
84
92
|
"""Adds metadata(authorization header) to rpc calls"""
|
85
93
|
|
86
|
-
def __init__(self, func, metadata):
|
94
|
+
def __init__(self, func, metadata, is_async):
|
87
95
|
self.f = func
|
88
96
|
self.metadata = metadata
|
97
|
+
self.is_async = is_async or asyncio.iscoroutinefunction(func)
|
89
98
|
|
90
99
|
def __repr__(self):
|
91
100
|
return repr(self.f)
|
92
101
|
|
93
102
|
def __call__(self, *args, **kwargs):
|
94
103
|
metadata = kwargs.pop('metadata', self.metadata)
|
95
|
-
|
104
|
+
|
105
|
+
return self.f(
|
106
|
+
*args,
|
107
|
+
**kwargs,
|
108
|
+
metadata=metadata,
|
109
|
+
)
|
96
110
|
|
97
111
|
def future(self, *args, **kwargs):
|
98
112
|
metadata = kwargs.pop('metadata', self.metadata)
|
@@ -107,30 +121,109 @@ class RetryStub(V2Stub):
|
|
107
121
|
V2Stub proxy that retries requests (currently on unavailable server or throttle codes)
|
108
122
|
"""
|
109
123
|
|
110
|
-
def __init__(self, stub, max_attempts=10, backoff_time=5):
|
124
|
+
def __init__(self, stub, max_attempts=10, backoff_time=5, is_async=False):
|
111
125
|
self.stub = stub
|
112
126
|
self.max_attempts = max_attempts
|
113
127
|
self.backoff_time = backoff_time
|
128
|
+
self.is_async = is_async
|
114
129
|
|
115
130
|
def __getattr__(self, name):
|
116
131
|
value = getattr(self.stub, name)
|
117
132
|
if isinstance(value, RpcCallable):
|
118
|
-
value = _RetryRpcCallable(value, self.max_attempts, self.backoff_time)
|
133
|
+
value = _RetryRpcCallable(value, self.max_attempts, self.backoff_time, self.is_async)
|
119
134
|
return value
|
120
135
|
|
121
136
|
|
122
137
|
class _RetryRpcCallable(RpcCallable):
|
123
138
|
"""Retries rpc calls on unavailable server or throttle codes"""
|
124
139
|
|
125
|
-
def __init__(self, func, max_attempts, backoff_time):
|
140
|
+
def __init__(self, func, max_attempts, backoff_time, is_async=False):
|
126
141
|
self.f = func
|
127
142
|
self.max_attempts = max_attempts
|
128
143
|
self.backoff_time = backoff_time
|
144
|
+
self.is_async = is_async or asyncio.iscoroutinefunction(func)
|
129
145
|
|
130
146
|
def __repr__(self):
|
131
147
|
return repr(self.f)
|
132
148
|
|
133
|
-
def
|
149
|
+
async def _async_call__(self, *args, **kwargs):
|
150
|
+
"""Handle async RPC calls with retries and validation"""
|
151
|
+
for attempt in range(1, self.max_attempts + 1):
|
152
|
+
if attempt != 1:
|
153
|
+
await asyncio.sleep(self.backoff_time)
|
154
|
+
|
155
|
+
try:
|
156
|
+
response = self.f(*args, **kwargs)
|
157
|
+
|
158
|
+
# Handle streaming response
|
159
|
+
if hasattr(response, '__aiter__'):
|
160
|
+
return await self._handle_streaming_response(response, attempt)
|
161
|
+
|
162
|
+
# Handle regular async response
|
163
|
+
validated_response = await self._handle_regular_response(response, attempt)
|
164
|
+
if validated_response is not None:
|
165
|
+
return validated_response
|
166
|
+
|
167
|
+
except grpc.RpcError as e:
|
168
|
+
if not self._should_retry(e, attempt):
|
169
|
+
raise
|
170
|
+
logger.debug(
|
171
|
+
f'Retrying after error {e.code()} (attempt {attempt}/{self.max_attempts})'
|
172
|
+
)
|
173
|
+
|
174
|
+
raise Exception(f'Max attempts reached ({self.max_attempts}) without success')
|
175
|
+
|
176
|
+
async def _handle_streaming_response(self, response, attempt):
|
177
|
+
"""Handle streaming response validation and processing"""
|
178
|
+
|
179
|
+
async def validated_stream():
|
180
|
+
try:
|
181
|
+
async for item in response:
|
182
|
+
if not self._is_valid_response(item):
|
183
|
+
if attempt < self.max_attempts:
|
184
|
+
yield None # Signal for retry
|
185
|
+
raise Exception(
|
186
|
+
f'Validation failed on streaming response (attempt {attempt})'
|
187
|
+
)
|
188
|
+
yield item
|
189
|
+
except grpc.RpcError as e:
|
190
|
+
if not self._should_retry(e, attempt):
|
191
|
+
raise
|
192
|
+
yield None # Signal for retry
|
193
|
+
|
194
|
+
return validated_stream()
|
195
|
+
|
196
|
+
async def _handle_regular_response(self, response, attempt):
|
197
|
+
"""Handle regular async response validation and processing"""
|
198
|
+
try:
|
199
|
+
result = await response
|
200
|
+
if not self._is_valid_response(result):
|
201
|
+
if attempt < self.max_attempts:
|
202
|
+
return None # Signal for retry
|
203
|
+
raise Exception(f'Validation failed on response (attempt {attempt})')
|
204
|
+
return result
|
205
|
+
except grpc.RpcError as e:
|
206
|
+
if not self._should_retry(e, attempt):
|
207
|
+
raise
|
208
|
+
return None # Signal for retry
|
209
|
+
|
210
|
+
def _is_valid_response(self, response):
|
211
|
+
"""Check if response status is valid"""
|
212
|
+
return not (
|
213
|
+
hasattr(response, 'status')
|
214
|
+
and hasattr(response.status, 'code')
|
215
|
+
and response.status.code in throttle_status_codes
|
216
|
+
)
|
217
|
+
|
218
|
+
def _should_retry(self, error, attempt):
|
219
|
+
"""Determine if we should retry based on error and attempt count"""
|
220
|
+
return (
|
221
|
+
isinstance(error, grpc.RpcError)
|
222
|
+
and error.code() in retry_codes_grpc
|
223
|
+
and attempt < self.max_attempts
|
224
|
+
)
|
225
|
+
|
226
|
+
def _sync_call__(self, *args, **kwargs):
|
134
227
|
attempt = 0
|
135
228
|
while attempt < self.max_attempts:
|
136
229
|
attempt += 1
|
@@ -147,8 +240,19 @@ class _RetryRpcCallable(RpcCallable):
|
|
147
240
|
else:
|
148
241
|
raise
|
149
242
|
|
243
|
+
def __call__(self, *args, **kwargs):
|
244
|
+
if self.is_async:
|
245
|
+
return self._async_call__(*args, **kwargs)
|
246
|
+
return self._sync_call__(*args, **kwargs)
|
247
|
+
|
248
|
+
async def __call_async__(self, *args, **kwargs):
|
249
|
+
"""Explicit async call method"""
|
250
|
+
return await self._async_call(*args, **kwargs)
|
251
|
+
|
150
252
|
def future(self, *args, **kwargs):
|
151
253
|
# TODO use single result event loop thread with asyncio
|
254
|
+
if self.is_async:
|
255
|
+
return asyncio.create_task(self._async_call(*args, **kwargs))
|
152
256
|
return _threadpool.submit(self, *args, **kwargs)
|
153
257
|
|
154
258
|
def __getattr__(self, name):
|
clarifai/client/base.py
CHANGED
@@ -52,6 +52,7 @@ class BaseClient:
|
|
52
52
|
|
53
53
|
self.auth_helper = ClarifaiAuthHelper(**kwargs, validate=False)
|
54
54
|
self.STUB = create_stub(self.auth_helper)
|
55
|
+
self._async_stub = None
|
55
56
|
self.metadata = self.auth_helper.metadata
|
56
57
|
self.pat = self.auth_helper.pat
|
57
58
|
self.token = self.auth_helper._token
|
@@ -59,6 +60,14 @@ class BaseClient:
|
|
59
60
|
self.base = self.auth_helper.base
|
60
61
|
self.root_certificates_path = self.auth_helper._root_certificates_path
|
61
62
|
|
63
|
+
@property
|
64
|
+
def async_stub(self):
|
65
|
+
"""Returns the asynchronous gRPC stub for the API interaction.
|
66
|
+
Lazy initialization of async stub"""
|
67
|
+
if self._async_stub is None:
|
68
|
+
self._async_stub = create_stub(self.auth_helper, is_async=True)
|
69
|
+
return self._async_stub
|
70
|
+
|
62
71
|
@classmethod
|
63
72
|
def from_env(cls, validate: bool = False):
|
64
73
|
auth = ClarifaiAuthHelper.from_env(validate=validate)
|
clarifai/client/model.py
CHANGED
@@ -104,11 +104,6 @@ class Model(Lister, BaseClient):
|
|
104
104
|
self.input_types = None
|
105
105
|
self._client = None
|
106
106
|
self._added_methods = False
|
107
|
-
self._set_runner_selector(
|
108
|
-
compute_cluster_id=compute_cluster_id,
|
109
|
-
nodepool_id=nodepool_id,
|
110
|
-
deployment_id=deployment_id,
|
111
|
-
)
|
112
107
|
BaseClient.__init__(
|
113
108
|
self,
|
114
109
|
user_id=self.user_id,
|
@@ -120,6 +115,12 @@ class Model(Lister, BaseClient):
|
|
120
115
|
)
|
121
116
|
Lister.__init__(self)
|
122
117
|
|
118
|
+
self._set_runner_selector(
|
119
|
+
compute_cluster_id=compute_cluster_id,
|
120
|
+
nodepool_id=nodepool_id,
|
121
|
+
deployment_id=deployment_id,
|
122
|
+
)
|
123
|
+
|
123
124
|
@classmethod
|
124
125
|
def from_current_context(cls, **kwargs) -> 'Model':
|
125
126
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
@@ -502,7 +503,9 @@ class Model(Lister, BaseClient):
|
|
502
503
|
model=self.model_info,
|
503
504
|
runner_selector=self._runner_selector,
|
504
505
|
)
|
505
|
-
self._client = ModelClient(
|
506
|
+
self._client = ModelClient(
|
507
|
+
stub=self.STUB, async_stub=self.async_stub, request_template=request_template
|
508
|
+
)
|
506
509
|
return self._client
|
507
510
|
|
508
511
|
def predict(self, *args, **kwargs):
|
@@ -530,6 +533,35 @@ class Model(Lister, BaseClient):
|
|
530
533
|
|
531
534
|
return self.client.predict(*args, **kwargs)
|
532
535
|
|
536
|
+
async def async_predict(self, *args, **kwargs):
|
537
|
+
"""
|
538
|
+
Calls the model's async predict() method with the given arguments.
|
539
|
+
|
540
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
541
|
+
protos directly for compatibility with previous versions of the SDK.
|
542
|
+
"""
|
543
|
+
inputs = None
|
544
|
+
if 'inputs' in kwargs:
|
545
|
+
inputs = kwargs['inputs']
|
546
|
+
elif args:
|
547
|
+
inputs = args[0]
|
548
|
+
if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
|
549
|
+
assert len(args) <= 1, (
|
550
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
551
|
+
)
|
552
|
+
inference_params = kwargs.get('inference_params', {})
|
553
|
+
output_config = kwargs.get('output_config', {})
|
554
|
+
return await self.client._async_predict_by_proto(
|
555
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
556
|
+
)
|
557
|
+
|
558
|
+
# Adding try-except, since the await works differently with jupyter kernels and in regular python scripts.
|
559
|
+
try:
|
560
|
+
return await self.client.predict(*args, **kwargs)
|
561
|
+
except TypeError:
|
562
|
+
# In jupyter, it returns a str object instead of a co-routine.
|
563
|
+
return self.client.predict(*args, **kwargs)
|
564
|
+
|
533
565
|
def __getattr__(self, name):
|
534
566
|
try:
|
535
567
|
return getattr(self.model_info, name)
|
@@ -603,7 +635,11 @@ class Model(Lister, BaseClient):
|
|
603
635
|
if any([deployment_id, nodepool_id, compute_cluster_id]):
|
604
636
|
from clarifai.client.user import User
|
605
637
|
|
606
|
-
user_id =
|
638
|
+
user_id = (
|
639
|
+
User(pat=self.auth_helper.pat, token=self.auth_helper._token)
|
640
|
+
.get_user_info(user_id='me')
|
641
|
+
.user.id
|
642
|
+
)
|
607
643
|
|
608
644
|
runner_selector = None
|
609
645
|
if deployment_id and (compute_cluster_id or nodepool_id):
|
@@ -761,6 +797,30 @@ class Model(Lister, BaseClient):
|
|
761
797
|
|
762
798
|
return self.client.generate(*args, **kwargs)
|
763
799
|
|
800
|
+
async def async_generate(self, *args, **kwargs):
|
801
|
+
"""
|
802
|
+
Calls the model's async generate() method with the given arguments.
|
803
|
+
|
804
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
805
|
+
protos directly for compatibility with previous versions of the SDK.
|
806
|
+
"""
|
807
|
+
inputs = None
|
808
|
+
if 'inputs' in kwargs:
|
809
|
+
inputs = kwargs['inputs']
|
810
|
+
elif args:
|
811
|
+
inputs = args[0]
|
812
|
+
if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
|
813
|
+
assert len(args) <= 1, (
|
814
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
815
|
+
)
|
816
|
+
inference_params = kwargs.get('inference_params', {})
|
817
|
+
output_config = kwargs.get('output_config', {})
|
818
|
+
return self.client._async_generate_by_proto(
|
819
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
820
|
+
)
|
821
|
+
|
822
|
+
return self.client.generate(*args, **kwargs)
|
823
|
+
|
764
824
|
def generate_by_filepath(
|
765
825
|
self,
|
766
826
|
filepath: str,
|
@@ -926,6 +986,50 @@ class Model(Lister, BaseClient):
|
|
926
986
|
|
927
987
|
return self.client.stream(*args, **kwargs)
|
928
988
|
|
989
|
+
async def async_stream(self, *args, **kwargs):
|
990
|
+
"""
|
991
|
+
Calls the model's async stream() method with the given arguments.
|
992
|
+
|
993
|
+
If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
|
994
|
+
protos directly for compatibility with previous versions of the SDK.
|
995
|
+
"""
|
996
|
+
|
997
|
+
use_proto_call = False
|
998
|
+
inputs = None
|
999
|
+
if 'inputs' in kwargs:
|
1000
|
+
inputs = kwargs['inputs']
|
1001
|
+
elif args:
|
1002
|
+
inputs = args[0]
|
1003
|
+
if inputs and isinstance(inputs, Iterable):
|
1004
|
+
inputs_iter = inputs
|
1005
|
+
try:
|
1006
|
+
peek = next(inputs_iter)
|
1007
|
+
except StopIteration:
|
1008
|
+
pass
|
1009
|
+
else:
|
1010
|
+
use_proto_call = (
|
1011
|
+
peek and isinstance(peek, list) and isinstance(peek[0], resources_pb2.Input)
|
1012
|
+
)
|
1013
|
+
# put back the peeked value
|
1014
|
+
if inputs_iter is inputs:
|
1015
|
+
inputs = itertools.chain([peek], inputs_iter)
|
1016
|
+
if 'inputs' in kwargs:
|
1017
|
+
kwargs['inputs'] = inputs
|
1018
|
+
else:
|
1019
|
+
args = (inputs,) + args[1:]
|
1020
|
+
|
1021
|
+
if use_proto_call:
|
1022
|
+
assert len(args) <= 1, (
|
1023
|
+
"Cannot pass in raw protos and additional arguments at the same time."
|
1024
|
+
)
|
1025
|
+
inference_params = kwargs.get('inference_params', {})
|
1026
|
+
output_config = kwargs.get('output_config', {})
|
1027
|
+
return self.client._async_stream_by_proto(
|
1028
|
+
inputs=inputs, inference_params=inference_params, output_config=output_config
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
return self.client.async_stream(*args, **kwargs)
|
1032
|
+
|
929
1033
|
def stream_by_filepath(
|
930
1034
|
self,
|
931
1035
|
filepath: str,
|