cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +114 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +44 -0
- cocoindex/cli.py +830 -0
- cocoindex/engine_object.py +214 -0
- cocoindex/engine_value.py +550 -0
- cocoindex/flow.py +1281 -0
- cocoindex/functions/__init__.py +40 -0
- cocoindex/functions/_engine_builtin_specs.py +66 -0
- cocoindex/functions/colpali.py +247 -0
- cocoindex/functions/sbert.py +77 -0
- cocoindex/index.py +50 -0
- cocoindex/lib.py +75 -0
- cocoindex/llm.py +47 -0
- cocoindex/op.py +1047 -0
- cocoindex/py.typed +0 -0
- cocoindex/query_handler.py +57 -0
- cocoindex/runtime.py +78 -0
- cocoindex/setting.py +171 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources/__init__.py +5 -0
- cocoindex/sources/_engine_builtin_specs.py +120 -0
- cocoindex/subprocess_exec.py +277 -0
- cocoindex/targets/__init__.py +5 -0
- cocoindex/targets/_engine_builtin_specs.py +153 -0
- cocoindex/targets/lancedb.py +466 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/test_engine_object.py +331 -0
- cocoindex/tests/test_engine_value.py +1724 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +300 -0
- cocoindex/tests/test_typing.py +553 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +834 -0
- cocoindex/user_app_loader.py +53 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.3.4.dist-info/METADATA +288 -0
- cocoindex-0.3.4.dist-info/RECORD +42 -0
- cocoindex-0.3.4.dist-info/WHEEL +4 -0
- cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
- cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
cocoindex/op.py
ADDED
|
@@ -0,0 +1,1047 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Facilities for defining cocoindex operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import inspect
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Awaitable,
|
|
11
|
+
Callable,
|
|
12
|
+
Iterator,
|
|
13
|
+
Protocol,
|
|
14
|
+
dataclass_transform,
|
|
15
|
+
Annotated,
|
|
16
|
+
TypeVar,
|
|
17
|
+
Generic,
|
|
18
|
+
Literal,
|
|
19
|
+
get_args,
|
|
20
|
+
)
|
|
21
|
+
from collections.abc import AsyncIterator
|
|
22
|
+
|
|
23
|
+
from . import _engine # type: ignore
|
|
24
|
+
from .subprocess_exec import executor_stub
|
|
25
|
+
from .engine_object import dump_engine_object, load_engine_object
|
|
26
|
+
from .engine_value import (
|
|
27
|
+
make_engine_key_encoder,
|
|
28
|
+
make_engine_value_encoder,
|
|
29
|
+
make_engine_value_decoder,
|
|
30
|
+
make_engine_key_decoder,
|
|
31
|
+
make_engine_struct_decoder,
|
|
32
|
+
)
|
|
33
|
+
from .typing import (
|
|
34
|
+
KEY_FIELD_NAME,
|
|
35
|
+
AnalyzedListType,
|
|
36
|
+
AnalyzedTypeInfo,
|
|
37
|
+
StructSchema,
|
|
38
|
+
StructType,
|
|
39
|
+
TableType,
|
|
40
|
+
encode_enriched_type_info,
|
|
41
|
+
resolve_forward_ref,
|
|
42
|
+
analyze_type_info,
|
|
43
|
+
AnalyzedAnyType,
|
|
44
|
+
AnalyzedDictType,
|
|
45
|
+
EnrichedValueType,
|
|
46
|
+
decode_engine_field_schemas,
|
|
47
|
+
FieldSchema,
|
|
48
|
+
ValueType,
|
|
49
|
+
)
|
|
50
|
+
from .runtime import to_async_call
|
|
51
|
+
from .index import IndexOptions
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class OpCategory(Enum):
|
|
55
|
+
"""The category of the operation."""
|
|
56
|
+
|
|
57
|
+
FUNCTION = "function"
|
|
58
|
+
SOURCE = "source"
|
|
59
|
+
TARGET = "target"
|
|
60
|
+
DECLARATION = "declaration"
|
|
61
|
+
TARGET_ATTACHMENT = "target_attachment"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass_transform()
|
|
65
|
+
class SpecMeta(type):
|
|
66
|
+
"""Meta class for spec classes."""
|
|
67
|
+
|
|
68
|
+
def __new__(
|
|
69
|
+
mcs,
|
|
70
|
+
name: str,
|
|
71
|
+
bases: tuple[type, ...],
|
|
72
|
+
attrs: dict[str, Any],
|
|
73
|
+
category: OpCategory | None = None,
|
|
74
|
+
) -> type:
|
|
75
|
+
cls: type = super().__new__(mcs, name, bases, attrs)
|
|
76
|
+
if category is not None:
|
|
77
|
+
# It's the base class.
|
|
78
|
+
setattr(cls, "_op_category", category)
|
|
79
|
+
else:
|
|
80
|
+
# It's the specific class providing specific fields.
|
|
81
|
+
cls = dataclasses.dataclass(cls)
|
|
82
|
+
return cls
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
|
|
86
|
+
"""A source spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
|
|
90
|
+
"""A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TargetSpec(metaclass=SpecMeta, category=OpCategory.TARGET): # pylint: disable=too-few-public-methods
|
|
94
|
+
"""A target spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TargetAttachmentSpec(metaclass=SpecMeta, category=OpCategory.TARGET_ATTACHMENT): # pylint: disable=too-few-public-methods
|
|
98
|
+
"""A target attachment spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
|
|
102
|
+
"""A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Executor(Protocol):
|
|
106
|
+
"""An executor for an operation."""
|
|
107
|
+
|
|
108
|
+
op_category: OpCategory
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_required_method(obj: type, name: str) -> Callable[..., Any]:
|
|
112
|
+
method = getattr(obj, name, None)
|
|
113
|
+
if method is None:
|
|
114
|
+
raise ValueError(f"Method {name}() is required for {obj}")
|
|
115
|
+
if not inspect.isfunction(method) and not inspect.ismethod(method):
|
|
116
|
+
raise ValueError(f"{obj}.{name}() is not a function; {method}")
|
|
117
|
+
return method
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class _EngineFunctionExecutorFactory:
|
|
121
|
+
_spec_loader: Callable[[Any], Any]
|
|
122
|
+
_executor_cls: type
|
|
123
|
+
|
|
124
|
+
def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
|
|
125
|
+
self._spec_loader = spec_loader
|
|
126
|
+
self._executor_cls = executor_cls
|
|
127
|
+
|
|
128
|
+
def __call__(
|
|
129
|
+
self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
|
|
130
|
+
) -> tuple[dict[str, Any], Executor]:
|
|
131
|
+
spec = self._spec_loader(raw_spec)
|
|
132
|
+
executor = self._executor_cls(spec)
|
|
133
|
+
result_type = executor.analyze_schema(*args, **kwargs)
|
|
134
|
+
return (result_type, executor)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
_COCOINDEX_ATTR_PREFIX = "cocoindex.io/"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ArgRelationship(Enum):
|
|
141
|
+
"""Specifies the relationship between an input argument and the output."""
|
|
142
|
+
|
|
143
|
+
EMBEDDING_ORIGIN_TEXT = _COCOINDEX_ATTR_PREFIX + "embedding_origin_text"
|
|
144
|
+
CHUNKS_BASE_TEXT = _COCOINDEX_ATTR_PREFIX + "chunk_base_text"
|
|
145
|
+
RECTS_BASE_IMAGE = _COCOINDEX_ATTR_PREFIX + "rects_base_image"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclasses.dataclass
|
|
149
|
+
class OpArgs:
|
|
150
|
+
"""
|
|
151
|
+
- gpu: Whether the executor will be executed on GPU.
|
|
152
|
+
- cache: Whether the executor will be cached.
|
|
153
|
+
- batching: Whether the executor will be batched.
|
|
154
|
+
- max_batch_size: The maximum batch size for the executor. Only valid if `batching` is True.
|
|
155
|
+
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
|
|
156
|
+
changes. Must be provided if `cache` is True.
|
|
157
|
+
- arg_relationship: It specifies the relationship between an input argument and the output,
|
|
158
|
+
e.g. `(ArgRelationship.CHUNKS_BASE_TEXT, "content")` means the output is chunks for the
|
|
159
|
+
input argument with name `content`.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
gpu: bool = False
|
|
163
|
+
cache: bool = False
|
|
164
|
+
batching: bool = False
|
|
165
|
+
max_batch_size: int | None = None
|
|
166
|
+
behavior_version: int | None = None
|
|
167
|
+
arg_relationship: tuple[ArgRelationship, str] | None = None
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@dataclasses.dataclass
|
|
171
|
+
class _ArgInfo:
|
|
172
|
+
decoder: Callable[[Any], Any]
|
|
173
|
+
is_required: bool
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _make_batched_engine_value_decoder(
|
|
177
|
+
field_path: list[str], src_type: ValueType, dst_type_info: AnalyzedTypeInfo
|
|
178
|
+
) -> Callable[[Any], Any]:
|
|
179
|
+
if not isinstance(dst_type_info.variant, AnalyzedListType):
|
|
180
|
+
raise ValueError("Expected arguments for batching function to be a list type")
|
|
181
|
+
elem_type_info = analyze_type_info(dst_type_info.variant.elem_type)
|
|
182
|
+
base_decoder = make_engine_value_decoder(field_path, src_type, elem_type_info)
|
|
183
|
+
return lambda value: [base_decoder(v) for v in value]
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _register_op_factory(
|
|
187
|
+
category: OpCategory,
|
|
188
|
+
expected_args: list[tuple[str, inspect.Parameter]],
|
|
189
|
+
expected_return: Any,
|
|
190
|
+
executor_factory: Any,
|
|
191
|
+
spec_loader: Callable[..., Any],
|
|
192
|
+
op_kind: str,
|
|
193
|
+
op_args: OpArgs,
|
|
194
|
+
) -> None:
|
|
195
|
+
"""
|
|
196
|
+
Register an op factory.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
if op_args.batching:
|
|
200
|
+
if len(expected_args) != 1:
|
|
201
|
+
raise ValueError("Batching is only supported for single argument functions")
|
|
202
|
+
|
|
203
|
+
class _WrappedExecutor:
|
|
204
|
+
_executor: Any
|
|
205
|
+
_args_info: list[_ArgInfo]
|
|
206
|
+
_kwargs_info: dict[str, _ArgInfo]
|
|
207
|
+
_result_encoder: Callable[[Any], Any]
|
|
208
|
+
_acall: Callable[..., Awaitable[Any]] | None = None
|
|
209
|
+
|
|
210
|
+
def __init__(self, spec: Any) -> None:
|
|
211
|
+
executor: Any
|
|
212
|
+
|
|
213
|
+
if op_args.gpu:
|
|
214
|
+
executor = executor_stub(executor_factory, spec)
|
|
215
|
+
else:
|
|
216
|
+
executor = executor_factory()
|
|
217
|
+
executor.spec = spec
|
|
218
|
+
|
|
219
|
+
self._executor = executor
|
|
220
|
+
|
|
221
|
+
def analyze_schema(
|
|
222
|
+
self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
|
|
223
|
+
) -> Any:
|
|
224
|
+
"""
|
|
225
|
+
Analyze the spec and arguments. In this phase, argument types should be validated.
|
|
226
|
+
It should return the expected result type for the current op.
|
|
227
|
+
"""
|
|
228
|
+
self._args_info = []
|
|
229
|
+
self._kwargs_info = {}
|
|
230
|
+
attributes = {}
|
|
231
|
+
potentially_missing_required_arg = False
|
|
232
|
+
|
|
233
|
+
def process_arg(
|
|
234
|
+
arg_name: str,
|
|
235
|
+
arg_param: inspect.Parameter,
|
|
236
|
+
actual_arg: _engine.OpArgSchema,
|
|
237
|
+
) -> _ArgInfo:
|
|
238
|
+
nonlocal potentially_missing_required_arg
|
|
239
|
+
if op_args.arg_relationship is not None:
|
|
240
|
+
related_attr, related_arg_name = op_args.arg_relationship
|
|
241
|
+
if related_arg_name == arg_name:
|
|
242
|
+
attributes[related_attr.value] = actual_arg.analyzed_value
|
|
243
|
+
type_info = analyze_type_info(arg_param.annotation)
|
|
244
|
+
enriched = EnrichedValueType.decode(actual_arg.value_type)
|
|
245
|
+
if op_args.batching:
|
|
246
|
+
decoder = _make_batched_engine_value_decoder(
|
|
247
|
+
[arg_name], enriched.type, type_info
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
decoder = make_engine_value_decoder(
|
|
251
|
+
[arg_name], enriched.type, type_info
|
|
252
|
+
)
|
|
253
|
+
is_required = not type_info.nullable
|
|
254
|
+
if is_required and actual_arg.value_type.get("nullable", False):
|
|
255
|
+
potentially_missing_required_arg = True
|
|
256
|
+
return _ArgInfo(
|
|
257
|
+
decoder=decoder,
|
|
258
|
+
is_required=is_required,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Match arguments with parameters.
|
|
262
|
+
next_param_idx = 0
|
|
263
|
+
for actual_arg in args:
|
|
264
|
+
if next_param_idx >= len(expected_args):
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
|
|
267
|
+
)
|
|
268
|
+
arg_name, arg_param = expected_args[next_param_idx]
|
|
269
|
+
if arg_param.kind in (
|
|
270
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
271
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
272
|
+
):
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
|
|
275
|
+
)
|
|
276
|
+
self._args_info.append(process_arg(arg_name, arg_param, actual_arg))
|
|
277
|
+
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
|
|
278
|
+
next_param_idx += 1
|
|
279
|
+
|
|
280
|
+
expected_kwargs = expected_args[next_param_idx:]
|
|
281
|
+
|
|
282
|
+
for kwarg_name, actual_arg in kwargs.items():
|
|
283
|
+
expected_arg = next(
|
|
284
|
+
(
|
|
285
|
+
arg
|
|
286
|
+
for arg in expected_kwargs
|
|
287
|
+
if (
|
|
288
|
+
arg[0] == kwarg_name
|
|
289
|
+
and arg[1].kind
|
|
290
|
+
in (
|
|
291
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
292
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
or arg[1].kind == inspect.Parameter.VAR_KEYWORD
|
|
296
|
+
),
|
|
297
|
+
None,
|
|
298
|
+
)
|
|
299
|
+
if expected_arg is None:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
f"Unexpected keyword argument passed in: {kwarg_name}"
|
|
302
|
+
)
|
|
303
|
+
arg_param = expected_arg[1]
|
|
304
|
+
self._kwargs_info[kwarg_name] = process_arg(
|
|
305
|
+
kwarg_name, arg_param, actual_arg
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
missing_args = [
|
|
309
|
+
name
|
|
310
|
+
for (name, arg) in expected_kwargs
|
|
311
|
+
if arg.default is inspect.Parameter.empty
|
|
312
|
+
and (
|
|
313
|
+
arg.kind == inspect.Parameter.POSITIONAL_ONLY
|
|
314
|
+
or (
|
|
315
|
+
arg.kind
|
|
316
|
+
in (
|
|
317
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
318
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
319
|
+
)
|
|
320
|
+
and name not in kwargs
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
]
|
|
324
|
+
if len(missing_args) > 0:
|
|
325
|
+
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
|
|
326
|
+
|
|
327
|
+
analyzed_expected_return_type = analyze_type_info(expected_return)
|
|
328
|
+
self._result_encoder = make_engine_value_encoder(
|
|
329
|
+
analyzed_expected_return_type
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
base_analyze_method = getattr(self._executor, "analyze", None)
|
|
333
|
+
if base_analyze_method is not None:
|
|
334
|
+
analyzed_result_type = analyze_type_info(base_analyze_method())
|
|
335
|
+
else:
|
|
336
|
+
if op_args.batching:
|
|
337
|
+
if not isinstance(
|
|
338
|
+
analyzed_expected_return_type.variant, AnalyzedListType
|
|
339
|
+
):
|
|
340
|
+
raise ValueError(
|
|
341
|
+
"Expected return type for batching function to be a list type"
|
|
342
|
+
)
|
|
343
|
+
analyzed_result_type = analyze_type_info(
|
|
344
|
+
analyzed_expected_return_type.variant.elem_type
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
analyzed_result_type = analyzed_expected_return_type
|
|
348
|
+
if len(attributes) > 0:
|
|
349
|
+
analyzed_result_type.attrs = attributes
|
|
350
|
+
if potentially_missing_required_arg:
|
|
351
|
+
analyzed_result_type.nullable = True
|
|
352
|
+
encoded_type = encode_enriched_type_info(analyzed_result_type)
|
|
353
|
+
|
|
354
|
+
return encoded_type
|
|
355
|
+
|
|
356
|
+
async def prepare(self) -> None:
|
|
357
|
+
"""
|
|
358
|
+
Prepare for execution.
|
|
359
|
+
It's executed after `analyze` and before any `__call__` execution.
|
|
360
|
+
"""
|
|
361
|
+
prepare_method = getattr(self._executor, "prepare", None)
|
|
362
|
+
if prepare_method is not None:
|
|
363
|
+
await to_async_call(prepare_method)()
|
|
364
|
+
self._acall = to_async_call(self._executor.__call__)
|
|
365
|
+
|
|
366
|
+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
367
|
+
decoded_args = []
|
|
368
|
+
for arg_info, arg in zip(self._args_info, args):
|
|
369
|
+
if arg_info.is_required and arg is None:
|
|
370
|
+
return None
|
|
371
|
+
decoded_args.append(arg_info.decoder(arg))
|
|
372
|
+
|
|
373
|
+
decoded_kwargs = {}
|
|
374
|
+
for kwarg_name, arg in kwargs.items():
|
|
375
|
+
kwarg_info = self._kwargs_info.get(kwarg_name)
|
|
376
|
+
if kwarg_info is None:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Unexpected keyword argument passed in: {kwarg_name}"
|
|
379
|
+
)
|
|
380
|
+
if kwarg_info.is_required and arg is None:
|
|
381
|
+
return None
|
|
382
|
+
decoded_kwargs[kwarg_name] = kwarg_info.decoder(arg)
|
|
383
|
+
|
|
384
|
+
assert self._acall is not None
|
|
385
|
+
output = await self._acall(*decoded_args, **decoded_kwargs)
|
|
386
|
+
return self._result_encoder(output)
|
|
387
|
+
|
|
388
|
+
def enable_cache(self) -> bool:
|
|
389
|
+
return op_args.cache
|
|
390
|
+
|
|
391
|
+
def behavior_version(self) -> int | None:
|
|
392
|
+
return op_args.behavior_version
|
|
393
|
+
|
|
394
|
+
def batching_options(self) -> dict[str, Any] | None:
|
|
395
|
+
if op_args.batching:
|
|
396
|
+
return {
|
|
397
|
+
"max_batch_size": op_args.max_batch_size,
|
|
398
|
+
}
|
|
399
|
+
else:
|
|
400
|
+
return None
|
|
401
|
+
|
|
402
|
+
if category == OpCategory.FUNCTION:
|
|
403
|
+
_engine.register_function_factory(
|
|
404
|
+
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
raise ValueError(f"Unsupported executor type {category}")
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def executor_class(**args: Any) -> Callable[[type], type]:
|
|
411
|
+
"""
|
|
412
|
+
Decorate a class to provide an executor for an op.
|
|
413
|
+
"""
|
|
414
|
+
op_args = OpArgs(**args)
|
|
415
|
+
|
|
416
|
+
def _inner(cls: type[Executor]) -> type:
|
|
417
|
+
"""
|
|
418
|
+
Decorate a class to provide an executor for an op.
|
|
419
|
+
"""
|
|
420
|
+
# Use `__annotations__` instead of `get_type_hints`, to avoid resolving forward references.
|
|
421
|
+
type_hints = cls.__annotations__
|
|
422
|
+
if "spec" not in type_hints:
|
|
423
|
+
raise TypeError("Expect a `spec` field with type hint")
|
|
424
|
+
spec_cls = resolve_forward_ref(type_hints["spec"])
|
|
425
|
+
sig = inspect.signature(cls.__call__)
|
|
426
|
+
_register_op_factory(
|
|
427
|
+
category=spec_cls._op_category,
|
|
428
|
+
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
|
|
429
|
+
expected_return=sig.return_annotation,
|
|
430
|
+
executor_factory=cls,
|
|
431
|
+
spec_loader=lambda v: load_engine_object(spec_cls, v),
|
|
432
|
+
op_kind=spec_cls.__name__,
|
|
433
|
+
op_args=op_args,
|
|
434
|
+
)
|
|
435
|
+
return cls
|
|
436
|
+
|
|
437
|
+
return _inner
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class EmptyFunctionSpec(FunctionSpec):
|
|
441
|
+
pass
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class _SimpleFunctionExecutor:
|
|
445
|
+
spec: Callable[..., Any]
|
|
446
|
+
|
|
447
|
+
def prepare(self) -> None:
|
|
448
|
+
self.__call__ = staticmethod(self.spec)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def function(**args: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
452
|
+
"""
|
|
453
|
+
Decorate a function to provide a function for an op.
|
|
454
|
+
"""
|
|
455
|
+
op_args = OpArgs(**args)
|
|
456
|
+
|
|
457
|
+
def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
458
|
+
# Convert snake case to camel case.
|
|
459
|
+
op_kind = "".join(word.capitalize() for word in fn.__name__.split("_"))
|
|
460
|
+
sig = inspect.signature(fn)
|
|
461
|
+
fn.__cocoindex_op_kind__ = op_kind # type: ignore
|
|
462
|
+
_register_op_factory(
|
|
463
|
+
category=OpCategory.FUNCTION,
|
|
464
|
+
expected_args=list(sig.parameters.items()),
|
|
465
|
+
expected_return=sig.return_annotation,
|
|
466
|
+
executor_factory=_SimpleFunctionExecutor,
|
|
467
|
+
spec_loader=lambda _: fn,
|
|
468
|
+
op_kind=op_kind,
|
|
469
|
+
op_args=op_args,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return fn
|
|
473
|
+
|
|
474
|
+
return _inner
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
########################################################
|
|
478
|
+
# Custom source connector
|
|
479
|
+
########################################################
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@dataclasses.dataclass
|
|
483
|
+
class SourceReadOptions:
|
|
484
|
+
"""
|
|
485
|
+
The options for reading a source row.
|
|
486
|
+
This is argument for both `list()` and `get_value()` methods.
|
|
487
|
+
Note that in most cases (unless spelled out otherwise below) it's not a mandatory requirement, but more like a hint to say it's useful under the current context.
|
|
488
|
+
|
|
489
|
+
- include_ordinal: Whether to include the ordinal of the source row.
|
|
490
|
+
When provides_ordinal() returns True, you must provide `ordinal` in `list()` when `include_ordinal` is True.
|
|
491
|
+
It's optional for other cases. It's helpful to skip unnecessary reprocessing early, and avoid output from older version of input over-writing the latest one when there's concurrency (especially multiple processes) and source updates frequently.
|
|
492
|
+
|
|
493
|
+
- include_content_version_fp: Whether to include the content version fingerprint of the source row.
|
|
494
|
+
It's always optional even if this is True.
|
|
495
|
+
It's helpful to skip unnecessary reprocessing early.
|
|
496
|
+
You should only consider providing it if you can directly get it without computing the hash on the content.
|
|
497
|
+
|
|
498
|
+
- include_value: Whether to include the value of the source row.
|
|
499
|
+
You must provide it in `get_value()` when `include_value` is True.
|
|
500
|
+
It's optional for `list()`.
|
|
501
|
+
Consider providing it when it's significantly cheaper then calling another `get_value()` for each row.
|
|
502
|
+
It will save costs of individual `get_value()` calls.
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
include_ordinal: bool = False
|
|
506
|
+
include_content_version_fp: bool = False
|
|
507
|
+
include_value: bool = False
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
K = TypeVar("K")
|
|
511
|
+
V = TypeVar("V")
|
|
512
|
+
|
|
513
|
+
NON_EXISTENCE: Literal["NON_EXISTENCE"] = "NON_EXISTENCE"
|
|
514
|
+
NO_ORDINAL: Literal["NO_ORDINAL"] = "NO_ORDINAL"
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@dataclasses.dataclass
|
|
518
|
+
class PartialSourceRowData(Generic[V]):
|
|
519
|
+
"""
|
|
520
|
+
The data of a source row.
|
|
521
|
+
|
|
522
|
+
- value: The value of the source row. NON_EXISTENCE means the row does not exist.
|
|
523
|
+
- ordinal: The ordinal of the source row. NO_ORDINAL means ordinal is not available for the source.
|
|
524
|
+
- content_version_fp: The content version fingerprint of the source row.
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
value: V | Literal["NON_EXISTENCE"] | None = None
|
|
528
|
+
ordinal: int | Literal["NO_ORDINAL"] | None = None
|
|
529
|
+
content_version_fp: bytes | None = None
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@dataclasses.dataclass
|
|
533
|
+
class PartialSourceRow(Generic[K, V]):
|
|
534
|
+
key: K
|
|
535
|
+
data: PartialSourceRowData[V]
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class _SourceExecutorContext:
|
|
539
|
+
_executor: Any
|
|
540
|
+
|
|
541
|
+
_key_encoder: Callable[[Any], Any]
|
|
542
|
+
_key_decoder: Callable[[Any], Any]
|
|
543
|
+
|
|
544
|
+
_value_encoder: Callable[[Any], Any]
|
|
545
|
+
|
|
546
|
+
_list_fn: Callable[
|
|
547
|
+
[SourceReadOptions],
|
|
548
|
+
AsyncIterator[PartialSourceRow[Any, Any]]
|
|
549
|
+
| Iterator[PartialSourceRow[Any, Any]],
|
|
550
|
+
]
|
|
551
|
+
_orig_get_value_fn: Callable[..., Any]
|
|
552
|
+
_get_value_fn: Callable[..., Awaitable[PartialSourceRowData[Any]]]
|
|
553
|
+
_provides_ordinal_fn: Callable[[], bool] | None
|
|
554
|
+
|
|
555
|
+
def __init__(
|
|
556
|
+
self,
|
|
557
|
+
executor: Any,
|
|
558
|
+
key_type_info: AnalyzedTypeInfo,
|
|
559
|
+
key_decoder: Callable[[Any], Any],
|
|
560
|
+
value_type_info: AnalyzedTypeInfo,
|
|
561
|
+
):
|
|
562
|
+
self._executor = executor
|
|
563
|
+
|
|
564
|
+
self._key_encoder = make_engine_key_encoder(key_type_info)
|
|
565
|
+
self._key_decoder = key_decoder
|
|
566
|
+
self._value_encoder = make_engine_value_encoder(value_type_info)
|
|
567
|
+
|
|
568
|
+
self._list_fn = _get_required_method(executor, "list")
|
|
569
|
+
self._orig_get_value_fn = _get_required_method(executor, "get_value")
|
|
570
|
+
self._get_value_fn = to_async_call(self._orig_get_value_fn)
|
|
571
|
+
self._provides_ordinal_fn = getattr(executor, "provides_ordinal", None)
|
|
572
|
+
|
|
573
|
+
def provides_ordinal(self) -> bool:
|
|
574
|
+
if self._provides_ordinal_fn is not None:
|
|
575
|
+
result = self._provides_ordinal_fn()
|
|
576
|
+
return bool(result)
|
|
577
|
+
else:
|
|
578
|
+
return False
|
|
579
|
+
|
|
580
|
+
async def list_async(
|
|
581
|
+
self, options: dict[str, Any]
|
|
582
|
+
) -> AsyncIterator[tuple[Any, dict[str, Any]]]:
|
|
583
|
+
"""
|
|
584
|
+
Return an async iterator that yields individual rows one by one.
|
|
585
|
+
Each yielded item is a tuple of (key, data).
|
|
586
|
+
"""
|
|
587
|
+
read_options = load_engine_object(SourceReadOptions, options)
|
|
588
|
+
args = _build_args(self._list_fn, 0, options=read_options)
|
|
589
|
+
list_result = self._list_fn(*args)
|
|
590
|
+
|
|
591
|
+
# Handle both sync and async iterators
|
|
592
|
+
if hasattr(list_result, "__aiter__"):
|
|
593
|
+
async for partial_row in list_result:
|
|
594
|
+
yield (
|
|
595
|
+
self._key_encoder(partial_row.key),
|
|
596
|
+
self._encode_source_row_data(partial_row.data),
|
|
597
|
+
)
|
|
598
|
+
else:
|
|
599
|
+
for partial_row in list_result:
|
|
600
|
+
yield (
|
|
601
|
+
self._key_encoder(partial_row.key),
|
|
602
|
+
self._encode_source_row_data(partial_row.data),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
async def get_value_async(
|
|
606
|
+
self,
|
|
607
|
+
raw_key: Any,
|
|
608
|
+
options: dict[str, Any],
|
|
609
|
+
) -> dict[str, Any]:
|
|
610
|
+
key = self._key_decoder(raw_key)
|
|
611
|
+
read_options = load_engine_object(SourceReadOptions, options)
|
|
612
|
+
args = _build_args(self._orig_get_value_fn, 1, key=key, options=read_options)
|
|
613
|
+
row_data = await self._get_value_fn(*args)
|
|
614
|
+
return self._encode_source_row_data(row_data)
|
|
615
|
+
|
|
616
|
+
def _encode_source_row_data(
|
|
617
|
+
self, row_data: PartialSourceRowData[Any]
|
|
618
|
+
) -> dict[str, Any]:
|
|
619
|
+
"""Convert Python PartialSourceRowData to the format expected by Rust."""
|
|
620
|
+
return {
|
|
621
|
+
"ordinal": row_data.ordinal,
|
|
622
|
+
"content_version_fp": row_data.content_version_fp,
|
|
623
|
+
"value": (
|
|
624
|
+
NON_EXISTENCE
|
|
625
|
+
if row_data.value == NON_EXISTENCE
|
|
626
|
+
else self._value_encoder(row_data.value)
|
|
627
|
+
),
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
class _SourceConnector:
|
|
632
|
+
"""
|
|
633
|
+
The connector class passed to the engine.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
_spec_cls: type[Any]
|
|
637
|
+
_key_type_info: AnalyzedTypeInfo
|
|
638
|
+
_key_decoder: Callable[[Any], Any]
|
|
639
|
+
_value_type_info: AnalyzedTypeInfo
|
|
640
|
+
_table_type: EnrichedValueType
|
|
641
|
+
_connector_cls: type[Any]
|
|
642
|
+
|
|
643
|
+
_create_fn: Callable[[Any], Awaitable[Any]]
|
|
644
|
+
|
|
645
|
+
def __init__(
|
|
646
|
+
self,
|
|
647
|
+
spec_cls: type[Any],
|
|
648
|
+
key_type: Any,
|
|
649
|
+
value_type: Any,
|
|
650
|
+
connector_cls: type[Any],
|
|
651
|
+
):
|
|
652
|
+
self._spec_cls = spec_cls
|
|
653
|
+
self._key_type_info = analyze_type_info(key_type)
|
|
654
|
+
self._value_type_info = analyze_type_info(value_type)
|
|
655
|
+
self._connector_cls = connector_cls
|
|
656
|
+
|
|
657
|
+
# TODO: We can save the intermediate step after #1083 is fixed.
|
|
658
|
+
encoded_engine_key_type = encode_enriched_type_info(self._key_type_info)
|
|
659
|
+
engine_key_type = EnrichedValueType.decode(encoded_engine_key_type)
|
|
660
|
+
|
|
661
|
+
# TODO: We can save the intermediate step after #1083 is fixed.
|
|
662
|
+
encoded_engine_value_type = encode_enriched_type_info(self._value_type_info)
|
|
663
|
+
engine_value_type = EnrichedValueType.decode(encoded_engine_value_type)
|
|
664
|
+
|
|
665
|
+
if not isinstance(engine_value_type.type, StructType):
|
|
666
|
+
raise ValueError(f"Expected a StructType, got {engine_value_type.type}")
|
|
667
|
+
|
|
668
|
+
if isinstance(engine_key_type.type, StructType):
|
|
669
|
+
key_fields_schema = engine_key_type.type.fields
|
|
670
|
+
else:
|
|
671
|
+
key_fields_schema = [
|
|
672
|
+
FieldSchema(name=KEY_FIELD_NAME, value_type=engine_key_type)
|
|
673
|
+
]
|
|
674
|
+
self._key_decoder = make_engine_key_decoder(
|
|
675
|
+
[], key_fields_schema, self._key_type_info
|
|
676
|
+
)
|
|
677
|
+
self._table_type = EnrichedValueType(
|
|
678
|
+
type=TableType(
|
|
679
|
+
kind="KTable",
|
|
680
|
+
row=StructSchema(
|
|
681
|
+
fields=key_fields_schema + engine_value_type.type.fields
|
|
682
|
+
),
|
|
683
|
+
num_key_parts=len(key_fields_schema),
|
|
684
|
+
),
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
self._create_fn = to_async_call(_get_required_method(connector_cls, "create"))
|
|
688
|
+
|
|
689
|
+
async def create_executor(self, raw_spec: dict[str, Any]) -> _SourceExecutorContext:
|
|
690
|
+
spec = load_engine_object(self._spec_cls, raw_spec)
|
|
691
|
+
executor = await self._create_fn(spec)
|
|
692
|
+
return _SourceExecutorContext(
|
|
693
|
+
executor, self._key_type_info, self._key_decoder, self._value_type_info
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
def get_table_type(self) -> Any:
|
|
697
|
+
return dump_engine_object(self._table_type)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
def source_connector(
|
|
701
|
+
*,
|
|
702
|
+
spec_cls: type[Any],
|
|
703
|
+
key_type: Any = Any,
|
|
704
|
+
value_type: Any = Any,
|
|
705
|
+
) -> Callable[[type], type]:
|
|
706
|
+
"""
|
|
707
|
+
Decorate a class to provide a source connector for an op.
|
|
708
|
+
"""
|
|
709
|
+
|
|
710
|
+
# Validate the spec_cls is a SourceSpec.
|
|
711
|
+
if not issubclass(spec_cls, SourceSpec):
|
|
712
|
+
raise ValueError(f"Expect a SourceSpec, got {spec_cls}")
|
|
713
|
+
|
|
714
|
+
# Register the source connector.
|
|
715
|
+
def _inner(connector_cls: type) -> type:
|
|
716
|
+
connector = _SourceConnector(spec_cls, key_type, value_type, connector_cls)
|
|
717
|
+
_engine.register_source_connector(spec_cls.__name__, connector)
|
|
718
|
+
return connector_cls
|
|
719
|
+
|
|
720
|
+
return _inner
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
########################################################
|
|
724
|
+
# Custom target connector
|
|
725
|
+
########################################################
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
@dataclasses.dataclass
|
|
729
|
+
class _TargetConnectorContext:
|
|
730
|
+
target_name: str
|
|
731
|
+
spec: Any
|
|
732
|
+
prepared_spec: Any
|
|
733
|
+
key_fields_schema: list[FieldSchema]
|
|
734
|
+
key_decoder: Callable[[Any], Any]
|
|
735
|
+
value_fields_schema: list[FieldSchema]
|
|
736
|
+
value_decoder: Callable[[Any], Any]
|
|
737
|
+
index_options: IndexOptions
|
|
738
|
+
setup_state: Any
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def _build_args(
|
|
742
|
+
method: Callable[..., Any], num_required_args: int, **kwargs: Any
|
|
743
|
+
) -> list[Any]:
|
|
744
|
+
signature = inspect.signature(method)
|
|
745
|
+
for param in signature.parameters.values():
|
|
746
|
+
if param.kind not in (
|
|
747
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
748
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
749
|
+
):
|
|
750
|
+
raise ValueError(
|
|
751
|
+
f"Method {method.__name__} should only have positional arguments, got {param.kind.name}"
|
|
752
|
+
)
|
|
753
|
+
if len(signature.parameters) < num_required_args:
|
|
754
|
+
raise ValueError(
|
|
755
|
+
f"Method {method.__name__} must have at least {num_required_args} required arguments: "
|
|
756
|
+
f"{', '.join(list(kwargs.keys())[:num_required_args])}"
|
|
757
|
+
)
|
|
758
|
+
if len(kwargs) > len(kwargs):
|
|
759
|
+
raise ValueError(
|
|
760
|
+
f"Method {method.__name__} can only have at most {num_required_args} arguments: {', '.join(kwargs.keys())}"
|
|
761
|
+
)
|
|
762
|
+
return [v for _, v in zip(signature.parameters, kwargs.values())]
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class TargetStateCompatibility(Enum):
|
|
766
|
+
"""The compatibility of the target state."""
|
|
767
|
+
|
|
768
|
+
COMPATIBLE = "Compatible"
|
|
769
|
+
PARTIALLY_COMPATIBLE = "PartialCompatible"
|
|
770
|
+
NOT_COMPATIBLE = "NotCompatible"
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class _TargetConnector:
|
|
774
|
+
"""
|
|
775
|
+
The connector class passed to the engine.
|
|
776
|
+
"""
|
|
777
|
+
|
|
778
|
+
_spec_cls: type[Any]
|
|
779
|
+
_persistent_key_type: Any
|
|
780
|
+
_setup_state_cls: type[Any]
|
|
781
|
+
_connector_cls: type[Any]
|
|
782
|
+
|
|
783
|
+
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
|
|
784
|
+
_apply_setup_change_async_fn: Callable[
|
|
785
|
+
[Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
|
|
786
|
+
]
|
|
787
|
+
_mutate_async_fn: Callable[..., Awaitable[None]]
|
|
788
|
+
_mutatation_type: AnalyzedDictType | None
|
|
789
|
+
|
|
790
|
+
def __init__(
|
|
791
|
+
self,
|
|
792
|
+
spec_cls: type[Any],
|
|
793
|
+
persistent_key_type: Any,
|
|
794
|
+
setup_state_cls: type[Any],
|
|
795
|
+
connector_cls: type[Any],
|
|
796
|
+
):
|
|
797
|
+
self._spec_cls = spec_cls
|
|
798
|
+
self._persistent_key_type = persistent_key_type
|
|
799
|
+
self._setup_state_cls = setup_state_cls
|
|
800
|
+
self._connector_cls = connector_cls
|
|
801
|
+
|
|
802
|
+
self._get_persistent_key_fn = _get_required_method(
|
|
803
|
+
connector_cls, "get_persistent_key"
|
|
804
|
+
)
|
|
805
|
+
self._apply_setup_change_async_fn = to_async_call(
|
|
806
|
+
_get_required_method(connector_cls, "apply_setup_change")
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
mutate_fn = _get_required_method(connector_cls, "mutate")
|
|
810
|
+
self._mutate_async_fn = to_async_call(mutate_fn)
|
|
811
|
+
|
|
812
|
+
# Store the type annotation for later use
|
|
813
|
+
self._mutatation_type = self._analyze_mutate_mutation_type(
|
|
814
|
+
connector_cls, mutate_fn
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
@staticmethod
|
|
818
|
+
def _analyze_mutate_mutation_type(
|
|
819
|
+
connector_cls: type, mutate_fn: Callable[..., Any]
|
|
820
|
+
) -> AnalyzedDictType | None:
|
|
821
|
+
# Validate mutate_fn signature and extract type annotation
|
|
822
|
+
mutate_sig = inspect.signature(mutate_fn)
|
|
823
|
+
params = list(mutate_sig.parameters.values())
|
|
824
|
+
|
|
825
|
+
if len(params) != 1:
|
|
826
|
+
raise ValueError(
|
|
827
|
+
f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
|
|
828
|
+
f"got {len(params)}"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
param = params[0]
|
|
832
|
+
if param.kind != inspect.Parameter.VAR_POSITIONAL:
|
|
833
|
+
raise ValueError(
|
|
834
|
+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
|
|
835
|
+
f"got {param.kind.name}"
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Extract type annotation
|
|
839
|
+
analyzed_args_type = analyze_type_info(param.annotation)
|
|
840
|
+
if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
|
|
841
|
+
return None
|
|
842
|
+
|
|
843
|
+
if analyzed_args_type.base_type is tuple:
|
|
844
|
+
args = get_args(analyzed_args_type.core_type)
|
|
845
|
+
if not args:
|
|
846
|
+
return None
|
|
847
|
+
if len(args) == 2:
|
|
848
|
+
mutation_type = analyze_type_info(args[1])
|
|
849
|
+
if isinstance(mutation_type.variant, AnalyzedAnyType):
|
|
850
|
+
return None
|
|
851
|
+
if isinstance(mutation_type.variant, AnalyzedDictType):
|
|
852
|
+
return mutation_type.variant
|
|
853
|
+
|
|
854
|
+
raise ValueError(
|
|
855
|
+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
|
|
856
|
+
f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
|
|
857
|
+
f"got {analyzed_args_type.core_type}"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
def create_export_context(
|
|
861
|
+
self,
|
|
862
|
+
name: str,
|
|
863
|
+
raw_spec: dict[str, Any],
|
|
864
|
+
raw_key_fields_schema: list[Any],
|
|
865
|
+
raw_value_fields_schema: list[Any],
|
|
866
|
+
raw_index_options: dict[str, Any],
|
|
867
|
+
) -> _TargetConnectorContext:
|
|
868
|
+
key_annotation, value_annotation = (
|
|
869
|
+
(
|
|
870
|
+
self._mutatation_type.key_type,
|
|
871
|
+
self._mutatation_type.value_type,
|
|
872
|
+
)
|
|
873
|
+
if self._mutatation_type is not None
|
|
874
|
+
else (Any, Any)
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
key_fields_schema = decode_engine_field_schemas(raw_key_fields_schema)
|
|
878
|
+
key_decoder = make_engine_key_decoder(
|
|
879
|
+
["<key>"], key_fields_schema, analyze_type_info(key_annotation)
|
|
880
|
+
)
|
|
881
|
+
value_fields_schema = decode_engine_field_schemas(raw_value_fields_schema)
|
|
882
|
+
value_decoder = make_engine_struct_decoder(
|
|
883
|
+
["<value>"], value_fields_schema, analyze_type_info(value_annotation)
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
spec = load_engine_object(self._spec_cls, raw_spec)
|
|
887
|
+
index_options = load_engine_object(IndexOptions, raw_index_options)
|
|
888
|
+
return _TargetConnectorContext(
|
|
889
|
+
target_name=name,
|
|
890
|
+
spec=spec,
|
|
891
|
+
prepared_spec=None,
|
|
892
|
+
key_fields_schema=key_fields_schema,
|
|
893
|
+
key_decoder=key_decoder,
|
|
894
|
+
value_fields_schema=value_fields_schema,
|
|
895
|
+
value_decoder=value_decoder,
|
|
896
|
+
index_options=index_options,
|
|
897
|
+
setup_state=None,
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
|
|
901
|
+
args = _build_args(
|
|
902
|
+
self._get_persistent_key_fn,
|
|
903
|
+
1,
|
|
904
|
+
spec=export_context.spec,
|
|
905
|
+
target_name=export_context.target_name,
|
|
906
|
+
)
|
|
907
|
+
return dump_engine_object(self._get_persistent_key_fn(*args))
|
|
908
|
+
|
|
909
|
+
def get_setup_state(self, export_context: _TargetConnectorContext) -> Any:
|
|
910
|
+
get_setup_state_fn = getattr(self._connector_cls, "get_setup_state", None)
|
|
911
|
+
if get_setup_state_fn is None:
|
|
912
|
+
state = export_context.spec
|
|
913
|
+
if not isinstance(state, self._setup_state_cls):
|
|
914
|
+
raise ValueError(
|
|
915
|
+
f"Expect a get_setup_state() method for {self._connector_cls} that returns an instance of {self._setup_state_cls}"
|
|
916
|
+
)
|
|
917
|
+
else:
|
|
918
|
+
args = _build_args(
|
|
919
|
+
get_setup_state_fn,
|
|
920
|
+
1,
|
|
921
|
+
spec=export_context.spec,
|
|
922
|
+
key_fields_schema=export_context.key_fields_schema,
|
|
923
|
+
value_fields_schema=export_context.value_fields_schema,
|
|
924
|
+
index_options=export_context.index_options,
|
|
925
|
+
)
|
|
926
|
+
state = get_setup_state_fn(*args)
|
|
927
|
+
if not isinstance(state, self._setup_state_cls):
|
|
928
|
+
raise ValueError(
|
|
929
|
+
f"Method {get_setup_state_fn.__name__} must return an instance of {self._setup_state_cls}, got {type(state)}"
|
|
930
|
+
)
|
|
931
|
+
export_context.setup_state = state
|
|
932
|
+
return dump_engine_object(state)
|
|
933
|
+
|
|
934
|
+
def check_state_compatibility(
|
|
935
|
+
self, raw_desired_state: Any, raw_existing_state: Any
|
|
936
|
+
) -> Any:
|
|
937
|
+
check_state_compatibility_fn = getattr(
|
|
938
|
+
self._connector_cls, "check_state_compatibility", None
|
|
939
|
+
)
|
|
940
|
+
if check_state_compatibility_fn is not None:
|
|
941
|
+
compatibility = check_state_compatibility_fn(
|
|
942
|
+
load_engine_object(self._setup_state_cls, raw_desired_state),
|
|
943
|
+
load_engine_object(self._setup_state_cls, raw_existing_state),
|
|
944
|
+
)
|
|
945
|
+
else:
|
|
946
|
+
compatibility = (
|
|
947
|
+
TargetStateCompatibility.COMPATIBLE
|
|
948
|
+
if raw_desired_state == raw_existing_state
|
|
949
|
+
else TargetStateCompatibility.PARTIALLY_COMPATIBLE
|
|
950
|
+
)
|
|
951
|
+
return dump_engine_object(compatibility)
|
|
952
|
+
|
|
953
|
+
async def prepare_async(
|
|
954
|
+
self,
|
|
955
|
+
export_context: _TargetConnectorContext,
|
|
956
|
+
) -> None:
|
|
957
|
+
prepare_fn = getattr(self._connector_cls, "prepare", None)
|
|
958
|
+
if prepare_fn is None:
|
|
959
|
+
export_context.prepared_spec = export_context.spec
|
|
960
|
+
return
|
|
961
|
+
args = _build_args(
|
|
962
|
+
prepare_fn,
|
|
963
|
+
1,
|
|
964
|
+
spec=export_context.spec,
|
|
965
|
+
setup_state=export_context.setup_state,
|
|
966
|
+
key_fields_schema=export_context.key_fields_schema,
|
|
967
|
+
value_fields_schema=export_context.value_fields_schema,
|
|
968
|
+
)
|
|
969
|
+
async_prepare_fn = to_async_call(prepare_fn)
|
|
970
|
+
export_context.prepared_spec = await async_prepare_fn(*args)
|
|
971
|
+
|
|
972
|
+
def describe_resource(self, raw_key: Any) -> str:
|
|
973
|
+
key = load_engine_object(self._persistent_key_type, raw_key)
|
|
974
|
+
describe_fn = getattr(self._connector_cls, "describe", None)
|
|
975
|
+
if describe_fn is None:
|
|
976
|
+
return str(key)
|
|
977
|
+
return str(describe_fn(key))
|
|
978
|
+
|
|
979
|
+
async def apply_setup_changes_async(
|
|
980
|
+
self,
|
|
981
|
+
changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
|
|
982
|
+
) -> None:
|
|
983
|
+
for raw_key, previous, current in changes:
|
|
984
|
+
key = load_engine_object(self._persistent_key_type, raw_key)
|
|
985
|
+
prev_specs = [
|
|
986
|
+
load_engine_object(self._setup_state_cls, spec)
|
|
987
|
+
if spec is not None
|
|
988
|
+
else None
|
|
989
|
+
for spec in previous
|
|
990
|
+
]
|
|
991
|
+
curr_spec = (
|
|
992
|
+
load_engine_object(self._setup_state_cls, current)
|
|
993
|
+
if current is not None
|
|
994
|
+
else None
|
|
995
|
+
)
|
|
996
|
+
for prev_spec in prev_specs:
|
|
997
|
+
await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
|
|
998
|
+
|
|
999
|
+
@staticmethod
|
|
1000
|
+
def _decode_mutation(
|
|
1001
|
+
context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
|
|
1002
|
+
) -> tuple[Any, dict[Any, Any | None]]:
|
|
1003
|
+
return (
|
|
1004
|
+
context.prepared_spec,
|
|
1005
|
+
{
|
|
1006
|
+
context.key_decoder(key): (
|
|
1007
|
+
context.value_decoder(value) if value is not None else None
|
|
1008
|
+
)
|
|
1009
|
+
for key, value in mutation
|
|
1010
|
+
},
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
async def mutate_async(
|
|
1014
|
+
self,
|
|
1015
|
+
mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
|
|
1016
|
+
) -> None:
|
|
1017
|
+
await self._mutate_async_fn(
|
|
1018
|
+
*(
|
|
1019
|
+
self._decode_mutation(context, mutation)
|
|
1020
|
+
for context, mutation in mutations
|
|
1021
|
+
)
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def target_connector(
|
|
1026
|
+
*,
|
|
1027
|
+
spec_cls: type[Any],
|
|
1028
|
+
persistent_key_type: Any = Any,
|
|
1029
|
+
setup_state_cls: type[Any] | None = None,
|
|
1030
|
+
) -> Callable[[type], type]:
|
|
1031
|
+
"""
|
|
1032
|
+
Decorate a class to provide a target connector for an op.
|
|
1033
|
+
"""
|
|
1034
|
+
|
|
1035
|
+
# Validate the spec_cls is a TargetSpec.
|
|
1036
|
+
if not issubclass(spec_cls, TargetSpec):
|
|
1037
|
+
raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
|
|
1038
|
+
|
|
1039
|
+
# Register the target connector.
|
|
1040
|
+
def _inner(connector_cls: type) -> type:
|
|
1041
|
+
connector = _TargetConnector(
|
|
1042
|
+
spec_cls, persistent_key_type, setup_state_cls or spec_cls, connector_cls
|
|
1043
|
+
)
|
|
1044
|
+
_engine.register_target_connector(spec_cls.__name__, connector)
|
|
1045
|
+
return connector_cls
|
|
1046
|
+
|
|
1047
|
+
return _inner
|