cocoindex 0.2.10__cp311-abi3-win_amd64.whl → 0.2.12__cp311-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/__init__.py CHANGED
@@ -24,6 +24,7 @@ from .llm import LlmSpec, LlmApiType
24
24
  from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
25
25
  from .setting import DatabaseConnectionSpec, Settings, ServerSettings
26
26
  from .setting import get_app_namespace
27
+ from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
27
28
  from .typing import (
28
29
  Int64,
29
30
  Float32,
@@ -95,4 +96,8 @@ __all__ = [
95
96
  "Range",
96
97
  "Vector",
97
98
  "Json",
99
+ # Query handler
100
+ "QueryHandlerResultFields",
101
+ "QueryInfo",
102
+ "QueryOutput",
98
103
  ]
cocoindex/_engine.pyd CHANGED
Binary file
@@ -4,25 +4,12 @@ Auth registry is used to register and reference auth entries.
4
4
 
5
5
  from dataclasses import dataclass
6
6
  from typing import Generic, TypeVar
7
- import threading
8
7
 
9
8
  from . import _engine # type: ignore
10
- from .convert import dump_engine_object
9
+ from .convert import dump_engine_object, load_engine_object
11
10
 
12
11
  T = TypeVar("T")
13
12
 
14
- # Global atomic counter for generating unique auth entry keys
15
- _counter_lock = threading.Lock()
16
- _auth_key_counter = 0
17
-
18
-
19
- def _generate_auth_key() -> str:
20
- """Generate a unique auth entry key using a global atomic counter."""
21
- global _auth_key_counter # pylint: disable=global-statement
22
- with _counter_lock:
23
- _auth_key_counter += 1
24
- return f"__auth_{_auth_key_counter}"
25
-
26
13
 
27
14
  @dataclass
28
15
  class TransientAuthEntryReference(Generic[T]):
@@ -37,7 +24,8 @@ class AuthEntryReference(TransientAuthEntryReference[T]):
37
24
 
38
25
  def add_transient_auth_entry(value: T) -> TransientAuthEntryReference[T]:
39
26
  """Add an auth entry to the registry. Returns its reference."""
40
- return add_auth_entry(_generate_auth_key(), value)
27
+ key = _engine.add_transient_auth_entry(dump_engine_object(value))
28
+ return TransientAuthEntryReference(key)
41
29
 
42
30
 
43
31
  def add_auth_entry(key: str, value: T) -> AuthEntryReference[T]:
@@ -49,3 +37,8 @@ def add_auth_entry(key: str, value: T) -> AuthEntryReference[T]:
49
37
  def ref_auth_entry(key: str) -> AuthEntryReference[T]:
50
38
  """Reference an auth entry by its key."""
51
39
  return AuthEntryReference(key)
40
+
41
+
42
+ def get_auth_entry(cls: type[T], ref: TransientAuthEntryReference[T]) -> T:
43
+ """Get an auth entry by its key."""
44
+ return load_engine_object(cls, _engine.get_auth_entry(ref.key))
cocoindex/convert.py CHANGED
@@ -9,12 +9,11 @@ import datetime
9
9
  import inspect
10
10
  import warnings
11
11
  from enum import Enum
12
- from typing import Any, Callable, Mapping, Sequence, Type, get_origin
12
+ from typing import Any, Callable, Mapping, get_origin, TypeVar, overload
13
13
 
14
14
  import numpy as np
15
15
 
16
16
  from .typing import (
17
- TABLE_TYPES,
18
17
  AnalyzedAnyType,
19
18
  AnalyzedBasicType,
20
19
  AnalyzedDictType,
@@ -23,13 +22,23 @@ from .typing import (
23
22
  AnalyzedTypeInfo,
24
23
  AnalyzedUnionType,
25
24
  AnalyzedUnknownType,
25
+ EnrichedValueType,
26
26
  analyze_type_info,
27
27
  encode_enriched_type,
28
28
  is_namedtuple_type,
29
29
  is_numpy_number_type,
30
+ extract_ndarray_elem_dtype,
31
+ ValueType,
32
+ FieldSchema,
33
+ BasicValueType,
34
+ StructType,
35
+ TableType,
30
36
  )
31
37
 
32
38
 
39
+ T = TypeVar("T")
40
+
41
+
33
42
  class ChildFieldPath:
34
43
  """Context manager to append a field to field_path on enter and pop it on exit."""
35
44
 
@@ -172,7 +181,7 @@ def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], An
172
181
 
173
182
  def make_engine_key_decoder(
174
183
  field_path: list[str],
175
- key_fields_schema: list[dict[str, Any]],
184
+ key_fields_schema: list[FieldSchema],
176
185
  dst_type_info: AnalyzedTypeInfo,
177
186
  ) -> Callable[[Any], Any]:
178
187
  """
@@ -183,7 +192,7 @@ def make_engine_key_decoder(
183
192
  ):
184
193
  single_key_decoder = make_engine_value_decoder(
185
194
  field_path,
186
- key_fields_schema[0]["type"],
195
+ key_fields_schema[0].value_type.type,
187
196
  dst_type_info,
188
197
  for_key=True,
189
198
  )
@@ -203,7 +212,7 @@ def make_engine_key_decoder(
203
212
 
204
213
  def make_engine_value_decoder(
205
214
  field_path: list[str],
206
- src_type: dict[str, Any],
215
+ src_type: ValueType,
207
216
  dst_type_info: AnalyzedTypeInfo,
208
217
  for_key: bool = False,
209
218
  ) -> Callable[[Any], Any]:
@@ -219,7 +228,7 @@ def make_engine_value_decoder(
219
228
  A decoder from an engine value to a Python value.
220
229
  """
221
230
 
222
- src_type_kind = src_type["kind"]
231
+ src_type_kind = src_type.kind
223
232
 
224
233
  dst_type_variant = dst_type_info.variant
225
234
 
@@ -229,19 +238,19 @@ def make_engine_value_decoder(
229
238
  f"declared `{dst_type_info.core_type}`, an unsupported type"
230
239
  )
231
240
 
232
- if src_type_kind == "Struct":
241
+ if isinstance(src_type, StructType): # type: ignore[redundant-cast]
233
242
  return make_engine_struct_decoder(
234
243
  field_path,
235
- src_type["fields"],
244
+ src_type.fields,
236
245
  dst_type_info,
237
246
  for_key=for_key,
238
247
  )
239
248
 
240
- if src_type_kind in TABLE_TYPES:
249
+ if isinstance(src_type, TableType): # type: ignore[redundant-cast]
241
250
  with ChildFieldPath(field_path, "[*]"):
242
- engine_fields_schema = src_type["row"]["fields"]
251
+ engine_fields_schema = src_type.row.fields
243
252
 
244
- if src_type_kind == "LTable":
253
+ if src_type.kind == "LTable":
245
254
  if isinstance(dst_type_variant, AnalyzedAnyType):
246
255
  dst_elem_type = Any
247
256
  elif isinstance(dst_type_variant, AnalyzedListType):
@@ -262,7 +271,7 @@ def make_engine_value_decoder(
262
271
  return None
263
272
  return [row_decoder(v) for v in value]
264
273
 
265
- elif src_type_kind == "KTable":
274
+ elif src_type.kind == "KTable":
266
275
  if isinstance(dst_type_variant, AnalyzedAnyType):
267
276
  key_type, value_type = Any, Any
268
277
  elif isinstance(dst_type_variant, AnalyzedDictType):
@@ -274,7 +283,7 @@ def make_engine_value_decoder(
274
283
  f"declared `{dst_type_info.core_type}`, a dict type expected"
275
284
  )
276
285
 
277
- num_key_parts = src_type.get("num_key_parts", 1)
286
+ num_key_parts = src_type.num_key_parts or 1
278
287
  key_decoder = make_engine_key_decoder(
279
288
  field_path,
280
289
  engine_fields_schema[0:num_key_parts],
@@ -298,7 +307,7 @@ def make_engine_value_decoder(
298
307
 
299
308
  return decode
300
309
 
301
- if src_type_kind == "Union":
310
+ if isinstance(src_type, BasicValueType) and src_type.kind == "Union":
302
311
  if isinstance(dst_type_variant, AnalyzedAnyType):
303
312
  return lambda value: value[1]
304
313
 
@@ -307,7 +316,10 @@ def make_engine_value_decoder(
307
316
  if isinstance(dst_type_variant, AnalyzedUnionType)
308
317
  else [dst_type_info]
309
318
  )
310
- src_type_variants = src_type["types"]
319
+ # mypy: union info exists for Union kind
320
+ assert src_type.union is not None # type: ignore[unreachable]
321
+ src_type_variants_basic: list[BasicValueType] = src_type.union.variants
322
+ src_type_variants = src_type_variants_basic
311
323
  decoders = []
312
324
  for i, src_type_variant in enumerate(src_type_variants):
313
325
  with ChildFieldPath(field_path, f"[{i}]"):
@@ -331,7 +343,7 @@ def make_engine_value_decoder(
331
343
  if isinstance(dst_type_variant, AnalyzedAnyType):
332
344
  return lambda value: value
333
345
 
334
- if src_type_kind == "Vector":
346
+ if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
335
347
  field_path_str = "".join(field_path)
336
348
  if not isinstance(dst_type_variant, AnalyzedListType):
337
349
  raise ValueError(
@@ -350,9 +362,11 @@ def make_engine_value_decoder(
350
362
  if is_numpy_number_type(dst_type_variant.elem_type):
351
363
  scalar_dtype = dst_type_variant.elem_type
352
364
  else:
365
+ # mypy: vector info exists for Vector kind
366
+ assert src_type.vector is not None # type: ignore[unreachable]
353
367
  vec_elem_decoder = make_engine_value_decoder(
354
368
  field_path + ["[*]"],
355
- src_type["element_type"],
369
+ src_type.vector.element_type,
356
370
  analyze_type_info(
357
371
  dst_type_variant.elem_type if dst_type_variant else Any
358
372
  ),
@@ -432,7 +446,7 @@ def _get_auto_default_for_type(
432
446
 
433
447
  def make_engine_struct_decoder(
434
448
  field_path: list[str],
435
- src_fields: list[dict[str, Any]],
449
+ src_fields: list[FieldSchema],
436
450
  dst_type_info: AnalyzedTypeInfo,
437
451
  for_key: bool = False,
438
452
  ) -> Callable[[list[Any]], Any]:
@@ -461,7 +475,7 @@ def make_engine_struct_decoder(
461
475
  f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
462
476
  )
463
477
 
464
- src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
478
+ src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
465
479
  dst_struct_type = dst_type_variant.struct_type
466
480
 
467
481
  parameters: Mapping[str, inspect.Parameter]
@@ -493,7 +507,10 @@ def make_engine_struct_decoder(
493
507
  with ChildFieldPath(field_path, f".{name}"):
494
508
  if src_idx is not None:
495
509
  field_decoder = make_engine_value_decoder(
496
- field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
510
+ field_path,
511
+ src_fields[src_idx].value_type.type,
512
+ type_info,
513
+ for_key=for_key,
497
514
  )
498
515
  return lambda values: field_decoder(values[src_idx])
499
516
 
@@ -526,7 +543,7 @@ def make_engine_struct_decoder(
526
543
 
527
544
  def _make_engine_struct_to_dict_decoder(
528
545
  field_path: list[str],
529
- src_fields: list[dict[str, Any]],
546
+ src_fields: list[FieldSchema],
530
547
  value_type_annotation: Any,
531
548
  ) -> Callable[[list[Any] | None], dict[str, Any] | None]:
532
549
  """Make a decoder from engine field values to a Python dict."""
@@ -534,11 +551,11 @@ def _make_engine_struct_to_dict_decoder(
534
551
  field_decoders = []
535
552
  value_type_info = analyze_type_info(value_type_annotation)
536
553
  for field_schema in src_fields:
537
- field_name = field_schema["name"]
554
+ field_name = field_schema.name
538
555
  with ChildFieldPath(field_path, f".{field_name}"):
539
556
  field_decoder = make_engine_value_decoder(
540
557
  field_path,
541
- field_schema["type"],
558
+ field_schema.value_type.type,
542
559
  value_type_info,
543
560
  )
544
561
  field_decoders.append((field_name, field_decoder))
@@ -560,19 +577,19 @@ def _make_engine_struct_to_dict_decoder(
560
577
 
561
578
  def _make_engine_struct_to_tuple_decoder(
562
579
  field_path: list[str],
563
- src_fields: list[dict[str, Any]],
580
+ src_fields: list[FieldSchema],
564
581
  ) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
565
582
  """Make a decoder from engine field values to a Python tuple."""
566
583
 
567
584
  field_decoders = []
568
585
  value_type_info = analyze_type_info(Any)
569
586
  for field_schema in src_fields:
570
- field_name = field_schema["name"]
587
+ field_name = field_schema.name
571
588
  with ChildFieldPath(field_path, f".{field_name}"):
572
589
  field_decoders.append(
573
590
  make_engine_value_decoder(
574
591
  field_path,
575
- field_schema["type"],
592
+ field_schema.value_type.type,
576
593
  value_type_info,
577
594
  )
578
595
  )
@@ -595,6 +612,10 @@ def dump_engine_object(v: Any) -> Any:
595
612
  """Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
596
613
  if v is None:
597
614
  return None
615
+ elif isinstance(v, EnrichedValueType):
616
+ return v.encode()
617
+ elif isinstance(v, FieldSchema):
618
+ return v.encode()
598
619
  elif isinstance(v, type) or get_origin(v) is not None:
599
620
  return encode_enriched_type(v)
600
621
  elif isinstance(v, Enum):
@@ -604,7 +625,17 @@ def dump_engine_object(v: Any) -> Any:
604
625
  secs = int(total_secs)
605
626
  nanos = int((total_secs - secs) * 1e9)
606
627
  return {"secs": secs, "nanos": nanos}
607
- elif hasattr(v, "__dict__"):
628
+ elif is_namedtuple_type(type(v)):
629
+ # Handle NamedTuple objects specifically to use dict format
630
+ field_names = list(getattr(type(v), "_fields", ()))
631
+ result = {}
632
+ for name in field_names:
633
+ val = getattr(v, name)
634
+ result[name] = dump_engine_object(val) # Include all values, including None
635
+ if hasattr(v, "kind") and "kind" not in result:
636
+ result["kind"] = v.kind
637
+ return result
638
+ elif hasattr(v, "__dict__"): # for dataclass-like objects
608
639
  s = {}
609
640
  for k, val in v.__dict__.items():
610
641
  if val is None:
@@ -616,6 +647,133 @@ def dump_engine_object(v: Any) -> Any:
616
647
  return s
617
648
  elif isinstance(v, (list, tuple)):
618
649
  return [dump_engine_object(item) for item in v]
650
+ elif isinstance(v, np.ndarray):
651
+ return v.tolist()
619
652
  elif isinstance(v, dict):
620
653
  return {k: dump_engine_object(v) for k, v in v.items()}
621
654
  return v
655
+
656
+
657
+ @overload
658
+ def load_engine_object(expected_type: type[T], v: Any) -> T: ...
659
+ @overload
660
+ def load_engine_object(expected_type: Any, v: Any) -> Any: ...
661
+ def load_engine_object(expected_type: Any, v: Any) -> Any:
662
+ """Recursively load an object that was produced by dump_engine_object().
663
+
664
+ Args:
665
+ expected_type: The Python type annotation to reconstruct to.
666
+ v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.
667
+
668
+ Returns:
669
+ A Python object matching the expected_type where possible.
670
+ """
671
+ # Fast path
672
+ if v is None:
673
+ return None
674
+
675
+ type_info = analyze_type_info(expected_type)
676
+ variant = type_info.variant
677
+
678
+ if type_info.core_type is EnrichedValueType:
679
+ return EnrichedValueType.decode(v)
680
+ if type_info.core_type is FieldSchema:
681
+ return FieldSchema.decode(v)
682
+
683
+ # Any or unknown → return as-is
684
+ if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
685
+ return v
686
+
687
+ # Enum handling
688
+ if isinstance(expected_type, type) and issubclass(expected_type, Enum):
689
+ return expected_type(v)
690
+
691
+ # TimeDelta special form {secs, nanos}
692
+ if isinstance(variant, AnalyzedBasicType) and variant.kind == "TimeDelta":
693
+ if isinstance(v, Mapping) and "secs" in v and "nanos" in v:
694
+ secs = int(v["secs"]) # type: ignore[index]
695
+ nanos = int(v["nanos"]) # type: ignore[index]
696
+ return datetime.timedelta(seconds=secs, microseconds=nanos / 1_000)
697
+ return v
698
+
699
+ # List, NDArray (Vector-ish), or general sequences
700
+ if isinstance(variant, AnalyzedListType):
701
+ elem_type = variant.elem_type if variant.elem_type else Any
702
+ if type_info.base_type is np.ndarray:
703
+ # Reconstruct NDArray with appropriate dtype if available
704
+ try:
705
+ dtype = extract_ndarray_elem_dtype(type_info.core_type)
706
+ except (TypeError, ValueError, AttributeError):
707
+ dtype = None
708
+ return np.array(v, dtype=dtype)
709
+ # Regular Python list
710
+ return [load_engine_object(elem_type, item) for item in v]
711
+
712
+ # Dict / Mapping
713
+ if isinstance(variant, AnalyzedDictType):
714
+ key_t = variant.key_type
715
+ val_t = variant.value_type
716
+ return {
717
+ load_engine_object(key_t, k): load_engine_object(val_t, val)
718
+ for k, val in v.items()
719
+ }
720
+
721
+ # Structs (dataclass or NamedTuple)
722
+ if isinstance(variant, AnalyzedStructType):
723
+ struct_type = variant.struct_type
724
+ if dataclasses.is_dataclass(struct_type):
725
+ if not isinstance(v, Mapping):
726
+ raise ValueError(f"Expected dict for dataclass, got {type(v)}")
727
+ # Drop auxiliary discriminator "kind" if present
728
+ dc_init_kwargs: dict[str, Any] = {}
729
+ field_types = {f.name: f.type for f in dataclasses.fields(struct_type)}
730
+ for name, f_type in field_types.items():
731
+ if name in v:
732
+ dc_init_kwargs[name] = load_engine_object(f_type, v[name])
733
+ return struct_type(**dc_init_kwargs)
734
+ elif is_namedtuple_type(struct_type):
735
+ if not isinstance(v, Mapping):
736
+ raise ValueError(f"Expected dict for NamedTuple, got {type(v)}")
737
+ # Dict format (from dump/load functions)
738
+ annotations = getattr(struct_type, "__annotations__", {})
739
+ field_names = list(getattr(struct_type, "_fields", ()))
740
+ nt_init_kwargs: dict[str, Any] = {}
741
+ for name in field_names:
742
+ f_type = annotations.get(name, Any)
743
+ if name in v:
744
+ nt_init_kwargs[name] = load_engine_object(f_type, v[name])
745
+ return struct_type(**nt_init_kwargs)
746
+ return v
747
+
748
+ # Union with discriminator support via "kind"
749
+ if isinstance(variant, AnalyzedUnionType):
750
+ if isinstance(v, Mapping) and "kind" in v:
751
+ discriminator = v["kind"]
752
+ for typ in variant.variant_types:
753
+ t_info = analyze_type_info(typ)
754
+ if isinstance(t_info.variant, AnalyzedStructType):
755
+ t_struct = t_info.variant.struct_type
756
+ candidate_kind = getattr(t_struct, "kind", None)
757
+ if candidate_kind == discriminator:
758
+ # Remove discriminator for constructor
759
+ v_wo_kind = dict(v)
760
+ v_wo_kind.pop("kind", None)
761
+ return load_engine_object(t_struct, v_wo_kind)
762
+ # Fallback: try each variant until one succeeds
763
+ for typ in variant.variant_types:
764
+ try:
765
+ return load_engine_object(typ, v)
766
+ except (TypeError, ValueError):
767
+ continue
768
+ return v
769
+
770
+ # Basic types and everything else: handle numpy scalars and passthrough
771
+ if isinstance(v, np.ndarray) and type_info.base_type is list:
772
+ return v.tolist()
773
+ if isinstance(v, (list, tuple)) and type_info.base_type not in (list, tuple):
774
+ # If a non-sequence basic type expected, attempt direct cast
775
+ try:
776
+ return type_info.core_type(v)
777
+ except (TypeError, ValueError):
778
+ return v
779
+ return v
cocoindex/flow.py CHANGED
@@ -38,9 +38,10 @@ from .convert import (
38
38
  make_engine_value_encoder,
39
39
  )
40
40
  from .op import FunctionSpec
41
- from .runtime import execution_context
41
+ from .runtime import execution_context, to_async_call
42
42
  from .setup import SetupChangeBundle
43
- from .typing import analyze_type_info, encode_enriched_type
43
+ from .typing import analyze_type_info, encode_enriched_type, decode_engine_value_type
44
+ from .query_handler import QueryHandlerInfo, QueryHandlerResultFields
44
45
  from .validation import (
45
46
  validate_flow_name,
46
47
  validate_full_flow_name,
@@ -694,23 +695,18 @@ class Flow:
694
695
  """
695
696
 
696
697
  _name: str
697
- _lazy_engine_flow: Callable[[], _engine.Flow] | None
698
+ _engine_flow_creator: Callable[[], _engine.Flow]
699
+
700
+ _lazy_flow_lock: Lock
701
+ _lazy_query_handler_args: list[tuple[Any, ...]]
702
+ _lazy_engine_flow: _engine.Flow | None = None
698
703
 
699
704
  def __init__(self, name: str, engine_flow_creator: Callable[[], _engine.Flow]):
700
705
  validate_flow_name(name)
701
706
  self._name = name
702
- engine_flow = None
703
- lock = Lock()
704
-
705
- def _lazy_engine_flow() -> _engine.Flow:
706
- nonlocal engine_flow, lock
707
- if engine_flow is None:
708
- with lock:
709
- if engine_flow is None:
710
- engine_flow = engine_flow_creator()
711
- return engine_flow
712
-
713
- self._lazy_engine_flow = _lazy_engine_flow
707
+ self._engine_flow_creator = engine_flow_creator
708
+ self._lazy_flow_lock = Lock()
709
+ self._lazy_query_handler_args = []
714
710
 
715
711
  def _render_spec(self, verbose: bool = False) -> Tree:
716
712
  """
@@ -794,15 +790,33 @@ class Flow:
794
790
  """
795
791
  Get the engine flow.
796
792
  """
797
- if self._lazy_engine_flow is None:
798
- raise RuntimeError(f"Flow {self.full_name} is already removed")
799
- return self._lazy_engine_flow()
793
+ if self._lazy_engine_flow is not None:
794
+ return self._lazy_engine_flow
795
+ return self._internal_flow()
800
796
 
801
797
  async def internal_flow_async(self) -> _engine.Flow:
802
798
  """
803
799
  Get the engine flow. The async version.
804
800
  """
805
- return await asyncio.to_thread(self.internal_flow)
801
+ if self._lazy_engine_flow is not None:
802
+ return self._lazy_engine_flow
803
+ return await asyncio.to_thread(self._internal_flow)
804
+
805
+ def _internal_flow(self) -> _engine.Flow:
806
+ """
807
+ Get the engine flow. The async version.
808
+ """
809
+ with self._lazy_flow_lock:
810
+ if self._lazy_engine_flow is not None:
811
+ return self._lazy_engine_flow
812
+
813
+ engine_flow = self._engine_flow_creator()
814
+ self._lazy_engine_flow = engine_flow
815
+ for args in self._lazy_query_handler_args:
816
+ engine_flow.add_query_handler(*args)
817
+ self._lazy_query_handler_args = []
818
+
819
+ return engine_flow
806
820
 
807
821
  def setup(self, report_to_stdout: bool = False) -> None:
808
822
  """
@@ -847,6 +861,53 @@ class Flow:
847
861
  with _flows_lock:
848
862
  del _flows[self.name]
849
863
 
864
+ def add_query_handler(
865
+ self,
866
+ name: str,
867
+ handler: Callable[[str], Any],
868
+ /,
869
+ *,
870
+ result_fields: QueryHandlerResultFields | None = None,
871
+ ) -> None:
872
+ """
873
+ Add a query handler to the flow.
874
+ """
875
+ async_handler = to_async_call(handler)
876
+
877
+ async def _handler(query: str) -> dict[str, Any]:
878
+ handler_result = await async_handler(query)
879
+ return {
880
+ "results": [
881
+ [(k, dump_engine_object(v)) for (k, v) in result.items()]
882
+ for result in handler_result.results
883
+ ],
884
+ "query_info": dump_engine_object(handler_result.query_info),
885
+ }
886
+
887
+ handler_info = dump_engine_object(QueryHandlerInfo(result_fields=result_fields))
888
+ with self._lazy_flow_lock:
889
+ if self._lazy_engine_flow is not None:
890
+ self._lazy_engine_flow.add_query_handler(name, _handler, handler_info)
891
+ else:
892
+ self._lazy_query_handler_args.append((name, _handler, handler_info))
893
+
894
+ def query_handler(
895
+ self,
896
+ name: str | None = None,
897
+ result_fields: QueryHandlerResultFields | None = None,
898
+ ) -> Callable[[Callable[[str], Any]], Callable[[str], Any]]:
899
+ """
900
+ A decorator to declare a query handler.
901
+ """
902
+
903
+ def _inner(handler: Callable[[str], Any]) -> Callable[[str], Any]:
904
+ self.add_query_handler(
905
+ name or handler.__qualname__, handler, result_fields=result_fields
906
+ )
907
+ return handler
908
+
909
+ return _inner
910
+
850
911
 
851
912
  def _create_lazy_flow(
852
913
  name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]
@@ -1103,7 +1164,9 @@ class TransformFlow(Generic[T]):
1103
1164
  inspect.signature(self._flow_fn).return_annotation
1104
1165
  )
1105
1166
  result_decoder = make_engine_value_decoder(
1106
- [], engine_return_type["type"], analyze_type_info(python_return_type)
1167
+ [],
1168
+ decode_engine_value_type(engine_return_type["type"]),
1169
+ analyze_type_info(python_return_type),
1107
1170
  )
1108
1171
 
1109
1172
  return TransformFlowInfo(engine_flow, result_decoder)