clarifai 11.1.4rc2__py3-none-any.whl → 11.1.5rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/cli/model.py +46 -10
- clarifai/client/model.py +89 -364
- clarifai/client/model_client.py +400 -0
- clarifai/client/workflow.py +2 -2
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/runners/__init__.py +2 -7
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
- clarifai/runners/dockerfile_template/Dockerfile.template +4 -32
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
- clarifai/runners/models/model_builder.py +47 -20
- clarifai/runners/models/model_class.py +249 -25
- clarifai/runners/models/model_run_locally.py +5 -2
- clarifai/runners/models/model_runner.py +2 -0
- clarifai/runners/models/model_servicer.py +11 -2
- clarifai/runners/server.py +26 -9
- clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
- clarifai/runners/utils/const.py +1 -1
- clarifai/runners/utils/data_handler.py +308 -205
- clarifai/runners/utils/method_signatures.py +437 -0
- clarifai/runners/utils/serializers.py +132 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
- clarifai/utils/misc.py +12 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/METADATA +3 -2
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/RECORD +43 -36
- clarifai/runners/models/base_typed_model.py +0 -238
- clarifai/runners/models/model_upload.py +0 -607
- clarifai/runners/utils/#const.py# +0 -30
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/LICENSE +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/WHEEL +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/entry_points.txt +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,437 @@
|
|
1
|
+
import inspect
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
import types
|
5
|
+
from collections import namedtuple
|
6
|
+
from typing import List, get_args, get_origin
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import PIL.Image
|
10
|
+
import yaml
|
11
|
+
from clarifai_grpc.grpc.api import resources_pb2
|
12
|
+
from google.protobuf.message import Message as MessageProto
|
13
|
+
|
14
|
+
from clarifai.runners.utils import data_handler
|
15
|
+
from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ImageSerializer,
|
16
|
+
ListSerializer, MessageSerializer,
|
17
|
+
NDArraySerializer, NullValueSerializer, Serializer)
|
18
|
+
|
19
|
+
|
20
|
+
def build_function_signature(func, method_type: str):
|
21
|
+
'''
|
22
|
+
Build a signature for the given function.
|
23
|
+
'''
|
24
|
+
sig = inspect.signature(func)
|
25
|
+
|
26
|
+
# check if func is bound, and if not, remove self/cls
|
27
|
+
if getattr(func, '__self__', None) is None and sig.parameters and list(
|
28
|
+
sig.parameters.values())[0].name in ('self', 'cls'):
|
29
|
+
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
|
30
|
+
|
31
|
+
return_annotation = sig.return_annotation
|
32
|
+
if return_annotation == inspect.Parameter.empty:
|
33
|
+
raise ValueError('Function must have a return annotation')
|
34
|
+
# check for multiple return values and convert to dict for named values
|
35
|
+
return_streaming = False
|
36
|
+
if get_origin(return_annotation) == data_handler.Stream:
|
37
|
+
return_annotation = get_args(return_annotation)[0]
|
38
|
+
return_streaming = True
|
39
|
+
if get_origin(return_annotation) == tuple:
|
40
|
+
return_annotation = tuple(get_args(return_annotation))
|
41
|
+
if isinstance(return_annotation, tuple):
|
42
|
+
return_annotation = {'return.%s' % i: tp for i, tp in enumerate(return_annotation)}
|
43
|
+
if not isinstance(return_annotation, dict):
|
44
|
+
return_annotation = {'return': return_annotation}
|
45
|
+
|
46
|
+
input_vars = build_variables_signature(sig.parameters.values())
|
47
|
+
output_vars = build_variables_signature(
|
48
|
+
[
|
49
|
+
# XXX inspect.Parameter errors for the special return names, so use SimpleNamespace
|
50
|
+
types.SimpleNamespace(name=name, annotation=tp, default=inspect.Parameter.empty)
|
51
|
+
for name, tp in return_annotation.items()
|
52
|
+
],
|
53
|
+
is_output=True)
|
54
|
+
if return_streaming:
|
55
|
+
for var in output_vars:
|
56
|
+
var.streaming = True
|
57
|
+
|
58
|
+
# check for streams
|
59
|
+
if method_type == 'predict':
|
60
|
+
for var in input_vars:
|
61
|
+
if var.streaming:
|
62
|
+
raise TypeError('Stream inputs are not supported for predict methods')
|
63
|
+
for var in output_vars:
|
64
|
+
if var.streaming:
|
65
|
+
raise TypeError('Stream outputs are not supported for predict methods')
|
66
|
+
elif method_type == 'generate':
|
67
|
+
for var in input_vars:
|
68
|
+
if var.streaming:
|
69
|
+
raise TypeError('Stream inputs are not supported for generate methods')
|
70
|
+
if not all(var.streaming for var in output_vars):
|
71
|
+
raise TypeError('Generate methods must return a stream')
|
72
|
+
elif method_type == 'stream':
|
73
|
+
input_stream_vars = [var for var in input_vars if var.streaming]
|
74
|
+
if len(input_stream_vars) == 0:
|
75
|
+
raise TypeError('Stream methods must include a Stream input')
|
76
|
+
if len(output_vars) != 1 or not output_vars[0].streaming:
|
77
|
+
raise TypeError('Stream methods must return a single Stream')
|
78
|
+
else:
|
79
|
+
raise TypeError('Invalid method type: %s' % method_type)
|
80
|
+
|
81
|
+
#method_signature = resources_pb2.MethodSignature() # TODO
|
82
|
+
method_signature = _NamedFields() #for now
|
83
|
+
|
84
|
+
method_signature.name = func.__name__
|
85
|
+
#method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
|
86
|
+
assert method_type in ('predict', 'generate', 'stream')
|
87
|
+
method_signature.method_type = method_type
|
88
|
+
|
89
|
+
#method_signature.inputs.extend(input_vars)
|
90
|
+
#method_signature.outputs.extend(output_vars)
|
91
|
+
method_signature.inputs = input_vars
|
92
|
+
method_signature.outputs = output_vars
|
93
|
+
return method_signature
|
94
|
+
|
95
|
+
|
96
|
+
def build_variables_signature(parameters: List[inspect.Parameter], is_output=False):
|
97
|
+
'''
|
98
|
+
Build a data proto signature for the given variable or return type annotation.
|
99
|
+
'''
|
100
|
+
|
101
|
+
vars = []
|
102
|
+
|
103
|
+
# check valid names (should already be constrained by python naming, but check anyway)
|
104
|
+
for param in parameters:
|
105
|
+
if not param.name.isidentifier() and not (is_output and
|
106
|
+
re.match(r'return(\.\d+)?', param.name)):
|
107
|
+
raise ValueError(f'Invalid variable name: {param.name}')
|
108
|
+
|
109
|
+
# get fields for each variable based on type
|
110
|
+
for param in parameters:
|
111
|
+
param_types, streaming = _normalize_types(param, is_output=is_output)
|
112
|
+
|
113
|
+
for name, tp in param_types.items():
|
114
|
+
#var = resources_pb2.MethodVariable() # TODO
|
115
|
+
var = _NamedFields()
|
116
|
+
var.name = name
|
117
|
+
var.data_type = _DATA_TYPES[tp].data_type
|
118
|
+
var.data_field = _DATA_TYPES[tp].data_field
|
119
|
+
var.streaming = streaming
|
120
|
+
if not is_output:
|
121
|
+
var.required = (param.default is inspect.Parameter.empty)
|
122
|
+
if not var.required:
|
123
|
+
var.default = param.default
|
124
|
+
vars.append(var)
|
125
|
+
|
126
|
+
# check if any fields are used more than once, and if so, use parts
|
127
|
+
# also if more than one field uses parts lists, also use parts, since the lists can be different lengths
|
128
|
+
# NOTE this is a little fancy, another way would just be to check if there is more than one arg
|
129
|
+
fields_unique = (len(set(var.data_field for var in vars)) == len(vars))
|
130
|
+
num_parts_lists = sum(int(var.data_field.startswith('parts[]')) for var in vars)
|
131
|
+
if not fields_unique or num_parts_lists > 1:
|
132
|
+
for var in vars:
|
133
|
+
var.data_field = 'parts[%s].%s' % (var.name, var.data_field)
|
134
|
+
|
135
|
+
return vars
|
136
|
+
|
137
|
+
|
138
|
+
def signatures_to_json(signatures):
|
139
|
+
assert isinstance(
|
140
|
+
signatures, dict), 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
|
141
|
+
return json.dumps(signatures, default=repr)
|
142
|
+
|
143
|
+
|
144
|
+
def signatures_from_json(json_str):
|
145
|
+
return json.loads(json_str, object_pairs_hook=_NamedFields)
|
146
|
+
|
147
|
+
|
148
|
+
def signatures_to_yaml(signatures):
|
149
|
+
# XXX go in/out of json to get the correct format and python dict types
|
150
|
+
d = json.loads(signatures_to_json(signatures))
|
151
|
+
return yaml.dump(d, default_flow_style=False)
|
152
|
+
|
153
|
+
|
154
|
+
def signatures_from_yaml(yaml_str):
|
155
|
+
d = yaml.safe_load(yaml_str)
|
156
|
+
return signatures_from_json(json.dumps(d))
|
157
|
+
|
158
|
+
|
159
|
+
def serialize(kwargs, signatures, proto=None, is_output=False):
|
160
|
+
'''
|
161
|
+
Serialize the given kwargs into the proto using the given signatures.
|
162
|
+
'''
|
163
|
+
if proto is None:
|
164
|
+
proto = resources_pb2.Data()
|
165
|
+
if not is_output: # TODO: use this consistently for return keys also
|
166
|
+
flatten_nested_keys(kwargs, signatures, is_output)
|
167
|
+
unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
|
168
|
+
if unknown:
|
169
|
+
if unknown == {'return'} and len(signatures) > 1:
|
170
|
+
raise TypeError('Got a single return value, but expected multiple outputs {%s}' %
|
171
|
+
', '.join(sig.name for sig in signatures))
|
172
|
+
raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
|
173
|
+
for sig in signatures:
|
174
|
+
if sig.name not in kwargs:
|
175
|
+
if sig.required:
|
176
|
+
raise TypeError(f'Missing required argument: {sig.name}')
|
177
|
+
continue # skip missing fields, they can be set to default on the server
|
178
|
+
data = kwargs[sig.name]
|
179
|
+
force_named_part = (_is_empty_proto_data(data) and not is_output and not sig.required)
|
180
|
+
data_proto, field = _get_data_part(
|
181
|
+
proto, sig, is_output=is_output, serializing=True, force_named_part=force_named_part)
|
182
|
+
serializer = get_serializer(sig.data_type)
|
183
|
+
serializer.serialize(data_proto, field, data)
|
184
|
+
return proto
|
185
|
+
|
186
|
+
|
187
|
+
def deserialize(proto, signatures, is_output=False):
|
188
|
+
'''
|
189
|
+
Deserialize the given proto into kwargs using the given signatures.
|
190
|
+
'''
|
191
|
+
kwargs = {}
|
192
|
+
for sig in signatures:
|
193
|
+
data_proto, field = _get_data_part(proto, sig, is_output=is_output, serializing=False)
|
194
|
+
if data_proto is None:
|
195
|
+
# not set in proto, check if required or skip if optional arg
|
196
|
+
if not is_output and sig.required:
|
197
|
+
raise ValueError(f'Missing required field: {sig.name}')
|
198
|
+
continue
|
199
|
+
serializer = get_serializer(sig.data_type)
|
200
|
+
data = serializer.deserialize(data_proto, field)
|
201
|
+
kwargs[sig.name] = data
|
202
|
+
if is_output:
|
203
|
+
if len(kwargs) == 1 and 'return' in kwargs: # case for single return value
|
204
|
+
return kwargs['return']
|
205
|
+
if kwargs and 'return.0' in kwargs: # case for tuple return values
|
206
|
+
return tuple(kwargs[f'return.{i}'] for i in range(len(kwargs)))
|
207
|
+
return data_handler.Output(kwargs)
|
208
|
+
unflatten_nested_keys(kwargs, signatures, is_output)
|
209
|
+
return kwargs
|
210
|
+
|
211
|
+
|
212
|
+
def get_serializer(data_type: str) -> Serializer:
|
213
|
+
if data_type in _SERIALIZERS_BY_TYPE_STRING:
|
214
|
+
return _SERIALIZERS_BY_TYPE_STRING[data_type]
|
215
|
+
if data_type.startswith('List['):
|
216
|
+
inner_type_string = data_type[len('List['):-1]
|
217
|
+
inner_serializer = get_serializer(inner_type_string)
|
218
|
+
return ListSerializer(inner_serializer)
|
219
|
+
raise ValueError(f'Unsupported type: "{data_type}"')
|
220
|
+
|
221
|
+
|
222
|
+
def flatten_nested_keys(kwargs, signatures, is_output):
|
223
|
+
'''
|
224
|
+
Flatten nested keys into a single key with a dot, e.g. {'a': {'b': 1}} -> {'a.b': 1}
|
225
|
+
in the kwargs, using the given signatures to determine which keys are nested.
|
226
|
+
'''
|
227
|
+
nested_keys = [sig.name for sig in signatures if '.' in sig.name]
|
228
|
+
outer_keys = set(key.split('.')[0] for key in nested_keys)
|
229
|
+
for outer in outer_keys:
|
230
|
+
if outer not in kwargs:
|
231
|
+
continue
|
232
|
+
kwargs.update({outer + '.' + k: v for k, v in kwargs.pop(outer).items()})
|
233
|
+
return kwargs
|
234
|
+
|
235
|
+
|
236
|
+
def unflatten_nested_keys(kwargs, signatures, is_output):
|
237
|
+
'''
|
238
|
+
Unflatten nested keys in kwargs into a dict, e.g. {'a.b': 1} -> {'a': {'b': 1}}
|
239
|
+
Uses the signatures to determine which keys are nested.
|
240
|
+
The dict subclass is Input or Output, depending on the is_output flag.
|
241
|
+
'''
|
242
|
+
for sig in signatures:
|
243
|
+
if '.' not in sig.name:
|
244
|
+
continue
|
245
|
+
if sig.name not in kwargs:
|
246
|
+
continue
|
247
|
+
parts = sig.name.split('.')
|
248
|
+
assert len(parts) == 2, 'Only one level of nested keys is supported'
|
249
|
+
if parts[0] not in kwargs:
|
250
|
+
kwargs[parts[0]] = data_handler.Output() if is_output else data_handler.Input()
|
251
|
+
kwargs[parts[0]][parts[1]] = kwargs.pop(sig.name)
|
252
|
+
return kwargs
|
253
|
+
|
254
|
+
|
255
|
+
def _is_empty_proto_data(data):
|
256
|
+
if isinstance(data, np.ndarray):
|
257
|
+
return False
|
258
|
+
if isinstance(data, MessageProto):
|
259
|
+
return not data.ByteSize()
|
260
|
+
return not data
|
261
|
+
|
262
|
+
|
263
|
+
def _get_data_part(proto, sig, is_output, serializing, force_named_part=False):
|
264
|
+
field = sig.data_field
|
265
|
+
|
266
|
+
# check if we need to force a named part, to distinguish between empty and unset values
|
267
|
+
if force_named_part and not field.startswith('parts['):
|
268
|
+
field = f'parts[{sig.name}].{field}'
|
269
|
+
|
270
|
+
# gets the named part from the proto, according to the field path
|
271
|
+
# note we only support one level of named parts
|
272
|
+
#parts = field.replace(' ', '').split('.')
|
273
|
+
# split on . but not if it is inside brackets, e.g. parts[outer.inner].field
|
274
|
+
parts = re.split(r'\.(?![^\[]*\])', field.replace(' ', ''))
|
275
|
+
|
276
|
+
if len(parts) not in (1, 2, 3): # field, parts[name].field, parts[name].parts[].field
|
277
|
+
raise ValueError('Invalid field: %s' % field)
|
278
|
+
|
279
|
+
if len(parts) == 1:
|
280
|
+
# also need to check if there is an explicitly named part, e.g. for empty values
|
281
|
+
part = next((part for part in proto.parts if part.id == sig.name), None)
|
282
|
+
if part:
|
283
|
+
return part.data, field
|
284
|
+
if not serializing and not is_output and _is_empty_proto_data(getattr(proto, field)):
|
285
|
+
return None, field
|
286
|
+
return proto, field
|
287
|
+
|
288
|
+
# list
|
289
|
+
if parts[0] == 'parts[]':
|
290
|
+
if len(parts) != 2:
|
291
|
+
raise ValueError('Invalid field: %s' % field)
|
292
|
+
return proto, field # return the data that contains the list itself
|
293
|
+
|
294
|
+
# named part
|
295
|
+
if not (m := re.match(r'parts\[([\w.]+)\]', parts[0])):
|
296
|
+
raise ValueError('Invalid field: %s' % field)
|
297
|
+
if not (name := m.group(1)):
|
298
|
+
raise ValueError('Invalid field: %s' % field)
|
299
|
+
assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
|
300
|
+
part = next((part for part in proto.parts if part.id == name), None)
|
301
|
+
if part is None:
|
302
|
+
if not serializing:
|
303
|
+
raise ValueError('Missing part: %s' % name)
|
304
|
+
part = proto.parts.add()
|
305
|
+
part.id = name
|
306
|
+
return part.data, '.'.join(parts[1:])
|
307
|
+
|
308
|
+
|
309
|
+
def _normalize_types(param, is_output=False):
|
310
|
+
'''
|
311
|
+
Normalize the types for the given parameter. Returns a dict of names to types,
|
312
|
+
including named return values for outputs, and a flag indicating if streaming is used.
|
313
|
+
'''
|
314
|
+
tp = param.annotation
|
315
|
+
|
316
|
+
# stream type indicates streaming, not part of the data itself
|
317
|
+
streaming = (get_origin(tp) == data_handler.Stream)
|
318
|
+
if streaming:
|
319
|
+
tp = get_args(tp)[0]
|
320
|
+
|
321
|
+
if is_output or streaming: # named types can be used for outputs or streaming inputs
|
322
|
+
# output type used for named return values, each with their own data type
|
323
|
+
if isinstance(tp, (dict, data_handler.Output, data_handler.Input)):
|
324
|
+
return {param.name + '.' + name: _normalize_data_type(val)
|
325
|
+
for name, val in tp.items()}, streaming
|
326
|
+
if tp == data_handler.Output: # check for Output type without values
|
327
|
+
if not is_output:
|
328
|
+
raise TypeError('Output types can only be used for output values')
|
329
|
+
raise TypeError('Output types must be instantiated with inner type values for each key')
|
330
|
+
if tp == data_handler.Input: # check for Output type without values
|
331
|
+
if is_output:
|
332
|
+
raise TypeError('Input types can only be used for input values')
|
333
|
+
raise TypeError(
|
334
|
+
'Stream[Input(...)] types must be instantiated with inner type values for each key')
|
335
|
+
|
336
|
+
return {param.name: _normalize_data_type(tp)}, streaming
|
337
|
+
|
338
|
+
|
339
|
+
def _normalize_data_type(tp):
|
340
|
+
# check if list, and if so, get inner type
|
341
|
+
is_list = (get_origin(tp) == list)
|
342
|
+
if is_list:
|
343
|
+
tp = get_args(tp)[0]
|
344
|
+
|
345
|
+
# check if numpy array, and if so, use ndarray
|
346
|
+
if get_origin(tp) == np.ndarray:
|
347
|
+
tp = np.ndarray
|
348
|
+
|
349
|
+
# check for PIL images (sometimes types use the module, sometimes the class)
|
350
|
+
# set these to use the Image data handler
|
351
|
+
if tp in (PIL.Image, PIL.Image.Image):
|
352
|
+
tp = data_handler.Image
|
353
|
+
|
354
|
+
# put back list
|
355
|
+
if is_list:
|
356
|
+
tp = List[tp]
|
357
|
+
|
358
|
+
# check if supported type
|
359
|
+
if tp not in _DATA_TYPES:
|
360
|
+
raise ValueError(f'Unsupported type: {tp}')
|
361
|
+
|
362
|
+
return tp
|
363
|
+
|
364
|
+
|
365
|
+
class _NamedFields(dict):
|
366
|
+
__getattr__ = dict.__getitem__
|
367
|
+
__setattr__ = dict.__setitem__
|
368
|
+
|
369
|
+
|
370
|
+
# data_type: name of the data type
|
371
|
+
# data_field: name of the field in the data proto
|
372
|
+
# serializer: serializer for the data type
|
373
|
+
_DataType = namedtuple('_DataType', ('data_type', 'data_field', 'serializer'))
|
374
|
+
|
375
|
+
# mapping of supported python types to data type names, fields, and serializers
|
376
|
+
_DATA_TYPES = {
|
377
|
+
str:
|
378
|
+
_DataType('str', 'string_value', AtomicFieldSerializer()),
|
379
|
+
bytes:
|
380
|
+
_DataType('bytes', 'bytes_value', AtomicFieldSerializer()),
|
381
|
+
int:
|
382
|
+
_DataType('int', 'int_value', AtomicFieldSerializer()),
|
383
|
+
float:
|
384
|
+
_DataType('float', 'float_value', AtomicFieldSerializer()),
|
385
|
+
bool:
|
386
|
+
_DataType('bool', 'bool_value', AtomicFieldSerializer()),
|
387
|
+
None:
|
388
|
+
_DataType('None', '', NullValueSerializer()),
|
389
|
+
np.ndarray:
|
390
|
+
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
391
|
+
data_handler.Text:
|
392
|
+
_DataType('Text', 'text', MessageSerializer(data_handler.Text)),
|
393
|
+
data_handler.Image:
|
394
|
+
_DataType('Image', 'image', ImageSerializer()),
|
395
|
+
data_handler.Concept:
|
396
|
+
_DataType('Concept', 'concepts', MessageSerializer(data_handler.Concept)),
|
397
|
+
data_handler.Region:
|
398
|
+
_DataType('Region', 'regions', MessageSerializer(data_handler.Region)),
|
399
|
+
data_handler.Frame:
|
400
|
+
_DataType('Frame', 'frames', MessageSerializer(data_handler.Frame)),
|
401
|
+
data_handler.Audio:
|
402
|
+
_DataType('Audio', 'audio', MessageSerializer(data_handler.Audio)),
|
403
|
+
data_handler.Video:
|
404
|
+
_DataType('Video', 'video', MessageSerializer(data_handler.Video)),
|
405
|
+
|
406
|
+
# lists handled specially, not as generic lists using parts
|
407
|
+
List[int]:
|
408
|
+
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
409
|
+
List[float]:
|
410
|
+
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
411
|
+
List[bool]:
|
412
|
+
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
413
|
+
}
|
414
|
+
|
415
|
+
|
416
|
+
# add generic lists using parts, for all supported types
|
417
|
+
def _add_list_fields():
|
418
|
+
for tp in list(_DATA_TYPES.keys()):
|
419
|
+
if List[tp] in _DATA_TYPES:
|
420
|
+
# already added as special case
|
421
|
+
continue
|
422
|
+
|
423
|
+
# check if data field is repeated, and if so, use repeated field for list
|
424
|
+
field_name = _DATA_TYPES[tp].data_field
|
425
|
+
descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
|
426
|
+
repeated = descriptor and descriptor.label == descriptor.LABEL_REPEATED
|
427
|
+
|
428
|
+
# add to supported types
|
429
|
+
data_type = 'List[%s]' % _DATA_TYPES[tp].data_type
|
430
|
+
data_field = field_name if repeated else 'parts[].' + field_name
|
431
|
+
serializer = ListSerializer(_DATA_TYPES[tp].serializer)
|
432
|
+
|
433
|
+
_DATA_TYPES[List[tp]] = _DataType(data_type, data_field, serializer)
|
434
|
+
|
435
|
+
|
436
|
+
_add_list_fields()
|
437
|
+
_SERIALIZERS_BY_TYPE_STRING = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from typing import Iterable
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from clarifai_grpc.grpc.api import resources_pb2
|
5
|
+
from PIL import Image as PILImage
|
6
|
+
|
7
|
+
from clarifai.runners.utils.data_handler import Image, MessageData
|
8
|
+
|
9
|
+
|
10
|
+
class Serializer:
|
11
|
+
|
12
|
+
def serialize(self, data_proto, field, value):
|
13
|
+
pass
|
14
|
+
|
15
|
+
def deserialize(self, data_proto, field):
|
16
|
+
pass
|
17
|
+
|
18
|
+
|
19
|
+
def is_repeated(field):
|
20
|
+
return hasattr(field, 'add')
|
21
|
+
|
22
|
+
|
23
|
+
class AtomicFieldSerializer(Serializer):
|
24
|
+
|
25
|
+
def serialize(self, data_proto, field, value):
|
26
|
+
try:
|
27
|
+
setattr(data_proto, field, value)
|
28
|
+
except TypeError as e:
|
29
|
+
raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
|
30
|
+
|
31
|
+
def deserialize(self, data_proto, field):
|
32
|
+
return getattr(data_proto, field)
|
33
|
+
|
34
|
+
|
35
|
+
class MessageSerializer(Serializer):
|
36
|
+
|
37
|
+
def __init__(self, message_class):
|
38
|
+
self.message_class = message_class
|
39
|
+
|
40
|
+
def serialize(self, data_proto, field, value):
|
41
|
+
if isinstance(value, MessageData):
|
42
|
+
value = value.to_proto()
|
43
|
+
dst = getattr(data_proto, field)
|
44
|
+
try:
|
45
|
+
if is_repeated(dst):
|
46
|
+
dst.add().CopyFrom(value)
|
47
|
+
else:
|
48
|
+
dst.CopyFrom(value)
|
49
|
+
except TypeError as e:
|
50
|
+
raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
|
51
|
+
|
52
|
+
def deserialize(self, data_proto, field):
|
53
|
+
src = getattr(data_proto, field)
|
54
|
+
if is_repeated(src):
|
55
|
+
return [self.message_class.from_proto(item) for item in src]
|
56
|
+
else:
|
57
|
+
return self.message_class.from_proto(src)
|
58
|
+
|
59
|
+
|
60
|
+
class ImageSerializer(Serializer):
|
61
|
+
|
62
|
+
def serialize(self, data_proto, field, value):
|
63
|
+
if not isinstance(value, (PILImage.Image, Image, resources_pb2.Image)):
|
64
|
+
raise TypeError(f"Expected Image, got {type(value)}")
|
65
|
+
if isinstance(value, PILImage.Image):
|
66
|
+
value = Image.from_pil(value)
|
67
|
+
if isinstance(value, MessageData):
|
68
|
+
value = value.to_proto()
|
69
|
+
getattr(data_proto, field).CopyFrom(value)
|
70
|
+
|
71
|
+
def deserialize(self, data_proto, field):
|
72
|
+
value = getattr(data_proto, field)
|
73
|
+
return Image.from_proto(value)
|
74
|
+
|
75
|
+
|
76
|
+
class NDArraySerializer(Serializer):
|
77
|
+
|
78
|
+
def serialize(self, data_proto, field, value):
|
79
|
+
value = np.asarray(value)
|
80
|
+
if not np.issubdtype(value.dtype, np.number):
|
81
|
+
raise TypeError(f"Expected number array, got {value.dtype}")
|
82
|
+
proto = getattr(data_proto, field)
|
83
|
+
proto.buffer = value.tobytes()
|
84
|
+
proto.shape.extend(value.shape)
|
85
|
+
proto.dtype = str(value.dtype)
|
86
|
+
|
87
|
+
def deserialize(self, data_proto, field):
|
88
|
+
proto = getattr(data_proto, field)
|
89
|
+
array = np.frombuffer(proto.buffer, dtype=np.dtype(proto.dtype)).reshape(proto.shape)
|
90
|
+
return array
|
91
|
+
|
92
|
+
|
93
|
+
class NullValueSerializer(Serializer):
|
94
|
+
|
95
|
+
def serialize(self, data_proto, field, value):
|
96
|
+
pass
|
97
|
+
|
98
|
+
def deserialize(self, data_proto, field):
|
99
|
+
return None
|
100
|
+
|
101
|
+
|
102
|
+
class ListSerializer(Serializer):
|
103
|
+
|
104
|
+
def __init__(self, inner_serializer):
|
105
|
+
self.inner_serializer = inner_serializer
|
106
|
+
|
107
|
+
def serialize(self, data_proto, field, value):
|
108
|
+
if not isinstance(value, Iterable):
|
109
|
+
raise TypeError(f"Expected iterable, got {type(value)}")
|
110
|
+
if field.startswith('parts[].'):
|
111
|
+
inner_field = field[len('parts[].'):]
|
112
|
+
for item in value:
|
113
|
+
part = data_proto.parts.add()
|
114
|
+
self.inner_serializer.serialize(part.data, inner_field, item)
|
115
|
+
return
|
116
|
+
repeated = getattr(data_proto, field)
|
117
|
+
assert is_repeated(repeated), f"Field {field} is not repeated"
|
118
|
+
for item in value:
|
119
|
+
self.inner_serializer.serialize(data_proto, field, item) # appends to repeated field
|
120
|
+
|
121
|
+
def deserialize(self, data_proto, field):
|
122
|
+
if field.startswith('parts[].'):
|
123
|
+
inner_field = field[len('parts[].'):]
|
124
|
+
return [
|
125
|
+
self.inner_serializer.deserialize(part.data, inner_field) for part in data_proto.parts
|
126
|
+
]
|
127
|
+
repeated = getattr(data_proto, field)
|
128
|
+
assert is_repeated(repeated), f"Field {field} is not repeated"
|
129
|
+
return self.inner_serializer.deserialize(data_proto, field) # returns repeated field list
|
130
|
+
|
131
|
+
|
132
|
+
# TODO dict serializer, maybe json only?
|
Binary file
|
Binary file
|
Binary file
|
clarifai/utils/misc.py
CHANGED
@@ -3,8 +3,20 @@ import re
|
|
3
3
|
import uuid
|
4
4
|
from typing import Any, Dict, List
|
5
5
|
|
6
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2
|
7
|
+
|
6
8
|
from clarifai.errors import UserError
|
7
9
|
|
10
|
+
RETRYABLE_CODES = [
|
11
|
+
status_code_pb2.MODEL_DEPLOYING, status_code_pb2.MODEL_LOADING,
|
12
|
+
status_code_pb2.MODEL_BUSY_PLEASE_RETRY
|
13
|
+
]
|
14
|
+
|
15
|
+
|
16
|
+
def status_is_retryable(status_code: int) -> bool:
|
17
|
+
"""Check if a status code is retryable."""
|
18
|
+
return status_code in RETRYABLE_CODES
|
19
|
+
|
8
20
|
|
9
21
|
class Chunker:
|
10
22
|
"""Split an input sequence into small chunks."""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: clarifai
|
3
|
-
Version: 11.1.
|
3
|
+
Version: 11.1.5rc1
|
4
4
|
Summary: Clarifai Python SDK
|
5
5
|
Home-page: https://github.com/Clarifai/clarifai-python
|
6
6
|
Author: Clarifai
|
@@ -20,7 +20,7 @@ Classifier: Operating System :: OS Independent
|
|
20
20
|
Requires-Python: >=3.8
|
21
21
|
Description-Content-Type: text/markdown
|
22
22
|
License-File: LICENSE
|
23
|
-
Requires-Dist: clarifai-grpc >=11.
|
23
|
+
Requires-Dist: clarifai-grpc >=11.1.3
|
24
24
|
Requires-Dist: clarifai-protocol >=0.0.16
|
25
25
|
Requires-Dist: numpy >=1.22.0
|
26
26
|
Requires-Dist: tqdm >=4.65.0
|
@@ -32,6 +32,7 @@ Requires-Dist: tabulate >=0.9.0
|
|
32
32
|
Requires-Dist: fsspec >=2024.6.1
|
33
33
|
Requires-Dist: click >=8.1.7
|
34
34
|
Requires-Dist: requests >=2.32.3
|
35
|
+
Requires-Dist: aiohttp >=3.8.1
|
35
36
|
Provides-Extra: all
|
36
37
|
Requires-Dist: pycocotools ==2.0.6 ; extra == 'all'
|
37
38
|
|