cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cocoindex/__init__.py +114 -0
  2. cocoindex/_engine.abi3.so +0 -0
  3. cocoindex/auth_registry.py +44 -0
  4. cocoindex/cli.py +830 -0
  5. cocoindex/engine_object.py +214 -0
  6. cocoindex/engine_value.py +550 -0
  7. cocoindex/flow.py +1281 -0
  8. cocoindex/functions/__init__.py +40 -0
  9. cocoindex/functions/_engine_builtin_specs.py +66 -0
  10. cocoindex/functions/colpali.py +247 -0
  11. cocoindex/functions/sbert.py +77 -0
  12. cocoindex/index.py +50 -0
  13. cocoindex/lib.py +75 -0
  14. cocoindex/llm.py +47 -0
  15. cocoindex/op.py +1047 -0
  16. cocoindex/py.typed +0 -0
  17. cocoindex/query_handler.py +57 -0
  18. cocoindex/runtime.py +78 -0
  19. cocoindex/setting.py +171 -0
  20. cocoindex/setup.py +92 -0
  21. cocoindex/sources/__init__.py +5 -0
  22. cocoindex/sources/_engine_builtin_specs.py +120 -0
  23. cocoindex/subprocess_exec.py +277 -0
  24. cocoindex/targets/__init__.py +5 -0
  25. cocoindex/targets/_engine_builtin_specs.py +153 -0
  26. cocoindex/targets/lancedb.py +466 -0
  27. cocoindex/tests/__init__.py +0 -0
  28. cocoindex/tests/test_engine_object.py +331 -0
  29. cocoindex/tests/test_engine_value.py +1724 -0
  30. cocoindex/tests/test_optional_database.py +249 -0
  31. cocoindex/tests/test_transform_flow.py +300 -0
  32. cocoindex/tests/test_typing.py +553 -0
  33. cocoindex/tests/test_validation.py +134 -0
  34. cocoindex/typing.py +834 -0
  35. cocoindex/user_app_loader.py +53 -0
  36. cocoindex/utils.py +20 -0
  37. cocoindex/validation.py +104 -0
  38. cocoindex-0.3.4.dist-info/METADATA +288 -0
  39. cocoindex-0.3.4.dist-info/RECORD +42 -0
  40. cocoindex-0.3.4.dist-info/WHEEL +4 -0
  41. cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
  42. cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
@@ -0,0 +1,1724 @@
1
+ import datetime
2
+ import inspect
3
+ import uuid
4
+ from dataclasses import dataclass, make_dataclass
5
+ from typing import Annotated, Any, Callable, Literal, NamedTuple, Type
6
+
7
+ import numpy as np
8
+ import pytest
9
+ from numpy.typing import NDArray
10
+
11
+ # Optional Pydantic support for testing
12
+ try:
13
+ from pydantic import BaseModel, Field
14
+
15
+ PYDANTIC_AVAILABLE = True
16
+ except ImportError:
17
+ BaseModel = None # type: ignore[misc,assignment]
18
+ Field = None # type: ignore[misc,assignment]
19
+ PYDANTIC_AVAILABLE = False
20
+
21
+ import cocoindex
22
+ from cocoindex.engine_value import (
23
+ make_engine_value_encoder,
24
+ make_engine_value_decoder,
25
+ )
26
+ from cocoindex.typing import (
27
+ Float32,
28
+ Float64,
29
+ TypeKind,
30
+ Vector,
31
+ analyze_type_info,
32
+ encode_enriched_type,
33
+ decode_engine_value_type,
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class Order:
39
+ order_id: str
40
+ name: str
41
+ price: float
42
+ extra_field: str = "default_extra"
43
+
44
+
45
+ @dataclass
46
+ class Tag:
47
+ name: str
48
+
49
+
50
+ @dataclass
51
+ class Basket:
52
+ items: list[str]
53
+
54
+
55
+ @dataclass
56
+ class Customer:
57
+ name: str
58
+ order: Order
59
+ tags: list[Tag] | None = None
60
+
61
+
62
+ @dataclass
63
+ class NestedStruct:
64
+ customer: Customer
65
+ orders: list[Order]
66
+ count: int = 0
67
+
68
+
69
+ class OrderNamedTuple(NamedTuple):
70
+ order_id: str
71
+ name: str
72
+ price: float
73
+ extra_field: str = "default_extra"
74
+
75
+
76
+ class CustomerNamedTuple(NamedTuple):
77
+ name: str
78
+ order: OrderNamedTuple
79
+ tags: list[Tag] | None = None
80
+
81
+
82
+ # Pydantic model definitions (if available)
83
+ if PYDANTIC_AVAILABLE:
84
+
85
+ class OrderPydantic(BaseModel):
86
+ order_id: str
87
+ name: str
88
+ price: float
89
+ extra_field: str = "default_extra"
90
+
91
+ class TagPydantic(BaseModel):
92
+ name: str
93
+
94
+ class CustomerPydantic(BaseModel):
95
+ name: str
96
+ order: OrderPydantic
97
+ tags: list[TagPydantic] | None = None
98
+
99
+ class NestedStructPydantic(BaseModel):
100
+ customer: CustomerPydantic
101
+ orders: list[OrderPydantic]
102
+ count: int = 0
103
+
104
+
105
+ def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
106
+ """
107
+ Encode a Python value to an engine value.
108
+ """
109
+ encoder = make_engine_value_encoder(analyze_type_info(type_hint))
110
+ return encoder(value)
111
+
112
+
113
+ def build_engine_value_decoder(
114
+ engine_type_in_py: Any, python_type: Any | None = None
115
+ ) -> Callable[[Any], Any]:
116
+ """
117
+ Helper to build a converter for the given engine-side type (as represented in Python).
118
+ If python_type is not specified, uses engine_type_in_py as the target.
119
+ """
120
+ engine_type = encode_enriched_type(engine_type_in_py)["type"]
121
+ return make_engine_value_decoder(
122
+ [],
123
+ decode_engine_value_type(engine_type),
124
+ analyze_type_info(python_type or engine_type_in_py),
125
+ )
126
+
127
+
128
+ def validate_full_roundtrip_to(
129
+ value: Any,
130
+ value_type: Any,
131
+ *decoded_values: tuple[Any, Any],
132
+ ) -> None:
133
+ """
134
+ Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
135
+
136
+ `decoded_values` is a tuple of (value, type) pairs.
137
+ """
138
+ from cocoindex import _engine # type: ignore
139
+
140
+ def eq(a: Any, b: Any) -> bool:
141
+ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
142
+ return np.array_equal(a, b)
143
+ return type(a) is type(b) and not not (a == b)
144
+
145
+ encoded_value = encode_engine_value(value, value_type)
146
+ value_type = value_type or type(value)
147
+ encoded_output_type = encode_enriched_type(value_type)["type"]
148
+ value_from_engine = _engine.testutil.seder_roundtrip(
149
+ encoded_value, encoded_output_type
150
+ )
151
+
152
+ for other_value, other_type in decoded_values:
153
+ decoder = make_engine_value_decoder(
154
+ [],
155
+ decode_engine_value_type(encoded_output_type),
156
+ analyze_type_info(other_type),
157
+ )
158
+ other_decoded_value = decoder(value_from_engine)
159
+ assert eq(other_decoded_value, other_value), (
160
+ f"Expected {other_value} but got {other_decoded_value} for {other_type}"
161
+ )
162
+
163
+
164
+ def validate_full_roundtrip(
165
+ value: Any,
166
+ value_type: Any,
167
+ *other_decoded_values: tuple[Any, Any],
168
+ ) -> None:
169
+ """
170
+ Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
171
+
172
+ `other_decoded_values` is a tuple of (value, type) pairs.
173
+ If provided, also validate the value can be decoded to the other types.
174
+ """
175
+ validate_full_roundtrip_to(
176
+ value, value_type, (value, value_type), *other_decoded_values
177
+ )
178
+
179
+
180
+ def test_encode_engine_value_basic_types() -> None:
181
+ assert encode_engine_value(123, int) == 123
182
+ assert encode_engine_value(3.14, float) == 3.14
183
+ assert encode_engine_value("hello", str) == "hello"
184
+ assert encode_engine_value(True, bool) is True
185
+
186
+
187
+ def test_encode_engine_value_uuid() -> None:
188
+ u = uuid.uuid4()
189
+ assert encode_engine_value(u, uuid.UUID) == u
190
+
191
+
192
+ def test_encode_engine_value_date_time_types() -> None:
193
+ d = datetime.date(2024, 1, 1)
194
+ assert encode_engine_value(d, datetime.date) == d
195
+ t = datetime.time(12, 30)
196
+ assert encode_engine_value(t, datetime.time) == t
197
+ dt = datetime.datetime(2024, 1, 1, 12, 30)
198
+ assert encode_engine_value(dt, datetime.datetime) == dt
199
+
200
+
201
+ def test_encode_scalar_numpy_values() -> None:
202
+ """Test encoding scalar NumPy values to engine-compatible values."""
203
+ test_cases = [
204
+ (np.int64(42), 42),
205
+ (np.float32(3.14), pytest.approx(3.14)),
206
+ (np.float64(2.718), pytest.approx(2.718)),
207
+ ]
208
+ for np_value, expected in test_cases:
209
+ encoded = encode_engine_value(np_value, type(np_value))
210
+ assert encoded == expected
211
+ assert isinstance(encoded, (int, float))
212
+
213
+
214
+ def test_encode_engine_value_struct() -> None:
215
+ order = Order(order_id="O123", name="mixed nuts", price=25.0)
216
+ assert encode_engine_value(order, Order) == [
217
+ "O123",
218
+ "mixed nuts",
219
+ 25.0,
220
+ "default_extra",
221
+ ]
222
+
223
+ order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
224
+ assert encode_engine_value(order_nt, OrderNamedTuple) == [
225
+ "O123",
226
+ "mixed nuts",
227
+ 25.0,
228
+ "default_extra",
229
+ ]
230
+
231
+
232
+ def test_encode_engine_value_list_of_structs() -> None:
233
+ orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
234
+ assert encode_engine_value(orders, list[Order]) == [
235
+ ["O1", "item1", 10.0, "default_extra"],
236
+ ["O2", "item2", 20.0, "default_extra"],
237
+ ]
238
+
239
+ orders_nt = [
240
+ OrderNamedTuple("O1", "item1", 10.0),
241
+ OrderNamedTuple("O2", "item2", 20.0),
242
+ ]
243
+ assert encode_engine_value(orders_nt, list[OrderNamedTuple]) == [
244
+ ["O1", "item1", 10.0, "default_extra"],
245
+ ["O2", "item2", 20.0, "default_extra"],
246
+ ]
247
+
248
+
249
+ def test_encode_engine_value_struct_with_list() -> None:
250
+ basket = Basket(items=["apple", "banana"])
251
+ assert encode_engine_value(basket, Basket) == [["apple", "banana"]]
252
+
253
+
254
+ def test_encode_engine_value_nested_struct() -> None:
255
+ customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
256
+ assert encode_engine_value(customer, Customer) == [
257
+ "Alice",
258
+ ["O1", "item1", 10.0, "default_extra"],
259
+ None,
260
+ ]
261
+
262
+ customer_nt = CustomerNamedTuple(
263
+ name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)
264
+ )
265
+ assert encode_engine_value(customer_nt, CustomerNamedTuple) == [
266
+ "Alice",
267
+ ["O1", "item1", 10.0, "default_extra"],
268
+ None,
269
+ ]
270
+
271
+
272
+ def test_encode_engine_value_empty_list() -> None:
273
+ assert encode_engine_value([], list) == []
274
+ assert encode_engine_value([[]], list[list[Any]]) == [[]]
275
+
276
+
277
+ def test_encode_engine_value_tuple() -> None:
278
+ assert encode_engine_value((), Any) == []
279
+ assert encode_engine_value((1, 2, 3), Any) == [1, 2, 3]
280
+ assert encode_engine_value(((1, 2), (3, 4)), Any) == [[1, 2], [3, 4]]
281
+ assert encode_engine_value(([],), Any) == [[]]
282
+ assert encode_engine_value(((),), Any) == [[]]
283
+
284
+
285
+ def test_encode_engine_value_none() -> None:
286
+ assert encode_engine_value(None, Any) is None
287
+
288
+
289
+ def test_roundtrip_basic_types() -> None:
290
+ validate_full_roundtrip(
291
+ b"hello world",
292
+ bytes,
293
+ (b"hello world", inspect.Parameter.empty),
294
+ (b"hello world", Any),
295
+ )
296
+ validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
297
+ validate_full_roundtrip("hello", str, ("hello", Any))
298
+ validate_full_roundtrip(True, bool, (True, Any))
299
+ validate_full_roundtrip(False, bool, (False, Any))
300
+ validate_full_roundtrip(
301
+ 42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
302
+ )
303
+ validate_full_roundtrip(42, int, (42, cocoindex.Int64))
304
+ validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
305
+
306
+ validate_full_roundtrip(
307
+ 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
308
+ )
309
+ validate_full_roundtrip(3.25, float, (3.25, Float64))
310
+ validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
311
+
312
+ validate_full_roundtrip(
313
+ 3.25,
314
+ Float32,
315
+ (3.25, float),
316
+ (np.float32(3.25), np.float32),
317
+ (np.float64(3.25), np.float64),
318
+ (3.25, Float64),
319
+ (3.25, Any),
320
+ )
321
+ validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
322
+
323
+
324
+ def test_roundtrip_uuid() -> None:
325
+ uuid_value = uuid.uuid4()
326
+ validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
327
+
328
+
329
+ def test_roundtrip_range() -> None:
330
+ r1 = (0, 100)
331
+ validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
332
+ r2 = (50, 50)
333
+ validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
334
+ r3 = (0, 1_000_000_000)
335
+ validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
336
+
337
+
338
+ def test_roundtrip_time() -> None:
339
+ t1 = datetime.time(10, 30, 50, 123456)
340
+ validate_full_roundtrip(t1, datetime.time, (t1, Any))
341
+ t2 = datetime.time(23, 59, 59)
342
+ validate_full_roundtrip(t2, datetime.time, (t2, Any))
343
+ t3 = datetime.time(0, 0, 0)
344
+ validate_full_roundtrip(t3, datetime.time, (t3, Any))
345
+
346
+ validate_full_roundtrip(
347
+ datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
348
+ )
349
+
350
+ validate_full_roundtrip(
351
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
352
+ cocoindex.LocalDateTime,
353
+ (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
354
+ )
355
+
356
+ tz = datetime.timezone(datetime.timedelta(hours=5))
357
+ validate_full_roundtrip(
358
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
359
+ cocoindex.OffsetDateTime,
360
+ (
361
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
362
+ datetime.datetime,
363
+ ),
364
+ )
365
+ validate_full_roundtrip(
366
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
367
+ datetime.datetime,
368
+ (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime),
369
+ )
370
+ validate_full_roundtrip_to(
371
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
372
+ cocoindex.OffsetDateTime,
373
+ (
374
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
375
+ datetime.datetime,
376
+ ),
377
+ )
378
+ validate_full_roundtrip_to(
379
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
380
+ datetime.datetime,
381
+ (
382
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
383
+ cocoindex.OffsetDateTime,
384
+ ),
385
+ )
386
+
387
+
388
+ def test_roundtrip_timedelta() -> None:
389
+ td1 = datetime.timedelta(
390
+ days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
391
+ )
392
+ validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
393
+ td2 = datetime.timedelta(days=-5, hours=-2)
394
+ validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
395
+ td3 = datetime.timedelta(0)
396
+ validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
397
+
398
+
399
+ def test_roundtrip_json() -> None:
400
+ simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23}
401
+ validate_full_roundtrip(simple_dict, cocoindex.Json)
402
+
403
+ simple_list = [1, "string", False, None, 4.56]
404
+ validate_full_roundtrip(simple_list, cocoindex.Json)
405
+
406
+ nested_structure = {
407
+ "name": "Test Json",
408
+ "version": 1.0,
409
+ "items": [
410
+ {"id": 1, "value": "item1"},
411
+ {"id": 2, "value": None, "props": {"active": True}},
412
+ ],
413
+ "metadata": None,
414
+ }
415
+ validate_full_roundtrip(nested_structure, cocoindex.Json)
416
+
417
+ validate_full_roundtrip({}, cocoindex.Json)
418
+ validate_full_roundtrip([], cocoindex.Json)
419
+
420
+
421
+ def test_decode_scalar_numpy_values() -> None:
422
+ test_cases = [
423
+ (decode_engine_value_type({"kind": "Int64"}), np.int64, 42, np.int64(42)),
424
+ (
425
+ decode_engine_value_type({"kind": "Float32"}),
426
+ np.float32,
427
+ 3.14,
428
+ np.float32(3.14),
429
+ ),
430
+ (
431
+ decode_engine_value_type({"kind": "Float64"}),
432
+ np.float64,
433
+ 2.718,
434
+ np.float64(2.718),
435
+ ),
436
+ ]
437
+ for src_type, dst_type, input_value, expected in test_cases:
438
+ decoder = make_engine_value_decoder(
439
+ ["field"], src_type, analyze_type_info(dst_type)
440
+ )
441
+ result = decoder(input_value)
442
+ assert isinstance(result, dst_type)
443
+ assert result == expected
444
+
445
+
446
+ def test_non_ndarray_vector_decoding() -> None:
447
+ # Test list[np.float64]
448
+ src_type = decode_engine_value_type(
449
+ {
450
+ "kind": "Vector",
451
+ "element_type": {"kind": "Float64"},
452
+ "dimension": None,
453
+ }
454
+ )
455
+ dst_type_float = list[np.float64]
456
+ decoder = make_engine_value_decoder(
457
+ ["field"], src_type, analyze_type_info(dst_type_float)
458
+ )
459
+ input_numbers = [1.0, 2.0, 3.0]
460
+ result = decoder(input_numbers)
461
+ assert isinstance(result, list)
462
+ assert all(isinstance(x, np.float64) for x in result)
463
+ assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
464
+
465
+ # Test list[Uuid]
466
+ src_type = decode_engine_value_type(
467
+ {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
468
+ )
469
+ dst_type_uuid = list[uuid.UUID]
470
+ decoder = make_engine_value_decoder(
471
+ ["field"], src_type, analyze_type_info(dst_type_uuid)
472
+ )
473
+ uuid1 = uuid.uuid4()
474
+ uuid2 = uuid.uuid4()
475
+ input_uuids = [uuid1, uuid2]
476
+ result = decoder(input_uuids)
477
+ assert isinstance(result, list)
478
+ assert all(isinstance(x, uuid.UUID) for x in result)
479
+ assert result == [uuid1, uuid2]
480
+
481
+
482
+ def test_roundtrip_struct() -> None:
483
+ validate_full_roundtrip(
484
+ Order("O123", "mixed nuts", 25.0, "default_extra"),
485
+ Order,
486
+ )
487
+ validate_full_roundtrip(
488
+ OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
489
+ OrderNamedTuple,
490
+ )
491
+
492
+
493
+ def test_make_engine_value_decoder_list_of_struct() -> None:
494
+ # List of structs (dataclass)
495
+ engine_val = [
496
+ ["O1", "item1", 10.0, "default_extra"],
497
+ ["O2", "item2", 20.0, "default_extra"],
498
+ ]
499
+ decoder = build_engine_value_decoder(list[Order])
500
+ assert decoder(engine_val) == [
501
+ Order("O1", "item1", 10.0, "default_extra"),
502
+ Order("O2", "item2", 20.0, "default_extra"),
503
+ ]
504
+
505
+ # List of structs (NamedTuple)
506
+ decoder = build_engine_value_decoder(list[OrderNamedTuple])
507
+ assert decoder(engine_val) == [
508
+ OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
509
+ OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
510
+ ]
511
+
512
+
513
+ def test_make_engine_value_decoder_struct_of_list() -> None:
514
+ # Struct with list field
515
+ engine_val = [
516
+ "Alice",
517
+ ["O1", "item1", 10.0, "default_extra"],
518
+ [["vip"], ["premium"]],
519
+ ]
520
+ decoder = build_engine_value_decoder(Customer)
521
+ assert decoder(engine_val) == Customer(
522
+ "Alice",
523
+ Order("O1", "item1", 10.0, "default_extra"),
524
+ [Tag("vip"), Tag("premium")],
525
+ )
526
+
527
+ # NamedTuple with list field
528
+ decoder = build_engine_value_decoder(CustomerNamedTuple)
529
+ assert decoder(engine_val) == CustomerNamedTuple(
530
+ "Alice",
531
+ OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
532
+ [Tag("vip"), Tag("premium")],
533
+ )
534
+
535
+
536
+ def test_make_engine_value_decoder_struct_of_struct() -> None:
537
+ # Struct with struct field
538
+ engine_val = [
539
+ ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
540
+ [
541
+ ["O1", "item1", 10.0, "default_extra"],
542
+ ["O2", "item2", 20.0, "default_extra"],
543
+ ],
544
+ 2,
545
+ ]
546
+ decoder = build_engine_value_decoder(NestedStruct)
547
+ assert decoder(engine_val) == NestedStruct(
548
+ Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
549
+ [
550
+ Order("O1", "item1", 10.0, "default_extra"),
551
+ Order("O2", "item2", 20.0, "default_extra"),
552
+ ],
553
+ 2,
554
+ )
555
+
556
+
557
+ def make_engine_order(fields: list[tuple[str, type]]) -> type:
558
+ return make_dataclass("EngineOrder", fields)
559
+
560
+
561
+ def make_python_order(
562
+ fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
563
+ ) -> type:
564
+ if defaults is None:
565
+ defaults = {}
566
+ # Move all fields with defaults to the end (Python dataclass requirement)
567
+ non_default_fields = [(n, t) for n, t in fields if n not in defaults]
568
+ default_fields = [(n, t) for n, t in fields if n in defaults]
569
+ ordered_fields = non_default_fields + default_fields
570
+ # Prepare the namespace for defaults (only for fields at the end)
571
+ namespace = {k: defaults[k] for k, _ in default_fields}
572
+ return make_dataclass("PythonOrder", ordered_fields, namespace=namespace)
573
+
574
+
575
+ @pytest.mark.parametrize(
576
+ "engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
577
+ [
578
+ # Extra field in Python (middle)
579
+ (
580
+ [("id", str), ("name", str)],
581
+ [("id", str), ("price", float), ("name", str)],
582
+ {"price": 0.0},
583
+ ["O123", "mixed nuts"],
584
+ ("O123", 0.0, "mixed nuts"),
585
+ ),
586
+ # Missing field in Python (middle)
587
+ (
588
+ [("id", str), ("price", float), ("name", str)],
589
+ [("id", str), ("name", str)],
590
+ {},
591
+ ["O123", 25.0, "mixed nuts"],
592
+ ("O123", "mixed nuts"),
593
+ ),
594
+ # Extra field in Python (start)
595
+ (
596
+ [("name", str), ("price", float)],
597
+ [("extra", str), ("name", str), ("price", float)],
598
+ {"extra": "default"},
599
+ ["mixed nuts", 25.0],
600
+ ("default", "mixed nuts", 25.0),
601
+ ),
602
+ # Missing field in Python (start)
603
+ (
604
+ [("extra", str), ("name", str), ("price", float)],
605
+ [("name", str), ("price", float)],
606
+ {},
607
+ ["unexpected", "mixed nuts", 25.0],
608
+ ("mixed nuts", 25.0),
609
+ ),
610
+ # Field order difference (should map by name)
611
+ (
612
+ [("id", str), ("name", str), ("price", float)],
613
+ [("name", str), ("id", str), ("price", float), ("extra", str)],
614
+ {"extra": "default"},
615
+ ["O123", "mixed nuts", 25.0],
616
+ ("mixed nuts", "O123", 25.0, "default"),
617
+ ),
618
+ # Extra field (Python has extra field with default)
619
+ (
620
+ [("id", str), ("name", str)],
621
+ [("id", str), ("name", str), ("price", float)],
622
+ {"price": 0.0},
623
+ ["O123", "mixed nuts"],
624
+ ("O123", "mixed nuts", 0.0),
625
+ ),
626
+ # Missing field (Engine has extra field)
627
+ (
628
+ [("id", str), ("name", str), ("price", float)],
629
+ [("id", str), ("name", str)],
630
+ {},
631
+ ["O123", "mixed nuts", 25.0],
632
+ ("O123", "mixed nuts"),
633
+ ),
634
+ ],
635
+ )
636
+ def test_field_position_cases(
637
+ engine_fields: list[tuple[str, type]],
638
+ python_fields: list[tuple[str, type]],
639
+ python_defaults: dict[str, Any],
640
+ engine_val: list[Any],
641
+ expected_python_val: tuple[Any, ...],
642
+ ) -> None:
643
+ EngineOrder = make_engine_order(engine_fields)
644
+ PythonOrder = make_python_order(python_fields, python_defaults)
645
+ decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
646
+ # Map field names to expected values
647
+ expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
648
+ # Instantiate using keyword arguments (order doesn't matter)
649
+ assert decoder(engine_val) == PythonOrder(**expected_dict)
650
+
651
+
652
+ def test_roundtrip_union_simple() -> None:
653
+ t = int | str | float
654
+ value = 10.4
655
+ validate_full_roundtrip(value, t)
656
+
657
+
658
+ def test_roundtrip_union_with_active_uuid() -> None:
659
+ t = str | uuid.UUID | int
660
+ value = uuid.uuid4()
661
+ validate_full_roundtrip(value, t)
662
+
663
+
664
+ def test_roundtrip_union_with_inactive_uuid() -> None:
665
+ t = str | uuid.UUID | int
666
+ value = "5a9f8f6a-318f-4f1f-929d-566d7444a62d" # it's a string
667
+ validate_full_roundtrip(value, t)
668
+
669
+
670
+ def test_roundtrip_union_offset_datetime() -> None:
671
+ t = str | uuid.UUID | float | int | datetime.datetime
672
+ value = datetime.datetime.now(datetime.UTC)
673
+ validate_full_roundtrip(value, t)
674
+
675
+
676
+ def test_roundtrip_union_date() -> None:
677
+ t = str | uuid.UUID | float | int | datetime.date
678
+ value = datetime.date.today()
679
+ validate_full_roundtrip(value, t)
680
+
681
+
682
+ def test_roundtrip_union_time() -> None:
683
+ t = str | uuid.UUID | float | int | datetime.time
684
+ value = datetime.time()
685
+ validate_full_roundtrip(value, t)
686
+
687
+
688
+ def test_roundtrip_union_timedelta() -> None:
689
+ t = str | uuid.UUID | float | int | datetime.timedelta
690
+ value = datetime.timedelta(hours=39, minutes=10, seconds=1)
691
+ validate_full_roundtrip(value, t)
692
+
693
+
694
+ def test_roundtrip_vector_of_union() -> None:
695
+ t = list[str | int]
696
+ value = ["a", 1]
697
+ validate_full_roundtrip(value, t)
698
+
699
+
700
+ def test_roundtrip_union_with_vector() -> None:
701
+ t = NDArray[np.float32] | str
702
+ value = np.array([1.0, 2.0, 3.0], dtype=np.float32)
703
+ validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))
704
+
705
+
706
+ def test_roundtrip_union_with_misc_types() -> None:
707
+ t_bytes_union = int | bytes | str
708
+ validate_full_roundtrip(b"test_bytes", t_bytes_union)
709
+ validate_full_roundtrip(123, t_bytes_union)
710
+
711
+ t_range_union = cocoindex.Range | str | bool
712
+ validate_full_roundtrip((100, 200), t_range_union)
713
+ validate_full_roundtrip("test_string", t_range_union)
714
+
715
+ t_json_union = cocoindex.Json | int | bytes
716
+ json_dict = {"a": 1, "b": [2, 3]}
717
+ validate_full_roundtrip(json_dict, t_json_union)
718
+ validate_full_roundtrip(b"another_byte_string", t_json_union)
719
+
720
+
721
+ def test_roundtrip_ltable() -> None:
722
+ t = list[Order]
723
+ value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
724
+ validate_full_roundtrip(value, t)
725
+
726
+ t_nt = list[OrderNamedTuple]
727
+ value_nt = [
728
+ OrderNamedTuple("O1", "item1", 10.0),
729
+ OrderNamedTuple("O2", "item2", 20.0),
730
+ ]
731
+ validate_full_roundtrip(value_nt, t_nt)
732
+
733
+
734
+ def test_roundtrip_ktable_various_key_types() -> None:
735
+ @dataclass
736
+ class SimpleValue:
737
+ data: str
738
+
739
+ t_bytes_key = dict[bytes, SimpleValue]
740
+ value_bytes_key = {b"key1": SimpleValue("val1"), b"key2": SimpleValue("val2")}
741
+ validate_full_roundtrip(value_bytes_key, t_bytes_key)
742
+
743
+ t_int_key = dict[int, SimpleValue]
744
+ value_int_key = {1: SimpleValue("val1"), 2: SimpleValue("val2")}
745
+ validate_full_roundtrip(value_int_key, t_int_key)
746
+
747
+ t_bool_key = dict[bool, SimpleValue]
748
+ value_bool_key = {True: SimpleValue("val_true"), False: SimpleValue("val_false")}
749
+ validate_full_roundtrip(value_bool_key, t_bool_key)
750
+
751
+ t_str_key = dict[str, Order]
752
+ value_str_key = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
753
+ validate_full_roundtrip(value_str_key, t_str_key)
754
+
755
+ t_nt = dict[str, OrderNamedTuple]
756
+ value_nt = {
757
+ "K1": OrderNamedTuple("O1", "item1", 10.0),
758
+ "K2": OrderNamedTuple("O2", "item2", 20.0),
759
+ }
760
+ validate_full_roundtrip(value_nt, t_nt)
761
+
762
+ t_range_key = dict[cocoindex.Range, SimpleValue]
763
+ value_range_key = {
764
+ (1, 10): SimpleValue("val_range1"),
765
+ (20, 30): SimpleValue("val_range2"),
766
+ }
767
+ validate_full_roundtrip(value_range_key, t_range_key)
768
+
769
+ t_date_key = dict[datetime.date, SimpleValue]
770
+ value_date_key = {
771
+ datetime.date(2023, 1, 1): SimpleValue("val_date1"),
772
+ datetime.date(2024, 2, 2): SimpleValue("val_date2"),
773
+ }
774
+ validate_full_roundtrip(value_date_key, t_date_key)
775
+
776
+ t_uuid_key = dict[uuid.UUID, SimpleValue]
777
+ value_uuid_key = {
778
+ uuid.uuid4(): SimpleValue("val_uuid1"),
779
+ uuid.uuid4(): SimpleValue("val_uuid2"),
780
+ }
781
+ validate_full_roundtrip(value_uuid_key, t_uuid_key)
782
+
783
+
784
+ def test_roundtrip_ktable_struct_key() -> None:
785
+ @dataclass(frozen=True)
786
+ class OrderKey:
787
+ shop_id: str
788
+ version: int
789
+
790
+ t = dict[OrderKey, Order]
791
+ value = {
792
+ OrderKey("A", 3): Order("O1", "item1", 10.0),
793
+ OrderKey("B", 4): Order("O2", "item2", 20.0),
794
+ }
795
+ validate_full_roundtrip(value, t)
796
+
797
+ t_nt = dict[OrderKey, OrderNamedTuple]
798
+ value_nt = {
799
+ OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
800
+ OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
801
+ }
802
+ validate_full_roundtrip(value_nt, t_nt)
803
+
804
+
805
+ IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
806
+
807
+
808
+ def test_vector_as_vector() -> None:
809
+ value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
810
+ encoded = encode_engine_value(value, IntVectorType)
811
+ assert np.array_equal(encoded, value)
812
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
813
+ assert np.array_equal(decoded, value)
814
+
815
+
816
+ ListIntType = list[int]
817
+
818
+
819
+ def test_vector_as_list() -> None:
820
+ value: ListIntType = [1, 2, 3, 4, 5]
821
+ encoded = encode_engine_value(value, ListIntType)
822
+ assert encoded == [1, 2, 3, 4, 5]
823
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
824
+ assert np.array_equal(decoded, value)
825
+
826
+
827
+ Float64VectorTypeNoDim = Vector[np.float64]
828
+ Float32VectorType = Vector[np.float32, Literal[3]]
829
+ Float64VectorType = Vector[np.float64, Literal[3]]
830
+ Int64VectorType = Vector[np.int64, Literal[3]]
831
+ NDArrayFloat32Type = NDArray[np.float32]
832
+ NDArrayFloat64Type = NDArray[np.float64]
833
+ NDArrayInt64Type = NDArray[np.int64]
834
+
835
+
836
+ def test_encode_engine_value_ndarray() -> None:
837
+ """Test encoding NDArray vectors to lists for the Rust engine."""
838
+ vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
839
+ assert np.array_equal(
840
+ encode_engine_value(vec_f32, Float32VectorType), [1.0, 2.0, 3.0]
841
+ )
842
+ vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
843
+ assert np.array_equal(
844
+ encode_engine_value(vec_f64, Float64VectorType), [1.0, 2.0, 3.0]
845
+ )
846
+ vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
847
+ assert np.array_equal(encode_engine_value(vec_i64, Int64VectorType), [1, 2, 3])
848
+ vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
849
+ assert np.array_equal(
850
+ encode_engine_value(vec_nd_f32, NDArrayFloat32Type), [1.0, 2.0, 3.0]
851
+ )
852
+
853
+
854
+ def test_make_engine_value_decoder_ndarray() -> None:
855
+ """Test decoding engine lists to NDArray vectors."""
856
+ decoder_f32 = build_engine_value_decoder(Float32VectorType)
857
+ result_f32 = decoder_f32([1.0, 2.0, 3.0])
858
+ assert isinstance(result_f32, np.ndarray)
859
+ assert result_f32.dtype == np.float32
860
+ assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
861
+ decoder_f64 = build_engine_value_decoder(Float64VectorType)
862
+ result_f64 = decoder_f64([1.0, 2.0, 3.0])
863
+ assert isinstance(result_f64, np.ndarray)
864
+ assert result_f64.dtype == np.float64
865
+ assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
866
+ decoder_i64 = build_engine_value_decoder(Int64VectorType)
867
+ result_i64 = decoder_i64([1, 2, 3])
868
+ assert isinstance(result_i64, np.ndarray)
869
+ assert result_i64.dtype == np.int64
870
+ assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
871
+ decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
872
+ result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
873
+ assert isinstance(result_nd_f32, np.ndarray)
874
+ assert result_nd_f32.dtype == np.float32
875
+ assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
876
+
877
+
878
+ def test_roundtrip_ndarray_vector() -> None:
879
+ """Test roundtrip encoding and decoding of NDArray vectors."""
880
+ value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
881
+ encoded_f32 = encode_engine_value(value_f32, Float32VectorType)
882
+ np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
883
+ decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
884
+ assert isinstance(decoded_f32, np.ndarray)
885
+ assert decoded_f32.dtype == np.float32
886
+ assert np.array_equal(decoded_f32, value_f32)
887
+ value_i64 = np.array([1, 2, 3], dtype=np.int64)
888
+ encoded_i64 = encode_engine_value(value_i64, Int64VectorType)
889
+ assert np.array_equal(encoded_i64, [1, 2, 3])
890
+ decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
891
+ assert isinstance(decoded_i64, np.ndarray)
892
+ assert decoded_i64.dtype == np.int64
893
+ assert np.array_equal(decoded_i64, value_i64)
894
+ value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
895
+ encoded_nd_f64 = encode_engine_value(value_nd_f64, NDArrayFloat64Type)
896
+ assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
897
+ decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
898
+ assert isinstance(decoded_nd_f64, np.ndarray)
899
+ assert decoded_nd_f64.dtype == np.float64
900
+ assert np.array_equal(decoded_nd_f64, value_nd_f64)
901
+
902
+
903
+ def test_ndarray_dimension_mismatch() -> None:
904
+ """Test dimension enforcement for Vector with specified dimension."""
905
+ value = np.array([1.0, 2.0], dtype=np.float32)
906
+ encoded = encode_engine_value(value, NDArray[np.float32])
907
+ assert np.array_equal(encoded, [1.0, 2.0])
908
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
909
+ build_engine_value_decoder(Float32VectorType)(encoded)
910
+
911
+
912
+ def test_list_vector_backward_compatibility() -> None:
913
+ """Test that list-based vectors still work for backward compatibility."""
914
+ value = [1, 2, 3, 4, 5]
915
+ encoded = encode_engine_value(value, list[int])
916
+ assert encoded == [1, 2, 3, 4, 5]
917
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
918
+ assert isinstance(decoded, np.ndarray)
919
+ assert decoded.dtype == np.int64
920
+ assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
921
+ value_list: ListIntType = [1, 2, 3, 4, 5]
922
+ encoded = encode_engine_value(value_list, ListIntType)
923
+ assert np.array_equal(encoded, [1, 2, 3, 4, 5])
924
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
925
+ assert np.array_equal(decoded, [1, 2, 3, 4, 5])
926
+
927
+
928
+ def test_encode_complex_structure_with_ndarray() -> None:
929
+ """Test encoding a complex structure that includes an NDArray."""
930
+
931
+ @dataclass
932
+ class MyStructWithNDArray:
933
+ name: str
934
+ data: NDArray[np.float32]
935
+ value: int
936
+
937
+ original = MyStructWithNDArray(
938
+ name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
939
+ )
940
+ encoded = encode_engine_value(original, MyStructWithNDArray)
941
+
942
+ assert encoded[0] == original.name
943
+ assert np.array_equal(encoded[1], original.data)
944
+ assert encoded[2] == original.value
945
+
946
+
947
+ def test_decode_nullable_ndarray_none_or_value_input() -> None:
948
+ """Test decoding a nullable NDArray with None or value inputs."""
949
+ src_type_dict = decode_engine_value_type(
950
+ {
951
+ "kind": "Vector",
952
+ "element_type": {"kind": "Float32"},
953
+ "dimension": None,
954
+ }
955
+ )
956
+ dst_annotation = NDArrayFloat32Type | None
957
+ decoder = make_engine_value_decoder(
958
+ [], src_type_dict, analyze_type_info(dst_annotation)
959
+ )
960
+
961
+ none_engine_value = None
962
+ decoded_array = decoder(none_engine_value)
963
+ assert decoded_array is None
964
+
965
+ engine_value = [1.0, 2.0, 3.0]
966
+ decoded_array = decoder(engine_value)
967
+
968
+ assert isinstance(decoded_array, np.ndarray)
969
+ assert decoded_array.dtype == np.float32
970
+ np.testing.assert_array_equal(
971
+ decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
972
+ )
973
+
974
+
975
+ def test_decode_vector_string() -> None:
976
+ """Test decoding a vector of strings works for Python native list type."""
977
+ src_type_dict = decode_engine_value_type(
978
+ {
979
+ "kind": "Vector",
980
+ "element_type": {"kind": "Str"},
981
+ "dimension": None,
982
+ }
983
+ )
984
+ decoder = make_engine_value_decoder(
985
+ [], src_type_dict, analyze_type_info(Vector[str])
986
+ )
987
+ assert decoder(["hello", "world"]) == ["hello", "world"]
988
+
989
+
990
+ def test_decode_error_non_nullable_or_non_list_vector() -> None:
991
+ """Test decoding errors for non-nullable vectors or non-list inputs."""
992
+ src_type_dict = decode_engine_value_type(
993
+ {
994
+ "kind": "Vector",
995
+ "element_type": {"kind": "Float32"},
996
+ "dimension": None,
997
+ }
998
+ )
999
+ decoder = make_engine_value_decoder(
1000
+ [], src_type_dict, analyze_type_info(NDArrayFloat32Type)
1001
+ )
1002
+ with pytest.raises(ValueError, match="Received null for non-nullable vector"):
1003
+ decoder(None)
1004
+ with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
1005
+ decoder("not a list")
1006
+
1007
+
1008
+ def test_full_roundtrip_vector_numeric_types() -> None:
1009
+ """Test full roundtrip for numeric vector types using NDArray."""
1010
+ value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
1011
+ validate_full_roundtrip(
1012
+ value_f32,
1013
+ Vector[np.float32, Literal[3]],
1014
+ ([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
1015
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
1016
+ ([1.0, 2.0, 3.0], list[float]),
1017
+ )
1018
+ validate_full_roundtrip(
1019
+ value_f32,
1020
+ np.typing.NDArray[np.float32],
1021
+ ([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
1022
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
1023
+ ([1.0, 2.0, 3.0], list[float]),
1024
+ )
1025
+ validate_full_roundtrip(
1026
+ value_f32.tolist(),
1027
+ list[np.float32],
1028
+ (value_f32, Vector[np.float32, Literal[3]]),
1029
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
1030
+ ([1.0, 2.0, 3.0], list[float]),
1031
+ )
1032
+
1033
+ value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
1034
+ validate_full_roundtrip(
1035
+ value_f64,
1036
+ Vector[np.float64, Literal[3]],
1037
+ ([np.float64(1.0), np.float64(2.0), np.float64(3.0)], list[np.float64]),
1038
+ ([1.0, 2.0, 3.0], list[cocoindex.Float64]),
1039
+ ([1.0, 2.0, 3.0], list[float]),
1040
+ )
1041
+
1042
+ value_i64 = np.array([1, 2, 3], dtype=np.int64)
1043
+ validate_full_roundtrip(
1044
+ value_i64,
1045
+ Vector[np.int64, Literal[3]],
1046
+ ([np.int64(1), np.int64(2), np.int64(3)], list[np.int64]),
1047
+ ([1, 2, 3], list[int]),
1048
+ )
1049
+
1050
+ value_i32 = np.array([1, 2, 3], dtype=np.int32)
1051
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1052
+ validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
1053
+ value_u8 = np.array([1, 2, 3], dtype=np.uint8)
1054
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1055
+ validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
1056
+ value_u16 = np.array([1, 2, 3], dtype=np.uint16)
1057
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1058
+ validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
1059
+ value_u32 = np.array([1, 2, 3], dtype=np.uint32)
1060
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1061
+ validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
1062
+ value_u64 = np.array([1, 2, 3], dtype=np.uint64)
1063
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1064
+ validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
1065
+
1066
+
1067
+ def test_full_roundtrip_vector_of_vector() -> None:
1068
+ """Test full roundtrip for vector of vector."""
1069
+ value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1070
+ validate_full_roundtrip(
1071
+ value_f32,
1072
+ Vector[Vector[np.float32, Literal[3]], Literal[2]],
1073
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
1074
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
1075
+ (
1076
+ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1077
+ list[Vector[cocoindex.Float32, Literal[3]]],
1078
+ ),
1079
+ (
1080
+ value_f32,
1081
+ np.typing.NDArray[np.float32],
1082
+ ),
1083
+ )
1084
+
1085
+
1086
+ def test_full_roundtrip_vector_other_types() -> None:
1087
+ """Test full roundtrip for Vector with non-numeric basic types."""
1088
+ uuid_list = [uuid.uuid4(), uuid.uuid4()]
1089
+ validate_full_roundtrip(uuid_list, Vector[uuid.UUID], (uuid_list, list[uuid.UUID]))
1090
+
1091
+ date_list = [datetime.date(2023, 1, 1), datetime.date(2024, 10, 5)]
1092
+ validate_full_roundtrip(
1093
+ date_list, Vector[datetime.date], (date_list, list[datetime.date])
1094
+ )
1095
+
1096
+ bool_list = [True, False, True, False]
1097
+ validate_full_roundtrip(bool_list, Vector[bool], (bool_list, list[bool]))
1098
+
1099
+ validate_full_roundtrip([], Vector[uuid.UUID], ([], list[uuid.UUID]))
1100
+ validate_full_roundtrip([], Vector[datetime.date], ([], list[datetime.date]))
1101
+ validate_full_roundtrip([], Vector[bool], ([], list[bool]))
1102
+
1103
+
1104
+ def test_roundtrip_vector_no_dimension() -> None:
1105
+ """Test full roundtrip for vector types without dimension annotation."""
1106
+ value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
1107
+ validate_full_roundtrip(
1108
+ value_f64,
1109
+ Vector[np.float64],
1110
+ ([1.0, 2.0, 3.0], list[float]),
1111
+ (np.array([1.0, 2.0, 3.0], dtype=np.float64), np.typing.NDArray[np.float64]),
1112
+ )
1113
+
1114
+
1115
+ def test_roundtrip_string_vector() -> None:
1116
+ """Test full roundtrip for string vector using list."""
1117
+ value_str: Vector[str] = ["hello", "world"]
1118
+ validate_full_roundtrip(value_str, Vector[str])
1119
+
1120
+
1121
+ def test_roundtrip_empty_vector() -> None:
1122
+ """Test full roundtrip for empty numeric vector."""
1123
+ value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
1124
+ validate_full_roundtrip(value_empty, Vector[np.float32])
1125
+
1126
+
1127
+ def test_roundtrip_dimension_mismatch() -> None:
1128
+ """Test that dimension mismatch raises an error during roundtrip."""
1129
+ value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
1130
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
1131
+ validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
1132
+
1133
+
1134
+ def test_full_roundtrip_scalar_numeric_types() -> None:
1135
+ """Test full roundtrip for scalar NumPy numeric types."""
1136
+ # Test supported scalar types
1137
+ validate_full_roundtrip(np.int64(42), np.int64, (42, int))
1138
+ validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, cocoindex.Float32))
1139
+ validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, cocoindex.Float64))
1140
+
1141
+ # Test unsupported scalar types
1142
+ for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
1143
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1144
+ validate_full_roundtrip(unsupported_type(1), unsupported_type)
1145
+
1146
+
1147
+ def test_full_roundtrip_nullable_scalar() -> None:
1148
+ """Test full roundtrip for nullable scalar NumPy types."""
1149
+ # Test with non-null values
1150
+ validate_full_roundtrip(np.int64(42), np.int64 | None)
1151
+ validate_full_roundtrip(np.float32(3.14), np.float32 | None)
1152
+ validate_full_roundtrip(np.float64(2.718), np.float64 | None)
1153
+
1154
+ # Test with None
1155
+ validate_full_roundtrip(None, np.int64 | None)
1156
+ validate_full_roundtrip(None, np.float32 | None)
1157
+ validate_full_roundtrip(None, np.float64 | None)
1158
+
1159
+
1160
+ def test_full_roundtrip_scalar_in_struct() -> None:
1161
+ """Test full roundtrip for scalar NumPy types in a dataclass."""
1162
+
1163
+ @dataclass
1164
+ class NumericStruct:
1165
+ int_field: np.int64
1166
+ float32_field: np.float32
1167
+ float64_field: np.float64
1168
+
1169
+ instance = NumericStruct(
1170
+ int_field=np.int64(42),
1171
+ float32_field=np.float32(3.14),
1172
+ float64_field=np.float64(2.718),
1173
+ )
1174
+ validate_full_roundtrip(instance, NumericStruct)
1175
+
1176
+
1177
+ def test_full_roundtrip_scalar_in_nested_struct() -> None:
1178
+ """Test full roundtrip for scalar NumPy types in a nested struct."""
1179
+
1180
+ @dataclass
1181
+ class InnerStruct:
1182
+ value: np.float64
1183
+
1184
+ @dataclass
1185
+ class OuterStruct:
1186
+ inner: InnerStruct
1187
+ count: np.int64
1188
+
1189
+ instance = OuterStruct(
1190
+ inner=InnerStruct(value=np.float64(2.718)),
1191
+ count=np.int64(1),
1192
+ )
1193
+ validate_full_roundtrip(instance, OuterStruct)
1194
+
1195
+
1196
+ def test_full_roundtrip_scalar_with_python_types() -> None:
1197
+ """Test full roundtrip for structs mixing NumPy and Python scalar types."""
1198
+
1199
+ @dataclass
1200
+ class MixedStruct:
1201
+ numpy_int: np.int64
1202
+ python_int: int
1203
+ numpy_float: np.float64
1204
+ python_float: float
1205
+ string: str
1206
+ annotated_int: Annotated[np.int64, TypeKind("Int64")]
1207
+ annotated_float: Float32
1208
+
1209
+ instance = MixedStruct(
1210
+ numpy_int=np.int64(42),
1211
+ python_int=43,
1212
+ numpy_float=np.float64(2.718),
1213
+ python_float=3.14,
1214
+ string="hello, world",
1215
+ annotated_int=np.int64(42),
1216
+ annotated_float=2.0,
1217
+ )
1218
+ validate_full_roundtrip(instance, MixedStruct)
1219
+
1220
+
1221
+ def test_roundtrip_simple_struct_to_dict_binding() -> None:
1222
+ """Test struct -> dict binding with Any annotation."""
1223
+
1224
+ @dataclass
1225
+ class SimpleStruct:
1226
+ first_name: str
1227
+ last_name: str
1228
+
1229
+ instance = SimpleStruct("John", "Doe")
1230
+ expected_dict = {"first_name": "John", "last_name": "Doe"}
1231
+
1232
+ # Test Any annotation
1233
+ validate_full_roundtrip(
1234
+ instance,
1235
+ SimpleStruct,
1236
+ (expected_dict, Any),
1237
+ (expected_dict, dict),
1238
+ (expected_dict, dict[Any, Any]),
1239
+ (expected_dict, dict[str, Any]),
1240
+ # For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
1241
+ (expected_dict, dict[Any, str]),
1242
+ (expected_dict, dict[str, str]),
1243
+ )
1244
+
1245
+ with pytest.raises(ValueError):
1246
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
1247
+
1248
+ with pytest.raises(ValueError):
1249
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1250
+
1251
+
1252
+ def test_roundtrip_struct_to_dict_binding() -> None:
1253
+ """Test struct -> dict binding with Any annotation."""
1254
+
1255
+ @dataclass
1256
+ class SimpleStruct:
1257
+ name: str
1258
+ value: int
1259
+ price: float
1260
+
1261
+ instance = SimpleStruct("test", 42, 3.14)
1262
+ expected_dict = {"name": "test", "value": 42, "price": 3.14}
1263
+
1264
+ # Test Any annotation
1265
+ validate_full_roundtrip(
1266
+ instance,
1267
+ SimpleStruct,
1268
+ (expected_dict, Any),
1269
+ (expected_dict, dict),
1270
+ (expected_dict, dict[Any, Any]),
1271
+ (expected_dict, dict[str, Any]),
1272
+ )
1273
+
1274
+ with pytest.raises(ValueError):
1275
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
1276
+
1277
+ with pytest.raises(ValueError):
1278
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1279
+
1280
+
1281
+ def test_roundtrip_struct_to_dict_explicit() -> None:
1282
+ """Test struct -> dict binding with explicit dict annotations."""
1283
+
1284
+ @dataclass
1285
+ class Product:
1286
+ id: str
1287
+ name: str
1288
+ price: float
1289
+ active: bool
1290
+
1291
+ instance = Product("P1", "Widget", 29.99, True)
1292
+ expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}
1293
+
1294
+ # Test explicit dict annotations
1295
+ validate_full_roundtrip(
1296
+ instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
1297
+ )
1298
+
1299
+
1300
+ def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
1301
+ """Test struct -> dict binding with None annotation."""
1302
+
1303
+ @dataclass
1304
+ class Config:
1305
+ host: str
1306
+ port: int
1307
+ debug: bool
1308
+
1309
+ instance = Config("localhost", 8080, True)
1310
+ expected_dict = {"host": "localhost", "port": 8080, "debug": True}
1311
+
1312
+ # Test empty annotation (should be treated as Any)
1313
+ validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
1314
+
1315
+
1316
+ def test_roundtrip_struct_to_dict_nested() -> None:
1317
+ """Test struct -> dict binding with nested structs."""
1318
+
1319
+ @dataclass
1320
+ class Address:
1321
+ street: str
1322
+ city: str
1323
+
1324
+ @dataclass
1325
+ class Person:
1326
+ name: str
1327
+ age: int
1328
+ address: Address
1329
+
1330
+ address = Address("123 Main St", "Anytown")
1331
+ person = Person("John", 30, address)
1332
+ expected_dict = {
1333
+ "name": "John",
1334
+ "age": 30,
1335
+ "address": {"street": "123 Main St", "city": "Anytown"},
1336
+ }
1337
+
1338
+ # Test nested struct conversion
1339
+ validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))
1340
+
1341
+
1342
+ def test_roundtrip_struct_to_dict_with_list() -> None:
1343
+ """Test struct -> dict binding with list fields."""
1344
+
1345
+ @dataclass
1346
+ class Team:
1347
+ name: str
1348
+ members: list[str]
1349
+ active: bool
1350
+
1351
+ instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
1352
+ expected_dict = {
1353
+ "name": "Dev Team",
1354
+ "members": ["Alice", "Bob", "Charlie"],
1355
+ "active": True,
1356
+ }
1357
+
1358
+ validate_full_roundtrip(instance, Team, (expected_dict, dict))
1359
+
1360
+
1361
+ def test_roundtrip_namedtuple_to_dict_binding() -> None:
1362
+ """Test NamedTuple -> dict binding."""
1363
+
1364
+ class Point(NamedTuple):
1365
+ x: float
1366
+ y: float
1367
+ z: float
1368
+
1369
+ instance = Point(1.0, 2.0, 3.0)
1370
+ expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}
1371
+
1372
+ validate_full_roundtrip(
1373
+ instance, Point, (expected_dict, dict), (expected_dict, Any)
1374
+ )
1375
+
1376
+
1377
+ def test_roundtrip_ltable_to_list_dict_binding() -> None:
1378
+ """Test LTable -> list[dict] binding with Any annotation."""
1379
+
1380
+ @dataclass
1381
+ class User:
1382
+ id: str
1383
+ name: str
1384
+ age: int
1385
+
1386
+ users = [User("u1", "Alice", 25), User("u2", "Bob", 30), User("u3", "Charlie", 35)]
1387
+ expected_list_dict = [
1388
+ {"id": "u1", "name": "Alice", "age": 25},
1389
+ {"id": "u2", "name": "Bob", "age": 30},
1390
+ {"id": "u3", "name": "Charlie", "age": 35},
1391
+ ]
1392
+
1393
+ # Test Any annotation
1394
+ validate_full_roundtrip(
1395
+ users,
1396
+ list[User],
1397
+ (expected_list_dict, Any),
1398
+ (expected_list_dict, list[Any]),
1399
+ (expected_list_dict, list[dict[str, Any]]),
1400
+ )
1401
+
1402
+
1403
+ def test_roundtrip_ktable_to_dict_dict_binding() -> None:
1404
+ """Test KTable -> dict[K, dict] binding with Any annotation."""
1405
+
1406
+ @dataclass
1407
+ class Product:
1408
+ name: str
1409
+ price: float
1410
+ active: bool
1411
+
1412
+ products = {
1413
+ "p1": Product("Widget", 29.99, True),
1414
+ "p2": Product("Gadget", 49.99, False),
1415
+ "p3": Product("Tool", 19.99, True),
1416
+ }
1417
+ expected_dict_dict = {
1418
+ "p1": {"name": "Widget", "price": 29.99, "active": True},
1419
+ "p2": {"name": "Gadget", "price": 49.99, "active": False},
1420
+ "p3": {"name": "Tool", "price": 19.99, "active": True},
1421
+ }
1422
+
1423
+ # Test Any annotation
1424
+ validate_full_roundtrip(
1425
+ products,
1426
+ dict[str, Product],
1427
+ (expected_dict_dict, Any),
1428
+ (expected_dict_dict, dict),
1429
+ (expected_dict_dict, dict[Any, Any]),
1430
+ (expected_dict_dict, dict[str, Any]),
1431
+ (expected_dict_dict, dict[Any, dict[Any, Any]]),
1432
+ (expected_dict_dict, dict[str, dict[Any, Any]]),
1433
+ (expected_dict_dict, dict[str, dict[str, Any]]),
1434
+ )
1435
+
1436
+
1437
+ def test_roundtrip_ktable_with_complex_key() -> None:
1438
+ """Test KTable with complex key types -> dict binding."""
1439
+
1440
+ @dataclass(frozen=True)
1441
+ class OrderKey:
1442
+ shop_id: str
1443
+ version: int
1444
+
1445
+ @dataclass
1446
+ class Order:
1447
+ customer: str
1448
+ total: float
1449
+
1450
+ orders = {
1451
+ OrderKey("shop1", 1): Order("Alice", 100.0),
1452
+ OrderKey("shop2", 2): Order("Bob", 200.0),
1453
+ }
1454
+ expected_dict_dict = {
1455
+ ("shop1", 1): {"customer": "Alice", "total": 100.0},
1456
+ ("shop2", 2): {"customer": "Bob", "total": 200.0},
1457
+ }
1458
+
1459
+ # Test Any annotation
1460
+ validate_full_roundtrip(
1461
+ orders,
1462
+ dict[OrderKey, Order],
1463
+ (expected_dict_dict, Any),
1464
+ (expected_dict_dict, dict),
1465
+ (expected_dict_dict, dict[Any, Any]),
1466
+ (expected_dict_dict, dict[Any, dict[str, Any]]),
1467
+ (
1468
+ {
1469
+ ("shop1", 1): Order("Alice", 100.0),
1470
+ ("shop2", 2): Order("Bob", 200.0),
1471
+ },
1472
+ dict[Any, Order],
1473
+ ),
1474
+ (
1475
+ {
1476
+ OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
1477
+ OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
1478
+ },
1479
+ dict[OrderKey, Any],
1480
+ ),
1481
+ )
1482
+
1483
+
1484
+ def test_roundtrip_ltable_with_nested_structs() -> None:
1485
+ """Test LTable with nested structs -> list[dict] binding."""
1486
+
1487
+ @dataclass
1488
+ class Address:
1489
+ street: str
1490
+ city: str
1491
+
1492
+ @dataclass
1493
+ class Person:
1494
+ name: str
1495
+ age: int
1496
+ address: Address
1497
+
1498
+ people = [
1499
+ Person("John", 30, Address("123 Main St", "Anytown")),
1500
+ Person("Jane", 25, Address("456 Oak Ave", "Somewhere")),
1501
+ ]
1502
+ expected_list_dict = [
1503
+ {
1504
+ "name": "John",
1505
+ "age": 30,
1506
+ "address": {"street": "123 Main St", "city": "Anytown"},
1507
+ },
1508
+ {
1509
+ "name": "Jane",
1510
+ "age": 25,
1511
+ "address": {"street": "456 Oak Ave", "city": "Somewhere"},
1512
+ },
1513
+ ]
1514
+
1515
+ # Test Any annotation
1516
+ validate_full_roundtrip(people, list[Person], (expected_list_dict, Any))
1517
+
1518
+
1519
+ def test_roundtrip_ktable_with_list_fields() -> None:
1520
+ """Test KTable with list fields -> dict binding."""
1521
+
1522
+ @dataclass
1523
+ class Team:
1524
+ name: str
1525
+ members: list[str]
1526
+ active: bool
1527
+
1528
+ teams = {
1529
+ "team1": Team("Dev Team", ["Alice", "Bob"], True),
1530
+ "team2": Team("QA Team", ["Charlie", "David"], False),
1531
+ }
1532
+ expected_dict_dict = {
1533
+ "team1": {"name": "Dev Team", "members": ["Alice", "Bob"], "active": True},
1534
+ "team2": {"name": "QA Team", "members": ["Charlie", "David"], "active": False},
1535
+ }
1536
+
1537
+ # Test Any annotation
1538
+ validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
1539
+
1540
+
1541
+ def test_auto_default_for_supported_and_unsupported_types() -> None:
1542
+ @dataclass
1543
+ class Base:
1544
+ a: int
1545
+
1546
+ @dataclass
1547
+ class NullableField:
1548
+ a: int
1549
+ b: int | None
1550
+
1551
+ @dataclass
1552
+ class LTableField:
1553
+ a: int
1554
+ b: list[Base]
1555
+
1556
+ @dataclass
1557
+ class KTableField:
1558
+ a: int
1559
+ b: dict[str, Base]
1560
+
1561
+ @dataclass
1562
+ class UnsupportedField:
1563
+ a: int
1564
+ b: int
1565
+
1566
+ validate_full_roundtrip(NullableField(1, None), NullableField)
1567
+
1568
+ validate_full_roundtrip(LTableField(1, []), LTableField)
1569
+
1570
+ validate_full_roundtrip(KTableField(1, {}), KTableField)
1571
+
1572
+ with pytest.raises(
1573
+ ValueError,
1574
+ match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
1575
+ ):
1576
+ build_engine_value_decoder(Base, UnsupportedField)
1577
+
1578
+
1579
+ # Pydantic model tests
1580
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1581
+ def test_pydantic_simple_struct() -> None:
1582
+ """Test basic Pydantic model encoding and decoding."""
1583
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1584
+ validate_full_roundtrip(order, OrderPydantic)
1585
+
1586
+
1587
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1588
+ def test_pydantic_struct_with_defaults() -> None:
1589
+ """Test Pydantic model with default values."""
1590
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1591
+ assert order.extra_field == "default_extra"
1592
+ validate_full_roundtrip(order, OrderPydantic)
1593
+
1594
+
1595
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1596
+ def test_pydantic_nested_struct() -> None:
1597
+ """Test nested Pydantic models."""
1598
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1599
+ customer = CustomerPydantic(name="Alice", order=order)
1600
+ validate_full_roundtrip(customer, CustomerPydantic)
1601
+
1602
+
1603
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1604
+ def test_pydantic_struct_with_list() -> None:
1605
+ """Test Pydantic model with list fields."""
1606
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1607
+ tags = [TagPydantic(name="vip"), TagPydantic(name="premium")]
1608
+ customer = CustomerPydantic(name="Alice", order=order, tags=tags)
1609
+ validate_full_roundtrip(customer, CustomerPydantic)
1610
+
1611
+
1612
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1613
+ def test_pydantic_complex_nested_struct() -> None:
1614
+ """Test complex nested Pydantic structure."""
1615
+ order1 = OrderPydantic(order_id="O1", name="item1", price=10.0)
1616
+ order2 = OrderPydantic(order_id="O2", name="item2", price=20.0)
1617
+ customer = CustomerPydantic(
1618
+ name="Alice", order=order1, tags=[TagPydantic(name="vip")]
1619
+ )
1620
+ nested = NestedStructPydantic(customer=customer, orders=[order1, order2], count=2)
1621
+ validate_full_roundtrip(nested, NestedStructPydantic)
1622
+
1623
+
1624
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1625
+ def test_pydantic_struct_to_dict_binding() -> None:
1626
+ """Test Pydantic model -> dict binding."""
1627
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0, extra_field="custom")
1628
+ expected_dict = {
1629
+ "order_id": "O1",
1630
+ "name": "item1",
1631
+ "price": 10.0,
1632
+ "extra_field": "custom",
1633
+ }
1634
+
1635
+ validate_full_roundtrip(
1636
+ order,
1637
+ OrderPydantic,
1638
+ (expected_dict, Any),
1639
+ (expected_dict, dict),
1640
+ (expected_dict, dict[Any, Any]),
1641
+ (expected_dict, dict[str, Any]),
1642
+ )
1643
+
1644
+
1645
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1646
+ def test_make_engine_value_decoder_pydantic_struct() -> None:
1647
+ """Test engine value decoder for Pydantic models."""
1648
+ engine_val = ["O1", "item1", 10.0, "default_extra"]
1649
+ decoder = build_engine_value_decoder(OrderPydantic)
1650
+ result = decoder(engine_val)
1651
+
1652
+ assert isinstance(result, OrderPydantic)
1653
+ assert result.order_id == "O1"
1654
+ assert result.name == "item1"
1655
+ assert result.price == 10.0
1656
+ assert result.extra_field == "default_extra"
1657
+
1658
+
1659
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1660
+ def test_make_engine_value_decoder_pydantic_nested() -> None:
1661
+ """Test engine value decoder for nested Pydantic models."""
1662
+ engine_val = [
1663
+ "Alice",
1664
+ ["O1", "item1", 10.0, "default_extra"],
1665
+ [["vip"]],
1666
+ ]
1667
+ decoder = build_engine_value_decoder(CustomerPydantic)
1668
+ result = decoder(engine_val)
1669
+
1670
+ assert isinstance(result, CustomerPydantic)
1671
+ assert result.name == "Alice"
1672
+ assert isinstance(result.order, OrderPydantic)
1673
+ assert result.order.order_id == "O1"
1674
+ assert result.tags is not None
1675
+ assert len(result.tags) == 1
1676
+ assert isinstance(result.tags[0], TagPydantic)
1677
+ assert result.tags[0].name == "vip"
1678
+
1679
+
1680
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1681
+ def test_pydantic_mixed_with_dataclass() -> None:
1682
+ """Test mixing Pydantic models with dataclasses."""
1683
+
1684
+ # Create a dataclass that uses a Pydantic model
1685
+ @dataclass
1686
+ class MixedStruct:
1687
+ name: str
1688
+ pydantic_order: OrderPydantic
1689
+
1690
+ order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1691
+ mixed = MixedStruct(name="test", pydantic_order=order)
1692
+ validate_full_roundtrip(mixed, MixedStruct)
1693
+
1694
+
1695
+ def test_forward_ref_in_dataclass() -> None:
1696
+ """Test mixing Pydantic models with dataclasses."""
1697
+
1698
+ @dataclass
1699
+ class Event:
1700
+ name: "str"
1701
+ tag: "Tag"
1702
+
1703
+ validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)
1704
+
1705
+
1706
+ def test_forward_ref_in_namedtuple() -> None:
1707
+ """Test mixing Pydantic models with dataclasses."""
1708
+
1709
+ class Event(NamedTuple):
1710
+ name: "str"
1711
+ tag: "Tag"
1712
+
1713
+ validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)
1714
+
1715
+
1716
+ @pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1717
+ def test_forward_ref_in_pydantic() -> None:
1718
+ """Test mixing Pydantic models with dataclasses."""
1719
+
1720
+ class Event(BaseModel):
1721
+ name: "str"
1722
+ tag: "Tag"
1723
+
1724
+ validate_full_roundtrip(Event(name="E1", tag=Tag(name="T1")), Event)