cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +114 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +44 -0
- cocoindex/cli.py +830 -0
- cocoindex/engine_object.py +214 -0
- cocoindex/engine_value.py +550 -0
- cocoindex/flow.py +1281 -0
- cocoindex/functions/__init__.py +40 -0
- cocoindex/functions/_engine_builtin_specs.py +66 -0
- cocoindex/functions/colpali.py +247 -0
- cocoindex/functions/sbert.py +77 -0
- cocoindex/index.py +50 -0
- cocoindex/lib.py +75 -0
- cocoindex/llm.py +47 -0
- cocoindex/op.py +1047 -0
- cocoindex/py.typed +0 -0
- cocoindex/query_handler.py +57 -0
- cocoindex/runtime.py +78 -0
- cocoindex/setting.py +171 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources/__init__.py +5 -0
- cocoindex/sources/_engine_builtin_specs.py +120 -0
- cocoindex/subprocess_exec.py +277 -0
- cocoindex/targets/__init__.py +5 -0
- cocoindex/targets/_engine_builtin_specs.py +153 -0
- cocoindex/targets/lancedb.py +466 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/test_engine_object.py +331 -0
- cocoindex/tests/test_engine_value.py +1724 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +300 -0
- cocoindex/tests/test_typing.py +553 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +834 -0
- cocoindex/user_app_loader.py +53 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.3.4.dist-info/METADATA +288 -0
- cocoindex-0.3.4.dist-info/RECORD +42 -0
- cocoindex-0.3.4.dist-info/WHEEL +4 -0
- cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
- cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities to encode/decode values in cocoindex (for data).
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, Callable, TypeVar
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from .typing import (
|
|
13
|
+
AnalyzedAnyType,
|
|
14
|
+
AnalyzedBasicType,
|
|
15
|
+
AnalyzedDictType,
|
|
16
|
+
AnalyzedListType,
|
|
17
|
+
AnalyzedStructType,
|
|
18
|
+
AnalyzedTypeInfo,
|
|
19
|
+
AnalyzedUnionType,
|
|
20
|
+
AnalyzedUnknownType,
|
|
21
|
+
AnalyzedStructFieldInfo,
|
|
22
|
+
analyze_type_info,
|
|
23
|
+
is_pydantic_model,
|
|
24
|
+
is_numpy_number_type,
|
|
25
|
+
ValueType,
|
|
26
|
+
FieldSchema,
|
|
27
|
+
BasicValueType,
|
|
28
|
+
StructType,
|
|
29
|
+
TableType,
|
|
30
|
+
)
|
|
31
|
+
from .engine_object import get_auto_default_for_type
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
T = TypeVar("T")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ChildFieldPath:
|
|
38
|
+
"""Context manager to append a field to field_path on enter and pop it on exit."""
|
|
39
|
+
|
|
40
|
+
_field_path: list[str]
|
|
41
|
+
_field_name: str
|
|
42
|
+
|
|
43
|
+
def __init__(self, field_path: list[str], field_name: str):
|
|
44
|
+
self._field_path: list[str] = field_path
|
|
45
|
+
self._field_name = field_name
|
|
46
|
+
|
|
47
|
+
def __enter__(self) -> ChildFieldPath:
|
|
48
|
+
self._field_path.append(self._field_name)
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
|
|
52
|
+
self._field_path.pop()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_CONVERTIBLE_KINDS = {
|
|
56
|
+
("Float32", "Float64"),
|
|
57
|
+
("LocalDateTime", "OffsetDateTime"),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
|
|
62
|
+
return (
|
|
63
|
+
src_type_kind == dst_type_kind
|
|
64
|
+
or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# Pre-computed type info for missing/Any type annotations
|
|
69
|
+
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def make_engine_key_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
|
|
73
|
+
"""
|
|
74
|
+
Create an encoder closure for a key type.
|
|
75
|
+
"""
|
|
76
|
+
value_encoder = make_engine_value_encoder(type_info)
|
|
77
|
+
if isinstance(type_info.variant, AnalyzedBasicType):
|
|
78
|
+
return lambda value: [value_encoder(value)]
|
|
79
|
+
else:
|
|
80
|
+
return value_encoder
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
|
|
84
|
+
"""
|
|
85
|
+
Create an encoder closure for a specific type.
|
|
86
|
+
"""
|
|
87
|
+
variant = type_info.variant
|
|
88
|
+
|
|
89
|
+
if isinstance(variant, AnalyzedUnknownType):
|
|
90
|
+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
|
|
91
|
+
|
|
92
|
+
if isinstance(variant, AnalyzedListType):
|
|
93
|
+
elem_type_info = (
|
|
94
|
+
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
|
|
95
|
+
)
|
|
96
|
+
if isinstance(elem_type_info.variant, AnalyzedStructType):
|
|
97
|
+
elem_encoder = make_engine_value_encoder(elem_type_info)
|
|
98
|
+
|
|
99
|
+
def encode_struct_list(value: Any) -> Any:
|
|
100
|
+
return None if value is None else [elem_encoder(v) for v in value]
|
|
101
|
+
|
|
102
|
+
return encode_struct_list
|
|
103
|
+
|
|
104
|
+
# Otherwise it's a vector, falling into basic type in the engine.
|
|
105
|
+
|
|
106
|
+
if isinstance(variant, AnalyzedDictType):
|
|
107
|
+
key_type_info = analyze_type_info(variant.key_type)
|
|
108
|
+
key_encoder = make_engine_key_encoder(key_type_info)
|
|
109
|
+
|
|
110
|
+
value_type_info = analyze_type_info(variant.value_type)
|
|
111
|
+
if not isinstance(value_type_info.variant, AnalyzedStructType):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. "
|
|
114
|
+
f"If you want a free-formed dict, use `cocoindex.Json` instead."
|
|
115
|
+
)
|
|
116
|
+
value_encoder = make_engine_value_encoder(value_type_info)
|
|
117
|
+
|
|
118
|
+
def encode_struct_dict(value: Any) -> Any:
|
|
119
|
+
if not value:
|
|
120
|
+
return []
|
|
121
|
+
return [key_encoder(k) + value_encoder(v) for k, v in value.items()]
|
|
122
|
+
|
|
123
|
+
return encode_struct_dict
|
|
124
|
+
|
|
125
|
+
if isinstance(variant, AnalyzedStructType):
|
|
126
|
+
field_encoders = [
|
|
127
|
+
(
|
|
128
|
+
field_info.name,
|
|
129
|
+
make_engine_value_encoder(analyze_type_info(field_info.type_hint)),
|
|
130
|
+
)
|
|
131
|
+
for field_info in variant.fields
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
def encode_struct(value: Any) -> Any:
|
|
135
|
+
if value is None:
|
|
136
|
+
return None
|
|
137
|
+
return [encoder(getattr(value, name)) for name, encoder in field_encoders]
|
|
138
|
+
|
|
139
|
+
return encode_struct
|
|
140
|
+
|
|
141
|
+
def encode_basic_value(value: Any) -> Any:
|
|
142
|
+
if isinstance(value, np.number):
|
|
143
|
+
return value.item()
|
|
144
|
+
if isinstance(value, np.ndarray):
|
|
145
|
+
return value
|
|
146
|
+
if isinstance(value, (list, tuple)):
|
|
147
|
+
return [encode_basic_value(v) for v in value]
|
|
148
|
+
return value
|
|
149
|
+
|
|
150
|
+
return encode_basic_value
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def make_engine_key_decoder(
|
|
154
|
+
field_path: list[str],
|
|
155
|
+
key_fields_schema: list[FieldSchema],
|
|
156
|
+
dst_type_info: AnalyzedTypeInfo,
|
|
157
|
+
) -> Callable[[Any], Any]:
|
|
158
|
+
"""
|
|
159
|
+
Create an encoder closure for a key type.
|
|
160
|
+
"""
|
|
161
|
+
if len(key_fields_schema) == 1 and isinstance(
|
|
162
|
+
dst_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType)
|
|
163
|
+
):
|
|
164
|
+
single_key_decoder = make_engine_value_decoder(
|
|
165
|
+
field_path,
|
|
166
|
+
key_fields_schema[0].value_type.type,
|
|
167
|
+
dst_type_info,
|
|
168
|
+
for_key=True,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def key_decoder(value: list[Any]) -> Any:
|
|
172
|
+
return single_key_decoder(value[0])
|
|
173
|
+
|
|
174
|
+
return key_decoder
|
|
175
|
+
|
|
176
|
+
return make_engine_struct_decoder(
|
|
177
|
+
field_path,
|
|
178
|
+
key_fields_schema,
|
|
179
|
+
dst_type_info,
|
|
180
|
+
for_key=True,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def make_engine_value_decoder(
|
|
185
|
+
field_path: list[str],
|
|
186
|
+
src_type: ValueType,
|
|
187
|
+
dst_type_info: AnalyzedTypeInfo,
|
|
188
|
+
for_key: bool = False,
|
|
189
|
+
) -> Callable[[Any], Any]:
|
|
190
|
+
"""
|
|
191
|
+
Make a decoder from an engine value to a Python value.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
field_path: The path to the field in the engine value. For error messages.
|
|
195
|
+
src_type: The type of the engine value, mapped from a `cocoindex::base::schema::ValueType`.
|
|
196
|
+
dst_annotation: The type annotation of the Python value.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
A decoder from an engine value to a Python value.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
src_type_kind = src_type.kind
|
|
203
|
+
|
|
204
|
+
dst_type_variant = dst_type_info.variant
|
|
205
|
+
|
|
206
|
+
if isinstance(dst_type_variant, AnalyzedUnknownType):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
209
|
+
f"declared `{dst_type_info.core_type}`, an unsupported type"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if isinstance(src_type, StructType): # type: ignore[redundant-cast]
|
|
213
|
+
return make_engine_struct_decoder(
|
|
214
|
+
field_path,
|
|
215
|
+
src_type.fields,
|
|
216
|
+
dst_type_info,
|
|
217
|
+
for_key=for_key,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if isinstance(src_type, TableType): # type: ignore[redundant-cast]
|
|
221
|
+
with ChildFieldPath(field_path, "[*]"):
|
|
222
|
+
engine_fields_schema = src_type.row.fields
|
|
223
|
+
|
|
224
|
+
if src_type.kind == "LTable":
|
|
225
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
|
226
|
+
dst_elem_type = Any
|
|
227
|
+
elif isinstance(dst_type_variant, AnalyzedListType):
|
|
228
|
+
dst_elem_type = dst_type_variant.elem_type
|
|
229
|
+
else:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
232
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
|
233
|
+
)
|
|
234
|
+
row_decoder = make_engine_struct_decoder(
|
|
235
|
+
field_path,
|
|
236
|
+
engine_fields_schema,
|
|
237
|
+
analyze_type_info(dst_elem_type),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def decode(value: Any) -> Any | None:
|
|
241
|
+
if value is None:
|
|
242
|
+
return None
|
|
243
|
+
return [row_decoder(v) for v in value]
|
|
244
|
+
|
|
245
|
+
elif src_type.kind == "KTable":
|
|
246
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
|
247
|
+
key_type, value_type = Any, Any
|
|
248
|
+
elif isinstance(dst_type_variant, AnalyzedDictType):
|
|
249
|
+
key_type = dst_type_variant.key_type
|
|
250
|
+
value_type = dst_type_variant.value_type
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
254
|
+
f"declared `{dst_type_info.core_type}`, a dict type expected"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
num_key_parts = src_type.num_key_parts or 1
|
|
258
|
+
key_decoder = make_engine_key_decoder(
|
|
259
|
+
field_path,
|
|
260
|
+
engine_fields_schema[0:num_key_parts],
|
|
261
|
+
analyze_type_info(key_type),
|
|
262
|
+
)
|
|
263
|
+
value_decoder = make_engine_struct_decoder(
|
|
264
|
+
field_path,
|
|
265
|
+
engine_fields_schema[num_key_parts:],
|
|
266
|
+
analyze_type_info(value_type),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def decode(value: Any) -> Any | None:
|
|
270
|
+
if value is None:
|
|
271
|
+
return None
|
|
272
|
+
return {
|
|
273
|
+
key_decoder(v[0:num_key_parts]): value_decoder(
|
|
274
|
+
v[num_key_parts:]
|
|
275
|
+
)
|
|
276
|
+
for v in value
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
return decode
|
|
280
|
+
|
|
281
|
+
if isinstance(src_type, BasicValueType) and src_type.kind == "Union":
|
|
282
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
|
283
|
+
return lambda value: value[1]
|
|
284
|
+
|
|
285
|
+
dst_type_info_variants = (
|
|
286
|
+
[analyze_type_info(t) for t in dst_type_variant.variant_types]
|
|
287
|
+
if isinstance(dst_type_variant, AnalyzedUnionType)
|
|
288
|
+
else [dst_type_info]
|
|
289
|
+
)
|
|
290
|
+
# mypy: union info exists for Union kind
|
|
291
|
+
assert src_type.union is not None # type: ignore[unreachable]
|
|
292
|
+
src_type_variants_basic: list[BasicValueType] = src_type.union.variants
|
|
293
|
+
src_type_variants = src_type_variants_basic
|
|
294
|
+
decoders = []
|
|
295
|
+
for i, src_type_variant in enumerate(src_type_variants):
|
|
296
|
+
with ChildFieldPath(field_path, f"[{i}]"):
|
|
297
|
+
decoder = None
|
|
298
|
+
for dst_type_info_variant in dst_type_info_variants:
|
|
299
|
+
try:
|
|
300
|
+
decoder = make_engine_value_decoder(
|
|
301
|
+
field_path, src_type_variant, dst_type_info_variant
|
|
302
|
+
)
|
|
303
|
+
break
|
|
304
|
+
except ValueError:
|
|
305
|
+
pass
|
|
306
|
+
if decoder is None:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
309
|
+
f"cannot find matched target type for source type variant {src_type_variant}"
|
|
310
|
+
)
|
|
311
|
+
decoders.append(decoder)
|
|
312
|
+
return lambda value: decoders[value[0]](value[1])
|
|
313
|
+
|
|
314
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
|
315
|
+
return lambda value: value
|
|
316
|
+
|
|
317
|
+
if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
|
|
318
|
+
field_path_str = "".join(field_path)
|
|
319
|
+
if not isinstance(dst_type_variant, AnalyzedListType):
|
|
320
|
+
raise ValueError(
|
|
321
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
322
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
|
323
|
+
)
|
|
324
|
+
expected_dim = (
|
|
325
|
+
dst_type_variant.vector_info.dim
|
|
326
|
+
if dst_type_variant and dst_type_variant.vector_info
|
|
327
|
+
else None
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
vec_elem_decoder = None
|
|
331
|
+
scalar_dtype = None
|
|
332
|
+
if dst_type_variant and dst_type_info.base_type is np.ndarray:
|
|
333
|
+
if is_numpy_number_type(dst_type_variant.elem_type):
|
|
334
|
+
scalar_dtype = dst_type_variant.elem_type
|
|
335
|
+
else:
|
|
336
|
+
# mypy: vector info exists for Vector kind
|
|
337
|
+
assert src_type.vector is not None # type: ignore[unreachable]
|
|
338
|
+
vec_elem_decoder = make_engine_value_decoder(
|
|
339
|
+
field_path + ["[*]"],
|
|
340
|
+
src_type.vector.element_type,
|
|
341
|
+
analyze_type_info(
|
|
342
|
+
dst_type_variant.elem_type if dst_type_variant else Any
|
|
343
|
+
),
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
def decode_vector(value: Any) -> Any | None:
|
|
347
|
+
if value is None:
|
|
348
|
+
if dst_type_info.nullable:
|
|
349
|
+
return None
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Received null for non-nullable vector `{field_path_str}`"
|
|
352
|
+
)
|
|
353
|
+
if not isinstance(value, (np.ndarray, list)):
|
|
354
|
+
raise TypeError(
|
|
355
|
+
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
|
|
356
|
+
)
|
|
357
|
+
if expected_dim is not None and len(value) != expected_dim:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Vector dimension mismatch for `{field_path_str}`: "
|
|
360
|
+
f"expected {expected_dim}, got {len(value)}"
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if vec_elem_decoder is not None: # for Non-NDArray vector
|
|
364
|
+
return [vec_elem_decoder(v) for v in value]
|
|
365
|
+
else: # for NDArray vector
|
|
366
|
+
return np.array(value, dtype=scalar_dtype)
|
|
367
|
+
|
|
368
|
+
return decode_vector
|
|
369
|
+
|
|
370
|
+
if isinstance(dst_type_variant, AnalyzedBasicType):
|
|
371
|
+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
|
|
372
|
+
raise ValueError(
|
|
373
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
374
|
+
f"passed in {src_type_kind}, declared {dst_type_info.core_type} ({dst_type_variant.kind})"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
|
|
378
|
+
dst_core_type = dst_type_info.core_type
|
|
379
|
+
|
|
380
|
+
def decode_scalar(value: Any) -> Any | None:
|
|
381
|
+
if value is None:
|
|
382
|
+
if dst_type_info.nullable:
|
|
383
|
+
return None
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
|
|
386
|
+
)
|
|
387
|
+
return dst_core_type(value)
|
|
388
|
+
|
|
389
|
+
return decode_scalar
|
|
390
|
+
|
|
391
|
+
return lambda value: value
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def make_engine_struct_decoder(
|
|
395
|
+
field_path: list[str],
|
|
396
|
+
src_fields: list[FieldSchema],
|
|
397
|
+
dst_type_info: AnalyzedTypeInfo,
|
|
398
|
+
for_key: bool = False,
|
|
399
|
+
) -> Callable[[list[Any]], Any]:
|
|
400
|
+
"""Make a decoder from an engine field values to a Python value."""
|
|
401
|
+
|
|
402
|
+
dst_type_variant = dst_type_info.variant
|
|
403
|
+
|
|
404
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
|
405
|
+
if for_key:
|
|
406
|
+
return _make_engine_struct_to_tuple_decoder(field_path, src_fields)
|
|
407
|
+
else:
|
|
408
|
+
return _make_engine_struct_to_dict_decoder(field_path, src_fields, Any)
|
|
409
|
+
elif isinstance(dst_type_variant, AnalyzedDictType):
|
|
410
|
+
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
|
|
411
|
+
if (
|
|
412
|
+
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
|
|
413
|
+
or analyzed_key_type.core_type is str
|
|
414
|
+
):
|
|
415
|
+
return _make_engine_struct_to_dict_decoder(
|
|
416
|
+
field_path, src_fields, dst_type_variant.value_type
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if not isinstance(dst_type_variant, AnalyzedStructType):
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
|
422
|
+
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple, Pydantic model or dict[str, Any] expected"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
|
|
426
|
+
dst_struct_type = dst_type_variant.struct_type
|
|
427
|
+
|
|
428
|
+
def make_closure_for_field(
|
|
429
|
+
field_info: AnalyzedStructFieldInfo,
|
|
430
|
+
) -> Callable[[list[Any]], Any]:
|
|
431
|
+
name = field_info.name
|
|
432
|
+
src_idx = src_name_to_idx.get(name)
|
|
433
|
+
type_info = analyze_type_info(field_info.type_hint)
|
|
434
|
+
|
|
435
|
+
with ChildFieldPath(field_path, f".{name}"):
|
|
436
|
+
if src_idx is not None:
|
|
437
|
+
field_decoder = make_engine_value_decoder(
|
|
438
|
+
field_path,
|
|
439
|
+
src_fields[src_idx].value_type.type,
|
|
440
|
+
type_info,
|
|
441
|
+
for_key=for_key,
|
|
442
|
+
)
|
|
443
|
+
return lambda values: field_decoder(values[src_idx])
|
|
444
|
+
|
|
445
|
+
default_value = field_info.default_value
|
|
446
|
+
if default_value is not inspect.Parameter.empty:
|
|
447
|
+
return lambda _: default_value
|
|
448
|
+
|
|
449
|
+
auto_default, is_supported = get_auto_default_for_type(type_info)
|
|
450
|
+
if is_supported:
|
|
451
|
+
warnings.warn(
|
|
452
|
+
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: "
|
|
453
|
+
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
|
|
454
|
+
UserWarning,
|
|
455
|
+
stacklevel=4,
|
|
456
|
+
)
|
|
457
|
+
return lambda _: auto_default
|
|
458
|
+
|
|
459
|
+
raise ValueError(
|
|
460
|
+
f"Field '{name}' (type {field_info.type_hint}) without default value is missing in input: {''.join(field_path)}"
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Different construction for different struct types
|
|
464
|
+
if is_pydantic_model(dst_struct_type):
|
|
465
|
+
# Pydantic models prefer keyword arguments
|
|
466
|
+
pydantic_fields_decoder = [
|
|
467
|
+
(field_info.name, make_closure_for_field(field_info))
|
|
468
|
+
for field_info in dst_type_variant.fields
|
|
469
|
+
]
|
|
470
|
+
return lambda values: dst_struct_type(
|
|
471
|
+
**{
|
|
472
|
+
field_name: decoder(values)
|
|
473
|
+
for field_name, decoder in pydantic_fields_decoder
|
|
474
|
+
}
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
struct_fields_decoder = [
|
|
478
|
+
make_closure_for_field(field_info) for field_info in dst_type_variant.fields
|
|
479
|
+
]
|
|
480
|
+
# Dataclasses and NamedTuples can use positional arguments
|
|
481
|
+
return lambda values: dst_struct_type(
|
|
482
|
+
*(decoder(values) for decoder in struct_fields_decoder)
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _make_engine_struct_to_dict_decoder(
|
|
487
|
+
field_path: list[str],
|
|
488
|
+
src_fields: list[FieldSchema],
|
|
489
|
+
value_type_annotation: Any,
|
|
490
|
+
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
|
|
491
|
+
"""Make a decoder from engine field values to a Python dict."""
|
|
492
|
+
|
|
493
|
+
field_decoders = []
|
|
494
|
+
value_type_info = analyze_type_info(value_type_annotation)
|
|
495
|
+
for field_schema in src_fields:
|
|
496
|
+
field_name = field_schema.name
|
|
497
|
+
with ChildFieldPath(field_path, f".{field_name}"):
|
|
498
|
+
field_decoder = make_engine_value_decoder(
|
|
499
|
+
field_path,
|
|
500
|
+
field_schema.value_type.type,
|
|
501
|
+
value_type_info,
|
|
502
|
+
)
|
|
503
|
+
field_decoders.append((field_name, field_decoder))
|
|
504
|
+
|
|
505
|
+
def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
|
|
506
|
+
if values is None:
|
|
507
|
+
return None
|
|
508
|
+
if len(field_decoders) != len(values):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
|
|
511
|
+
)
|
|
512
|
+
return {
|
|
513
|
+
field_name: field_decoder(value)
|
|
514
|
+
for value, (field_name, field_decoder) in zip(values, field_decoders)
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
return decode_to_dict
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def _make_engine_struct_to_tuple_decoder(
|
|
521
|
+
field_path: list[str],
|
|
522
|
+
src_fields: list[FieldSchema],
|
|
523
|
+
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
|
|
524
|
+
"""Make a decoder from engine field values to a Python tuple."""
|
|
525
|
+
|
|
526
|
+
field_decoders = []
|
|
527
|
+
value_type_info = analyze_type_info(Any)
|
|
528
|
+
for field_schema in src_fields:
|
|
529
|
+
field_name = field_schema.name
|
|
530
|
+
with ChildFieldPath(field_path, f".{field_name}"):
|
|
531
|
+
field_decoders.append(
|
|
532
|
+
make_engine_value_decoder(
|
|
533
|
+
field_path,
|
|
534
|
+
field_schema.value_type.type,
|
|
535
|
+
value_type_info,
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
def decode_to_tuple(values: list[Any] | None) -> tuple[Any, ...] | None:
|
|
540
|
+
if values is None:
|
|
541
|
+
return None
|
|
542
|
+
if len(field_decoders) != len(values):
|
|
543
|
+
raise ValueError(
|
|
544
|
+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
|
|
545
|
+
)
|
|
546
|
+
return tuple(
|
|
547
|
+
field_decoder(value) for value, field_decoder in zip(values, field_decoders)
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
return decode_to_tuple
|