cocoindex 0.1.73__cp313-cp313-manylinux_2_28_aarch64.whl → 0.1.75__cp313-cp313-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/op.py CHANGED
@@ -6,11 +6,31 @@ import asyncio
6
6
  import dataclasses
7
7
  import inspect
8
8
  from enum import Enum
9
- from typing import Any, Awaitable, Callable, Protocol, dataclass_transform, Annotated
9
+ from typing import (
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ Protocol,
14
+ dataclass_transform,
15
+ Annotated,
16
+ get_args,
17
+ )
10
18
 
11
19
  from . import _engine # type: ignore
12
- from .convert import encode_engine_value, make_engine_value_decoder
13
- from .typing import TypeAttr, encode_enriched_type, resolve_forward_ref
20
+ from .convert import (
21
+ encode_engine_value,
22
+ make_engine_value_decoder,
23
+ make_engine_struct_decoder,
24
+ )
25
+ from .typing import (
26
+ TypeAttr,
27
+ encode_enriched_type,
28
+ resolve_forward_ref,
29
+ analyze_type_info,
30
+ AnalyzedAnyType,
31
+ AnalyzedBasicType,
32
+ AnalyzedDictType,
33
+ )
14
34
 
15
35
 
16
36
  class OpCategory(Enum):
@@ -65,6 +85,22 @@ class Executor(Protocol):
65
85
  op_category: OpCategory
66
86
 
67
87
 
88
+ def _load_spec_from_engine(spec_cls: type, spec: dict[str, Any]) -> Any:
89
+ """
90
+ Load a spec from the engine.
91
+ """
92
+ return spec_cls(**spec)
93
+
94
+
95
+ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
96
+ method = getattr(cls, name, None)
97
+ if method is None:
98
+ raise ValueError(f"Method {name}() is required for {cls.__name__}")
99
+ if not inspect.isfunction(method):
100
+ raise ValueError(f"Method {cls.__name__}.{name}() is not a function")
101
+ return method
102
+
103
+
68
104
  class _FunctionExecutorFactory:
69
105
  _spec_cls: type
70
106
  _executor_cls: type
@@ -76,7 +112,7 @@ class _FunctionExecutorFactory:
76
112
  def __call__(
77
113
  self, spec: dict[str, Any], *args: Any, **kwargs: Any
78
114
  ) -> tuple[dict[str, Any], Executor]:
79
- spec = self._spec_cls(**spec)
115
+ spec = _load_spec_from_engine(self._spec_cls, spec)
80
116
  executor = self._executor_cls(spec)
81
117
  result_type = executor.analyze(*args, **kwargs)
82
118
  return (encode_enriched_type(result_type), executor)
@@ -185,7 +221,9 @@ def _register_op_factory(
185
221
  )
186
222
  self._args_decoders.append(
187
223
  make_engine_value_decoder(
188
- [arg_name], arg.value_type["type"], arg_param.annotation
224
+ [arg_name],
225
+ arg.value_type["type"],
226
+ analyze_type_info(arg_param.annotation),
189
227
  )
190
228
  )
191
229
  process_attribute(arg_name, arg)
@@ -217,7 +255,9 @@ def _register_op_factory(
217
255
  )
218
256
  arg_param = expected_arg[1]
219
257
  self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
220
- [kwarg_name], kwarg.value_type["type"], arg_param.annotation
258
+ [kwarg_name],
259
+ kwarg.value_type["type"],
260
+ analyze_type_info(arg_param.annotation),
221
261
  )
222
262
  process_attribute(kwarg_name, kwarg)
223
263
 
@@ -359,3 +399,221 @@ def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
359
399
  return _Spec()
360
400
 
361
401
  return _inner
402
+
403
+
404
+ ########################################################
405
+ # Custom target connector
406
+ ########################################################
407
+
408
+
409
+ @dataclasses.dataclass
410
+ class _TargetConnectorContext:
411
+ target_name: str
412
+ spec: Any
413
+ prepared_spec: Any
414
+ key_decoder: Callable[[Any], Any]
415
+ value_decoder: Callable[[Any], Any]
416
+
417
+
418
+ class _TargetConnector:
419
+ """
420
+ The connector class passed to the engine.
421
+ """
422
+
423
+ _spec_cls: type
424
+ _connector_cls: type
425
+
426
+ _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
427
+ _apply_setup_change_async_fn: Callable[
428
+ [Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
429
+ ]
430
+ _mutate_async_fn: Callable[..., Awaitable[None]]
431
+ _mutatation_type: AnalyzedDictType | None
432
+
433
+ def __init__(self, spec_cls: type, connector_cls: type):
434
+ self._spec_cls = spec_cls
435
+ self._connector_cls = connector_cls
436
+
437
+ self._get_persistent_key_fn = _get_required_method(
438
+ connector_cls, "get_persistent_key"
439
+ )
440
+ self._apply_setup_change_async_fn = _to_async_call(
441
+ _get_required_method(connector_cls, "apply_setup_change")
442
+ )
443
+
444
+ mutate_fn = _get_required_method(connector_cls, "mutate")
445
+ self._mutate_async_fn = _to_async_call(mutate_fn)
446
+
447
+ # Store the type annotation for later use
448
+ self._mutatation_type = self._analyze_mutate_mutation_type(
449
+ connector_cls, mutate_fn
450
+ )
451
+
452
+ @staticmethod
453
+ def _analyze_mutate_mutation_type(
454
+ connector_cls: type, mutate_fn: Callable[..., Any]
455
+ ) -> AnalyzedDictType | None:
456
+ # Validate mutate_fn signature and extract type annotation
457
+ mutate_sig = inspect.signature(mutate_fn)
458
+ params = list(mutate_sig.parameters.values())
459
+
460
+ if len(params) != 1:
461
+ raise ValueError(
462
+ f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
463
+ f"got {len(params)}"
464
+ )
465
+
466
+ param = params[0]
467
+ if param.kind != inspect.Parameter.VAR_POSITIONAL:
468
+ raise ValueError(
469
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
470
+ f"got {param.kind.name}"
471
+ )
472
+
473
+ # Extract type annotation
474
+ analyzed_args_type = analyze_type_info(param.annotation)
475
+ if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
476
+ return None
477
+
478
+ if analyzed_args_type.base_type is tuple:
479
+ args = get_args(analyzed_args_type.core_type)
480
+ if not args:
481
+ return None
482
+ if len(args) == 2:
483
+ mutation_type = analyze_type_info(args[1])
484
+ if isinstance(mutation_type.variant, AnalyzedAnyType):
485
+ return None
486
+ if isinstance(mutation_type.variant, AnalyzedDictType):
487
+ return mutation_type.variant
488
+
489
+ raise ValueError(
490
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
491
+ f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
492
+ "got {args_type}"
493
+ )
494
+
495
+ def create_export_context(
496
+ self,
497
+ name: str,
498
+ spec: dict[str, Any],
499
+ key_fields_schema: list[Any],
500
+ value_fields_schema: list[Any],
501
+ ) -> _TargetConnectorContext:
502
+ key_annotation, value_annotation = (
503
+ (
504
+ self._mutatation_type.key_type,
505
+ self._mutatation_type.value_type,
506
+ )
507
+ if self._mutatation_type is not None
508
+ else (Any, Any)
509
+ )
510
+
511
+ key_type_info = analyze_type_info(key_annotation)
512
+ if (
513
+ len(key_fields_schema) == 1
514
+ and key_fields_schema[0]["type"]["kind"] != "Struct"
515
+ and isinstance(key_type_info.variant, (AnalyzedAnyType, AnalyzedBasicType))
516
+ ):
517
+ # Special case for ease of use: single key column can be mapped to a basic type without the wrapper struct.
518
+ key_decoder = make_engine_value_decoder(
519
+ ["(key)"],
520
+ key_fields_schema[0]["type"],
521
+ key_type_info,
522
+ for_key=True,
523
+ )
524
+ else:
525
+ key_decoder = make_engine_struct_decoder(
526
+ ["(key)"], key_fields_schema, key_type_info, for_key=True
527
+ )
528
+
529
+ value_decoder = make_engine_struct_decoder(
530
+ ["(value)"], value_fields_schema, analyze_type_info(value_annotation)
531
+ )
532
+
533
+ loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
534
+ prepare_method = getattr(self._connector_cls, "prepare", None)
535
+ if prepare_method is None:
536
+ prepared_spec = loaded_spec
537
+ else:
538
+ prepared_spec = prepare_method(loaded_spec)
539
+
540
+ return _TargetConnectorContext(
541
+ target_name=name,
542
+ spec=loaded_spec,
543
+ prepared_spec=prepared_spec,
544
+ key_decoder=key_decoder,
545
+ value_decoder=value_decoder,
546
+ )
547
+
548
+ def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
549
+ return self._get_persistent_key_fn(
550
+ export_context.spec, export_context.target_name
551
+ )
552
+
553
+ def describe_resource(self, key: Any) -> str:
554
+ describe_fn = getattr(self._connector_cls, "describe", None)
555
+ if describe_fn is None:
556
+ return str(key)
557
+ return str(describe_fn(key))
558
+
559
+ async def apply_setup_changes_async(
560
+ self,
561
+ changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
562
+ ) -> None:
563
+ for key, previous, current in changes:
564
+ prev_specs = [
565
+ _load_spec_from_engine(self._spec_cls, spec)
566
+ if spec is not None
567
+ else None
568
+ for spec in previous
569
+ ]
570
+ curr_spec = (
571
+ _load_spec_from_engine(self._spec_cls, current)
572
+ if current is not None
573
+ else None
574
+ )
575
+ for prev_spec in prev_specs:
576
+ await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
577
+
578
+ @staticmethod
579
+ def _decode_mutation(
580
+ context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
581
+ ) -> tuple[Any, dict[Any, Any | None]]:
582
+ return (
583
+ context.prepared_spec,
584
+ {
585
+ context.key_decoder(key): (
586
+ context.value_decoder(value) if value is not None else None
587
+ )
588
+ for key, value in mutation
589
+ },
590
+ )
591
+
592
+ async def mutate_async(
593
+ self,
594
+ mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
595
+ ) -> None:
596
+ await self._mutate_async_fn(
597
+ *(
598
+ self._decode_mutation(context, mutation)
599
+ for context, mutation in mutations
600
+ )
601
+ )
602
+
603
+
604
+ def target_connector(spec_cls: type) -> Callable[[type], type]:
605
+ """
606
+ Decorate a class to provide a target connector for an op.
607
+ """
608
+
609
+ # Validate the spec_cls is a TargetSpec.
610
+ if not issubclass(spec_cls, TargetSpec):
611
+ raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
612
+
613
+ # Register the target connector.
614
+ def _inner(connector_cls: type) -> type:
615
+ connector = _TargetConnector(spec_cls, connector_cls)
616
+ _engine.register_target_connector(spec_cls.__name__, connector)
617
+ return connector_cls
618
+
619
+ return _inner