clarifai 11.1.4rc2__py3-none-any.whl → 11.1.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  3. clarifai/cli/model.py +46 -10
  4. clarifai/client/model.py +89 -364
  5. clarifai/client/model_client.py +400 -0
  6. clarifai/client/workflow.py +2 -2
  7. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  8. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
  9. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  10. clarifai/runners/__init__.py +2 -7
  11. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  12. clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
  13. clarifai/runners/dockerfile_template/Dockerfile.template +4 -32
  14. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  15. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  16. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  17. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  18. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  19. clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
  20. clarifai/runners/models/model_builder.py +47 -20
  21. clarifai/runners/models/model_class.py +249 -25
  22. clarifai/runners/models/model_run_locally.py +5 -2
  23. clarifai/runners/models/model_runner.py +2 -0
  24. clarifai/runners/models/model_servicer.py +11 -2
  25. clarifai/runners/server.py +26 -9
  26. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  27. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  28. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  29. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  30. clarifai/runners/utils/const.py +1 -1
  31. clarifai/runners/utils/data_handler.py +308 -205
  32. clarifai/runners/utils/method_signatures.py +437 -0
  33. clarifai/runners/utils/serializers.py +132 -0
  34. clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  35. clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
  36. clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
  37. clarifai/utils/misc.py +12 -0
  38. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/METADATA +3 -2
  39. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/RECORD +43 -36
  40. clarifai/runners/models/base_typed_model.py +0 -238
  41. clarifai/runners/models/model_upload.py +0 -607
  42. clarifai/runners/utils/#const.py# +0 -30
  43. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/LICENSE +0 -0
  44. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/WHEEL +0 -0
  45. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/entry_points.txt +0 -0
  46. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,437 @@
1
+ import inspect
2
+ import json
3
+ import re
4
+ import types
5
+ from collections import namedtuple
6
+ from typing import List, get_args, get_origin
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import yaml
11
+ from clarifai_grpc.grpc.api import resources_pb2
12
+ from google.protobuf.message import Message as MessageProto
13
+
14
+ from clarifai.runners.utils import data_handler
15
+ from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ImageSerializer,
16
+ ListSerializer, MessageSerializer,
17
+ NDArraySerializer, NullValueSerializer, Serializer)
18
+
19
+
20
+ def build_function_signature(func, method_type: str):
21
+ '''
22
+ Build a signature for the given function.
23
+ '''
24
+ sig = inspect.signature(func)
25
+
26
+ # check if func is bound, and if not, remove self/cls
27
+ if getattr(func, '__self__', None) is None and sig.parameters and list(
28
+ sig.parameters.values())[0].name in ('self', 'cls'):
29
+ sig = sig.replace(parameters=list(sig.parameters.values())[1:])
30
+
31
+ return_annotation = sig.return_annotation
32
+ if return_annotation == inspect.Parameter.empty:
33
+ raise ValueError('Function must have a return annotation')
34
+ # check for multiple return values and convert to dict for named values
35
+ return_streaming = False
36
+ if get_origin(return_annotation) == data_handler.Stream:
37
+ return_annotation = get_args(return_annotation)[0]
38
+ return_streaming = True
39
+ if get_origin(return_annotation) == tuple:
40
+ return_annotation = tuple(get_args(return_annotation))
41
+ if isinstance(return_annotation, tuple):
42
+ return_annotation = {'return.%s' % i: tp for i, tp in enumerate(return_annotation)}
43
+ if not isinstance(return_annotation, dict):
44
+ return_annotation = {'return': return_annotation}
45
+
46
+ input_vars = build_variables_signature(sig.parameters.values())
47
+ output_vars = build_variables_signature(
48
+ [
49
+ # XXX inspect.Parameter errors for the special return names, so use SimpleNamespace
50
+ types.SimpleNamespace(name=name, annotation=tp, default=inspect.Parameter.empty)
51
+ for name, tp in return_annotation.items()
52
+ ],
53
+ is_output=True)
54
+ if return_streaming:
55
+ for var in output_vars:
56
+ var.streaming = True
57
+
58
+ # check for streams
59
+ if method_type == 'predict':
60
+ for var in input_vars:
61
+ if var.streaming:
62
+ raise TypeError('Stream inputs are not supported for predict methods')
63
+ for var in output_vars:
64
+ if var.streaming:
65
+ raise TypeError('Stream outputs are not supported for predict methods')
66
+ elif method_type == 'generate':
67
+ for var in input_vars:
68
+ if var.streaming:
69
+ raise TypeError('Stream inputs are not supported for generate methods')
70
+ if not all(var.streaming for var in output_vars):
71
+ raise TypeError('Generate methods must return a stream')
72
+ elif method_type == 'stream':
73
+ input_stream_vars = [var for var in input_vars if var.streaming]
74
+ if len(input_stream_vars) == 0:
75
+ raise TypeError('Stream methods must include a Stream input')
76
+ if len(output_vars) != 1 or not output_vars[0].streaming:
77
+ raise TypeError('Stream methods must return a single Stream')
78
+ else:
79
+ raise TypeError('Invalid method type: %s' % method_type)
80
+
81
+ #method_signature = resources_pb2.MethodSignature() # TODO
82
+ method_signature = _NamedFields() #for now
83
+
84
+ method_signature.name = func.__name__
85
+ #method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
86
+ assert method_type in ('predict', 'generate', 'stream')
87
+ method_signature.method_type = method_type
88
+
89
+ #method_signature.inputs.extend(input_vars)
90
+ #method_signature.outputs.extend(output_vars)
91
+ method_signature.inputs = input_vars
92
+ method_signature.outputs = output_vars
93
+ return method_signature
94
+
95
+
96
+ def build_variables_signature(parameters: List[inspect.Parameter], is_output=False):
97
+ '''
98
+ Build a data proto signature for the given variable or return type annotation.
99
+ '''
100
+
101
+ vars = []
102
+
103
+ # check valid names (should already be constrained by python naming, but check anyway)
104
+ for param in parameters:
105
+ if not param.name.isidentifier() and not (is_output and
106
+ re.match(r'return(\.\d+)?', param.name)):
107
+ raise ValueError(f'Invalid variable name: {param.name}')
108
+
109
+ # get fields for each variable based on type
110
+ for param in parameters:
111
+ param_types, streaming = _normalize_types(param, is_output=is_output)
112
+
113
+ for name, tp in param_types.items():
114
+ #var = resources_pb2.MethodVariable() # TODO
115
+ var = _NamedFields()
116
+ var.name = name
117
+ var.data_type = _DATA_TYPES[tp].data_type
118
+ var.data_field = _DATA_TYPES[tp].data_field
119
+ var.streaming = streaming
120
+ if not is_output:
121
+ var.required = (param.default is inspect.Parameter.empty)
122
+ if not var.required:
123
+ var.default = param.default
124
+ vars.append(var)
125
+
126
+ # check if any fields are used more than once, and if so, use parts
127
+ # also if more than one field uses parts lists, also use parts, since the lists can be different lengths
128
+ # NOTE this is a little fancy, another way would just be to check if there is more than one arg
129
+ fields_unique = (len(set(var.data_field for var in vars)) == len(vars))
130
+ num_parts_lists = sum(int(var.data_field.startswith('parts[]')) for var in vars)
131
+ if not fields_unique or num_parts_lists > 1:
132
+ for var in vars:
133
+ var.data_field = 'parts[%s].%s' % (var.name, var.data_field)
134
+
135
+ return vars
136
+
137
+
138
+ def signatures_to_json(signatures):
139
+ assert isinstance(
140
+ signatures, dict), 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
141
+ return json.dumps(signatures, default=repr)
142
+
143
+
144
+ def signatures_from_json(json_str):
145
+ return json.loads(json_str, object_pairs_hook=_NamedFields)
146
+
147
+
148
+ def signatures_to_yaml(signatures):
149
+ # XXX go in/out of json to get the correct format and python dict types
150
+ d = json.loads(signatures_to_json(signatures))
151
+ return yaml.dump(d, default_flow_style=False)
152
+
153
+
154
+ def signatures_from_yaml(yaml_str):
155
+ d = yaml.safe_load(yaml_str)
156
+ return signatures_from_json(json.dumps(d))
157
+
158
+
159
+ def serialize(kwargs, signatures, proto=None, is_output=False):
160
+ '''
161
+ Serialize the given kwargs into the proto using the given signatures.
162
+ '''
163
+ if proto is None:
164
+ proto = resources_pb2.Data()
165
+ if not is_output: # TODO: use this consistently for return keys also
166
+ flatten_nested_keys(kwargs, signatures, is_output)
167
+ unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
168
+ if unknown:
169
+ if unknown == {'return'} and len(signatures) > 1:
170
+ raise TypeError('Got a single return value, but expected multiple outputs {%s}' %
171
+ ', '.join(sig.name for sig in signatures))
172
+ raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
173
+ for sig in signatures:
174
+ if sig.name not in kwargs:
175
+ if sig.required:
176
+ raise TypeError(f'Missing required argument: {sig.name}')
177
+ continue # skip missing fields, they can be set to default on the server
178
+ data = kwargs[sig.name]
179
+ force_named_part = (_is_empty_proto_data(data) and not is_output and not sig.required)
180
+ data_proto, field = _get_data_part(
181
+ proto, sig, is_output=is_output, serializing=True, force_named_part=force_named_part)
182
+ serializer = get_serializer(sig.data_type)
183
+ serializer.serialize(data_proto, field, data)
184
+ return proto
185
+
186
+
187
+ def deserialize(proto, signatures, is_output=False):
188
+ '''
189
+ Deserialize the given proto into kwargs using the given signatures.
190
+ '''
191
+ kwargs = {}
192
+ for sig in signatures:
193
+ data_proto, field = _get_data_part(proto, sig, is_output=is_output, serializing=False)
194
+ if data_proto is None:
195
+ # not set in proto, check if required or skip if optional arg
196
+ if not is_output and sig.required:
197
+ raise ValueError(f'Missing required field: {sig.name}')
198
+ continue
199
+ serializer = get_serializer(sig.data_type)
200
+ data = serializer.deserialize(data_proto, field)
201
+ kwargs[sig.name] = data
202
+ if is_output:
203
+ if len(kwargs) == 1 and 'return' in kwargs: # case for single return value
204
+ return kwargs['return']
205
+ if kwargs and 'return.0' in kwargs: # case for tuple return values
206
+ return tuple(kwargs[f'return.{i}'] for i in range(len(kwargs)))
207
+ return data_handler.Output(kwargs)
208
+ unflatten_nested_keys(kwargs, signatures, is_output)
209
+ return kwargs
210
+
211
+
212
+ def get_serializer(data_type: str) -> Serializer:
213
+ if data_type in _SERIALIZERS_BY_TYPE_STRING:
214
+ return _SERIALIZERS_BY_TYPE_STRING[data_type]
215
+ if data_type.startswith('List['):
216
+ inner_type_string = data_type[len('List['):-1]
217
+ inner_serializer = get_serializer(inner_type_string)
218
+ return ListSerializer(inner_serializer)
219
+ raise ValueError(f'Unsupported type: "{data_type}"')
220
+
221
+
222
+ def flatten_nested_keys(kwargs, signatures, is_output):
223
+ '''
224
+ Flatten nested keys into a single key with a dot, e.g. {'a': {'b': 1}} -> {'a.b': 1}
225
+ in the kwargs, using the given signatures to determine which keys are nested.
226
+ '''
227
+ nested_keys = [sig.name for sig in signatures if '.' in sig.name]
228
+ outer_keys = set(key.split('.')[0] for key in nested_keys)
229
+ for outer in outer_keys:
230
+ if outer not in kwargs:
231
+ continue
232
+ kwargs.update({outer + '.' + k: v for k, v in kwargs.pop(outer).items()})
233
+ return kwargs
234
+
235
+
236
+ def unflatten_nested_keys(kwargs, signatures, is_output):
237
+ '''
238
+ Unflatten nested keys in kwargs into a dict, e.g. {'a.b': 1} -> {'a': {'b': 1}}
239
+ Uses the signatures to determine which keys are nested.
240
+ The dict subclass is Input or Output, depending on the is_output flag.
241
+ '''
242
+ for sig in signatures:
243
+ if '.' not in sig.name:
244
+ continue
245
+ if sig.name not in kwargs:
246
+ continue
247
+ parts = sig.name.split('.')
248
+ assert len(parts) == 2, 'Only one level of nested keys is supported'
249
+ if parts[0] not in kwargs:
250
+ kwargs[parts[0]] = data_handler.Output() if is_output else data_handler.Input()
251
+ kwargs[parts[0]][parts[1]] = kwargs.pop(sig.name)
252
+ return kwargs
253
+
254
+
255
+ def _is_empty_proto_data(data):
256
+ if isinstance(data, np.ndarray):
257
+ return False
258
+ if isinstance(data, MessageProto):
259
+ return not data.ByteSize()
260
+ return not data
261
+
262
+
263
+ def _get_data_part(proto, sig, is_output, serializing, force_named_part=False):
264
+ field = sig.data_field
265
+
266
+ # check if we need to force a named part, to distinguish between empty and unset values
267
+ if force_named_part and not field.startswith('parts['):
268
+ field = f'parts[{sig.name}].{field}'
269
+
270
+ # gets the named part from the proto, according to the field path
271
+ # note we only support one level of named parts
272
+ #parts = field.replace(' ', '').split('.')
273
+ # split on . but not if it is inside brackets, e.g. parts[outer.inner].field
274
+ parts = re.split(r'\.(?![^\[]*\])', field.replace(' ', ''))
275
+
276
+ if len(parts) not in (1, 2, 3): # field, parts[name].field, parts[name].parts[].field
277
+ raise ValueError('Invalid field: %s' % field)
278
+
279
+ if len(parts) == 1:
280
+ # also need to check if there is an explicitly named part, e.g. for empty values
281
+ part = next((part for part in proto.parts if part.id == sig.name), None)
282
+ if part:
283
+ return part.data, field
284
+ if not serializing and not is_output and _is_empty_proto_data(getattr(proto, field)):
285
+ return None, field
286
+ return proto, field
287
+
288
+ # list
289
+ if parts[0] == 'parts[]':
290
+ if len(parts) != 2:
291
+ raise ValueError('Invalid field: %s' % field)
292
+ return proto, field # return the data that contains the list itself
293
+
294
+ # named part
295
+ if not (m := re.match(r'parts\[([\w.]+)\]', parts[0])):
296
+ raise ValueError('Invalid field: %s' % field)
297
+ if not (name := m.group(1)):
298
+ raise ValueError('Invalid field: %s' % field)
299
+ assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
300
+ part = next((part for part in proto.parts if part.id == name), None)
301
+ if part is None:
302
+ if not serializing:
303
+ raise ValueError('Missing part: %s' % name)
304
+ part = proto.parts.add()
305
+ part.id = name
306
+ return part.data, '.'.join(parts[1:])
307
+
308
+
309
+ def _normalize_types(param, is_output=False):
310
+ '''
311
+ Normalize the types for the given parameter. Returns a dict of names to types,
312
+ including named return values for outputs, and a flag indicating if streaming is used.
313
+ '''
314
+ tp = param.annotation
315
+
316
+ # stream type indicates streaming, not part of the data itself
317
+ streaming = (get_origin(tp) == data_handler.Stream)
318
+ if streaming:
319
+ tp = get_args(tp)[0]
320
+
321
+ if is_output or streaming: # named types can be used for outputs or streaming inputs
322
+ # output type used for named return values, each with their own data type
323
+ if isinstance(tp, (dict, data_handler.Output, data_handler.Input)):
324
+ return {param.name + '.' + name: _normalize_data_type(val)
325
+ for name, val in tp.items()}, streaming
326
+ if tp == data_handler.Output: # check for Output type without values
327
+ if not is_output:
328
+ raise TypeError('Output types can only be used for output values')
329
+ raise TypeError('Output types must be instantiated with inner type values for each key')
330
+ if tp == data_handler.Input: # check for Output type without values
331
+ if is_output:
332
+ raise TypeError('Input types can only be used for input values')
333
+ raise TypeError(
334
+ 'Stream[Input(...)] types must be instantiated with inner type values for each key')
335
+
336
+ return {param.name: _normalize_data_type(tp)}, streaming
337
+
338
+
339
+ def _normalize_data_type(tp):
340
+ # check if list, and if so, get inner type
341
+ is_list = (get_origin(tp) == list)
342
+ if is_list:
343
+ tp = get_args(tp)[0]
344
+
345
+ # check if numpy array, and if so, use ndarray
346
+ if get_origin(tp) == np.ndarray:
347
+ tp = np.ndarray
348
+
349
+ # check for PIL images (sometimes types use the module, sometimes the class)
350
+ # set these to use the Image data handler
351
+ if tp in (PIL.Image, PIL.Image.Image):
352
+ tp = data_handler.Image
353
+
354
+ # put back list
355
+ if is_list:
356
+ tp = List[tp]
357
+
358
+ # check if supported type
359
+ if tp not in _DATA_TYPES:
360
+ raise ValueError(f'Unsupported type: {tp}')
361
+
362
+ return tp
363
+
364
+
365
+ class _NamedFields(dict):
366
+ __getattr__ = dict.__getitem__
367
+ __setattr__ = dict.__setitem__
368
+
369
+
370
+ # data_type: name of the data type
371
+ # data_field: name of the field in the data proto
372
+ # serializer: serializer for the data type
373
+ _DataType = namedtuple('_DataType', ('data_type', 'data_field', 'serializer'))
374
+
375
+ # mapping of supported python types to data type names, fields, and serializers
376
+ _DATA_TYPES = {
377
+ str:
378
+ _DataType('str', 'string_value', AtomicFieldSerializer()),
379
+ bytes:
380
+ _DataType('bytes', 'bytes_value', AtomicFieldSerializer()),
381
+ int:
382
+ _DataType('int', 'int_value', AtomicFieldSerializer()),
383
+ float:
384
+ _DataType('float', 'float_value', AtomicFieldSerializer()),
385
+ bool:
386
+ _DataType('bool', 'bool_value', AtomicFieldSerializer()),
387
+ None:
388
+ _DataType('None', '', NullValueSerializer()),
389
+ np.ndarray:
390
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
391
+ data_handler.Text:
392
+ _DataType('Text', 'text', MessageSerializer(data_handler.Text)),
393
+ data_handler.Image:
394
+ _DataType('Image', 'image', ImageSerializer()),
395
+ data_handler.Concept:
396
+ _DataType('Concept', 'concepts', MessageSerializer(data_handler.Concept)),
397
+ data_handler.Region:
398
+ _DataType('Region', 'regions', MessageSerializer(data_handler.Region)),
399
+ data_handler.Frame:
400
+ _DataType('Frame', 'frames', MessageSerializer(data_handler.Frame)),
401
+ data_handler.Audio:
402
+ _DataType('Audio', 'audio', MessageSerializer(data_handler.Audio)),
403
+ data_handler.Video:
404
+ _DataType('Video', 'video', MessageSerializer(data_handler.Video)),
405
+
406
+ # lists handled specially, not as generic lists using parts
407
+ List[int]:
408
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
409
+ List[float]:
410
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
411
+ List[bool]:
412
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
413
+ }
414
+
415
+
416
+ # add generic lists using parts, for all supported types
417
+ def _add_list_fields():
418
+ for tp in list(_DATA_TYPES.keys()):
419
+ if List[tp] in _DATA_TYPES:
420
+ # already added as special case
421
+ continue
422
+
423
+ # check if data field is repeated, and if so, use repeated field for list
424
+ field_name = _DATA_TYPES[tp].data_field
425
+ descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
426
+ repeated = descriptor and descriptor.label == descriptor.LABEL_REPEATED
427
+
428
+ # add to supported types
429
+ data_type = 'List[%s]' % _DATA_TYPES[tp].data_type
430
+ data_field = field_name if repeated else 'parts[].' + field_name
431
+ serializer = ListSerializer(_DATA_TYPES[tp].serializer)
432
+
433
+ _DATA_TYPES[List[tp]] = _DataType(data_type, data_field, serializer)
434
+
435
+
436
+ _add_list_fields()
437
+ _SERIALIZERS_BY_TYPE_STRING = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
@@ -0,0 +1,132 @@
1
+ from typing import Iterable
2
+
3
+ import numpy as np
4
+ from clarifai_grpc.grpc.api import resources_pb2
5
+ from PIL import Image as PILImage
6
+
7
+ from clarifai.runners.utils.data_handler import Image, MessageData
8
+
9
+
10
+ class Serializer:
11
+
12
+ def serialize(self, data_proto, field, value):
13
+ pass
14
+
15
+ def deserialize(self, data_proto, field):
16
+ pass
17
+
18
+
19
+ def is_repeated(field):
20
+ return hasattr(field, 'add')
21
+
22
+
23
+ class AtomicFieldSerializer(Serializer):
24
+
25
+ def serialize(self, data_proto, field, value):
26
+ try:
27
+ setattr(data_proto, field, value)
28
+ except TypeError as e:
29
+ raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
30
+
31
+ def deserialize(self, data_proto, field):
32
+ return getattr(data_proto, field)
33
+
34
+
35
+ class MessageSerializer(Serializer):
36
+
37
+ def __init__(self, message_class):
38
+ self.message_class = message_class
39
+
40
+ def serialize(self, data_proto, field, value):
41
+ if isinstance(value, MessageData):
42
+ value = value.to_proto()
43
+ dst = getattr(data_proto, field)
44
+ try:
45
+ if is_repeated(dst):
46
+ dst.add().CopyFrom(value)
47
+ else:
48
+ dst.CopyFrom(value)
49
+ except TypeError as e:
50
+ raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
51
+
52
+ def deserialize(self, data_proto, field):
53
+ src = getattr(data_proto, field)
54
+ if is_repeated(src):
55
+ return [self.message_class.from_proto(item) for item in src]
56
+ else:
57
+ return self.message_class.from_proto(src)
58
+
59
+
60
+ class ImageSerializer(Serializer):
61
+
62
+ def serialize(self, data_proto, field, value):
63
+ if not isinstance(value, (PILImage.Image, Image, resources_pb2.Image)):
64
+ raise TypeError(f"Expected Image, got {type(value)}")
65
+ if isinstance(value, PILImage.Image):
66
+ value = Image.from_pil(value)
67
+ if isinstance(value, MessageData):
68
+ value = value.to_proto()
69
+ getattr(data_proto, field).CopyFrom(value)
70
+
71
+ def deserialize(self, data_proto, field):
72
+ value = getattr(data_proto, field)
73
+ return Image.from_proto(value)
74
+
75
+
76
+ class NDArraySerializer(Serializer):
77
+
78
+ def serialize(self, data_proto, field, value):
79
+ value = np.asarray(value)
80
+ if not np.issubdtype(value.dtype, np.number):
81
+ raise TypeError(f"Expected number array, got {value.dtype}")
82
+ proto = getattr(data_proto, field)
83
+ proto.buffer = value.tobytes()
84
+ proto.shape.extend(value.shape)
85
+ proto.dtype = str(value.dtype)
86
+
87
+ def deserialize(self, data_proto, field):
88
+ proto = getattr(data_proto, field)
89
+ array = np.frombuffer(proto.buffer, dtype=np.dtype(proto.dtype)).reshape(proto.shape)
90
+ return array
91
+
92
+
93
+ class NullValueSerializer(Serializer):
94
+
95
+ def serialize(self, data_proto, field, value):
96
+ pass
97
+
98
+ def deserialize(self, data_proto, field):
99
+ return None
100
+
101
+
102
+ class ListSerializer(Serializer):
103
+
104
+ def __init__(self, inner_serializer):
105
+ self.inner_serializer = inner_serializer
106
+
107
+ def serialize(self, data_proto, field, value):
108
+ if not isinstance(value, Iterable):
109
+ raise TypeError(f"Expected iterable, got {type(value)}")
110
+ if field.startswith('parts[].'):
111
+ inner_field = field[len('parts[].'):]
112
+ for item in value:
113
+ part = data_proto.parts.add()
114
+ self.inner_serializer.serialize(part.data, inner_field, item)
115
+ return
116
+ repeated = getattr(data_proto, field)
117
+ assert is_repeated(repeated), f"Field {field} is not repeated"
118
+ for item in value:
119
+ self.inner_serializer.serialize(data_proto, field, item) # appends to repeated field
120
+
121
+ def deserialize(self, data_proto, field):
122
+ if field.startswith('parts[].'):
123
+ inner_field = field[len('parts[].'):]
124
+ return [
125
+ self.inner_serializer.deserialize(part.data, inner_field) for part in data_proto.parts
126
+ ]
127
+ repeated = getattr(data_proto, field)
128
+ assert is_repeated(repeated), f"Field {field} is not repeated"
129
+ return self.inner_serializer.deserialize(data_proto, field) # returns repeated field list
130
+
131
+
132
+ # TODO dict serializer, maybe json only?
clarifai/utils/misc.py CHANGED
@@ -3,8 +3,20 @@ import re
3
3
  import uuid
4
4
  from typing import Any, Dict, List
5
5
 
6
+ from clarifai_grpc.grpc.api.status import status_code_pb2
7
+
6
8
  from clarifai.errors import UserError
7
9
 
10
+ RETRYABLE_CODES = [
11
+ status_code_pb2.MODEL_DEPLOYING, status_code_pb2.MODEL_LOADING,
12
+ status_code_pb2.MODEL_BUSY_PLEASE_RETRY
13
+ ]
14
+
15
+
16
+ def status_is_retryable(status_code: int) -> bool:
17
+ """Check if a status code is retryable."""
18
+ return status_code in RETRYABLE_CODES
19
+
8
20
 
9
21
  class Chunker:
10
22
  """Split an input sequence into small chunks."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 11.1.4rc2
3
+ Version: 11.1.5rc1
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -20,7 +20,7 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: clarifai-grpc >=11.0.7
23
+ Requires-Dist: clarifai-grpc >=11.1.3
24
24
  Requires-Dist: clarifai-protocol >=0.0.16
25
25
  Requires-Dist: numpy >=1.22.0
26
26
  Requires-Dist: tqdm >=4.65.0
@@ -32,6 +32,7 @@ Requires-Dist: tabulate >=0.9.0
32
32
  Requires-Dist: fsspec >=2024.6.1
33
33
  Requires-Dist: click >=8.1.7
34
34
  Requires-Dist: requests >=2.32.3
35
+ Requires-Dist: aiohttp >=3.8.1
35
36
  Provides-Extra: all
36
37
  Requires-Dist: pycocotools ==2.0.6 ; extra == 'all'
37
38