clarifai 11.1.5rc6__py3-none-any.whl → 11.1.5rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/client/#model_client.py# +430 -0
- clarifai/client/model.py +95 -61
- clarifai/client/model_client.py +64 -49
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/model_class.py +31 -48
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_types.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
- clarifai/runners/utils/data_types.py +62 -10
- clarifai/runners/utils/method_signatures.py +278 -295
- clarifai/runners/utils/serializers.py +143 -67
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/METADATA +1 -1
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/RECORD +24 -23
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/LICENSE +0 -0
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/WHEEL +0 -0
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/entry_points.txt +0 -0
- {clarifai-11.1.5rc6.dist-info → clarifai-11.1.5rc8.dist-info}/top_level.txt +0 -0
clarifai/client/model_client.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
+
import json
|
2
3
|
import time
|
3
4
|
from typing import Any, Dict, Iterator, List
|
4
5
|
|
@@ -7,9 +8,10 @@ from clarifai_grpc.grpc.api.status import status_code_pb2
|
|
7
8
|
|
8
9
|
from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
|
9
10
|
from clarifai.errors import UserError
|
10
|
-
from clarifai.runners.utils.method_signatures import (
|
11
|
-
|
12
|
-
|
11
|
+
from clarifai.runners.utils.method_signatures import (CompatibilitySerializer, deserialize,
|
12
|
+
get_stream_from_signature, serialize,
|
13
|
+
signatures_from_json)
|
14
|
+
from clarifai.utils.logging import logger
|
13
15
|
from clarifai.utils.misc import BackoffIterator, status_is_retryable
|
14
16
|
|
15
17
|
|
@@ -29,8 +31,25 @@ class ModelClient:
|
|
29
31
|
'''
|
30
32
|
self.STUB = stub
|
31
33
|
self.request_template = request_template or service_pb2.PostModelOutputsRequest()
|
32
|
-
self.
|
33
|
-
self.
|
34
|
+
self._method_signatures = None
|
35
|
+
self._defined = False
|
36
|
+
|
37
|
+
def fetch(self):
|
38
|
+
'''
|
39
|
+
Fetch function signature definitions from the model and define the functions in the client
|
40
|
+
'''
|
41
|
+
if self._defined:
|
42
|
+
return
|
43
|
+
try:
|
44
|
+
self._fetch_signatures()
|
45
|
+
self._define_functions()
|
46
|
+
finally:
|
47
|
+
self._defined = True
|
48
|
+
|
49
|
+
def __getattr__(self, name):
|
50
|
+
if not self._defined:
|
51
|
+
self.fetch()
|
52
|
+
return self.__getattribute__(name)
|
34
53
|
|
35
54
|
def _fetch_signatures(self):
|
36
55
|
'''
|
@@ -58,19 +77,18 @@ class ModelClient:
|
|
58
77
|
response = self.STUB.PostModelOutputs(request)
|
59
78
|
if status_is_retryable(
|
60
79
|
response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
|
61
|
-
|
80
|
+
logger.info(f"Retrying model info fetch with response {response.status!r}")
|
62
81
|
time.sleep(next(backoff_iterator))
|
63
82
|
continue
|
64
83
|
break
|
65
84
|
if response.status.code == status_code_pb2.INPUT_UNSUPPORTED_FORMAT:
|
66
85
|
# return code from older models that don't support _GET_SIGNATURES
|
67
86
|
self._method_signatures = {}
|
87
|
+
self._define_compatability_functions()
|
68
88
|
return
|
69
89
|
if response.status.code != status_code_pb2.SUCCESS:
|
70
90
|
raise Exception(f"Model failed with response {response!r}")
|
71
91
|
self._method_signatures = signatures_from_json(response.outputs[0].data.text.raw)
|
72
|
-
import pdb
|
73
|
-
pdb.set_trace()
|
74
92
|
|
75
93
|
def _define_functions(self):
|
76
94
|
'''
|
@@ -116,39 +134,13 @@ class ModelClient:
|
|
116
134
|
# set names, annotations and docstrings
|
117
135
|
f.__name__ = method_name
|
118
136
|
f.__qualname__ = f'{self.__class__.__name__}.{method_name}'
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
input_annos = unflatten_nested_keys(input_annos, method_signature.inputs, is_output=False)
|
123
|
-
output_annos = unflatten_nested_keys(output_annos, method_signature.outputs, is_output=True)
|
124
|
-
|
125
|
-
# add Stream[] to the stream input annotations for docs
|
126
|
-
input_stream_argname, _ = get_stream_from_signature(method_signature.inputs)
|
127
|
-
if input_stream_argname:
|
128
|
-
input_annos[input_stream_argname] = 'Stream[' + str(
|
129
|
-
input_annos[input_stream_argname]) + ']'
|
130
|
-
|
131
|
-
# handle multiple outputs in the return annotation
|
132
|
-
return_annotation = output_annos
|
133
|
-
name = next(iter(output_annos.keys()))
|
134
|
-
if len(output_annos) == 1 and name == 'return':
|
135
|
-
# single output
|
136
|
-
return_annotation = output_annos[name]
|
137
|
-
elif name.startswith('return.') and name.split('.', 1)[1].isnumeric():
|
138
|
-
# tuple output
|
139
|
-
return_annotation = '(' + ", ".join(output_annos[f'return.{i}']
|
140
|
-
for i in range(len(output_annos))) + ')'
|
141
|
-
else:
|
142
|
-
# named output
|
143
|
-
return_annotation = f'Output({", ".join(f"{k}={t}" for k, t in output_annos.items())})'
|
144
|
-
if method_signature.method_type in ['generate', 'stream']:
|
145
|
-
return_annotation = f'Stream[{return_annotation}]'
|
146
|
-
|
147
|
-
# set annotations and docstrings
|
137
|
+
f.__doc__ = method_signature.docstring
|
138
|
+
input_annotations = json.loads(method_signature.annotations_json)
|
139
|
+
return_annotation = input_annotations.pop('return', None)
|
148
140
|
sig = inspect.signature(f).replace(
|
149
141
|
parameters=[
|
150
142
|
inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=v)
|
151
|
-
for k, v in
|
143
|
+
for k, v in input_annotations.items()
|
152
144
|
],
|
153
145
|
return_annotation=return_annotation,
|
154
146
|
)
|
@@ -156,6 +148,28 @@ class ModelClient:
|
|
156
148
|
f.__doc__ = method_signature.docstring
|
157
149
|
setattr(self, method_name, f)
|
158
150
|
|
151
|
+
def _define_compatability_functions(self):
|
152
|
+
|
153
|
+
serializer = CompatibilitySerializer()
|
154
|
+
|
155
|
+
def predict(input: Any) -> Any:
|
156
|
+
proto = resources_pb2.Input()
|
157
|
+
serializer.serialize(proto.data, input)
|
158
|
+
# always use text.raw for compat
|
159
|
+
if proto.data.string_value:
|
160
|
+
proto.data.text.raw = proto.data.string_value
|
161
|
+
proto.data.string_value = ''
|
162
|
+
response = self._predict_by_proto([proto])
|
163
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
164
|
+
raise Exception(f"Model predict failed with response {response!r}")
|
165
|
+
response_data = response.outputs[0].data
|
166
|
+
if response_data.text.raw:
|
167
|
+
response_data.string_value = response_data.text.raw
|
168
|
+
response_data.text.raw = ''
|
169
|
+
return serializer.deserialize(response_data)
|
170
|
+
|
171
|
+
self.predict = predict
|
172
|
+
|
159
173
|
def _predict(
|
160
174
|
self,
|
161
175
|
inputs, # TODO set up functions according to fetched signatures?
|
@@ -214,10 +228,9 @@ class ModelClient:
|
|
214
228
|
request.inputs.extend(inputs)
|
215
229
|
|
216
230
|
if method_name:
|
217
|
-
for inp in inputs:
|
218
|
-
inp.data.metadata['_method_name'] = method_name
|
219
231
|
# TODO put in new proto field?
|
220
|
-
|
232
|
+
for inp in request.inputs:
|
233
|
+
inp.data.metadata['_method_name'] = method_name
|
221
234
|
if inference_params:
|
222
235
|
request.model.model_version.output_info.params.update(inference_params)
|
223
236
|
if output_config:
|
@@ -230,7 +243,7 @@ class ModelClient:
|
|
230
243
|
response = self.STUB.PostModelOutputs(request)
|
231
244
|
if status_is_retryable(
|
232
245
|
response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
|
233
|
-
|
246
|
+
logger.info(f"Model predict failed with response {response!r}")
|
234
247
|
time.sleep(next(backoff_iterator))
|
235
248
|
continue
|
236
249
|
|
@@ -298,9 +311,8 @@ class ModelClient:
|
|
298
311
|
|
299
312
|
if method_name:
|
300
313
|
# TODO put in new proto field?
|
301
|
-
for inp in inputs:
|
314
|
+
for inp in request.inputs:
|
302
315
|
inp.data.metadata['_method_name'] = method_name
|
303
|
-
# request.model.model_version.output_info.params['_method_name'] = method_name
|
304
316
|
if inference_params:
|
305
317
|
request.model.model_version.output_info.params.update(inference_params)
|
306
318
|
if output_config:
|
@@ -317,7 +329,7 @@ class ModelClient:
|
|
317
329
|
raise Exception("Model Generate failed with no response")
|
318
330
|
if status_is_retryable(response.status.code) and \
|
319
331
|
time.time() - start_time < 60 * 10:
|
320
|
-
|
332
|
+
logger.info("Model is still deploying, please wait...")
|
321
333
|
time.sleep(next(backoff_iterator))
|
322
334
|
continue
|
323
335
|
if response.status.code != status_code_pb2.SUCCESS:
|
@@ -346,7 +358,10 @@ class ModelClient:
|
|
346
358
|
kwargs = inputs
|
347
359
|
|
348
360
|
# find the streaming vars in the input signature, and the streaming input python param
|
349
|
-
|
361
|
+
stream_sig = get_stream_from_signature(input_signature)
|
362
|
+
if stream_sig is None:
|
363
|
+
raise ValueError("Streaming method must have a Stream input")
|
364
|
+
stream_argname = stream_sig.name
|
350
365
|
|
351
366
|
# get the streaming input generator from the user-provided function arg values
|
352
367
|
user_inputs_generator = kwargs.pop(stream_argname)
|
@@ -366,7 +381,7 @@ class ModelClient:
|
|
366
381
|
# subsequent items are just the stream items
|
367
382
|
for item in user_inputs_generator:
|
368
383
|
proto = resources_pb2.Input()
|
369
|
-
serialize({stream_argname: item},
|
384
|
+
serialize({stream_argname: item}, [stream_sig], proto.data)
|
370
385
|
yield proto
|
371
386
|
|
372
387
|
response_stream = self._stream_by_proto(_input_proto_stream(), method_name)
|
@@ -383,7 +398,6 @@ class ModelClient:
|
|
383
398
|
output_config: Dict = {}):
|
384
399
|
request = service_pb2.PostModelOutputsRequest()
|
385
400
|
request.CopyFrom(self.request_template)
|
386
|
-
# request.model.model_version.output_info.params['_method_name'] = method_name
|
387
401
|
if inference_params:
|
388
402
|
request.model.model_version.output_info.params.update(inference_params)
|
389
403
|
if output_config:
|
@@ -395,6 +409,7 @@ class ModelClient:
|
|
395
409
|
req.inputs.extend(inputs)
|
396
410
|
else:
|
397
411
|
req.inputs.append(inputs)
|
412
|
+
# TODO: put into new proto field?
|
398
413
|
for inp in req.inputs:
|
399
414
|
inp.data.metadata['_method_name'] = method_name
|
400
415
|
yield req
|
@@ -421,7 +436,7 @@ class ModelClient:
|
|
421
436
|
for response in stream_response:
|
422
437
|
if status_is_retryable(response.status.code) and \
|
423
438
|
time.time() - start_time < 60 * 10:
|
424
|
-
|
439
|
+
logger.info("Model is still deploying, please wait...")
|
425
440
|
time.sleep(next(backoff_iterator))
|
426
441
|
break
|
427
442
|
if response.status.code != status_code_pb2.SUCCESS:
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -19,32 +19,11 @@ _METHOD_INFO_ATTR = '_cf_method_info'
|
|
19
19
|
_RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() == "true"
|
20
20
|
|
21
21
|
|
22
|
-
class methods:
|
23
|
-
'''
|
24
|
-
Decorators to mark methods as predict, generate, or stream methods.
|
25
|
-
'''
|
26
|
-
|
27
|
-
@staticmethod
|
28
|
-
def predict(method):
|
29
|
-
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'predict'))
|
30
|
-
return method
|
31
|
-
|
32
|
-
@staticmethod
|
33
|
-
def generate(method):
|
34
|
-
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'generate'))
|
35
|
-
return method
|
36
|
-
|
37
|
-
@staticmethod
|
38
|
-
def stream(method):
|
39
|
-
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'stream'))
|
40
|
-
return method
|
41
|
-
|
42
|
-
|
43
22
|
class ModelClass(ABC):
|
44
23
|
'''
|
45
24
|
Base class for model classes that can be run as a service.
|
46
25
|
|
47
|
-
Define
|
26
|
+
Define predict, generate, or stream methods using the @ModelClass.method decorator.
|
48
27
|
|
49
28
|
Example:
|
50
29
|
|
@@ -53,21 +32,26 @@ class ModelClass(ABC):
|
|
53
32
|
|
54
33
|
class MyModel(ModelClass):
|
55
34
|
|
56
|
-
@
|
35
|
+
@ModelClass.method
|
57
36
|
def predict(self, x: str, y: int) -> List[str]:
|
58
37
|
return [x] * y
|
59
38
|
|
60
|
-
@
|
39
|
+
@ModelClass.method
|
61
40
|
def generate(self, x: str, y: int) -> Stream[str]:
|
62
41
|
for i in range(y):
|
63
42
|
yield x + str(i)
|
64
43
|
|
65
|
-
@
|
44
|
+
@ModelClass.method
|
66
45
|
def stream(self, input_stream: Stream[Input(x=str, y=int)]) -> Stream[str]:
|
67
46
|
for item in input_stream:
|
68
47
|
yield item.x + ' ' + str(item.y)
|
69
48
|
'''
|
70
49
|
|
50
|
+
@staticmethod
|
51
|
+
def method(func):
|
52
|
+
setattr(func, _METHOD_INFO_ATTR, _MethodInfo(func))
|
53
|
+
return func
|
54
|
+
|
71
55
|
def load_model(self):
|
72
56
|
"""Load the model."""
|
73
57
|
|
@@ -80,7 +64,7 @@ class ModelClass(ABC):
|
|
80
64
|
output.data.text.raw = signatures_to_json(signatures)
|
81
65
|
return resp
|
82
66
|
|
83
|
-
def
|
67
|
+
def _batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
|
84
68
|
"""Batch predict method for multiple inputs."""
|
85
69
|
outputs = []
|
86
70
|
for input in inputs:
|
@@ -88,7 +72,7 @@ class ModelClass(ABC):
|
|
88
72
|
outputs.append(output)
|
89
73
|
return outputs
|
90
74
|
|
91
|
-
def
|
75
|
+
def _batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
|
92
76
|
"""Batch generate method for multiple inputs."""
|
93
77
|
generators = [method(**input) for input in inputs]
|
94
78
|
for outputs in itertools.zip_longest(*generators):
|
@@ -99,11 +83,9 @@ class ModelClass(ABC):
|
|
99
83
|
outputs = []
|
100
84
|
try:
|
101
85
|
# TODO add method name field to proto
|
102
|
-
method_name =
|
103
|
-
if len(request.inputs) > 0:
|
104
|
-
method_name = request.inputs[0].data.metadata
|
105
|
-
# call_params = dict(request.model.model_version.output_info.params)
|
106
|
-
# method_name = call_params.get('_method_name', 'predict')
|
86
|
+
method_name = 'predict'
|
87
|
+
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
88
|
+
method_name = request.inputs[0].data.metadata['_method_name']
|
107
89
|
if method_name == '_GET_SIGNATURES': # special case to fetch signatures, TODO add endpoint for this
|
108
90
|
return self._handle_get_signatures_request()
|
109
91
|
if method_name not in self._get_method_info():
|
@@ -119,7 +101,7 @@ class ModelClass(ABC):
|
|
119
101
|
output = method(**inputs)
|
120
102
|
outputs.append(self._convert_output_to_proto(output, signature.outputs))
|
121
103
|
else:
|
122
|
-
outputs = self.
|
104
|
+
outputs = self._batch_predict(method, inputs)
|
123
105
|
outputs = [self._convert_output_to_proto(output, signature.outputs) for output in outputs]
|
124
106
|
|
125
107
|
return service_pb2.MultiOutputResponse(
|
@@ -136,8 +118,9 @@ class ModelClass(ABC):
|
|
136
118
|
def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
|
137
119
|
) -> Iterator[service_pb2.MultiOutputResponse]:
|
138
120
|
try:
|
139
|
-
|
140
|
-
|
121
|
+
method_name = 'generate'
|
122
|
+
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
123
|
+
method_name = request.inputs[0].data.metadata['_method_name']
|
141
124
|
method = getattr(self, method_name)
|
142
125
|
method_info = method._cf_method_info
|
143
126
|
signature = method_info.signature
|
@@ -153,7 +136,7 @@ class ModelClass(ABC):
|
|
153
136
|
resp.status.code = status_code_pb2.SUCCESS
|
154
137
|
yield resp
|
155
138
|
else:
|
156
|
-
for outputs in self.
|
139
|
+
for outputs in self._batch_generate(method, inputs):
|
157
140
|
resp = service_pb2.MultiOutputResponse()
|
158
141
|
for output in outputs:
|
159
142
|
self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
|
@@ -174,15 +157,19 @@ class ModelClass(ABC):
|
|
174
157
|
request = next(request_iterator) # get first request to determine method
|
175
158
|
assert len(request.inputs) == 1, "Streaming requires exactly one input"
|
176
159
|
|
177
|
-
|
178
|
-
|
160
|
+
method_name = 'generate'
|
161
|
+
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
162
|
+
method_name = request.inputs[0].data.metadata['_method_name']
|
179
163
|
method = getattr(self, method_name)
|
180
164
|
method_info = method._cf_method_info
|
181
165
|
signature = method_info.signature
|
182
166
|
python_param_types = method_info.python_param_types
|
183
167
|
|
184
168
|
# find the streaming vars in the signature
|
185
|
-
|
169
|
+
stream_sig = get_stream_from_signature(signature.inputs)
|
170
|
+
if stream_sig is None:
|
171
|
+
raise ValueError("Streaming method must have a Stream input")
|
172
|
+
stream_argname = stream_sig.name
|
186
173
|
|
187
174
|
# convert all inputs for the first request, including the first stream value
|
188
175
|
inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
|
@@ -197,7 +184,7 @@ class ModelClass(ABC):
|
|
197
184
|
yield first_item
|
198
185
|
# subsequent streaming items contain only the streaming input
|
199
186
|
for request in request_iterator:
|
200
|
-
item = self._convert_input_protos_to_python(request.inputs,
|
187
|
+
item = self._convert_input_protos_to_python(request.inputs, stream_sig,
|
201
188
|
python_param_types)
|
202
189
|
item = item[0][stream_argname]
|
203
190
|
yield item
|
@@ -236,11 +223,7 @@ class ModelClass(ABC):
|
|
236
223
|
proto=None) -> resources_pb2.Output:
|
237
224
|
if proto is None:
|
238
225
|
proto = resources_pb2.Output()
|
239
|
-
|
240
|
-
output = {f'return.{i}': item for i, item in enumerate(output)}
|
241
|
-
if not isinstance(output, dict): # TODO Output type, not just dict
|
242
|
-
output = {'return': output}
|
243
|
-
serialize(output, variables_signature, proto.data, is_output=True)
|
226
|
+
serialize({'return': output}, [variables_signature], proto.data, is_output=True)
|
244
227
|
proto.status.code = status_code_pb2.SUCCESS
|
245
228
|
return proto
|
246
229
|
|
@@ -259,7 +242,7 @@ class ModelClass(ABC):
|
|
259
242
|
# if hasattr(cls, name):
|
260
243
|
# method = getattr(cls, name)
|
261
244
|
# if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
|
262
|
-
# methods[name] = _MethodInfo(method
|
245
|
+
# methods[name] = _MethodInfo(method)
|
263
246
|
# set method table for this class in the registry
|
264
247
|
return methods
|
265
248
|
|
@@ -275,9 +258,9 @@ class ModelClass(ABC):
|
|
275
258
|
|
276
259
|
class _MethodInfo:
|
277
260
|
|
278
|
-
def __init__(self, method
|
261
|
+
def __init__(self, method):
|
279
262
|
self.name = method.__name__
|
280
|
-
self.signature = build_function_signature(method
|
263
|
+
self.signature = build_function_signature(method)
|
281
264
|
self.python_param_types = {
|
282
265
|
p.name: p.annotation
|
283
266
|
for p in inspect.signature(method).parameters.values()
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import io
|
2
|
+
import json
|
2
3
|
from typing import Iterable, List, get_args, get_origin
|
3
4
|
|
4
5
|
import numpy as np
|
@@ -21,13 +22,19 @@ class MessageData:
|
|
21
22
|
def from_proto(cls, proto):
|
22
23
|
raise NotImplementedError
|
23
24
|
|
25
|
+
@classmethod
|
26
|
+
def from_value(cls, value):
|
27
|
+
if isinstance(value, cls):
|
28
|
+
return value
|
29
|
+
return cls(value)
|
30
|
+
|
24
31
|
def cast(self, python_type):
|
25
32
|
if python_type == self.__class__:
|
26
33
|
return self
|
27
34
|
raise TypeError(f'Incompatible type for {self.__class__.__name__}: {python_type}')
|
28
35
|
|
29
36
|
|
30
|
-
class
|
37
|
+
class NamedFields(dict):
|
31
38
|
__getattr__ = dict.__getitem__
|
32
39
|
__setattr__ = dict.__setitem__
|
33
40
|
|
@@ -38,19 +45,36 @@ class Output(dict):
|
|
38
45
|
return list(self.keys())
|
39
46
|
|
40
47
|
|
41
|
-
class
|
42
|
-
|
43
|
-
__setattr__ = dict.__setitem__
|
48
|
+
class Stream(Iterable):
|
49
|
+
pass
|
44
50
|
|
45
|
-
def __origin__(self):
|
46
|
-
return self
|
47
51
|
|
48
|
-
|
49
|
-
return list(self.keys())
|
52
|
+
class JSON:
|
50
53
|
|
54
|
+
def __init__(self, value):
|
55
|
+
self.value = value
|
51
56
|
|
52
|
-
|
53
|
-
|
57
|
+
def __eq__(self, other):
|
58
|
+
return self.value == other
|
59
|
+
|
60
|
+
def __bool__(self):
|
61
|
+
return bool(self.value)
|
62
|
+
|
63
|
+
def to_json(self):
|
64
|
+
return json.dumps(self.value)
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def from_json(cls, json_str):
|
68
|
+
return cls(json.loads(json_str))
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def from_value(cls, value):
|
72
|
+
return cls(value)
|
73
|
+
|
74
|
+
def cast(self, python_type):
|
75
|
+
if not isinstance(self.value, python_type):
|
76
|
+
raise TypeError(f'Incompatible type {type(self.value)} for {python_type}')
|
77
|
+
return self.value
|
54
78
|
|
55
79
|
|
56
80
|
class Text(MessageData):
|
@@ -59,6 +83,16 @@ class Text(MessageData):
|
|
59
83
|
self.text = text
|
60
84
|
self.url = url
|
61
85
|
|
86
|
+
def __eq__(self, other):
|
87
|
+
if isinstance(other, Text):
|
88
|
+
return self.text == other.text and self.url == other.url
|
89
|
+
if isinstance(other, str):
|
90
|
+
return self.text == other
|
91
|
+
return False
|
92
|
+
|
93
|
+
def __bool__(self):
|
94
|
+
return bool(self.text) or bool(self.url)
|
95
|
+
|
62
96
|
def to_proto(self) -> TextProto:
|
63
97
|
return TextProto(raw=self.text or '', url=self.url or '')
|
64
98
|
|
@@ -66,6 +100,16 @@ class Text(MessageData):
|
|
66
100
|
def from_proto(cls, proto: TextProto) -> "Text":
|
67
101
|
return cls(proto.raw, proto.url or None)
|
68
102
|
|
103
|
+
@classmethod
|
104
|
+
def from_value(cls, value):
|
105
|
+
if isinstance(value, str):
|
106
|
+
return cls(value)
|
107
|
+
if isinstance(value, Text):
|
108
|
+
return value
|
109
|
+
if isinstance(value, dict):
|
110
|
+
return cls(value.get('text'), value.get('url'))
|
111
|
+
raise TypeError(f'Incompatible type for Text: {type(value)}')
|
112
|
+
|
69
113
|
def cast(self, python_type):
|
70
114
|
if python_type == str:
|
71
115
|
return self.text
|
@@ -189,6 +233,14 @@ class Image(MessageData):
|
|
189
233
|
def from_proto(cls, proto: ImageProto) -> "Image":
|
190
234
|
return cls(proto)
|
191
235
|
|
236
|
+
@classmethod
|
237
|
+
def from_value(cls, value):
|
238
|
+
if isinstance(value, PILImage.Image):
|
239
|
+
return cls.from_pil(value)
|
240
|
+
if isinstance(value, Image):
|
241
|
+
return value
|
242
|
+
raise TypeError(f'Incompatible type for Image: {type(value)}')
|
243
|
+
|
192
244
|
def cast(self, python_type):
|
193
245
|
if python_type == Image:
|
194
246
|
return self
|