cocoindex 0.2.3__cp311-abi3-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/convert.py ADDED
@@ -0,0 +1,621 @@
1
+ """
2
+ Utilities to convert between Python and engine values.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ import datetime
9
+ import inspect
10
+ import warnings
11
+ from enum import Enum
12
+ from typing import Any, Callable, Mapping, Sequence, Type, get_origin
13
+
14
+ import numpy as np
15
+
16
+ from .typing import (
17
+ TABLE_TYPES,
18
+ AnalyzedAnyType,
19
+ AnalyzedBasicType,
20
+ AnalyzedDictType,
21
+ AnalyzedListType,
22
+ AnalyzedStructType,
23
+ AnalyzedTypeInfo,
24
+ AnalyzedUnionType,
25
+ AnalyzedUnknownType,
26
+ analyze_type_info,
27
+ encode_enriched_type,
28
+ is_namedtuple_type,
29
+ is_numpy_number_type,
30
+ )
31
+
32
+
33
+ class ChildFieldPath:
34
+ """Context manager to append a field to field_path on enter and pop it on exit."""
35
+
36
+ _field_path: list[str]
37
+ _field_name: str
38
+
39
+ def __init__(self, field_path: list[str], field_name: str):
40
+ self._field_path: list[str] = field_path
41
+ self._field_name = field_name
42
+
43
+ def __enter__(self) -> ChildFieldPath:
44
+ self._field_path.append(self._field_name)
45
+ return self
46
+
47
+ def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
48
+ self._field_path.pop()
49
+
50
+
51
+ _CONVERTIBLE_KINDS = {
52
+ ("Float32", "Float64"),
53
+ ("LocalDateTime", "OffsetDateTime"),
54
+ }
55
+
56
+
57
+ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
58
+ return (
59
+ src_type_kind == dst_type_kind
60
+ or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
61
+ )
62
+
63
+
64
+ # Pre-computed type info for missing/Any type annotations
65
+ ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
66
+
67
+
68
+ def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
69
+ """
70
+ Create an encoder closure for a specific type.
71
+ """
72
+ variant = type_info.variant
73
+
74
+ if isinstance(variant, AnalyzedUnknownType):
75
+ raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
76
+
77
+ if isinstance(variant, AnalyzedListType):
78
+ elem_type_info = (
79
+ analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
80
+ )
81
+ if isinstance(elem_type_info.variant, AnalyzedStructType):
82
+ elem_encoder = make_engine_value_encoder(elem_type_info)
83
+
84
+ def encode_struct_list(value: Any) -> Any:
85
+ return None if value is None else [elem_encoder(v) for v in value]
86
+
87
+ return encode_struct_list
88
+
89
+ # Otherwise it's a vector, falling into basic type in the engine.
90
+
91
+ if isinstance(variant, AnalyzedDictType):
92
+ value_type_info = analyze_type_info(variant.value_type)
93
+ if not isinstance(value_type_info.variant, AnalyzedStructType):
94
+ raise ValueError(
95
+ f"Value type for dict is required to be a struct (e.g. dataclass or NamedTuple), got {variant.value_type}. "
96
+ f"If you want a free-formed dict, use `cocoindex.Json` instead."
97
+ )
98
+ value_encoder = make_engine_value_encoder(value_type_info)
99
+
100
+ key_type_info = analyze_type_info(variant.key_type)
101
+ key_encoder = make_engine_value_encoder(key_type_info)
102
+ if isinstance(key_type_info.variant, AnalyzedBasicType):
103
+
104
+ def encode_row(k: Any, v: Any) -> Any:
105
+ return [key_encoder(k)] + value_encoder(v)
106
+
107
+ else:
108
+
109
+ def encode_row(k: Any, v: Any) -> Any:
110
+ return key_encoder(k) + value_encoder(v)
111
+
112
+ def encode_struct_dict(value: Any) -> Any:
113
+ if not value:
114
+ return []
115
+ return [encode_row(k, v) for k, v in value.items()]
116
+
117
+ return encode_struct_dict
118
+
119
+ if isinstance(variant, AnalyzedStructType):
120
+ struct_type = variant.struct_type
121
+
122
+ if dataclasses.is_dataclass(struct_type):
123
+ fields = dataclasses.fields(struct_type)
124
+ field_encoders = [
125
+ make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
126
+ ]
127
+ field_names = [f.name for f in fields]
128
+
129
+ def encode_dataclass(value: Any) -> Any:
130
+ if value is None:
131
+ return None
132
+ return [
133
+ encoder(getattr(value, name))
134
+ for encoder, name in zip(field_encoders, field_names)
135
+ ]
136
+
137
+ return encode_dataclass
138
+
139
+ elif is_namedtuple_type(struct_type):
140
+ annotations = struct_type.__annotations__
141
+ field_names = list(getattr(struct_type, "_fields", ()))
142
+ field_encoders = [
143
+ make_engine_value_encoder(
144
+ analyze_type_info(annotations[name])
145
+ if name in annotations
146
+ else ANY_TYPE_INFO
147
+ )
148
+ for name in field_names
149
+ ]
150
+
151
+ def encode_namedtuple(value: Any) -> Any:
152
+ if value is None:
153
+ return None
154
+ return [
155
+ encoder(getattr(value, name))
156
+ for encoder, name in zip(field_encoders, field_names)
157
+ ]
158
+
159
+ return encode_namedtuple
160
+
161
+ def encode_basic_value(value: Any) -> Any:
162
+ if isinstance(value, np.number):
163
+ return value.item()
164
+ if isinstance(value, np.ndarray):
165
+ return value
166
+ if isinstance(value, (list, tuple)):
167
+ return [encode_basic_value(v) for v in value]
168
+ return value
169
+
170
+ return encode_basic_value
171
+
172
+
173
+ def make_engine_key_decoder(
174
+ field_path: list[str],
175
+ key_fields_schema: list[dict[str, Any]],
176
+ dst_type_info: AnalyzedTypeInfo,
177
+ ) -> Callable[[Any], Any]:
178
+ """
179
+ Create an encoder closure for a key type.
180
+ """
181
+ if len(key_fields_schema) == 1 and isinstance(
182
+ dst_type_info.variant, (AnalyzedBasicType, AnalyzedAnyType)
183
+ ):
184
+ single_key_decoder = make_engine_value_decoder(
185
+ field_path,
186
+ key_fields_schema[0]["type"],
187
+ dst_type_info,
188
+ for_key=True,
189
+ )
190
+
191
+ def key_decoder(value: list[Any]) -> Any:
192
+ return single_key_decoder(value[0])
193
+
194
+ return key_decoder
195
+
196
+ return make_engine_struct_decoder(
197
+ field_path,
198
+ key_fields_schema,
199
+ dst_type_info,
200
+ for_key=True,
201
+ )
202
+
203
+
204
+ def make_engine_value_decoder(
205
+ field_path: list[str],
206
+ src_type: dict[str, Any],
207
+ dst_type_info: AnalyzedTypeInfo,
208
+ for_key: bool = False,
209
+ ) -> Callable[[Any], Any]:
210
+ """
211
+ Make a decoder from an engine value to a Python value.
212
+
213
+ Args:
214
+ field_path: The path to the field in the engine value. For error messages.
215
+ src_type: The type of the engine value, mapped from a `cocoindex::base::schema::ValueType`.
216
+ dst_annotation: The type annotation of the Python value.
217
+
218
+ Returns:
219
+ A decoder from an engine value to a Python value.
220
+ """
221
+
222
+ src_type_kind = src_type["kind"]
223
+
224
+ dst_type_variant = dst_type_info.variant
225
+
226
+ if isinstance(dst_type_variant, AnalyzedUnknownType):
227
+ raise ValueError(
228
+ f"Type mismatch for `{''.join(field_path)}`: "
229
+ f"declared `{dst_type_info.core_type}`, an unsupported type"
230
+ )
231
+
232
+ if src_type_kind == "Struct":
233
+ return make_engine_struct_decoder(
234
+ field_path,
235
+ src_type["fields"],
236
+ dst_type_info,
237
+ for_key=for_key,
238
+ )
239
+
240
+ if src_type_kind in TABLE_TYPES:
241
+ with ChildFieldPath(field_path, "[*]"):
242
+ engine_fields_schema = src_type["row"]["fields"]
243
+
244
+ if src_type_kind == "LTable":
245
+ if isinstance(dst_type_variant, AnalyzedAnyType):
246
+ dst_elem_type = Any
247
+ elif isinstance(dst_type_variant, AnalyzedListType):
248
+ dst_elem_type = dst_type_variant.elem_type
249
+ else:
250
+ raise ValueError(
251
+ f"Type mismatch for `{''.join(field_path)}`: "
252
+ f"declared `{dst_type_info.core_type}`, a list type expected"
253
+ )
254
+ row_decoder = make_engine_struct_decoder(
255
+ field_path,
256
+ engine_fields_schema,
257
+ analyze_type_info(dst_elem_type),
258
+ )
259
+
260
+ def decode(value: Any) -> Any | None:
261
+ if value is None:
262
+ return None
263
+ return [row_decoder(v) for v in value]
264
+
265
+ elif src_type_kind == "KTable":
266
+ if isinstance(dst_type_variant, AnalyzedAnyType):
267
+ key_type, value_type = Any, Any
268
+ elif isinstance(dst_type_variant, AnalyzedDictType):
269
+ key_type = dst_type_variant.key_type
270
+ value_type = dst_type_variant.value_type
271
+ else:
272
+ raise ValueError(
273
+ f"Type mismatch for `{''.join(field_path)}`: "
274
+ f"declared `{dst_type_info.core_type}`, a dict type expected"
275
+ )
276
+
277
+ num_key_parts = src_type.get("num_key_parts", 1)
278
+ key_decoder = make_engine_key_decoder(
279
+ field_path,
280
+ engine_fields_schema[0:num_key_parts],
281
+ analyze_type_info(key_type),
282
+ )
283
+ value_decoder = make_engine_struct_decoder(
284
+ field_path,
285
+ engine_fields_schema[num_key_parts:],
286
+ analyze_type_info(value_type),
287
+ )
288
+
289
+ def decode(value: Any) -> Any | None:
290
+ if value is None:
291
+ return None
292
+ return {
293
+ key_decoder(v[0:num_key_parts]): value_decoder(
294
+ v[num_key_parts:]
295
+ )
296
+ for v in value
297
+ }
298
+
299
+ return decode
300
+
301
+ if src_type_kind == "Union":
302
+ if isinstance(dst_type_variant, AnalyzedAnyType):
303
+ return lambda value: value[1]
304
+
305
+ dst_type_info_variants = (
306
+ [analyze_type_info(t) for t in dst_type_variant.variant_types]
307
+ if isinstance(dst_type_variant, AnalyzedUnionType)
308
+ else [dst_type_info]
309
+ )
310
+ src_type_variants = src_type["types"]
311
+ decoders = []
312
+ for i, src_type_variant in enumerate(src_type_variants):
313
+ with ChildFieldPath(field_path, f"[{i}]"):
314
+ decoder = None
315
+ for dst_type_info_variant in dst_type_info_variants:
316
+ try:
317
+ decoder = make_engine_value_decoder(
318
+ field_path, src_type_variant, dst_type_info_variant
319
+ )
320
+ break
321
+ except ValueError:
322
+ pass
323
+ if decoder is None:
324
+ raise ValueError(
325
+ f"Type mismatch for `{''.join(field_path)}`: "
326
+ f"cannot find matched target type for source type variant {src_type_variant}"
327
+ )
328
+ decoders.append(decoder)
329
+ return lambda value: decoders[value[0]](value[1])
330
+
331
+ if isinstance(dst_type_variant, AnalyzedAnyType):
332
+ return lambda value: value
333
+
334
+ if src_type_kind == "Vector":
335
+ field_path_str = "".join(field_path)
336
+ if not isinstance(dst_type_variant, AnalyzedListType):
337
+ raise ValueError(
338
+ f"Type mismatch for `{''.join(field_path)}`: "
339
+ f"declared `{dst_type_info.core_type}`, a list type expected"
340
+ )
341
+ expected_dim = (
342
+ dst_type_variant.vector_info.dim
343
+ if dst_type_variant and dst_type_variant.vector_info
344
+ else None
345
+ )
346
+
347
+ vec_elem_decoder = None
348
+ scalar_dtype = None
349
+ if dst_type_variant and dst_type_info.base_type is np.ndarray:
350
+ if is_numpy_number_type(dst_type_variant.elem_type):
351
+ scalar_dtype = dst_type_variant.elem_type
352
+ else:
353
+ vec_elem_decoder = make_engine_value_decoder(
354
+ field_path + ["[*]"],
355
+ src_type["element_type"],
356
+ analyze_type_info(
357
+ dst_type_variant.elem_type if dst_type_variant else Any
358
+ ),
359
+ )
360
+
361
+ def decode_vector(value: Any) -> Any | None:
362
+ if value is None:
363
+ if dst_type_info.nullable:
364
+ return None
365
+ raise ValueError(
366
+ f"Received null for non-nullable vector `{field_path_str}`"
367
+ )
368
+ if not isinstance(value, (np.ndarray, list)):
369
+ raise TypeError(
370
+ f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
371
+ )
372
+ if expected_dim is not None and len(value) != expected_dim:
373
+ raise ValueError(
374
+ f"Vector dimension mismatch for `{field_path_str}`: "
375
+ f"expected {expected_dim}, got {len(value)}"
376
+ )
377
+
378
+ if vec_elem_decoder is not None: # for Non-NDArray vector
379
+ return [vec_elem_decoder(v) for v in value]
380
+ else: # for NDArray vector
381
+ return np.array(value, dtype=scalar_dtype)
382
+
383
+ return decode_vector
384
+
385
+ if isinstance(dst_type_variant, AnalyzedBasicType):
386
+ if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
387
+ raise ValueError(
388
+ f"Type mismatch for `{''.join(field_path)}`: "
389
+ f"passed in {src_type_kind}, declared {dst_type_info.core_type} ({dst_type_variant.kind})"
390
+ )
391
+
392
+ if dst_type_variant.kind in ("Float32", "Float64", "Int64"):
393
+ dst_core_type = dst_type_info.core_type
394
+
395
+ def decode_scalar(value: Any) -> Any | None:
396
+ if value is None:
397
+ if dst_type_info.nullable:
398
+ return None
399
+ raise ValueError(
400
+ f"Received null for non-nullable scalar `{''.join(field_path)}`"
401
+ )
402
+ return dst_core_type(value)
403
+
404
+ return decode_scalar
405
+
406
+ return lambda value: value
407
+
408
+
409
+ def _get_auto_default_for_type(
410
+ type_info: AnalyzedTypeInfo,
411
+ ) -> tuple[Any, bool]:
412
+ """
413
+ Get an auto-default value for a type annotation if it's safe to do so.
414
+
415
+ Returns:
416
+ A tuple of (default_value, is_supported) where:
417
+ - default_value: The default value if auto-defaulting is supported
418
+ - is_supported: True if auto-defaulting is supported for this type
419
+ """
420
+ # Case 1: Nullable types (Optional[T] or T | None)
421
+ if type_info.nullable:
422
+ return None, True
423
+
424
+ # Case 2: Table types (KTable or LTable) - check if it's a list or dict type
425
+ if isinstance(type_info.variant, AnalyzedListType):
426
+ return [], True
427
+ elif isinstance(type_info.variant, AnalyzedDictType):
428
+ return {}, True
429
+
430
+ return None, False
431
+
432
+
433
+ def make_engine_struct_decoder(
434
+ field_path: list[str],
435
+ src_fields: list[dict[str, Any]],
436
+ dst_type_info: AnalyzedTypeInfo,
437
+ for_key: bool = False,
438
+ ) -> Callable[[list[Any]], Any]:
439
+ """Make a decoder from an engine field values to a Python value."""
440
+
441
+ dst_type_variant = dst_type_info.variant
442
+
443
+ if isinstance(dst_type_variant, AnalyzedAnyType):
444
+ if for_key:
445
+ return _make_engine_struct_to_tuple_decoder(field_path, src_fields)
446
+ else:
447
+ return _make_engine_struct_to_dict_decoder(field_path, src_fields, Any)
448
+ elif isinstance(dst_type_variant, AnalyzedDictType):
449
+ analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
450
+ if (
451
+ isinstance(analyzed_key_type.variant, AnalyzedAnyType)
452
+ or analyzed_key_type.core_type is str
453
+ ):
454
+ return _make_engine_struct_to_dict_decoder(
455
+ field_path, src_fields, dst_type_variant.value_type
456
+ )
457
+
458
+ if not isinstance(dst_type_variant, AnalyzedStructType):
459
+ raise ValueError(
460
+ f"Type mismatch for `{''.join(field_path)}`: "
461
+ f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
462
+ )
463
+
464
+ src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
465
+ dst_struct_type = dst_type_variant.struct_type
466
+
467
+ parameters: Mapping[str, inspect.Parameter]
468
+ if dataclasses.is_dataclass(dst_struct_type):
469
+ parameters = inspect.signature(dst_struct_type).parameters
470
+ elif is_namedtuple_type(dst_struct_type):
471
+ defaults = getattr(dst_struct_type, "_field_defaults", {})
472
+ fields = getattr(dst_struct_type, "_fields", ())
473
+ parameters = {
474
+ name: inspect.Parameter(
475
+ name=name,
476
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
477
+ default=defaults.get(name, inspect.Parameter.empty),
478
+ annotation=dst_struct_type.__annotations__.get(
479
+ name, inspect.Parameter.empty
480
+ ),
481
+ )
482
+ for name in fields
483
+ }
484
+ else:
485
+ raise ValueError(f"Unsupported struct type: {dst_struct_type}")
486
+
487
+ def make_closure_for_field(
488
+ name: str, param: inspect.Parameter
489
+ ) -> Callable[[list[Any]], Any]:
490
+ src_idx = src_name_to_idx.get(name)
491
+ type_info = analyze_type_info(param.annotation)
492
+
493
+ with ChildFieldPath(field_path, f".{name}"):
494
+ if src_idx is not None:
495
+ field_decoder = make_engine_value_decoder(
496
+ field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
497
+ )
498
+ return lambda values: field_decoder(values[src_idx])
499
+
500
+ default_value = param.default
501
+ if default_value is not inspect.Parameter.empty:
502
+ return lambda _: default_value
503
+
504
+ auto_default, is_supported = _get_auto_default_for_type(type_info)
505
+ if is_supported:
506
+ warnings.warn(
507
+ f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
508
+ f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
509
+ UserWarning,
510
+ stacklevel=4,
511
+ )
512
+ return lambda _: auto_default
513
+
514
+ raise ValueError(
515
+ f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
516
+ )
517
+
518
+ field_value_decoder = [
519
+ make_closure_for_field(name, param) for (name, param) in parameters.items()
520
+ ]
521
+
522
+ return lambda values: dst_struct_type(
523
+ *(decoder(values) for decoder in field_value_decoder)
524
+ )
525
+
526
+
527
+ def _make_engine_struct_to_dict_decoder(
528
+ field_path: list[str],
529
+ src_fields: list[dict[str, Any]],
530
+ value_type_annotation: Any,
531
+ ) -> Callable[[list[Any] | None], dict[str, Any] | None]:
532
+ """Make a decoder from engine field values to a Python dict."""
533
+
534
+ field_decoders = []
535
+ value_type_info = analyze_type_info(value_type_annotation)
536
+ for field_schema in src_fields:
537
+ field_name = field_schema["name"]
538
+ with ChildFieldPath(field_path, f".{field_name}"):
539
+ field_decoder = make_engine_value_decoder(
540
+ field_path,
541
+ field_schema["type"],
542
+ value_type_info,
543
+ )
544
+ field_decoders.append((field_name, field_decoder))
545
+
546
+ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
547
+ if values is None:
548
+ return None
549
+ if len(field_decoders) != len(values):
550
+ raise ValueError(
551
+ f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
552
+ )
553
+ return {
554
+ field_name: field_decoder(value)
555
+ for value, (field_name, field_decoder) in zip(values, field_decoders)
556
+ }
557
+
558
+ return decode_to_dict
559
+
560
+
561
+ def _make_engine_struct_to_tuple_decoder(
562
+ field_path: list[str],
563
+ src_fields: list[dict[str, Any]],
564
+ ) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
565
+ """Make a decoder from engine field values to a Python tuple."""
566
+
567
+ field_decoders = []
568
+ value_type_info = analyze_type_info(Any)
569
+ for field_schema in src_fields:
570
+ field_name = field_schema["name"]
571
+ with ChildFieldPath(field_path, f".{field_name}"):
572
+ field_decoders.append(
573
+ make_engine_value_decoder(
574
+ field_path,
575
+ field_schema["type"],
576
+ value_type_info,
577
+ )
578
+ )
579
+
580
+ def decode_to_tuple(values: list[Any] | None) -> tuple[Any, ...] | None:
581
+ if values is None:
582
+ return None
583
+ if len(field_decoders) != len(values):
584
+ raise ValueError(
585
+ f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
586
+ )
587
+ return tuple(
588
+ field_decoder(value) for value, field_decoder in zip(values, field_decoders)
589
+ )
590
+
591
+ return decode_to_tuple
592
+
593
+
594
+ def dump_engine_object(v: Any) -> Any:
595
+ """Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
596
+ if v is None:
597
+ return None
598
+ elif isinstance(v, type) or get_origin(v) is not None:
599
+ return encode_enriched_type(v)
600
+ elif isinstance(v, Enum):
601
+ return v.value
602
+ elif isinstance(v, datetime.timedelta):
603
+ total_secs = v.total_seconds()
604
+ secs = int(total_secs)
605
+ nanos = int((total_secs - secs) * 1e9)
606
+ return {"secs": secs, "nanos": nanos}
607
+ elif hasattr(v, "__dict__"):
608
+ s = {}
609
+ for k, val in v.__dict__.items():
610
+ if val is None:
611
+ # Skip None values
612
+ continue
613
+ s[k] = dump_engine_object(val)
614
+ if hasattr(v, "kind") and "kind" not in s:
615
+ s["kind"] = v.kind
616
+ return s
617
+ elif isinstance(v, (list, tuple)):
618
+ return [dump_engine_object(item) for item in v]
619
+ elif isinstance(v, dict):
620
+ return {k: dump_engine_object(v) for k, v in v.items()}
621
+ return v