cocoindex 0.2.11__cp311-abi3-macosx_11_0_arm64.whl → 0.2.13__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/op.py CHANGED
@@ -18,6 +18,8 @@ from typing import (
18
18
  from . import _engine # type: ignore
19
19
  from .subprocess_exec import executor_stub
20
20
  from .convert import (
21
+ dump_engine_object,
22
+ load_engine_object,
21
23
  make_engine_value_encoder,
22
24
  make_engine_value_decoder,
23
25
  make_engine_key_decoder,
@@ -30,8 +32,12 @@ from .typing import (
30
32
  analyze_type_info,
31
33
  AnalyzedAnyType,
32
34
  AnalyzedDictType,
35
+ EnrichedValueType,
36
+ decode_engine_field_schemas,
37
+ FieldSchema,
33
38
  )
34
39
  from .runtime import to_async_call
40
+ from .index import IndexOptions
35
41
 
36
42
 
37
43
  class OpCategory(Enum):
@@ -86,15 +92,6 @@ class Executor(Protocol):
86
92
  op_category: OpCategory
87
93
 
88
94
 
89
- def _load_spec_from_engine(
90
- spec_loader: Callable[..., Any], spec: dict[str, Any]
91
- ) -> Any:
92
- """
93
- Load a spec from the engine.
94
- """
95
- return spec_loader(**spec)
96
-
97
-
98
95
  def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
99
96
  method = getattr(cls, name, None)
100
97
  if method is None:
@@ -105,7 +102,7 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
105
102
 
106
103
 
107
104
  class _EngineFunctionExecutorFactory:
108
- _spec_loader: Callable[..., Any]
105
+ _spec_loader: Callable[[Any], Any]
109
106
  _executor_cls: type
110
107
 
111
108
  def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
@@ -113,9 +110,9 @@ class _EngineFunctionExecutorFactory:
113
110
  self._executor_cls = executor_cls
114
111
 
115
112
  def __call__(
116
- self, spec: dict[str, Any], *args: Any, **kwargs: Any
113
+ self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
117
114
  ) -> tuple[dict[str, Any], Executor]:
118
- spec = _load_spec_from_engine(self._spec_loader, spec)
115
+ spec = self._spec_loader(raw_spec)
119
116
  executor = self._executor_cls(spec)
120
117
  result_type = executor.analyze_schema(*args, **kwargs)
121
118
  return (result_type, executor)
@@ -212,8 +209,9 @@ def _register_op_factory(
212
209
  TypeAttr(related_attr.value, actual_arg.analyzed_value)
213
210
  )
214
211
  type_info = analyze_type_info(arg_param.annotation)
212
+ enriched = EnrichedValueType.decode(actual_arg.value_type)
215
213
  decoder = make_engine_value_decoder(
216
- [arg_name], actual_arg.value_type["type"], type_info
214
+ [arg_name], enriched.type, type_info
217
215
  )
218
216
  is_required = not type_info.nullable
219
217
  if is_required and actual_arg.value_type.get("nullable", False):
@@ -373,7 +371,7 @@ def executor_class(**args: Any) -> Callable[[type], type]:
373
371
  expected_args=list(sig.parameters.items())[1:], # First argument is `self`
374
372
  expected_return=sig.return_annotation,
375
373
  executor_factory=cls,
376
- spec_loader=spec_cls,
374
+ spec_loader=lambda v: load_engine_object(spec_cls, v),
377
375
  op_kind=spec_cls.__name__,
378
376
  op_args=op_args,
379
377
  )
@@ -409,7 +407,7 @@ def function(**args: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
409
407
  expected_args=list(sig.parameters.items()),
410
408
  expected_return=sig.return_annotation,
411
409
  executor_factory=_SimpleFunctionExecutor,
412
- spec_loader=lambda: fn,
410
+ spec_loader=lambda _: fn,
413
411
  op_kind=op_kind,
414
412
  op_args=op_args,
415
413
  )
@@ -429,8 +427,44 @@ class _TargetConnectorContext:
429
427
  target_name: str
430
428
  spec: Any
431
429
  prepared_spec: Any
430
+ key_fields_schema: list[FieldSchema]
432
431
  key_decoder: Callable[[Any], Any]
432
+ value_fields_schema: list[FieldSchema]
433
433
  value_decoder: Callable[[Any], Any]
434
+ index_options: IndexOptions
435
+ setup_state: Any
436
+
437
+
438
+ def _build_args(
439
+ method: Callable[..., Any], num_required_args: int, **kwargs: Any
440
+ ) -> list[Any]:
441
+ signature = inspect.signature(method)
442
+ for param in signature.parameters.values():
443
+ if param.kind not in (
444
+ inspect.Parameter.POSITIONAL_ONLY,
445
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
446
+ ):
447
+ raise ValueError(
448
+ f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
449
+ )
450
+ if len(signature.parameters) < num_required_args:
451
+ raise ValueError(
452
+ f"Method {method.__name__} must have at least {num_required_args} required arguments: "
453
+ f"{', '.join(list(kwargs.keys())[:num_required_args])}"
454
+ )
455
+ if len(kwargs) > len(kwargs):
456
+ raise ValueError(
457
+ f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
458
+ )
459
+ return [v for _, v in zip(signature.parameters, kwargs.values())]
460
+
461
+
462
+ class TargetStateCompatibility(Enum):
463
+ """The compatibility of the target state."""
464
+
465
+ COMPATIBLE = "Compatible"
466
+ PARTIALLY_COMPATIBLE = "PartialCompatible"
467
+ NOT_COMPATIBLE = "NotCompatible"
434
468
 
435
469
 
436
470
  class _TargetConnector:
@@ -438,8 +472,10 @@ class _TargetConnector:
438
472
  The connector class passed to the engine.
439
473
  """
440
474
 
441
- _spec_cls: type
442
- _connector_cls: type
475
+ _spec_cls: type[Any]
476
+ _persistent_key_type: Any
477
+ _setup_state_cls: type[Any]
478
+ _connector_cls: type[Any]
443
479
 
444
480
  _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
445
481
  _apply_setup_change_async_fn: Callable[
@@ -448,8 +484,16 @@ class _TargetConnector:
448
484
  _mutate_async_fn: Callable[..., Awaitable[None]]
449
485
  _mutatation_type: AnalyzedDictType | None
450
486
 
451
- def __init__(self, spec_cls: type, connector_cls: type):
487
+ def __init__(
488
+ self,
489
+ spec_cls: type[Any],
490
+ persistent_key_type: Any,
491
+ setup_state_cls: type[Any],
492
+ connector_cls: type[Any],
493
+ ):
452
494
  self._spec_cls = spec_cls
495
+ self._persistent_key_type = persistent_key_type
496
+ self._setup_state_cls = setup_state_cls
453
497
  self._connector_cls = connector_cls
454
498
 
455
499
  self._get_persistent_key_fn = _get_required_method(
@@ -507,15 +551,16 @@ class _TargetConnector:
507
551
  raise ValueError(
508
552
  f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
509
553
  f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
510
- "got {args_type}"
554
+ f"got {analyzed_args_type.core_type}"
511
555
  )
512
556
 
513
557
  def create_export_context(
514
558
  self,
515
559
  name: str,
516
- spec: dict[str, Any],
517
- key_fields_schema: list[Any],
518
- value_fields_schema: list[Any],
560
+ raw_spec: dict[str, Any],
561
+ raw_key_fields_schema: list[Any],
562
+ raw_value_fields_schema: list[Any],
563
+ raw_index_options: dict[str, Any],
519
564
  ) -> _TargetConnectorContext:
520
565
  key_annotation, value_annotation = (
521
566
  (
@@ -526,34 +571,103 @@ class _TargetConnector:
526
571
  else (Any, Any)
527
572
  )
528
573
 
574
+ key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
529
575
  key_decoder = make_engine_key_decoder(
530
- ["(key)"], key_fields_schema, analyze_type_info(key_annotation)
576
+ ["<key>"], key_fields_schema, analyze_type_info(key_annotation)
531
577
  )
578
+ value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
532
579
  value_decoder = make_engine_struct_decoder(
533
- ["(value)"], value_fields_schema, analyze_type_info(value_annotation)
580
+ ["<value>"], value_fields_schema, analyze_type_info(value_annotation)
534
581
  )
535
582
 
536
- loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
537
- prepare_method = getattr(self._connector_cls, "prepare", None)
538
- if prepare_method is None:
539
- prepared_spec = loaded_spec
540
- else:
541
- prepared_spec = prepare_method(loaded_spec)
542
-
583
+ spec = load_engine_object(self._spec_cls, raw_spec)
584
+ index_options = load_engine_object(IndexOptions, raw_index_options)
543
585
  return _TargetConnectorContext(
544
586
  target_name=name,
545
- spec=loaded_spec,
546
- prepared_spec=prepared_spec,
587
+ spec=spec,
588
+ prepared_spec=None,
589
+ key_fields_schema=key_fields_schema,
547
590
  key_decoder=key_decoder,
591
+ value_fields_schema=value_fields_schema,
548
592
  value_decoder=value_decoder,
593
+ index_options=index_options,
594
+ setup_state=None,
549
595
  )
550
596
 
551
597
  def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
552
- return self._get_persistent_key_fn(
553
- export_context.spec, export_context.target_name
598
+ args = _build_args(
599
+ self._get_persistent_key_fn,
600
+ 1,
601
+ spec=export_context.spec,
602
+ target_name=export_context.target_name,
603
+ )
604
+ return dump_engine_object(self._get_persistent_key_fn(*args))
605
+
606
+ def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
607
+ get_setup_state_fn = getattr(self._connector_cls, "get_setup_state", None)
608
+ if get_setup_state_fn is None:
609
+ state = export_context.spec
610
+ if not isinstance(state, self._setup_state_cls):
611
+ raise ValueError(
612
+ f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._setup_state_cls}"
613
+ )
614
+ else:
615
+ args = _build_args(
616
+ get_setup_state_fn,
617
+ 1,
618
+ spec=export_context.spec,
619
+ key_fields_schema=export_context.key_fields_schema,
620
+ value_fields_schema=export_context.value_fields_schema,
621
+ index_options=export_context.index_options,
622
+ )
623
+ state = get_setup_state_fn(*args)
624
+ if not isinstance(state, self._setup_state_cls):
625
+ raise ValueError(
626
+ f"Method {get_setup_state_fn.__name__} must return an instance of {self._setup_state_cls}, got {type(state)}"
627
+ )
628
+ export_context.setup_state = state
629
+ return dump_engine_object(state)
630
+
631
+ def check_state_compatibility(
632
+ self, raw_desired_state: Any, raw_existing_state: Any
633
+ ) -> Any:
634
+ check_state_compatibility_fn = getattr(
635
+ self._connector_cls, "check_state_compatibility", None
636
+ )
637
+ if check_state_compatibility_fn is not None:
638
+ compatibility = check_state_compatibility_fn(
639
+ load_engine_object(self._setup_state_cls, raw_desired_state),
640
+ load_engine_object(self._setup_state_cls, raw_existing_state),
641
+ )
642
+ else:
643
+ compatibility = (
644
+ TargetStateCompatibility.COMPATIBLE
645
+ if raw_desired_state == raw_existing_state
646
+ else TargetStateCompatibility.PARTIALLY_COMPATIBLE
647
+ )
648
+ return dump_engine_object(compatibility)
649
+
650
+ async def prepare_async(
651
+ self,
652
+ export_context: _TargetConnectorContext,
653
+ ) -> None:
654
+ prepare_fn = getattr(self._connector_cls, "prepare", None)
655
+ if prepare_fn is None:
656
+ export_context.prepared_spec = export_context.spec
657
+ return
658
+ args = _build_args(
659
+ prepare_fn,
660
+ 1,
661
+ spec=export_context.spec,
662
+ setup_state=export_context.setup_state,
663
+ key_fields_schema=export_context.key_fields_schema,
664
+ value_fields_schema=export_context.value_fields_schema,
554
665
  )
666
+ async_prepare_fn = to_async_call(prepare_fn)
667
+ export_context.prepared_spec = await async_prepare_fn(*args)
555
668
 
556
- def describe_resource(self, key: Any) -> str:
669
+ def describe_resource(self, raw_key: Any) -> str:
670
+ key = load_engine_object(self._persistent_key_type, raw_key)
557
671
  describe_fn = getattr(self._connector_cls, "describe", None)
558
672
  if describe_fn is None:
559
673
  return str(key)
@@ -563,15 +677,16 @@ class _TargetConnector:
563
677
  self,
564
678
  changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
565
679
  ) -> None:
566
- for key, previous, current in changes:
680
+ for raw_key, previous, current in changes:
681
+ key = load_engine_object(self._persistent_key_type, raw_key)
567
682
  prev_specs = [
568
- _load_spec_from_engine(self._spec_cls, spec)
683
+ load_engine_object(self._setup_state_cls, spec)
569
684
  if spec is not None
570
685
  else None
571
686
  for spec in previous
572
687
  ]
573
688
  curr_spec = (
574
- _load_spec_from_engine(self._spec_cls, current)
689
+ load_engine_object(self._setup_state_cls, current)
575
690
  if current is not None
576
691
  else None
577
692
  )
@@ -604,7 +719,12 @@ class _TargetConnector:
604
719
  )
605
720
 
606
721
 
607
- def target_connector(spec_cls: type) -> Callable[[type], type]:
722
+ def target_connector(
723
+ *,
724
+ spec_cls: type[Any],
725
+ persistent_key_type: Any = Any,
726
+ setup_state_cls: type[Any] | None = None,
727
+ ) -> Callable[[type], type]:
608
728
  """
609
729
  Decorate a class to provide a target connector for an op.
610
730
  """
@@ -615,7 +735,9 @@ def target_connector(spec_cls: type) -> Callable[[type], type]:
615
735
 
616
736
  # Register the target connector.
617
737
  def _inner(connector_cls: type) -> type:
618
- connector = _TargetConnector(spec_cls, connector_cls)
738
+ connector = _TargetConnector(
739
+ spec_cls, persistent_key_type, setup_state_cls or spec_cls, connector_cls
740
+ )
619
741
  _engine.register_target_connector(spec_cls.__name__, connector)
620
742
  return connector_cls
621
743
 
@@ -1,8 +1,14 @@
1
1
  import dataclasses
2
2
  import numpy as np
3
3
  from numpy import typing as npt
4
- from typing import Generic, TypeVar
4
+ from typing import Generic, Any
5
5
  from .index import VectorSimilarityMetric
6
+ import sys
7
+
8
+ if sys.version_info >= (3, 13):
9
+ from typing import TypeVar
10
+ else:
11
+ from typing_extensions import TypeVar # PEP 696 backport
6
12
 
7
13
 
8
14
  @dataclasses.dataclass
@@ -35,7 +41,7 @@ class QueryInfo:
35
41
  similarity_metric: VectorSimilarityMetric | None = None
36
42
 
37
43
 
38
- R = TypeVar("R")
44
+ R = TypeVar("R", default=Any)
39
45
 
40
46
 
41
47
  @dataclasses.dataclass
@@ -0,0 +1,5 @@
1
+ """
2
+ Targets supported by CocoIndex.
3
+ """
4
+
5
+ from ._engine_builtin_specs import *
@@ -3,10 +3,10 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Sequence
5
5
 
6
- from . import op
7
- from . import index
8
- from .auth_registry import AuthEntryReference
9
- from .setting import DatabaseConnectionSpec
6
+ from .. import op
7
+ from .. import index
8
+ from ..auth_registry import AuthEntryReference
9
+ from ..setting import DatabaseConnectionSpec
10
10
 
11
11
 
12
12
  class Postgres(op.TargetSpec):