clarifai 11.1.5rc7__py3-none-any.whl → 11.1.5rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  3. clarifai/client/#model_client.py# +430 -0
  4. clarifai/client/model.py +95 -61
  5. clarifai/client/model_client.py +64 -49
  6. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  7. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  8. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  9. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  10. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  11. clarifai/runners/models/model_class.py +29 -46
  12. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  13. clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
  14. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  15. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  16. clarifai/runners/utils/data_types.py +62 -10
  17. clarifai/runners/utils/method_signatures.py +278 -295
  18. clarifai/runners/utils/serializers.py +143 -67
  19. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/METADATA +1 -1
  20. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/RECORD +24 -23
  21. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/LICENSE +0 -0
  22. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/WHEEL +0 -0
  23. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/entry_points.txt +0 -0
  24. {clarifai-11.1.5rc7.dist-info → clarifai-11.1.5rc8.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
+ import ast
1
2
  import inspect
2
3
  import json
3
- import re
4
- import types
5
- from collections import OrderedDict, namedtuple
6
- from typing import List, get_args, get_origin
4
+ import textwrap
5
+ from collections import namedtuple
6
+ from typing import List, Tuple, get_args, get_origin
7
7
 
8
8
  import numpy as np
9
9
  import PIL.Image
@@ -12,12 +12,12 @@ from clarifai_grpc.grpc.api import resources_pb2
12
12
  from google.protobuf.message import Message as MessageProto
13
13
 
14
14
  from clarifai.runners.utils import data_types
15
- from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ImageSerializer,
16
- ListSerializer, MessageSerializer,
17
- NDArraySerializer, NullValueSerializer, Serializer)
15
+ from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ListSerializer,
16
+ MessageSerializer, NamedFieldsSerializer,
17
+ NDArraySerializer, Serializer, TupleSerializer)
18
18
 
19
19
 
20
- def build_function_signature(func, method_type: str):
20
+ def build_function_signature(func):
21
21
  '''
22
22
  Build a signature for the given function.
23
23
  '''
@@ -30,110 +30,144 @@ def build_function_signature(func, method_type: str):
30
30
 
31
31
  return_annotation = sig.return_annotation
32
32
  if return_annotation == inspect.Parameter.empty:
33
- raise ValueError('Function must have a return annotation')
34
- # check for multiple return values and convert to dict for named values
35
- return_streaming = False
36
- if get_origin(return_annotation) == data_types.Stream:
37
- return_annotation = get_args(return_annotation)[0]
38
- return_streaming = True
39
- if get_origin(return_annotation) == tuple:
40
- return_annotation = tuple(get_args(return_annotation))
41
- if isinstance(return_annotation, tuple):
42
- return_annotation = {'return.%s' % i: tp for i, tp in enumerate(return_annotation)}
43
- if not isinstance(return_annotation, dict):
44
- return_annotation = {'return': return_annotation}
45
-
46
- input_vars = build_variables_signature(sig.parameters.values())
47
- output_vars = build_variables_signature(
48
- [
49
- # XXX inspect.Parameter errors for the special return names, so use SimpleNamespace
50
- types.SimpleNamespace(name=name, annotation=tp, default=inspect.Parameter.empty)
51
- for name, tp in return_annotation.items()
52
- ],
53
- is_output=True)
54
- if return_streaming:
55
- for var in output_vars:
56
- var.streaming = True
57
-
58
- # check for streams
59
- if method_type == 'predict':
60
- for var in input_vars:
61
- if var.streaming:
62
- raise TypeError('Stream inputs are not supported for predict methods')
63
- for var in output_vars:
64
- if var.streaming:
65
- raise TypeError('Stream outputs are not supported for predict methods')
66
- elif method_type == 'generate':
67
- for var in input_vars:
68
- if var.streaming:
69
- raise TypeError('Stream inputs are not supported for generate methods')
70
- if not all(var.streaming for var in output_vars):
71
- raise TypeError('Generate methods must return a stream')
72
- elif method_type == 'stream':
73
- input_stream_vars = [var for var in input_vars if var.streaming]
74
- if len(input_stream_vars) == 0:
75
- raise TypeError('Stream methods must include a Stream input')
76
- if not all(var.streaming for var in output_vars):
77
- raise TypeError('Stream methods must return a single Stream')
33
+ raise TypeError('Function must have a return annotation')
34
+
35
+ input_sigs = [
36
+ build_variable_signature(p.name, p.annotation, p.default) for p in sig.parameters.values()
37
+ ]
38
+ input_sigs, input_types, input_streaming = zip(*input_sigs)
39
+ output_sig, output_type, output_streaming = build_variable_signature(
40
+ 'return', return_annotation, is_output=True)
41
+ # TODO: flatten out "return" layer if not needed
42
+
43
+ # check for streams and determine method type
44
+ if sum(input_streaming) > 1:
45
+ raise TypeError('streaming methods must have at most one streaming input')
46
+ input_streaming = any(input_streaming)
47
+ if not (input_streaming or output_streaming):
48
+ method_type = 'predict'
49
+ elif not input_streaming and output_streaming:
50
+ method_type = 'generate'
51
+ elif input_streaming and output_streaming:
52
+ method_type = 'stream'
78
53
  else:
79
- raise TypeError('Invalid method type: %s' % method_type)
54
+ raise TypeError('stream methods with streaming inputs must have streaming outputs')
80
55
 
81
56
  #method_signature = resources_pb2.MethodSignature() # TODO
82
- method_signature = _NamedFields() #for now
57
+ method_signature = _SignatureDict() #for now
83
58
 
84
59
  method_signature.name = func.__name__
85
60
  #method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
86
61
  assert method_type in ('predict', 'generate', 'stream')
87
62
  method_signature.method_type = method_type
88
63
  method_signature.docstring = func.__doc__
64
+ method_signature.annotations_json = json.dumps(_get_annotations_source(func))
89
65
 
90
66
  #method_signature.inputs.extend(input_vars)
91
67
  #method_signature.outputs.extend(output_vars)
92
- method_signature.inputs = input_vars
93
- method_signature.outputs = output_vars
68
+ method_signature.inputs = input_sigs
69
+ method_signature.outputs = output_sig
94
70
  return method_signature
95
71
 
96
72
 
97
- def build_variables_signature(parameters: List[inspect.Parameter], is_output=False):
73
+ def _get_annotations_source(func):
74
+ """Extracts raw annotation strings from the function source."""
75
+ source = inspect.getsource(func) # Get function source code
76
+ source = textwrap.dedent(source) # Dedent source code
77
+ tree = ast.parse(source) # Parse into AST
78
+ func_node = next(node for node in tree.body
79
+ if isinstance(node, ast.FunctionDef)) # Get function node
80
+
81
+ annotations = {}
82
+ for arg in func_node.args.args: # Process arguments
83
+ if arg.annotation:
84
+ annotations[arg.arg] = ast.unparse(arg.annotation) # Get raw annotation string
85
+
86
+ if func_node.returns: # Process return type
87
+ annotations["return"] = ast.unparse(func_node.returns)
88
+
89
+ return annotations
90
+
91
+
92
+ def build_variable_signature(name, annotation, default=inspect.Parameter.empty, is_output=False):
98
93
  '''
99
- Build a data proto signature for the given variable or return type annotation.
94
+ Build a data proto signature and get the normalized python type for the given annotation.
100
95
  '''
101
96
 
102
- vars = []
103
-
104
97
  # check valid names (should already be constrained by python naming, but check anyway)
105
- for param in parameters:
106
- if not param.name.isidentifier() and not (is_output and
107
- re.match(r'return(\.\d+)?', param.name)):
108
- raise ValueError(f'Invalid variable name: {param.name}')
98
+ if not name.isidentifier():
99
+ raise ValueError(f'Invalid variable name: {name}')
109
100
 
110
101
  # get fields for each variable based on type
111
- for param in parameters:
112
- param_types, streaming = _normalize_types(param, is_output=is_output)
113
-
114
- for name, tp in param_types.items():
115
- #var = resources_pb2.MethodVariable() # TODO
116
- var = _NamedFields()
117
- var.name = name
118
- var.data_type = _DATA_TYPES[tp].data_type
119
- var.data_field = _DATA_TYPES[tp].data_field
120
- var.streaming = streaming
121
- if not is_output:
122
- var.required = (param.default is inspect.Parameter.empty)
123
- if not var.required:
124
- var.default = param.default
125
- vars.append(var)
126
-
127
- # check if any fields are used more than once, and if so, use parts
128
- # also if more than one field uses parts lists, also use parts, since the lists can be different lengths
129
- # NOTE this is a little fancy, another way would just be to check if there is more than one arg
130
- fields_unique = (len(set(var.data_field for var in vars)) == len(vars))
131
- num_parts_lists = sum(int(var.data_field.startswith('parts[]')) for var in vars)
132
- if not fields_unique or num_parts_lists > 1:
133
- for var in vars:
134
- var.data_field = 'parts[%s].%s' % (var.name, var.data_field)
135
-
136
- return vars
102
+ tp, streaming = _normalize_type(annotation)
103
+
104
+ #var = resources_pb2.VariableSignature() # TODO
105
+ sig = _VariableSignature() #for now
106
+ sig.name = name
107
+
108
+ _fill_signature_type(sig, tp)
109
+
110
+ sig.streaming = streaming
111
+
112
+ if not is_output:
113
+ sig.required = (default is inspect.Parameter.empty)
114
+ if not sig.required:
115
+ sig.default = default
116
+
117
+ return sig, type, streaming
118
+
119
+
120
+ def _fill_signature_type(sig, tp):
121
+ try:
122
+ if tp in _DATA_TYPES:
123
+ sig.data_type = _DATA_TYPES[tp].data_type
124
+ return
125
+ except TypeError:
126
+ pass # not hashable type
127
+
128
+ if isinstance(tp, data_types.NamedFields):
129
+ sig.data_type = DataType.NAMED_FIELDS
130
+ for name, inner_type in tp.items():
131
+ # inner_sig = sig.type_args.add()
132
+ sig.type_args.append(inner_sig := _VariableSignature())
133
+ inner_sig.name = name
134
+ _fill_signature_type(inner_sig, inner_type)
135
+ return
136
+
137
+ if get_origin(tp) == tuple:
138
+ sig.data_type = DataType.TUPLE
139
+ for inner_type in get_args(tp):
140
+ #inner_sig = sig.type_args.add()
141
+ sig.type_args.append(inner_sig := _VariableSignature())
142
+ _fill_signature_type(inner_sig, inner_type)
143
+ return
144
+
145
+ if get_origin(tp) == list:
146
+ sig.data_type = DataType.LIST
147
+ inner_type = get_args(tp)[0]
148
+ #inner_sig = sig.type_args.add()
149
+ sig.type_args.append(inner_sig := _VariableSignature())
150
+ _fill_signature_type(inner_sig, inner_type)
151
+ return
152
+
153
+ raise TypeError(f'Unsupported type: {tp}')
154
+
155
+
156
+ def serializer_from_signature(signature):
157
+ '''
158
+ Get the serializer for the given signature.
159
+ '''
160
+ if signature.data_type in _SERIALIZERS_BY_TYPE_ENUM:
161
+ return _SERIALIZERS_BY_TYPE_ENUM[signature.data_type]
162
+ if signature.data_type == DataType.LIST:
163
+ return ListSerializer(serializer_from_signature(signature.type_args[0]))
164
+ if signature.data_type == DataType.TUPLE:
165
+ return TupleSerializer([serializer_from_signature(sig) for sig in signature.type_args])
166
+ if signature.data_type == DataType.NAMED_FIELDS:
167
+ return NamedFieldsSerializer(
168
+ {sig.name: serializer_from_signature(sig)
169
+ for sig in signature.type_args})
170
+ raise ValueError(f'Unsupported type: {signature.data_type}')
137
171
 
138
172
 
139
173
  def signatures_to_json(signatures):
@@ -143,7 +177,8 @@ def signatures_to_json(signatures):
143
177
 
144
178
 
145
179
  def signatures_from_json(json_str):
146
- return json.loads(json_str, object_pairs_hook=_NamedFields)
180
+ d = json.loads(json_str, object_pairs_hook=_SignatureDict)
181
+ return d
147
182
 
148
183
 
149
184
  def signatures_to_yaml(signatures):
@@ -163,8 +198,6 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
163
198
  '''
164
199
  if proto is None:
165
200
  proto = resources_pb2.Data()
166
- if not is_output: # TODO: use this consistently for return keys also
167
- kwargs = flatten_nested_keys(kwargs, signatures, is_output)
168
201
  unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
169
202
  if unknown:
170
203
  if unknown == {'return'} and len(signatures) > 1:
@@ -177,11 +210,12 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
177
210
  raise TypeError(f'Missing required argument: {sig.name}')
178
211
  continue # skip missing fields, they can be set to default on the server
179
212
  data = kwargs[sig.name]
180
- force_named_part = (_is_empty_proto_data(data) and not is_output and not sig.required)
181
- data_proto, field = _get_data_part(
182
- proto, sig, is_output=is_output, serializing=True, force_named_part=force_named_part)
183
- serializer = get_serializer(sig.data_type)
184
- serializer.serialize(data_proto, field, data)
213
+ serializer = serializer_from_signature(sig)
214
+ # TODO determine if any (esp the first) var can go in the proto without parts
215
+ # and whether to put this in the signature or dynamically determine it
216
+ part = proto.parts.add()
217
+ part.id = sig.name
218
+ serializer.serialize(part.data, data)
185
219
  return proto
186
220
 
187
221
 
@@ -189,82 +223,31 @@ def deserialize(proto, signatures, is_output=False):
189
223
  '''
190
224
  Deserialize the given proto into kwargs using the given signatures.
191
225
  '''
226
+ if isinstance(signatures, dict):
227
+ signatures = [signatures] # TODO update return key level and make consistnet
192
228
  kwargs = {}
229
+ parts_by_name = {part.id: part for part in proto.parts}
193
230
  for sig in signatures:
194
- data_proto, field = _get_data_part(proto, sig, is_output=is_output, serializing=False)
195
- if data_proto is None:
196
- # not set in proto, check if required or skip if optional arg
197
- if not is_output and sig.required:
231
+ serializer = serializer_from_signature(sig)
232
+ part = parts_by_name.get(sig.name)
233
+ if part is None:
234
+ if sig.required or is_output: # TODO allow optional outputs?
198
235
  raise ValueError(f'Missing required field: {sig.name}')
199
236
  continue
200
- serializer = get_serializer(sig.data_type)
201
- data = serializer.deserialize(data_proto, field)
202
- kwargs[sig.name] = data
203
- if is_output:
204
- if len(kwargs) == 1 and 'return' in kwargs: # case for single return value
205
- return kwargs['return']
206
- if kwargs and 'return.0' in kwargs: # case for tuple return values
207
- return tuple(kwargs[f'return.{i}'] for i in range(len(kwargs)))
208
- return data_types.Output(kwargs)
209
- kwargs = unflatten_nested_keys(kwargs, signatures, is_output)
237
+ kwargs[sig.name] = serializer.deserialize(part.data)
238
+ if len(kwargs) == 1 and 'return' in kwargs:
239
+ return kwargs['return']
210
240
  return kwargs
211
241
 
212
242
 
213
- def get_serializer(data_type: str) -> Serializer:
214
- if data_type in _SERIALIZERS_BY_TYPE_STRING:
215
- return _SERIALIZERS_BY_TYPE_STRING[data_type]
216
- if data_type.startswith('List['):
217
- inner_type_string = data_type[len('List['):-1]
218
- inner_serializer = get_serializer(inner_type_string)
219
- return ListSerializer(inner_serializer)
220
- raise ValueError(f'Unsupported type: "{data_type}"')
221
-
222
-
223
- def flatten_nested_keys(kwargs, signatures, is_output):
224
- '''
225
- Flatten nested keys into a single key with a dot, e.g. {'a': {'b': 1}} -> {'a.b': 1}
226
- in the kwargs, using the given signatures to determine which keys are nested.
227
- '''
228
- nested_keys = [sig.name for sig in signatures if '.' in sig.name]
229
- outer_keys = set(key.split('.')[0] for key in nested_keys)
230
- for outer in outer_keys:
231
- if outer not in kwargs:
232
- continue
233
- kwargs.update({outer + '.' + k: v for k, v in kwargs.pop(outer).items()})
234
- return kwargs
235
-
236
-
237
- def unflatten_nested_keys(kwargs, signatures, is_output):
243
+ def get_stream_from_signature(signatures):
238
244
  '''
239
- Unflatten nested keys in kwargs into a dict, e.g. {'a.b': 1} -> {'a': {'b': 1}}
240
- Uses the signatures to determine which keys are nested.
241
- The dict subclass is Input or Output, depending on the is_output flag.
242
- Preserves the order of args from the signatures.
245
+ Get the stream signature from the given signatures.
243
246
  '''
244
- unflattened = OrderedDict()
245
247
  for sig in signatures:
246
- if '.' not in sig.name:
247
- if sig.name in kwargs:
248
- unflattened[sig.name] = kwargs[sig.name]
249
- continue
250
- if sig.name not in kwargs:
251
- continue
252
- parts = sig.name.split('.')
253
- assert len(parts) == 2, 'Only one level of nested keys is supported'
254
- if parts[0] not in unflattened:
255
- unflattened[parts[0]] = data_types.Output() if is_output else data_types.Input()
256
- unflattened[parts[0]][parts[1]] = kwargs[sig.name]
257
- return unflattened
258
-
259
-
260
- def get_stream_from_signature(signatures):
261
- streaming_signatures = [var for var in signatures if var.streaming]
262
- if not streaming_signatures:
263
- return None, []
264
- stream_argname = set([var.name.split('.', 1)[0] for var in streaming_signatures])
265
- assert len(stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
266
- stream_argname = stream_argname.pop()
267
- return stream_argname, streaming_signatures
248
+ if sig.streaming:
249
+ return sig
250
+ return None
268
251
 
269
252
 
270
253
  def _is_empty_proto_data(data):
@@ -275,178 +258,178 @@ def _is_empty_proto_data(data):
275
258
  return not data
276
259
 
277
260
 
278
- def _get_data_part(proto, sig, is_output, serializing, force_named_part=False):
279
- field = sig.data_field
280
-
281
- # check if we need to force a named part, to distinguish between empty and unset values
282
- if force_named_part and not field.startswith('parts['):
283
- field = f'parts[{sig.name}].{field}'
284
-
285
- # gets the named part from the proto, according to the field path
286
- # note we only support one level of named parts
287
- #parts = field.replace(' ', '').split('.')
288
- # split on . but not if it is inside brackets, e.g. parts[outer.inner].field
289
- parts = re.split(r'\.(?![^\[]*\])', field.replace(' ', ''))
290
-
291
- if len(parts) not in (1, 2, 3): # field, parts[name].field, parts[name].parts[].field
292
- raise ValueError('Invalid field: %s' % field)
293
-
294
- if len(parts) == 1:
295
- # also need to check if there is an explicitly named part, e.g. for empty values
296
- part = next((part for part in proto.parts if part.id == sig.name), None)
297
- if part:
298
- return part.data, field
299
- if not serializing and not is_output and _is_empty_proto_data(getattr(proto, field)):
300
- return None, field
301
- return proto, field
302
-
303
- # list
304
- if parts[0] == 'parts[]':
305
- if len(parts) != 2:
306
- raise ValueError('Invalid field: %s' % field)
307
- return proto, field # return the data that contains the list itself
308
-
309
- # named part
310
- if not (m := re.match(r'parts\[([\w.]+)\]', parts[0])):
311
- raise ValueError('Invalid field: %s' % field)
312
- if not (name := m.group(1)):
313
- raise ValueError('Invalid field: %s' % field)
314
- assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
315
- part = next((part for part in proto.parts if part.id == name), None)
316
- if part is None:
317
- if not serializing:
318
- raise ValueError('Missing part: %s' % name)
319
- part = proto.parts.add()
320
- part.id = name
321
- return part.data, '.'.join(parts[1:])
322
-
323
-
324
- def _normalize_types(param, is_output=False):
261
+ def _normalize_type(tp):
325
262
  '''
326
- Normalize the types for the given parameter. Returns a dict of names to types,
327
- including named return values for outputs, and a flag indicating if streaming is used.
263
+ Normalize the types for the given parameter.
264
+ Returns the normalized type and whether the parameter is streaming.
328
265
  '''
329
- tp = param.annotation
330
-
331
266
  # stream type indicates streaming, not part of the data itself
267
+ # it can only be used at the top-level of the var type
332
268
  streaming = (get_origin(tp) == data_types.Stream)
333
269
  if streaming:
334
270
  tp = get_args(tp)[0]
335
271
 
336
- if is_output or streaming: # named types can be used for outputs or streaming inputs
337
- # output type used for named return values, each with their own data type
338
- if isinstance(tp, (dict, data_types.Output, data_types.Input)):
339
- return {param.name + '.' + name: _normalize_data_type(val)
340
- for name, val in tp.items()}, streaming
341
- if tp == data_types.Output: # check for Output type without values
342
- if not is_output:
343
- raise TypeError('Output types can only be used for output values')
344
- raise TypeError('Output types must be instantiated with inner type values for each key')
345
- if tp == data_types.Input: # check for Output type without values
346
- if is_output:
347
- raise TypeError('Input types can only be used for input values')
348
- raise TypeError(
349
- 'Stream[Input(...)] types must be instantiated with inner type values for each key')
350
-
351
- return {param.name: _normalize_data_type(tp)}, streaming
272
+ return _normalize_data_type(tp), streaming
352
273
 
353
274
 
354
275
  def _normalize_data_type(tp):
355
276
  # check if list, and if so, get inner type
356
- is_list = (get_origin(tp) == list)
357
- if is_list:
277
+ if get_origin(tp) == list:
358
278
  tp = get_args(tp)[0]
279
+ return List[_normalize_data_type(tp)]
280
+
281
+ if isinstance(tp, (tuple, list)):
282
+ return Tuple[tuple(_normalize_data_type(val) for val in tp)]
283
+
284
+ if isinstance(tp, (dict, data_types.NamedFields)):
285
+ return data_types.NamedFields(**{name: _normalize_data_type(val) for name, val in tp.items()})
359
286
 
360
- # check if numpy array, and if so, use ndarray
287
+ # check if numpy array type, and if so, use ndarray
361
288
  if get_origin(tp) == np.ndarray:
362
- tp = np.ndarray
289
+ return np.ndarray
363
290
 
364
291
  # check for PIL images (sometimes types use the module, sometimes the class)
365
292
  # set these to use the Image data handler
366
- if tp in (PIL.Image, PIL.Image.Image):
367
- tp = data_types.Image
368
-
369
- # put back list
370
- if is_list:
371
- tp = List[tp]
372
-
373
- # check if supported type
374
- if tp not in _DATA_TYPES:
375
- raise ValueError(f'Unsupported type: {tp}')
293
+ if tp in (data_types.Image, PIL.Image, PIL.Image.Image):
294
+ return data_types.Image
295
+
296
+ # check for jsonable types
297
+ # TODO should we include dict vs list in the data type somehow?
298
+ if tp == dict or (get_origin(tp) == dict and tp not in _DATA_TYPES and _is_jsonable(tp)):
299
+ return data_types.JSON
300
+ if tp == list or (get_origin(tp) == list and tp not in _DATA_TYPES and _is_jsonable(tp)):
301
+ return data_types.JSON
302
+
303
+ # check for known data types
304
+ try:
305
+ if tp in _DATA_TYPES:
306
+ return tp
307
+ except TypeError:
308
+ pass # not hashable type
309
+
310
+ raise TypeError(f'Unsupported type: {tp}')
311
+
312
+
313
+ def _is_jsonable(tp):
314
+ if tp in (dict, list, tuple, str, int, float, bool, type(None)):
315
+ return True
316
+ if get_origin(tp) == list:
317
+ return _is_jsonable(get_args(tp)[0])
318
+ if get_origin(tp) == dict:
319
+ return all(_is_jsonable(val) for val in get_args(tp))
320
+ return False
321
+
322
+
323
+ # TODO --- tmp classes to stand-in for protos until they are defined and built into this package
324
+ class _SignatureDict(dict):
325
+ __getattr__ = dict.__getitem__
326
+ __setattr__ = dict.__setitem__
376
327
 
377
- return tp
378
328
 
329
+ class _VariableSignature(_SignatureDict):
379
330
 
380
- class _NamedFields(dict):
381
- __getattr__ = dict.__getitem__
382
- __setattr__ = dict.__setitem__
331
+ def __init__(self):
332
+ super().__init__()
333
+ self.name = ''
334
+ self.type = ''
335
+ self.type_args = []
336
+ self.streaming = False
337
+ self.required = False
338
+ self.default = ''
339
+ self.description = ''
383
340
 
384
341
 
385
342
  # data_type: name of the data type
386
343
  # data_field: name of the field in the data proto
387
344
  # serializer: serializer for the data type
388
- _DataType = namedtuple('_DataType', ('data_type', 'data_field', 'serializer'))
345
+ _DataType = namedtuple('_DataType', ('data_type', 'serializer'))
346
+
347
+
348
+ # this will come from the proto module, but for now, define it here
349
+ class DataType:
350
+ NOT_SET = 'NOT_SET'
351
+
352
+ STR = 'STR'
353
+ BYTES = 'BYTES'
354
+ INT = 'INT'
355
+ FLOAT = 'FLOAT'
356
+ BOOL = 'BOOL'
357
+ NDARRAY = 'NDARRAY'
358
+ JSON = 'JSON'
359
+
360
+ TEXT = 'TEXT'
361
+ IMAGE = 'IMAGE'
362
+ CONCEPT = 'CONCEPT'
363
+ REGION = 'REGION'
364
+ FRAME = 'FRAME'
365
+ AUDIO = 'AUDIO'
366
+ VIDEO = 'VIDEO'
367
+
368
+ NAMED_FIELDS = 'NAMED_FIELDS'
369
+ TUPLE = 'TUPLE'
370
+ LIST = 'LIST'
389
371
 
390
- # mapping of supported python types to data type names, fields, and serializers
372
+
373
+ # simple, non-container types that correspond directly to a data field
391
374
  _DATA_TYPES = {
392
375
  str:
393
- _DataType('str', 'string_value', AtomicFieldSerializer()),
376
+ _DataType(DataType.STR, AtomicFieldSerializer('string_value')),
394
377
  bytes:
395
- _DataType('bytes', 'bytes_value', AtomicFieldSerializer()),
378
+ _DataType(DataType.BYTES, AtomicFieldSerializer('bytes_value')),
396
379
  int:
397
- _DataType('int', 'int_value', AtomicFieldSerializer()),
380
+ _DataType(DataType.INT, AtomicFieldSerializer('int_value')),
398
381
  float:
399
- _DataType('float', 'float_value', AtomicFieldSerializer()),
382
+ _DataType(DataType.FLOAT, AtomicFieldSerializer('float_value')),
400
383
  bool:
401
- _DataType('bool', 'bool_value', AtomicFieldSerializer()),
402
- None:
403
- _DataType('None', '', NullValueSerializer()),
384
+ _DataType(DataType.BOOL, AtomicFieldSerializer('bool_value')),
404
385
  np.ndarray:
405
- _DataType('ndarray', 'ndarray', NDArraySerializer()),
386
+ _DataType(DataType.NDARRAY, NDArraySerializer('ndarray')),
406
387
  data_types.Text:
407
- _DataType('Text', 'text', MessageSerializer(data_types.Text)),
388
+ _DataType(DataType.TEXT, MessageSerializer('text', data_types.Text)),
408
389
  data_types.Image:
409
- _DataType('Image', 'image', ImageSerializer()),
390
+ _DataType(DataType.IMAGE, MessageSerializer('image', data_types.Image)),
410
391
  data_types.Concept:
411
- _DataType('Concept', 'concepts', MessageSerializer(data_types.Concept)),
392
+ _DataType(DataType.CONCEPT, MessageSerializer('concepts', data_types.Concept)),
412
393
  data_types.Region:
413
- _DataType('Region', 'regions', MessageSerializer(data_types.Region)),
394
+ _DataType(DataType.REGION, MessageSerializer('regions', data_types.Region)),
414
395
  data_types.Frame:
415
- _DataType('Frame', 'frames', MessageSerializer(data_types.Frame)),
396
+ _DataType(DataType.FRAME, MessageSerializer('frames', data_types.Frame)),
416
397
  data_types.Audio:
417
- _DataType('Audio', 'audio', MessageSerializer(data_types.Audio)),
398
+ _DataType(DataType.AUDIO, MessageSerializer('audio', data_types.Audio)),
418
399
  data_types.Video:
419
- _DataType('Video', 'video', MessageSerializer(data_types.Video)),
420
-
421
- # lists handled specially, not as generic lists using parts
422
- List[int]:
423
- _DataType('ndarray', 'ndarray', NDArraySerializer()),
424
- List[float]:
425
- _DataType('ndarray', 'ndarray', NDArraySerializer()),
426
- List[bool]:
427
- _DataType('ndarray', 'ndarray', NDArraySerializer()),
400
+ _DataType(DataType.VIDEO, MessageSerializer('video', data_types.Video)),
428
401
  }
429
402
 
403
+ _SERIALIZERS_BY_TYPE_ENUM = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
430
404
 
431
- # add generic lists using parts, for all supported types
432
- def _add_list_fields():
433
- for tp in list(_DATA_TYPES.keys()):
434
- if List[tp] in _DATA_TYPES:
435
- # already added as special case
436
- continue
437
-
438
- # check if data field is repeated, and if so, use repeated field for list
439
- field_name = _DATA_TYPES[tp].data_field
440
- descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
441
- repeated = descriptor and descriptor.label == descriptor.LABEL_REPEATED
442
-
443
- # add to supported types
444
- data_type = 'List[%s]' % _DATA_TYPES[tp].data_type
445
- data_field = field_name if repeated else 'parts[].' + field_name
446
- serializer = ListSerializer(_DATA_TYPES[tp].serializer)
447
-
448
- _DATA_TYPES[List[tp]] = _DataType(data_type, data_field, serializer)
449
405
 
406
+ class CompatibilitySerializer(Serializer):
407
+ '''
408
+ Serialization of basic value types, used for backwards compatibility
409
+ with older models that don't have type signatures.
410
+ '''
450
411
 
451
- _add_list_fields()
452
- _SERIALIZERS_BY_TYPE_STRING = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
412
+ def serialize(self, data_proto, value):
413
+ tp = _normalize_data_type(type(value))
414
+
415
+ try:
416
+ serializer = _DATA_TYPES[tp].serializer
417
+ except (KeyError, TypeError):
418
+ raise TypeError(f'serializer currently only supports basic types, got {tp}')
419
+
420
+ serializer.serialize(data_proto, value)
421
+
422
+ def deserialize(self, data_proto):
423
+ fields = [k.name for k, _ in data_proto.ListFields()]
424
+ if 'parts' in fields:
425
+ raise ValueError('serializer does not support parts')
426
+ serializers = [
427
+ serializer for serializer in _SERIALIZERS_BY_TYPE_ENUM.values()
428
+ if serializer.field_name in fields
429
+ ]
430
+ if not serializers:
431
+ raise ValueError('Returned data not recognized')
432
+ if len(serializers) != 1:
433
+ raise ValueError('Only single output supported for serializer')
434
+ serializer = serializers[0]
435
+ return serializer.deserialize(data_proto)