cocoindex 0.2.10__cp311-abi3-win_amd64.whl → 0.2.12__cp311-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/op.py CHANGED
@@ -2,7 +2,6 @@
2
2
  Facilities for defining cocoindex operations.
3
3
  """
4
4
 
5
- import asyncio
6
5
  import dataclasses
7
6
  import inspect
8
7
  from enum import Enum
@@ -19,6 +18,8 @@ from typing import (
19
18
  from . import _engine # type: ignore
20
19
  from .subprocess_exec import executor_stub
21
20
  from .convert import (
21
+ dump_engine_object,
22
+ load_engine_object,
22
23
  make_engine_value_encoder,
23
24
  make_engine_value_decoder,
24
25
  make_engine_key_decoder,
@@ -31,7 +32,12 @@ from .typing import (
31
32
  analyze_type_info,
32
33
  AnalyzedAnyType,
33
34
  AnalyzedDictType,
35
+ EnrichedValueType,
36
+ decode_engine_field_schemas,
37
+ FieldSchema,
34
38
  )
39
+ from .runtime import to_async_call
40
+ from .index import IndexOptions
35
41
 
36
42
 
37
43
  class OpCategory(Enum):
@@ -86,15 +92,6 @@ class Executor(Protocol):
86
92
  op_category: OpCategory
87
93
 
88
94
 
89
- def _load_spec_from_engine(
90
- spec_loader: Callable[..., Any], spec: dict[str, Any]
91
- ) -> Any:
92
- """
93
- Load a spec from the engine.
94
- """
95
- return spec_loader(**spec)
96
-
97
-
98
95
  def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
99
96
  method = getattr(cls, name, None)
100
97
  if method is None:
@@ -105,7 +102,7 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
105
102
 
106
103
 
107
104
  class _EngineFunctionExecutorFactory:
108
- _spec_loader: Callable[..., Any]
105
+ _spec_loader: Callable[[Any], Any]
109
106
  _executor_cls: type
110
107
 
111
108
  def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
@@ -113,9 +110,9 @@ class _EngineFunctionExecutorFactory:
113
110
  self._executor_cls = executor_cls
114
111
 
115
112
  def __call__(
116
- self, spec: dict[str, Any], *args: Any, **kwargs: Any
113
+ self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
117
114
  ) -> tuple[dict[str, Any], Executor]:
118
- spec = _load_spec_from_engine(self._spec_loader, spec)
115
+ spec = self._spec_loader(raw_spec)
119
116
  executor = self._executor_cls(spec)
120
117
  result_type = executor.analyze_schema(*args, **kwargs)
121
118
  return (result_type, executor)
@@ -150,12 +147,6 @@ class OpArgs:
150
147
  arg_relationship: tuple[ArgRelationship, str] | None = None
151
148
 
152
149
 
153
- def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
154
- if inspect.iscoroutinefunction(call):
155
- return call
156
- return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
157
-
158
-
159
150
  @dataclasses.dataclass
160
151
  class _ArgInfo:
161
152
  decoder: Callable[[Any], Any]
@@ -218,8 +209,9 @@ def _register_op_factory(
218
209
  TypeAttr(related_attr.value, actual_arg.analyzed_value)
219
210
  )
220
211
  type_info = analyze_type_info(arg_param.annotation)
212
+ enriched = EnrichedValueType.decode(actual_arg.value_type)
221
213
  decoder = make_engine_value_decoder(
222
- [arg_name], actual_arg.value_type["type"], type_info
214
+ [arg_name], enriched.type, type_info
223
215
  )
224
216
  is_required = not type_info.nullable
225
217
  if is_required and actual_arg.value_type.get("nullable", False):
@@ -319,8 +311,8 @@ def _register_op_factory(
319
311
  """
320
312
  prepare_method = getattr(self._executor, "prepare", None)
321
313
  if prepare_method is not None:
322
- await _to_async_call(prepare_method)()
323
- self._acall = _to_async_call(self._executor.__call__)
314
+ await to_async_call(prepare_method)()
315
+ self._acall = to_async_call(self._executor.__call__)
324
316
 
325
317
  async def __call__(self, *args: Any, **kwargs: Any) -> Any:
326
318
  decoded_args = []
@@ -379,7 +371,7 @@ def executor_class(**args: Any) -> Callable[[type], type]:
379
371
  expected_args=list(sig.parameters.items())[1:], # First argument is `self`
380
372
  expected_return=sig.return_annotation,
381
373
  executor_factory=cls,
382
- spec_loader=spec_cls,
374
+ spec_loader=lambda v: load_engine_object(spec_cls, v),
383
375
  op_kind=spec_cls.__name__,
384
376
  op_args=op_args,
385
377
  )
@@ -415,7 +407,7 @@ def function(**args: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
415
407
  expected_args=list(sig.parameters.items()),
416
408
  expected_return=sig.return_annotation,
417
409
  executor_factory=_SimpleFunctionExecutor,
418
- spec_loader=lambda: fn,
410
+ spec_loader=lambda _: fn,
419
411
  op_kind=op_kind,
420
412
  op_args=op_args,
421
413
  )
@@ -435,8 +427,44 @@ class _TargetConnectorContext:
435
427
  target_name: str
436
428
  spec: Any
437
429
  prepared_spec: Any
430
+ key_fields_schema: list[FieldSchema]
438
431
  key_decoder: Callable[[Any], Any]
432
+ value_fields_schema: list[FieldSchema]
439
433
  value_decoder: Callable[[Any], Any]
434
+ index_options: IndexOptions
435
+ setup_state: Any
436
+
437
+
438
+ def _build_args(
439
+ method: Callable[..., Any], num_required_args: int, **kwargs: Any
440
+ ) -> list[Any]:
441
+ signature = inspect.signature(method)
442
+ for param in signature.parameters.values():
443
+ if param.kind not in (
444
+ inspect.Parameter.POSITIONAL_ONLY,
445
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
446
+ ):
447
+ raise ValueError(
448
+ f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
449
+ )
450
+ if len(signature.parameters) < num_required_args:
451
+ raise ValueError(
452
+ f"Method {method.__name__} must have at least {num_required_args} required arguments: "
453
+ f"{', '.join(list(kwargs.keys())[:num_required_args])}"
454
+ )
455
+ if len(kwargs) > len(kwargs):
456
+ raise ValueError(
457
+ f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
458
+ )
459
+ return [v for _, v in zip(signature.parameters, kwargs.values())]
460
+
461
+
462
+ class TargetStateCompatibility(Enum):
463
+ """The compatibility of the target state."""
464
+
465
+ COMPATIBLE = "Compatible"
466
+ PARTIALLY_COMPATIBLE = "PartialCompatible"
467
+ NOT_COMPATIBLE = "NotCompatible"
440
468
 
441
469
 
442
470
  class _TargetConnector:
@@ -444,8 +472,10 @@ class _TargetConnector:
444
472
  The connector class passed to the engine.
445
473
  """
446
474
 
447
- _spec_cls: type
448
- _connector_cls: type
475
+ _spec_cls: type[Any]
476
+ _persistent_key_type: Any
477
+ _setup_state_cls: type[Any]
478
+ _connector_cls: type[Any]
449
479
 
450
480
  _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
451
481
  _apply_setup_change_async_fn: Callable[
@@ -454,19 +484,27 @@ class _TargetConnector:
454
484
  _mutate_async_fn: Callable[..., Awaitable[None]]
455
485
  _mutatation_type: AnalyzedDictType | None
456
486
 
457
- def __init__(self, spec_cls: type, connector_cls: type):
487
+ def __init__(
488
+ self,
489
+ spec_cls: type[Any],
490
+ persistent_key_type: Any,
491
+ setup_state_cls: type[Any],
492
+ connector_cls: type[Any],
493
+ ):
458
494
  self._spec_cls = spec_cls
495
+ self._persistent_key_type = persistent_key_type
496
+ self._setup_state_cls = setup_state_cls
459
497
  self._connector_cls = connector_cls
460
498
 
461
499
  self._get_persistent_key_fn = _get_required_method(
462
500
  connector_cls, "get_persistent_key"
463
501
  )
464
- self._apply_setup_change_async_fn = _to_async_call(
502
+ self._apply_setup_change_async_fn = to_async_call(
465
503
  _get_required_method(connector_cls, "apply_setup_change")
466
504
  )
467
505
 
468
506
  mutate_fn = _get_required_method(connector_cls, "mutate")
469
- self._mutate_async_fn = _to_async_call(mutate_fn)
507
+ self._mutate_async_fn = to_async_call(mutate_fn)
470
508
 
471
509
  # Store the type annotation for later use
472
510
  self._mutatation_type = self._analyze_mutate_mutation_type(
@@ -513,15 +551,16 @@ class _TargetConnector:
513
551
  raise ValueError(
514
552
  f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
515
553
  f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
516
- "got {args_type}"
554
+ f"got {analyzed_args_type.core_type}"
517
555
  )
518
556
 
519
557
  def create_export_context(
520
558
  self,
521
559
  name: str,
522
- spec: dict[str, Any],
523
- key_fields_schema: list[Any],
524
- value_fields_schema: list[Any],
560
+ raw_spec: dict[str, Any],
561
+ raw_key_fields_schema: list[Any],
562
+ raw_value_fields_schema: list[Any],
563
+ raw_index_options: dict[str, Any],
525
564
  ) -> _TargetConnectorContext:
526
565
  key_annotation, value_annotation = (
527
566
  (
@@ -532,34 +571,103 @@ class _TargetConnector:
532
571
  else (Any, Any)
533
572
  )
534
573
 
574
+ key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
535
575
  key_decoder = make_engine_key_decoder(
536
- ["(key)"], key_fields_schema, analyze_type_info(key_annotation)
576
+ ["<key>"], key_fields_schema, analyze_type_info(key_annotation)
537
577
  )
578
+ value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
538
579
  value_decoder = make_engine_struct_decoder(
539
- ["(value)"], value_fields_schema, analyze_type_info(value_annotation)
580
+ ["<value>"], value_fields_schema, analyze_type_info(value_annotation)
540
581
  )
541
582
 
542
- loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
543
- prepare_method = getattr(self._connector_cls, "prepare", None)
544
- if prepare_method is None:
545
- prepared_spec = loaded_spec
546
- else:
547
- prepared_spec = prepare_method(loaded_spec)
548
-
583
+ spec = load_engine_object(self._spec_cls, raw_spec)
584
+ index_options = load_engine_object(IndexOptions, raw_index_options)
549
585
  return _TargetConnectorContext(
550
586
  target_name=name,
551
- spec=loaded_spec,
552
- prepared_spec=prepared_spec,
587
+ spec=spec,
588
+ prepared_spec=None,
589
+ key_fields_schema=key_fields_schema,
553
590
  key_decoder=key_decoder,
591
+ value_fields_schema=value_fields_schema,
554
592
  value_decoder=value_decoder,
593
+ index_options=index_options,
594
+ setup_state=None,
555
595
  )
556
596
 
557
597
  def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
558
- return self._get_persistent_key_fn(
559
- export_context.spec, export_context.target_name
598
+ args = _build_args(
599
+ self._get_persistent_key_fn,
600
+ 1,
601
+ spec=export_context.spec,
602
+ target_name=export_context.target_name,
603
+ )
604
+ return dump_engine_object(self._get_persistent_key_fn(*args))
605
+
606
+ def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
607
+ get_setup_state_fn = getattr(self._connector_cls, "get_setup_state", None)
608
+ if get_setup_state_fn is None:
609
+ state = export_context.spec
610
+ if not isinstance(state, self._setup_state_cls):
611
+ raise ValueError(
612
+ f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._setup_state_cls}"
613
+ )
614
+ else:
615
+ args = _build_args(
616
+ get_setup_state_fn,
617
+ 1,
618
+ spec=export_context.spec,
619
+ key_fields_schema=export_context.key_fields_schema,
620
+ value_fields_schema=export_context.value_fields_schema,
621
+ index_options=export_context.index_options,
622
+ )
623
+ state = get_setup_state_fn(*args)
624
+ if not isinstance(state, self._setup_state_cls):
625
+ raise ValueError(
626
+ f"Method {get_setup_state_fn.__name__} must return an instance of {self._setup_state_cls}, got {type(state)}"
627
+ )
628
+ export_context.setup_state = state
629
+ return dump_engine_object(state)
630
+
631
+ def check_state_compatibility(
632
+ self, raw_desired_state: Any, raw_existing_state: Any
633
+ ) -> Any:
634
+ check_state_compatibility_fn = getattr(
635
+ self._connector_cls, "check_state_compatibility", None
560
636
  )
637
+ if check_state_compatibility_fn is not None:
638
+ compatibility = check_state_compatibility_fn(
639
+ load_engine_object(self._setup_state_cls, raw_desired_state),
640
+ load_engine_object(self._setup_state_cls, raw_existing_state),
641
+ )
642
+ else:
643
+ compatibility = (
644
+ TargetStateCompatibility.COMPATIBLE
645
+ if raw_desired_state == raw_existing_state
646
+ else TargetStateCompatibility.PARTIALLY_COMPATIBLE
647
+ )
648
+ return dump_engine_object(compatibility)
561
649
 
562
- def describe_resource(self, key: Any) -> str:
650
+ async def prepare_async(
651
+ self,
652
+ export_context: _TargetConnectorContext,
653
+ ) -> None:
654
+ prepare_fn = getattr(self._connector_cls, "prepare", None)
655
+ if prepare_fn is None:
656
+ export_context.prepared_spec = export_context.spec
657
+ return
658
+ args = _build_args(
659
+ prepare_fn,
660
+ 1,
661
+ spec=export_context.spec,
662
+ setup_state=export_context.setup_state,
663
+ key_fields_schema=export_context.key_fields_schema,
664
+ value_fields_schema=export_context.value_fields_schema,
665
+ )
666
+ async_prepare_fn = to_async_call(prepare_fn)
667
+ export_context.prepared_spec = await async_prepare_fn(*args)
668
+
669
+ def describe_resource(self, raw_key: Any) -> str:
670
+ key = load_engine_object(self._persistent_key_type, raw_key)
563
671
  describe_fn = getattr(self._connector_cls, "describe", None)
564
672
  if describe_fn is None:
565
673
  return str(key)
@@ -569,15 +677,16 @@ class _TargetConnector:
569
677
  self,
570
678
  changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
571
679
  ) -> None:
572
- for key, previous, current in changes:
680
+ for raw_key, previous, current in changes:
681
+ key = load_engine_object(self._persistent_key_type, raw_key)
573
682
  prev_specs = [
574
- _load_spec_from_engine(self._spec_cls, spec)
683
+ load_engine_object(self._setup_state_cls, spec)
575
684
  if spec is not None
576
685
  else None
577
686
  for spec in previous
578
687
  ]
579
688
  curr_spec = (
580
- _load_spec_from_engine(self._spec_cls, current)
689
+ load_engine_object(self._setup_state_cls, current)
581
690
  if current is not None
582
691
  else None
583
692
  )
@@ -610,7 +719,12 @@ class _TargetConnector:
610
719
  )
611
720
 
612
721
 
613
- def target_connector(spec_cls: type) -> Callable[[type], type]:
722
+ def target_connector(
723
+ *,
724
+ spec_cls: type[Any],
725
+ persistent_key_type: Any = Any,
726
+ setup_state_cls: type[Any] | None = None,
727
+ ) -> Callable[[type], type]:
614
728
  """
615
729
  Decorate a class to provide a target connector for an op.
616
730
  """
@@ -621,7 +735,9 @@ def target_connector(spec_cls: type) -> Callable[[type], type]:
621
735
 
622
736
  # Register the target connector.
623
737
  def _inner(connector_cls: type) -> type:
624
- connector = _TargetConnector(spec_cls, connector_cls)
738
+ connector = _TargetConnector(
739
+ spec_cls, persistent_key_type, setup_state_cls or spec_cls, connector_cls
740
+ )
625
741
  _engine.register_target_connector(spec_cls.__name__, connector)
626
742
  return connector_cls
627
743
 
@@ -0,0 +1,51 @@
1
+ import dataclasses
2
+ import numpy as np
3
+ from numpy import typing as npt
4
+ from typing import Generic, TypeVar
5
+ from .index import VectorSimilarityMetric
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class QueryHandlerResultFields:
10
+ """
11
+ Specify field names in query results returned by the query handler.
12
+ This provides metadata for tools like CocoInsight to recognize structure of the query results.
13
+ """
14
+
15
+ embedding: list[str] = dataclasses.field(default_factory=list)
16
+ score: str | None = None
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class QueryHandlerInfo:
21
+ """
22
+ Info to configure a query handler.
23
+ """
24
+
25
+ result_fields: QueryHandlerResultFields | None = None
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class QueryInfo:
30
+ """
31
+ Info about the query.
32
+ """
33
+
34
+ embedding: list[float] | npt.NDArray[np.float32] | None = None
35
+ similarity_metric: VectorSimilarityMetric | None = None
36
+
37
+
38
+ R = TypeVar("R")
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class QueryOutput(Generic[R]):
43
+ """
44
+ Output of a query handler.
45
+
46
+ results: list of results. Each result can be a dict or a dataclass.
47
+ query_info: Info about the query.
48
+ """
49
+
50
+ results: list[R]
51
+ query_info: QueryInfo = dataclasses.field(default_factory=QueryInfo)
cocoindex/runtime.py CHANGED
@@ -5,7 +5,8 @@ manner.
5
5
 
6
6
  import threading
7
7
  import asyncio
8
- from typing import Any, Coroutine, TypeVar
8
+ import inspect
9
+ from typing import Any, Callable, Coroutine, TypeVar, Awaitable
9
10
 
10
11
 
11
12
  T = TypeVar("T")
@@ -35,3 +36,9 @@ class _ExecutionContext:
35
36
 
36
37
 
37
38
  execution_context = _ExecutionContext()
39
+
40
+
41
+ def to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
42
+ if inspect.iscoroutinefunction(call):
43
+ return call
44
+ return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
@@ -0,0 +1,5 @@
1
+ """
2
+ Targets supported by CocoIndex.
3
+ """
4
+
5
+ from ._engine_builtin_specs import *
@@ -3,10 +3,10 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Sequence
5
5
 
6
- from . import op
7
- from . import index
8
- from .auth_registry import AuthEntryReference
9
- from .setting import DatabaseConnectionSpec
6
+ from .. import op
7
+ from .. import index
8
+ from ..auth_registry import AuthEntryReference
9
+ from ..setting import DatabaseConnectionSpec
10
10
 
11
11
 
12
12
  class Postgres(op.TargetSpec):