clarifai 11.1.5rc7__py3-none-any.whl → 11.1.5rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/client/#model_client.py# +430 -0
- clarifai/client/model.py +95 -61
- clarifai/client/model_client.py +64 -49
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/model_class.py +29 -46
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
- clarifai/runners/utils/data_types.py +62 -10
- clarifai/runners/utils/method_signatures.py +278 -295
- clarifai/runners/utils/serializers.py +143 -67
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/METADATA +1 -1
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/RECORD +24 -23
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/LICENSE +0 -0
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/WHEEL +0 -0
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/entry_points.txt +0 -0
- {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
|
+
import ast
|
1
2
|
import inspect
|
2
3
|
import json
|
3
|
-
import
|
4
|
-
import
|
5
|
-
from
|
6
|
-
from typing import List, get_args, get_origin
|
4
|
+
import textwrap
|
5
|
+
from collections import namedtuple
|
6
|
+
from typing import List, Tuple, get_args, get_origin
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import PIL.Image
|
@@ -12,12 +12,12 @@ from clarifai_grpc.grpc.api import resources_pb2
|
|
12
12
|
from google.protobuf.message import Message as MessageProto
|
13
13
|
|
14
14
|
from clarifai.runners.utils import data_types
|
15
|
-
from clarifai.runners.utils.serializers import (AtomicFieldSerializer,
|
16
|
-
|
17
|
-
NDArraySerializer,
|
15
|
+
from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ListSerializer,
|
16
|
+
MessageSerializer, NamedFieldsSerializer,
|
17
|
+
NDArraySerializer, Serializer, TupleSerializer)
|
18
18
|
|
19
19
|
|
20
|
-
def build_function_signature(func
|
20
|
+
def build_function_signature(func):
|
21
21
|
'''
|
22
22
|
Build a signature for the given function.
|
23
23
|
'''
|
@@ -30,110 +30,144 @@ def build_function_signature(func, method_type: str):
|
|
30
30
|
|
31
31
|
return_annotation = sig.return_annotation
|
32
32
|
if return_annotation == inspect.Parameter.empty:
|
33
|
-
raise
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
if
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
is_output=True)
|
54
|
-
if return_streaming:
|
55
|
-
for var in output_vars:
|
56
|
-
var.streaming = True
|
57
|
-
|
58
|
-
# check for streams
|
59
|
-
if method_type == 'predict':
|
60
|
-
for var in input_vars:
|
61
|
-
if var.streaming:
|
62
|
-
raise TypeError('Stream inputs are not supported for predict methods')
|
63
|
-
for var in output_vars:
|
64
|
-
if var.streaming:
|
65
|
-
raise TypeError('Stream outputs are not supported for predict methods')
|
66
|
-
elif method_type == 'generate':
|
67
|
-
for var in input_vars:
|
68
|
-
if var.streaming:
|
69
|
-
raise TypeError('Stream inputs are not supported for generate methods')
|
70
|
-
if not all(var.streaming for var in output_vars):
|
71
|
-
raise TypeError('Generate methods must return a stream')
|
72
|
-
elif method_type == 'stream':
|
73
|
-
input_stream_vars = [var for var in input_vars if var.streaming]
|
74
|
-
if len(input_stream_vars) == 0:
|
75
|
-
raise TypeError('Stream methods must include a Stream input')
|
76
|
-
if not all(var.streaming for var in output_vars):
|
77
|
-
raise TypeError('Stream methods must return a single Stream')
|
33
|
+
raise TypeError('Function must have a return annotation')
|
34
|
+
|
35
|
+
input_sigs = [
|
36
|
+
build_variable_signature(p.name, p.annotation, p.default) for p in sig.parameters.values()
|
37
|
+
]
|
38
|
+
input_sigs, input_types, input_streaming = zip(*input_sigs)
|
39
|
+
output_sig, output_type, output_streaming = build_variable_signature(
|
40
|
+
'return', return_annotation, is_output=True)
|
41
|
+
# TODO: flatten out "return" layer if not needed
|
42
|
+
|
43
|
+
# check for streams and determine method type
|
44
|
+
if sum(input_streaming) > 1:
|
45
|
+
raise TypeError('streaming methods must have at most one streaming input')
|
46
|
+
input_streaming = any(input_streaming)
|
47
|
+
if not (input_streaming or output_streaming):
|
48
|
+
method_type = 'predict'
|
49
|
+
elif not input_streaming and output_streaming:
|
50
|
+
method_type = 'generate'
|
51
|
+
elif input_streaming and output_streaming:
|
52
|
+
method_type = 'stream'
|
78
53
|
else:
|
79
|
-
raise TypeError('
|
54
|
+
raise TypeError('stream methods with streaming inputs must have streaming outputs')
|
80
55
|
|
81
56
|
#method_signature = resources_pb2.MethodSignature() # TODO
|
82
|
-
method_signature =
|
57
|
+
method_signature = _SignatureDict() #for now
|
83
58
|
|
84
59
|
method_signature.name = func.__name__
|
85
60
|
#method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
|
86
61
|
assert method_type in ('predict', 'generate', 'stream')
|
87
62
|
method_signature.method_type = method_type
|
88
63
|
method_signature.docstring = func.__doc__
|
64
|
+
method_signature.annotations_json = json.dumps(_get_annotations_source(func))
|
89
65
|
|
90
66
|
#method_signature.inputs.extend(input_vars)
|
91
67
|
#method_signature.outputs.extend(output_vars)
|
92
|
-
method_signature.inputs =
|
93
|
-
method_signature.outputs =
|
68
|
+
method_signature.inputs = input_sigs
|
69
|
+
method_signature.outputs = output_sig
|
94
70
|
return method_signature
|
95
71
|
|
96
72
|
|
97
|
-
def
|
73
|
+
def _get_annotations_source(func):
|
74
|
+
"""Extracts raw annotation strings from the function source."""
|
75
|
+
source = inspect.getsource(func) # Get function source code
|
76
|
+
source = textwrap.dedent(source) # Dedent source code
|
77
|
+
tree = ast.parse(source) # Parse into AST
|
78
|
+
func_node = next(node for node in tree.body
|
79
|
+
if isinstance(node, ast.FunctionDef)) # Get function node
|
80
|
+
|
81
|
+
annotations = {}
|
82
|
+
for arg in func_node.args.args: # Process arguments
|
83
|
+
if arg.annotation:
|
84
|
+
annotations[arg.arg] = ast.unparse(arg.annotation) # Get raw annotation string
|
85
|
+
|
86
|
+
if func_node.returns: # Process return type
|
87
|
+
annotations["return"] = ast.unparse(func_node.returns)
|
88
|
+
|
89
|
+
return annotations
|
90
|
+
|
91
|
+
|
92
|
+
def build_variable_signature(name, annotation, default=inspect.Parameter.empty, is_output=False):
|
98
93
|
'''
|
99
|
-
Build a data proto signature
|
94
|
+
Build a data proto signature and get the normalized python type for the given annotation.
|
100
95
|
'''
|
101
96
|
|
102
|
-
vars = []
|
103
|
-
|
104
97
|
# check valid names (should already be constrained by python naming, but check anyway)
|
105
|
-
|
106
|
-
|
107
|
-
re.match(r'return(\.\d+)?', param.name)):
|
108
|
-
raise ValueError(f'Invalid variable name: {param.name}')
|
98
|
+
if not name.isidentifier():
|
99
|
+
raise ValueError(f'Invalid variable name: {name}')
|
109
100
|
|
110
101
|
# get fields for each variable based on type
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
102
|
+
tp, streaming = _normalize_type(annotation)
|
103
|
+
|
104
|
+
#var = resources_pb2.VariableSignature() # TODO
|
105
|
+
sig = _VariableSignature() #for now
|
106
|
+
sig.name = name
|
107
|
+
|
108
|
+
_fill_signature_type(sig, tp)
|
109
|
+
|
110
|
+
sig.streaming = streaming
|
111
|
+
|
112
|
+
if not is_output:
|
113
|
+
sig.required = (default is inspect.Parameter.empty)
|
114
|
+
if not sig.required:
|
115
|
+
sig.default = default
|
116
|
+
|
117
|
+
return sig, type, streaming
|
118
|
+
|
119
|
+
|
120
|
+
def _fill_signature_type(sig, tp):
|
121
|
+
try:
|
122
|
+
if tp in _DATA_TYPES:
|
123
|
+
sig.data_type = _DATA_TYPES[tp].data_type
|
124
|
+
return
|
125
|
+
except TypeError:
|
126
|
+
pass # not hashable type
|
127
|
+
|
128
|
+
if isinstance(tp, data_types.NamedFields):
|
129
|
+
sig.data_type = DataType.NAMED_FIELDS
|
130
|
+
for name, inner_type in tp.items():
|
131
|
+
# inner_sig = sig.type_args.add()
|
132
|
+
sig.type_args.append(inner_sig := _VariableSignature())
|
133
|
+
inner_sig.name = name
|
134
|
+
_fill_signature_type(inner_sig, inner_type)
|
135
|
+
return
|
136
|
+
|
137
|
+
if get_origin(tp) == tuple:
|
138
|
+
sig.data_type = DataType.TUPLE
|
139
|
+
for inner_type in get_args(tp):
|
140
|
+
#inner_sig = sig.type_args.add()
|
141
|
+
sig.type_args.append(inner_sig := _VariableSignature())
|
142
|
+
_fill_signature_type(inner_sig, inner_type)
|
143
|
+
return
|
144
|
+
|
145
|
+
if get_origin(tp) == list:
|
146
|
+
sig.data_type = DataType.LIST
|
147
|
+
inner_type = get_args(tp)[0]
|
148
|
+
#inner_sig = sig.type_args.add()
|
149
|
+
sig.type_args.append(inner_sig := _VariableSignature())
|
150
|
+
_fill_signature_type(inner_sig, inner_type)
|
151
|
+
return
|
152
|
+
|
153
|
+
raise TypeError(f'Unsupported type: {tp}')
|
154
|
+
|
155
|
+
|
156
|
+
def serializer_from_signature(signature):
|
157
|
+
'''
|
158
|
+
Get the serializer for the given signature.
|
159
|
+
'''
|
160
|
+
if signature.data_type in _SERIALIZERS_BY_TYPE_ENUM:
|
161
|
+
return _SERIALIZERS_BY_TYPE_ENUM[signature.data_type]
|
162
|
+
if signature.data_type == DataType.LIST:
|
163
|
+
return ListSerializer(serializer_from_signature(signature.type_args[0]))
|
164
|
+
if signature.data_type == DataType.TUPLE:
|
165
|
+
return TupleSerializer([serializer_from_signature(sig) for sig in signature.type_args])
|
166
|
+
if signature.data_type == DataType.NAMED_FIELDS:
|
167
|
+
return NamedFieldsSerializer(
|
168
|
+
{sig.name: serializer_from_signature(sig)
|
169
|
+
for sig in signature.type_args})
|
170
|
+
raise ValueError(f'Unsupported type: {signature.data_type}')
|
137
171
|
|
138
172
|
|
139
173
|
def signatures_to_json(signatures):
|
@@ -143,7 +177,8 @@ def signatures_to_json(signatures):
|
|
143
177
|
|
144
178
|
|
145
179
|
def signatures_from_json(json_str):
|
146
|
-
|
180
|
+
d = json.loads(json_str, object_pairs_hook=_SignatureDict)
|
181
|
+
return d
|
147
182
|
|
148
183
|
|
149
184
|
def signatures_to_yaml(signatures):
|
@@ -163,8 +198,6 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
|
|
163
198
|
'''
|
164
199
|
if proto is None:
|
165
200
|
proto = resources_pb2.Data()
|
166
|
-
if not is_output: # TODO: use this consistently for return keys also
|
167
|
-
kwargs = flatten_nested_keys(kwargs, signatures, is_output)
|
168
201
|
unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
|
169
202
|
if unknown:
|
170
203
|
if unknown == {'return'} and len(signatures) > 1:
|
@@ -177,11 +210,12 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
|
|
177
210
|
raise TypeError(f'Missing required argument: {sig.name}')
|
178
211
|
continue # skip missing fields, they can be set to default on the server
|
179
212
|
data = kwargs[sig.name]
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
213
|
+
serializer = serializer_from_signature(sig)
|
214
|
+
# TODO determine if any (esp the first) var can go in the proto without parts
|
215
|
+
# and whether to put this in the signature or dynamically determine it
|
216
|
+
part = proto.parts.add()
|
217
|
+
part.id = sig.name
|
218
|
+
serializer.serialize(part.data, data)
|
185
219
|
return proto
|
186
220
|
|
187
221
|
|
@@ -189,82 +223,31 @@ def deserialize(proto, signatures, is_output=False):
|
|
189
223
|
'''
|
190
224
|
Deserialize the given proto into kwargs using the given signatures.
|
191
225
|
'''
|
226
|
+
if isinstance(signatures, dict):
|
227
|
+
signatures = [signatures] # TODO update return key level and make consistnet
|
192
228
|
kwargs = {}
|
229
|
+
parts_by_name = {part.id: part for part in proto.parts}
|
193
230
|
for sig in signatures:
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
if
|
231
|
+
serializer = serializer_from_signature(sig)
|
232
|
+
part = parts_by_name.get(sig.name)
|
233
|
+
if part is None:
|
234
|
+
if sig.required or is_output: # TODO allow optional outputs?
|
198
235
|
raise ValueError(f'Missing required field: {sig.name}')
|
199
236
|
continue
|
200
|
-
|
201
|
-
|
202
|
-
kwargs[
|
203
|
-
if is_output:
|
204
|
-
if len(kwargs) == 1 and 'return' in kwargs: # case for single return value
|
205
|
-
return kwargs['return']
|
206
|
-
if kwargs and 'return.0' in kwargs: # case for tuple return values
|
207
|
-
return tuple(kwargs[f'return.{i}'] for i in range(len(kwargs)))
|
208
|
-
return data_types.Output(kwargs)
|
209
|
-
kwargs = unflatten_nested_keys(kwargs, signatures, is_output)
|
237
|
+
kwargs[sig.name] = serializer.deserialize(part.data)
|
238
|
+
if len(kwargs) == 1 and 'return' in kwargs:
|
239
|
+
return kwargs['return']
|
210
240
|
return kwargs
|
211
241
|
|
212
242
|
|
213
|
-
def
|
214
|
-
if data_type in _SERIALIZERS_BY_TYPE_STRING:
|
215
|
-
return _SERIALIZERS_BY_TYPE_STRING[data_type]
|
216
|
-
if data_type.startswith('List['):
|
217
|
-
inner_type_string = data_type[len('List['):-1]
|
218
|
-
inner_serializer = get_serializer(inner_type_string)
|
219
|
-
return ListSerializer(inner_serializer)
|
220
|
-
raise ValueError(f'Unsupported type: "{data_type}"')
|
221
|
-
|
222
|
-
|
223
|
-
def flatten_nested_keys(kwargs, signatures, is_output):
|
224
|
-
'''
|
225
|
-
Flatten nested keys into a single key with a dot, e.g. {'a': {'b': 1}} -> {'a.b': 1}
|
226
|
-
in the kwargs, using the given signatures to determine which keys are nested.
|
227
|
-
'''
|
228
|
-
nested_keys = [sig.name for sig in signatures if '.' in sig.name]
|
229
|
-
outer_keys = set(key.split('.')[0] for key in nested_keys)
|
230
|
-
for outer in outer_keys:
|
231
|
-
if outer not in kwargs:
|
232
|
-
continue
|
233
|
-
kwargs.update({outer + '.' + k: v for k, v in kwargs.pop(outer).items()})
|
234
|
-
return kwargs
|
235
|
-
|
236
|
-
|
237
|
-
def unflatten_nested_keys(kwargs, signatures, is_output):
|
243
|
+
def get_stream_from_signature(signatures):
|
238
244
|
'''
|
239
|
-
|
240
|
-
Uses the signatures to determine which keys are nested.
|
241
|
-
The dict subclass is Input or Output, depending on the is_output flag.
|
242
|
-
Preserves the order of args from the signatures.
|
245
|
+
Get the stream signature from the given signatures.
|
243
246
|
'''
|
244
|
-
unflattened = OrderedDict()
|
245
247
|
for sig in signatures:
|
246
|
-
if
|
247
|
-
|
248
|
-
|
249
|
-
continue
|
250
|
-
if sig.name not in kwargs:
|
251
|
-
continue
|
252
|
-
parts = sig.name.split('.')
|
253
|
-
assert len(parts) == 2, 'Only one level of nested keys is supported'
|
254
|
-
if parts[0] not in unflattened:
|
255
|
-
unflattened[parts[0]] = data_types.Output() if is_output else data_types.Input()
|
256
|
-
unflattened[parts[0]][parts[1]] = kwargs[sig.name]
|
257
|
-
return unflattened
|
258
|
-
|
259
|
-
|
260
|
-
def get_stream_from_signature(signatures):
|
261
|
-
streaming_signatures = [var for var in signatures if var.streaming]
|
262
|
-
if not streaming_signatures:
|
263
|
-
return None, []
|
264
|
-
stream_argname = set([var.name.split('.', 1)[0] for var in streaming_signatures])
|
265
|
-
assert len(stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
|
266
|
-
stream_argname = stream_argname.pop()
|
267
|
-
return stream_argname, streaming_signatures
|
248
|
+
if sig.streaming:
|
249
|
+
return sig
|
250
|
+
return None
|
268
251
|
|
269
252
|
|
270
253
|
def _is_empty_proto_data(data):
|
@@ -275,178 +258,178 @@ def _is_empty_proto_data(data):
|
|
275
258
|
return not data
|
276
259
|
|
277
260
|
|
278
|
-
def
|
279
|
-
field = sig.data_field
|
280
|
-
|
281
|
-
# check if we need to force a named part, to distinguish between empty and unset values
|
282
|
-
if force_named_part and not field.startswith('parts['):
|
283
|
-
field = f'parts[{sig.name}].{field}'
|
284
|
-
|
285
|
-
# gets the named part from the proto, according to the field path
|
286
|
-
# note we only support one level of named parts
|
287
|
-
#parts = field.replace(' ', '').split('.')
|
288
|
-
# split on . but not if it is inside brackets, e.g. parts[outer.inner].field
|
289
|
-
parts = re.split(r'\.(?![^\[]*\])', field.replace(' ', ''))
|
290
|
-
|
291
|
-
if len(parts) not in (1, 2, 3): # field, parts[name].field, parts[name].parts[].field
|
292
|
-
raise ValueError('Invalid field: %s' % field)
|
293
|
-
|
294
|
-
if len(parts) == 1:
|
295
|
-
# also need to check if there is an explicitly named part, e.g. for empty values
|
296
|
-
part = next((part for part in proto.parts if part.id == sig.name), None)
|
297
|
-
if part:
|
298
|
-
return part.data, field
|
299
|
-
if not serializing and not is_output and _is_empty_proto_data(getattr(proto, field)):
|
300
|
-
return None, field
|
301
|
-
return proto, field
|
302
|
-
|
303
|
-
# list
|
304
|
-
if parts[0] == 'parts[]':
|
305
|
-
if len(parts) != 2:
|
306
|
-
raise ValueError('Invalid field: %s' % field)
|
307
|
-
return proto, field # return the data that contains the list itself
|
308
|
-
|
309
|
-
# named part
|
310
|
-
if not (m := re.match(r'parts\[([\w.]+)\]', parts[0])):
|
311
|
-
raise ValueError('Invalid field: %s' % field)
|
312
|
-
if not (name := m.group(1)):
|
313
|
-
raise ValueError('Invalid field: %s' % field)
|
314
|
-
assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
|
315
|
-
part = next((part for part in proto.parts if part.id == name), None)
|
316
|
-
if part is None:
|
317
|
-
if not serializing:
|
318
|
-
raise ValueError('Missing part: %s' % name)
|
319
|
-
part = proto.parts.add()
|
320
|
-
part.id = name
|
321
|
-
return part.data, '.'.join(parts[1:])
|
322
|
-
|
323
|
-
|
324
|
-
def _normalize_types(param, is_output=False):
|
261
|
+
def _normalize_type(tp):
|
325
262
|
'''
|
326
|
-
Normalize the types for the given parameter.
|
327
|
-
|
263
|
+
Normalize the types for the given parameter.
|
264
|
+
Returns the normalized type and whether the parameter is streaming.
|
328
265
|
'''
|
329
|
-
tp = param.annotation
|
330
|
-
|
331
266
|
# stream type indicates streaming, not part of the data itself
|
267
|
+
# it can only be used at the top-level of the var type
|
332
268
|
streaming = (get_origin(tp) == data_types.Stream)
|
333
269
|
if streaming:
|
334
270
|
tp = get_args(tp)[0]
|
335
271
|
|
336
|
-
|
337
|
-
# output type used for named return values, each with their own data type
|
338
|
-
if isinstance(tp, (dict, data_types.Output, data_types.Input)):
|
339
|
-
return {param.name + '.' + name: _normalize_data_type(val)
|
340
|
-
for name, val in tp.items()}, streaming
|
341
|
-
if tp == data_types.Output: # check for Output type without values
|
342
|
-
if not is_output:
|
343
|
-
raise TypeError('Output types can only be used for output values')
|
344
|
-
raise TypeError('Output types must be instantiated with inner type values for each key')
|
345
|
-
if tp == data_types.Input: # check for Output type without values
|
346
|
-
if is_output:
|
347
|
-
raise TypeError('Input types can only be used for input values')
|
348
|
-
raise TypeError(
|
349
|
-
'Stream[Input(...)] types must be instantiated with inner type values for each key')
|
350
|
-
|
351
|
-
return {param.name: _normalize_data_type(tp)}, streaming
|
272
|
+
return _normalize_data_type(tp), streaming
|
352
273
|
|
353
274
|
|
354
275
|
def _normalize_data_type(tp):
|
355
276
|
# check if list, and if so, get inner type
|
356
|
-
|
357
|
-
if is_list:
|
277
|
+
if get_origin(tp) == list:
|
358
278
|
tp = get_args(tp)[0]
|
279
|
+
return List[_normalize_data_type(tp)]
|
280
|
+
|
281
|
+
if isinstance(tp, (tuple, list)):
|
282
|
+
return Tuple[tuple(_normalize_data_type(val) for val in tp)]
|
283
|
+
|
284
|
+
if isinstance(tp, (dict, data_types.NamedFields)):
|
285
|
+
return data_types.NamedFields(**{name: _normalize_data_type(val) for name, val in tp.items()})
|
359
286
|
|
360
|
-
# check if numpy array, and if so, use ndarray
|
287
|
+
# check if numpy array type, and if so, use ndarray
|
361
288
|
if get_origin(tp) == np.ndarray:
|
362
|
-
|
289
|
+
return np.ndarray
|
363
290
|
|
364
291
|
# check for PIL images (sometimes types use the module, sometimes the class)
|
365
292
|
# set these to use the Image data handler
|
366
|
-
if tp in (PIL.Image, PIL.Image.Image):
|
367
|
-
|
368
|
-
|
369
|
-
#
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
293
|
+
if tp in (data_types.Image, PIL.Image, PIL.Image.Image):
|
294
|
+
return data_types.Image
|
295
|
+
|
296
|
+
# check for jsonable types
|
297
|
+
# TODO should we include dict vs list in the data type somehow?
|
298
|
+
if tp == dict or (get_origin(tp) == dict and tp not in _DATA_TYPES and _is_jsonable(tp)):
|
299
|
+
return data_types.JSON
|
300
|
+
if tp == list or (get_origin(tp) == list and tp not in _DATA_TYPES and _is_jsonable(tp)):
|
301
|
+
return data_types.JSON
|
302
|
+
|
303
|
+
# check for known data types
|
304
|
+
try:
|
305
|
+
if tp in _DATA_TYPES:
|
306
|
+
return tp
|
307
|
+
except TypeError:
|
308
|
+
pass # not hashable type
|
309
|
+
|
310
|
+
raise TypeError(f'Unsupported type: {tp}')
|
311
|
+
|
312
|
+
|
313
|
+
def _is_jsonable(tp):
|
314
|
+
if tp in (dict, list, tuple, str, int, float, bool, type(None)):
|
315
|
+
return True
|
316
|
+
if get_origin(tp) == list:
|
317
|
+
return _is_jsonable(get_args(tp)[0])
|
318
|
+
if get_origin(tp) == dict:
|
319
|
+
return all(_is_jsonable(val) for val in get_args(tp))
|
320
|
+
return False
|
321
|
+
|
322
|
+
|
323
|
+
# TODO --- tmp classes to stand-in for protos until they are defined and built into this package
|
324
|
+
class _SignatureDict(dict):
|
325
|
+
__getattr__ = dict.__getitem__
|
326
|
+
__setattr__ = dict.__setitem__
|
376
327
|
|
377
|
-
return tp
|
378
328
|
|
329
|
+
class _VariableSignature(_SignatureDict):
|
379
330
|
|
380
|
-
|
381
|
-
|
382
|
-
|
331
|
+
def __init__(self):
|
332
|
+
super().__init__()
|
333
|
+
self.name = ''
|
334
|
+
self.type = ''
|
335
|
+
self.type_args = []
|
336
|
+
self.streaming = False
|
337
|
+
self.required = False
|
338
|
+
self.default = ''
|
339
|
+
self.description = ''
|
383
340
|
|
384
341
|
|
385
342
|
# data_type: name of the data type
|
386
343
|
# data_field: name of the field in the data proto
|
387
344
|
# serializer: serializer for the data type
|
388
|
-
_DataType = namedtuple('_DataType', ('data_type', '
|
345
|
+
_DataType = namedtuple('_DataType', ('data_type', 'serializer'))
|
346
|
+
|
347
|
+
|
348
|
+
# this will come from the proto module, but for now, define it here
|
349
|
+
class DataType:
|
350
|
+
NOT_SET = 'NOT_SET'
|
351
|
+
|
352
|
+
STR = 'STR'
|
353
|
+
BYTES = 'BYTES'
|
354
|
+
INT = 'INT'
|
355
|
+
FLOAT = 'FLOAT'
|
356
|
+
BOOL = 'BOOL'
|
357
|
+
NDARRAY = 'NDARRAY'
|
358
|
+
JSON = 'JSON'
|
359
|
+
|
360
|
+
TEXT = 'TEXT'
|
361
|
+
IMAGE = 'IMAGE'
|
362
|
+
CONCEPT = 'CONCEPT'
|
363
|
+
REGION = 'REGION'
|
364
|
+
FRAME = 'FRAME'
|
365
|
+
AUDIO = 'AUDIO'
|
366
|
+
VIDEO = 'VIDEO'
|
367
|
+
|
368
|
+
NAMED_FIELDS = 'NAMED_FIELDS'
|
369
|
+
TUPLE = 'TUPLE'
|
370
|
+
LIST = 'LIST'
|
389
371
|
|
390
|
-
|
372
|
+
|
373
|
+
# simple, non-container types that correspond directly to a data field
|
391
374
|
_DATA_TYPES = {
|
392
375
|
str:
|
393
|
-
_DataType(
|
376
|
+
_DataType(DataType.STR, AtomicFieldSerializer('string_value')),
|
394
377
|
bytes:
|
395
|
-
_DataType(
|
378
|
+
_DataType(DataType.BYTES, AtomicFieldSerializer('bytes_value')),
|
396
379
|
int:
|
397
|
-
_DataType(
|
380
|
+
_DataType(DataType.INT, AtomicFieldSerializer('int_value')),
|
398
381
|
float:
|
399
|
-
_DataType(
|
382
|
+
_DataType(DataType.FLOAT, AtomicFieldSerializer('float_value')),
|
400
383
|
bool:
|
401
|
-
_DataType(
|
402
|
-
None:
|
403
|
-
_DataType('None', '', NullValueSerializer()),
|
384
|
+
_DataType(DataType.BOOL, AtomicFieldSerializer('bool_value')),
|
404
385
|
np.ndarray:
|
405
|
-
_DataType(
|
386
|
+
_DataType(DataType.NDARRAY, NDArraySerializer('ndarray')),
|
406
387
|
data_types.Text:
|
407
|
-
_DataType(
|
388
|
+
_DataType(DataType.TEXT, MessageSerializer('text', data_types.Text)),
|
408
389
|
data_types.Image:
|
409
|
-
_DataType(
|
390
|
+
_DataType(DataType.IMAGE, MessageSerializer('image', data_types.Image)),
|
410
391
|
data_types.Concept:
|
411
|
-
_DataType(
|
392
|
+
_DataType(DataType.CONCEPT, MessageSerializer('concepts', data_types.Concept)),
|
412
393
|
data_types.Region:
|
413
|
-
_DataType(
|
394
|
+
_DataType(DataType.REGION, MessageSerializer('regions', data_types.Region)),
|
414
395
|
data_types.Frame:
|
415
|
-
_DataType(
|
396
|
+
_DataType(DataType.FRAME, MessageSerializer('frames', data_types.Frame)),
|
416
397
|
data_types.Audio:
|
417
|
-
_DataType(
|
398
|
+
_DataType(DataType.AUDIO, MessageSerializer('audio', data_types.Audio)),
|
418
399
|
data_types.Video:
|
419
|
-
_DataType(
|
420
|
-
|
421
|
-
# lists handled specially, not as generic lists using parts
|
422
|
-
List[int]:
|
423
|
-
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
424
|
-
List[float]:
|
425
|
-
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
426
|
-
List[bool]:
|
427
|
-
_DataType('ndarray', 'ndarray', NDArraySerializer()),
|
400
|
+
_DataType(DataType.VIDEO, MessageSerializer('video', data_types.Video)),
|
428
401
|
}
|
429
402
|
|
403
|
+
_SERIALIZERS_BY_TYPE_ENUM = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
|
430
404
|
|
431
|
-
# add generic lists using parts, for all supported types
|
432
|
-
def _add_list_fields():
|
433
|
-
for tp in list(_DATA_TYPES.keys()):
|
434
|
-
if List[tp] in _DATA_TYPES:
|
435
|
-
# already added as special case
|
436
|
-
continue
|
437
|
-
|
438
|
-
# check if data field is repeated, and if so, use repeated field for list
|
439
|
-
field_name = _DATA_TYPES[tp].data_field
|
440
|
-
descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
|
441
|
-
repeated = descriptor and descriptor.label == descriptor.LABEL_REPEATED
|
442
|
-
|
443
|
-
# add to supported types
|
444
|
-
data_type = 'List[%s]' % _DATA_TYPES[tp].data_type
|
445
|
-
data_field = field_name if repeated else 'parts[].' + field_name
|
446
|
-
serializer = ListSerializer(_DATA_TYPES[tp].serializer)
|
447
|
-
|
448
|
-
_DATA_TYPES[List[tp]] = _DataType(data_type, data_field, serializer)
|
449
405
|
|
406
|
+
class CompatibilitySerializer(Serializer):
|
407
|
+
'''
|
408
|
+
Serialization of basic value types, used for backwards compatibility
|
409
|
+
with older models that don't have type signatures.
|
410
|
+
'''
|
450
411
|
|
451
|
-
|
452
|
-
|
412
|
+
def serialize(self, data_proto, value):
|
413
|
+
tp = _normalize_data_type(type(value))
|
414
|
+
|
415
|
+
try:
|
416
|
+
serializer = _DATA_TYPES[tp].serializer
|
417
|
+
except (KeyError, TypeError):
|
418
|
+
raise TypeError(f'serializer currently only supports basic types, got {tp}')
|
419
|
+
|
420
|
+
serializer.serialize(data_proto, value)
|
421
|
+
|
422
|
+
def deserialize(self, data_proto):
|
423
|
+
fields = [k.name for k, _ in data_proto.ListFields()]
|
424
|
+
if 'parts' in fields:
|
425
|
+
raise ValueError('serializer does not support parts')
|
426
|
+
serializers = [
|
427
|
+
serializer for serializer in _SERIALIZERS_BY_TYPE_ENUM.values()
|
428
|
+
if serializer.field_name in fields
|
429
|
+
]
|
430
|
+
if not serializers:
|
431
|
+
raise ValueError('Returned data not recognized')
|
432
|
+
if len(serializers) != 1:
|
433
|
+
raise ValueError('Only single output supported for serializer')
|
434
|
+
serializer = serializers[0]
|
435
|
+
return serializer.deserialize(data_proto)
|