cocoindex 0.2.3__cp311-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +92 -0
- cocoindex/_engine.pyd +0 -0
- cocoindex/auth_registry.py +51 -0
- cocoindex/cli.py +697 -0
- cocoindex/convert.py +621 -0
- cocoindex/flow.py +1205 -0
- cocoindex/functions.py +357 -0
- cocoindex/index.py +29 -0
- cocoindex/lib.py +32 -0
- cocoindex/llm.py +46 -0
- cocoindex/op.py +628 -0
- cocoindex/py.typed +0 -0
- cocoindex/runtime.py +37 -0
- cocoindex/setting.py +181 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources.py +102 -0
- cocoindex/subprocess_exec.py +279 -0
- cocoindex/targets.py +135 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/conftest.py +38 -0
- cocoindex/tests/test_convert.py +1543 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +207 -0
- cocoindex/tests/test_typing.py +429 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +473 -0
- cocoindex/user_app_loader.py +51 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.2.3.dist-info/METADATA +262 -0
- cocoindex-0.2.3.dist-info/RECORD +34 -0
- cocoindex-0.2.3.dist-info/WHEEL +4 -0
- cocoindex-0.2.3.dist-info/entry_points.txt +2 -0
- cocoindex-0.2.3.dist-info/licenses/LICENSE +201 -0
cocoindex/convert.py
ADDED
@@ -0,0 +1,621 @@
|
|
1
|
+
"""
|
2
|
+
Utilities to convert between Python and engine values.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import dataclasses
|
8
|
+
import datetime
|
9
|
+
import inspect
|
10
|
+
import warnings
|
11
|
+
from enum import Enum
|
12
|
+
from typing import Any, Callable, Mapping, Sequence, Type, get_origin
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from .typing import (
|
17
|
+
TABLE_TYPES,
|
18
|
+
AnalyzedAnyType,
|
19
|
+
AnalyzedBasicType,
|
20
|
+
AnalyzedDictType,
|
21
|
+
AnalyzedListType,
|
22
|
+
AnalyzedStructType,
|
23
|
+
AnalyzedTypeInfo,
|
24
|
+
AnalyzedUnionType,
|
25
|
+
AnalyzedUnknownType,
|
26
|
+
analyze_type_info,
|
27
|
+
encode_enriched_type,
|
28
|
+
is_namedtuple_type,
|
29
|
+
is_numpy_number_type,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class ChildFieldPath:
|
34
|
+
"""Context manager to append a field to field_path on enter and pop it on exit."""
|
35
|
+
|
36
|
+
_field_path: list[str]
|
37
|
+
_field_name: str
|
38
|
+
|
39
|
+
def __init__(self, field_path: list[str], field_name: str):
|
40
|
+
self._field_path: list[str] = field_path
|
41
|
+
self._field_name = field_name
|
42
|
+
|
43
|
+
def __enter__(self) -> ChildFieldPath:
|
44
|
+
self._field_path.append(self._field_name)
|
45
|
+
return self
|
46
|
+
|
47
|
+
def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
|
48
|
+
self._field_path.pop()
|
49
|
+
|
50
|
+
|
51
|
+
_CONVERTIBLE_KINDS = {
|
52
|
+
("Float32", "Float64"),
|
53
|
+
("LocalDateTime", "OffsetDateTime"),
|
54
|
+
}
|
55
|
+
|
56
|
+
|
57
|
+
def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
|
58
|
+
return (
|
59
|
+
src_type_kind == dst_type_kind
|
60
|
+
or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
# Pre-computed type info for missing/Any type annotations
|
65
|
+
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
|
66
|
+
|
67
|
+
|
68
|
+
def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
|
69
|
+
"""
|
70
|
+
Create an encoder closure for a specific type.
|
71
|
+
"""
|
72
|
+
variant = type_info.variant
|
73
|
+
|
74
|
+
if isinstance(variant, AnalyzedUnknownType):
|
75
|
+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
|
76
|
+
|
77
|
+
if isinstance(variant, AnalyzedListType):
|
78
|
+
elem_type_info = (
|
79
|
+
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
|
80
|
+
)
|
81
|
+
if isinstance(elem_type_info.variant, AnalyzedStructType):
|
82
|
+
elem_encoder = make_engine_value_encoder(elem_type_info)
|
83
|
+
|
84
|
+
def encode_struct_list(value: Any) -> Any:
|
85
|
+
return None if value is None else [elem_encoder(v) for v in value]
|
86
|
+
|
87
|
+
return encode_struct_list
|
88
|
+
|
89
|
+
# Otherwise it's a vector, falling into basic type in the engine.
|
90
|
+
|
91
|
+
if isinstance(variant, AnalyzedDictType):
|
92
|
+
value_type_info = analyze_type_info(variant.value_type)
|
93
|
+
if not isinstance(value_type_info.variant, AnalyzedStructType):
|
94
|
+
raise ValueError(
|
95
|
+
f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. "
|
96
|
+
f"If you want a free-formed dict, use `cocoindex.Json` instead."
|
97
|
+
)
|
98
|
+
value_encoder = make_engine_value_encoder(value_type_info)
|
99
|
+
|
100
|
+
key_type_info = analyze_type_info(variant.key_type)
|
101
|
+
key_encoder = make_engine_value_encoder(key_type_info)
|
102
|
+
if isinstance(key_type_info.variant, AnalyzedBasicType):
|
103
|
+
|
104
|
+
def encode_row(k: Any, v: Any) -> Any:
|
105
|
+
return [key_encoder(k)] + value_encoder(v)
|
106
|
+
|
107
|
+
else:
|
108
|
+
|
109
|
+
def encode_row(k: Any, v: Any) -> Any:
|
110
|
+
return key_encoder(k) + value_encoder(v)
|
111
|
+
|
112
|
+
def encode_struct_dict(value: Any) -> Any:
|
113
|
+
if not value:
|
114
|
+
return []
|
115
|
+
return [encode_row(k, v) for k, v in value.items()]
|
116
|
+
|
117
|
+
return encode_struct_dict
|
118
|
+
|
119
|
+
if isinstance(variant, AnalyzedStructType):
|
120
|
+
struct_type = variant.struct_type
|
121
|
+
|
122
|
+
if dataclasses.is_dataclass(struct_type):
|
123
|
+
fields = dataclasses.fields(struct_type)
|
124
|
+
field_encoders = [
|
125
|
+
make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
|
126
|
+
]
|
127
|
+
field_names = [f.name for f in fields]
|
128
|
+
|
129
|
+
def encode_dataclass(value: Any) -> Any:
|
130
|
+
if value is None:
|
131
|
+
return None
|
132
|
+
return [
|
133
|
+
encoder(getattr(value, name))
|
134
|
+
for encoder, name in zip(field_encoders, field_names)
|
135
|
+
]
|
136
|
+
|
137
|
+
return encode_dataclass
|
138
|
+
|
139
|
+
elif is_namedtuple_type(struct_type):
|
140
|
+
annotations = struct_type.__annotations__
|
141
|
+
field_names = list(getattr(struct_type, "_fields", ()))
|
142
|
+
field_encoders = [
|
143
|
+
make_engine_value_encoder(
|
144
|
+
analyze_type_info(annotations[name])
|
145
|
+
if name in annotations
|
146
|
+
else ANY_TYPE_INFO
|
147
|
+
)
|
148
|
+
for name in field_names
|
149
|
+
]
|
150
|
+
|
151
|
+
def encode_namedtuple(value: Any) -> Any:
|
152
|
+
if value is None:
|
153
|
+
return None
|
154
|
+
return [
|
155
|
+
encoder(getattr(value, name))
|
156
|
+
for encoder, name in zip(field_encoders, field_names)
|
157
|
+
]
|
158
|
+
|
159
|
+
return encode_namedtuple
|
160
|
+
|
161
|
+
def encode_basic_value(value: Any) -> Any:
|
162
|
+
if isinstance(value, np.number):
|
163
|
+
return value.item()
|
164
|
+
if isinstance(value, np.ndarray):
|
165
|
+
return value
|
166
|
+
if isinstance(value, (list, tuple)):
|
167
|
+
return [encode_basic_value(v) for v in value]
|
168
|
+
return value
|
169
|
+
|
170
|
+
return encode_basic_value
|
171
|
+
|
172
|
+
|
173
|
+
def make_engine_key_decoder(
|
174
|
+
field_path: list[str],
|
175
|
+
key_fields_schema: list[dict[str, Any]],
|
176
|
+
dst_type_info: AnalyzedTypeInfo,
|
177
|
+
) -> Callable[[Any], Any]:
|
178
|
+
"""
|
179
|
+
Create an encoder closure for a key type.
|
180
|
+
"""
|
181
|
+
if len(key_fields_schema) == 1 and isinstance(
|
182
|
+
dst_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType)
|
183
|
+
):
|
184
|
+
single_key_decoder = make_engine_value_decoder(
|
185
|
+
field_path,
|
186
|
+
key_fields_schema[0]["type"],
|
187
|
+
dst_type_info,
|
188
|
+
for_key=True,
|
189
|
+
)
|
190
|
+
|
191
|
+
def key_decoder(value: list[Any]) -> Any:
|
192
|
+
return single_key_decoder(value[0])
|
193
|
+
|
194
|
+
return key_decoder
|
195
|
+
|
196
|
+
return make_engine_struct_decoder(
|
197
|
+
field_path,
|
198
|
+
key_fields_schema,
|
199
|
+
dst_type_info,
|
200
|
+
for_key=True,
|
201
|
+
)
|
202
|
+
|
203
|
+
|
204
|
+
def make_engine_value_decoder(
|
205
|
+
field_path: list[str],
|
206
|
+
src_type: dict[str, Any],
|
207
|
+
dst_type_info: AnalyzedTypeInfo,
|
208
|
+
for_key: bool = False,
|
209
|
+
) -> Callable[[Any], Any]:
|
210
|
+
"""
|
211
|
+
Make a decoder from an engine value to a Python value.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
field_path: The path to the field in the engine value. For error messages.
|
215
|
+
src_type: The type of the engine value, mapped from a `cocoindex::base::schema::ValueType`.
|
216
|
+
dst_annotation: The type annotation of the Python value.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
A decoder from an engine value to a Python value.
|
220
|
+
"""
|
221
|
+
|
222
|
+
src_type_kind = src_type["kind"]
|
223
|
+
|
224
|
+
dst_type_variant = dst_type_info.variant
|
225
|
+
|
226
|
+
if isinstance(dst_type_variant, AnalyzedUnknownType):
|
227
|
+
raise ValueError(
|
228
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
229
|
+
f"declared `{dst_type_info.core_type}`, an unsupported type"
|
230
|
+
)
|
231
|
+
|
232
|
+
if src_type_kind == "Struct":
|
233
|
+
return make_engine_struct_decoder(
|
234
|
+
field_path,
|
235
|
+
src_type["fields"],
|
236
|
+
dst_type_info,
|
237
|
+
for_key=for_key,
|
238
|
+
)
|
239
|
+
|
240
|
+
if src_type_kind in TABLE_TYPES:
|
241
|
+
with ChildFieldPath(field_path, "[*]"):
|
242
|
+
engine_fields_schema = src_type["row"]["fields"]
|
243
|
+
|
244
|
+
if src_type_kind == "LTable":
|
245
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
246
|
+
dst_elem_type = Any
|
247
|
+
elif isinstance(dst_type_variant, AnalyzedListType):
|
248
|
+
dst_elem_type = dst_type_variant.elem_type
|
249
|
+
else:
|
250
|
+
raise ValueError(
|
251
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
252
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
253
|
+
)
|
254
|
+
row_decoder = make_engine_struct_decoder(
|
255
|
+
field_path,
|
256
|
+
engine_fields_schema,
|
257
|
+
analyze_type_info(dst_elem_type),
|
258
|
+
)
|
259
|
+
|
260
|
+
def decode(value: Any) -> Any | None:
|
261
|
+
if value is None:
|
262
|
+
return None
|
263
|
+
return [row_decoder(v) for v in value]
|
264
|
+
|
265
|
+
elif src_type_kind == "KTable":
|
266
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
267
|
+
key_type, value_type = Any, Any
|
268
|
+
elif isinstance(dst_type_variant, AnalyzedDictType):
|
269
|
+
key_type = dst_type_variant.key_type
|
270
|
+
value_type = dst_type_variant.value_type
|
271
|
+
else:
|
272
|
+
raise ValueError(
|
273
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
274
|
+
f"declared `{dst_type_info.core_type}`, a dict type expected"
|
275
|
+
)
|
276
|
+
|
277
|
+
num_key_parts = src_type.get("num_key_parts", 1)
|
278
|
+
key_decoder = make_engine_key_decoder(
|
279
|
+
field_path,
|
280
|
+
engine_fields_schema[0:num_key_parts],
|
281
|
+
analyze_type_info(key_type),
|
282
|
+
)
|
283
|
+
value_decoder = make_engine_struct_decoder(
|
284
|
+
field_path,
|
285
|
+
engine_fields_schema[num_key_parts:],
|
286
|
+
analyze_type_info(value_type),
|
287
|
+
)
|
288
|
+
|
289
|
+
def decode(value: Any) -> Any | None:
|
290
|
+
if value is None:
|
291
|
+
return None
|
292
|
+
return {
|
293
|
+
key_decoder(v[0:num_key_parts]): value_decoder(
|
294
|
+
v[num_key_parts:]
|
295
|
+
)
|
296
|
+
for v in value
|
297
|
+
}
|
298
|
+
|
299
|
+
return decode
|
300
|
+
|
301
|
+
if src_type_kind == "Union":
|
302
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
303
|
+
return lambda value: value[1]
|
304
|
+
|
305
|
+
dst_type_info_variants = (
|
306
|
+
[analyze_type_info(t) for t in dst_type_variant.variant_types]
|
307
|
+
if isinstance(dst_type_variant, AnalyzedUnionType)
|
308
|
+
else [dst_type_info]
|
309
|
+
)
|
310
|
+
src_type_variants = src_type["types"]
|
311
|
+
decoders = []
|
312
|
+
for i, src_type_variant in enumerate(src_type_variants):
|
313
|
+
with ChildFieldPath(field_path, f"[{i}]"):
|
314
|
+
decoder = None
|
315
|
+
for dst_type_info_variant in dst_type_info_variants:
|
316
|
+
try:
|
317
|
+
decoder = make_engine_value_decoder(
|
318
|
+
field_path, src_type_variant, dst_type_info_variant
|
319
|
+
)
|
320
|
+
break
|
321
|
+
except ValueError:
|
322
|
+
pass
|
323
|
+
if decoder is None:
|
324
|
+
raise ValueError(
|
325
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
326
|
+
f"cannot find matched target type for source type variant {src_type_variant}"
|
327
|
+
)
|
328
|
+
decoders.append(decoder)
|
329
|
+
return lambda value: decoders[value[0]](value[1])
|
330
|
+
|
331
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
332
|
+
return lambda value: value
|
333
|
+
|
334
|
+
if src_type_kind == "Vector":
|
335
|
+
field_path_str = "".join(field_path)
|
336
|
+
if not isinstance(dst_type_variant, AnalyzedListType):
|
337
|
+
raise ValueError(
|
338
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
339
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
340
|
+
)
|
341
|
+
expected_dim = (
|
342
|
+
dst_type_variant.vector_info.dim
|
343
|
+
if dst_type_variant and dst_type_variant.vector_info
|
344
|
+
else None
|
345
|
+
)
|
346
|
+
|
347
|
+
vec_elem_decoder = None
|
348
|
+
scalar_dtype = None
|
349
|
+
if dst_type_variant and dst_type_info.base_type is np.ndarray:
|
350
|
+
if is_numpy_number_type(dst_type_variant.elem_type):
|
351
|
+
scalar_dtype = dst_type_variant.elem_type
|
352
|
+
else:
|
353
|
+
vec_elem_decoder = make_engine_value_decoder(
|
354
|
+
field_path + ["[*]"],
|
355
|
+
src_type["element_type"],
|
356
|
+
analyze_type_info(
|
357
|
+
dst_type_variant.elem_type if dst_type_variant else Any
|
358
|
+
),
|
359
|
+
)
|
360
|
+
|
361
|
+
def decode_vector(value: Any) -> Any | None:
|
362
|
+
if value is None:
|
363
|
+
if dst_type_info.nullable:
|
364
|
+
return None
|
365
|
+
raise ValueError(
|
366
|
+
f"Received null for non-nullable vector `{field_path_str}`"
|
367
|
+
)
|
368
|
+
if not isinstance(value, (np.ndarray, list)):
|
369
|
+
raise TypeError(
|
370
|
+
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
|
371
|
+
)
|
372
|
+
if expected_dim is not None and len(value) != expected_dim:
|
373
|
+
raise ValueError(
|
374
|
+
f"Vector dimension mismatch for `{field_path_str}`: "
|
375
|
+
f"expected {expected_dim}, got {len(value)}"
|
376
|
+
)
|
377
|
+
|
378
|
+
if vec_elem_decoder is not None: # for Non-NDArray vector
|
379
|
+
return [vec_elem_decoder(v) for v in value]
|
380
|
+
else: # for NDArray vector
|
381
|
+
return np.array(value, dtype=scalar_dtype)
|
382
|
+
|
383
|
+
return decode_vector
|
384
|
+
|
385
|
+
if isinstance(dst_type_variant, AnalyzedBasicType):
|
386
|
+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
|
387
|
+
raise ValueError(
|
388
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
389
|
+
f"passed in {src_type_kind}, declared {dst_type_info.core_type} ({dst_type_variant.kind})"
|
390
|
+
)
|
391
|
+
|
392
|
+
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
|
393
|
+
dst_core_type = dst_type_info.core_type
|
394
|
+
|
395
|
+
def decode_scalar(value: Any) -> Any | None:
|
396
|
+
if value is None:
|
397
|
+
if dst_type_info.nullable:
|
398
|
+
return None
|
399
|
+
raise ValueError(
|
400
|
+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
|
401
|
+
)
|
402
|
+
return dst_core_type(value)
|
403
|
+
|
404
|
+
return decode_scalar
|
405
|
+
|
406
|
+
return lambda value: value
|
407
|
+
|
408
|
+
|
409
|
+
def _get_auto_default_for_type(
|
410
|
+
type_info: AnalyzedTypeInfo,
|
411
|
+
) -> tuple[Any, bool]:
|
412
|
+
"""
|
413
|
+
Get an auto-default value for a type annotation if it's safe to do so.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
A tuple of (default_value, is_supported) where:
|
417
|
+
- default_value: The default value if auto-defaulting is supported
|
418
|
+
- is_supported: True if auto-defaulting is supported for this type
|
419
|
+
"""
|
420
|
+
# Case 1: Nullable types (Optional[T] or T | None)
|
421
|
+
if type_info.nullable:
|
422
|
+
return None, True
|
423
|
+
|
424
|
+
# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
|
425
|
+
if isinstance(type_info.variant, AnalyzedListType):
|
426
|
+
return [], True
|
427
|
+
elif isinstance(type_info.variant, AnalyzedDictType):
|
428
|
+
return {}, True
|
429
|
+
|
430
|
+
return None, False
|
431
|
+
|
432
|
+
|
433
|
+
def make_engine_struct_decoder(
|
434
|
+
field_path: list[str],
|
435
|
+
src_fields: list[dict[str, Any]],
|
436
|
+
dst_type_info: AnalyzedTypeInfo,
|
437
|
+
for_key: bool = False,
|
438
|
+
) -> Callable[[list[Any]], Any]:
|
439
|
+
"""Make a decoder from an engine field values to a Python value."""
|
440
|
+
|
441
|
+
dst_type_variant = dst_type_info.variant
|
442
|
+
|
443
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
444
|
+
if for_key:
|
445
|
+
return _make_engine_struct_to_tuple_decoder(field_path, src_fields)
|
446
|
+
else:
|
447
|
+
return _make_engine_struct_to_dict_decoder(field_path, src_fields, Any)
|
448
|
+
elif isinstance(dst_type_variant, AnalyzedDictType):
|
449
|
+
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
|
450
|
+
if (
|
451
|
+
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
|
452
|
+
or analyzed_key_type.core_type is str
|
453
|
+
):
|
454
|
+
return _make_engine_struct_to_dict_decoder(
|
455
|
+
field_path, src_fields, dst_type_variant.value_type
|
456
|
+
)
|
457
|
+
|
458
|
+
if not isinstance(dst_type_variant, AnalyzedStructType):
|
459
|
+
raise ValueError(
|
460
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
461
|
+
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
|
462
|
+
)
|
463
|
+
|
464
|
+
src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
|
465
|
+
dst_struct_type = dst_type_variant.struct_type
|
466
|
+
|
467
|
+
parameters: Mapping[str, inspect.Parameter]
|
468
|
+
if dataclasses.is_dataclass(dst_struct_type):
|
469
|
+
parameters = inspect.signature(dst_struct_type).parameters
|
470
|
+
elif is_namedtuple_type(dst_struct_type):
|
471
|
+
defaults = getattr(dst_struct_type, "_field_defaults", {})
|
472
|
+
fields = getattr(dst_struct_type, "_fields", ())
|
473
|
+
parameters = {
|
474
|
+
name: inspect.Parameter(
|
475
|
+
name=name,
|
476
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
477
|
+
default=defaults.get(name, inspect.Parameter.empty),
|
478
|
+
annotation=dst_struct_type.__annotations__.get(
|
479
|
+
name, inspect.Parameter.empty
|
480
|
+
),
|
481
|
+
)
|
482
|
+
for name in fields
|
483
|
+
}
|
484
|
+
else:
|
485
|
+
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
|
486
|
+
|
487
|
+
def make_closure_for_field(
|
488
|
+
name: str, param: inspect.Parameter
|
489
|
+
) -> Callable[[list[Any]], Any]:
|
490
|
+
src_idx = src_name_to_idx.get(name)
|
491
|
+
type_info = analyze_type_info(param.annotation)
|
492
|
+
|
493
|
+
with ChildFieldPath(field_path, f".{name}"):
|
494
|
+
if src_idx is not None:
|
495
|
+
field_decoder = make_engine_value_decoder(
|
496
|
+
field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
|
497
|
+
)
|
498
|
+
return lambda values: field_decoder(values[src_idx])
|
499
|
+
|
500
|
+
default_value = param.default
|
501
|
+
if default_value is not inspect.Parameter.empty:
|
502
|
+
return lambda _: default_value
|
503
|
+
|
504
|
+
auto_default, is_supported = _get_auto_default_for_type(type_info)
|
505
|
+
if is_supported:
|
506
|
+
warnings.warn(
|
507
|
+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
|
508
|
+
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
|
509
|
+
UserWarning,
|
510
|
+
stacklevel=4,
|
511
|
+
)
|
512
|
+
return lambda _: auto_default
|
513
|
+
|
514
|
+
raise ValueError(
|
515
|
+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
|
516
|
+
)
|
517
|
+
|
518
|
+
field_value_decoder = [
|
519
|
+
make_closure_for_field(name, param) for (name, param) in parameters.items()
|
520
|
+
]
|
521
|
+
|
522
|
+
return lambda values: dst_struct_type(
|
523
|
+
*(decoder(values) for decoder in field_value_decoder)
|
524
|
+
)
|
525
|
+
|
526
|
+
|
527
|
+
def _make_engine_struct_to_dict_decoder(
|
528
|
+
field_path: list[str],
|
529
|
+
src_fields: list[dict[str, Any]],
|
530
|
+
value_type_annotation: Any,
|
531
|
+
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
|
532
|
+
"""Make a decoder from engine field values to a Python dict."""
|
533
|
+
|
534
|
+
field_decoders = []
|
535
|
+
value_type_info = analyze_type_info(value_type_annotation)
|
536
|
+
for field_schema in src_fields:
|
537
|
+
field_name = field_schema["name"]
|
538
|
+
with ChildFieldPath(field_path, f".{field_name}"):
|
539
|
+
field_decoder = make_engine_value_decoder(
|
540
|
+
field_path,
|
541
|
+
field_schema["type"],
|
542
|
+
value_type_info,
|
543
|
+
)
|
544
|
+
field_decoders.append((field_name, field_decoder))
|
545
|
+
|
546
|
+
def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
|
547
|
+
if values is None:
|
548
|
+
return None
|
549
|
+
if len(field_decoders) != len(values):
|
550
|
+
raise ValueError(
|
551
|
+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
|
552
|
+
)
|
553
|
+
return {
|
554
|
+
field_name: field_decoder(value)
|
555
|
+
for value, (field_name, field_decoder) in zip(values, field_decoders)
|
556
|
+
}
|
557
|
+
|
558
|
+
return decode_to_dict
|
559
|
+
|
560
|
+
|
561
|
+
def _make_engine_struct_to_tuple_decoder(
|
562
|
+
field_path: list[str],
|
563
|
+
src_fields: list[dict[str, Any]],
|
564
|
+
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
|
565
|
+
"""Make a decoder from engine field values to a Python tuple."""
|
566
|
+
|
567
|
+
field_decoders = []
|
568
|
+
value_type_info = analyze_type_info(Any)
|
569
|
+
for field_schema in src_fields:
|
570
|
+
field_name = field_schema["name"]
|
571
|
+
with ChildFieldPath(field_path, f".{field_name}"):
|
572
|
+
field_decoders.append(
|
573
|
+
make_engine_value_decoder(
|
574
|
+
field_path,
|
575
|
+
field_schema["type"],
|
576
|
+
value_type_info,
|
577
|
+
)
|
578
|
+
)
|
579
|
+
|
580
|
+
def decode_to_tuple(values: list[Any] | None) -> tuple[Any, ...] | None:
|
581
|
+
if values is None:
|
582
|
+
return None
|
583
|
+
if len(field_decoders) != len(values):
|
584
|
+
raise ValueError(
|
585
|
+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
|
586
|
+
)
|
587
|
+
return tuple(
|
588
|
+
field_decoder(value) for value, field_decoder in zip(values, field_decoders)
|
589
|
+
)
|
590
|
+
|
591
|
+
return decode_to_tuple
|
592
|
+
|
593
|
+
|
594
|
+
def dump_engine_object(v: Any) -> Any:
|
595
|
+
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
|
596
|
+
if v is None:
|
597
|
+
return None
|
598
|
+
elif isinstance(v, type) or get_origin(v) is not None:
|
599
|
+
return encode_enriched_type(v)
|
600
|
+
elif isinstance(v, Enum):
|
601
|
+
return v.value
|
602
|
+
elif isinstance(v, datetime.timedelta):
|
603
|
+
total_secs = v.total_seconds()
|
604
|
+
secs = int(total_secs)
|
605
|
+
nanos = int((total_secs - secs) * 1e9)
|
606
|
+
return {"secs": secs, "nanos": nanos}
|
607
|
+
elif hasattr(v, "__dict__"):
|
608
|
+
s = {}
|
609
|
+
for k, val in v.__dict__.items():
|
610
|
+
if val is None:
|
611
|
+
# Skip None values
|
612
|
+
continue
|
613
|
+
s[k] = dump_engine_object(val)
|
614
|
+
if hasattr(v, "kind") and "kind" not in s:
|
615
|
+
s["kind"] = v.kind
|
616
|
+
return s
|
617
|
+
elif isinstance(v, (list, tuple)):
|
618
|
+
return [dump_engine_object(item) for item in v]
|
619
|
+
elif isinstance(v, dict):
|
620
|
+
return {k: dump_engine_object(v) for k, v in v.items()}
|
621
|
+
return v
|