cocoindex 0.1.72__cp311-cp311-win_amd64.whl → 0.1.74__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/op.py CHANGED
@@ -6,11 +6,31 @@ import asyncio
6
6
  import dataclasses
7
7
  import inspect
8
8
  from enum import Enum
9
- from typing import Any, Awaitable, Callable, Protocol, dataclass_transform, Annotated
9
+ from typing import (
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ Protocol,
14
+ dataclass_transform,
15
+ Annotated,
16
+ get_args,
17
+ )
10
18
 
11
19
  from . import _engine # type: ignore
12
- from .convert import encode_engine_value, make_engine_value_decoder
13
- from .typing import TypeAttr, encode_enriched_type, resolve_forward_ref
20
+ from .convert import (
21
+ encode_engine_value,
22
+ make_engine_value_decoder,
23
+ make_engine_struct_decoder,
24
+ )
25
+ from .typing import (
26
+ TypeAttr,
27
+ encode_enriched_type,
28
+ resolve_forward_ref,
29
+ analyze_type_info,
30
+ AnalyzedAnyType,
31
+ AnalyzedBasicType,
32
+ AnalyzedDictType,
33
+ )
14
34
 
15
35
 
16
36
  class OpCategory(Enum):
@@ -65,6 +85,22 @@ class Executor(Protocol):
65
85
  op_category: OpCategory
66
86
 
67
87
 
88
+ def _load_spec_from_engine(spec_cls: type, spec: dict[str, Any]) -> Any:
89
+ """
90
+ Load a spec from the engine.
91
+ """
92
+ return spec_cls(**spec)
93
+
94
+
95
+ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
96
+ method = getattr(cls, name, None)
97
+ if method is None:
98
+ raise ValueError(f"Method {name}() is required for {cls.__name__}")
99
+ if not inspect.isfunction(method):
100
+ raise ValueError(f"Method {cls.__name__}.{name}() is not a function")
101
+ return method
102
+
103
+
68
104
  class _FunctionExecutorFactory:
69
105
  _spec_cls: type
70
106
  _executor_cls: type
@@ -76,7 +112,7 @@ class _FunctionExecutorFactory:
76
112
  def __call__(
77
113
  self, spec: dict[str, Any], *args: Any, **kwargs: Any
78
114
  ) -> tuple[dict[str, Any], Executor]:
79
- spec = self._spec_cls(**spec)
115
+ spec = _load_spec_from_engine(self._spec_cls, spec)
80
116
  executor = self._executor_cls(spec)
81
117
  result_type = executor.analyze(*args, **kwargs)
82
118
  return (encode_enriched_type(result_type), executor)
@@ -185,7 +221,9 @@ def _register_op_factory(
185
221
  )
186
222
  self._args_decoders.append(
187
223
  make_engine_value_decoder(
188
- [arg_name], arg.value_type["type"], arg_param.annotation
224
+ [arg_name],
225
+ arg.value_type["type"],
226
+ analyze_type_info(arg_param.annotation),
189
227
  )
190
228
  )
191
229
  process_attribute(arg_name, arg)
@@ -217,7 +255,9 @@ def _register_op_factory(
217
255
  )
218
256
  arg_param = expected_arg[1]
219
257
  self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
220
- [kwarg_name], kwarg.value_type["type"], arg_param.annotation
258
+ [kwarg_name],
259
+ kwarg.value_type["type"],
260
+ analyze_type_info(arg_param.annotation),
221
261
  )
222
262
  process_attribute(kwarg_name, kwarg)
223
263
 
@@ -359,3 +399,220 @@ def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
359
399
  return _Spec()
360
400
 
361
401
  return _inner
402
+
403
+
404
+ ########################################################
405
+ # Custom target connector
406
+ ########################################################
407
+
408
+
409
+ @dataclasses.dataclass
410
+ class _TargetConnectorContext:
411
+ target_name: str
412
+ spec: Any
413
+ prepared_spec: Any
414
+ key_decoder: Callable[[Any], Any]
415
+ value_decoder: Callable[[Any], Any]
416
+
417
+
418
+ class _TargetConnector:
419
+ """
420
+ The connector class passed to the engine.
421
+ """
422
+
423
+ _spec_cls: type
424
+ _connector_cls: type
425
+
426
+ _get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
427
+ _apply_setup_change_async_fn: Callable[
428
+ [Any, dict[str, Any] | None, dict[str, Any] | None], Awaitable[None]
429
+ ]
430
+ _mutate_async_fn: Callable[..., Awaitable[None]]
431
+ _mutatation_type: AnalyzedDictType | None
432
+
433
+ def __init__(self, spec_cls: type, connector_cls: type):
434
+ self._spec_cls = spec_cls
435
+ self._connector_cls = connector_cls
436
+
437
+ self._get_persistent_key_fn = _get_required_method(
438
+ connector_cls, "get_persistent_key"
439
+ )
440
+ self._apply_setup_change_async_fn = _to_async_call(
441
+ _get_required_method(connector_cls, "apply_setup_change")
442
+ )
443
+
444
+ mutate_fn = _get_required_method(connector_cls, "mutate")
445
+ self._mutate_async_fn = _to_async_call(mutate_fn)
446
+
447
+ # Store the type annotation for later use
448
+ self._mutatation_type = self._analyze_mutate_mutation_type(
449
+ connector_cls, mutate_fn
450
+ )
451
+
452
+ @staticmethod
453
+ def _analyze_mutate_mutation_type(
454
+ connector_cls: type, mutate_fn: Callable[..., Any]
455
+ ) -> AnalyzedDictType | None:
456
+ # Validate mutate_fn signature and extract type annotation
457
+ mutate_sig = inspect.signature(mutate_fn)
458
+ params = list(mutate_sig.parameters.values())
459
+
460
+ if len(params) != 1:
461
+ raise ValueError(
462
+ f"Method {connector_cls.__name__}.mutate(*args) must have exactly one parameter, "
463
+ f"got {len(params)}"
464
+ )
465
+
466
+ param = params[0]
467
+ if param.kind != inspect.Parameter.VAR_POSITIONAL:
468
+ raise ValueError(
469
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be *args format, "
470
+ f"got {param.kind.name}"
471
+ )
472
+
473
+ # Extract type annotation
474
+ analyzed_args_type = analyze_type_info(param.annotation)
475
+ if isinstance(analyzed_args_type.variant, AnalyzedAnyType):
476
+ return None
477
+
478
+ if analyzed_args_type.base_type is tuple:
479
+ args = get_args(analyzed_args_type.core_type)
480
+ if not args:
481
+ return None
482
+ if len(args) == 2:
483
+ mutation_type = analyze_type_info(args[1])
484
+ if isinstance(mutation_type.variant, AnalyzedAnyType):
485
+ return None
486
+ if isinstance(mutation_type.variant, AnalyzedDictType):
487
+ return mutation_type.variant
488
+
489
+ raise ValueError(
490
+ f"Method {connector_cls.__name__}.mutate(*args) parameter must be a tuple with "
491
+ f"2 elements (tuple[SpecType, dict[str, ValueStruct]], spec and mutation in dict), "
492
+ "got {args_type}"
493
+ )
494
+
495
+ def create_export_context(
496
+ self,
497
+ name: str,
498
+ spec: dict[str, Any],
499
+ key_fields_schema: list[Any],
500
+ value_fields_schema: list[Any],
501
+ ) -> _TargetConnectorContext:
502
+ key_annotation, value_annotation = (
503
+ (
504
+ self._mutatation_type.key_type,
505
+ self._mutatation_type.value_type,
506
+ )
507
+ if self._mutatation_type is not None
508
+ else (None, None)
509
+ )
510
+
511
+ key_type_info = analyze_type_info(key_annotation)
512
+ if (
513
+ len(key_fields_schema) == 1
514
+ and key_fields_schema[0]["type"]["kind"] != "Struct"
515
+ and isinstance(key_type_info.variant, (AnalyzedAnyType, AnalyzedBasicType))
516
+ ):
517
+ # Special case for ease of use: single key column can be mapped to a basic type without the wrapper struct.
518
+ key_decoder = make_engine_value_decoder(
519
+ ["(key)"],
520
+ key_fields_schema[0]["type"],
521
+ key_type_info,
522
+ )
523
+ else:
524
+ key_decoder = make_engine_struct_decoder(
525
+ ["(key)"], key_fields_schema, key_type_info
526
+ )
527
+
528
+ value_decoder = make_engine_struct_decoder(
529
+ ["(value)"], value_fields_schema, analyze_type_info(value_annotation)
530
+ )
531
+
532
+ loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
533
+ prepare_method = getattr(self._connector_cls, "prepare", None)
534
+ if prepare_method is None:
535
+ prepared_spec = loaded_spec
536
+ else:
537
+ prepared_spec = prepare_method(loaded_spec)
538
+
539
+ return _TargetConnectorContext(
540
+ target_name=name,
541
+ spec=loaded_spec,
542
+ prepared_spec=prepared_spec,
543
+ key_decoder=key_decoder,
544
+ value_decoder=value_decoder,
545
+ )
546
+
547
+ def get_persistent_key(self, export_context: _TargetConnectorContext) -> Any:
548
+ return self._get_persistent_key_fn(
549
+ export_context.spec, export_context.target_name
550
+ )
551
+
552
+ def describe_resource(self, key: Any) -> str:
553
+ describe_fn = getattr(self._connector_cls, "describe", None)
554
+ if describe_fn is None:
555
+ return str(key)
556
+ return str(describe_fn(key))
557
+
558
+ async def apply_setup_changes_async(
559
+ self,
560
+ changes: list[tuple[Any, list[dict[str, Any] | None], dict[str, Any] | None]],
561
+ ) -> None:
562
+ for key, previous, current in changes:
563
+ prev_specs = [
564
+ _load_spec_from_engine(self._spec_cls, spec)
565
+ if spec is not None
566
+ else None
567
+ for spec in previous
568
+ ]
569
+ curr_spec = (
570
+ _load_spec_from_engine(self._spec_cls, current)
571
+ if current is not None
572
+ else None
573
+ )
574
+ for prev_spec in prev_specs:
575
+ await self._apply_setup_change_async_fn(key, prev_spec, curr_spec)
576
+
577
+ @staticmethod
578
+ def _decode_mutation(
579
+ context: _TargetConnectorContext, mutation: list[tuple[Any, Any | None]]
580
+ ) -> tuple[Any, dict[Any, Any | None]]:
581
+ return (
582
+ context.prepared_spec,
583
+ {
584
+ context.key_decoder(key): (
585
+ context.value_decoder(value) if value is not None else None
586
+ )
587
+ for key, value in mutation
588
+ },
589
+ )
590
+
591
+ async def mutate_async(
592
+ self,
593
+ mutations: list[tuple[_TargetConnectorContext, list[tuple[Any, Any | None]]]],
594
+ ) -> None:
595
+ await self._mutate_async_fn(
596
+ *(
597
+ self._decode_mutation(context, mutation)
598
+ for context, mutation in mutations
599
+ )
600
+ )
601
+
602
+
603
+ def target_connector(spec_cls: type) -> Callable[[type], type]:
604
+ """
605
+ Decorate a class to provide a target connector for an op.
606
+ """
607
+
608
+ # Validate the spec_cls is a TargetSpec.
609
+ if not issubclass(spec_cls, TargetSpec):
610
+ raise ValueError(f"Expect a TargetSpec, got {spec_cls}")
611
+
612
+ # Register the target connector.
613
+ def _inner(connector_cls: type) -> type:
614
+ connector = _TargetConnector(spec_cls, connector_cls)
615
+ _engine.register_target_connector(spec_cls.__name__, connector)
616
+ return connector_cls
617
+
618
+ return _inner
@@ -1,6 +1,6 @@
1
1
  import datetime
2
2
  import uuid
3
- from dataclasses import dataclass, make_dataclass
3
+ from dataclasses import dataclass, make_dataclass, field
4
4
  from typing import Annotated, Any, Callable, Literal, NamedTuple
5
5
 
6
6
  import numpy as np
@@ -19,6 +19,7 @@ from cocoindex.typing import (
19
19
  TypeKind,
20
20
  Vector,
21
21
  encode_enriched_type,
22
+ analyze_type_info,
22
23
  )
23
24
 
24
25
 
@@ -75,7 +76,9 @@ def build_engine_value_decoder(
75
76
  If python_type is not specified, uses engine_type_in_py as the target.
76
77
  """
77
78
  engine_type = encode_enriched_type(engine_type_in_py)["type"]
78
- return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
79
+ return make_engine_value_decoder(
80
+ [], engine_type, analyze_type_info(python_type or engine_type_in_py)
81
+ )
79
82
 
80
83
 
81
84
  def validate_full_roundtrip_to(
@@ -103,9 +106,13 @@ def validate_full_roundtrip_to(
103
106
  )
104
107
 
105
108
  for other_value, other_type in decoded_values:
106
- decoder = make_engine_value_decoder([], encoded_output_type, other_type)
109
+ decoder = make_engine_value_decoder(
110
+ [], encoded_output_type, analyze_type_info(other_type)
111
+ )
107
112
  other_decoded_value = decoder(value_from_engine)
108
- assert eq(other_decoded_value, other_value)
113
+ assert eq(other_decoded_value, other_value), (
114
+ f"Expected {other_value} but got {other_decoded_value} for {other_type}"
115
+ )
109
116
 
110
117
 
111
118
  def validate_full_roundtrip(
@@ -362,7 +369,9 @@ def test_decode_scalar_numpy_values() -> None:
362
369
  ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
363
370
  ]
364
371
  for src_type, dst_type, input_value, expected in test_cases:
365
- decoder = make_engine_value_decoder(["field"], src_type, dst_type)
372
+ decoder = make_engine_value_decoder(
373
+ ["field"], src_type, analyze_type_info(dst_type)
374
+ )
366
375
  result = decoder(input_value)
367
376
  assert isinstance(result, dst_type)
368
377
  assert result == expected
@@ -376,7 +385,9 @@ def test_non_ndarray_vector_decoding() -> None:
376
385
  "dimension": None,
377
386
  }
378
387
  dst_type_float = list[np.float64]
379
- decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
388
+ decoder = make_engine_value_decoder(
389
+ ["field"], src_type, analyze_type_info(dst_type_float)
390
+ )
380
391
  input_numbers = [1.0, 2.0, 3.0]
381
392
  result = decoder(input_numbers)
382
393
  assert isinstance(result, list)
@@ -386,7 +397,9 @@ def test_non_ndarray_vector_decoding() -> None:
386
397
  # Test list[Uuid]
387
398
  src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
388
399
  dst_type_uuid = list[uuid.UUID]
389
- decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
400
+ decoder = make_engine_value_decoder(
401
+ ["field"], src_type, analyze_type_info(dst_type_uuid)
402
+ )
390
403
  uuid1 = uuid.uuid4()
391
404
  uuid2 = uuid.uuid4()
392
405
  input_uuids = [uuid1, uuid2]
@@ -396,124 +409,15 @@ def test_non_ndarray_vector_decoding() -> None:
396
409
  assert result == [uuid1, uuid2]
397
410
 
398
411
 
399
- @pytest.mark.parametrize(
400
- "data_type, engine_val, expected",
401
- [
402
- # All fields match (dataclass)
403
- (
404
- Order,
405
- ["O123", "mixed nuts", 25.0, "default_extra"],
406
- Order("O123", "mixed nuts", 25.0, "default_extra"),
407
- ),
408
- # All fields match (NamedTuple)
409
- (
410
- OrderNamedTuple,
411
- ["O123", "mixed nuts", 25.0, "default_extra"],
412
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
413
- ),
414
- # Extra field in engine value (should ignore extra)
415
- (
416
- Order,
417
- ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
418
- Order("O123", "mixed nuts", 25.0, "default_extra"),
419
- ),
420
- (
421
- OrderNamedTuple,
422
- ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
423
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
424
- ),
425
- # Fewer fields in engine value (should fill with default)
426
- (
427
- Order,
428
- ["O123", "mixed nuts", 0.0, "default_extra"],
429
- Order("O123", "mixed nuts", 0.0, "default_extra"),
430
- ),
431
- (
432
- OrderNamedTuple,
433
- ["O123", "mixed nuts", 0.0, "default_extra"],
434
- OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra"),
435
- ),
436
- # More fields in engine value (should ignore extra)
437
- (
438
- Order,
439
- ["O123", "mixed nuts", 25.0, "unexpected"],
440
- Order("O123", "mixed nuts", 25.0, "unexpected"),
441
- ),
442
- (
443
- OrderNamedTuple,
444
- ["O123", "mixed nuts", 25.0, "unexpected"],
445
- OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected"),
446
- ),
447
- # Truly extra field (should ignore the fifth field)
448
- (
449
- Order,
450
- ["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
451
- Order("O123", "mixed nuts", 25.0, "default_extra"),
452
- ),
453
- (
454
- OrderNamedTuple,
455
- ["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
456
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
457
- ),
458
- # Missing optional field in engine value (tags=None)
459
- (
460
- Customer,
461
- ["Alice", ["O1", "item1", 10.0, "default_extra"], None],
462
- Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
463
- ),
464
- (
465
- CustomerNamedTuple,
466
- ["Alice", ["O1", "item1", 10.0, "default_extra"], None],
467
- CustomerNamedTuple(
468
- "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
469
- ),
470
- ),
471
- # Extra field in engine value for Customer (should ignore)
472
- (
473
- Customer,
474
- ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
475
- Customer(
476
- "Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]
477
- ),
478
- ),
479
- (
480
- CustomerNamedTuple,
481
- ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
482
- CustomerNamedTuple(
483
- "Alice",
484
- OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
485
- [Tag("vip")],
486
- ),
487
- ),
488
- # Missing optional field with default
489
- (
490
- Order,
491
- ["O123", "mixed nuts", 25.0],
492
- Order("O123", "mixed nuts", 25.0, "default_extra"),
493
- ),
494
- (
495
- OrderNamedTuple,
496
- ["O123", "mixed nuts", 25.0],
497
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
498
- ),
499
- # Partial optional fields
500
- (
501
- Customer,
502
- ["Alice", ["O1", "item1", 10.0]],
503
- Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
504
- ),
505
- (
506
- CustomerNamedTuple,
507
- ["Alice", ["O1", "item1", 10.0]],
508
- CustomerNamedTuple(
509
- "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
510
- ),
511
- ),
512
- ],
513
- )
514
- def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
515
- decoder = build_engine_value_decoder(data_type)
516
- assert decoder(engine_val) == expected
412
+ def test_roundtrip_struct() -> None:
413
+ validate_full_roundtrip(
414
+ Order("O123", "mixed nuts", 25.0, "default_extra"),
415
+ Order,
416
+ )
417
+ validate_full_roundtrip(
418
+ OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
419
+ OrderNamedTuple,
420
+ )
517
421
 
518
422
 
519
423
  def test_make_engine_value_decoder_list_of_struct() -> None:
@@ -972,7 +876,9 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
972
876
  "dimension": None,
973
877
  }
974
878
  dst_annotation = NDArrayFloat32Type | None
975
- decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
879
+ decoder = make_engine_value_decoder(
880
+ [], src_type_dict, analyze_type_info(dst_annotation)
881
+ )
976
882
 
977
883
  none_engine_value = None
978
884
  decoded_array = decoder(none_engine_value)
@@ -995,7 +901,9 @@ def test_decode_vector_string() -> None:
995
901
  "element_type": {"kind": "Str"},
996
902
  "dimension": None,
997
903
  }
998
- decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
904
+ decoder = make_engine_value_decoder(
905
+ [], src_type_dict, analyze_type_info(Vector[str])
906
+ )
999
907
  assert decoder(["hello", "world"]) == ["hello", "world"]
1000
908
 
1001
909
 
@@ -1006,7 +914,9 @@ def test_decode_error_non_nullable_or_non_list_vector() -> None:
1006
914
  "element_type": {"kind": "Float32"},
1007
915
  "dimension": None,
1008
916
  }
1009
- decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
917
+ decoder = make_engine_value_decoder(
918
+ [], src_type_dict, analyze_type_info(NDArrayFloat32Type)
919
+ )
1010
920
  with pytest.raises(ValueError, match="Received null for non-nullable vector"):
1011
921
  decoder(None)
1012
922
  with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
@@ -1096,6 +1006,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
1096
1006
  validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
1097
1007
 
1098
1008
 
1009
+ def test_full_roundtrip_vector_of_vector() -> None:
1010
+ """Test full roundtrip for vector of vector."""
1011
+ value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1012
+ validate_full_roundtrip(
1013
+ value_f32,
1014
+ Vector[Vector[np.float32, Literal[3]], Literal[2]],
1015
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
1016
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
1017
+ (
1018
+ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1019
+ list[Vector[cocoindex.Float32, Literal[3]]],
1020
+ ),
1021
+ (
1022
+ value_f32,
1023
+ np.typing.NDArray[np.typing.NDArray[np.float32]],
1024
+ ),
1025
+ )
1026
+
1027
+
1099
1028
  def test_full_roundtrip_vector_other_types() -> None:
1100
1029
  """Test full roundtrip for Vector with non-numeric basic types."""
1101
1030
  uuid_list = [uuid.uuid4(), uuid.uuid4()]
@@ -1216,7 +1145,7 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
1216
1145
  numpy_float: np.float64
1217
1146
  python_float: float
1218
1147
  string: str
1219
- annotated_int: Annotated[np.int64, TypeKind("int")]
1148
+ annotated_int: Annotated[np.int64, TypeKind("Int64")]
1220
1149
  annotated_float: Float32
1221
1150
 
1222
1151
  instance = MixedStruct(
@@ -1468,3 +1397,41 @@ def test_roundtrip_ktable_with_list_fields() -> None:
1468
1397
 
1469
1398
  # Test Any annotation
1470
1399
  validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
1400
+
1401
+
1402
+ def test_auto_default_for_supported_and_unsupported_types() -> None:
1403
+ @dataclass
1404
+ class Base:
1405
+ a: int
1406
+
1407
+ @dataclass
1408
+ class NullableField:
1409
+ a: int
1410
+ b: int | None
1411
+
1412
+ @dataclass
1413
+ class LTableField:
1414
+ a: int
1415
+ b: list[Base]
1416
+
1417
+ @dataclass
1418
+ class KTableField:
1419
+ a: int
1420
+ b: dict[str, Base]
1421
+
1422
+ @dataclass
1423
+ class UnsupportedField:
1424
+ a: int
1425
+ b: int
1426
+
1427
+ validate_full_roundtrip(NullableField(1, None), NullableField)
1428
+
1429
+ validate_full_roundtrip(LTableField(1, []), LTableField)
1430
+
1431
+ # validate_full_roundtrip(KTableField(1, {}), KTableField)
1432
+
1433
+ with pytest.raises(
1434
+ ValueError,
1435
+ match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
1436
+ ):
1437
+ build_engine_value_decoder(Base, UnsupportedField)