clarifai 11.1.5__py3-none-any.whl → 11.1.5rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  3. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  4. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  5. clarifai/cli/__main__.py~ +4 -0
  6. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  7. clarifai/cli/__pycache__/__main__.cpython-310.pyc +0 -0
  8. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  9. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  10. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  11. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  12. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  13. clarifai/cli/model.py +25 -0
  14. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  15. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  16. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  17. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  18. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  19. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  20. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  21. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  22. clarifai/client/__pycache__/runner.cpython-310.pyc +0 -0
  23. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  24. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  25. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  26. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  27. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  28. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  29. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  30. clarifai/client/model.py +95 -362
  31. clarifai/client/model_client.py +432 -0
  32. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  33. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  34. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  35. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  36. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  37. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  38. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  39. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  40. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  41. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  42. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  43. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  44. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  45. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
  46. clarifai/models/__pycache__/__init__.cpython-310.pyc +0 -0
  47. clarifai/models/model_serving/__pycache__/__init__.cpython-310.pyc +0 -0
  48. clarifai/models/model_serving/__pycache__/constants.cpython-310.pyc +0 -0
  49. clarifai/models/model_serving/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  50. clarifai/models/model_serving/cli/__pycache__/_utils.cpython-310.pyc +0 -0
  51. clarifai/models/model_serving/cli/__pycache__/base.cpython-310.pyc +0 -0
  52. clarifai/models/model_serving/cli/__pycache__/build.cpython-310.pyc +0 -0
  53. clarifai/models/model_serving/cli/__pycache__/create.cpython-310.pyc +0 -0
  54. clarifai/models/model_serving/model_config/__pycache__/__init__.cpython-310.pyc +0 -0
  55. clarifai/models/model_serving/model_config/__pycache__/base.cpython-310.pyc +0 -0
  56. clarifai/models/model_serving/model_config/__pycache__/config.cpython-310.pyc +0 -0
  57. clarifai/models/model_serving/model_config/__pycache__/inference_parameter.cpython-310.pyc +0 -0
  58. clarifai/models/model_serving/model_config/__pycache__/output.cpython-310.pyc +0 -0
  59. clarifai/models/model_serving/model_config/triton/__pycache__/__init__.cpython-310.pyc +0 -0
  60. clarifai/models/model_serving/model_config/triton/__pycache__/serializer.cpython-310.pyc +0 -0
  61. clarifai/models/model_serving/model_config/triton/__pycache__/triton_config.cpython-310.pyc +0 -0
  62. clarifai/models/model_serving/model_config/triton/__pycache__/wrappers.cpython-310.pyc +0 -0
  63. clarifai/models/model_serving/repo_build/__pycache__/__init__.cpython-310.pyc +0 -0
  64. clarifai/models/model_serving/repo_build/__pycache__/build.cpython-310.pyc +0 -0
  65. clarifai/models/model_serving/repo_build/static_files/__pycache__/base_test.cpython-310-pytest-7.2.0.pyc +0 -0
  66. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  67. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  68. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  69. clarifai/runners/__init__.py +2 -7
  70. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  71. clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
  72. clarifai/runners/dockerfile_template/Dockerfile.debug +11 -0
  73. clarifai/runners/dockerfile_template/Dockerfile.debug~ +9 -0
  74. clarifai/runners/dockerfile_template/Dockerfile.template +3 -0
  75. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  76. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  77. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  78. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  79. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  80. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  81. clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
  82. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  83. clarifai/runners/models/model_builder.py +33 -7
  84. clarifai/runners/models/model_class.py +273 -28
  85. clarifai/runners/models/model_run_locally.py +3 -78
  86. clarifai/runners/models/model_runner.py +2 -0
  87. clarifai/runners/models/model_servicer.py +11 -2
  88. clarifai/runners/server.py +5 -1
  89. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  90. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  91. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  92. clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
  93. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  94. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  95. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  96. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  97. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  98. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  99. clarifai/runners/utils/data_handler.py +308 -205
  100. clarifai/runners/utils/data_types.py +334 -0
  101. clarifai/runners/utils/method_signatures.py +452 -0
  102. clarifai/runners/utils/serializers.py +132 -0
  103. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  104. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  105. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  106. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  107. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  108. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  109. clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  110. clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
  111. clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
  112. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  113. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  114. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  115. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  116. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc6.dist-info}/METADATA +16 -26
  117. clarifai-11.1.5rc6.dist-info/RECORD +203 -0
  118. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc6.dist-info}/WHEEL +1 -1
  119. clarifai/runners/models/base_typed_model.py +0 -238
  120. clarifai-11.1.5.dist-info/RECORD +0 -101
  121. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc6.dist-info}/LICENSE +0 -0
  122. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc6.dist-info}/entry_points.txt +0 -0
  123. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,452 @@
1
+ import inspect
2
+ import json
3
+ import re
4
+ import types
5
+ from collections import OrderedDict, namedtuple
6
+ from typing import List, get_args, get_origin
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import yaml
11
+ from clarifai_grpc.grpc.api import resources_pb2
12
+ from google.protobuf.message import Message as MessageProto
13
+
14
+ from clarifai.runners.utils import data_types
15
+ from clarifai.runners.utils.serializers import (AtomicFieldSerializer, ImageSerializer,
16
+ ListSerializer, MessageSerializer,
17
+ NDArraySerializer, NullValueSerializer, Serializer)
18
+
19
+
20
+ def build_function_signature(func, method_type: str):
21
+ '''
22
+ Build a signature for the given function.
23
+ '''
24
+ sig = inspect.signature(func)
25
+
26
+ # check if func is bound, and if not, remove self/cls
27
+ if getattr(func, '__self__', None) is None and sig.parameters and list(
28
+ sig.parameters.values())[0].name in ('self', 'cls'):
29
+ sig = sig.replace(parameters=list(sig.parameters.values())[1:])
30
+
31
+ return_annotation = sig.return_annotation
32
+ if return_annotation == inspect.Parameter.empty:
33
+ raise ValueError('Function must have a return annotation')
34
+ # check for multiple return values and convert to dict for named values
35
+ return_streaming = False
36
+ if get_origin(return_annotation) == data_types.Stream:
37
+ return_annotation = get_args(return_annotation)[0]
38
+ return_streaming = True
39
+ if get_origin(return_annotation) == tuple:
40
+ return_annotation = tuple(get_args(return_annotation))
41
+ if isinstance(return_annotation, tuple):
42
+ return_annotation = {'return.%s' % i: tp for i, tp in enumerate(return_annotation)}
43
+ if not isinstance(return_annotation, dict):
44
+ return_annotation = {'return': return_annotation}
45
+
46
+ input_vars = build_variables_signature(sig.parameters.values())
47
+ output_vars = build_variables_signature(
48
+ [
49
+ # XXX inspect.Parameter errors for the special return names, so use SimpleNamespace
50
+ types.SimpleNamespace(name=name, annotation=tp, default=inspect.Parameter.empty)
51
+ for name, tp in return_annotation.items()
52
+ ],
53
+ is_output=True)
54
+ if return_streaming:
55
+ for var in output_vars:
56
+ var.streaming = True
57
+
58
+ # check for streams
59
+ if method_type == 'predict':
60
+ for var in input_vars:
61
+ if var.streaming:
62
+ raise TypeError('Stream inputs are not supported for predict methods')
63
+ for var in output_vars:
64
+ if var.streaming:
65
+ raise TypeError('Stream outputs are not supported for predict methods')
66
+ elif method_type == 'generate':
67
+ for var in input_vars:
68
+ if var.streaming:
69
+ raise TypeError('Stream inputs are not supported for generate methods')
70
+ if not all(var.streaming for var in output_vars):
71
+ raise TypeError('Generate methods must return a stream')
72
+ elif method_type == 'stream':
73
+ input_stream_vars = [var for var in input_vars if var.streaming]
74
+ if len(input_stream_vars) == 0:
75
+ raise TypeError('Stream methods must include a Stream input')
76
+ if not all(var.streaming for var in output_vars):
77
+ raise TypeError('Stream methods must return a single Stream')
78
+ else:
79
+ raise TypeError('Invalid method type: %s' % method_type)
80
+
81
+ #method_signature = resources_pb2.MethodSignature() # TODO
82
+ method_signature = _NamedFields() #for now
83
+
84
+ method_signature.name = func.__name__
85
+ #method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
86
+ assert method_type in ('predict', 'generate', 'stream')
87
+ method_signature.method_type = method_type
88
+ method_signature.docstring = func.__doc__
89
+
90
+ #method_signature.inputs.extend(input_vars)
91
+ #method_signature.outputs.extend(output_vars)
92
+ method_signature.inputs = input_vars
93
+ method_signature.outputs = output_vars
94
+ return method_signature
95
+
96
+
97
+ def build_variables_signature(parameters: List[inspect.Parameter], is_output=False):
98
+ '''
99
+ Build a data proto signature for the given variable or return type annotation.
100
+ '''
101
+
102
+ vars = []
103
+
104
+ # check valid names (should already be constrained by python naming, but check anyway)
105
+ for param in parameters:
106
+ if not param.name.isidentifier() and not (is_output and
107
+ re.match(r'return(\.\d+)?', param.name)):
108
+ raise ValueError(f'Invalid variable name: {param.name}')
109
+
110
+ # get fields for each variable based on type
111
+ for param in parameters:
112
+ param_types, streaming = _normalize_types(param, is_output=is_output)
113
+
114
+ for name, tp in param_types.items():
115
+ #var = resources_pb2.MethodVariable() # TODO
116
+ var = _NamedFields()
117
+ var.name = name
118
+ var.data_type = _DATA_TYPES[tp].data_type
119
+ var.data_field = _DATA_TYPES[tp].data_field
120
+ var.streaming = streaming
121
+ if not is_output:
122
+ var.required = (param.default is inspect.Parameter.empty)
123
+ if not var.required:
124
+ var.default = param.default
125
+ vars.append(var)
126
+
127
+ # check if any fields are used more than once, and if so, use parts
128
+ # also if more than one field uses parts lists, also use parts, since the lists can be different lengths
129
+ # NOTE this is a little fancy, another way would just be to check if there is more than one arg
130
+ fields_unique = (len(set(var.data_field for var in vars)) == len(vars))
131
+ num_parts_lists = sum(int(var.data_field.startswith('parts[]')) for var in vars)
132
+ if not fields_unique or num_parts_lists > 1:
133
+ for var in vars:
134
+ var.data_field = 'parts[%s].%s' % (var.name, var.data_field)
135
+
136
+ return vars
137
+
138
+
139
+ def signatures_to_json(signatures):
140
+ assert isinstance(
141
+ signatures, dict), 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
142
+ return json.dumps(signatures, default=repr)
143
+
144
+
145
+ def signatures_from_json(json_str):
146
+ return json.loads(json_str, object_pairs_hook=_NamedFields)
147
+
148
+
149
+ def signatures_to_yaml(signatures):
150
+ # XXX go in/out of json to get the correct format and python dict types
151
+ d = json.loads(signatures_to_json(signatures))
152
+ return yaml.dump(d, default_flow_style=False)
153
+
154
+
155
+ def signatures_from_yaml(yaml_str):
156
+ d = yaml.safe_load(yaml_str)
157
+ return signatures_from_json(json.dumps(d))
158
+
159
+
160
+ def serialize(kwargs, signatures, proto=None, is_output=False):
161
+ '''
162
+ Serialize the given kwargs into the proto using the given signatures.
163
+ '''
164
+ if proto is None:
165
+ proto = resources_pb2.Data()
166
+ if not is_output: # TODO: use this consistently for return keys also
167
+ kwargs = flatten_nested_keys(kwargs, signatures, is_output)
168
+ unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
169
+ if unknown:
170
+ if unknown == {'return'} and len(signatures) > 1:
171
+ raise TypeError('Got a single return value, but expected multiple outputs {%s}' %
172
+ ', '.join(sig.name for sig in signatures))
173
+ raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
174
+ for sig in signatures:
175
+ if sig.name not in kwargs:
176
+ if sig.required:
177
+ raise TypeError(f'Missing required argument: {sig.name}')
178
+ continue # skip missing fields, they can be set to default on the server
179
+ data = kwargs[sig.name]
180
+ force_named_part = (_is_empty_proto_data(data) and not is_output and not sig.required)
181
+ data_proto, field = _get_data_part(
182
+ proto, sig, is_output=is_output, serializing=True, force_named_part=force_named_part)
183
+ serializer = get_serializer(sig.data_type)
184
+ serializer.serialize(data_proto, field, data)
185
+ return proto
186
+
187
+
188
+ def deserialize(proto, signatures, is_output=False):
189
+ '''
190
+ Deserialize the given proto into kwargs using the given signatures.
191
+ '''
192
+ kwargs = {}
193
+ for sig in signatures:
194
+ data_proto, field = _get_data_part(proto, sig, is_output=is_output, serializing=False)
195
+ if data_proto is None:
196
+ # not set in proto, check if required or skip if optional arg
197
+ if not is_output and sig.required:
198
+ raise ValueError(f'Missing required field: {sig.name}')
199
+ continue
200
+ serializer = get_serializer(sig.data_type)
201
+ data = serializer.deserialize(data_proto, field)
202
+ kwargs[sig.name] = data
203
+ if is_output:
204
+ if len(kwargs) == 1 and 'return' in kwargs: # case for single return value
205
+ return kwargs['return']
206
+ if kwargs and 'return.0' in kwargs: # case for tuple return values
207
+ return tuple(kwargs[f'return.{i}'] for i in range(len(kwargs)))
208
+ return data_types.Output(kwargs)
209
+ kwargs = unflatten_nested_keys(kwargs, signatures, is_output)
210
+ return kwargs
211
+
212
+
213
+ def get_serializer(data_type: str) -> Serializer:
214
+ if data_type in _SERIALIZERS_BY_TYPE_STRING:
215
+ return _SERIALIZERS_BY_TYPE_STRING[data_type]
216
+ if data_type.startswith('List['):
217
+ inner_type_string = data_type[len('List['):-1]
218
+ inner_serializer = get_serializer(inner_type_string)
219
+ return ListSerializer(inner_serializer)
220
+ raise ValueError(f'Unsupported type: "{data_type}"')
221
+
222
+
223
+ def flatten_nested_keys(kwargs, signatures, is_output):
224
+ '''
225
+ Flatten nested keys into a single key with a dot, e.g. {'a': {'b': 1}} -> {'a.b': 1}
226
+ in the kwargs, using the given signatures to determine which keys are nested.
227
+ '''
228
+ nested_keys = [sig.name for sig in signatures if '.' in sig.name]
229
+ outer_keys = set(key.split('.')[0] for key in nested_keys)
230
+ for outer in outer_keys:
231
+ if outer not in kwargs:
232
+ continue
233
+ kwargs.update({outer + '.' + k: v for k, v in kwargs.pop(outer).items()})
234
+ return kwargs
235
+
236
+
237
+ def unflatten_nested_keys(kwargs, signatures, is_output):
238
+ '''
239
+ Unflatten nested keys in kwargs into a dict, e.g. {'a.b': 1} -> {'a': {'b': 1}}
240
+ Uses the signatures to determine which keys are nested.
241
+ The dict subclass is Input or Output, depending on the is_output flag.
242
+ Preserves the order of args from the signatures.
243
+ '''
244
+ unflattened = OrderedDict()
245
+ for sig in signatures:
246
+ if '.' not in sig.name:
247
+ if sig.name in kwargs:
248
+ unflattened[sig.name] = kwargs[sig.name]
249
+ continue
250
+ if sig.name not in kwargs:
251
+ continue
252
+ parts = sig.name.split('.')
253
+ assert len(parts) == 2, 'Only one level of nested keys is supported'
254
+ if parts[0] not in unflattened:
255
+ unflattened[parts[0]] = data_types.Output() if is_output else data_types.Input()
256
+ unflattened[parts[0]][parts[1]] = kwargs[sig.name]
257
+ return unflattened
258
+
259
+
260
+ def get_stream_from_signature(signatures):
261
+ streaming_signatures = [var for var in signatures if var.streaming]
262
+ if not streaming_signatures:
263
+ return None, []
264
+ stream_argname = set([var.name.split('.', 1)[0] for var in streaming_signatures])
265
+ assert len(stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
266
+ stream_argname = stream_argname.pop()
267
+ return stream_argname, streaming_signatures
268
+
269
+
270
+ def _is_empty_proto_data(data):
271
+ if isinstance(data, np.ndarray):
272
+ return False
273
+ if isinstance(data, MessageProto):
274
+ return not data.ByteSize()
275
+ return not data
276
+
277
+
278
+ def _get_data_part(proto, sig, is_output, serializing, force_named_part=False):
279
+ field = sig.data_field
280
+
281
+ # check if we need to force a named part, to distinguish between empty and unset values
282
+ if force_named_part and not field.startswith('parts['):
283
+ field = f'parts[{sig.name}].{field}'
284
+
285
+ # gets the named part from the proto, according to the field path
286
+ # note we only support one level of named parts
287
+ #parts = field.replace(' ', '').split('.')
288
+ # split on . but not if it is inside brackets, e.g. parts[outer.inner].field
289
+ parts = re.split(r'\.(?![^\[]*\])', field.replace(' ', ''))
290
+
291
+ if len(parts) not in (1, 2, 3): # field, parts[name].field, parts[name].parts[].field
292
+ raise ValueError('Invalid field: %s' % field)
293
+
294
+ if len(parts) == 1:
295
+ # also need to check if there is an explicitly named part, e.g. for empty values
296
+ part = next((part for part in proto.parts if part.id == sig.name), None)
297
+ if part:
298
+ return part.data, field
299
+ if not serializing and not is_output and _is_empty_proto_data(getattr(proto, field)):
300
+ return None, field
301
+ return proto, field
302
+
303
+ # list
304
+ if parts[0] == 'parts[]':
305
+ if len(parts) != 2:
306
+ raise ValueError('Invalid field: %s' % field)
307
+ return proto, field # return the data that contains the list itself
308
+
309
+ # named part
310
+ if not (m := re.match(r'parts\[([\w.]+)\]', parts[0])):
311
+ raise ValueError('Invalid field: %s' % field)
312
+ if not (name := m.group(1)):
313
+ raise ValueError('Invalid field: %s' % field)
314
+ assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
315
+ part = next((part for part in proto.parts if part.id == name), None)
316
+ if part is None:
317
+ if not serializing:
318
+ raise ValueError('Missing part: %s' % name)
319
+ part = proto.parts.add()
320
+ part.id = name
321
+ return part.data, '.'.join(parts[1:])
322
+
323
+
324
+ def _normalize_types(param, is_output=False):
325
+ '''
326
+ Normalize the types for the given parameter. Returns a dict of names to types,
327
+ including named return values for outputs, and a flag indicating if streaming is used.
328
+ '''
329
+ tp = param.annotation
330
+
331
+ # stream type indicates streaming, not part of the data itself
332
+ streaming = (get_origin(tp) == data_types.Stream)
333
+ if streaming:
334
+ tp = get_args(tp)[0]
335
+
336
+ if is_output or streaming: # named types can be used for outputs or streaming inputs
337
+ # output type used for named return values, each with their own data type
338
+ if isinstance(tp, (dict, data_types.Output, data_types.Input)):
339
+ return {param.name + '.' + name: _normalize_data_type(val)
340
+ for name, val in tp.items()}, streaming
341
+ if tp == data_types.Output: # check for Output type without values
342
+ if not is_output:
343
+ raise TypeError('Output types can only be used for output values')
344
+ raise TypeError('Output types must be instantiated with inner type values for each key')
345
+ if tp == data_types.Input: # check for Output type without values
346
+ if is_output:
347
+ raise TypeError('Input types can only be used for input values')
348
+ raise TypeError(
349
+ 'Stream[Input(...)] types must be instantiated with inner type values for each key')
350
+
351
+ return {param.name: _normalize_data_type(tp)}, streaming
352
+
353
+
354
+ def _normalize_data_type(tp):
355
+ # check if list, and if so, get inner type
356
+ is_list = (get_origin(tp) == list)
357
+ if is_list:
358
+ tp = get_args(tp)[0]
359
+
360
+ # check if numpy array, and if so, use ndarray
361
+ if get_origin(tp) == np.ndarray:
362
+ tp = np.ndarray
363
+
364
+ # check for PIL images (sometimes types use the module, sometimes the class)
365
+ # set these to use the Image data handler
366
+ if tp in (PIL.Image, PIL.Image.Image):
367
+ tp = data_types.Image
368
+
369
+ # put back list
370
+ if is_list:
371
+ tp = List[tp]
372
+
373
+ # check if supported type
374
+ if tp not in _DATA_TYPES:
375
+ raise ValueError(f'Unsupported type: {tp}')
376
+
377
+ return tp
378
+
379
+
380
+ class _NamedFields(dict):
381
+ __getattr__ = dict.__getitem__
382
+ __setattr__ = dict.__setitem__
383
+
384
+
385
+ # data_type: name of the data type
386
+ # data_field: name of the field in the data proto
387
+ # serializer: serializer for the data type
388
+ _DataType = namedtuple('_DataType', ('data_type', 'data_field', 'serializer'))
389
+
390
+ # mapping of supported python types to data type names, fields, and serializers
391
+ _DATA_TYPES = {
392
+ str:
393
+ _DataType('str', 'string_value', AtomicFieldSerializer()),
394
+ bytes:
395
+ _DataType('bytes', 'bytes_value', AtomicFieldSerializer()),
396
+ int:
397
+ _DataType('int', 'int_value', AtomicFieldSerializer()),
398
+ float:
399
+ _DataType('float', 'float_value', AtomicFieldSerializer()),
400
+ bool:
401
+ _DataType('bool', 'bool_value', AtomicFieldSerializer()),
402
+ None:
403
+ _DataType('None', '', NullValueSerializer()),
404
+ np.ndarray:
405
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
406
+ data_types.Text:
407
+ _DataType('Text', 'text', MessageSerializer(data_types.Text)),
408
+ data_types.Image:
409
+ _DataType('Image', 'image', ImageSerializer()),
410
+ data_types.Concept:
411
+ _DataType('Concept', 'concepts', MessageSerializer(data_types.Concept)),
412
+ data_types.Region:
413
+ _DataType('Region', 'regions', MessageSerializer(data_types.Region)),
414
+ data_types.Frame:
415
+ _DataType('Frame', 'frames', MessageSerializer(data_types.Frame)),
416
+ data_types.Audio:
417
+ _DataType('Audio', 'audio', MessageSerializer(data_types.Audio)),
418
+ data_types.Video:
419
+ _DataType('Video', 'video', MessageSerializer(data_types.Video)),
420
+
421
+ # lists handled specially, not as generic lists using parts
422
+ List[int]:
423
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
424
+ List[float]:
425
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
426
+ List[bool]:
427
+ _DataType('ndarray', 'ndarray', NDArraySerializer()),
428
+ }
429
+
430
+
431
+ # add generic lists using parts, for all supported types
432
+ def _add_list_fields():
433
+ for tp in list(_DATA_TYPES.keys()):
434
+ if List[tp] in _DATA_TYPES:
435
+ # already added as special case
436
+ continue
437
+
438
+ # check if data field is repeated, and if so, use repeated field for list
439
+ field_name = _DATA_TYPES[tp].data_field
440
+ descriptor = resources_pb2.Data.DESCRIPTOR.fields_by_name.get(field_name)
441
+ repeated = descriptor and descriptor.label == descriptor.LABEL_REPEATED
442
+
443
+ # add to supported types
444
+ data_type = 'List[%s]' % _DATA_TYPES[tp].data_type
445
+ data_field = field_name if repeated else 'parts[].' + field_name
446
+ serializer = ListSerializer(_DATA_TYPES[tp].serializer)
447
+
448
+ _DATA_TYPES[List[tp]] = _DataType(data_type, data_field, serializer)
449
+
450
+
451
+ _add_list_fields()
452
+ _SERIALIZERS_BY_TYPE_STRING = {dt.data_type: dt.serializer for dt in _DATA_TYPES.values()}
@@ -0,0 +1,132 @@
1
+ from typing import Iterable
2
+
3
+ import numpy as np
4
+ from clarifai_grpc.grpc.api import resources_pb2
5
+ from PIL import Image as PILImage
6
+
7
+ from clarifai.runners.utils.data_types import Image, MessageData
8
+
9
+
10
+ class Serializer:
11
+
12
+ def serialize(self, data_proto, field, value):
13
+ pass
14
+
15
+ def deserialize(self, data_proto, field):
16
+ pass
17
+
18
+
19
+ def is_repeated(field):
20
+ return hasattr(field, 'add')
21
+
22
+
23
+ class AtomicFieldSerializer(Serializer):
24
+
25
+ def serialize(self, data_proto, field, value):
26
+ try:
27
+ setattr(data_proto, field, value)
28
+ except TypeError as e:
29
+ raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
30
+
31
+ def deserialize(self, data_proto, field):
32
+ return getattr(data_proto, field)
33
+
34
+
35
+ class MessageSerializer(Serializer):
36
+
37
+ def __init__(self, message_class):
38
+ self.message_class = message_class
39
+
40
+ def serialize(self, data_proto, field, value):
41
+ if isinstance(value, MessageData):
42
+ value = value.to_proto()
43
+ dst = getattr(data_proto, field)
44
+ try:
45
+ if is_repeated(dst):
46
+ dst.add().CopyFrom(value)
47
+ else:
48
+ dst.CopyFrom(value)
49
+ except TypeError as e:
50
+ raise TypeError(f"Incompatible type for {field}: {type(value)}") from e
51
+
52
+ def deserialize(self, data_proto, field):
53
+ src = getattr(data_proto, field)
54
+ if is_repeated(src):
55
+ return [self.message_class.from_proto(item) for item in src]
56
+ else:
57
+ return self.message_class.from_proto(src)
58
+
59
+
60
+ class ImageSerializer(Serializer):
61
+
62
+ def serialize(self, data_proto, field, value):
63
+ if not isinstance(value, (PILImage.Image, Image, resources_pb2.Image)):
64
+ raise TypeError(f"Expected Image, got {type(value)}")
65
+ if isinstance(value, PILImage.Image):
66
+ value = Image.from_pil(value)
67
+ if isinstance(value, MessageData):
68
+ value = value.to_proto()
69
+ getattr(data_proto, field).CopyFrom(value)
70
+
71
+ def deserialize(self, data_proto, field):
72
+ value = getattr(data_proto, field)
73
+ return Image.from_proto(value)
74
+
75
+
76
+ class NDArraySerializer(Serializer):
77
+
78
+ def serialize(self, data_proto, field, value):
79
+ value = np.asarray(value)
80
+ if not np.issubdtype(value.dtype, np.number):
81
+ raise TypeError(f"Expected number array, got {value.dtype}")
82
+ proto = getattr(data_proto, field)
83
+ proto.buffer = value.tobytes()
84
+ proto.shape.extend(value.shape)
85
+ proto.dtype = str(value.dtype)
86
+
87
+ def deserialize(self, data_proto, field):
88
+ proto = getattr(data_proto, field)
89
+ array = np.frombuffer(proto.buffer, dtype=np.dtype(proto.dtype)).reshape(proto.shape)
90
+ return array
91
+
92
+
93
+ class NullValueSerializer(Serializer):
94
+
95
+ def serialize(self, data_proto, field, value):
96
+ pass
97
+
98
+ def deserialize(self, data_proto, field):
99
+ return None
100
+
101
+
102
+ class ListSerializer(Serializer):
103
+
104
+ def __init__(self, inner_serializer):
105
+ self.inner_serializer = inner_serializer
106
+
107
+ def serialize(self, data_proto, field, value):
108
+ if not isinstance(value, Iterable):
109
+ raise TypeError(f"Expected iterable, got {type(value)}")
110
+ if field.startswith('parts[].'):
111
+ inner_field = field[len('parts[].'):]
112
+ for item in value:
113
+ part = data_proto.parts.add()
114
+ self.inner_serializer.serialize(part.data, inner_field, item)
115
+ return
116
+ repeated = getattr(data_proto, field)
117
+ assert is_repeated(repeated), f"Field {field} is not repeated"
118
+ for item in value:
119
+ self.inner_serializer.serialize(data_proto, field, item) # appends to repeated field
120
+
121
+ def deserialize(self, data_proto, field):
122
+ if field.startswith('parts[].'):
123
+ inner_field = field[len('parts[].'):]
124
+ return [
125
+ self.inner_serializer.deserialize(part.data, inner_field) for part in data_proto.parts
126
+ ]
127
+ repeated = getattr(data_proto, field)
128
+ assert is_repeated(repeated), f"Field {field} is not repeated"
129
+ return self.inner_serializer.deserialize(data_proto, field) # returns repeated field list
130
+
131
+
132
+ # TODO dict serializer, maybe json only?
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 11.1.5
3
+ Version: 11.1.5rc6
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -20,31 +20,21 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: clarifai-grpc>=11.1.3
24
- Requires-Dist: clarifai-protocol>=0.0.16
25
- Requires-Dist: numpy>=1.22.0
26
- Requires-Dist: tqdm>=4.65.0
27
- Requires-Dist: rich>=13.4.2
28
- Requires-Dist: PyYAML>=6.0.1
29
- Requires-Dist: schema==0.7.5
30
- Requires-Dist: Pillow>=9.5.0
31
- Requires-Dist: tabulate>=0.9.0
32
- Requires-Dist: fsspec>=2024.6.1
33
- Requires-Dist: click>=8.1.7
34
- Requires-Dist: requests>=2.32.3
23
+ Requires-Dist: clarifai-grpc >=11.1.3
24
+ Requires-Dist: clarifai-protocol >=0.0.16
25
+ Requires-Dist: numpy >=1.22.0
26
+ Requires-Dist: tqdm >=4.65.0
27
+ Requires-Dist: rich >=13.4.2
28
+ Requires-Dist: PyYAML >=6.0.1
29
+ Requires-Dist: schema ==0.7.5
30
+ Requires-Dist: Pillow >=9.5.0
31
+ Requires-Dist: tabulate >=0.9.0
32
+ Requires-Dist: fsspec >=2024.6.1
33
+ Requires-Dist: click >=8.1.7
34
+ Requires-Dist: requests >=2.32.3
35
+ Requires-Dist: aiohttp >=3.8.1
35
36
  Provides-Extra: all
36
- Requires-Dist: pycocotools==2.0.6; extra == "all"
37
- Dynamic: author
38
- Dynamic: author-email
39
- Dynamic: classifier
40
- Dynamic: description
41
- Dynamic: description-content-type
42
- Dynamic: home-page
43
- Dynamic: license
44
- Dynamic: provides-extra
45
- Dynamic: requires-dist
46
- Dynamic: requires-python
47
- Dynamic: summary
37
+ Requires-Dist: pycocotools ==2.0.6 ; extra == 'all'
48
38
 
49
39
  <h1 align="center">
50
40
  <a href="https://www.clarifai.com/"><img alt="Clarifai" title="Clarifai" src="https://github.com/user-attachments/assets/623b883b-7fe5-4b95-bbfa-8691f5779af4"></a>