cocoindex 0.1.72__cp312-cp312-manylinux_2_28_aarch64.whl → 0.1.74__cp312-cp312-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/_engine.cpython-312-aarch64-linux-gnu.so +0 -0
- cocoindex/convert.py +237 -142
- cocoindex/flow.py +2 -1
- cocoindex/op.py +263 -6
- cocoindex/tests/test_convert.py +96 -129
- cocoindex/tests/test_typing.py +94 -211
- cocoindex/typing.py +176 -133
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.74.dist-info}/METADATA +2 -2
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.74.dist-info}/RECORD +12 -12
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.74.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.74.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.72.dist-info → cocoindex-0.1.74.dist-info}/licenses/LICENSE +0 -0
cocoindex/op.py
CHANGED
@@ -6,11 +6,31 @@ import asyncio
|
|
6
6
|
import dataclasses
|
7
7
|
import inspect
|
8
8
|
from enum import Enum
|
9
|
-
from typing import
|
9
|
+
from typing import (
|
10
|
+
Any,
|
11
|
+
Awaitable,
|
12
|
+
Callable,
|
13
|
+
Protocol,
|
14
|
+
dataclass_transform,
|
15
|
+
Annotated,
|
16
|
+
get_args,
|
17
|
+
)
|
10
18
|
|
11
19
|
from . import _engine # type: ignore
|
12
|
-
from .convert import
|
13
|
-
|
20
|
+
from .convert import (
|
21
|
+
encode_engine_value,
|
22
|
+
make_engine_value_decoder,
|
23
|
+
make_engine_struct_decoder,
|
24
|
+
)
|
25
|
+
from .typing import (
|
26
|
+
TypeAttr,
|
27
|
+
encode_enriched_type,
|
28
|
+
resolve_forward_ref,
|
29
|
+
analyze_type_info,
|
30
|
+
AnalyzedAnyType,
|
31
|
+
AnalyzedBasicType,
|
32
|
+
AnalyzedDictType,
|
33
|
+
)
|
14
34
|
|
15
35
|
|
16
36
|
class OpCategory(Enum):
|
@@ -65,6 +85,22 @@ class Executor(Protocol):
|
|
65
85
|
op_category: OpCategory
|
66
86
|
|
67
87
|
|
88
|
+
def _load_spec_from_engine(spec_cls: type, spec: dict[str, Any]) -> Any:
|
89
|
+
"""
|
90
|
+
Load a spec from the engine.
|
91
|
+
"""
|
92
|
+
return spec_cls(**spec)
|
93
|
+
|
94
|
+
|
95
|
+
def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
|
96
|
+
method = getattr(cls, name, None)
|
97
|
+
if method is None:
|
98
|
+
raise ValueError(f"Method {name}() is required for {cls.__name__}")
|
99
|
+
if not inspect.isfunction(method):
|
100
|
+
raise ValueError(f"Method {cls.__name__}.{name}() is not a function")
|
101
|
+
return method
|
102
|
+
|
103
|
+
|
68
104
|
class _FunctionExecutorFactory:
|
69
105
|
_spec_cls: type
|
70
106
|
_executor_cls: type
|
@@ -76,7 +112,7 @@ class _FunctionExecutorFactory:
|
|
76
112
|
def __call__(
|
77
113
|
self, spec: dict[str, Any], *args: Any, **kwargs: Any
|
78
114
|
) -> tuple[dict[str, Any], Executor]:
|
79
|
-
spec = self._spec_cls
|
115
|
+
spec = _load_spec_from_engine(self._spec_cls, spec)
|
80
116
|
executor = self._executor_cls(spec)
|
81
117
|
result_type = executor.analyze(*args, **kwargs)
|
82
118
|
return (encode_enriched_type(result_type), executor)
|
@@ -185,7 +221,9 @@ def _register_op_factory(
|
|
185
221
|
)
|
186
222
|
self._args_decoders.append(
|
187
223
|
make_engine_value_decoder(
|
188
|
-
[arg_name],
|
224
|
+
[arg_name],
|
225
|
+
arg.value_type["type"],
|
226
|
+
analyze_type_info(arg_param.annotation),
|
189
227
|
)
|
190
228
|
)
|
191
229
|
process_attribute(arg_name, arg)
|
@@ -217,7 +255,9 @@ def _register_op_factory(
|
|
217
255
|
)
|
218
256
|
arg_param = expected_arg[1]
|
219
257
|
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
|
220
|
-
[kwarg_name],
|
258
|
+
[kwarg_name],
|
259
|
+
kwarg.value_type["type"],
|
260
|
+
analyze_type_info(arg_param.annotation),
|
221
261
|
)
|
222
262
|
process_attribute(kwarg_name, kwarg)
|
223
263
|
|
@@ -359,3 +399,220 @@ def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
|
|
359
399
|
return _Spec()
|
360
400
|
|
361
401
|
return _inner
|
402
|
+
|
403
|
+
|
404
|
+
########################################################
|
405
|
+
# Custom target connector
|
406
|
+
########################################################
|
407
|
+
|
408
|
+
|
409
|
+
@dataclasses.dataclass
|
410
|
+
class _TargetConnectorContext:
|
411
|
+
target_name: str
|
412
|
+
spec: Any
|
413
|
+
prepared_spec: Any
|
414
|
+
key_decoder: Callable[[Any], Any]
|
415
|
+
value_decoder: Callable[[Any], Any]
|
416
|
+
|
417
|
+
|
418
|
+
class _TargetConnector:
|
419
|
+
"""
|
420
|
+
The connector class passed to the engine.
|
421
|
+
"""
|
422
|
+
|
423
|
+
_spec_cls: type
|
424
|
+
_connector_cls: type
|
425
|
+
|
426
|
+
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
|
427
|
+
_apply_setup_change_async_fn: Callable[
|
428
|
+
[Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
|
429
|
+
]
|
430
|
+
_mutate_async_fn: Callable[..., Awaitable[None]]
|
431
|
+
_mutatation_type: AnalyzedDictType | None
|
432
|
+
|
433
|
+
def __init__(self, spec_cls: type, connector_cls: type):
|
434
|
+
self._spec_cls = spec_cls
|
435
|
+
self._connector_cls = connector_cls
|
436
|
+
|
437
|
+
self._get_persistent_key_fn = _get_required_method(
|
438
|
+
connector_cls, "get_persistent_key"
|
439
|
+
)
|
440
|
+
self._apply_setup_change_async_fn = _to_async_call(
|
441
|
+
_get_required_method(connector_cls, "apply_setup_change")
|
442
|
+
)
|
443
|
+
|
444
|
+
mutate_fn = _get_required_method(connector_cls, "mutate")
|
445
|
+
self._mutate_async_fn = _to_async_call(mutate_fn)
|
446
|
+
|
447
|
+
# Store the type annotation for later use
|
448
|
+
self._mutatation_type = self._analyze_mutate_mutation_type(
|
449
|
+
connector_cls, mutate_fn
|
450
|
+
)
|
451
|
+
|
452
|
+
@staticmethod
|
453
|
+
def _analyze_mutate_mutation_type(
|
454
|
+
connector_cls: type, mutate_fn: Callable[..., Any]
|
455
|
+
) -> AnalyzedDictType | None:
|
456
|
+
# Validate mutate_fn signature and extract type annotation
|
457
|
+
mutate_sig = inspect.signature(mutate_fn)
|
458
|
+
params = list(mutate_sig.parameters.values())
|
459
|
+
|
460
|
+
if len(params) != 1:
|
461
|
+
raise ValueError(
|
462
|
+
f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
|
463
|
+
f"got {len(params)}"
|
464
|
+
)
|
465
|
+
|
466
|
+
param = params[0]
|
467
|
+
if param.kind != inspect.Parameter.VAR_POSITIONAL:
|
468
|
+
raise ValueError(
|
469
|
+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
|
470
|
+
f"got {param.kind.name}"
|
471
|
+
)
|
472
|
+
|
473
|
+
# Extract type annotation
|
474
|
+
analyzed_args_type = analyze_type_info(param.annotation)
|
475
|
+
if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
|
476
|
+
return None
|
477
|
+
|
478
|
+
if analyzed_args_type.base_type is tuple:
|
479
|
+
args = get_args(analyzed_args_type.core_type)
|
480
|
+
if not args:
|
481
|
+
return None
|
482
|
+
if len(args) == 2:
|
483
|
+
mutation_type = analyze_type_info(args[1])
|
484
|
+
if isinstance(mutation_type.variant, AnalyzedAnyType):
|
485
|
+
return None
|
486
|
+
if isinstance(mutation_type.variant, AnalyzedDictType):
|
487
|
+
return mutation_type.variant
|
488
|
+
|
489
|
+
raise ValueError(
|
490
|
+
f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
|
491
|
+
f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
|
492
|
+
"got {args_type}"
|
493
|
+
)
|
494
|
+
|
495
|
+
def create_export_context(
|
496
|
+
self,
|
497
|
+
name: str,
|
498
|
+
spec: dict[str, Any],
|
499
|
+
key_fields_schema: list[Any],
|
500
|
+
value_fields_schema: list[Any],
|
501
|
+
) -> _TargetConnectorContext:
|
502
|
+
key_annotation, value_annotation = (
|
503
|
+
(
|
504
|
+
self._mutatation_type.key_type,
|
505
|
+
self._mutatation_type.value_type,
|
506
|
+
)
|
507
|
+
if self._mutatation_type is not None
|
508
|
+
else (None, None)
|
509
|
+
)
|
510
|
+
|
511
|
+
key_type_info = analyze_type_info(key_annotation)
|
512
|
+
if (
|
513
|
+
len(key_fields_schema) == 1
|
514
|
+
and key_fields_schema[0]["type"]["kind"] != "Struct"
|
515
|
+
and isinstance(key_type_info.variant, (AnalyzedAnyType, AnalyzedBasicType))
|
516
|
+
):
|
517
|
+
# Special case for ease of use: single key column can be mapped to a basic type without the wrapper struct.
|
518
|
+
key_decoder = make_engine_value_decoder(
|
519
|
+
["(key)"],
|
520
|
+
key_fields_schema[0]["type"],
|
521
|
+
key_type_info,
|
522
|
+
)
|
523
|
+
else:
|
524
|
+
key_decoder = make_engine_struct_decoder(
|
525
|
+
["(key)"], key_fields_schema, key_type_info
|
526
|
+
)
|
527
|
+
|
528
|
+
value_decoder = make_engine_struct_decoder(
|
529
|
+
["(value)"], value_fields_schema, analyze_type_info(value_annotation)
|
530
|
+
)
|
531
|
+
|
532
|
+
loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
|
533
|
+
prepare_method = getattr(self._connector_cls, "prepare", None)
|
534
|
+
if prepare_method is None:
|
535
|
+
prepared_spec = loaded_spec
|
536
|
+
else:
|
537
|
+
prepared_spec = prepare_method(loaded_spec)
|
538
|
+
|
539
|
+
return _TargetConnectorContext(
|
540
|
+
target_name=name,
|
541
|
+
spec=loaded_spec,
|
542
|
+
prepared_spec=prepared_spec,
|
543
|
+
key_decoder=key_decoder,
|
544
|
+
value_decoder=value_decoder,
|
545
|
+
)
|
546
|
+
|
547
|
+
def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
|
548
|
+
return self._get_persistent_key_fn(
|
549
|
+
export_context.spec, export_context.target_name
|
550
|
+
)
|
551
|
+
|
552
|
+
def describe_resource(self, key: Any) -> str:
|
553
|
+
describe_fn = getattr(self._connector_cls, "describe", None)
|
554
|
+
if describe_fn is None:
|
555
|
+
return str(key)
|
556
|
+
return str(describe_fn(key))
|
557
|
+
|
558
|
+
async def apply_setup_changes_async(
|
559
|
+
self,
|
560
|
+
changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
|
561
|
+
) -> None:
|
562
|
+
for key, previous, current in changes:
|
563
|
+
prev_specs = [
|
564
|
+
_load_spec_from_engine(self._spec_cls, spec)
|
565
|
+
if spec is not None
|
566
|
+
else None
|
567
|
+
for spec in previous
|
568
|
+
]
|
569
|
+
curr_spec = (
|
570
|
+
_load_spec_from_engine(self._spec_cls, current)
|
571
|
+
if current is not None
|
572
|
+
else None
|
573
|
+
)
|
574
|
+
for prev_spec in prev_specs:
|
575
|
+
await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
|
576
|
+
|
577
|
+
@staticmethod
|
578
|
+
def _decode_mutation(
|
579
|
+
context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
|
580
|
+
) -> tuple[Any, dict[Any, Any | None]]:
|
581
|
+
return (
|
582
|
+
context.prepared_spec,
|
583
|
+
{
|
584
|
+
context.key_decoder(key): (
|
585
|
+
context.value_decoder(value) if value is not None else None
|
586
|
+
)
|
587
|
+
for key, value in mutation
|
588
|
+
},
|
589
|
+
)
|
590
|
+
|
591
|
+
async def mutate_async(
|
592
|
+
self,
|
593
|
+
mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
|
594
|
+
) -> None:
|
595
|
+
await self._mutate_async_fn(
|
596
|
+
*(
|
597
|
+
self._decode_mutation(context, mutation)
|
598
|
+
for context, mutation in mutations
|
599
|
+
)
|
600
|
+
)
|
601
|
+
|
602
|
+
|
603
|
+
def target_connector(spec_cls: type) -> Callable[[type], type]:
|
604
|
+
"""
|
605
|
+
Decorate a class to provide a target connector for an op.
|
606
|
+
"""
|
607
|
+
|
608
|
+
# Validate the spec_cls is a TargetSpec.
|
609
|
+
if not issubclass(spec_cls, TargetSpec):
|
610
|
+
raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
|
611
|
+
|
612
|
+
# Register the target connector.
|
613
|
+
def _inner(connector_cls: type) -> type:
|
614
|
+
connector = _TargetConnector(spec_cls, connector_cls)
|
615
|
+
_engine.register_target_connector(spec_cls.__name__, connector)
|
616
|
+
return connector_cls
|
617
|
+
|
618
|
+
return _inner
|
cocoindex/tests/test_convert.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import datetime
|
2
2
|
import uuid
|
3
|
-
from dataclasses import dataclass, make_dataclass
|
3
|
+
from dataclasses import dataclass, make_dataclass, field
|
4
4
|
from typing import Annotated, Any, Callable, Literal, NamedTuple
|
5
5
|
|
6
6
|
import numpy as np
|
@@ -19,6 +19,7 @@ from cocoindex.typing import (
|
|
19
19
|
TypeKind,
|
20
20
|
Vector,
|
21
21
|
encode_enriched_type,
|
22
|
+
analyze_type_info,
|
22
23
|
)
|
23
24
|
|
24
25
|
|
@@ -75,7 +76,9 @@ def build_engine_value_decoder(
|
|
75
76
|
If python_type is not specified, uses engine_type_in_py as the target.
|
76
77
|
"""
|
77
78
|
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
78
|
-
return make_engine_value_decoder(
|
79
|
+
return make_engine_value_decoder(
|
80
|
+
[], engine_type, analyze_type_info(python_type or engine_type_in_py)
|
81
|
+
)
|
79
82
|
|
80
83
|
|
81
84
|
def validate_full_roundtrip_to(
|
@@ -103,9 +106,13 @@ def validate_full_roundtrip_to(
|
|
103
106
|
)
|
104
107
|
|
105
108
|
for other_value, other_type in decoded_values:
|
106
|
-
decoder = make_engine_value_decoder(
|
109
|
+
decoder = make_engine_value_decoder(
|
110
|
+
[], encoded_output_type, analyze_type_info(other_type)
|
111
|
+
)
|
107
112
|
other_decoded_value = decoder(value_from_engine)
|
108
|
-
assert eq(other_decoded_value, other_value)
|
113
|
+
assert eq(other_decoded_value, other_value), (
|
114
|
+
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
|
115
|
+
)
|
109
116
|
|
110
117
|
|
111
118
|
def validate_full_roundtrip(
|
@@ -362,7 +369,9 @@ def test_decode_scalar_numpy_values() -> None:
|
|
362
369
|
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
|
363
370
|
]
|
364
371
|
for src_type, dst_type, input_value, expected in test_cases:
|
365
|
-
decoder = make_engine_value_decoder(
|
372
|
+
decoder = make_engine_value_decoder(
|
373
|
+
["field"], src_type, analyze_type_info(dst_type)
|
374
|
+
)
|
366
375
|
result = decoder(input_value)
|
367
376
|
assert isinstance(result, dst_type)
|
368
377
|
assert result == expected
|
@@ -376,7 +385,9 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
376
385
|
"dimension": None,
|
377
386
|
}
|
378
387
|
dst_type_float = list[np.float64]
|
379
|
-
decoder = make_engine_value_decoder(
|
388
|
+
decoder = make_engine_value_decoder(
|
389
|
+
["field"], src_type, analyze_type_info(dst_type_float)
|
390
|
+
)
|
380
391
|
input_numbers = [1.0, 2.0, 3.0]
|
381
392
|
result = decoder(input_numbers)
|
382
393
|
assert isinstance(result, list)
|
@@ -386,7 +397,9 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
386
397
|
# Test list[Uuid]
|
387
398
|
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
|
388
399
|
dst_type_uuid = list[uuid.UUID]
|
389
|
-
decoder = make_engine_value_decoder(
|
400
|
+
decoder = make_engine_value_decoder(
|
401
|
+
["field"], src_type, analyze_type_info(dst_type_uuid)
|
402
|
+
)
|
390
403
|
uuid1 = uuid.uuid4()
|
391
404
|
uuid2 = uuid.uuid4()
|
392
405
|
input_uuids = [uuid1, uuid2]
|
@@ -396,124 +409,15 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
396
409
|
assert result == [uuid1, uuid2]
|
397
410
|
|
398
411
|
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
# All fields match (NamedTuple)
|
409
|
-
(
|
410
|
-
OrderNamedTuple,
|
411
|
-
["O123", "mixed nuts", 25.0, "default_extra"],
|
412
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
413
|
-
),
|
414
|
-
# Extra field in engine value (should ignore extra)
|
415
|
-
(
|
416
|
-
Order,
|
417
|
-
["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
|
418
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
419
|
-
),
|
420
|
-
(
|
421
|
-
OrderNamedTuple,
|
422
|
-
["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
|
423
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
424
|
-
),
|
425
|
-
# Fewer fields in engine value (should fill with default)
|
426
|
-
(
|
427
|
-
Order,
|
428
|
-
["O123", "mixed nuts", 0.0, "default_extra"],
|
429
|
-
Order("O123", "mixed nuts", 0.0, "default_extra"),
|
430
|
-
),
|
431
|
-
(
|
432
|
-
OrderNamedTuple,
|
433
|
-
["O123", "mixed nuts", 0.0, "default_extra"],
|
434
|
-
OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra"),
|
435
|
-
),
|
436
|
-
# More fields in engine value (should ignore extra)
|
437
|
-
(
|
438
|
-
Order,
|
439
|
-
["O123", "mixed nuts", 25.0, "unexpected"],
|
440
|
-
Order("O123", "mixed nuts", 25.0, "unexpected"),
|
441
|
-
),
|
442
|
-
(
|
443
|
-
OrderNamedTuple,
|
444
|
-
["O123", "mixed nuts", 25.0, "unexpected"],
|
445
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected"),
|
446
|
-
),
|
447
|
-
# Truly extra field (should ignore the fifth field)
|
448
|
-
(
|
449
|
-
Order,
|
450
|
-
["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
|
451
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
452
|
-
),
|
453
|
-
(
|
454
|
-
OrderNamedTuple,
|
455
|
-
["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
|
456
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
457
|
-
),
|
458
|
-
# Missing optional field in engine value (tags=None)
|
459
|
-
(
|
460
|
-
Customer,
|
461
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], None],
|
462
|
-
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
|
463
|
-
),
|
464
|
-
(
|
465
|
-
CustomerNamedTuple,
|
466
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], None],
|
467
|
-
CustomerNamedTuple(
|
468
|
-
"Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
|
469
|
-
),
|
470
|
-
),
|
471
|
-
# Extra field in engine value for Customer (should ignore)
|
472
|
-
(
|
473
|
-
Customer,
|
474
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
|
475
|
-
Customer(
|
476
|
-
"Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]
|
477
|
-
),
|
478
|
-
),
|
479
|
-
(
|
480
|
-
CustomerNamedTuple,
|
481
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
|
482
|
-
CustomerNamedTuple(
|
483
|
-
"Alice",
|
484
|
-
OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
|
485
|
-
[Tag("vip")],
|
486
|
-
),
|
487
|
-
),
|
488
|
-
# Missing optional field with default
|
489
|
-
(
|
490
|
-
Order,
|
491
|
-
["O123", "mixed nuts", 25.0],
|
492
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
493
|
-
),
|
494
|
-
(
|
495
|
-
OrderNamedTuple,
|
496
|
-
["O123", "mixed nuts", 25.0],
|
497
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
498
|
-
),
|
499
|
-
# Partial optional fields
|
500
|
-
(
|
501
|
-
Customer,
|
502
|
-
["Alice", ["O1", "item1", 10.0]],
|
503
|
-
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
|
504
|
-
),
|
505
|
-
(
|
506
|
-
CustomerNamedTuple,
|
507
|
-
["Alice", ["O1", "item1", 10.0]],
|
508
|
-
CustomerNamedTuple(
|
509
|
-
"Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
|
510
|
-
),
|
511
|
-
),
|
512
|
-
],
|
513
|
-
)
|
514
|
-
def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
|
515
|
-
decoder = build_engine_value_decoder(data_type)
|
516
|
-
assert decoder(engine_val) == expected
|
412
|
+
def test_roundtrip_struct() -> None:
|
413
|
+
validate_full_roundtrip(
|
414
|
+
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
415
|
+
Order,
|
416
|
+
)
|
417
|
+
validate_full_roundtrip(
|
418
|
+
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
419
|
+
OrderNamedTuple,
|
420
|
+
)
|
517
421
|
|
518
422
|
|
519
423
|
def test_make_engine_value_decoder_list_of_struct() -> None:
|
@@ -972,7 +876,9 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
|
972
876
|
"dimension": None,
|
973
877
|
}
|
974
878
|
dst_annotation = NDArrayFloat32Type | None
|
975
|
-
decoder = make_engine_value_decoder(
|
879
|
+
decoder = make_engine_value_decoder(
|
880
|
+
[], src_type_dict, analyze_type_info(dst_annotation)
|
881
|
+
)
|
976
882
|
|
977
883
|
none_engine_value = None
|
978
884
|
decoded_array = decoder(none_engine_value)
|
@@ -995,7 +901,9 @@ def test_decode_vector_string() -> None:
|
|
995
901
|
"element_type": {"kind": "Str"},
|
996
902
|
"dimension": None,
|
997
903
|
}
|
998
|
-
decoder = make_engine_value_decoder(
|
904
|
+
decoder = make_engine_value_decoder(
|
905
|
+
[], src_type_dict, analyze_type_info(Vector[str])
|
906
|
+
)
|
999
907
|
assert decoder(["hello", "world"]) == ["hello", "world"]
|
1000
908
|
|
1001
909
|
|
@@ -1006,7 +914,9 @@ def test_decode_error_non_nullable_or_non_list_vector() -> None:
|
|
1006
914
|
"element_type": {"kind": "Float32"},
|
1007
915
|
"dimension": None,
|
1008
916
|
}
|
1009
|
-
decoder = make_engine_value_decoder(
|
917
|
+
decoder = make_engine_value_decoder(
|
918
|
+
[], src_type_dict, analyze_type_info(NDArrayFloat32Type)
|
919
|
+
)
|
1010
920
|
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
|
1011
921
|
decoder(None)
|
1012
922
|
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
|
@@ -1096,6 +1006,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
|
|
1096
1006
|
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
|
1097
1007
|
|
1098
1008
|
|
1009
|
+
def test_full_roundtrip_vector_of_vector() -> None:
|
1010
|
+
"""Test full roundtrip for vector of vector."""
|
1011
|
+
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
1012
|
+
validate_full_roundtrip(
|
1013
|
+
value_f32,
|
1014
|
+
Vector[Vector[np.float32, Literal[3]], Literal[2]],
|
1015
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
|
1016
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
|
1017
|
+
(
|
1018
|
+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
1019
|
+
list[Vector[cocoindex.Float32, Literal[3]]],
|
1020
|
+
),
|
1021
|
+
(
|
1022
|
+
value_f32,
|
1023
|
+
np.typing.NDArray[np.typing.NDArray[np.float32]],
|
1024
|
+
),
|
1025
|
+
)
|
1026
|
+
|
1027
|
+
|
1099
1028
|
def test_full_roundtrip_vector_other_types() -> None:
|
1100
1029
|
"""Test full roundtrip for Vector with non-numeric basic types."""
|
1101
1030
|
uuid_list = [uuid.uuid4(), uuid.uuid4()]
|
@@ -1216,7 +1145,7 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
|
|
1216
1145
|
numpy_float: np.float64
|
1217
1146
|
python_float: float
|
1218
1147
|
string: str
|
1219
|
-
annotated_int: Annotated[np.int64, TypeKind("
|
1148
|
+
annotated_int: Annotated[np.int64, TypeKind("Int64")]
|
1220
1149
|
annotated_float: Float32
|
1221
1150
|
|
1222
1151
|
instance = MixedStruct(
|
@@ -1468,3 +1397,41 @@ def test_roundtrip_ktable_with_list_fields() -> None:
|
|
1468
1397
|
|
1469
1398
|
# Test Any annotation
|
1470
1399
|
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
|
1400
|
+
|
1401
|
+
|
1402
|
+
def test_auto_default_for_supported_and_unsupported_types() -> None:
|
1403
|
+
@dataclass
|
1404
|
+
class Base:
|
1405
|
+
a: int
|
1406
|
+
|
1407
|
+
@dataclass
|
1408
|
+
class NullableField:
|
1409
|
+
a: int
|
1410
|
+
b: int | None
|
1411
|
+
|
1412
|
+
@dataclass
|
1413
|
+
class LTableField:
|
1414
|
+
a: int
|
1415
|
+
b: list[Base]
|
1416
|
+
|
1417
|
+
@dataclass
|
1418
|
+
class KTableField:
|
1419
|
+
a: int
|
1420
|
+
b: dict[str, Base]
|
1421
|
+
|
1422
|
+
@dataclass
|
1423
|
+
class UnsupportedField:
|
1424
|
+
a: int
|
1425
|
+
b: int
|
1426
|
+
|
1427
|
+
validate_full_roundtrip(NullableField(1, None), NullableField)
|
1428
|
+
|
1429
|
+
validate_full_roundtrip(LTableField(1, []), LTableField)
|
1430
|
+
|
1431
|
+
# validate_full_roundtrip(KTableField(1, {}), KTableField)
|
1432
|
+
|
1433
|
+
with pytest.raises(
|
1434
|
+
ValueError,
|
1435
|
+
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
|
1436
|
+
):
|
1437
|
+
build_engine_value_decoder(Base, UnsupportedField)
|