cocoindex 0.2.10__cp311-abi3-macosx_10_12_x86_64.whl → 0.2.12__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +5 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +8 -15
- cocoindex/convert.py +185 -27
- cocoindex/flow.py +83 -20
- cocoindex/op.py +168 -52
- cocoindex/query_handler.py +51 -0
- cocoindex/runtime.py +8 -1
- cocoindex/targets/__init__.py +5 -0
- cocoindex/{targets.py → targets/_engine_builtin_specs.py} +4 -4
- cocoindex/targets/lancedb.py +460 -0
- cocoindex/tests/test_convert.py +51 -26
- cocoindex/tests/test_load_convert.py +118 -0
- cocoindex/tests/test_typing.py +126 -2
- cocoindex/typing.py +207 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/METADATA +4 -1
- cocoindex-0.2.12.dist-info/RECORD +37 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/licenses/THIRD_PARTY_NOTICES.html +1 -1
- cocoindex-0.2.10.dist-info/RECORD +0 -33
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/WHEEL +0 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/entry_points.txt +0 -0
cocoindex/__init__.py
CHANGED
@@ -24,6 +24,7 @@ from .llm import LlmSpec, LlmApiType
|
|
24
24
|
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
|
25
25
|
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
|
26
26
|
from .setting import get_app_namespace
|
27
|
+
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
|
27
28
|
from .typing import (
|
28
29
|
Int64,
|
29
30
|
Float32,
|
@@ -95,4 +96,8 @@ __all__ = [
|
|
95
96
|
"Range",
|
96
97
|
"Vector",
|
97
98
|
"Json",
|
99
|
+
# Query handler
|
100
|
+
"QueryHandlerResultFields",
|
101
|
+
"QueryInfo",
|
102
|
+
"QueryOutput",
|
98
103
|
]
|
cocoindex/_engine.abi3.so
CHANGED
Binary file
|
cocoindex/auth_registry.py
CHANGED
@@ -4,25 +4,12 @@ Auth registry is used to register and reference auth entries.
|
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from typing import Generic, TypeVar
|
7
|
-
import threading
|
8
7
|
|
9
8
|
from . import _engine # type: ignore
|
10
|
-
from .convert import dump_engine_object
|
9
|
+
from .convert import dump_engine_object, load_engine_object
|
11
10
|
|
12
11
|
T = TypeVar("T")
|
13
12
|
|
14
|
-
# Global atomic counter for generating unique auth entry keys
|
15
|
-
_counter_lock = threading.Lock()
|
16
|
-
_auth_key_counter = 0
|
17
|
-
|
18
|
-
|
19
|
-
def _generate_auth_key() -> str:
|
20
|
-
"""Generate a unique auth entry key using a global atomic counter."""
|
21
|
-
global _auth_key_counter # pylint: disable=global-statement
|
22
|
-
with _counter_lock:
|
23
|
-
_auth_key_counter += 1
|
24
|
-
return f"__auth_{_auth_key_counter}"
|
25
|
-
|
26
13
|
|
27
14
|
@dataclass
|
28
15
|
class TransientAuthEntryReference(Generic[T]):
|
@@ -37,7 +24,8 @@ class AuthEntryReference(TransientAuthEntryReference[T]):
|
|
37
24
|
|
38
25
|
def add_transient_auth_entry(value: T) -> TransientAuthEntryReference[T]:
|
39
26
|
"""Add an auth entry to the registry. Returns its reference."""
|
40
|
-
|
27
|
+
key = _engine.add_transient_auth_entry(dump_engine_object(value))
|
28
|
+
return TransientAuthEntryReference(key)
|
41
29
|
|
42
30
|
|
43
31
|
def add_auth_entry(key: str, value: T) -> AuthEntryReference[T]:
|
@@ -49,3 +37,8 @@ def add_auth_entry(key: str, value: T) -> AuthEntryReference[T]:
|
|
49
37
|
def ref_auth_entry(key: str) -> AuthEntryReference[T]:
|
50
38
|
"""Reference an auth entry by its key."""
|
51
39
|
return AuthEntryReference(key)
|
40
|
+
|
41
|
+
|
42
|
+
def get_auth_entry(cls: type[T], ref: TransientAuthEntryReference[T]) -> T:
|
43
|
+
"""Get an auth entry by its key."""
|
44
|
+
return load_engine_object(cls, _engine.get_auth_entry(ref.key))
|
cocoindex/convert.py
CHANGED
@@ -9,12 +9,11 @@ import datetime
|
|
9
9
|
import inspect
|
10
10
|
import warnings
|
11
11
|
from enum import Enum
|
12
|
-
from typing import Any, Callable, Mapping,
|
12
|
+
from typing import Any, Callable, Mapping, get_origin, TypeVar, overload
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
|
16
16
|
from .typing import (
|
17
|
-
TABLE_TYPES,
|
18
17
|
AnalyzedAnyType,
|
19
18
|
AnalyzedBasicType,
|
20
19
|
AnalyzedDictType,
|
@@ -23,13 +22,23 @@ from .typing import (
|
|
23
22
|
AnalyzedTypeInfo,
|
24
23
|
AnalyzedUnionType,
|
25
24
|
AnalyzedUnknownType,
|
25
|
+
EnrichedValueType,
|
26
26
|
analyze_type_info,
|
27
27
|
encode_enriched_type,
|
28
28
|
is_namedtuple_type,
|
29
29
|
is_numpy_number_type,
|
30
|
+
extract_ndarray_elem_dtype,
|
31
|
+
ValueType,
|
32
|
+
FieldSchema,
|
33
|
+
BasicValueType,
|
34
|
+
StructType,
|
35
|
+
TableType,
|
30
36
|
)
|
31
37
|
|
32
38
|
|
39
|
+
T = TypeVar("T")
|
40
|
+
|
41
|
+
|
33
42
|
class ChildFieldPath:
|
34
43
|
"""Context manager to append a field to field_path on enter and pop it on exit."""
|
35
44
|
|
@@ -172,7 +181,7 @@ def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], An
|
|
172
181
|
|
173
182
|
def make_engine_key_decoder(
|
174
183
|
field_path: list[str],
|
175
|
-
key_fields_schema: list[
|
184
|
+
key_fields_schema: list[FieldSchema],
|
176
185
|
dst_type_info: AnalyzedTypeInfo,
|
177
186
|
) -> Callable[[Any], Any]:
|
178
187
|
"""
|
@@ -183,7 +192,7 @@ def make_engine_key_decoder(
|
|
183
192
|
):
|
184
193
|
single_key_decoder = make_engine_value_decoder(
|
185
194
|
field_path,
|
186
|
-
key_fields_schema[0]
|
195
|
+
key_fields_schema[0].value_type.type,
|
187
196
|
dst_type_info,
|
188
197
|
for_key=True,
|
189
198
|
)
|
@@ -203,7 +212,7 @@ def make_engine_key_decoder(
|
|
203
212
|
|
204
213
|
def make_engine_value_decoder(
|
205
214
|
field_path: list[str],
|
206
|
-
src_type:
|
215
|
+
src_type: ValueType,
|
207
216
|
dst_type_info: AnalyzedTypeInfo,
|
208
217
|
for_key: bool = False,
|
209
218
|
) -> Callable[[Any], Any]:
|
@@ -219,7 +228,7 @@ def make_engine_value_decoder(
|
|
219
228
|
A decoder from an engine value to a Python value.
|
220
229
|
"""
|
221
230
|
|
222
|
-
src_type_kind = src_type
|
231
|
+
src_type_kind = src_type.kind
|
223
232
|
|
224
233
|
dst_type_variant = dst_type_info.variant
|
225
234
|
|
@@ -229,19 +238,19 @@ def make_engine_value_decoder(
|
|
229
238
|
f"declared `{dst_type_info.core_type}`, an unsupported type"
|
230
239
|
)
|
231
240
|
|
232
|
-
if
|
241
|
+
if isinstance(src_type, StructType): # type: ignore[redundant-cast]
|
233
242
|
return make_engine_struct_decoder(
|
234
243
|
field_path,
|
235
|
-
src_type
|
244
|
+
src_type.fields,
|
236
245
|
dst_type_info,
|
237
246
|
for_key=for_key,
|
238
247
|
)
|
239
248
|
|
240
|
-
if
|
249
|
+
if isinstance(src_type, TableType): # type: ignore[redundant-cast]
|
241
250
|
with ChildFieldPath(field_path, "[*]"):
|
242
|
-
engine_fields_schema = src_type
|
251
|
+
engine_fields_schema = src_type.row.fields
|
243
252
|
|
244
|
-
if
|
253
|
+
if src_type.kind == "LTable":
|
245
254
|
if isinstance(dst_type_variant, AnalyzedAnyType):
|
246
255
|
dst_elem_type = Any
|
247
256
|
elif isinstance(dst_type_variant, AnalyzedListType):
|
@@ -262,7 +271,7 @@ def make_engine_value_decoder(
|
|
262
271
|
return None
|
263
272
|
return [row_decoder(v) for v in value]
|
264
273
|
|
265
|
-
elif
|
274
|
+
elif src_type.kind == "KTable":
|
266
275
|
if isinstance(dst_type_variant, AnalyzedAnyType):
|
267
276
|
key_type, value_type = Any, Any
|
268
277
|
elif isinstance(dst_type_variant, AnalyzedDictType):
|
@@ -274,7 +283,7 @@ def make_engine_value_decoder(
|
|
274
283
|
f"declared `{dst_type_info.core_type}`, a dict type expected"
|
275
284
|
)
|
276
285
|
|
277
|
-
num_key_parts = src_type.
|
286
|
+
num_key_parts = src_type.num_key_parts or 1
|
278
287
|
key_decoder = make_engine_key_decoder(
|
279
288
|
field_path,
|
280
289
|
engine_fields_schema[0:num_key_parts],
|
@@ -298,7 +307,7 @@ def make_engine_value_decoder(
|
|
298
307
|
|
299
308
|
return decode
|
300
309
|
|
301
|
-
if
|
310
|
+
if isinstance(src_type, BasicValueType) and src_type.kind == "Union":
|
302
311
|
if isinstance(dst_type_variant, AnalyzedAnyType):
|
303
312
|
return lambda value: value[1]
|
304
313
|
|
@@ -307,7 +316,10 @@ def make_engine_value_decoder(
|
|
307
316
|
if isinstance(dst_type_variant, AnalyzedUnionType)
|
308
317
|
else [dst_type_info]
|
309
318
|
)
|
310
|
-
|
319
|
+
# mypy: union info exists for Union kind
|
320
|
+
assert src_type.union is not None # type: ignore[unreachable]
|
321
|
+
src_type_variants_basic: list[BasicValueType] = src_type.union.variants
|
322
|
+
src_type_variants = src_type_variants_basic
|
311
323
|
decoders = []
|
312
324
|
for i, src_type_variant in enumerate(src_type_variants):
|
313
325
|
with ChildFieldPath(field_path, f"[{i}]"):
|
@@ -331,7 +343,7 @@ def make_engine_value_decoder(
|
|
331
343
|
if isinstance(dst_type_variant, AnalyzedAnyType):
|
332
344
|
return lambda value: value
|
333
345
|
|
334
|
-
if
|
346
|
+
if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
|
335
347
|
field_path_str = "".join(field_path)
|
336
348
|
if not isinstance(dst_type_variant, AnalyzedListType):
|
337
349
|
raise ValueError(
|
@@ -350,9 +362,11 @@ def make_engine_value_decoder(
|
|
350
362
|
if is_numpy_number_type(dst_type_variant.elem_type):
|
351
363
|
scalar_dtype = dst_type_variant.elem_type
|
352
364
|
else:
|
365
|
+
# mypy: vector info exists for Vector kind
|
366
|
+
assert src_type.vector is not None # type: ignore[unreachable]
|
353
367
|
vec_elem_decoder = make_engine_value_decoder(
|
354
368
|
field_path + ["[*]"],
|
355
|
-
src_type
|
369
|
+
src_type.vector.element_type,
|
356
370
|
analyze_type_info(
|
357
371
|
dst_type_variant.elem_type if dst_type_variant else Any
|
358
372
|
),
|
@@ -432,7 +446,7 @@ def _get_auto_default_for_type(
|
|
432
446
|
|
433
447
|
def make_engine_struct_decoder(
|
434
448
|
field_path: list[str],
|
435
|
-
src_fields: list[
|
449
|
+
src_fields: list[FieldSchema],
|
436
450
|
dst_type_info: AnalyzedTypeInfo,
|
437
451
|
for_key: bool = False,
|
438
452
|
) -> Callable[[list[Any]], Any]:
|
@@ -461,7 +475,7 @@ def make_engine_struct_decoder(
|
|
461
475
|
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
|
462
476
|
)
|
463
477
|
|
464
|
-
src_name_to_idx = {f
|
478
|
+
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
|
465
479
|
dst_struct_type = dst_type_variant.struct_type
|
466
480
|
|
467
481
|
parameters: Mapping[str, inspect.Parameter]
|
@@ -493,7 +507,10 @@ def make_engine_struct_decoder(
|
|
493
507
|
with ChildFieldPath(field_path, f".{name}"):
|
494
508
|
if src_idx is not None:
|
495
509
|
field_decoder = make_engine_value_decoder(
|
496
|
-
field_path,
|
510
|
+
field_path,
|
511
|
+
src_fields[src_idx].value_type.type,
|
512
|
+
type_info,
|
513
|
+
for_key=for_key,
|
497
514
|
)
|
498
515
|
return lambda values: field_decoder(values[src_idx])
|
499
516
|
|
@@ -526,7 +543,7 @@ def make_engine_struct_decoder(
|
|
526
543
|
|
527
544
|
def _make_engine_struct_to_dict_decoder(
|
528
545
|
field_path: list[str],
|
529
|
-
src_fields: list[
|
546
|
+
src_fields: list[FieldSchema],
|
530
547
|
value_type_annotation: Any,
|
531
548
|
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
|
532
549
|
"""Make a decoder from engine field values to a Python dict."""
|
@@ -534,11 +551,11 @@ def _make_engine_struct_to_dict_decoder(
|
|
534
551
|
field_decoders = []
|
535
552
|
value_type_info = analyze_type_info(value_type_annotation)
|
536
553
|
for field_schema in src_fields:
|
537
|
-
field_name = field_schema
|
554
|
+
field_name = field_schema.name
|
538
555
|
with ChildFieldPath(field_path, f".{field_name}"):
|
539
556
|
field_decoder = make_engine_value_decoder(
|
540
557
|
field_path,
|
541
|
-
field_schema
|
558
|
+
field_schema.value_type.type,
|
542
559
|
value_type_info,
|
543
560
|
)
|
544
561
|
field_decoders.append((field_name, field_decoder))
|
@@ -560,19 +577,19 @@ def _make_engine_struct_to_dict_decoder(
|
|
560
577
|
|
561
578
|
def _make_engine_struct_to_tuple_decoder(
|
562
579
|
field_path: list[str],
|
563
|
-
src_fields: list[
|
580
|
+
src_fields: list[FieldSchema],
|
564
581
|
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
|
565
582
|
"""Make a decoder from engine field values to a Python tuple."""
|
566
583
|
|
567
584
|
field_decoders = []
|
568
585
|
value_type_info = analyze_type_info(Any)
|
569
586
|
for field_schema in src_fields:
|
570
|
-
field_name = field_schema
|
587
|
+
field_name = field_schema.name
|
571
588
|
with ChildFieldPath(field_path, f".{field_name}"):
|
572
589
|
field_decoders.append(
|
573
590
|
make_engine_value_decoder(
|
574
591
|
field_path,
|
575
|
-
field_schema
|
592
|
+
field_schema.value_type.type,
|
576
593
|
value_type_info,
|
577
594
|
)
|
578
595
|
)
|
@@ -595,6 +612,10 @@ def dump_engine_object(v: Any) -> Any:
|
|
595
612
|
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
|
596
613
|
if v is None:
|
597
614
|
return None
|
615
|
+
elif isinstance(v, EnrichedValueType):
|
616
|
+
return v.encode()
|
617
|
+
elif isinstance(v, FieldSchema):
|
618
|
+
return v.encode()
|
598
619
|
elif isinstance(v, type) or get_origin(v) is not None:
|
599
620
|
return encode_enriched_type(v)
|
600
621
|
elif isinstance(v, Enum):
|
@@ -604,7 +625,17 @@ def dump_engine_object(v: Any) -> Any:
|
|
604
625
|
secs = int(total_secs)
|
605
626
|
nanos = int((total_secs - secs) * 1e9)
|
606
627
|
return {"secs": secs, "nanos": nanos}
|
607
|
-
elif
|
628
|
+
elif is_namedtuple_type(type(v)):
|
629
|
+
# Handle NamedTuple objects specifically to use dict format
|
630
|
+
field_names = list(getattr(type(v), "_fields", ()))
|
631
|
+
result = {}
|
632
|
+
for name in field_names:
|
633
|
+
val = getattr(v, name)
|
634
|
+
result[name] = dump_engine_object(val) # Include all values, including None
|
635
|
+
if hasattr(v, "kind") and "kind" not in result:
|
636
|
+
result["kind"] = v.kind
|
637
|
+
return result
|
638
|
+
elif hasattr(v, "__dict__"): # for dataclass-like objects
|
608
639
|
s = {}
|
609
640
|
for k, val in v.__dict__.items():
|
610
641
|
if val is None:
|
@@ -616,6 +647,133 @@ def dump_engine_object(v: Any) -> Any:
|
|
616
647
|
return s
|
617
648
|
elif isinstance(v, (list, tuple)):
|
618
649
|
return [dump_engine_object(item) for item in v]
|
650
|
+
elif isinstance(v, np.ndarray):
|
651
|
+
return v.tolist()
|
619
652
|
elif isinstance(v, dict):
|
620
653
|
return {k: dump_engine_object(v) for k, v in v.items()}
|
621
654
|
return v
|
655
|
+
|
656
|
+
|
657
|
+
@overload
|
658
|
+
def load_engine_object(expected_type: type[T], v: Any) -> T: ...
|
659
|
+
@overload
|
660
|
+
def load_engine_object(expected_type: Any, v: Any) -> Any: ...
|
661
|
+
def load_engine_object(expected_type: Any, v: Any) -> Any:
|
662
|
+
"""Recursively load an object that was produced by dump_engine_object().
|
663
|
+
|
664
|
+
Args:
|
665
|
+
expected_type: The Python type annotation to reconstruct to.
|
666
|
+
v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.
|
667
|
+
|
668
|
+
Returns:
|
669
|
+
A Python object matching the expected_type where possible.
|
670
|
+
"""
|
671
|
+
# Fast path
|
672
|
+
if v is None:
|
673
|
+
return None
|
674
|
+
|
675
|
+
type_info = analyze_type_info(expected_type)
|
676
|
+
variant = type_info.variant
|
677
|
+
|
678
|
+
if type_info.core_type is EnrichedValueType:
|
679
|
+
return EnrichedValueType.decode(v)
|
680
|
+
if type_info.core_type is FieldSchema:
|
681
|
+
return FieldSchema.decode(v)
|
682
|
+
|
683
|
+
# Any or unknown → return as-is
|
684
|
+
if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
|
685
|
+
return v
|
686
|
+
|
687
|
+
# Enum handling
|
688
|
+
if isinstance(expected_type, type) and issubclass(expected_type, Enum):
|
689
|
+
return expected_type(v)
|
690
|
+
|
691
|
+
# TimeDelta special form {secs, nanos}
|
692
|
+
if isinstance(variant, AnalyzedBasicType) and variant.kind == "TimeDelta":
|
693
|
+
if isinstance(v, Mapping) and "secs" in v and "nanos" in v:
|
694
|
+
secs = int(v["secs"]) # type: ignore[index]
|
695
|
+
nanos = int(v["nanos"]) # type: ignore[index]
|
696
|
+
return datetime.timedelta(seconds=secs, microseconds=nanos / 1_000)
|
697
|
+
return v
|
698
|
+
|
699
|
+
# List, NDArray (Vector-ish), or general sequences
|
700
|
+
if isinstance(variant, AnalyzedListType):
|
701
|
+
elem_type = variant.elem_type if variant.elem_type else Any
|
702
|
+
if type_info.base_type is np.ndarray:
|
703
|
+
# Reconstruct NDArray with appropriate dtype if available
|
704
|
+
try:
|
705
|
+
dtype = extract_ndarray_elem_dtype(type_info.core_type)
|
706
|
+
except (TypeError, ValueError, AttributeError):
|
707
|
+
dtype = None
|
708
|
+
return np.array(v, dtype=dtype)
|
709
|
+
# Regular Python list
|
710
|
+
return [load_engine_object(elem_type, item) for item in v]
|
711
|
+
|
712
|
+
# Dict / Mapping
|
713
|
+
if isinstance(variant, AnalyzedDictType):
|
714
|
+
key_t = variant.key_type
|
715
|
+
val_t = variant.value_type
|
716
|
+
return {
|
717
|
+
load_engine_object(key_t, k): load_engine_object(val_t, val)
|
718
|
+
for k, val in v.items()
|
719
|
+
}
|
720
|
+
|
721
|
+
# Structs (dataclass or NamedTuple)
|
722
|
+
if isinstance(variant, AnalyzedStructType):
|
723
|
+
struct_type = variant.struct_type
|
724
|
+
if dataclasses.is_dataclass(struct_type):
|
725
|
+
if not isinstance(v, Mapping):
|
726
|
+
raise ValueError(f"Expected dict for dataclass, got {type(v)}")
|
727
|
+
# Drop auxiliary discriminator "kind" if present
|
728
|
+
dc_init_kwargs: dict[str, Any] = {}
|
729
|
+
field_types = {f.name: f.type for f in dataclasses.fields(struct_type)}
|
730
|
+
for name, f_type in field_types.items():
|
731
|
+
if name in v:
|
732
|
+
dc_init_kwargs[name] = load_engine_object(f_type, v[name])
|
733
|
+
return struct_type(**dc_init_kwargs)
|
734
|
+
elif is_namedtuple_type(struct_type):
|
735
|
+
if not isinstance(v, Mapping):
|
736
|
+
raise ValueError(f"Expected dict for NamedTuple, got {type(v)}")
|
737
|
+
# Dict format (from dump/load functions)
|
738
|
+
annotations = getattr(struct_type, "__annotations__", {})
|
739
|
+
field_names = list(getattr(struct_type, "_fields", ()))
|
740
|
+
nt_init_kwargs: dict[str, Any] = {}
|
741
|
+
for name in field_names:
|
742
|
+
f_type = annotations.get(name, Any)
|
743
|
+
if name in v:
|
744
|
+
nt_init_kwargs[name] = load_engine_object(f_type, v[name])
|
745
|
+
return struct_type(**nt_init_kwargs)
|
746
|
+
return v
|
747
|
+
|
748
|
+
# Union with discriminator support via "kind"
|
749
|
+
if isinstance(variant, AnalyzedUnionType):
|
750
|
+
if isinstance(v, Mapping) and "kind" in v:
|
751
|
+
discriminator = v["kind"]
|
752
|
+
for typ in variant.variant_types:
|
753
|
+
t_info = analyze_type_info(typ)
|
754
|
+
if isinstance(t_info.variant, AnalyzedStructType):
|
755
|
+
t_struct = t_info.variant.struct_type
|
756
|
+
candidate_kind = getattr(t_struct, "kind", None)
|
757
|
+
if candidate_kind == discriminator:
|
758
|
+
# Remove discriminator for constructor
|
759
|
+
v_wo_kind = dict(v)
|
760
|
+
v_wo_kind.pop("kind", None)
|
761
|
+
return load_engine_object(t_struct, v_wo_kind)
|
762
|
+
# Fallback: try each variant until one succeeds
|
763
|
+
for typ in variant.variant_types:
|
764
|
+
try:
|
765
|
+
return load_engine_object(typ, v)
|
766
|
+
except (TypeError, ValueError):
|
767
|
+
continue
|
768
|
+
return v
|
769
|
+
|
770
|
+
# Basic types and everything else: handle numpy scalars and passthrough
|
771
|
+
if isinstance(v, np.ndarray) and type_info.base_type is list:
|
772
|
+
return v.tolist()
|
773
|
+
if isinstance(v, (list, tuple)) and type_info.base_type not in (list, tuple):
|
774
|
+
# If a non-sequence basic type expected, attempt direct cast
|
775
|
+
try:
|
776
|
+
return type_info.core_type(v)
|
777
|
+
except (TypeError, ValueError):
|
778
|
+
return v
|
779
|
+
return v
|
cocoindex/flow.py
CHANGED
@@ -38,9 +38,10 @@ from .convert import (
|
|
38
38
|
make_engine_value_encoder,
|
39
39
|
)
|
40
40
|
from .op import FunctionSpec
|
41
|
-
from .runtime import execution_context
|
41
|
+
from .runtime import execution_context, to_async_call
|
42
42
|
from .setup import SetupChangeBundle
|
43
|
-
from .typing import analyze_type_info, encode_enriched_type
|
43
|
+
from .typing import analyze_type_info, encode_enriched_type, decode_engine_value_type
|
44
|
+
from .query_handler import QueryHandlerInfo, QueryHandlerResultFields
|
44
45
|
from .validation import (
|
45
46
|
validate_flow_name,
|
46
47
|
validate_full_flow_name,
|
@@ -694,23 +695,18 @@ class Flow:
|
|
694
695
|
"""
|
695
696
|
|
696
697
|
_name: str
|
697
|
-
|
698
|
+
_engine_flow_creator: Callable[[], _engine.Flow]
|
699
|
+
|
700
|
+
_lazy_flow_lock: Lock
|
701
|
+
_lazy_query_handler_args: list[tuple[Any, ...]]
|
702
|
+
_lazy_engine_flow: _engine.Flow | None = None
|
698
703
|
|
699
704
|
def __init__(self, name: str, engine_flow_creator: Callable[[], _engine.Flow]):
|
700
705
|
validate_flow_name(name)
|
701
706
|
self._name = name
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
def _lazy_engine_flow() -> _engine.Flow:
|
706
|
-
nonlocal engine_flow, lock
|
707
|
-
if engine_flow is None:
|
708
|
-
with lock:
|
709
|
-
if engine_flow is None:
|
710
|
-
engine_flow = engine_flow_creator()
|
711
|
-
return engine_flow
|
712
|
-
|
713
|
-
self._lazy_engine_flow = _lazy_engine_flow
|
707
|
+
self._engine_flow_creator = engine_flow_creator
|
708
|
+
self._lazy_flow_lock = Lock()
|
709
|
+
self._lazy_query_handler_args = []
|
714
710
|
|
715
711
|
def _render_spec(self, verbose: bool = False) -> Tree:
|
716
712
|
"""
|
@@ -794,15 +790,33 @@ class Flow:
|
|
794
790
|
"""
|
795
791
|
Get the engine flow.
|
796
792
|
"""
|
797
|
-
if self._lazy_engine_flow is None:
|
798
|
-
|
799
|
-
return self.
|
793
|
+
if self._lazy_engine_flow is not None:
|
794
|
+
return self._lazy_engine_flow
|
795
|
+
return self._internal_flow()
|
800
796
|
|
801
797
|
async def internal_flow_async(self) -> _engine.Flow:
|
802
798
|
"""
|
803
799
|
Get the engine flow. The async version.
|
804
800
|
"""
|
805
|
-
|
801
|
+
if self._lazy_engine_flow is not None:
|
802
|
+
return self._lazy_engine_flow
|
803
|
+
return await asyncio.to_thread(self._internal_flow)
|
804
|
+
|
805
|
+
def _internal_flow(self) -> _engine.Flow:
|
806
|
+
"""
|
807
|
+
Get the engine flow. The async version.
|
808
|
+
"""
|
809
|
+
with self._lazy_flow_lock:
|
810
|
+
if self._lazy_engine_flow is not None:
|
811
|
+
return self._lazy_engine_flow
|
812
|
+
|
813
|
+
engine_flow = self._engine_flow_creator()
|
814
|
+
self._lazy_engine_flow = engine_flow
|
815
|
+
for args in self._lazy_query_handler_args:
|
816
|
+
engine_flow.add_query_handler(*args)
|
817
|
+
self._lazy_query_handler_args = []
|
818
|
+
|
819
|
+
return engine_flow
|
806
820
|
|
807
821
|
def setup(self, report_to_stdout: bool = False) -> None:
|
808
822
|
"""
|
@@ -847,6 +861,53 @@ class Flow:
|
|
847
861
|
with _flows_lock:
|
848
862
|
del _flows[self.name]
|
849
863
|
|
864
|
+
def add_query_handler(
|
865
|
+
self,
|
866
|
+
name: str,
|
867
|
+
handler: Callable[[str], Any],
|
868
|
+
/,
|
869
|
+
*,
|
870
|
+
result_fields: QueryHandlerResultFields | None = None,
|
871
|
+
) -> None:
|
872
|
+
"""
|
873
|
+
Add a query handler to the flow.
|
874
|
+
"""
|
875
|
+
async_handler = to_async_call(handler)
|
876
|
+
|
877
|
+
async def _handler(query: str) -> dict[str, Any]:
|
878
|
+
handler_result = await async_handler(query)
|
879
|
+
return {
|
880
|
+
"results": [
|
881
|
+
[(k, dump_engine_object(v)) for (k, v) in result.items()]
|
882
|
+
for result in handler_result.results
|
883
|
+
],
|
884
|
+
"query_info": dump_engine_object(handler_result.query_info),
|
885
|
+
}
|
886
|
+
|
887
|
+
handler_info = dump_engine_object(QueryHandlerInfo(result_fields=result_fields))
|
888
|
+
with self._lazy_flow_lock:
|
889
|
+
if self._lazy_engine_flow is not None:
|
890
|
+
self._lazy_engine_flow.add_query_handler(name, _handler, handler_info)
|
891
|
+
else:
|
892
|
+
self._lazy_query_handler_args.append((name, _handler, handler_info))
|
893
|
+
|
894
|
+
def query_handler(
|
895
|
+
self,
|
896
|
+
name: str | None = None,
|
897
|
+
result_fields: QueryHandlerResultFields | None = None,
|
898
|
+
) -> Callable[[Callable[[str], Any]], Callable[[str], Any]]:
|
899
|
+
"""
|
900
|
+
A decorator to declare a query handler.
|
901
|
+
"""
|
902
|
+
|
903
|
+
def _inner(handler: Callable[[str], Any]) -> Callable[[str], Any]:
|
904
|
+
self.add_query_handler(
|
905
|
+
name or handler.__qualname__, handler, result_fields=result_fields
|
906
|
+
)
|
907
|
+
return handler
|
908
|
+
|
909
|
+
return _inner
|
910
|
+
|
850
911
|
|
851
912
|
def _create_lazy_flow(
|
852
913
|
name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]
|
@@ -1103,7 +1164,9 @@ class TransformFlow(Generic[T]):
|
|
1103
1164
|
inspect.signature(self._flow_fn).return_annotation
|
1104
1165
|
)
|
1105
1166
|
result_decoder = make_engine_value_decoder(
|
1106
|
-
[],
|
1167
|
+
[],
|
1168
|
+
decode_engine_value_type(engine_return_type["type"]),
|
1169
|
+
analyze_type_info(python_return_type),
|
1107
1170
|
)
|
1108
1171
|
|
1109
1172
|
return TransformFlowInfo(engine_flow, result_decoder)
|