clarifai 11.1.5rc8__py3-none-any.whl → 11.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/model.py +40 -50
  3. clarifai/client/model.py +393 -157
  4. clarifai/runners/__init__.py +7 -2
  5. clarifai/runners/dockerfile_template/Dockerfile.template +1 -4
  6. clarifai/runners/models/base_typed_model.py +238 -0
  7. clarifai/runners/models/model_builder.py +9 -26
  8. clarifai/runners/models/model_class.py +28 -256
  9. clarifai/runners/models/model_run_locally.py +78 -3
  10. clarifai/runners/models/model_runner.py +0 -2
  11. clarifai/runners/models/model_servicer.py +2 -11
  12. clarifai/runners/utils/data_handler.py +205 -308
  13. {clarifai-11.1.5rc8.dist-info → clarifai-11.1.6.dist-info}/METADATA +26 -16
  14. clarifai-11.1.6.dist-info/RECORD +101 -0
  15. {clarifai-11.1.5rc8.dist-info → clarifai-11.1.6.dist-info}/WHEEL +1 -1
  16. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  17. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  18. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  19. clarifai/cli/__main__.py~ +0 -4
  20. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  21. clarifai/cli/__pycache__/__main__.cpython-310.pyc +0 -0
  22. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  23. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  24. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  25. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  26. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  27. clarifai/client/#model_client.py# +0 -430
  28. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  29. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  30. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  31. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  32. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  33. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  34. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  35. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  36. clarifai/client/__pycache__/runner.cpython-310.pyc +0 -0
  37. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  38. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  39. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  40. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  41. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  42. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  43. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  44. clarifai/client/model_client.py +0 -447
  45. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  46. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  47. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  48. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  49. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  50. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  51. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  52. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  53. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  54. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  55. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  56. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  57. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  58. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
  59. clarifai/models/__pycache__/__init__.cpython-310.pyc +0 -0
  60. clarifai/models/model_serving/__pycache__/__init__.cpython-310.pyc +0 -0
  61. clarifai/models/model_serving/__pycache__/constants.cpython-310.pyc +0 -0
  62. clarifai/models/model_serving/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  63. clarifai/models/model_serving/cli/__pycache__/_utils.cpython-310.pyc +0 -0
  64. clarifai/models/model_serving/cli/__pycache__/base.cpython-310.pyc +0 -0
  65. clarifai/models/model_serving/cli/__pycache__/build.cpython-310.pyc +0 -0
  66. clarifai/models/model_serving/cli/__pycache__/create.cpython-310.pyc +0 -0
  67. clarifai/models/model_serving/model_config/__pycache__/__init__.cpython-310.pyc +0 -0
  68. clarifai/models/model_serving/model_config/__pycache__/base.cpython-310.pyc +0 -0
  69. clarifai/models/model_serving/model_config/__pycache__/config.cpython-310.pyc +0 -0
  70. clarifai/models/model_serving/model_config/__pycache__/inference_parameter.cpython-310.pyc +0 -0
  71. clarifai/models/model_serving/model_config/__pycache__/output.cpython-310.pyc +0 -0
  72. clarifai/models/model_serving/model_config/triton/__pycache__/__init__.cpython-310.pyc +0 -0
  73. clarifai/models/model_serving/model_config/triton/__pycache__/serializer.cpython-310.pyc +0 -0
  74. clarifai/models/model_serving/model_config/triton/__pycache__/triton_config.cpython-310.pyc +0 -0
  75. clarifai/models/model_serving/model_config/triton/__pycache__/wrappers.cpython-310.pyc +0 -0
  76. clarifai/models/model_serving/repo_build/__pycache__/__init__.cpython-310.pyc +0 -0
  77. clarifai/models/model_serving/repo_build/__pycache__/build.cpython-310.pyc +0 -0
  78. clarifai/models/model_serving/repo_build/static_files/__pycache__/base_test.cpython-310-pytest-7.2.0.pyc +0 -0
  79. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  80. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  81. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  82. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  83. clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
  84. clarifai/runners/dockerfile_template/Dockerfile.debug +0 -11
  85. clarifai/runners/dockerfile_template/Dockerfile.debug~ +0 -9
  86. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  87. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  88. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  89. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  90. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  91. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  92. clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
  93. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  94. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  95. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  96. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  97. clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
  98. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  99. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  100. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  101. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  102. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  103. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  104. clarifai/runners/utils/data_types.py +0 -386
  105. clarifai/runners/utils/method_signatures.py +0 -435
  106. clarifai/runners/utils/serializers.py +0 -208
  107. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  108. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  109. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  110. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  111. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  112. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  113. clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  114. clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
  115. clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
  116. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  117. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  118. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  119. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  120. clarifai-11.1.5rc8.dist-info/RECORD +0 -204
  121. {clarifai-11.1.5rc8.dist-info → clarifai-11.1.6.dist-info}/LICENSE +0 -0
  122. {clarifai-11.1.5rc8.dist-info → clarifai-11.1.6.dist-info}/entry_points.txt +0 -0
  123. {clarifai-11.1.5rc8.dist-info → clarifai-11.1.6.dist-info}/top_level.txt +0 -0
@@ -1,435 +0,0 @@
1
- import ast
2
- import inspect
3
- import json
4
- import textwrap
5
- from collections import namedtuple
6
- from typing import List, Tuple, get_args, get_origin
7
-
8
- import numpy as np
9
- import PIL.Image
10
- import yaml
11
- from clarifai_grpc.grpc.api import resources_pb2
12
- from google.protobuf.message import Message as MessageProto
13
-
14
- from clarifai.runners.utils import data_types
15
- from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ListSerializer,
16
- MessageSerializer, NamedFieldsSerializer,
17
- NDArraySerializer, Serializer, TupleSerializer)
18
-
19
-
20
- def build_function_signature(func):
21
- '''
22
- Build a signature for the given function.
23
- '''
24
- sig = inspect.signature(func)
25
-
26
- # check if func is bound, and if not, remove self/cls
27
- if getattr(func, '__self__', None) is None and sig.parameters and list(
28
- sig.parameters.values())[0].name in ('self', 'cls'):
29
- sig = sig.replace(parameters=list(sig.parameters.values())[1:])
30
-
31
- return_annotation = sig.return_annotation
32
- if return_annotation == inspect.Parameter.empty:
33
- raise TypeError('Function must have a return annotation')
34
-
35
- input_sigs = [
36
- build_variable_signature(p.name, p.annotation, p.default) for p in sig.parameters.values()
37
- ]
38
- input_sigs, input_types, input_streaming = zip(*input_sigs)
39
- output_sig, output_type, output_streaming = build_variable_signature(
40
- 'return', return_annotation, is_output=True)
41
- # TODO: flatten out "return" layer if not needed
42
-
43
- # check for streams and determine method type
44
- if sum(input_streaming) > 1:
45
- raise TypeError('streaming methods must have at most one streaming input')
46
- input_streaming = any(input_streaming)
47
- if not (input_streaming or output_streaming):
48
- method_type = 'predict'
49
- elif not input_streaming and output_streaming:
50
- method_type = 'generate'
51
- elif input_streaming and output_streaming:
52
- method_type = 'stream'
53
- else:
54
- raise TypeError('stream methods with streaming inputs must have streaming outputs')
55
-
56
- #method_signature = resources_pb2.MethodSignature() # TODO
57
- method_signature = _SignatureDict() #for now
58
-
59
- method_signature.name = func.__name__
60
- #method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
61
- assert method_type in ('predict', 'generate', 'stream')
62
- method_signature.method_type = method_type
63
- method_signature.docstring = func.__doc__
64
- method_signature.annotations_json = json.dumps(_get_annotations_source(func))
65
-
66
- #method_signature.inputs.extend(input_vars)
67
- #method_signature.outputs.extend(output_vars)
68
- method_signature.inputs = input_sigs
69
- method_signature.outputs = output_sig
70
- return method_signature
71
-
72
-
73
- def _get_annotations_source(func):
74
- """Extracts raw annotation strings from the function source."""
75
- source = inspect.getsource(func) # Get function source code
76
- source = textwrap.dedent(source) # Dedent source code
77
- tree = ast.parse(source) # Parse into AST
78
- func_node = next(node for node in tree.body
79
- if isinstance(node, ast.FunctionDef)) # Get function node
80
-
81
- annotations = {}
82
- for arg in func_node.args.args: # Process arguments
83
- if arg.annotation:
84
- annotations[arg.arg] = ast.unparse(arg.annotation) # Get raw annotation string
85
-
86
- if func_node.returns: # Process return type
87
- annotations["return"] = ast.unparse(func_node.returns)
88
-
89
- return annotations
90
-
91
-
92
- def build_variable_signature(name, annotation, default=inspect.Parameter.empty, is_output=False):
93
- '''
94
- Build a data proto signature and get the normalized python type for the given annotation.
95
- '''
96
-
97
- # check valid names (should already be constrained by python naming, but check anyway)
98
- if not name.isidentifier():
99
- raise ValueError(f'Invalid variable name: {name}')
100
-
101
- # get fields for each variable based on type
102
- tp, streaming = _normalize_type(annotation)
103
-
104
- #var = resources_pb2.VariableSignature() # TODO
105
- sig = _VariableSignature() #for now
106
- sig.name = name
107
-
108
- _fill_signature_type(sig, tp)
109
-
110
- sig.streaming = streaming
111
-
112
- if not is_output:
113
- sig.required = (default is inspect.Parameter.empty)
114
- if not sig.required:
115
- sig.default = default
116
-
117
- return sig, type, streaming
118
-
119
-
120
- def _fill_signature_type(sig, tp):
121
- try:
122
- if tp in _DATA_TYPES:
123
- sig.data_type = _DATA_TYPES[tp].data_type
124
- return
125
- except TypeError:
126
- pass # not hashable type
127
-
128
- if isinstance(tp, data_types.NamedFields):
129
- sig.data_type = DataType.NAMED_FIELDS
130
- for name, inner_type in tp.items():
131
- # inner_sig = sig.type_args.add()
132
- sig.type_args.append(inner_sig := _VariableSignature())
133
- inner_sig.name = name
134
- _fill_signature_type(inner_sig, inner_type)
135
- return
136
-
137
- if get_origin(tp) == tuple:
138
- sig.data_type = DataType.TUPLE
139
- for inner_type in get_args(tp):
140
- #inner_sig = sig.type_args.add()
141
- sig.type_args.append(inner_sig := _VariableSignature())
142
- _fill_signature_type(inner_sig, inner_type)
143
- return
144
-
145
- if get_origin(tp) == list:
146
- sig.data_type = DataType.LIST
147
- inner_type = get_args(tp)[0]
148
- #inner_sig = sig.type_args.add()
149
- sig.type_args.append(inner_sig := _VariableSignature())
150
- _fill_signature_type(inner_sig, inner_type)
151
- return
152
-
153
- raise TypeError(f'Unsupported type: {tp}')
154
-
155
-
156
- def serializer_from_signature(signature):
157
- '''
158
- Get the serializer for the given signature.
159
- '''
160
- if signature.data_type in _SERIALIZERS_BY_TYPE_ENUM:
161
- return _SERIALIZERS_BY_TYPE_ENUM[signature.data_type]
162
- if signature.data_type == DataType.LIST:
163
- return ListSerializer(serializer_from_signature(signature.type_args[0]))
164
- if signature.data_type == DataType.TUPLE:
165
- return TupleSerializer([serializer_from_signature(sig) for sig in signature.type_args])
166
- if signature.data_type == DataType.NAMED_FIELDS:
167
- return NamedFieldsSerializer(
168
- {sig.name: serializer_from_signature(sig)
169
- for sig in signature.type_args})
170
- raise ValueError(f'Unsupported type: {signature.data_type}')
171
-
172
-
173
- def signatures_to_json(signatures):
174
- assert isinstance(
175
- signatures, dict), 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
176
- return json.dumps(signatures, default=repr)
177
-
178
-
179
- def signatures_from_json(json_str):
180
- d = json.loads(json_str, object_pairs_hook=_SignatureDict)
181
- return d
182
-
183
-
184
- def signatures_to_yaml(signatures):
185
- # XXX go in/out of json to get the correct format and python dict types
186
- d = json.loads(signatures_to_json(signatures))
187
- return yaml.dump(d, default_flow_style=False)
188
-
189
-
190
- def signatures_from_yaml(yaml_str):
191
- d = yaml.safe_load(yaml_str)
192
- return signatures_from_json(json.dumps(d))
193
-
194
-
195
- def serialize(kwargs, signatures, proto=None, is_output=False):
196
- '''
197
- Serialize the given kwargs into the proto using the given signatures.
198
- '''
199
- if proto is None:
200
- proto = resources_pb2.Data()
201
- unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
202
- if unknown:
203
- if unknown == {'return'} and len(signatures) > 1:
204
- raise TypeError('Got a single return value, but expected multiple outputs {%s}' %
205
- ', '.join(sig.name for sig in signatures))
206
- raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
207
- for sig in signatures:
208
- if sig.name not in kwargs:
209
- if sig.required:
210
- raise TypeError(f'Missing required argument: {sig.name}')
211
- continue # skip missing fields, they can be set to default on the server
212
- data = kwargs[sig.name]
213
- serializer = serializer_from_signature(sig)
214
- # TODO determine if any (esp the first) var can go in the proto without parts
215
- # and whether to put this in the signature or dynamically determine it
216
- part = proto.parts.add()
217
- part.id = sig.name
218
- serializer.serialize(part.data, data)
219
- return proto
220
-
221
-
222
- def deserialize(proto, signatures, is_output=False):
223
- '''
224
- Deserialize the given proto into kwargs using the given signatures.
225
- '''
226
- if isinstance(signatures, dict):
227
- signatures = [signatures] # TODO update return key level and make consistnet
228
- kwargs = {}
229
- parts_by_name = {part.id: part for part in proto.parts}
230
- for sig in signatures:
231
- serializer = serializer_from_signature(sig)
232
- part = parts_by_name.get(sig.name)
233
- if part is None:
234
- if sig.required or is_output: # TODO allow optional outputs?
235
- raise ValueError(f'Missing required field: {sig.name}')
236
- continue
237
- kwargs[sig.name] = serializer.deserialize(part.data)
238
- if len(kwargs) == 1 and 'return' in kwargs:
239
- return kwargs['return']
240
- return kwargs
241
-
242
-
243
- def get_stream_from_signature(signatures):
244
- '''
245
- Get the stream signature from the given signatures.
246
- '''
247
- for sig in signatures:
248
- if sig.streaming:
249
- return sig
250
- return None
251
-
252
-
253
- def _is_empty_proto_data(data):
254
- if isinstance(data, np.ndarray):
255
- return False
256
- if isinstance(data, MessageProto):
257
- return not data.ByteSize()
258
- return not data
259
-
260
-
261
- def _normalize_type(tp):
262
- '''
263
- Normalize the types for the given parameter.
264
- Returns the normalized type and whether the parameter is streaming.
265
- '''
266
- # stream type indicates streaming, not part of the data itself
267
- # it can only be used at the top-level of the var type
268
- streaming = (get_origin(tp) == data_types.Stream)
269
- if streaming:
270
- tp = get_args(tp)[0]
271
-
272
- return _normalize_data_type(tp), streaming
273
-
274
-
275
- def _normalize_data_type(tp):
276
- # check if list, and if so, get inner type
277
- if get_origin(tp) == list:
278
- tp = get_args(tp)[0]
279
- return List[_normalize_data_type(tp)]
280
-
281
- if isinstance(tp, (tuple, list)):
282
- return Tuple[tuple(_normalize_data_type(val) for val in tp)]
283
-
284
- if isinstance(tp, (dict, data_types.NamedFields)):
285
- return data_types.NamedFields(**{name: _normalize_data_type(val) for name, val in tp.items()})
286
-
287
- # check if numpy array type, and if so, use ndarray
288
- if get_origin(tp) == np.ndarray:
289
- return np.ndarray
290
-
291
- # check for PIL images (sometimes types use the module, sometimes the class)
292
- # set these to use the Image data handler
293
- if tp in (data_types.Image, PIL.Image, PIL.Image.Image):
294
- return data_types.Image
295
-
296
- # check for jsonable types
297
- # TODO should we include dict vs list in the data type somehow?
298
- if tp == dict or (get_origin(tp) == dict and tp not in _DATA_TYPES and _is_jsonable(tp)):
299
- return data_types.JSON
300
- if tp == list or (get_origin(tp) == list and tp not in _DATA_TYPES and _is_jsonable(tp)):
301
- return data_types.JSON
302
-
303
- # check for known data types
304
- try:
305
- if tp in _DATA_TYPES:
306
- return tp
307
- except TypeError:
308
- pass # not hashable type
309
-
310
- raise TypeError(f'Unsupported type: {tp}')
311
-
312
-
313
- def _is_jsonable(tp):
314
- if tp in (dict, list, tuple, str, int, float, bool, type(None)):
315
- return True
316
- if get_origin(tp) == list:
317
- return _is_jsonable(get_args(tp)[0])
318
- if get_origin(tp) == dict:
319
- return all(_is_jsonable(val) for val in get_args(tp))
320
- return False
321
-
322
-
323
- # TODO --- tmp classes to stand-in for protos until they are defined and built into this package
324
- class _SignatureDict(dict):
325
- __getattr__ = dict.__getitem__
326
- __setattr__ = dict.__setitem__
327
-
328
-
329
- class _VariableSignature(_SignatureDict):
330
-
331
- def __init__(self):
332
- super().__init__()
333
- self.name = ''
334
- self.type = ''
335
- self.type_args = []
336
- self.streaming = False
337
- self.required = False
338
- self.default = ''
339
- self.description = ''
340
-
341
-
342
- # data_type: name of the data type
343
- # data_field: name of the field in the data proto
344
- # serializer: serializer for the data type
345
- _DataType = namedtuple('_DataType', ('data_type', 'serializer'))
346
-
347
-
348
- # this will come from the proto module, but for now, define it here
349
- class DataType:
350
- NOT_SET = 'NOT_SET'
351
-
352
- STR = 'STR'
353
- BYTES = 'BYTES'
354
- INT = 'INT'
355
- FLOAT = 'FLOAT'
356
- BOOL = 'BOOL'
357
- NDARRAY = 'NDARRAY'
358
- JSON = 'JSON'
359
-
360
- TEXT = 'TEXT'
361
- IMAGE = 'IMAGE'
362
- CONCEPT = 'CONCEPT'
363
- REGION = 'REGION'
364
- FRAME = 'FRAME'
365
- AUDIO = 'AUDIO'
366
- VIDEO = 'VIDEO'
367
-
368
- NAMED_FIELDS = 'NAMED_FIELDS'
369
- TUPLE = 'TUPLE'
370
- LIST = 'LIST'
371
-
372
-
373
- # simple, non-container types that correspond directly to a data field
374
- _DATA_TYPES = {
375
- str:
376
- _DataType(DataType.STR, AtomicFieldSerializer('string_value')),
377
- bytes:
378
- _DataType(DataType.BYTES, AtomicFieldSerializer('bytes_value')),
379
- int:
380
- _DataType(DataType.INT, AtomicFieldSerializer('int_value')),
381
- float:
382
- _DataType(DataType.FLOAT, AtomicFieldSerializer('float_value')),
383
- bool:
384
- _DataType(DataType.BOOL, AtomicFieldSerializer('bool_value')),
385
- np.ndarray:
386
- _DataType(DataType.NDARRAY, NDArraySerializer('ndarray')),
387
- data_types.Text:
388
- _DataType(DataType.TEXT, MessageSerializer('text', data_types.Text)),
389
- data_types.Image:
390
- _DataType(DataType.IMAGE, MessageSerializer('image', data_types.Image)),
391
- data_types.Concept:
392
- _DataType(DataType.CONCEPT, MessageSerializer('concepts', data_types.Concept)),
393
- data_types.Region:
394
- _DataType(DataType.REGION, MessageSerializer('regions', data_types.Region)),
395
- data_types.Frame:
396
- _DataType(DataType.FRAME, MessageSerializer('frames', data_types.Frame)),
397
- data_types.Audio:
398
- _DataType(DataType.AUDIO, MessageSerializer('audio', data_types.Audio)),
399
- data_types.Video:
400
- _DataType(DataType.VIDEO, MessageSerializer('video', data_types.Video)),
401
- }
402
-
403
- _SERIALIZERS_BY_TYPE_ENUM = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
404
-
405
-
406
- class CompatibilitySerializer(Serializer):
407
- '''
408
- Serialization of basic value types, used for backwards compatibility
409
- with older models that don't have type signatures.
410
- '''
411
-
412
- def serialize(self, data_proto, value):
413
- tp = _normalize_data_type(type(value))
414
-
415
- try:
416
- serializer = _DATA_TYPES[tp].serializer
417
- except (KeyError, TypeError):
418
- raise TypeError(f'serializer currently only supports basic types, got {tp}')
419
-
420
- serializer.serialize(data_proto, value)
421
-
422
- def deserialize(self, data_proto):
423
- fields = [k.name for k, _ in data_proto.ListFields()]
424
- if 'parts' in fields:
425
- raise ValueError('serializer does not support parts')
426
- serializers = [
427
- serializer for serializer in _SERIALIZERS_BY_TYPE_ENUM.values()
428
- if serializer.field_name in fields
429
- ]
430
- if not serializers:
431
- raise ValueError('Returned data not recognized')
432
- if len(serializers) != 1:
433
- raise ValueError('Only single output supported for serializer')
434
- serializer = serializers[0]
435
- return serializer.deserialize(data_proto)
@@ -1,208 +0,0 @@
1
- import json
2
- from typing import Dict, Iterable
3
-
4
- import numpy as np
5
- from clarifai_grpc.grpc.api import resources_pb2
6
-
7
- from clarifai.runners.utils import data_types
8
-
9
-
10
- class Serializer:
11
-
12
- def serialize(self, data_proto, value):
13
- pass
14
-
15
- def deserialize(self, data_proto):
16
- pass
17
-
18
- def handles_list(self):
19
- return False
20
-
21
-
22
- def is_repeated_field(field_name):
23
- descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
24
- return descriptor and descriptor.label == descriptor.LABEL_REPEATED
25
-
26
-
27
- class AtomicFieldSerializer(Serializer):
28
-
29
- def __init__(self, field_name):
30
- self.field_name = field_name
31
-
32
- def serialize(self, data_proto, value):
33
- try:
34
- setattr(data_proto, self.field_name, value)
35
- except TypeError as e:
36
- raise TypeError(f"Incompatible type for {self.field_name}: {type(value)}") from e
37
-
38
- def deserialize(self, data_proto):
39
- return getattr(data_proto, self.field_name)
40
-
41
-
42
- class MessageSerializer(Serializer):
43
-
44
- def __init__(self, field_name, message_class):
45
- self.field_name = field_name
46
- self.message_class = message_class
47
- self.is_repeated_field = is_repeated_field(field_name)
48
-
49
- def handles_list(self):
50
- return self.is_repeated_field
51
-
52
- def serialize(self, data_proto, value):
53
- value = self.message_class.from_value(value).to_proto()
54
- dst = getattr(data_proto, self.field_name)
55
- try:
56
- if self.is_repeated_field:
57
- dst.add().CopyFrom(value)
58
- else:
59
- dst.CopyFrom(value)
60
- except TypeError as e:
61
- raise TypeError(f"Incompatible type for {self.field_name}: {type(value)}") from e
62
-
63
- def serialize_list(self, data_proto, values):
64
- assert self.is_repeated_field
65
- dst = getattr(data_proto, self.field_name)
66
- dst.extend([self.message_class.from_value(value).to_proto() for value in values])
67
-
68
- def deserialize(self, data_proto):
69
- src = getattr(data_proto, self.field_name)
70
- if self.is_repeated_field:
71
- values = [self.message_class.from_proto(x) for x in src]
72
- if len(values) == 1:
73
- return values[0]
74
- return values
75
- else:
76
- return self.message_class.from_proto(src)
77
-
78
- def deserialize_list(self, data_proto, values):
79
- assert self.is_repeated_field
80
- src = getattr(data_proto, self.field_name)
81
- return [self.message_class.from_proto(x) for x in src]
82
-
83
-
84
- class NDArraySerializer(Serializer):
85
-
86
- def __init__(self, field_name, as_list=False):
87
- self.field_name = field_name
88
- self.as_list = as_list
89
-
90
- def serialize(self, data_proto, value):
91
- if self.as_list and not isinstance(value, Iterable):
92
- raise TypeError(f"Expected list, got {type(value)}")
93
- value = np.asarray(value)
94
- if not np.issubdtype(value.dtype, np.number):
95
- raise TypeError(f"Expected number array, got {value.dtype}")
96
- proto = getattr(data_proto, self.field_name)
97
- proto.buffer = value.tobytes()
98
- proto.shape.extend(value.shape)
99
- proto.dtype = str(value.dtype)
100
-
101
- def deserialize(self, data_proto):
102
- proto = getattr(data_proto, self.field_name)
103
- array = np.frombuffer(proto.buffer, dtype=np.dtype(proto.dtype)).reshape(proto.shape)
104
- if self.as_list:
105
- return array.tolist()
106
- return array
107
-
108
-
109
- class JSONSerializer(Serializer):
110
-
111
- def __init__(self, field_name, type=None):
112
- self.field_name = field_name
113
- self.type = type
114
-
115
- def serialize(self, data_proto, value):
116
- #if self.type is not None and not isinstance(value, self.type):
117
- # raise TypeError(f"Expected {self.type}, got {type(value)}")
118
- try:
119
- setattr(data_proto, self.field_name, json.dumps(value))
120
- except TypeError as e:
121
- raise TypeError(f"Incompatible type for {self.field_name}: {type(value)}") from e
122
-
123
- def deserialize(self, data_proto):
124
- return json.loads(getattr(data_proto, self.field_name))
125
-
126
-
127
- class ListSerializer(Serializer):
128
-
129
- def __init__(self, inner_serializer):
130
- self.field_name = 'parts'
131
- self.inner_serializer = inner_serializer
132
-
133
- def handles_list(self):
134
- # if handles_list() is called on this serializer, it means that we're
135
- # trying to serialize a list of lists. In this case, we need to use
136
- # parts[] for the outer list, so we return False here (we can't inline it).
137
- return False
138
-
139
- def serialize(self, data_proto, value):
140
- if not isinstance(value, Iterable):
141
- raise TypeError(f"Expected iterable, got {type(value)}")
142
- if self.inner_serializer.handles_list():
143
- self.inner_serializer.serialize_list(data_proto, value)
144
- else:
145
- for item in value:
146
- part = data_proto.parts.add()
147
- self.inner_serializer.serialize(part.data, item)
148
-
149
- def deserialize(self, data_proto):
150
- if self.inner_serializer.handles_list():
151
- return self.inner_serializer.deserialize_list(data_proto)
152
- return [self.inner_serializer.deserialize(part.data) for part in data_proto.parts]
153
-
154
-
155
- class TupleSerializer(Serializer):
156
-
157
- def __init__(self, inner_serializers):
158
- self.field_name = 'parts'
159
- self.inner_serializers = inner_serializers
160
-
161
- def serialize(self, data_proto, value):
162
- if not isinstance(value, (tuple, list)):
163
- raise TypeError(f"Expected tuple, got {type(value)}")
164
- if len(value) != len(self.inner_serializers):
165
- raise ValueError(f"Expected tuple of length {len(self.inner_serializers)}, got {len(value)}")
166
- for i, (serializer, item) in enumerate(zip(self.inner_serializers, value)):
167
- part = data_proto.parts.add()
168
- part.id = str(i)
169
- serializer.serialize(part.data, item)
170
-
171
- def deserialize(self, data_proto):
172
- return tuple(
173
- serializer.deserialize(part.data)
174
- for serializer, part in zip(self.inner_serializers, data_proto.parts))
175
-
176
-
177
- class NamedFieldsSerializer(Serializer):
178
-
179
- def __init__(self, named_field_serializers: Dict[str, Serializer]):
180
- self.field_name = 'parts'
181
- self.named_field_serializers = named_field_serializers
182
-
183
- def serialize(self, data_proto, value):
184
- for name, serializer in self.named_field_serializers.items():
185
- if name not in value:
186
- raise KeyError(f"Missing field {name}")
187
- part = self._get_part(data_proto, name, add=True)
188
- serializer.serialize(part.data, value[name])
189
-
190
- def deserialize(self, data_proto):
191
- value = data_types.NamedFields()
192
- for name, serializer in self.named_field_serializers.items():
193
- part = self._get_part(data_proto, name)
194
- value[name] = serializer.deserialize(part.data)
195
- return value
196
-
197
- def _get_part(self, data_proto, name, add=False):
198
- for part in data_proto.parts:
199
- if part.id == name:
200
- return part
201
- if add:
202
- part = data_proto.parts.add()
203
- part.id = name
204
- return part
205
- raise KeyError(f"Missing part with key {name}")
206
-
207
-
208
- # TODO dict serializer, maybe json only?