cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cocoindex/__init__.py +114 -0
  2. cocoindex/_engine.abi3.so +0 -0
  3. cocoindex/auth_registry.py +44 -0
  4. cocoindex/cli.py +830 -0
  5. cocoindex/engine_object.py +214 -0
  6. cocoindex/engine_value.py +550 -0
  7. cocoindex/flow.py +1281 -0
  8. cocoindex/functions/__init__.py +40 -0
  9. cocoindex/functions/_engine_builtin_specs.py +66 -0
  10. cocoindex/functions/colpali.py +247 -0
  11. cocoindex/functions/sbert.py +77 -0
  12. cocoindex/index.py +50 -0
  13. cocoindex/lib.py +75 -0
  14. cocoindex/llm.py +47 -0
  15. cocoindex/op.py +1047 -0
  16. cocoindex/py.typed +0 -0
  17. cocoindex/query_handler.py +57 -0
  18. cocoindex/runtime.py +78 -0
  19. cocoindex/setting.py +171 -0
  20. cocoindex/setup.py +92 -0
  21. cocoindex/sources/__init__.py +5 -0
  22. cocoindex/sources/_engine_builtin_specs.py +120 -0
  23. cocoindex/subprocess_exec.py +277 -0
  24. cocoindex/targets/__init__.py +5 -0
  25. cocoindex/targets/_engine_builtin_specs.py +153 -0
  26. cocoindex/targets/lancedb.py +466 -0
  27. cocoindex/tests/__init__.py +0 -0
  28. cocoindex/tests/test_engine_object.py +331 -0
  29. cocoindex/tests/test_engine_value.py +1724 -0
  30. cocoindex/tests/test_optional_database.py +249 -0
  31. cocoindex/tests/test_transform_flow.py +300 -0
  32. cocoindex/tests/test_typing.py +553 -0
  33. cocoindex/tests/test_validation.py +134 -0
  34. cocoindex/typing.py +834 -0
  35. cocoindex/user_app_loader.py +53 -0
  36. cocoindex/utils.py +20 -0
  37. cocoindex/validation.py +104 -0
  38. cocoindex-0.3.4.dist-info/METADATA +288 -0
  39. cocoindex-0.3.4.dist-info/RECORD +42 -0
  40. cocoindex-0.3.4.dist-info/WHEEL +4 -0
  41. cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
  42. cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
cocoindex/op.py ADDED
@@ -0,0 +1,1047 @@
1
+ """
2
+ Facilities for defining cocoindex operations.
3
+ """
4
+
5
+ import dataclasses
6
+ import inspect
7
+ from enum import Enum
8
+ from typing import (
9
+ Any,
10
+ Awaitable,
11
+ Callable,
12
+ Iterator,
13
+ Protocol,
14
+ dataclass_transform,
15
+ Annotated,
16
+ TypeVar,
17
+ Generic,
18
+ Literal,
19
+ get_args,
20
+ )
21
+ from collections.abc import AsyncIterator
22
+
23
+ from . import _engine # type: ignore
24
+ from .subprocess_exec import executor_stub
25
+ from .engine_object import dump_engine_object, load_engine_object
26
+ from .engine_value import (
27
+ make_engine_key_encoder,
28
+ make_engine_value_encoder,
29
+ make_engine_value_decoder,
30
+ make_engine_key_decoder,
31
+ make_engine_struct_decoder,
32
+ )
33
+ from .typing import (
34
+ KEY_FIELD_NAME,
35
+ AnalyzedListType,
36
+ AnalyzedTypeInfo,
37
+ StructSchema,
38
+ StructType,
39
+ TableType,
40
+ encode_enriched_type_info,
41
+ resolve_forward_ref,
42
+ analyze_type_info,
43
+ AnalyzedAnyType,
44
+ AnalyzedDictType,
45
+ EnrichedValueType,
46
+ decode_engine_field_schemas,
47
+ FieldSchema,
48
+ ValueType,
49
+ )
50
+ from .runtime import to_async_call
51
+ from .index import IndexOptions
52
+
53
+
54
+ class OpCategory(Enum):
55
+ """The category of the operation."""
56
+
57
+ FUNCTION = "function"
58
+ SOURCE = "source"
59
+ TARGET = "target"
60
+ DECLARATION = "declaration"
61
+ TARGET_ATTACHMENT = "target_attachment"
62
+
63
+
64
+ @dataclass_transform()
65
+ class SpecMeta(type):
66
+ """Meta class for spec classes."""
67
+
68
+ def __new__(
69
+ mcs,
70
+ name: str,
71
+ bases: tuple[type, ...],
72
+ attrs: dict[str, Any],
73
+ category: OpCategory | None = None,
74
+ ) -> type:
75
+ cls: type = super().__new__(mcs, name, bases, attrs)
76
+ if category is not None:
77
+ # It's the base class.
78
+ setattr(cls, "_op_category", category)
79
+ else:
80
+ # It's the specific class providing specific fields.
81
+ cls = dataclasses.dataclass(cls)
82
+ return cls
83
+
84
+
85
+ class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
86
+ """A source spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
87
+
88
+
89
+ class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
90
+ """A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
91
+
92
+
93
+ class TargetSpec(metaclass=SpecMeta, category=OpCategory.TARGET): # pylint: disable=too-few-public-methods
94
+ """A target spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
95
+
96
+
97
+ class TargetAttachmentSpec(metaclass=SpecMeta, category=OpCategory.TARGET_ATTACHMENT): # pylint: disable=too-few-public-methods
98
+ """A target attachment spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
99
+
100
+
101
+ class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
102
+ """A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
103
+
104
+
105
+ class Executor(Protocol):
106
+ """An executor for an operation."""
107
+
108
+ op_category: OpCategory
109
+
110
+
111
+ def _get_required_method(obj: type, name: str) -> Callable[..., Any]:
112
+ method = getattr(obj, name, None)
113
+ if method is None:
114
+ raise ValueError(f"Method {name}() is required for {obj}")
115
+ if not inspect.isfunction(method) and not inspect.ismethod(method):
116
+ raise ValueError(f"{obj}.{name}() is not a function; {method}")
117
+ return method
118
+
119
+
120
+ class _EngineFunctionExecutorFactory:
121
+ _spec_loader: Callable[[Any], Any]
122
+ _executor_cls: type
123
+
124
+ def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
125
+ self._spec_loader = spec_loader
126
+ self._executor_cls = executor_cls
127
+
128
+ def __call__(
129
+ self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
130
+ ) -> tuple[dict[str, Any], Executor]:
131
+ spec = self._spec_loader(raw_spec)
132
+ executor = self._executor_cls(spec)
133
+ result_type = executor.analyze_schema(*args, **kwargs)
134
+ return (result_type, executor)
135
+
136
+
137
+ _COCOINDEX_ATTR_PREFIX = "cocoindex.io/"
138
+
139
+
140
+ class ArgRelationship(Enum):
141
+ """Specifies the relationship between an input argument and the output."""
142
+
143
+ EMBEDDING_ORIGIN_TEXT = _COCOINDEX_ATTR_PREFIX + "embedding_origin_text"
144
+ CHUNKS_BASE_TEXT = _COCOINDEX_ATTR_PREFIX + "chunk_base_text"
145
+ RECTS_BASE_IMAGE = _COCOINDEX_ATTR_PREFIX + "rects_base_image"
146
+
147
+
148
+ @dataclasses.dataclass
149
+ class OpArgs:
150
+ """
151
+ - gpu: Whether the executor will be executed on GPU.
152
+ - cache: Whether the executor will be cached.
153
+ - batching: Whether the executor will be batched.
154
+ - max_batch_size: The maximum batch size for the executor. Only valid if `batching` is True.
155
+ - behavior_version: The behavior version of the executor. Cache will be invalidated if it
156
+ changes. Must be provided if `cache` is True.
157
+ - arg_relationship: It specifies the relationship between an input argument and the output,
158
+ e.g. `(ArgRelationship.CHUNKS_BASE_TEXT, "content")` means the output is chunks for the
159
+ input argument with name `content`.
160
+ """
161
+
162
+ gpu: bool = False
163
+ cache: bool = False
164
+ batching: bool = False
165
+ max_batch_size: int | None = None
166
+ behavior_version: int | None = None
167
+ arg_relationship: tuple[ArgRelationship, str] | None = None
168
+
169
+
170
+ @dataclasses.dataclass
171
+ class _ArgInfo:
172
+ decoder: Callable[[Any], Any]
173
+ is_required: bool
174
+
175
+
176
+ def _make_batched_engine_value_decoder(
177
+ field_path: list[str], src_type: ValueType, dst_type_info: AnalyzedTypeInfo
178
+ ) -> Callable[[Any], Any]:
179
+ if not isinstance(dst_type_info.variant, AnalyzedListType):
180
+ raise ValueError("Expected arguments for batching function to be a list type")
181
+ elem_type_info = analyze_type_info(dst_type_info.variant.elem_type)
182
+ base_decoder = make_engine_value_decoder(field_path, src_type, elem_type_info)
183
+ return lambda value: [base_decoder(v) for v in value]
184
+
185
+
186
+ def _register_op_factory(
187
+ category: OpCategory,
188
+ expected_args: list[tuple[str, inspect.Parameter]],
189
+ expected_return: Any,
190
+ executor_factory: Any,
191
+ spec_loader: Callable[..., Any],
192
+ op_kind: str,
193
+ op_args: OpArgs,
194
+ ) -> None:
195
+ """
196
+ Register an op factory.
197
+ """
198
+
199
+ if op_args.batching:
200
+ if len(expected_args) != 1:
201
+ raise ValueError("Batching is only supported for single argument functions")
202
+
203
+ class _WrappedExecutor:
204
+ _executor: Any
205
+ _args_info: list[_ArgInfo]
206
+ _kwargs_info: dict[str, _ArgInfo]
207
+ _result_encoder: Callable[[Any], Any]
208
+ _acall: Callable[..., Awaitable[Any]] | None = None
209
+
210
+ def __init__(self, spec: Any) -> None:
211
+ executor: Any
212
+
213
+ if op_args.gpu:
214
+ executor = executor_stub(executor_factory, spec)
215
+ else:
216
+ executor = executor_factory()
217
+ executor.spec = spec
218
+
219
+ self._executor = executor
220
+
221
+ def analyze_schema(
222
+ self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
223
+ ) -> Any:
224
+ """
225
+ Analyze the spec and arguments. In this phase, argument types should be validated.
226
+ It should return the expected result type for the current op.
227
+ """
228
+ self._args_info = []
229
+ self._kwargs_info = {}
230
+ attributes = {}
231
+ potentially_missing_required_arg = False
232
+
233
+ def process_arg(
234
+ arg_name: str,
235
+ arg_param: inspect.Parameter,
236
+ actual_arg: _engine.OpArgSchema,
237
+ ) -> _ArgInfo:
238
+ nonlocal potentially_missing_required_arg
239
+ if op_args.arg_relationship is not None:
240
+ related_attr, related_arg_name = op_args.arg_relationship
241
+ if related_arg_name == arg_name:
242
+ attributes[related_attr.value] = actual_arg.analyzed_value
243
+ type_info = analyze_type_info(arg_param.annotation)
244
+ enriched = EnrichedValueType.decode(actual_arg.value_type)
245
+ if op_args.batching:
246
+ decoder = _make_batched_engine_value_decoder(
247
+ [arg_name], enriched.type, type_info
248
+ )
249
+ else:
250
+ decoder = make_engine_value_decoder(
251
+ [arg_name], enriched.type, type_info
252
+ )
253
+ is_required = not type_info.nullable
254
+ if is_required and actual_arg.value_type.get("nullable", False):
255
+ potentially_missing_required_arg = True
256
+ return _ArgInfo(
257
+ decoder=decoder,
258
+ is_required=is_required,
259
+ )
260
+
261
+ # Match arguments with parameters.
262
+ next_param_idx = 0
263
+ for actual_arg in args:
264
+ if next_param_idx >= len(expected_args):
265
+ raise ValueError(
266
+ f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
267
+ )
268
+ arg_name, arg_param = expected_args[next_param_idx]
269
+ if arg_param.kind in (
270
+ inspect.Parameter.KEYWORD_ONLY,
271
+ inspect.Parameter.VAR_KEYWORD,
272
+ ):
273
+ raise ValueError(
274
+ f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
275
+ )
276
+ self._args_info.append(process_arg(arg_name, arg_param, actual_arg))
277
+ if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
278
+ next_param_idx += 1
279
+
280
+ expected_kwargs = expected_args[next_param_idx:]
281
+
282
+ for kwarg_name, actual_arg in kwargs.items():
283
+ expected_arg = next(
284
+ (
285
+ arg
286
+ for arg in expected_kwargs
287
+ if (
288
+ arg[0] == kwarg_name
289
+ and arg[1].kind
290
+ in (
291
+ inspect.Parameter.KEYWORD_ONLY,
292
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
293
+ )
294
+ )
295
+ or arg[1].kind == inspect.Parameter.VAR_KEYWORD
296
+ ),
297
+ None,
298
+ )
299
+ if expected_arg is None:
300
+ raise ValueError(
301
+ f"Unexpected keyword argument passed in: {kwarg_name}"
302
+ )
303
+ arg_param = expected_arg[1]
304
+ self._kwargs_info[kwarg_name] = process_arg(
305
+ kwarg_name, arg_param, actual_arg
306
+ )
307
+
308
+ missing_args = [
309
+ name
310
+ for (name, arg) in expected_kwargs
311
+ if arg.default is inspect.Parameter.empty
312
+ and (
313
+ arg.kind == inspect.Parameter.POSITIONAL_ONLY
314
+ or (
315
+ arg.kind
316
+ in (
317
+ inspect.Parameter.KEYWORD_ONLY,
318
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
319
+ )
320
+ and name not in kwargs
321
+ )
322
+ )
323
+ ]
324
+ if len(missing_args) > 0:
325
+ raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
326
+
327
+ analyzed_expected_return_type = analyze_type_info(expected_return)
328
+ self._result_encoder = make_engine_value_encoder(
329
+ analyzed_expected_return_type
330
+ )
331
+
332
+ base_analyze_method = getattr(self._executor, "analyze", None)
333
+ if base_analyze_method is not None:
334
+ analyzed_result_type = analyze_type_info(base_analyze_method())
335
+ else:
336
+ if op_args.batching:
337
+ if not isinstance(
338
+ analyzed_expected_return_type.variant, AnalyzedListType
339
+ ):
340
+ raise ValueError(
341
+ "Expected return type for batching function to be a list type"
342
+ )
343
+ analyzed_result_type = analyze_type_info(
344
+ analyzed_expected_return_type.variant.elem_type
345
+ )
346
+ else:
347
+ analyzed_result_type = analyzed_expected_return_type
348
+ if len(attributes) > 0:
349
+ analyzed_result_type.attrs = attributes
350
+ if potentially_missing_required_arg:
351
+ analyzed_result_type.nullable = True
352
+ encoded_type = encode_enriched_type_info(analyzed_result_type)
353
+
354
+ return encoded_type
355
+
356
+ async def prepare(self) -> None:
357
+ """
358
+ Prepare for execution.
359
+ It's executed after `analyze` and before any `__call__` execution.
360
+ """
361
+ prepare_method = getattr(self._executor, "prepare", None)
362
+ if prepare_method is not None:
363
+ await to_async_call(prepare_method)()
364
+ self._acall = to_async_call(self._executor.__call__)
365
+
366
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
367
+ decoded_args = []
368
+ for arg_info, arg in zip(self._args_info, args):
369
+ if arg_info.is_required and arg is None:
370
+ return None
371
+ decoded_args.append(arg_info.decoder(arg))
372
+
373
+ decoded_kwargs = {}
374
+ for kwarg_name, arg in kwargs.items():
375
+ kwarg_info = self._kwargs_info.get(kwarg_name)
376
+ if kwarg_info is None:
377
+ raise ValueError(
378
+ f"Unexpected keyword argument passed in: {kwarg_name}"
379
+ )
380
+ if kwarg_info.is_required and arg is None:
381
+ return None
382
+ decoded_kwargs[kwarg_name] = kwarg_info.decoder(arg)
383
+
384
+ assert self._acall is not None
385
+ output = await self._acall(*decoded_args, **decoded_kwargs)
386
+ return self._result_encoder(output)
387
+
388
+ def enable_cache(self) -> bool:
389
+ return op_args.cache
390
+
391
+ def behavior_version(self) -> int | None:
392
+ return op_args.behavior_version
393
+
394
+ def batching_options(self) -> dict[str, Any] | None:
395
+ if op_args.batching:
396
+ return {
397
+ "max_batch_size": op_args.max_batch_size,
398
+ }
399
+ else:
400
+ return None
401
+
402
+ if category == OpCategory.FUNCTION:
403
+ _engine.register_function_factory(
404
+ op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
405
+ )
406
+ else:
407
+ raise ValueError(f"Unsupported executor type {category}")
408
+
409
+
410
+ def executor_class(**args: Any) -> Callable[[type], type]:
411
+ """
412
+ Decorate a class to provide an executor for an op.
413
+ """
414
+ op_args = OpArgs(**args)
415
+
416
+ def _inner(cls: type[Executor]) -> type:
417
+ """
418
+ Decorate a class to provide an executor for an op.
419
+ """
420
+ # Use `__annotations__` instead of `get_type_hints`, to avoid resolving forward references.
421
+ type_hints = cls.__annotations__
422
+ if "spec" not in type_hints:
423
+ raise TypeError("Expect a `spec` field with type hint")
424
+ spec_cls = resolve_forward_ref(type_hints["spec"])
425
+ sig = inspect.signature(cls.__call__)
426
+ _register_op_factory(
427
+ category=spec_cls._op_category,
428
+ expected_args=list(sig.parameters.items())[1:], # First argument is `self`
429
+ expected_return=sig.return_annotation,
430
+ executor_factory=cls,
431
+ spec_loader=lambda v: load_engine_object(spec_cls, v),
432
+ op_kind=spec_cls.__name__,
433
+ op_args=op_args,
434
+ )
435
+ return cls
436
+
437
+ return _inner
438
+
439
+
440
+ class EmptyFunctionSpec(FunctionSpec):
441
+ pass
442
+
443
+
444
+ class _SimpleFunctionExecutor:
445
+ spec: Callable[..., Any]
446
+
447
+ def prepare(self) -> None:
448
+ self.__call__ = staticmethod(self.spec)
449
+
450
+
451
+ def function(**args: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
452
+ """
453
+ Decorate a function to provide a function for an op.
454
+ """
455
+ op_args = OpArgs(**args)
456
+
457
+ def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
458
+ # Convert snake case to camel case.
459
+ op_kind = "".join(word.capitalize() for word in fn.__name__.split("_"))
460
+ sig = inspect.signature(fn)
461
+ fn.__cocoindex_op_kind__ = op_kind # type: ignore
462
+ _register_op_factory(
463
+ category=OpCategory.FUNCTION,
464
+ expected_args=list(sig.parameters.items()),
465
+ expected_return=sig.return_annotation,
466
+ executor_factory=_SimpleFunctionExecutor,
467
+ spec_loader=lambda _: fn,
468
+ op_kind=op_kind,
469
+ op_args=op_args,
470
+ )
471
+
472
+ return fn
473
+
474
+ return _inner
475
+
476
+
477
+ ########################################################
478
+ # Custom source connector
479
+ ########################################################
480
+
481
+
482
+ @dataclasses.dataclass
483
+ class SourceReadOptions:
484
+ """
485
+ The options for reading a source row.
486
+ This is argument for both `list()` and `get_value()` methods.
487
+ Note that in most cases (unless spelled out otherwise below) it's not a mandatory requirement, but more like a hint to say it's useful under the current context.
488
+
489
+ - include_ordinal: Whether to include the ordinal of the source row.
490
+ When provides_ordinal() returns True, you must provide `ordinal` in `list()` when `include_ordinal` is True.
491
+ It's optional for other cases. It's helpful to skip unnecessary reprocessing early, and avoid output from older version of input over-writing the latest one when there's concurrency (especially multiple processes) and source updates frequently.
492
+
493
+ - include_content_version_fp: Whether to include the content version fingerprint of the source row.
494
+ It's always optional even if this is True.
495
+ It's helpful to skip unnecessary reprocessing early.
496
+ You should only consider providing it if you can directly get it without computing the hash on the content.
497
+
498
+ - include_value: Whether to include the value of the source row.
499
+ You must provide it in `get_value()` when `include_value` is True.
500
+ It's optional for `list()`.
501
+ Consider providing it when it's significantly cheaper then calling another `get_value()` for each row.
502
+ It will save costs of individual `get_value()` calls.
503
+ """
504
+
505
+ include_ordinal: bool = False
506
+ include_content_version_fp: bool = False
507
+ include_value: bool = False
508
+
509
+
510
+ K = TypeVar("K")
511
+ V = TypeVar("V")
512
+
513
+ NON_EXISTENCE: Literal["NON_EXISTENCE"] = "NON_EXISTENCE"
514
+ NO_ORDINAL: Literal["NO_ORDINAL"] = "NO_ORDINAL"
515
+
516
+
517
+ @dataclasses.dataclass
518
+ class PartialSourceRowData(Generic[V]):
519
+ """
520
+ The data of a source row.
521
+
522
+ - value: The value of the source row. NON_EXISTENCE means the row does not exist.
523
+ - ordinal: The ordinal of the source row. NO_ORDINAL means ordinal is not available for the source.
524
+ - content_version_fp: The content version fingerprint of the source row.
525
+ """
526
+
527
+ value: V | Literal["NON_EXISTENCE"] | None = None
528
+ ordinal: int | Literal["NO_ORDINAL"] | None = None
529
+ content_version_fp: bytes | None = None
530
+
531
+
532
+ @dataclasses.dataclass
533
+ class PartialSourceRow(Generic[K, V]):
534
+ key: K
535
+ data: PartialSourceRowData[V]
536
+
537
+
538
+ class _SourceExecutorContext:
539
+ _executor: Any
540
+
541
+ _key_encoder: Callable[[Any], Any]
542
+ _key_decoder: Callable[[Any], Any]
543
+
544
+ _value_encoder: Callable[[Any], Any]
545
+
546
+ _list_fn: Callable[
547
+ [SourceReadOptions],
548
+ AsyncIterator[PartialSourceRow[Any, Any]]
549
+ | Iterator[PartialSourceRow[Any, Any]],
550
+ ]
551
+ _orig_get_value_fn: Callable[..., Any]
552
+ _get_value_fn: Callable[..., Awaitable[PartialSourceRowData[Any]]]
553
+ _provides_ordinal_fn: Callable[[], bool] | None
554
+
555
+ def __init__(
556
+ self,
557
+ executor: Any,
558
+ key_type_info: AnalyzedTypeInfo,
559
+ key_decoder: Callable[[Any], Any],
560
+ value_type_info: AnalyzedTypeInfo,
561
+ ):
562
+ self._executor = executor
563
+
564
+ self._key_encoder = make_engine_key_encoder(key_type_info)
565
+ self._key_decoder = key_decoder
566
+ self._value_encoder = make_engine_value_encoder(value_type_info)
567
+
568
+ self._list_fn = _get_required_method(executor, "list")
569
+ self._orig_get_value_fn = _get_required_method(executor, "get_value")
570
+ self._get_value_fn = to_async_call(self._orig_get_value_fn)
571
+ self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None)
572
+
573
+ def provides_ordinal(self) -> bool:
574
+ if self._provides_ordinal_fn is not None:
575
+ result = self._provides_ordinal_fn()
576
+ return bool(result)
577
+ else:
578
+ return False
579
+
580
+ async def list_async(
581
+ self, options: dict[str, Any]
582
+ ) -> AsyncIterator[tuple[Any, dict[str, Any]]]:
583
+ """
584
+ Return an async iterator that yields individual rows one by one.
585
+ Each yielded item is a tuple of (key, data).
586
+ """
587
+ read_options = load_engine_object(SourceReadOptions, options)
588
+ args = _build_args(self._list_fn, 0, options=read_options)
589
+ list_result = self._list_fn(*args)
590
+
591
+ # Handle both sync and async iterators
592
+ if hasattr(list_result, "__aiter__"):
593
+ async for partial_row in list_result:
594
+ yield (
595
+ self._key_encoder(partial_row.key),
596
+ self._encode_source_row_data(partial_row.data),
597
+ )
598
+ else:
599
+ for partial_row in list_result:
600
+ yield (
601
+ self._key_encoder(partial_row.key),
602
+ self._encode_source_row_data(partial_row.data),
603
+ )
604
+
605
+ async def get_value_async(
606
+ self,
607
+ raw_key: Any,
608
+ options: dict[str, Any],
609
+ ) -> dict[str, Any]:
610
+ key = self._key_decoder(raw_key)
611
+ read_options = load_engine_object(SourceReadOptions, options)
612
+ args = _build_args(self._orig_get_value_fn, 1, key=key, options=read_options)
613
+ row_data = await self._get_value_fn(*args)
614
+ return self._encode_source_row_data(row_data)
615
+
616
+ def _encode_source_row_data(
617
+ self, row_data: PartialSourceRowData[Any]
618
+ ) -> dict[str, Any]:
619
+ """Convert Python PartialSourceRowData to the format expected by Rust."""
620
+ return {
621
+ "ordinal": row_data.ordinal,
622
+ "content_version_fp": row_data.content_version_fp,
623
+ "value": (
624
+ NON_EXISTENCE
625
+ if row_data.value == NON_EXISTENCE
626
+ else self._value_encoder(row_data.value)
627
+ ),
628
+ }
629
+
630
+
631
+ class _SourceConnector:
632
+ """
633
+ The connector class passed to the engine.
634
+ """
635
+
636
+ _spec_cls: type[Any]
637
+ _key_type_info: AnalyzedTypeInfo
638
+ _key_decoder: Callable[[Any], Any]
639
+ _value_type_info: AnalyzedTypeInfo
640
+ _table_type: EnrichedValueType
641
+ _connector_cls: type[Any]
642
+
643
+ _create_fn: Callable[[Any], Awaitable[Any]]
644
+
645
+ def __init__(
646
+ self,
647
+ spec_cls: type[Any],
648
+ key_type: Any,
649
+ value_type: Any,
650
+ connector_cls: type[Any],
651
+ ):
652
+ self._spec_cls = spec_cls
653
+ self._key_type_info = analyze_type_info(key_type)
654
+ self._value_type_info = analyze_type_info(value_type)
655
+ self._connector_cls = connector_cls
656
+
657
+ # TODO: We can save the intermediate step after #1083 is fixed.
658
+ encoded_engine_key_type = encode_enriched_type_info(self._key_type_info)
659
+ engine_key_type = EnrichedValueType.decode(encoded_engine_key_type)
660
+
661
+ # TODO: We can save the intermediate step after #1083 is fixed.
662
+ encoded_engine_value_type = encode_enriched_type_info(self._value_type_info)
663
+ engine_value_type = EnrichedValueType.decode(encoded_engine_value_type)
664
+
665
+ if not isinstance(engine_value_type.type, StructType):
666
+ raise ValueError(f"Expected a StructType, got {engine_value_type.type}")
667
+
668
+ if isinstance(engine_key_type.type, StructType):
669
+ key_fields_schema = engine_key_type.type.fields
670
+ else:
671
+ key_fields_schema = [
672
+ FieldSchema(name=KEY_FIELD_NAME, value_type=engine_key_type)
673
+ ]
674
+ self._key_decoder = make_engine_key_decoder(
675
+ [], key_fields_schema, self._key_type_info
676
+ )
677
+ self._table_type = EnrichedValueType(
678
+ type=TableType(
679
+ kind="KTable",
680
+ row=StructSchema(
681
+ fields=key_fields_schema + engine_value_type.type.fields
682
+ ),
683
+ num_key_parts=len(key_fields_schema),
684
+ ),
685
+ )
686
+
687
+ self._create_fn = to_async_call(_get_required_method(connector_cls, "create"))
688
+
689
+ async def create_executor(self, raw_spec: dict[str, Any]) -> _SourceExecutorContext:
690
+ spec = load_engine_object(self._spec_cls, raw_spec)
691
+ executor = await self._create_fn(spec)
692
+ return _SourceExecutorContext(
693
+ executor, self._key_type_info, self._key_decoder, self._value_type_info
694
+ )
695
+
696
+ def get_table_type(self) -> Any:
697
+ return dump_engine_object(self._table_type)
698
+
699
+
700
+ def source_connector(
701
+ *,
702
+ spec_cls: type[Any],
703
+ key_type: Any = Any,
704
+ value_type: Any = Any,
705
+ ) -> Callable[[type], type]:
706
+ """
707
+ Decorate a class to provide a source connector for an op.
708
+ """
709
+
710
+ # Validate the spec_cls is a SourceSpec.
711
+ if not issubclass(spec_cls, SourceSpec):
712
+ raise ValueError(f"Expect a SourceSpec, got {spec_cls}")
713
+
714
+ # Register the source connector.
715
+ def _inner(connector_cls: type) -> type:
716
+ connector = _SourceConnector(spec_cls, key_type, value_type, connector_cls)
717
+ _engine.register_source_connector(spec_cls.__name__, connector)
718
+ return connector_cls
719
+
720
+ return _inner
721
+
722
+
723
+ ########################################################
724
+ # Custom target connector
725
+ ########################################################
726
+
727
+
728
+ @dataclasses.dataclass
729
+ class _TargetConnectorContext:
730
+ target_name: str
731
+ spec: Any
732
+ prepared_spec: Any
733
+ key_fields_schema: list[FieldSchema]
734
+ key_decoder: Callable[[Any], Any]
735
+ value_fields_schema: list[FieldSchema]
736
+ value_decoder: Callable[[Any], Any]
737
+ index_options: IndexOptions
738
+ setup_state: Any
739
+
740
+
741
+ def _build_args(
742
+ method: Callable[..., Any], num_required_args: int, **kwargs: Any
743
+ ) -> list[Any]:
744
+ signature = inspect.signature(method)
745
+ for param in signature.parameters.values():
746
+ if param.kind not in (
747
+ inspect.Parameter.POSITIONAL_ONLY,
748
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
749
+ ):
750
+ raise ValueError(
751
+ f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
752
+ )
753
+ if len(signature.parameters) < num_required_args:
754
+ raise ValueError(
755
+ f"Method {method.__name__} must have at least {num_required_args} required arguments: "
756
+ f"{', '.join(list(kwargs.keys())[:num_required_args])}"
757
+ )
758
+ if len(kwargs) > len(kwargs):
759
+ raise ValueError(
760
+ f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
761
+ )
762
+ return [v for _, v in zip(signature.parameters, kwargs.values())]
763
+
764
+
765
+ class TargetStateCompatibility(Enum):
766
+ """The compatibility of the target state."""
767
+
768
+ COMPATIBLE = "Compatible"
769
+ PARTIALLY_COMPATIBLE = "PartialCompatible"
770
+ NOT_COMPATIBLE = "NotCompatible"
771
+
772
+
773
+ class _TargetConnector:
774
+ """
775
+ The connector class passed to the engine.
776
+ """
777
+
778
+ _spec_cls: type[Any]
779
+ _persistent_key_type: Any
780
+ _setup_state_cls: type[Any]
781
+ _connector_cls: type[Any]
782
+
783
+ _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
784
+ _apply_setup_change_async_fn: Callable[
785
+ [Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
786
+ ]
787
+ _mutate_async_fn: Callable[..., Awaitable[None]]
788
+ _mutatation_type: AnalyzedDictType | None
789
+
790
+ def __init__(
791
+ self,
792
+ spec_cls: type[Any],
793
+ persistent_key_type: Any,
794
+ setup_state_cls: type[Any],
795
+ connector_cls: type[Any],
796
+ ):
797
+ self._spec_cls = spec_cls
798
+ self._persistent_key_type = persistent_key_type
799
+ self._setup_state_cls = setup_state_cls
800
+ self._connector_cls = connector_cls
801
+
802
+ self._get_persistent_key_fn = _get_required_method(
803
+ connector_cls, "get_persistent_key"
804
+ )
805
+ self._apply_setup_change_async_fn = to_async_call(
806
+ _get_required_method(connector_cls, "apply_setup_change")
807
+ )
808
+
809
+ mutate_fn = _get_required_method(connector_cls, "mutate")
810
+ self._mutate_async_fn = to_async_call(mutate_fn)
811
+
812
+ # Store the type annotation for later use
813
+ self._mutatation_type = self._analyze_mutate_mutation_type(
814
+ connector_cls, mutate_fn
815
+ )
816
+
817
+ @staticmethod
818
+ def _analyze_mutate_mutation_type(
819
+ connector_cls: type, mutate_fn: Callable[..., Any]
820
+ ) -> AnalyzedDictType | None:
821
+ # Validate mutate_fn signature and extract type annotation
822
+ mutate_sig = inspect.signature(mutate_fn)
823
+ params = list(mutate_sig.parameters.values())
824
+
825
+ if len(params) != 1:
826
+ raise ValueError(
827
+ f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
828
+ f"got {len(params)}"
829
+ )
830
+
831
+ param = params[0]
832
+ if param.kind != inspect.Parameter.VAR_POSITIONAL:
833
+ raise ValueError(
834
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
835
+ f"got {param.kind.name}"
836
+ )
837
+
838
+ # Extract type annotation
839
+ analyzed_args_type = analyze_type_info(param.annotation)
840
+ if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
841
+ return None
842
+
843
+ if analyzed_args_type.base_type is tuple:
844
+ args = get_args(analyzed_args_type.core_type)
845
+ if not args:
846
+ return None
847
+ if len(args) == 2:
848
+ mutation_type = analyze_type_info(args[1])
849
+ if isinstance(mutation_type.variant, AnalyzedAnyType):
850
+ return None
851
+ if isinstance(mutation_type.variant, AnalyzedDictType):
852
+ return mutation_type.variant
853
+
854
+ raise ValueError(
855
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
856
+ f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
857
+ f"got {analyzed_args_type.core_type}"
858
+ )
859
+
860
+ def create_export_context(
861
+ self,
862
+ name: str,
863
+ raw_spec: dict[str, Any],
864
+ raw_key_fields_schema: list[Any],
865
+ raw_value_fields_schema: list[Any],
866
+ raw_index_options: dict[str, Any],
867
+ ) -> _TargetConnectorContext:
868
+ key_annotation, value_annotation = (
869
+ (
870
+ self._mutatation_type.key_type,
871
+ self._mutatation_type.value_type,
872
+ )
873
+ if self._mutatation_type is not None
874
+ else (Any, Any)
875
+ )
876
+
877
+ key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
878
+ key_decoder = make_engine_key_decoder(
879
+ ["<key>"], key_fields_schema, analyze_type_info(key_annotation)
880
+ )
881
+ value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
882
+ value_decoder = make_engine_struct_decoder(
883
+ ["<value>"], value_fields_schema, analyze_type_info(value_annotation)
884
+ )
885
+
886
+ spec = load_engine_object(self._spec_cls, raw_spec)
887
+ index_options = load_engine_object(IndexOptions, raw_index_options)
888
+ return _TargetConnectorContext(
889
+ target_name=name,
890
+ spec=spec,
891
+ prepared_spec=None,
892
+ key_fields_schema=key_fields_schema,
893
+ key_decoder=key_decoder,
894
+ value_fields_schema=value_fields_schema,
895
+ value_decoder=value_decoder,
896
+ index_options=index_options,
897
+ setup_state=None,
898
+ )
899
+
900
+ def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
901
+ args = _build_args(
902
+ self._get_persistent_key_fn,
903
+ 1,
904
+ spec=export_context.spec,
905
+ target_name=export_context.target_name,
906
+ )
907
+ return dump_engine_object(self._get_persistent_key_fn(*args))
908
+
909
+ def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
910
+ get_setup_state_fn = getattr(self._connector_cls, "get_setup_state", None)
911
+ if get_setup_state_fn is None:
912
+ state = export_context.spec
913
+ if not isinstance(state, self._setup_state_cls):
914
+ raise ValueError(
915
+ f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._setup_state_cls}"
916
+ )
917
+ else:
918
+ args = _build_args(
919
+ get_setup_state_fn,
920
+ 1,
921
+ spec=export_context.spec,
922
+ key_fields_schema=export_context.key_fields_schema,
923
+ value_fields_schema=export_context.value_fields_schema,
924
+ index_options=export_context.index_options,
925
+ )
926
+ state = get_setup_state_fn(*args)
927
+ if not isinstance(state, self._setup_state_cls):
928
+ raise ValueError(
929
+ f"Method {get_setup_state_fn.__name__} must return an instance of {self._setup_state_cls}, got {type(state)}"
930
+ )
931
+ export_context.setup_state = state
932
+ return dump_engine_object(state)
933
+
934
+ def check_state_compatibility(
935
+ self, raw_desired_state: Any, raw_existing_state: Any
936
+ ) -> Any:
937
+ check_state_compatibility_fn = getattr(
938
+ self._connector_cls, "check_state_compatibility", None
939
+ )
940
+ if check_state_compatibility_fn is not None:
941
+ compatibility = check_state_compatibility_fn(
942
+ load_engine_object(self._setup_state_cls, raw_desired_state),
943
+ load_engine_object(self._setup_state_cls, raw_existing_state),
944
+ )
945
+ else:
946
+ compatibility = (
947
+ TargetStateCompatibility.COMPATIBLE
948
+ if raw_desired_state == raw_existing_state
949
+ else TargetStateCompatibility.PARTIALLY_COMPATIBLE
950
+ )
951
+ return dump_engine_object(compatibility)
952
+
953
+ async def prepare_async(
954
+ self,
955
+ export_context: _TargetConnectorContext,
956
+ ) -> None:
957
+ prepare_fn = getattr(self._connector_cls, "prepare", None)
958
+ if prepare_fn is None:
959
+ export_context.prepared_spec = export_context.spec
960
+ return
961
+ args = _build_args(
962
+ prepare_fn,
963
+ 1,
964
+ spec=export_context.spec,
965
+ setup_state=export_context.setup_state,
966
+ key_fields_schema=export_context.key_fields_schema,
967
+ value_fields_schema=export_context.value_fields_schema,
968
+ )
969
+ async_prepare_fn = to_async_call(prepare_fn)
970
+ export_context.prepared_spec = await async_prepare_fn(*args)
971
+
972
+ def describe_resource(self, raw_key: Any) -> str:
973
+ key = load_engine_object(self._persistent_key_type, raw_key)
974
+ describe_fn = getattr(self._connector_cls, "describe", None)
975
+ if describe_fn is None:
976
+ return str(key)
977
+ return str(describe_fn(key))
978
+
979
+ async def apply_setup_changes_async(
980
+ self,
981
+ changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
982
+ ) -> None:
983
+ for raw_key, previous, current in changes:
984
+ key = load_engine_object(self._persistent_key_type, raw_key)
985
+ prev_specs = [
986
+ load_engine_object(self._setup_state_cls, spec)
987
+ if spec is not None
988
+ else None
989
+ for spec in previous
990
+ ]
991
+ curr_spec = (
992
+ load_engine_object(self._setup_state_cls, current)
993
+ if current is not None
994
+ else None
995
+ )
996
+ for prev_spec in prev_specs:
997
+ await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
998
+
999
+ @staticmethod
1000
+ def _decode_mutation(
1001
+ context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
1002
+ ) -> tuple[Any, dict[Any, Any | None]]:
1003
+ return (
1004
+ context.prepared_spec,
1005
+ {
1006
+ context.key_decoder(key): (
1007
+ context.value_decoder(value) if value is not None else None
1008
+ )
1009
+ for key, value in mutation
1010
+ },
1011
+ )
1012
+
1013
+ async def mutate_async(
1014
+ self,
1015
+ mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
1016
+ ) -> None:
1017
+ await self._mutate_async_fn(
1018
+ *(
1019
+ self._decode_mutation(context, mutation)
1020
+ for context, mutation in mutations
1021
+ )
1022
+ )
1023
+
1024
+
1025
+ def target_connector(
1026
+ *,
1027
+ spec_cls: type[Any],
1028
+ persistent_key_type: Any = Any,
1029
+ setup_state_cls: type[Any] | None = None,
1030
+ ) -> Callable[[type], type]:
1031
+ """
1032
+ Decorate a class to provide a target connector for an op.
1033
+ """
1034
+
1035
+ # Validate the spec_cls is a TargetSpec.
1036
+ if not issubclass(spec_cls, TargetSpec):
1037
+ raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
1038
+
1039
+ # Register the target connector.
1040
+ def _inner(connector_cls: type) -> type:
1041
+ connector = _TargetConnector(
1042
+ spec_cls, persistent_key_type, setup_state_cls or spec_cls, connector_cls
1043
+ )
1044
+ _engine.register_target_connector(spec_cls.__name__, connector)
1045
+ return connector_cls
1046
+
1047
+ return _inner