cocoindex 0.1.72__cp312-cp312-macosx_11_0_arm64.whl → 0.1.73__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/_engine.cpython-312-darwin.so +0 -0
- cocoindex/convert.py +138 -95
- cocoindex/tests/test_convert.py +23 -2
- cocoindex/tests/test_typing.py +97 -207
- cocoindex/typing.py +173 -130
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.73.dist-info}/METADATA +1 -1
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.73.dist-info}/RECORD +10 -10
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.73.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.73.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.73.dist-info}/licenses/LICENSE +0 -0
Binary file
|
cocoindex/convert.py
CHANGED
@@ -13,12 +13,19 @@ import numpy as np
|
|
13
13
|
from .typing import (
|
14
14
|
KEY_FIELD_NAME,
|
15
15
|
TABLE_TYPES,
|
16
|
-
DtypeRegistry,
|
17
16
|
analyze_type_info,
|
18
17
|
encode_enriched_type,
|
19
|
-
extract_ndarray_scalar_dtype,
|
20
18
|
is_namedtuple_type,
|
21
19
|
is_struct_type,
|
20
|
+
AnalyzedTypeInfo,
|
21
|
+
AnalyzedAnyType,
|
22
|
+
AnalyzedDictType,
|
23
|
+
AnalyzedListType,
|
24
|
+
AnalyzedBasicType,
|
25
|
+
AnalyzedUnionType,
|
26
|
+
AnalyzedUnknownType,
|
27
|
+
AnalyzedStructType,
|
28
|
+
is_numpy_number_type,
|
22
29
|
)
|
23
30
|
|
24
31
|
|
@@ -79,46 +86,88 @@ def make_engine_value_decoder(
|
|
79
86
|
Returns:
|
80
87
|
A decoder from an engine value to a Python value.
|
81
88
|
"""
|
89
|
+
|
82
90
|
src_type_kind = src_type["kind"]
|
83
91
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
92
|
+
dst_type_info = analyze_type_info(dst_annotation)
|
93
|
+
dst_type_variant = dst_type_info.variant
|
94
|
+
|
95
|
+
if isinstance(dst_type_variant, AnalyzedUnknownType):
|
96
|
+
raise ValueError(
|
97
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
98
|
+
f"declared `{dst_type_info.core_type}`, an unsupported type"
|
99
|
+
)
|
100
|
+
|
101
|
+
if src_type_kind == "Struct":
|
102
|
+
return _make_engine_struct_value_decoder(
|
103
|
+
field_path,
|
104
|
+
src_type["fields"],
|
105
|
+
dst_type_info,
|
106
|
+
)
|
107
|
+
|
108
|
+
if src_type_kind in TABLE_TYPES:
|
109
|
+
field_path.append("[*]")
|
110
|
+
engine_fields_schema = src_type["row"]["fields"]
|
111
|
+
|
112
|
+
if src_type_kind == "LTable":
|
113
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
96
114
|
return _make_engine_ltable_to_list_dict_decoder(
|
97
|
-
field_path,
|
115
|
+
field_path, engine_fields_schema
|
116
|
+
)
|
117
|
+
if not isinstance(dst_type_variant, AnalyzedListType):
|
118
|
+
raise ValueError(
|
119
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
120
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
98
121
|
)
|
99
|
-
|
122
|
+
row_decoder = _make_engine_struct_value_decoder(
|
123
|
+
field_path,
|
124
|
+
engine_fields_schema,
|
125
|
+
analyze_type_info(dst_type_variant.elem_type),
|
126
|
+
)
|
127
|
+
|
128
|
+
def decode(value: Any) -> Any | None:
|
129
|
+
if value is None:
|
130
|
+
return None
|
131
|
+
return [row_decoder(v) for v in value]
|
132
|
+
|
133
|
+
elif src_type_kind == "KTable":
|
134
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
100
135
|
return _make_engine_ktable_to_dict_dict_decoder(
|
101
|
-
field_path,
|
136
|
+
field_path, engine_fields_schema
|
137
|
+
)
|
138
|
+
if not isinstance(dst_type_variant, AnalyzedDictType):
|
139
|
+
raise ValueError(
|
140
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
141
|
+
f"declared `{dst_type_info.core_type}`, a dict type expected"
|
102
142
|
)
|
103
|
-
return lambda value: value
|
104
143
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
144
|
+
key_field_schema = engine_fields_schema[0]
|
145
|
+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
|
146
|
+
key_decoder = make_engine_value_decoder(
|
147
|
+
field_path, key_field_schema["type"], dst_type_variant.key_type
|
148
|
+
)
|
149
|
+
field_path.pop()
|
150
|
+
value_decoder = _make_engine_struct_value_decoder(
|
151
|
+
field_path,
|
152
|
+
engine_fields_schema[1:],
|
153
|
+
analyze_type_info(dst_type_variant.value_type),
|
154
|
+
)
|
115
155
|
|
116
|
-
|
156
|
+
def decode(value: Any) -> Any | None:
|
157
|
+
if value is None:
|
158
|
+
return None
|
159
|
+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
|
160
|
+
|
161
|
+
field_path.pop()
|
162
|
+
return decode
|
117
163
|
|
118
164
|
if src_type_kind == "Union":
|
165
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
166
|
+
return lambda value: value[1]
|
167
|
+
|
119
168
|
dst_type_variants = (
|
120
|
-
|
121
|
-
if
|
169
|
+
dst_type_variant.variant_types
|
170
|
+
if isinstance(dst_type_variant, AnalyzedUnionType)
|
122
171
|
else [dst_annotation]
|
123
172
|
)
|
124
173
|
src_type_variants = src_type["types"]
|
@@ -142,43 +191,33 @@ def make_engine_value_decoder(
|
|
142
191
|
decoders.append(decoder)
|
143
192
|
return lambda value: decoders[value[0]](value[1])
|
144
193
|
|
145
|
-
if
|
146
|
-
|
147
|
-
f"Type mismatch for `{''.join(field_path)}`: "
|
148
|
-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
|
149
|
-
)
|
150
|
-
|
151
|
-
if dst_type_info.kind in ("Float32", "Float64", "Int64"):
|
152
|
-
dst_core_type = dst_type_info.core_type
|
153
|
-
|
154
|
-
def decode_scalar(value: Any) -> Any | None:
|
155
|
-
if value is None:
|
156
|
-
if dst_type_info.nullable:
|
157
|
-
return None
|
158
|
-
raise ValueError(
|
159
|
-
f"Received null for non-nullable scalar `{''.join(field_path)}`"
|
160
|
-
)
|
161
|
-
return dst_core_type(value)
|
162
|
-
|
163
|
-
return decode_scalar
|
194
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
195
|
+
return lambda value: value
|
164
196
|
|
165
197
|
if src_type_kind == "Vector":
|
166
198
|
field_path_str = "".join(field_path)
|
199
|
+
if not isinstance(dst_type_variant, AnalyzedListType):
|
200
|
+
raise ValueError(
|
201
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
202
|
+
f"declared `{dst_type_info.core_type}`, a list type expected"
|
203
|
+
)
|
167
204
|
expected_dim = (
|
168
|
-
|
205
|
+
dst_type_variant.vector_info.dim
|
206
|
+
if dst_type_variant and dst_type_variant.vector_info
|
207
|
+
else None
|
169
208
|
)
|
170
209
|
|
171
|
-
|
210
|
+
vec_elem_decoder = None
|
172
211
|
scalar_dtype = None
|
173
|
-
if dst_type_info.
|
174
|
-
|
212
|
+
if dst_type_variant and dst_type_info.base_type is np.ndarray:
|
213
|
+
if is_numpy_number_type(dst_type_variant.elem_type):
|
214
|
+
scalar_dtype = dst_type_variant.elem_type
|
215
|
+
else:
|
216
|
+
vec_elem_decoder = make_engine_value_decoder(
|
175
217
|
field_path + ["[*]"],
|
176
218
|
src_type["element_type"],
|
177
|
-
|
219
|
+
dst_type_variant and dst_type_variant.elem_type,
|
178
220
|
)
|
179
|
-
else: # for NDArray vector
|
180
|
-
scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
|
181
|
-
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
|
182
221
|
|
183
222
|
def decode_vector(value: Any) -> Any | None:
|
184
223
|
if value is None:
|
@@ -197,54 +236,33 @@ def make_engine_value_decoder(
|
|
197
236
|
f"expected {expected_dim}, got {len(value)}"
|
198
237
|
)
|
199
238
|
|
200
|
-
if
|
201
|
-
return [
|
239
|
+
if vec_elem_decoder is not None: # for Non-NDArray vector
|
240
|
+
return [vec_elem_decoder(v) for v in value]
|
202
241
|
else: # for NDArray vector
|
203
242
|
return np.array(value, dtype=scalar_dtype)
|
204
243
|
|
205
244
|
return decode_vector
|
206
245
|
|
207
|
-
if
|
208
|
-
|
209
|
-
field_path, src_type["fields"], dst_type_info.struct_type
|
210
|
-
)
|
211
|
-
|
212
|
-
if src_type_kind in TABLE_TYPES:
|
213
|
-
field_path.append("[*]")
|
214
|
-
elem_type_info = analyze_type_info(dst_type_info.elem_type)
|
215
|
-
if elem_type_info.struct_type is None:
|
246
|
+
if isinstance(dst_type_variant, AnalyzedBasicType):
|
247
|
+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
|
216
248
|
raise ValueError(
|
217
249
|
f"Type mismatch for `{''.join(field_path)}`: "
|
218
|
-
f"declared
|
219
|
-
)
|
220
|
-
engine_fields_schema = src_type["row"]["fields"]
|
221
|
-
if elem_type_info.key_type is not None:
|
222
|
-
key_field_schema = engine_fields_schema[0]
|
223
|
-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
|
224
|
-
key_decoder = make_engine_value_decoder(
|
225
|
-
field_path, key_field_schema["type"], elem_type_info.key_type
|
226
|
-
)
|
227
|
-
field_path.pop()
|
228
|
-
value_decoder = _make_engine_struct_value_decoder(
|
229
|
-
field_path, engine_fields_schema[1:], elem_type_info.struct_type
|
250
|
+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_variant.kind})"
|
230
251
|
)
|
231
252
|
|
232
|
-
|
233
|
-
|
234
|
-
return None
|
235
|
-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
|
236
|
-
else:
|
237
|
-
elem_decoder = _make_engine_struct_value_decoder(
|
238
|
-
field_path, engine_fields_schema, elem_type_info.struct_type
|
239
|
-
)
|
253
|
+
if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
|
254
|
+
dst_core_type = dst_type_info.core_type
|
240
255
|
|
241
|
-
def
|
256
|
+
def decode_scalar(value: Any) -> Any | None:
|
242
257
|
if value is None:
|
243
|
-
|
244
|
-
|
258
|
+
if dst_type_info.nullable:
|
259
|
+
return None
|
260
|
+
raise ValueError(
|
261
|
+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
|
262
|
+
)
|
263
|
+
return dst_core_type(value)
|
245
264
|
|
246
|
-
|
247
|
-
return decode
|
265
|
+
return decode_scalar
|
248
266
|
|
249
267
|
return lambda value: value
|
250
268
|
|
@@ -252,11 +270,36 @@ def make_engine_value_decoder(
|
|
252
270
|
def _make_engine_struct_value_decoder(
|
253
271
|
field_path: list[str],
|
254
272
|
src_fields: list[dict[str, Any]],
|
255
|
-
|
273
|
+
dst_type_info: AnalyzedTypeInfo,
|
256
274
|
) -> Callable[[list[Any]], Any]:
|
257
275
|
"""Make a decoder from an engine field values to a Python value."""
|
258
276
|
|
277
|
+
dst_type_variant = dst_type_info.variant
|
278
|
+
|
279
|
+
use_dict = False
|
280
|
+
if isinstance(dst_type_variant, AnalyzedAnyType):
|
281
|
+
use_dict = True
|
282
|
+
elif isinstance(dst_type_variant, AnalyzedDictType):
|
283
|
+
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
|
284
|
+
analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
|
285
|
+
use_dict = (
|
286
|
+
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
|
287
|
+
or (
|
288
|
+
isinstance(analyzed_key_type.variant, AnalyzedBasicType)
|
289
|
+
and analyzed_key_type.variant.kind == "Str"
|
290
|
+
)
|
291
|
+
) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
|
292
|
+
if use_dict:
|
293
|
+
return _make_engine_struct_to_dict_decoder(field_path, src_fields)
|
294
|
+
|
295
|
+
if not isinstance(dst_type_variant, AnalyzedStructType):
|
296
|
+
raise ValueError(
|
297
|
+
f"Type mismatch for `{''.join(field_path)}`: "
|
298
|
+
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
|
299
|
+
)
|
300
|
+
|
259
301
|
src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
|
302
|
+
dst_struct_type = dst_type_variant.struct_type
|
260
303
|
|
261
304
|
parameters: Mapping[str, inspect.Parameter]
|
262
305
|
if dataclasses.is_dataclass(dst_struct_type):
|
cocoindex/tests/test_convert.py
CHANGED
@@ -105,7 +105,9 @@ def validate_full_roundtrip_to(
|
|
105
105
|
for other_value, other_type in decoded_values:
|
106
106
|
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
|
107
107
|
other_decoded_value = decoder(value_from_engine)
|
108
|
-
assert eq(other_decoded_value, other_value)
|
108
|
+
assert eq(other_decoded_value, other_value), (
|
109
|
+
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
|
110
|
+
)
|
109
111
|
|
110
112
|
|
111
113
|
def validate_full_roundtrip(
|
@@ -1096,6 +1098,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
|
|
1096
1098
|
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
|
1097
1099
|
|
1098
1100
|
|
1101
|
+
def test_full_roundtrip_vector_of_vector() -> None:
|
1102
|
+
"""Test full roundtrip for vector of vector."""
|
1103
|
+
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
1104
|
+
validate_full_roundtrip(
|
1105
|
+
value_f32,
|
1106
|
+
Vector[Vector[np.float32, Literal[3]], Literal[2]],
|
1107
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
|
1108
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
|
1109
|
+
(
|
1110
|
+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
1111
|
+
list[Vector[cocoindex.Float32, Literal[3]]],
|
1112
|
+
),
|
1113
|
+
(
|
1114
|
+
value_f32,
|
1115
|
+
np.typing.NDArray[np.typing.NDArray[np.float32]],
|
1116
|
+
),
|
1117
|
+
)
|
1118
|
+
|
1119
|
+
|
1099
1120
|
def test_full_roundtrip_vector_other_types() -> None:
|
1100
1121
|
"""Test full roundtrip for Vector with non-numeric basic types."""
|
1101
1122
|
uuid_list = [uuid.uuid4(), uuid.uuid4()]
|
@@ -1216,7 +1237,7 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
|
|
1216
1237
|
numpy_float: np.float64
|
1217
1238
|
python_float: float
|
1218
1239
|
string: str
|
1219
|
-
annotated_int: Annotated[np.int64, TypeKind("
|
1240
|
+
annotated_int: Annotated[np.int64, TypeKind("Int64")]
|
1220
1241
|
annotated_float: Float32
|
1221
1242
|
|
1222
1243
|
instance = MixedStruct(
|