cocoindex 0.2.3__cp311-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1543 @@
1
+ import datetime
2
+ import inspect
3
+ import uuid
4
+ from dataclasses import dataclass, make_dataclass
5
+ from typing import Annotated, Any, Callable, Literal, NamedTuple, Type
6
+
7
+ import numpy as np
8
+ import pytest
9
+ from numpy.typing import NDArray
10
+
11
+ import cocoindex
12
+ from cocoindex.convert import (
13
+ dump_engine_object,
14
+ make_engine_value_encoder,
15
+ make_engine_value_decoder,
16
+ )
17
+ from cocoindex.typing import (
18
+ Float32,
19
+ Float64,
20
+ TypeKind,
21
+ Vector,
22
+ analyze_type_info,
23
+ encode_enriched_type,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class Order:
29
+ order_id: str
30
+ name: str
31
+ price: float
32
+ extra_field: str = "default_extra"
33
+
34
+
35
+ @dataclass
36
+ class Tag:
37
+ name: str
38
+
39
+
40
+ @dataclass
41
+ class Basket:
42
+ items: list[str]
43
+
44
+
45
+ @dataclass
46
+ class Customer:
47
+ name: str
48
+ order: Order
49
+ tags: list[Tag] | None = None
50
+
51
+
52
+ @dataclass
53
+ class NestedStruct:
54
+ customer: Customer
55
+ orders: list[Order]
56
+ count: int = 0
57
+
58
+
59
+ class OrderNamedTuple(NamedTuple):
60
+ order_id: str
61
+ name: str
62
+ price: float
63
+ extra_field: str = "default_extra"
64
+
65
+
66
+ class CustomerNamedTuple(NamedTuple):
67
+ name: str
68
+ order: OrderNamedTuple
69
+ tags: list[Tag] | None = None
70
+
71
+
72
+ def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
73
+ """
74
+ Encode a Python value to an engine value.
75
+ """
76
+ encoder = make_engine_value_encoder(analyze_type_info(type_hint))
77
+ return encoder(value)
78
+
79
+
80
+ def build_engine_value_decoder(
81
+ engine_type_in_py: Any, python_type: Any | None = None
82
+ ) -> Callable[[Any], Any]:
83
+ """
84
+ Helper to build a converter for the given engine-side type (as represented in Python).
85
+ If python_type is not specified, uses engine_type_in_py as the target.
86
+ """
87
+ engine_type = encode_enriched_type(engine_type_in_py)["type"]
88
+ return make_engine_value_decoder(
89
+ [], engine_type, analyze_type_info(python_type or engine_type_in_py)
90
+ )
91
+
92
+
93
+ def validate_full_roundtrip_to(
94
+ value: Any,
95
+ value_type: Any,
96
+ *decoded_values: tuple[Any, Any],
97
+ ) -> None:
98
+ """
99
+ Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
100
+
101
+ `decoded_values` is a tuple of (value, type) pairs.
102
+ """
103
+ from cocoindex import _engine # type: ignore
104
+
105
+ def eq(a: Any, b: Any) -> bool:
106
+ if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
107
+ return np.array_equal(a, b)
108
+ return type(a) is type(b) and not not (a == b)
109
+
110
+ encoded_value = encode_engine_value(value, value_type)
111
+ value_type = value_type or type(value)
112
+ encoded_output_type = encode_enriched_type(value_type)["type"]
113
+ value_from_engine = _engine.testutil.seder_roundtrip(
114
+ encoded_value, encoded_output_type
115
+ )
116
+
117
+ for other_value, other_type in decoded_values:
118
+ decoder = make_engine_value_decoder(
119
+ [], encoded_output_type, analyze_type_info(other_type)
120
+ )
121
+ other_decoded_value = decoder(value_from_engine)
122
+ assert eq(other_decoded_value, other_value), (
123
+ f"Expected {other_value} but got {other_decoded_value} for {other_type}"
124
+ )
125
+
126
+
127
+ def validate_full_roundtrip(
128
+ value: Any,
129
+ value_type: Any,
130
+ *other_decoded_values: tuple[Any, Any],
131
+ ) -> None:
132
+ """
133
+ Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
134
+
135
+ `other_decoded_values` is a tuple of (value, type) pairs.
136
+ If provided, also validate the value can be decoded to the other types.
137
+ """
138
+ validate_full_roundtrip_to(
139
+ value, value_type, (value, value_type), *other_decoded_values
140
+ )
141
+
142
+
143
+ def test_encode_engine_value_basic_types() -> None:
144
+ assert encode_engine_value(123, int) == 123
145
+ assert encode_engine_value(3.14, float) == 3.14
146
+ assert encode_engine_value("hello", str) == "hello"
147
+ assert encode_engine_value(True, bool) is True
148
+
149
+
150
+ def test_encode_engine_value_uuid() -> None:
151
+ u = uuid.uuid4()
152
+ assert encode_engine_value(u, uuid.UUID) == u
153
+
154
+
155
+ def test_encode_engine_value_date_time_types() -> None:
156
+ d = datetime.date(2024, 1, 1)
157
+ assert encode_engine_value(d, datetime.date) == d
158
+ t = datetime.time(12, 30)
159
+ assert encode_engine_value(t, datetime.time) == t
160
+ dt = datetime.datetime(2024, 1, 1, 12, 30)
161
+ assert encode_engine_value(dt, datetime.datetime) == dt
162
+
163
+
164
+ def test_encode_scalar_numpy_values() -> None:
165
+ """Test encoding scalar NumPy values to engine-compatible values."""
166
+ test_cases = [
167
+ (np.int64(42), 42),
168
+ (np.float32(3.14), pytest.approx(3.14)),
169
+ (np.float64(2.718), pytest.approx(2.718)),
170
+ ]
171
+ for np_value, expected in test_cases:
172
+ encoded = encode_engine_value(np_value, type(np_value))
173
+ assert encoded == expected
174
+ assert isinstance(encoded, (int, float))
175
+
176
+
177
+ def test_encode_engine_value_struct() -> None:
178
+ order = Order(order_id="O123", name="mixed nuts", price=25.0)
179
+ assert encode_engine_value(order, Order) == [
180
+ "O123",
181
+ "mixed nuts",
182
+ 25.0,
183
+ "default_extra",
184
+ ]
185
+
186
+ order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
187
+ assert encode_engine_value(order_nt, OrderNamedTuple) == [
188
+ "O123",
189
+ "mixed nuts",
190
+ 25.0,
191
+ "default_extra",
192
+ ]
193
+
194
+
195
+ def test_encode_engine_value_list_of_structs() -> None:
196
+ orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
197
+ assert encode_engine_value(orders, list[Order]) == [
198
+ ["O1", "item1", 10.0, "default_extra"],
199
+ ["O2", "item2", 20.0, "default_extra"],
200
+ ]
201
+
202
+ orders_nt = [
203
+ OrderNamedTuple("O1", "item1", 10.0),
204
+ OrderNamedTuple("O2", "item2", 20.0),
205
+ ]
206
+ assert encode_engine_value(orders_nt, list[OrderNamedTuple]) == [
207
+ ["O1", "item1", 10.0, "default_extra"],
208
+ ["O2", "item2", 20.0, "default_extra"],
209
+ ]
210
+
211
+
212
+ def test_encode_engine_value_struct_with_list() -> None:
213
+ basket = Basket(items=["apple", "banana"])
214
+ assert encode_engine_value(basket, Basket) == [["apple", "banana"]]
215
+
216
+
217
+ def test_encode_engine_value_nested_struct() -> None:
218
+ customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
219
+ assert encode_engine_value(customer, Customer) == [
220
+ "Alice",
221
+ ["O1", "item1", 10.0, "default_extra"],
222
+ None,
223
+ ]
224
+
225
+ customer_nt = CustomerNamedTuple(
226
+ name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)
227
+ )
228
+ assert encode_engine_value(customer_nt, CustomerNamedTuple) == [
229
+ "Alice",
230
+ ["O1", "item1", 10.0, "default_extra"],
231
+ None,
232
+ ]
233
+
234
+
235
+ def test_encode_engine_value_empty_list() -> None:
236
+ assert encode_engine_value([], list) == []
237
+ assert encode_engine_value([[]], list[list[Any]]) == [[]]
238
+
239
+
240
+ def test_encode_engine_value_tuple() -> None:
241
+ assert encode_engine_value((), Any) == []
242
+ assert encode_engine_value((1, 2, 3), Any) == [1, 2, 3]
243
+ assert encode_engine_value(((1, 2), (3, 4)), Any) == [[1, 2], [3, 4]]
244
+ assert encode_engine_value(([],), Any) == [[]]
245
+ assert encode_engine_value(((),), Any) == [[]]
246
+
247
+
248
+ def test_encode_engine_value_none() -> None:
249
+ assert encode_engine_value(None, Any) is None
250
+
251
+
252
+ def test_roundtrip_basic_types() -> None:
253
+ validate_full_roundtrip(
254
+ b"hello world",
255
+ bytes,
256
+ (b"hello world", inspect.Parameter.empty),
257
+ (b"hello world", Any),
258
+ )
259
+ validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
260
+ validate_full_roundtrip("hello", str, ("hello", Any))
261
+ validate_full_roundtrip(True, bool, (True, Any))
262
+ validate_full_roundtrip(False, bool, (False, Any))
263
+ validate_full_roundtrip(
264
+ 42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
265
+ )
266
+ validate_full_roundtrip(42, int, (42, cocoindex.Int64))
267
+ validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
268
+
269
+ validate_full_roundtrip(
270
+ 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
271
+ )
272
+ validate_full_roundtrip(3.25, float, (3.25, Float64))
273
+ validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
274
+
275
+ validate_full_roundtrip(
276
+ 3.25,
277
+ Float32,
278
+ (3.25, float),
279
+ (np.float32(3.25), np.float32),
280
+ (np.float64(3.25), np.float64),
281
+ (3.25, Float64),
282
+ (3.25, Any),
283
+ )
284
+ validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
285
+
286
+
287
+ def test_roundtrip_uuid() -> None:
288
+ uuid_value = uuid.uuid4()
289
+ validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
290
+
291
+
292
+ def test_roundtrip_range() -> None:
293
+ r1 = (0, 100)
294
+ validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
295
+ r2 = (50, 50)
296
+ validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
297
+ r3 = (0, 1_000_000_000)
298
+ validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
299
+
300
+
301
+ def test_roundtrip_time() -> None:
302
+ t1 = datetime.time(10, 30, 50, 123456)
303
+ validate_full_roundtrip(t1, datetime.time, (t1, Any))
304
+ t2 = datetime.time(23, 59, 59)
305
+ validate_full_roundtrip(t2, datetime.time, (t2, Any))
306
+ t3 = datetime.time(0, 0, 0)
307
+ validate_full_roundtrip(t3, datetime.time, (t3, Any))
308
+
309
+ validate_full_roundtrip(
310
+ datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
311
+ )
312
+
313
+ validate_full_roundtrip(
314
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
315
+ cocoindex.LocalDateTime,
316
+ (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
317
+ )
318
+
319
+ tz = datetime.timezone(datetime.timedelta(hours=5))
320
+ validate_full_roundtrip(
321
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
322
+ cocoindex.OffsetDateTime,
323
+ (
324
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
325
+ datetime.datetime,
326
+ ),
327
+ )
328
+ validate_full_roundtrip(
329
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
330
+ datetime.datetime,
331
+ (datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime),
332
+ )
333
+ validate_full_roundtrip_to(
334
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
335
+ cocoindex.OffsetDateTime,
336
+ (
337
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
338
+ datetime.datetime,
339
+ ),
340
+ )
341
+ validate_full_roundtrip_to(
342
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
343
+ datetime.datetime,
344
+ (
345
+ datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
346
+ cocoindex.OffsetDateTime,
347
+ ),
348
+ )
349
+
350
+
351
+ def test_roundtrip_timedelta() -> None:
352
+ td1 = datetime.timedelta(
353
+ days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
354
+ )
355
+ validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
356
+ td2 = datetime.timedelta(days=-5, hours=-2)
357
+ validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
358
+ td3 = datetime.timedelta(0)
359
+ validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
360
+
361
+
362
+ def test_roundtrip_json() -> None:
363
+ simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23}
364
+ validate_full_roundtrip(simple_dict, cocoindex.Json)
365
+
366
+ simple_list = [1, "string", False, None, 4.56]
367
+ validate_full_roundtrip(simple_list, cocoindex.Json)
368
+
369
+ nested_structure = {
370
+ "name": "Test Json",
371
+ "version": 1.0,
372
+ "items": [
373
+ {"id": 1, "value": "item1"},
374
+ {"id": 2, "value": None, "props": {"active": True}},
375
+ ],
376
+ "metadata": None,
377
+ }
378
+ validate_full_roundtrip(nested_structure, cocoindex.Json)
379
+
380
+ validate_full_roundtrip({}, cocoindex.Json)
381
+ validate_full_roundtrip([], cocoindex.Json)
382
+
383
+
384
+ def test_decode_scalar_numpy_values() -> None:
385
+ test_cases = [
386
+ ({"kind": "Int64"}, np.int64, 42, np.int64(42)),
387
+ ({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
388
+ ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
389
+ ]
390
+ for src_type, dst_type, input_value, expected in test_cases:
391
+ decoder = make_engine_value_decoder(
392
+ ["field"], src_type, analyze_type_info(dst_type)
393
+ )
394
+ result = decoder(input_value)
395
+ assert isinstance(result, dst_type)
396
+ assert result == expected
397
+
398
+
399
+ def test_non_ndarray_vector_decoding() -> None:
400
+ # Test list[np.float64]
401
+ src_type = {
402
+ "kind": "Vector",
403
+ "element_type": {"kind": "Float64"},
404
+ "dimension": None,
405
+ }
406
+ dst_type_float = list[np.float64]
407
+ decoder = make_engine_value_decoder(
408
+ ["field"], src_type, analyze_type_info(dst_type_float)
409
+ )
410
+ input_numbers = [1.0, 2.0, 3.0]
411
+ result = decoder(input_numbers)
412
+ assert isinstance(result, list)
413
+ assert all(isinstance(x, np.float64) for x in result)
414
+ assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
415
+
416
+ # Test list[Uuid]
417
+ src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
418
+ dst_type_uuid = list[uuid.UUID]
419
+ decoder = make_engine_value_decoder(
420
+ ["field"], src_type, analyze_type_info(dst_type_uuid)
421
+ )
422
+ uuid1 = uuid.uuid4()
423
+ uuid2 = uuid.uuid4()
424
+ input_uuids = [uuid1, uuid2]
425
+ result = decoder(input_uuids)
426
+ assert isinstance(result, list)
427
+ assert all(isinstance(x, uuid.UUID) for x in result)
428
+ assert result == [uuid1, uuid2]
429
+
430
+
431
+ def test_roundtrip_struct() -> None:
432
+ validate_full_roundtrip(
433
+ Order("O123", "mixed nuts", 25.0, "default_extra"),
434
+ Order,
435
+ )
436
+ validate_full_roundtrip(
437
+ OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
438
+ OrderNamedTuple,
439
+ )
440
+
441
+
442
+ def test_make_engine_value_decoder_list_of_struct() -> None:
443
+ # List of structs (dataclass)
444
+ engine_val = [
445
+ ["O1", "item1", 10.0, "default_extra"],
446
+ ["O2", "item2", 20.0, "default_extra"],
447
+ ]
448
+ decoder = build_engine_value_decoder(list[Order])
449
+ assert decoder(engine_val) == [
450
+ Order("O1", "item1", 10.0, "default_extra"),
451
+ Order("O2", "item2", 20.0, "default_extra"),
452
+ ]
453
+
454
+ # List of structs (NamedTuple)
455
+ decoder = build_engine_value_decoder(list[OrderNamedTuple])
456
+ assert decoder(engine_val) == [
457
+ OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
458
+ OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
459
+ ]
460
+
461
+
462
+ def test_make_engine_value_decoder_struct_of_list() -> None:
463
+ # Struct with list field
464
+ engine_val = [
465
+ "Alice",
466
+ ["O1", "item1", 10.0, "default_extra"],
467
+ [["vip"], ["premium"]],
468
+ ]
469
+ decoder = build_engine_value_decoder(Customer)
470
+ assert decoder(engine_val) == Customer(
471
+ "Alice",
472
+ Order("O1", "item1", 10.0, "default_extra"),
473
+ [Tag("vip"), Tag("premium")],
474
+ )
475
+
476
+ # NamedTuple with list field
477
+ decoder = build_engine_value_decoder(CustomerNamedTuple)
478
+ assert decoder(engine_val) == CustomerNamedTuple(
479
+ "Alice",
480
+ OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
481
+ [Tag("vip"), Tag("premium")],
482
+ )
483
+
484
+
485
+ def test_make_engine_value_decoder_struct_of_struct() -> None:
486
+ # Struct with struct field
487
+ engine_val = [
488
+ ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
489
+ [
490
+ ["O1", "item1", 10.0, "default_extra"],
491
+ ["O2", "item2", 20.0, "default_extra"],
492
+ ],
493
+ 2,
494
+ ]
495
+ decoder = build_engine_value_decoder(NestedStruct)
496
+ assert decoder(engine_val) == NestedStruct(
497
+ Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
498
+ [
499
+ Order("O1", "item1", 10.0, "default_extra"),
500
+ Order("O2", "item2", 20.0, "default_extra"),
501
+ ],
502
+ 2,
503
+ )
504
+
505
+
506
+ def make_engine_order(fields: list[tuple[str, type]]) -> type:
507
+ return make_dataclass("EngineOrder", fields)
508
+
509
+
510
+ def make_python_order(
511
+ fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
512
+ ) -> type:
513
+ if defaults is None:
514
+ defaults = {}
515
+ # Move all fields with defaults to the end (Python dataclass requirement)
516
+ non_default_fields = [(n, t) for n, t in fields if n not in defaults]
517
+ default_fields = [(n, t) for n, t in fields if n in defaults]
518
+ ordered_fields = non_default_fields + default_fields
519
+ # Prepare the namespace for defaults (only for fields at the end)
520
+ namespace = {k: defaults[k] for k, _ in default_fields}
521
+ return make_dataclass("PythonOrder", ordered_fields, namespace=namespace)
522
+
523
+
524
+ @pytest.mark.parametrize(
525
+ "engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
526
+ [
527
+ # Extra field in Python (middle)
528
+ (
529
+ [("id", str), ("name", str)],
530
+ [("id", str), ("price", float), ("name", str)],
531
+ {"price": 0.0},
532
+ ["O123", "mixed nuts"],
533
+ ("O123", 0.0, "mixed nuts"),
534
+ ),
535
+ # Missing field in Python (middle)
536
+ (
537
+ [("id", str), ("price", float), ("name", str)],
538
+ [("id", str), ("name", str)],
539
+ {},
540
+ ["O123", 25.0, "mixed nuts"],
541
+ ("O123", "mixed nuts"),
542
+ ),
543
+ # Extra field in Python (start)
544
+ (
545
+ [("name", str), ("price", float)],
546
+ [("extra", str), ("name", str), ("price", float)],
547
+ {"extra": "default"},
548
+ ["mixed nuts", 25.0],
549
+ ("default", "mixed nuts", 25.0),
550
+ ),
551
+ # Missing field in Python (start)
552
+ (
553
+ [("extra", str), ("name", str), ("price", float)],
554
+ [("name", str), ("price", float)],
555
+ {},
556
+ ["unexpected", "mixed nuts", 25.0],
557
+ ("mixed nuts", 25.0),
558
+ ),
559
+ # Field order difference (should map by name)
560
+ (
561
+ [("id", str), ("name", str), ("price", float)],
562
+ [("name", str), ("id", str), ("price", float), ("extra", str)],
563
+ {"extra": "default"},
564
+ ["O123", "mixed nuts", 25.0],
565
+ ("mixed nuts", "O123", 25.0, "default"),
566
+ ),
567
+ # Extra field (Python has extra field with default)
568
+ (
569
+ [("id", str), ("name", str)],
570
+ [("id", str), ("name", str), ("price", float)],
571
+ {"price": 0.0},
572
+ ["O123", "mixed nuts"],
573
+ ("O123", "mixed nuts", 0.0),
574
+ ),
575
+ # Missing field (Engine has extra field)
576
+ (
577
+ [("id", str), ("name", str), ("price", float)],
578
+ [("id", str), ("name", str)],
579
+ {},
580
+ ["O123", "mixed nuts", 25.0],
581
+ ("O123", "mixed nuts"),
582
+ ),
583
+ ],
584
+ )
585
+ def test_field_position_cases(
586
+ engine_fields: list[tuple[str, type]],
587
+ python_fields: list[tuple[str, type]],
588
+ python_defaults: dict[str, Any],
589
+ engine_val: list[Any],
590
+ expected_python_val: tuple[Any, ...],
591
+ ) -> None:
592
+ EngineOrder = make_engine_order(engine_fields)
593
+ PythonOrder = make_python_order(python_fields, python_defaults)
594
+ decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
595
+ # Map field names to expected values
596
+ expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
597
+ # Instantiate using keyword arguments (order doesn't matter)
598
+ assert decoder(engine_val) == PythonOrder(**expected_dict)
599
+
600
+
601
+ def test_roundtrip_union_simple() -> None:
602
+ t = int | str | float
603
+ value = 10.4
604
+ validate_full_roundtrip(value, t)
605
+
606
+
607
+ def test_roundtrip_union_with_active_uuid() -> None:
608
+ t = str | uuid.UUID | int
609
+ value = uuid.uuid4()
610
+ validate_full_roundtrip(value, t)
611
+
612
+
613
+ def test_roundtrip_union_with_inactive_uuid() -> None:
614
+ t = str | uuid.UUID | int
615
+ value = "5a9f8f6a-318f-4f1f-929d-566d7444a62d" # it's a string
616
+ validate_full_roundtrip(value, t)
617
+
618
+
619
+ def test_roundtrip_union_offset_datetime() -> None:
620
+ t = str | uuid.UUID | float | int | datetime.datetime
621
+ value = datetime.datetime.now(datetime.UTC)
622
+ validate_full_roundtrip(value, t)
623
+
624
+
625
+ def test_roundtrip_union_date() -> None:
626
+ t = str | uuid.UUID | float | int | datetime.date
627
+ value = datetime.date.today()
628
+ validate_full_roundtrip(value, t)
629
+
630
+
631
+ def test_roundtrip_union_time() -> None:
632
+ t = str | uuid.UUID | float | int | datetime.time
633
+ value = datetime.time()
634
+ validate_full_roundtrip(value, t)
635
+
636
+
637
+ def test_roundtrip_union_timedelta() -> None:
638
+ t = str | uuid.UUID | float | int | datetime.timedelta
639
+ value = datetime.timedelta(hours=39, minutes=10, seconds=1)
640
+ validate_full_roundtrip(value, t)
641
+
642
+
643
+ def test_roundtrip_vector_of_union() -> None:
644
+ t = list[str | int]
645
+ value = ["a", 1]
646
+ validate_full_roundtrip(value, t)
647
+
648
+
649
+ def test_roundtrip_union_with_vector() -> None:
650
+ t = NDArray[np.float32] | str
651
+ value = np.array([1.0, 2.0, 3.0], dtype=np.float32)
652
+ validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))
653
+
654
+
655
+ def test_roundtrip_union_with_misc_types() -> None:
656
+ t_bytes_union = int | bytes | str
657
+ validate_full_roundtrip(b"test_bytes", t_bytes_union)
658
+ validate_full_roundtrip(123, t_bytes_union)
659
+
660
+ t_range_union = cocoindex.Range | str | bool
661
+ validate_full_roundtrip((100, 200), t_range_union)
662
+ validate_full_roundtrip("test_string", t_range_union)
663
+
664
+ t_json_union = cocoindex.Json | int | bytes
665
+ json_dict = {"a": 1, "b": [2, 3]}
666
+ validate_full_roundtrip(json_dict, t_json_union)
667
+ validate_full_roundtrip(b"another_byte_string", t_json_union)
668
+
669
+
670
+ def test_roundtrip_ltable() -> None:
671
+ t = list[Order]
672
+ value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
673
+ validate_full_roundtrip(value, t)
674
+
675
+ t_nt = list[OrderNamedTuple]
676
+ value_nt = [
677
+ OrderNamedTuple("O1", "item1", 10.0),
678
+ OrderNamedTuple("O2", "item2", 20.0),
679
+ ]
680
+ validate_full_roundtrip(value_nt, t_nt)
681
+
682
+
683
+ def test_roundtrip_ktable_various_key_types() -> None:
684
+ @dataclass
685
+ class SimpleValue:
686
+ data: str
687
+
688
+ t_bytes_key = dict[bytes, SimpleValue]
689
+ value_bytes_key = {b"key1": SimpleValue("val1"), b"key2": SimpleValue("val2")}
690
+ validate_full_roundtrip(value_bytes_key, t_bytes_key)
691
+
692
+ t_int_key = dict[int, SimpleValue]
693
+ value_int_key = {1: SimpleValue("val1"), 2: SimpleValue("val2")}
694
+ validate_full_roundtrip(value_int_key, t_int_key)
695
+
696
+ t_bool_key = dict[bool, SimpleValue]
697
+ value_bool_key = {True: SimpleValue("val_true"), False: SimpleValue("val_false")}
698
+ validate_full_roundtrip(value_bool_key, t_bool_key)
699
+
700
+ t_str_key = dict[str, Order]
701
+ value_str_key = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
702
+ validate_full_roundtrip(value_str_key, t_str_key)
703
+
704
+ t_nt = dict[str, OrderNamedTuple]
705
+ value_nt = {
706
+ "K1": OrderNamedTuple("O1", "item1", 10.0),
707
+ "K2": OrderNamedTuple("O2", "item2", 20.0),
708
+ }
709
+ validate_full_roundtrip(value_nt, t_nt)
710
+
711
+ t_range_key = dict[cocoindex.Range, SimpleValue]
712
+ value_range_key = {
713
+ (1, 10): SimpleValue("val_range1"),
714
+ (20, 30): SimpleValue("val_range2"),
715
+ }
716
+ validate_full_roundtrip(value_range_key, t_range_key)
717
+
718
+ t_date_key = dict[datetime.date, SimpleValue]
719
+ value_date_key = {
720
+ datetime.date(2023, 1, 1): SimpleValue("val_date1"),
721
+ datetime.date(2024, 2, 2): SimpleValue("val_date2"),
722
+ }
723
+ validate_full_roundtrip(value_date_key, t_date_key)
724
+
725
+ t_uuid_key = dict[uuid.UUID, SimpleValue]
726
+ value_uuid_key = {
727
+ uuid.uuid4(): SimpleValue("val_uuid1"),
728
+ uuid.uuid4(): SimpleValue("val_uuid2"),
729
+ }
730
+ validate_full_roundtrip(value_uuid_key, t_uuid_key)
731
+
732
+
733
+ def test_roundtrip_ktable_struct_key() -> None:
734
+ @dataclass(frozen=True)
735
+ class OrderKey:
736
+ shop_id: str
737
+ version: int
738
+
739
+ t = dict[OrderKey, Order]
740
+ value = {
741
+ OrderKey("A", 3): Order("O1", "item1", 10.0),
742
+ OrderKey("B", 4): Order("O2", "item2", 20.0),
743
+ }
744
+ validate_full_roundtrip(value, t)
745
+
746
+ t_nt = dict[OrderKey, OrderNamedTuple]
747
+ value_nt = {
748
+ OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
749
+ OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
750
+ }
751
+ validate_full_roundtrip(value_nt, t_nt)
752
+
753
+
754
+ IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
755
+
756
+
757
+ def test_vector_as_vector() -> None:
758
+ value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
759
+ encoded = encode_engine_value(value, IntVectorType)
760
+ assert np.array_equal(encoded, value)
761
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
762
+ assert np.array_equal(decoded, value)
763
+
764
+
765
+ ListIntType = list[int]
766
+
767
+
768
+ def test_vector_as_list() -> None:
769
+ value: ListIntType = [1, 2, 3, 4, 5]
770
+ encoded = encode_engine_value(value, ListIntType)
771
+ assert encoded == [1, 2, 3, 4, 5]
772
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
773
+ assert np.array_equal(decoded, value)
774
+
775
+
776
+ Float64VectorTypeNoDim = Vector[np.float64]
777
+ Float32VectorType = Vector[np.float32, Literal[3]]
778
+ Float64VectorType = Vector[np.float64, Literal[3]]
779
+ Int64VectorType = Vector[np.int64, Literal[3]]
780
+ NDArrayFloat32Type = NDArray[np.float32]
781
+ NDArrayFloat64Type = NDArray[np.float64]
782
+ NDArrayInt64Type = NDArray[np.int64]
783
+
784
+
785
+ def test_encode_engine_value_ndarray() -> None:
786
+ """Test encoding NDArray vectors to lists for the Rust engine."""
787
+ vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
788
+ assert np.array_equal(
789
+ encode_engine_value(vec_f32, Float32VectorType), [1.0, 2.0, 3.0]
790
+ )
791
+ vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
792
+ assert np.array_equal(
793
+ encode_engine_value(vec_f64, Float64VectorType), [1.0, 2.0, 3.0]
794
+ )
795
+ vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
796
+ assert np.array_equal(encode_engine_value(vec_i64, Int64VectorType), [1, 2, 3])
797
+ vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
798
+ assert np.array_equal(
799
+ encode_engine_value(vec_nd_f32, NDArrayFloat32Type), [1.0, 2.0, 3.0]
800
+ )
801
+
802
+
803
+ def test_make_engine_value_decoder_ndarray() -> None:
804
+ """Test decoding engine lists to NDArray vectors."""
805
+ decoder_f32 = build_engine_value_decoder(Float32VectorType)
806
+ result_f32 = decoder_f32([1.0, 2.0, 3.0])
807
+ assert isinstance(result_f32, np.ndarray)
808
+ assert result_f32.dtype == np.float32
809
+ assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
810
+ decoder_f64 = build_engine_value_decoder(Float64VectorType)
811
+ result_f64 = decoder_f64([1.0, 2.0, 3.0])
812
+ assert isinstance(result_f64, np.ndarray)
813
+ assert result_f64.dtype == np.float64
814
+ assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
815
+ decoder_i64 = build_engine_value_decoder(Int64VectorType)
816
+ result_i64 = decoder_i64([1, 2, 3])
817
+ assert isinstance(result_i64, np.ndarray)
818
+ assert result_i64.dtype == np.int64
819
+ assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
820
+ decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
821
+ result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
822
+ assert isinstance(result_nd_f32, np.ndarray)
823
+ assert result_nd_f32.dtype == np.float32
824
+ assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
825
+
826
+
827
+ def test_roundtrip_ndarray_vector() -> None:
828
+ """Test roundtrip encoding and decoding of NDArray vectors."""
829
+ value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
830
+ encoded_f32 = encode_engine_value(value_f32, Float32VectorType)
831
+ np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
832
+ decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
833
+ assert isinstance(decoded_f32, np.ndarray)
834
+ assert decoded_f32.dtype == np.float32
835
+ assert np.array_equal(decoded_f32, value_f32)
836
+ value_i64 = np.array([1, 2, 3], dtype=np.int64)
837
+ encoded_i64 = encode_engine_value(value_i64, Int64VectorType)
838
+ assert np.array_equal(encoded_i64, [1, 2, 3])
839
+ decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
840
+ assert isinstance(decoded_i64, np.ndarray)
841
+ assert decoded_i64.dtype == np.int64
842
+ assert np.array_equal(decoded_i64, value_i64)
843
+ value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
844
+ encoded_nd_f64 = encode_engine_value(value_nd_f64, NDArrayFloat64Type)
845
+ assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
846
+ decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
847
+ assert isinstance(decoded_nd_f64, np.ndarray)
848
+ assert decoded_nd_f64.dtype == np.float64
849
+ assert np.array_equal(decoded_nd_f64, value_nd_f64)
850
+
851
+
852
+ def test_ndarray_dimension_mismatch() -> None:
853
+ """Test dimension enforcement for Vector with specified dimension."""
854
+ value = np.array([1.0, 2.0], dtype=np.float32)
855
+ encoded = encode_engine_value(value, NDArray[np.float32])
856
+ assert np.array_equal(encoded, [1.0, 2.0])
857
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
858
+ build_engine_value_decoder(Float32VectorType)(encoded)
859
+
860
+
861
+ def test_list_vector_backward_compatibility() -> None:
862
+ """Test that list-based vectors still work for backward compatibility."""
863
+ value = [1, 2, 3, 4, 5]
864
+ encoded = encode_engine_value(value, list[int])
865
+ assert encoded == [1, 2, 3, 4, 5]
866
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
867
+ assert isinstance(decoded, np.ndarray)
868
+ assert decoded.dtype == np.int64
869
+ assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
870
+ value_list: ListIntType = [1, 2, 3, 4, 5]
871
+ encoded = encode_engine_value(value_list, ListIntType)
872
+ assert np.array_equal(encoded, [1, 2, 3, 4, 5])
873
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
874
+ assert np.array_equal(decoded, [1, 2, 3, 4, 5])
875
+
876
+
877
+ def test_encode_complex_structure_with_ndarray() -> None:
878
+ """Test encoding a complex structure that includes an NDArray."""
879
+
880
+ @dataclass
881
+ class MyStructWithNDArray:
882
+ name: str
883
+ data: NDArray[np.float32]
884
+ value: int
885
+
886
+ original = MyStructWithNDArray(
887
+ name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
888
+ )
889
+ encoded = encode_engine_value(original, MyStructWithNDArray)
890
+
891
+ assert encoded[0] == original.name
892
+ assert np.array_equal(encoded[1], original.data)
893
+ assert encoded[2] == original.value
894
+
895
+
896
+ def test_decode_nullable_ndarray_none_or_value_input() -> None:
897
+ """Test decoding a nullable NDArray with None or value inputs."""
898
+ src_type_dict = {
899
+ "kind": "Vector",
900
+ "element_type": {"kind": "Float32"},
901
+ "dimension": None,
902
+ }
903
+ dst_annotation = NDArrayFloat32Type | None
904
+ decoder = make_engine_value_decoder(
905
+ [], src_type_dict, analyze_type_info(dst_annotation)
906
+ )
907
+
908
+ none_engine_value = None
909
+ decoded_array = decoder(none_engine_value)
910
+ assert decoded_array is None
911
+
912
+ engine_value = [1.0, 2.0, 3.0]
913
+ decoded_array = decoder(engine_value)
914
+
915
+ assert isinstance(decoded_array, np.ndarray)
916
+ assert decoded_array.dtype == np.float32
917
+ np.testing.assert_array_equal(
918
+ decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
919
+ )
920
+
921
+
922
+ def test_decode_vector_string() -> None:
923
+ """Test decoding a vector of strings works for Python native list type."""
924
+ src_type_dict = {
925
+ "kind": "Vector",
926
+ "element_type": {"kind": "Str"},
927
+ "dimension": None,
928
+ }
929
+ decoder = make_engine_value_decoder(
930
+ [], src_type_dict, analyze_type_info(Vector[str])
931
+ )
932
+ assert decoder(["hello", "world"]) == ["hello", "world"]
933
+
934
+
935
+ def test_decode_error_non_nullable_or_non_list_vector() -> None:
936
+ """Test decoding errors for non-nullable vectors or non-list inputs."""
937
+ src_type_dict = {
938
+ "kind": "Vector",
939
+ "element_type": {"kind": "Float32"},
940
+ "dimension": None,
941
+ }
942
+ decoder = make_engine_value_decoder(
943
+ [], src_type_dict, analyze_type_info(NDArrayFloat32Type)
944
+ )
945
+ with pytest.raises(ValueError, match="Received null for non-nullable vector"):
946
+ decoder(None)
947
+ with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
948
+ decoder("not a list")
949
+
950
+
951
+ def test_dump_vector_type_annotation_with_dim() -> None:
952
+ """Test dumping a vector type annotation with a specified dimension."""
953
+ expected_dump = {
954
+ "type": {
955
+ "kind": "Vector",
956
+ "element_type": {"kind": "Float32"},
957
+ "dimension": 3,
958
+ }
959
+ }
960
+ assert dump_engine_object(Float32VectorType) == expected_dump
961
+
962
+
963
+ def test_dump_vector_type_annotation_no_dim() -> None:
964
+ """Test dumping a vector type annotation with no dimension."""
965
+ expected_dump_no_dim = {
966
+ "type": {
967
+ "kind": "Vector",
968
+ "element_type": {"kind": "Float64"},
969
+ "dimension": None,
970
+ }
971
+ }
972
+ assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
973
+
974
+
975
+ def test_full_roundtrip_vector_numeric_types() -> None:
976
+ """Test full roundtrip for numeric vector types using NDArray."""
977
+ value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
978
+ validate_full_roundtrip(
979
+ value_f32,
980
+ Vector[np.float32, Literal[3]],
981
+ ([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
982
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
983
+ ([1.0, 2.0, 3.0], list[float]),
984
+ )
985
+ validate_full_roundtrip(
986
+ value_f32,
987
+ np.typing.NDArray[np.float32],
988
+ ([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
989
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
990
+ ([1.0, 2.0, 3.0], list[float]),
991
+ )
992
+ validate_full_roundtrip(
993
+ value_f32.tolist(),
994
+ list[np.float32],
995
+ (value_f32, Vector[np.float32, Literal[3]]),
996
+ ([1.0, 2.0, 3.0], list[cocoindex.Float32]),
997
+ ([1.0, 2.0, 3.0], list[float]),
998
+ )
999
+
1000
+ value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
1001
+ validate_full_roundtrip(
1002
+ value_f64,
1003
+ Vector[np.float64, Literal[3]],
1004
+ ([np.float64(1.0), np.float64(2.0), np.float64(3.0)], list[np.float64]),
1005
+ ([1.0, 2.0, 3.0], list[cocoindex.Float64]),
1006
+ ([1.0, 2.0, 3.0], list[float]),
1007
+ )
1008
+
1009
+ value_i64 = np.array([1, 2, 3], dtype=np.int64)
1010
+ validate_full_roundtrip(
1011
+ value_i64,
1012
+ Vector[np.int64, Literal[3]],
1013
+ ([np.int64(1), np.int64(2), np.int64(3)], list[np.int64]),
1014
+ ([1, 2, 3], list[int]),
1015
+ )
1016
+
1017
+ value_i32 = np.array([1, 2, 3], dtype=np.int32)
1018
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1019
+ validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
1020
+ value_u8 = np.array([1, 2, 3], dtype=np.uint8)
1021
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1022
+ validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
1023
+ value_u16 = np.array([1, 2, 3], dtype=np.uint16)
1024
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1025
+ validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
1026
+ value_u32 = np.array([1, 2, 3], dtype=np.uint32)
1027
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1028
+ validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
1029
+ value_u64 = np.array([1, 2, 3], dtype=np.uint64)
1030
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1031
+ validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
1032
+
1033
+
1034
+ def test_full_roundtrip_vector_of_vector() -> None:
1035
+ """Test full roundtrip for vector of vector."""
1036
+ value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1037
+ validate_full_roundtrip(
1038
+ value_f32,
1039
+ Vector[Vector[np.float32, Literal[3]], Literal[2]],
1040
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
1041
+ ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
1042
+ (
1043
+ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1044
+ list[Vector[cocoindex.Float32, Literal[3]]],
1045
+ ),
1046
+ (
1047
+ value_f32,
1048
+ np.typing.NDArray[np.float32],
1049
+ ),
1050
+ )
1051
+
1052
+
1053
+ def test_full_roundtrip_vector_other_types() -> None:
1054
+ """Test full roundtrip for Vector with non-numeric basic types."""
1055
+ uuid_list = [uuid.uuid4(), uuid.uuid4()]
1056
+ validate_full_roundtrip(uuid_list, Vector[uuid.UUID], (uuid_list, list[uuid.UUID]))
1057
+
1058
+ date_list = [datetime.date(2023, 1, 1), datetime.date(2024, 10, 5)]
1059
+ validate_full_roundtrip(
1060
+ date_list, Vector[datetime.date], (date_list, list[datetime.date])
1061
+ )
1062
+
1063
+ bool_list = [True, False, True, False]
1064
+ validate_full_roundtrip(bool_list, Vector[bool], (bool_list, list[bool]))
1065
+
1066
+ validate_full_roundtrip([], Vector[uuid.UUID], ([], list[uuid.UUID]))
1067
+ validate_full_roundtrip([], Vector[datetime.date], ([], list[datetime.date]))
1068
+ validate_full_roundtrip([], Vector[bool], ([], list[bool]))
1069
+
1070
+
1071
+ def test_roundtrip_vector_no_dimension() -> None:
1072
+ """Test full roundtrip for vector types without dimension annotation."""
1073
+ value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
1074
+ validate_full_roundtrip(
1075
+ value_f64,
1076
+ Vector[np.float64],
1077
+ ([1.0, 2.0, 3.0], list[float]),
1078
+ (np.array([1.0, 2.0, 3.0], dtype=np.float64), np.typing.NDArray[np.float64]),
1079
+ )
1080
+
1081
+
1082
+ def test_roundtrip_string_vector() -> None:
1083
+ """Test full roundtrip for string vector using list."""
1084
+ value_str: Vector[str] = ["hello", "world"]
1085
+ validate_full_roundtrip(value_str, Vector[str])
1086
+
1087
+
1088
+ def test_roundtrip_empty_vector() -> None:
1089
+ """Test full roundtrip for empty numeric vector."""
1090
+ value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
1091
+ validate_full_roundtrip(value_empty, Vector[np.float32])
1092
+
1093
+
1094
+ def test_roundtrip_dimension_mismatch() -> None:
1095
+ """Test that dimension mismatch raises an error during roundtrip."""
1096
+ value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
1097
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
1098
+ validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
1099
+
1100
+
1101
+ def test_full_roundtrip_scalar_numeric_types() -> None:
1102
+ """Test full roundtrip for scalar NumPy numeric types."""
1103
+ # Test supported scalar types
1104
+ validate_full_roundtrip(np.int64(42), np.int64, (42, int))
1105
+ validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, cocoindex.Float32))
1106
+ validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, cocoindex.Float64))
1107
+
1108
+ # Test unsupported scalar types
1109
+ for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
1110
+ with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
1111
+ validate_full_roundtrip(unsupported_type(1), unsupported_type)
1112
+
1113
+
1114
+ def test_full_roundtrip_nullable_scalar() -> None:
1115
+ """Test full roundtrip for nullable scalar NumPy types."""
1116
+ # Test with non-null values
1117
+ validate_full_roundtrip(np.int64(42), np.int64 | None)
1118
+ validate_full_roundtrip(np.float32(3.14), np.float32 | None)
1119
+ validate_full_roundtrip(np.float64(2.718), np.float64 | None)
1120
+
1121
+ # Test with None
1122
+ validate_full_roundtrip(None, np.int64 | None)
1123
+ validate_full_roundtrip(None, np.float32 | None)
1124
+ validate_full_roundtrip(None, np.float64 | None)
1125
+
1126
+
1127
+ def test_full_roundtrip_scalar_in_struct() -> None:
1128
+ """Test full roundtrip for scalar NumPy types in a dataclass."""
1129
+
1130
+ @dataclass
1131
+ class NumericStruct:
1132
+ int_field: np.int64
1133
+ float32_field: np.float32
1134
+ float64_field: np.float64
1135
+
1136
+ instance = NumericStruct(
1137
+ int_field=np.int64(42),
1138
+ float32_field=np.float32(3.14),
1139
+ float64_field=np.float64(2.718),
1140
+ )
1141
+ validate_full_roundtrip(instance, NumericStruct)
1142
+
1143
+
1144
+ def test_full_roundtrip_scalar_in_nested_struct() -> None:
1145
+ """Test full roundtrip for scalar NumPy types in a nested struct."""
1146
+
1147
+ @dataclass
1148
+ class InnerStruct:
1149
+ value: np.float64
1150
+
1151
+ @dataclass
1152
+ class OuterStruct:
1153
+ inner: InnerStruct
1154
+ count: np.int64
1155
+
1156
+ instance = OuterStruct(
1157
+ inner=InnerStruct(value=np.float64(2.718)),
1158
+ count=np.int64(1),
1159
+ )
1160
+ validate_full_roundtrip(instance, OuterStruct)
1161
+
1162
+
1163
+ def test_full_roundtrip_scalar_with_python_types() -> None:
1164
+ """Test full roundtrip for structs mixing NumPy and Python scalar types."""
1165
+
1166
+ @dataclass
1167
+ class MixedStruct:
1168
+ numpy_int: np.int64
1169
+ python_int: int
1170
+ numpy_float: np.float64
1171
+ python_float: float
1172
+ string: str
1173
+ annotated_int: Annotated[np.int64, TypeKind("Int64")]
1174
+ annotated_float: Float32
1175
+
1176
+ instance = MixedStruct(
1177
+ numpy_int=np.int64(42),
1178
+ python_int=43,
1179
+ numpy_float=np.float64(2.718),
1180
+ python_float=3.14,
1181
+ string="hello, world",
1182
+ annotated_int=np.int64(42),
1183
+ annotated_float=2.0,
1184
+ )
1185
+ validate_full_roundtrip(instance, MixedStruct)
1186
+
1187
+
1188
+ def test_roundtrip_simple_struct_to_dict_binding() -> None:
1189
+ """Test struct -> dict binding with Any annotation."""
1190
+
1191
+ @dataclass
1192
+ class SimpleStruct:
1193
+ first_name: str
1194
+ last_name: str
1195
+
1196
+ instance = SimpleStruct("John", "Doe")
1197
+ expected_dict = {"first_name": "John", "last_name": "Doe"}
1198
+
1199
+ # Test Any annotation
1200
+ validate_full_roundtrip(
1201
+ instance,
1202
+ SimpleStruct,
1203
+ (expected_dict, Any),
1204
+ (expected_dict, dict),
1205
+ (expected_dict, dict[Any, Any]),
1206
+ (expected_dict, dict[str, Any]),
1207
+ # For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
1208
+ (expected_dict, dict[Any, str]),
1209
+ (expected_dict, dict[str, str]),
1210
+ )
1211
+
1212
+ with pytest.raises(ValueError):
1213
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
1214
+
1215
+ with pytest.raises(ValueError):
1216
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1217
+
1218
+
1219
+ def test_roundtrip_struct_to_dict_binding() -> None:
1220
+ """Test struct -> dict binding with Any annotation."""
1221
+
1222
+ @dataclass
1223
+ class SimpleStruct:
1224
+ name: str
1225
+ value: int
1226
+ price: float
1227
+
1228
+ instance = SimpleStruct("test", 42, 3.14)
1229
+ expected_dict = {"name": "test", "value": 42, "price": 3.14}
1230
+
1231
+ # Test Any annotation
1232
+ validate_full_roundtrip(
1233
+ instance,
1234
+ SimpleStruct,
1235
+ (expected_dict, Any),
1236
+ (expected_dict, dict),
1237
+ (expected_dict, dict[Any, Any]),
1238
+ (expected_dict, dict[str, Any]),
1239
+ )
1240
+
1241
+ with pytest.raises(ValueError):
1242
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
1243
+
1244
+ with pytest.raises(ValueError):
1245
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1246
+
1247
+
1248
+ def test_roundtrip_struct_to_dict_explicit() -> None:
1249
+ """Test struct -> dict binding with explicit dict annotations."""
1250
+
1251
+ @dataclass
1252
+ class Product:
1253
+ id: str
1254
+ name: str
1255
+ price: float
1256
+ active: bool
1257
+
1258
+ instance = Product("P1", "Widget", 29.99, True)
1259
+ expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}
1260
+
1261
+ # Test explicit dict annotations
1262
+ validate_full_roundtrip(
1263
+ instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
1264
+ )
1265
+
1266
+
1267
+ def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
1268
+ """Test struct -> dict binding with None annotation."""
1269
+
1270
+ @dataclass
1271
+ class Config:
1272
+ host: str
1273
+ port: int
1274
+ debug: bool
1275
+
1276
+ instance = Config("localhost", 8080, True)
1277
+ expected_dict = {"host": "localhost", "port": 8080, "debug": True}
1278
+
1279
+ # Test empty annotation (should be treated as Any)
1280
+ validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
1281
+
1282
+
1283
+ def test_roundtrip_struct_to_dict_nested() -> None:
1284
+ """Test struct -> dict binding with nested structs."""
1285
+
1286
+ @dataclass
1287
+ class Address:
1288
+ street: str
1289
+ city: str
1290
+
1291
+ @dataclass
1292
+ class Person:
1293
+ name: str
1294
+ age: int
1295
+ address: Address
1296
+
1297
+ address = Address("123 Main St", "Anytown")
1298
+ person = Person("John", 30, address)
1299
+ expected_dict = {
1300
+ "name": "John",
1301
+ "age": 30,
1302
+ "address": {"street": "123 Main St", "city": "Anytown"},
1303
+ }
1304
+
1305
+ # Test nested struct conversion
1306
+ validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))
1307
+
1308
+
1309
+ def test_roundtrip_struct_to_dict_with_list() -> None:
1310
+ """Test struct -> dict binding with list fields."""
1311
+
1312
+ @dataclass
1313
+ class Team:
1314
+ name: str
1315
+ members: list[str]
1316
+ active: bool
1317
+
1318
+ instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
1319
+ expected_dict = {
1320
+ "name": "Dev Team",
1321
+ "members": ["Alice", "Bob", "Charlie"],
1322
+ "active": True,
1323
+ }
1324
+
1325
+ validate_full_roundtrip(instance, Team, (expected_dict, dict))
1326
+
1327
+
1328
+ def test_roundtrip_namedtuple_to_dict_binding() -> None:
1329
+ """Test NamedTuple -> dict binding."""
1330
+
1331
+ class Point(NamedTuple):
1332
+ x: float
1333
+ y: float
1334
+ z: float
1335
+
1336
+ instance = Point(1.0, 2.0, 3.0)
1337
+ expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}
1338
+
1339
+ validate_full_roundtrip(
1340
+ instance, Point, (expected_dict, dict), (expected_dict, Any)
1341
+ )
1342
+
1343
+
1344
+ def test_roundtrip_ltable_to_list_dict_binding() -> None:
1345
+ """Test LTable -> list[dict] binding with Any annotation."""
1346
+
1347
+ @dataclass
1348
+ class User:
1349
+ id: str
1350
+ name: str
1351
+ age: int
1352
+
1353
+ users = [User("u1", "Alice", 25), User("u2", "Bob", 30), User("u3", "Charlie", 35)]
1354
+ expected_list_dict = [
1355
+ {"id": "u1", "name": "Alice", "age": 25},
1356
+ {"id": "u2", "name": "Bob", "age": 30},
1357
+ {"id": "u3", "name": "Charlie", "age": 35},
1358
+ ]
1359
+
1360
+ # Test Any annotation
1361
+ validate_full_roundtrip(
1362
+ users,
1363
+ list[User],
1364
+ (expected_list_dict, Any),
1365
+ (expected_list_dict, list[Any]),
1366
+ (expected_list_dict, list[dict[str, Any]]),
1367
+ )
1368
+
1369
+
1370
+ def test_roundtrip_ktable_to_dict_dict_binding() -> None:
1371
+ """Test KTable -> dict[K, dict] binding with Any annotation."""
1372
+
1373
+ @dataclass
1374
+ class Product:
1375
+ name: str
1376
+ price: float
1377
+ active: bool
1378
+
1379
+ products = {
1380
+ "p1": Product("Widget", 29.99, True),
1381
+ "p2": Product("Gadget", 49.99, False),
1382
+ "p3": Product("Tool", 19.99, True),
1383
+ }
1384
+ expected_dict_dict = {
1385
+ "p1": {"name": "Widget", "price": 29.99, "active": True},
1386
+ "p2": {"name": "Gadget", "price": 49.99, "active": False},
1387
+ "p3": {"name": "Tool", "price": 19.99, "active": True},
1388
+ }
1389
+
1390
+ # Test Any annotation
1391
+ validate_full_roundtrip(
1392
+ products,
1393
+ dict[str, Product],
1394
+ (expected_dict_dict, Any),
1395
+ (expected_dict_dict, dict),
1396
+ (expected_dict_dict, dict[Any, Any]),
1397
+ (expected_dict_dict, dict[str, Any]),
1398
+ (expected_dict_dict, dict[Any, dict[Any, Any]]),
1399
+ (expected_dict_dict, dict[str, dict[Any, Any]]),
1400
+ (expected_dict_dict, dict[str, dict[str, Any]]),
1401
+ )
1402
+
1403
+
1404
+ def test_roundtrip_ktable_with_complex_key() -> None:
1405
+ """Test KTable with complex key types -> dict binding."""
1406
+
1407
+ @dataclass(frozen=True)
1408
+ class OrderKey:
1409
+ shop_id: str
1410
+ version: int
1411
+
1412
+ @dataclass
1413
+ class Order:
1414
+ customer: str
1415
+ total: float
1416
+
1417
+ orders = {
1418
+ OrderKey("shop1", 1): Order("Alice", 100.0),
1419
+ OrderKey("shop2", 2): Order("Bob", 200.0),
1420
+ }
1421
+ expected_dict_dict = {
1422
+ ("shop1", 1): {"customer": "Alice", "total": 100.0},
1423
+ ("shop2", 2): {"customer": "Bob", "total": 200.0},
1424
+ }
1425
+
1426
+ # Test Any annotation
1427
+ validate_full_roundtrip(
1428
+ orders,
1429
+ dict[OrderKey, Order],
1430
+ (expected_dict_dict, Any),
1431
+ (expected_dict_dict, dict),
1432
+ (expected_dict_dict, dict[Any, Any]),
1433
+ (expected_dict_dict, dict[Any, dict[str, Any]]),
1434
+ (
1435
+ {
1436
+ ("shop1", 1): Order("Alice", 100.0),
1437
+ ("shop2", 2): Order("Bob", 200.0),
1438
+ },
1439
+ dict[Any, Order],
1440
+ ),
1441
+ (
1442
+ {
1443
+ OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
1444
+ OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
1445
+ },
1446
+ dict[OrderKey, Any],
1447
+ ),
1448
+ )
1449
+
1450
+
1451
+ def test_roundtrip_ltable_with_nested_structs() -> None:
1452
+ """Test LTable with nested structs -> list[dict] binding."""
1453
+
1454
+ @dataclass
1455
+ class Address:
1456
+ street: str
1457
+ city: str
1458
+
1459
+ @dataclass
1460
+ class Person:
1461
+ name: str
1462
+ age: int
1463
+ address: Address
1464
+
1465
+ people = [
1466
+ Person("John", 30, Address("123 Main St", "Anytown")),
1467
+ Person("Jane", 25, Address("456 Oak Ave", "Somewhere")),
1468
+ ]
1469
+ expected_list_dict = [
1470
+ {
1471
+ "name": "John",
1472
+ "age": 30,
1473
+ "address": {"street": "123 Main St", "city": "Anytown"},
1474
+ },
1475
+ {
1476
+ "name": "Jane",
1477
+ "age": 25,
1478
+ "address": {"street": "456 Oak Ave", "city": "Somewhere"},
1479
+ },
1480
+ ]
1481
+
1482
+ # Test Any annotation
1483
+ validate_full_roundtrip(people, list[Person], (expected_list_dict, Any))
1484
+
1485
+
1486
+ def test_roundtrip_ktable_with_list_fields() -> None:
1487
+ """Test KTable with list fields -> dict binding."""
1488
+
1489
+ @dataclass
1490
+ class Team:
1491
+ name: str
1492
+ members: list[str]
1493
+ active: bool
1494
+
1495
+ teams = {
1496
+ "team1": Team("Dev Team", ["Alice", "Bob"], True),
1497
+ "team2": Team("QA Team", ["Charlie", "David"], False),
1498
+ }
1499
+ expected_dict_dict = {
1500
+ "team1": {"name": "Dev Team", "members": ["Alice", "Bob"], "active": True},
1501
+ "team2": {"name": "QA Team", "members": ["Charlie", "David"], "active": False},
1502
+ }
1503
+
1504
+ # Test Any annotation
1505
+ validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
1506
+
1507
+
1508
+ def test_auto_default_for_supported_and_unsupported_types() -> None:
1509
+ @dataclass
1510
+ class Base:
1511
+ a: int
1512
+
1513
+ @dataclass
1514
+ class NullableField:
1515
+ a: int
1516
+ b: int | None
1517
+
1518
+ @dataclass
1519
+ class LTableField:
1520
+ a: int
1521
+ b: list[Base]
1522
+
1523
+ @dataclass
1524
+ class KTableField:
1525
+ a: int
1526
+ b: dict[str, Base]
1527
+
1528
+ @dataclass
1529
+ class UnsupportedField:
1530
+ a: int
1531
+ b: int
1532
+
1533
+ validate_full_roundtrip(NullableField(1, None), NullableField)
1534
+
1535
+ validate_full_roundtrip(LTableField(1, []), LTableField)
1536
+
1537
+ validate_full_roundtrip(KTableField(1, {}), KTableField)
1538
+
1539
+ with pytest.raises(
1540
+ ValueError,
1541
+ match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
1542
+ ):
1543
+ build_engine_value_decoder(Base, UnsupportedField)