clarifai 11.5.1__py3-none-any.whl → 11.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/model.py +42 -1
  3. clarifai/cli/pipeline.py +137 -0
  4. clarifai/cli/pipeline_step.py +104 -0
  5. clarifai/cli/templates/__init__.py +1 -0
  6. clarifai/cli/templates/pipeline_step_templates.py +64 -0
  7. clarifai/cli/templates/pipeline_templates.py +150 -0
  8. clarifai/client/auth/helper.py +46 -21
  9. clarifai/client/auth/register.py +5 -0
  10. clarifai/client/auth/stub.py +116 -12
  11. clarifai/client/base.py +9 -0
  12. clarifai/client/model.py +111 -7
  13. clarifai/client/model_client.py +355 -6
  14. clarifai/client/user.py +81 -0
  15. clarifai/runners/models/model_builder.py +52 -9
  16. clarifai/runners/pipeline_steps/__init__.py +0 -0
  17. clarifai/runners/pipeline_steps/pipeline_step_builder.py +510 -0
  18. clarifai/runners/pipelines/__init__.py +0 -0
  19. clarifai/runners/pipelines/pipeline_builder.py +313 -0
  20. clarifai/runners/utils/code_script.py +40 -7
  21. clarifai/runners/utils/const.py +2 -2
  22. clarifai/runners/utils/model_utils.py +135 -0
  23. clarifai/runners/utils/pipeline_validation.py +153 -0
  24. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/METADATA +1 -1
  25. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/RECORD +30 -19
  26. /clarifai/cli/{model_templates.py → templates/model_templates.py} +0 -0
  27. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/WHEEL +0 -0
  28. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/entry_points.txt +0 -0
  29. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/licenses/LICENSE +0 -0
  30. {clarifai-11.5.1.dist-info → clarifai-11.5.3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import itertools
2
3
  import time
3
4
  from concurrent.futures import ThreadPoolExecutor
@@ -31,6 +32,10 @@ def validate_response(response, attempt, max_attempts):
31
32
  else:
32
33
  return response
33
34
 
35
+ # Check if response is an async iterator
36
+ if hasattr(response, '__aiter__'):
37
+ return response # Return async iterator directly for handling in _async_call__
38
+
34
39
  # Check if the response is an instance of a gRPC streaming call
35
40
  if isinstance(response, grpc._channel._MultiThreadedRendezvous):
36
41
  try:
@@ -49,7 +54,9 @@ def validate_response(response, attempt, max_attempts):
49
54
  return handle_simple_response(response)
50
55
 
51
56
 
52
- def create_stub(auth_helper: ClarifaiAuthHelper = None, max_retry_attempts: int = 10) -> V2Stub:
57
+ def create_stub(
58
+ auth_helper: ClarifaiAuthHelper = None, max_retry_attempts: int = 10, is_async: bool = False
59
+ ) -> V2Stub:
53
60
  """
54
61
  Create client stub that handles authorization and basic retries for
55
62
  unavailable or throttled connections.
@@ -58,41 +65,48 @@ def create_stub(auth_helper: ClarifaiAuthHelper = None, max_retry_attempts: int
58
65
  auth_helper: ClarifaiAuthHelper to use for auth metadata (default: from env)
59
66
  max_retry_attempts: max attempts to retry rpcs with retryable failures
60
67
  """
61
- stub = AuthorizedStub(auth_helper)
68
+ stub = AuthorizedStub(auth_helper, is_async=is_async)
62
69
  if max_retry_attempts > 0:
63
- return RetryStub(stub, max_retry_attempts)
70
+ return RetryStub(stub, max_retry_attempts, is_async=is_async)
64
71
  return stub
65
72
 
66
73
 
67
74
  class AuthorizedStub(V2Stub):
68
75
  """V2Stub proxy that inserts metadata authorization in rpc calls."""
69
76
 
70
- def __init__(self, auth_helper: ClarifaiAuthHelper = None):
77
+ def __init__(self, auth_helper: ClarifaiAuthHelper = None, is_async: bool = False):
71
78
  if auth_helper is None:
72
79
  auth_helper = ClarifaiAuthHelper.from_env()
73
- self.stub = auth_helper.get_stub()
80
+ self.is_async = is_async
81
+ self.stub = auth_helper.get_async_stub() if is_async else auth_helper.get_stub()
74
82
  self.metadata = auth_helper.metadata
75
83
 
76
84
  def __getattr__(self, name):
77
85
  value = getattr(self.stub, name)
78
86
  if isinstance(value, RpcCallable):
79
- value = _AuthorizedRpcCallable(value, self.metadata)
87
+ value = _AuthorizedRpcCallable(value, self.metadata, self.is_async)
80
88
  return value
81
89
 
82
90
 
83
91
  class _AuthorizedRpcCallable(RpcCallable):
84
92
  """Adds metadata(authorization header) to rpc calls"""
85
93
 
86
- def __init__(self, func, metadata):
94
+ def __init__(self, func, metadata, is_async):
87
95
  self.f = func
88
96
  self.metadata = metadata
97
+ self.is_async = is_async or asyncio.iscoroutinefunction(func)
89
98
 
90
99
  def __repr__(self):
91
100
  return repr(self.f)
92
101
 
93
102
  def __call__(self, *args, **kwargs):
94
103
  metadata = kwargs.pop('metadata', self.metadata)
95
- return self.f(*args, **kwargs, metadata=metadata)
104
+
105
+ return self.f(
106
+ *args,
107
+ **kwargs,
108
+ metadata=metadata,
109
+ )
96
110
 
97
111
  def future(self, *args, **kwargs):
98
112
  metadata = kwargs.pop('metadata', self.metadata)
@@ -107,30 +121,109 @@ class RetryStub(V2Stub):
107
121
  V2Stub proxy that retries requests (currently on unavailable server or throttle codes)
108
122
  """
109
123
 
110
- def __init__(self, stub, max_attempts=10, backoff_time=5):
124
+ def __init__(self, stub, max_attempts=10, backoff_time=5, is_async=False):
111
125
  self.stub = stub
112
126
  self.max_attempts = max_attempts
113
127
  self.backoff_time = backoff_time
128
+ self.is_async = is_async
114
129
 
115
130
  def __getattr__(self, name):
116
131
  value = getattr(self.stub, name)
117
132
  if isinstance(value, RpcCallable):
118
- value = _RetryRpcCallable(value, self.max_attempts, self.backoff_time)
133
+ value = _RetryRpcCallable(value, self.max_attempts, self.backoff_time, self.is_async)
119
134
  return value
120
135
 
121
136
 
122
137
  class _RetryRpcCallable(RpcCallable):
123
138
  """Retries rpc calls on unavailable server or throttle codes"""
124
139
 
125
- def __init__(self, func, max_attempts, backoff_time):
140
+ def __init__(self, func, max_attempts, backoff_time, is_async=False):
126
141
  self.f = func
127
142
  self.max_attempts = max_attempts
128
143
  self.backoff_time = backoff_time
144
+ self.is_async = is_async or asyncio.iscoroutinefunction(func)
129
145
 
130
146
  def __repr__(self):
131
147
  return repr(self.f)
132
148
 
133
- def __call__(self, *args, **kwargs):
149
+ async def _async_call__(self, *args, **kwargs):
150
+ """Handle async RPC calls with retries and validation"""
151
+ for attempt in range(1, self.max_attempts + 1):
152
+ if attempt != 1:
153
+ await asyncio.sleep(self.backoff_time)
154
+
155
+ try:
156
+ response = self.f(*args, **kwargs)
157
+
158
+ # Handle streaming response
159
+ if hasattr(response, '__aiter__'):
160
+ return await self._handle_streaming_response(response, attempt)
161
+
162
+ # Handle regular async response
163
+ validated_response = await self._handle_regular_response(response, attempt)
164
+ if validated_response is not None:
165
+ return validated_response
166
+
167
+ except grpc.RpcError as e:
168
+ if not self._should_retry(e, attempt):
169
+ raise
170
+ logger.debug(
171
+ f'Retrying after error {e.code()} (attempt {attempt}/{self.max_attempts})'
172
+ )
173
+
174
+ raise Exception(f'Max attempts reached ({self.max_attempts}) without success')
175
+
176
+ async def _handle_streaming_response(self, response, attempt):
177
+ """Handle streaming response validation and processing"""
178
+
179
+ async def validated_stream():
180
+ try:
181
+ async for item in response:
182
+ if not self._is_valid_response(item):
183
+ if attempt < self.max_attempts:
184
+ yield None # Signal for retry
185
+ raise Exception(
186
+ f'Validation failed on streaming response (attempt {attempt})'
187
+ )
188
+ yield item
189
+ except grpc.RpcError as e:
190
+ if not self._should_retry(e, attempt):
191
+ raise
192
+ yield None # Signal for retry
193
+
194
+ return validated_stream()
195
+
196
+ async def _handle_regular_response(self, response, attempt):
197
+ """Handle regular async response validation and processing"""
198
+ try:
199
+ result = await response
200
+ if not self._is_valid_response(result):
201
+ if attempt < self.max_attempts:
202
+ return None # Signal for retry
203
+ raise Exception(f'Validation failed on response (attempt {attempt})')
204
+ return result
205
+ except grpc.RpcError as e:
206
+ if not self._should_retry(e, attempt):
207
+ raise
208
+ return None # Signal for retry
209
+
210
+ def _is_valid_response(self, response):
211
+ """Check if response status is valid"""
212
+ return not (
213
+ hasattr(response, 'status')
214
+ and hasattr(response.status, 'code')
215
+ and response.status.code in throttle_status_codes
216
+ )
217
+
218
+ def _should_retry(self, error, attempt):
219
+ """Determine if we should retry based on error and attempt count"""
220
+ return (
221
+ isinstance(error, grpc.RpcError)
222
+ and error.code() in retry_codes_grpc
223
+ and attempt < self.max_attempts
224
+ )
225
+
226
+ def _sync_call__(self, *args, **kwargs):
134
227
  attempt = 0
135
228
  while attempt < self.max_attempts:
136
229
  attempt += 1
@@ -147,8 +240,19 @@ class _RetryRpcCallable(RpcCallable):
147
240
  else:
148
241
  raise
149
242
 
243
+ def __call__(self, *args, **kwargs):
244
+ if self.is_async:
245
+ return self._async_call__(*args, **kwargs)
246
+ return self._sync_call__(*args, **kwargs)
247
+
248
+ async def __call_async__(self, *args, **kwargs):
249
+ """Explicit async call method"""
250
+ return await self._async_call(*args, **kwargs)
251
+
150
252
  def future(self, *args, **kwargs):
151
253
  # TODO use single result event loop thread with asyncio
254
+ if self.is_async:
255
+ return asyncio.create_task(self._async_call(*args, **kwargs))
152
256
  return _threadpool.submit(self, *args, **kwargs)
153
257
 
154
258
  def __getattr__(self, name):
clarifai/client/base.py CHANGED
@@ -52,6 +52,7 @@ class BaseClient:
52
52
 
53
53
  self.auth_helper = ClarifaiAuthHelper(**kwargs, validate=False)
54
54
  self.STUB = create_stub(self.auth_helper)
55
+ self._async_stub = None
55
56
  self.metadata = self.auth_helper.metadata
56
57
  self.pat = self.auth_helper.pat
57
58
  self.token = self.auth_helper._token
@@ -59,6 +60,14 @@ class BaseClient:
59
60
  self.base = self.auth_helper.base
60
61
  self.root_certificates_path = self.auth_helper._root_certificates_path
61
62
 
63
+ @property
64
+ def async_stub(self):
65
+ """Returns the asynchronous gRPC stub for the API interaction.
66
+ Lazy initialization of async stub"""
67
+ if self._async_stub is None:
68
+ self._async_stub = create_stub(self.auth_helper, is_async=True)
69
+ return self._async_stub
70
+
62
71
  @classmethod
63
72
  def from_env(cls, validate: bool = False):
64
73
  auth = ClarifaiAuthHelper.from_env(validate=validate)
clarifai/client/model.py CHANGED
@@ -104,11 +104,6 @@ class Model(Lister, BaseClient):
104
104
  self.input_types = None
105
105
  self._client = None
106
106
  self._added_methods = False
107
- self._set_runner_selector(
108
- compute_cluster_id=compute_cluster_id,
109
- nodepool_id=nodepool_id,
110
- deployment_id=deployment_id,
111
- )
112
107
  BaseClient.__init__(
113
108
  self,
114
109
  user_id=self.user_id,
@@ -120,6 +115,12 @@ class Model(Lister, BaseClient):
120
115
  )
121
116
  Lister.__init__(self)
122
117
 
118
+ self._set_runner_selector(
119
+ compute_cluster_id=compute_cluster_id,
120
+ nodepool_id=nodepool_id,
121
+ deployment_id=deployment_id,
122
+ )
123
+
123
124
  @classmethod
124
125
  def from_current_context(cls, **kwargs) -> 'Model':
125
126
  from clarifai.urls.helper import ClarifaiUrlHelper
@@ -502,7 +503,9 @@ class Model(Lister, BaseClient):
502
503
  model=self.model_info,
503
504
  runner_selector=self._runner_selector,
504
505
  )
505
- self._client = ModelClient(self.STUB, request_template=request_template)
506
+ self._client = ModelClient(
507
+ stub=self.STUB, async_stub=self.async_stub, request_template=request_template
508
+ )
506
509
  return self._client
507
510
 
508
511
  def predict(self, *args, **kwargs):
@@ -530,6 +533,35 @@ class Model(Lister, BaseClient):
530
533
 
531
534
  return self.client.predict(*args, **kwargs)
532
535
 
536
+ async def async_predict(self, *args, **kwargs):
537
+ """
538
+ Calls the model's async predict() method with the given arguments.
539
+
540
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
541
+ protos directly for compatibility with previous versions of the SDK.
542
+ """
543
+ inputs = None
544
+ if 'inputs' in kwargs:
545
+ inputs = kwargs['inputs']
546
+ elif args:
547
+ inputs = args[0]
548
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
549
+ assert len(args) <= 1, (
550
+ "Cannot pass in raw protos and additional arguments at the same time."
551
+ )
552
+ inference_params = kwargs.get('inference_params', {})
553
+ output_config = kwargs.get('output_config', {})
554
+ return await self.client._async_predict_by_proto(
555
+ inputs=inputs, inference_params=inference_params, output_config=output_config
556
+ )
557
+
558
+ # Adding try-except, since the await works differently with jupyter kernels and in regular python scripts.
559
+ try:
560
+ return await self.client.predict(*args, **kwargs)
561
+ except TypeError:
562
+ # In jupyter, it returns a str object instead of a co-routine.
563
+ return self.client.predict(*args, **kwargs)
564
+
533
565
  def __getattr__(self, name):
534
566
  try:
535
567
  return getattr(self.model_info, name)
@@ -603,7 +635,11 @@ class Model(Lister, BaseClient):
603
635
  if any([deployment_id, nodepool_id, compute_cluster_id]):
604
636
  from clarifai.client.user import User
605
637
 
606
- user_id = User().get_user_info(user_id='me').user.id
638
+ user_id = (
639
+ User(pat=self.auth_helper.pat, token=self.auth_helper._token)
640
+ .get_user_info(user_id='me')
641
+ .user.id
642
+ )
607
643
 
608
644
  runner_selector = None
609
645
  if deployment_id and (compute_cluster_id or nodepool_id):
@@ -761,6 +797,30 @@ class Model(Lister, BaseClient):
761
797
 
762
798
  return self.client.generate(*args, **kwargs)
763
799
 
800
+ async def async_generate(self, *args, **kwargs):
801
+ """
802
+ Calls the model's async generate() method with the given arguments.
803
+
804
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
805
+ protos directly for compatibility with previous versions of the SDK.
806
+ """
807
+ inputs = None
808
+ if 'inputs' in kwargs:
809
+ inputs = kwargs['inputs']
810
+ elif args:
811
+ inputs = args[0]
812
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
813
+ assert len(args) <= 1, (
814
+ "Cannot pass in raw protos and additional arguments at the same time."
815
+ )
816
+ inference_params = kwargs.get('inference_params', {})
817
+ output_config = kwargs.get('output_config', {})
818
+ return self.client._async_generate_by_proto(
819
+ inputs=inputs, inference_params=inference_params, output_config=output_config
820
+ )
821
+
822
+ return self.client.generate(*args, **kwargs)
823
+
764
824
  def generate_by_filepath(
765
825
  self,
766
826
  filepath: str,
@@ -926,6 +986,50 @@ class Model(Lister, BaseClient):
926
986
 
927
987
  return self.client.stream(*args, **kwargs)
928
988
 
989
+ async def async_stream(self, *args, **kwargs):
990
+ """
991
+ Calls the model's async stream() method with the given arguments.
992
+
993
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
994
+ protos directly for compatibility with previous versions of the SDK.
995
+ """
996
+
997
+ use_proto_call = False
998
+ inputs = None
999
+ if 'inputs' in kwargs:
1000
+ inputs = kwargs['inputs']
1001
+ elif args:
1002
+ inputs = args[0]
1003
+ if inputs and isinstance(inputs, Iterable):
1004
+ inputs_iter = inputs
1005
+ try:
1006
+ peek = next(inputs_iter)
1007
+ except StopIteration:
1008
+ pass
1009
+ else:
1010
+ use_proto_call = (
1011
+ peek and isinstance(peek, list) and isinstance(peek[0], resources_pb2.Input)
1012
+ )
1013
+ # put back the peeked value
1014
+ if inputs_iter is inputs:
1015
+ inputs = itertools.chain([peek], inputs_iter)
1016
+ if 'inputs' in kwargs:
1017
+ kwargs['inputs'] = inputs
1018
+ else:
1019
+ args = (inputs,) + args[1:]
1020
+
1021
+ if use_proto_call:
1022
+ assert len(args) <= 1, (
1023
+ "Cannot pass in raw protos and additional arguments at the same time."
1024
+ )
1025
+ inference_params = kwargs.get('inference_params', {})
1026
+ output_config = kwargs.get('output_config', {})
1027
+ return self.client._async_stream_by_proto(
1028
+ inputs=inputs, inference_params=inference_params, output_config=output_config
1029
+ )
1030
+
1031
+ return self.client.async_stream(*args, **kwargs)
1032
+
929
1033
  def stream_by_filepath(
930
1034
  self,
931
1035
  filepath: str,