clarifai 11.1.5rc6__py3-none-any.whl → 11.1.5rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  3. clarifai/client/#model_client.py# +430 -0
  4. clarifai/client/model.py +95 -61
  5. clarifai/client/model_client.py +64 -49
  6. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  7. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  8. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  9. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  10. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  11. clarifai/runners/models/model_class.py +31 -48
  12. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  13. clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
  14. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  15. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  16. clarifai/runners/utils/data_types.py +62 -10
  17. clarifai/runners/utils/method_signatures.py +278 -295
  18. clarifai/runners/utils/serializers.py +143 -67
  19. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/METADATA +1 -1
  20. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/RECORD +24 -23
  21. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/LICENSE +0 -0
  22. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/WHEEL +0 -0
  23. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/entry_points.txt +0 -0
  24. {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/top_level.txt +0 -0
clarifai/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "11.1.5rc6"
1
+ __version__ = "11.1.5rc8"
@@ -0,0 +1,430 @@
1
+ import inspect
2
+ import time
3
+ from typing import Any, Dict, Iterator, List
4
+
5
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2
6
+ from clarifai_grpc.grpc.api.status import status_code_pb2
7
+
8
+ from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
9
+ from clarifai.errors import UserError
10
+ from clarifai.runners.utils.method_signatures import (deserialize, get_stream_from_signature,
11
+ serialize, signatures_from_json,
12
+ unflatten_nested_keys)
13
+ from clarifai.utils.misc import BackoffIterator, status_is_retryable
14
+
15
+ from clarifai.utils.logging import logger
16
+
17
+
18
+ class ModelClient:
19
+ '''
20
+ Client for calling model predict, generate, and stream methods.
21
+ '''
22
+
23
+ def __init__(self, stub, request_template: service_pb2.PostModelOutputsRequest = None):
24
+ '''
25
+ Initialize the model client.
26
+
27
+ Args:
28
+ stub: The gRPC stub for the model.
29
+ request_template: The template for the request to send to the model, including
30
+ common fields like model_id, model_version, cluster, etc.
31
+ '''
32
+ self.STUB = stub
33
+ self.request_template = request_template or service_pb2.PostModelOutputsRequest()
34
+ self._fetch_signatures()
35
+ self._define_functions()
36
+
37
+ def _fetch_signatures(self):
38
+ '''
39
+ Fetch the method signatures from the model.
40
+
41
+ Returns:
42
+ Dict: The method signatures.
43
+ '''
44
+ #request = resources_pb2.GetModelSignaturesRequest()
45
+ #response = self.stub.GetModelSignatures(request)
46
+ #self._method_signatures = json.loads(response.signatures) # or define protos
47
+ # TODO this could use a new endpoint to get the signatures
48
+ # for local grpc models, we'll also have to add the endpoint to the model servicer
49
+ # for now we'll just use the predict endpoint with a special method name
50
+
51
+ request = service_pb2.PostModelOutputsRequest()
52
+ request.CopyFrom(self.request_template)
53
+ # request.model.model_version.output_info.params['_method_name'] = '_GET_SIGNATURES'
54
+ inp = request.inputs.add() # empty input for this method
55
+ inp.data.parts.add() # empty part for this input
56
+ inp.data.metadata['_method_name'] = '_GET_SIGNATURES'
57
+ start_time = time.time()
58
+ backoff_iterator = BackoffIterator(10)
59
+ while True:
60
+ response = self.STUB.PostModelOutputs(request)
61
+ if status_is_retryable(
62
+ response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
63
+ self.logger.info(f"Retrying model info fetch with response {response.status!r}")
64
+ time.sleep(next(backoff_iterator))
65
+ continue
66
+ break
67
+ if response.status.code == status_code_pb2.INPUT_UNSUPPORTED_FORMAT:
68
+ # return code from older models that don't support _GET_SIGNATURES
69
+ self._method_signatures = {}
70
+ return
71
+ if response.status.code != status_code_pb2.SUCCESS:
72
+ raise Exception(f"Model failed with response {response!r}")
73
+ self._method_signatures = signatures_from_json(response.outputs[0].data.text.raw)
74
+
75
+ def _define_functions(self):
76
+ '''
77
+ Define the functions based on the method signatures.
78
+ '''
79
+ for method_name, method_signature in self._method_signatures.items():
80
+ # define the function in this client instance
81
+ if method_signature.method_type == 'predict':
82
+ call_func = self._predict
83
+ elif method_signature.method_type == 'generate':
84
+ call_func = self._generate
85
+ elif method_signature.method_type == 'stream':
86
+ call_func = self._stream
87
+ else:
88
+ raise ValueError(f"Unknown method type {method_signature.method_type}")
89
+
90
+ # method argnames, in order, collapsing nested keys to corresponding user function args
91
+ method_argnames = []
92
+ for var in method_signature.inputs:
93
+ outer = var.name.split('.', 1)[0]
94
+ if outer in method_argnames:
95
+ continue
96
+ method_argnames.append(outer)
97
+
98
+ def bind_f(method_name, method_argnames, call_func):
99
+
100
+ def f(*args, **kwargs):
101
+ if len(args) > len(method_argnames):
102
+ raise TypeError(
103
+ f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
104
+ )
105
+ for name, arg in zip(method_argnames, args): # handle positional with zip shortest
106
+ if name in kwargs:
107
+ raise TypeError(f"Multiple values for argument {name}")
108
+ kwargs[name] = arg
109
+ return call_func(kwargs, method_name)
110
+
111
+ return f
112
+
113
+ # need to bind method_name to the value, not the mutating loop variable
114
+ f = bind_f(method_name, method_argnames, call_func)
115
+
116
+ # set names, annotations and docstrings
117
+ f.__name__ = method_name
118
+ f.__qualname__ = f'{self.__class__.__name__}.{method_name}'
119
+ input_annos = {var.name: var.data_type for var in method_signature.inputs}
120
+ output_annos = {var.name: var.data_type for var in method_signature.outputs}
121
+ # unflatten nested keys to match the user function args for docs
122
+ input_annos = unflatten_nested_keys(input_annos, method_signature.inputs, is_output=False)
123
+ output_annos = unflatten_nested_keys(output_annos, method_signature.outputs, is_output=True)
124
+
125
+ # add Stream[] to the stream input annotations for docs
126
+ input_stream_argname, _ = get_stream_from_signature(method_signature.inputs)
127
+ if input_stream_argname:
128
+ input_annos[input_stream_argname] = 'Stream[' + str(
129
+ input_annos[input_stream_argname]) + ']'
130
+
131
+ # handle multiple outputs in the return annotation
132
+ return_annotation = output_annos
133
+ name = next(iter(output_annos.keys()))
134
+ if len(output_annos) == 1 and name == 'return':
135
+ # single output
136
+ return_annotation = output_annos[name]
137
+ elif name.startswith('return.') and name.split('.', 1)[1].isnumeric():
138
+ # tuple output
139
+ return_annotation = '(' + ", ".join(output_annos[f'return.{i}']
140
+ for i in range(len(output_annos))) + ')'
141
+ else:
142
+ # named output
143
+ return_annotation = f'Output({", ".join(f"{k}={t}" for k, t in output_annos.items())})'
144
+ if method_signature.method_type in ['generate', 'stream']:
145
+ return_annotation = f'Stream[{return_annotation}]'
146
+
147
+ # set annotations and docstrings
148
+ sig = inspect.signature(f).replace(
149
+ parameters=[
150
+ inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=v)
151
+ for k, v in input_annos.items()
152
+ ],
153
+ return_annotation=return_annotation,
154
+ )
155
+ f.__signature__ = sig
156
+ f.__doc__ = method_signature.docstring
157
+ setattr(self, method_name, f)
158
+
159
+ def _predict(
160
+ self,
161
+ inputs, # TODO set up functions according to fetched signatures?
162
+ method_name: str = 'predict',
163
+ ) -> Any:
164
+ input_signature = self._method_signatures[method_name].inputs
165
+ output_signature = self._method_signatures[method_name].outputs
166
+
167
+ batch_input = True
168
+ if isinstance(inputs, dict):
169
+ inputs = [inputs]
170
+ batch_input = False
171
+
172
+ proto_inputs = []
173
+ for input in inputs:
174
+ proto = resources_pb2.Input()
175
+ serialize(input, input_signature, proto.data)
176
+ proto_inputs.append(proto)
177
+
178
+ response = self._predict_by_proto(proto_inputs, method_name)
179
+ #print(response)
180
+
181
+ outputs = []
182
+ for output in response.outputs:
183
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
184
+ if batch_input:
185
+ return outputs
186
+ return outputs[0]
187
+
188
+ def _predict_by_proto(
189
+ self,
190
+ inputs: List[resources_pb2.Input],
191
+ method_name: str = None,
192
+ inference_params: Dict = None,
193
+ output_config: Dict = None,
194
+ ) -> service_pb2.MultiOutputResponse:
195
+ """Predicts the model based on the given inputs.
196
+
197
+ Args:
198
+ inputs (List[resources_pb2.Input]): The inputs to predict.
199
+ method_name (str): The remote method name to call.
200
+ inference_params (Dict): Inference parameters to override.
201
+ output_config (Dict): Output configuration to override.
202
+
203
+ Returns:
204
+ service_pb2.MultiOutputResponse: The prediction response(s).
205
+ """
206
+ if not isinstance(inputs, list):
207
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
208
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
209
+ raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
210
+
211
+ request = service_pb2.PostModelOutputsRequest()
212
+ request.CopyFrom(self.request_template)
213
+
214
+ request.inputs.extend(inputs)
215
+
216
+ if method_name:
217
+ # TODO put in new proto field?
218
+ for inp in request.inputs:
219
+ inp.data.metadata['_method_name'] = method_name
220
+ if inference_params:
221
+ request.model.model_version.output_info.params.update(inference_params)
222
+ if output_config:
223
+ request.model.model_version.output_info.output_config.MergeFrom(
224
+ resources_pb2.OutputConfig(**output_config))
225
+
226
+ start_time = time.time()
227
+ backoff_iterator = BackoffIterator(10)
228
+ while True:
229
+ response = self.STUB.PostModelOutputs(request)
230
+ if status_is_retryable(
231
+ response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
232
+ self.logger.info(f"Model predict failed with response {response!r}")
233
+ time.sleep(next(backoff_iterator))
234
+ continue
235
+
236
+ if response.status.code != status_code_pb2.SUCCESS:
237
+ raise Exception(f"Model predict failed with response {response!r}")
238
+ break
239
+
240
+ return response
241
+
242
+ def _generate(
243
+ self,
244
+ inputs, # TODO set up functions according to fetched signatures?
245
+ method_name: str = 'generate',
246
+ ) -> Any:
247
+ input_signature = self._method_signatures[method_name].inputs
248
+ output_signature = self._method_signatures[method_name].outputs
249
+
250
+ batch_input = True
251
+ if isinstance(inputs, dict):
252
+ inputs = [inputs]
253
+ batch_input = False
254
+
255
+ proto_inputs = []
256
+ for input in inputs:
257
+ proto = resources_pb2.Input()
258
+ serialize(input, input_signature, proto.data)
259
+ proto_inputs.append(proto)
260
+
261
+ response_stream = self._generate_by_proto(proto_inputs, method_name)
262
+ #print(response)
263
+
264
+ for response in response_stream:
265
+ outputs = []
266
+ for output in response.outputs:
267
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
268
+ if batch_input:
269
+ yield outputs
270
+ yield outputs[0]
271
+
272
+ def _generate_by_proto(
273
+ self,
274
+ inputs: List[resources_pb2.Input],
275
+ method_name: str = None,
276
+ inference_params: Dict = {},
277
+ output_config: Dict = {},
278
+ ):
279
+ """Generate the stream output on model based on the given inputs.
280
+
281
+ Args:
282
+ inputs (list[Input]): The inputs to generate, must be less than 128.
283
+ method_name (str): The remote method name to call.
284
+ inference_params (dict): The inference params to override.
285
+ output_config (dict): The output config to override.
286
+ """
287
+ if not isinstance(inputs, list):
288
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
289
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
290
+ raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
291
+ ) # TODO Use Chunker for inputs len > 128
292
+
293
+ request = service_pb2.PostModelOutputsRequest()
294
+ request.CopyFrom(self.request_template)
295
+
296
+ request.inputs.extend(inputs)
297
+
298
+ if method_name:
299
+ # TODO put in new proto field?
300
+ for inp in request.inputs:
301
+ inp.data.metadata['_method_name'] = method_name
302
+ if inference_params:
303
+ request.model.model_version.output_info.params.update(inference_params)
304
+ if output_config:
305
+ request.model.model_version.output_info.output_config.MergeFromDict(output_config)
306
+
307
+ start_time = time.time()
308
+ backoff_iterator = BackoffIterator(10)
309
+ started = False
310
+ while not started:
311
+ stream_response = self.STUB.GenerateModelOutputs(request)
312
+ try:
313
+ response = next(stream_response) # get the first response
314
+ except StopIteration:
315
+ raise Exception("Model Generate failed with no response")
316
+ if status_is_retryable(response.status.code) and \
317
+ time.time() - start_time < 60 * 10:
318
+ self.logger.info("Model is still deploying, please wait...")
319
+ time.sleep(next(backoff_iterator))
320
+ continue
321
+ if response.status.code != status_code_pb2.SUCCESS:
322
+ raise Exception(f"Model Generate failed with response {response.status!r}")
323
+ started = True
324
+
325
+ yield response # yield the first response
326
+
327
+ for response in stream_response:
328
+ if response.status.code != status_code_pb2.SUCCESS:
329
+ raise Exception(f"Model Generate failed with response {response.status!r}")
330
+ yield response
331
+
332
+ def _stream(
333
+ self,
334
+ inputs,
335
+ method_name: str = 'stream',
336
+ ) -> Any:
337
+ input_signature = self._method_signatures[method_name].inputs
338
+ output_signature = self._method_signatures[method_name].outputs
339
+
340
+ if isinstance(inputs, list):
341
+ assert len(inputs) == 1, 'streaming methods do not support batched calls'
342
+ inputs = inputs[0]
343
+ assert isinstance(inputs, dict)
344
+ kwargs = inputs
345
+
346
+ # find the streaming vars in the input signature, and the streaming input python param
347
+ stream_argname, streaming_var_signatures = get_stream_from_signature(input_signature)
348
+
349
+ # get the streaming input generator from the user-provided function arg values
350
+ user_inputs_generator = kwargs.pop(stream_argname)
351
+
352
+ def _input_proto_stream():
353
+ # first item contains all the inputs and the first stream item
354
+ proto = resources_pb2.Input()
355
+ try:
356
+ item = next(user_inputs_generator)
357
+ except StopIteration:
358
+ return # no items to stream
359
+ kwargs[stream_argname] = item
360
+ serialize(kwargs, input_signature, proto.data)
361
+
362
+ yield proto
363
+
364
+ # subsequent items are just the stream items
365
+ for item in user_inputs_generator:
366
+ proto = resources_pb2.Input()
367
+ serialize({stream_argname: item}, streaming_var_signatures, proto.data)
368
+ yield proto
369
+
370
+ response_stream = self._stream_by_proto(_input_proto_stream(), method_name)
371
+ #print(response)
372
+
373
+ for response in response_stream:
374
+ assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
375
+ yield deserialize(response.outputs[0].data, output_signature, is_output=True)
376
+
377
+ def _req_iterator(self,
378
+ input_iterator: Iterator[List[resources_pb2.Input]],
379
+ method_name: str = None,
380
+ inference_params: Dict = {},
381
+ output_config: Dict = {}):
382
+ request = service_pb2.PostModelOutputsRequest()
383
+ request.CopyFrom(self.request_template)
384
+ if inference_params:
385
+ request.model.model_version.output_info.params.update(inference_params)
386
+ if output_config:
387
+ request.model.model_version.output_info.output_config.MergeFromDict(output_config)
388
+ for inputs in input_iterator:
389
+ req = service_pb2.PostModelOutputsRequest()
390
+ req.CopyFrom(request)
391
+ if isinstance(inputs, list):
392
+ req.inputs.extend(inputs)
393
+ else:
394
+ req.inputs.append(inputs)
395
+ # TODO: put into new proto field?
396
+ for inp in req.inputs:
397
+ inp.data.metadata['_method_name'] = method_name
398
+ yield req
399
+
400
+ def _stream_by_proto(self,
401
+ inputs: Iterator[List[resources_pb2.Input]],
402
+ method_name: str = None,
403
+ inference_params: Dict = {},
404
+ output_config: Dict = {}):
405
+ """Generate the stream output on model based on the given stream of inputs.
406
+ """
407
+ # if not isinstance(inputs, Iterator[List[Input]]):
408
+ # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
409
+
410
+ request = self._req_iterator(inputs, method_name, inference_params, output_config)
411
+
412
+ start_time = time.time()
413
+ backoff_iterator = BackoffIterator(10)
414
+ generation_started = False
415
+ while True:
416
+ if generation_started:
417
+ break
418
+ stream_response = self.STUB.StreamModelOutputs(request)
419
+ for response in stream_response:
420
+ if status_is_retryable(response.status.code) and \
421
+ time.time() - start_time < 60 * 10:
422
+ self.logger.info("Model is still deploying, please wait...")
423
+ time.sleep(next(backoff_iterator))
424
+ break
425
+ if response.status.code != status_code_pb2.SUCCESS:
426
+ raise Exception(f"Model Predict failed with response {response.status!r}")
427
+ else:
428
+ if not generation_started:
429
+ generation_started = True
430
+ yield response
clarifai/client/model.py CHANGED
@@ -1,7 +1,8 @@
1
+ import itertools
1
2
  import json
2
3
  import os
3
4
  import time
4
- from typing import Any, Dict, Generator, Iterator, List, Tuple, Union
5
+ from typing import Any, Dict, Generator, Iterable, Iterator, List, Tuple, Union
5
6
 
6
7
  import numpy as np
7
8
  import requests
@@ -77,7 +78,8 @@ class Model(Lister, BaseClient):
77
78
  self.logger = logger
78
79
  self.training_params = {}
79
80
  self.input_types = None
80
- self._model_client = None
81
+ self._client = None
82
+ self._added_methods = False
81
83
  self._set_runner_selector(
82
84
  compute_cluster_id=compute_cluster_id,
83
85
  nodepool_id=nodepool_id,
@@ -418,8 +420,8 @@ class Model(Lister, BaseClient):
418
420
  **dict(self.kwargs, model_version=model_version_info))
419
421
 
420
422
  @property
421
- def model_client(self):
422
- if self._model_client is None:
423
+ def client(self):
424
+ if self._client is None:
423
425
  request_template = service_pb2.PostModelOutputsRequest(
424
426
  user_app_id=self.user_app_id,
425
427
  model_id=self.id,
@@ -427,30 +429,46 @@ class Model(Lister, BaseClient):
427
429
  model=self.model_info,
428
430
  runner_selector=self._runner_selector,
429
431
  )
430
- self._model_client = ModelClient(self.STUB, request_template=request_template)
431
- return self._model_client
432
+ self._client = ModelClient(self.STUB, request_template=request_template)
433
+ return self._client
432
434
 
433
- def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
434
- """Predicts the model based on the given inputs.
435
-
436
- Args:
437
- inputs (list[Input]): The inputs to predict, must be less than 128.
435
+ def predict(self, *args, **kwargs):
438
436
  """
437
+ Calls the model's predict() method with the given arguments.
439
438
 
440
- return self.model_client._predict_by_proto(
441
- inputs=inputs,
442
- inference_params=inference_params,
443
- output_config=output_config,
444
- )
439
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
440
+ protos directly for compatibility with previous versions of the SDK.
441
+ """
445
442
 
446
- def predict2(self, inputs):
447
- """Predicts the model based on the given inputs.
443
+ inputs = None
444
+ if 'inputs' in kwargs:
445
+ inputs = kwargs['inputs']
446
+ elif args:
447
+ inputs = args[0]
448
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
449
+ assert not args, "Cannot pass in raw protos and additional arguments at the same time."
450
+ inference_params = kwargs.get('inference_params', {})
451
+ output_config = kwargs.get('output_config', {})
452
+ return self.client._predict_by_proto(
453
+ inputs=inputs, inference_params=inference_params, output_config=output_config)
448
454
 
449
- Args:
450
- inputs (list[Input]): The inputs to predict, must be less than 128.
451
- """
455
+ return self.client.predict(*args, **kwargs)
452
456
 
453
- return self.model_client._predict(inputs=inputs,)
457
+ def __getattr__(self, name):
458
+ try:
459
+ return getattr(self.model_info, name)
460
+ except AttributeError:
461
+ pass
462
+ if not self._added_methods:
463
+ # fetch and set all the model methods
464
+ self._added_methods = True
465
+ self.client.fetch()
466
+ for method_name in self.client._method_signatures.keys():
467
+ if not hasattr(self, method_name):
468
+ setattr(self, method_name, getattr(self.client, method_name))
469
+ if hasattr(self.client, name):
470
+ return getattr(self.client, name)
471
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
454
472
 
455
473
  def _check_predict_input_type(self, input_type: str) -> None:
456
474
  """Checks if the input type is valid for the model.
@@ -635,24 +653,27 @@ class Model(Lister, BaseClient):
635
653
  return self.predict(
636
654
  inputs=[input_proto], inference_params=inference_params, output_config=output_config)
637
655
 
638
- def generate(
639
- self,
640
- inputs: List[Input],
641
- inference_params: Dict = {},
642
- output_config: Dict = {},
643
- ):
644
- """Generate the stream output on model based on the given inputs.
656
+ def generate(self, *args, **kwargs):
657
+ """
658
+ Calls the model's generate() method with the given arguments.
645
659
 
646
- Args:
647
- inputs (list[Input]): The inputs to generate, must be less than 128.
648
- inference_params (dict): The inference params to override.
649
- output_config (dict): The output config to override.
660
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
661
+ protos directly for compatibility with previous versions of the SDK.
650
662
  """
651
- return self.model_client._generate_by_proto(
652
- inputs=inputs,
653
- inference_params=inference_params,
654
- output_config=output_config,
655
- )
663
+
664
+ inputs = None
665
+ if 'inputs' in kwargs:
666
+ inputs = kwargs['inputs']
667
+ elif args:
668
+ inputs = args[0]
669
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
670
+ assert not args, "Cannot pass in raw protos and additional arguments at the same time."
671
+ inference_params = kwargs.get('inference_params', {})
672
+ output_config = kwargs.get('output_config', {})
673
+ return self.client._generate_by_proto(
674
+ inputs=inputs, inference_params=inference_params, output_config=output_config)
675
+
676
+ return self.client.generate(*args, **kwargs)
656
677
 
657
678
  def generate_by_filepath(self,
658
679
  filepath: str,
@@ -766,28 +787,44 @@ class Model(Lister, BaseClient):
766
787
  return self.generate(
767
788
  inputs=[input_proto], inference_params=inference_params, output_config=output_config)
768
789
 
769
- def stream(self,
770
- inputs: Iterator[List[Input]],
771
- inference_params: Dict = {},
772
- output_config: Dict = {}):
773
- """Generate the stream output on model based on the given stream of inputs.
774
-
775
- Args:
776
- inputs (Iterator[list[Input]]): stream of inputs to predict, must be less than 128.
790
+ def stream(self, *args, **kwargs):
791
+ """
792
+ Calls the model's stream() method with the given arguments.
777
793
 
778
- Example:
779
- >>> from clarifai.client.model import Model
780
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
781
- or
782
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
783
- >>> stream_response = model.stream(inputs=inputs, runner_selector=runner_selector)
784
- >>> list_stream_response = [response for response in stream_response]
794
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
795
+ protos directly for compatibility with previous versions of the SDK.
785
796
  """
786
- return self.model_client._stream_by_proto(
787
- inputs=inputs,
788
- inference_params=inference_params,
789
- output_config=output_config,
790
- )
797
+
798
+ use_proto_call = False
799
+ inputs = None
800
+ if 'inputs' in kwargs:
801
+ inputs = kwargs['inputs']
802
+ elif args:
803
+ inputs = args[0]
804
+ if inputs and isinstance(inputs, Iterable):
805
+ inputs_iter = iter(inputs)
806
+ try:
807
+ peek = next(inputs_iter)
808
+ except StopIteration:
809
+ pass
810
+ else:
811
+ use_proto_call = isinstance(peek, resources_pb2.Input)
812
+ # put back the peeked value
813
+ if inputs_iter is inputs:
814
+ inputs = itertools.chain([peek], inputs_iter)
815
+ if 'inputs' in kwargs:
816
+ kwargs['inputs'] = inputs
817
+ else:
818
+ args = (inputs,) + args[1:]
819
+
820
+ if use_proto_call:
821
+ assert not args, "Cannot pass in raw protos and additional arguments at the same time."
822
+ inference_params = kwargs.get('inference_params', {})
823
+ output_config = kwargs.get('output_config', {})
824
+ return self.client._stream_by_proto(
825
+ inputs=inputs, inference_params=inference_params, output_config=output_config)
826
+
827
+ return self.client.stream(*args, **kwargs)
791
828
 
792
829
  def stream_by_filepath(self,
793
830
  filepath: str,
@@ -946,9 +983,6 @@ class Model(Lister, BaseClient):
946
983
  self.kwargs = self.process_response_keys(dict_response['model'])
947
984
  self.model_info = resources_pb2.Model(**self.kwargs)
948
985
 
949
- def __getattr__(self, name):
950
- return getattr(self.model_info, name)
951
-
952
986
  def __str__(self):
953
987
  if len(self.kwargs) < 10:
954
988
  self.load_info()