cocoindex 0.1.49__cp311-cp311-win_amd64.whl → 0.1.51__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,22 @@
1
1
  import uuid
2
2
  import datetime
3
3
  from dataclasses import dataclass, make_dataclass
4
- from typing import NamedTuple, Literal
4
+ from typing import NamedTuple, Literal, Any, Callable, Union
5
5
  import pytest
6
6
  import cocoindex
7
- from cocoindex.typing import encode_enriched_type
8
- from cocoindex.convert import encode_engine_value, make_engine_value_decoder
7
+ from cocoindex.typing import (
8
+ encode_enriched_type,
9
+ Vector,
10
+ Float32,
11
+ Float64,
12
+ )
13
+ from cocoindex.convert import (
14
+ encode_engine_value,
15
+ make_engine_value_decoder,
16
+ dump_engine_object,
17
+ )
18
+ import numpy as np
19
+ from numpy.typing import NDArray
9
20
 
10
21
 
11
22
  @dataclass
@@ -23,7 +34,7 @@ class Tag:
23
34
 
24
35
  @dataclass
25
36
  class Basket:
26
- items: list
37
+ items: list[str]
27
38
 
28
39
 
29
40
  @dataclass
@@ -53,7 +64,9 @@ class CustomerNamedTuple(NamedTuple):
53
64
  tags: list[Tag] | None = None
54
65
 
55
66
 
56
- def build_engine_value_decoder(engine_type_in_py, python_type=None):
67
+ def build_engine_value_decoder(
68
+ engine_type_in_py: Any, python_type: Any | None = None
69
+ ) -> Callable[[Any], Any]:
57
70
  """
58
71
  Helper to build a converter for the given engine-side type (as represented in Python).
59
72
  If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,19 +75,51 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
62
75
  return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
63
76
 
64
77
 
65
- def test_encode_engine_value_basic_types():
78
+ def validate_full_roundtrip(
79
+ value: Any,
80
+ value_type: Any = None,
81
+ *other_decoded_values: tuple[Any, Any],
82
+ ) -> None:
83
+ """
84
+ Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
85
+
86
+ `other_decoded_values` is a tuple of (value, type) pairs.
87
+ If provided, also validate the value can be decoded to the other types.
88
+ """
89
+ from cocoindex import _engine # type: ignore
90
+
91
+ encoded_value = encode_engine_value(value)
92
+ value_type = value_type or type(value)
93
+ encoded_output_type = encode_enriched_type(value_type)["type"]
94
+ value_from_engine = _engine.testutil.seder_roundtrip(
95
+ encoded_value, encoded_output_type
96
+ )
97
+ decoded_value = build_engine_value_decoder(value_type, value_type)(
98
+ value_from_engine
99
+ )
100
+ np.testing.assert_array_equal(decoded_value, value)
101
+
102
+ if other_decoded_values is not None:
103
+ for other_value, other_type in other_decoded_values:
104
+ other_decoded_value = build_engine_value_decoder(other_type, other_type)(
105
+ value_from_engine
106
+ )
107
+ np.testing.assert_array_equal(other_decoded_value, other_value)
108
+
109
+
110
+ def test_encode_engine_value_basic_types() -> None:
66
111
  assert encode_engine_value(123) == 123
67
112
  assert encode_engine_value(3.14) == 3.14
68
113
  assert encode_engine_value("hello") == "hello"
69
114
  assert encode_engine_value(True) is True
70
115
 
71
116
 
72
- def test_encode_engine_value_uuid():
117
+ def test_encode_engine_value_uuid() -> None:
73
118
  u = uuid.uuid4()
74
119
  assert encode_engine_value(u) == u.bytes
75
120
 
76
121
 
77
- def test_encode_engine_value_date_time_types():
122
+ def test_encode_engine_value_date_time_types() -> None:
78
123
  d = datetime.date(2024, 1, 1)
79
124
  assert encode_engine_value(d) == d
80
125
  t = datetime.time(12, 30)
@@ -83,7 +128,7 @@ def test_encode_engine_value_date_time_types():
83
128
  assert encode_engine_value(dt) == dt
84
129
 
85
130
 
86
- def test_encode_engine_value_struct():
131
+ def test_encode_engine_value_struct() -> None:
87
132
  order = Order(order_id="O123", name="mixed nuts", price=25.0)
88
133
  assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
89
134
 
@@ -96,7 +141,7 @@ def test_encode_engine_value_struct():
96
141
  ]
97
142
 
98
143
 
99
- def test_encode_engine_value_list_of_structs():
144
+ def test_encode_engine_value_list_of_structs() -> None:
100
145
  orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
101
146
  assert encode_engine_value(orders) == [
102
147
  ["O1", "item1", 10.0, "default_extra"],
@@ -113,12 +158,12 @@ def test_encode_engine_value_list_of_structs():
113
158
  ]
114
159
 
115
160
 
116
- def test_encode_engine_value_struct_with_list():
161
+ def test_encode_engine_value_struct_with_list() -> None:
117
162
  basket = Basket(items=["apple", "banana"])
118
163
  assert encode_engine_value(basket) == [["apple", "banana"]]
119
164
 
120
165
 
121
- def test_encode_engine_value_nested_struct():
166
+ def test_encode_engine_value_nested_struct() -> None:
122
167
  customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
123
168
  assert encode_engine_value(customer) == [
124
169
  "Alice",
@@ -136,12 +181,12 @@ def test_encode_engine_value_nested_struct():
136
181
  ]
137
182
 
138
183
 
139
- def test_encode_engine_value_empty_list():
184
+ def test_encode_engine_value_empty_list() -> None:
140
185
  assert encode_engine_value([]) == []
141
186
  assert encode_engine_value([[]]) == [[]]
142
187
 
143
188
 
144
- def test_encode_engine_value_tuple():
189
+ def test_encode_engine_value_tuple() -> None:
145
190
  assert encode_engine_value(()) == []
146
191
  assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
147
192
  assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
@@ -149,20 +194,23 @@ def test_encode_engine_value_tuple():
149
194
  assert encode_engine_value(((),)) == [[]]
150
195
 
151
196
 
152
- def test_encode_engine_value_none():
197
+ def test_encode_engine_value_none() -> None:
153
198
  assert encode_engine_value(None) is None
154
199
 
155
200
 
156
- def test_make_engine_value_decoder_basic_types():
157
- for engine_type_in_py, value in [
158
- (int, 42),
159
- (float, 3.14),
160
- (str, "hello"),
161
- (bool, True),
162
- # (type(None), None), # Removed unsupported NoneType
163
- ]:
164
- decoder = build_engine_value_decoder(engine_type_in_py)
165
- assert decoder(value) == value
201
+ def test_roundtrip_basic_types() -> None:
202
+ validate_full_roundtrip(42, int)
203
+ validate_full_roundtrip(3.25, float, (3.25, Float64))
204
+ validate_full_roundtrip(3.25, Float64, (3.25, float))
205
+ validate_full_roundtrip(3.25, Float32)
206
+ validate_full_roundtrip("hello", str)
207
+ validate_full_roundtrip(True, bool)
208
+ validate_full_roundtrip(False, bool)
209
+ validate_full_roundtrip(datetime.date(2025, 1, 1), datetime.date)
210
+ validate_full_roundtrip(datetime.datetime.now(), cocoindex.LocalDateTime)
211
+ validate_full_roundtrip(
212
+ datetime.datetime.now(datetime.UTC), cocoindex.OffsetDateTime
213
+ )
166
214
 
167
215
 
168
216
  @pytest.mark.parametrize(
@@ -280,18 +328,18 @@ def test_make_engine_value_decoder_basic_types():
280
328
  ),
281
329
  ],
282
330
  )
283
- def test_struct_decoder_cases(data_type, engine_val, expected):
331
+ def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
284
332
  decoder = build_engine_value_decoder(data_type)
285
333
  assert decoder(engine_val) == expected
286
334
 
287
335
 
288
- def test_make_engine_value_decoder_collections():
336
+ def test_make_engine_value_decoder_list_of_struct() -> None:
289
337
  # List of structs (dataclass)
290
- decoder = build_engine_value_decoder(list[Order])
291
338
  engine_val = [
292
339
  ["O1", "item1", 10.0, "default_extra"],
293
340
  ["O2", "item2", 20.0, "default_extra"],
294
341
  ]
342
+ decoder = build_engine_value_decoder(list[Order])
295
343
  assert decoder(engine_val) == [
296
344
  Order("O1", "item1", 10.0, "default_extra"),
297
345
  Order("O2", "item2", 20.0, "default_extra"),
@@ -304,13 +352,15 @@ def test_make_engine_value_decoder_collections():
304
352
  OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
305
353
  ]
306
354
 
355
+
356
+ def test_make_engine_value_decoder_struct_of_list() -> None:
307
357
  # Struct with list field
308
- decoder = build_engine_value_decoder(Customer)
309
358
  engine_val = [
310
359
  "Alice",
311
360
  ["O1", "item1", 10.0, "default_extra"],
312
361
  [["vip"], ["premium"]],
313
362
  ]
363
+ decoder = build_engine_value_decoder(Customer)
314
364
  assert decoder(engine_val) == Customer(
315
365
  "Alice",
316
366
  Order("O1", "item1", 10.0, "default_extra"),
@@ -325,8 +375,9 @@ def test_make_engine_value_decoder_collections():
325
375
  [Tag("vip"), Tag("premium")],
326
376
  )
327
377
 
378
+
379
+ def test_make_engine_value_decoder_struct_of_struct() -> None:
328
380
  # Struct with struct field
329
- decoder = build_engine_value_decoder(NestedStruct)
330
381
  engine_val = [
331
382
  ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
332
383
  [
@@ -335,6 +386,7 @@ def test_make_engine_value_decoder_collections():
335
386
  ],
336
387
  2,
337
388
  ]
389
+ decoder = build_engine_value_decoder(NestedStruct)
338
390
  assert decoder(engine_val) == NestedStruct(
339
391
  Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
340
392
  [
@@ -345,11 +397,13 @@ def test_make_engine_value_decoder_collections():
345
397
  )
346
398
 
347
399
 
348
- def make_engine_order(fields):
400
+ def make_engine_order(fields: list[tuple[str, type]]) -> type:
349
401
  return make_dataclass("EngineOrder", fields)
350
402
 
351
403
 
352
- def make_python_order(fields, defaults=None):
404
+ def make_python_order(
405
+ fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
406
+ ) -> type:
353
407
  if defaults is None:
354
408
  defaults = {}
355
409
  # Move all fields with defaults to the end (Python dataclass requirement)
@@ -423,8 +477,12 @@ def make_python_order(fields, defaults=None):
423
477
  ],
424
478
  )
425
479
  def test_field_position_cases(
426
- engine_fields, python_fields, python_defaults, engine_val, expected_python_val
427
- ):
480
+ engine_fields: list[tuple[str, type]],
481
+ python_fields: list[tuple[str, type]],
482
+ python_defaults: dict[str, Any],
483
+ engine_val: list[Any],
484
+ expected_python_val: tuple[Any, ...],
485
+ ) -> None:
428
486
  EngineOrder = make_engine_order(engine_fields)
429
487
  PythonOrder = make_python_order(python_fields, python_defaults)
430
488
  decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
@@ -434,57 +492,33 @@ def test_field_position_cases(
434
492
  assert decoder(engine_val) == PythonOrder(**expected_dict)
435
493
 
436
494
 
437
- def test_roundtrip_ltable():
495
+ def test_roundtrip_ltable() -> None:
438
496
  t = list[Order]
439
497
  value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
440
- encoded = encode_engine_value(value)
441
- assert encoded == [
442
- ["O1", "item1", 10.0, "default_extra"],
443
- ["O2", "item2", 20.0, "default_extra"],
444
- ]
445
- decoded = build_engine_value_decoder(t)(encoded)
446
- assert decoded == value
498
+ validate_full_roundtrip(value, t)
447
499
 
448
500
  t_nt = list[OrderNamedTuple]
449
501
  value_nt = [
450
502
  OrderNamedTuple("O1", "item1", 10.0),
451
503
  OrderNamedTuple("O2", "item2", 20.0),
452
504
  ]
453
- encoded = encode_engine_value(value_nt)
454
- assert encoded == [
455
- ["O1", "item1", 10.0, "default_extra"],
456
- ["O2", "item2", 20.0, "default_extra"],
457
- ]
458
- decoded = build_engine_value_decoder(t_nt)(encoded)
459
- assert decoded == value_nt
505
+ validate_full_roundtrip(value_nt, t_nt)
460
506
 
461
507
 
462
- def test_roundtrip_ktable_str_key():
508
+ def test_roundtrip_ktable_str_key() -> None:
463
509
  t = dict[str, Order]
464
510
  value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
465
- encoded = encode_engine_value(value)
466
- assert encoded == [
467
- ["K1", "O1", "item1", 10.0, "default_extra"],
468
- ["K2", "O2", "item2", 20.0, "default_extra"],
469
- ]
470
- decoded = build_engine_value_decoder(t)(encoded)
471
- assert decoded == value
511
+ validate_full_roundtrip(value, t)
472
512
 
473
513
  t_nt = dict[str, OrderNamedTuple]
474
514
  value_nt = {
475
515
  "K1": OrderNamedTuple("O1", "item1", 10.0),
476
516
  "K2": OrderNamedTuple("O2", "item2", 20.0),
477
517
  }
478
- encoded = encode_engine_value(value_nt)
479
- assert encoded == [
480
- ["K1", "O1", "item1", 10.0, "default_extra"],
481
- ["K2", "O2", "item2", 20.0, "default_extra"],
482
- ]
483
- decoded = build_engine_value_decoder(t_nt)(encoded)
484
- assert decoded == value_nt
518
+ validate_full_roundtrip(value_nt, t_nt)
485
519
 
486
520
 
487
- def test_roundtrip_ktable_struct_key():
521
+ def test_roundtrip_ktable_struct_key() -> None:
488
522
  @dataclass(frozen=True)
489
523
  class OrderKey:
490
524
  shop_id: str
@@ -495,37 +529,25 @@ def test_roundtrip_ktable_struct_key():
495
529
  OrderKey("A", 3): Order("O1", "item1", 10.0),
496
530
  OrderKey("B", 4): Order("O2", "item2", 20.0),
497
531
  }
498
- encoded = encode_engine_value(value)
499
- assert encoded == [
500
- [["A", 3], "O1", "item1", 10.0, "default_extra"],
501
- [["B", 4], "O2", "item2", 20.0, "default_extra"],
502
- ]
503
- decoded = build_engine_value_decoder(t)(encoded)
504
- assert decoded == value
532
+ validate_full_roundtrip(value, t)
505
533
 
506
534
  t_nt = dict[OrderKey, OrderNamedTuple]
507
535
  value_nt = {
508
536
  OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
509
537
  OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
510
538
  }
511
- encoded = encode_engine_value(value_nt)
512
- assert encoded == [
513
- [["A", 3], "O1", "item1", 10.0, "default_extra"],
514
- [["B", 4], "O2", "item2", 20.0, "default_extra"],
515
- ]
516
- decoded = build_engine_value_decoder(t_nt)(encoded)
517
- assert decoded == value_nt
539
+ validate_full_roundtrip(value_nt, t_nt)
518
540
 
519
541
 
520
- IntVectorType = cocoindex.Vector[int, Literal[5]]
542
+ IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
521
543
 
522
544
 
523
545
  def test_vector_as_vector() -> None:
524
- value: IntVectorType = [1, 2, 3, 4, 5]
546
+ value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
525
547
  encoded = encode_engine_value(value)
526
- assert encoded == [1, 2, 3, 4, 5]
548
+ assert np.array_equal(encoded, value)
527
549
  decoded = build_engine_value_decoder(IntVectorType)(encoded)
528
- assert decoded == value
550
+ assert np.array_equal(decoded, value)
529
551
 
530
552
 
531
553
  ListIntType = list[int]
@@ -536,4 +558,257 @@ def test_vector_as_list() -> None:
536
558
  encoded = encode_engine_value(value)
537
559
  assert encoded == [1, 2, 3, 4, 5]
538
560
  decoded = build_engine_value_decoder(ListIntType)(encoded)
539
- assert decoded == value
561
+ assert np.array_equal(decoded, value)
562
+
563
+
564
+ Float64VectorTypeNoDim = Vector[np.float64]
565
+ Float32VectorType = Vector[np.float32, Literal[3]]
566
+ Float64VectorType = Vector[np.float64, Literal[3]]
567
+ Int64VectorType = Vector[np.int64, Literal[3]]
568
+ Int32VectorType = Vector[np.int32, Literal[3]]
569
+ UInt8VectorType = Vector[np.uint8, Literal[3]]
570
+ UInt16VectorType = Vector[np.uint16, Literal[3]]
571
+ UInt32VectorType = Vector[np.uint32, Literal[3]]
572
+ UInt64VectorType = Vector[np.uint64, Literal[3]]
573
+ StrVectorType = Vector[str]
574
+ NDArrayFloat32Type = NDArray[np.float32]
575
+ NDArrayFloat64Type = NDArray[np.float64]
576
+ NDArrayInt64Type = NDArray[np.int64]
577
+
578
+
579
+ def test_encode_engine_value_ndarray() -> None:
580
+ """Test encoding NDArray vectors to lists for the Rust engine."""
581
+ vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
582
+ assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
583
+ vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
584
+ assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
585
+ vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
586
+ assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
587
+ vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
588
+ assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
589
+
590
+
591
+ def test_make_engine_value_decoder_ndarray() -> None:
592
+ """Test decoding engine lists to NDArray vectors."""
593
+ decoder_f32 = build_engine_value_decoder(Float32VectorType)
594
+ result_f32 = decoder_f32([1.0, 2.0, 3.0])
595
+ assert isinstance(result_f32, np.ndarray)
596
+ assert result_f32.dtype == np.float32
597
+ assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
598
+ decoder_f64 = build_engine_value_decoder(Float64VectorType)
599
+ result_f64 = decoder_f64([1.0, 2.0, 3.0])
600
+ assert isinstance(result_f64, np.ndarray)
601
+ assert result_f64.dtype == np.float64
602
+ assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
603
+ decoder_i64 = build_engine_value_decoder(Int64VectorType)
604
+ result_i64 = decoder_i64([1, 2, 3])
605
+ assert isinstance(result_i64, np.ndarray)
606
+ assert result_i64.dtype == np.int64
607
+ assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
608
+ decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
609
+ result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
610
+ assert isinstance(result_nd_f32, np.ndarray)
611
+ assert result_nd_f32.dtype == np.float32
612
+ assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
613
+
614
+
615
+ def test_roundtrip_ndarray_vector() -> None:
616
+ """Test roundtrip encoding and decoding of NDArray vectors."""
617
+ value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
618
+ encoded_f32 = encode_engine_value(value_f32)
619
+ np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
620
+ decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
621
+ assert isinstance(decoded_f32, np.ndarray)
622
+ assert decoded_f32.dtype == np.float32
623
+ assert np.array_equal(decoded_f32, value_f32)
624
+ value_i64 = np.array([1, 2, 3], dtype=np.int64)
625
+ encoded_i64 = encode_engine_value(value_i64)
626
+ assert np.array_equal(encoded_i64, [1, 2, 3])
627
+ decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
628
+ assert isinstance(decoded_i64, np.ndarray)
629
+ assert decoded_i64.dtype == np.int64
630
+ assert np.array_equal(decoded_i64, value_i64)
631
+ value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
632
+ encoded_nd_f64 = encode_engine_value(value_nd_f64)
633
+ assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
634
+ decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
635
+ assert isinstance(decoded_nd_f64, np.ndarray)
636
+ assert decoded_nd_f64.dtype == np.float64
637
+ assert np.array_equal(decoded_nd_f64, value_nd_f64)
638
+
639
+
640
+ def test_ndarray_dimension_mismatch() -> None:
641
+ """Test dimension enforcement for Vector with specified dimension."""
642
+ value = np.array([1.0, 2.0], dtype=np.float32)
643
+ encoded = encode_engine_value(value)
644
+ assert np.array_equal(encoded, [1.0, 2.0])
645
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
646
+ build_engine_value_decoder(Float32VectorType)(encoded)
647
+
648
+
649
+ def test_list_vector_backward_compatibility() -> None:
650
+ """Test that list-based vectors still work for backward compatibility."""
651
+ value = [1, 2, 3, 4, 5]
652
+ encoded = encode_engine_value(value)
653
+ assert encoded == [1, 2, 3, 4, 5]
654
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
655
+ assert isinstance(decoded, np.ndarray)
656
+ assert decoded.dtype == np.int64
657
+ assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
658
+ value_list: ListIntType = [1, 2, 3, 4, 5]
659
+ encoded = encode_engine_value(value_list)
660
+ assert np.array_equal(encoded, [1, 2, 3, 4, 5])
661
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
662
+ assert np.array_equal(decoded, [1, 2, 3, 4, 5])
663
+
664
+
665
+ def test_encode_complex_structure_with_ndarray() -> None:
666
+ """Test encoding a complex structure that includes an NDArray."""
667
+
668
+ @dataclass
669
+ class MyStructWithNDArray:
670
+ name: str
671
+ data: NDArray[np.float32]
672
+ value: int
673
+
674
+ original = MyStructWithNDArray(
675
+ name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
676
+ )
677
+ encoded = encode_engine_value(original)
678
+
679
+ assert encoded[0] == original.name
680
+ assert np.array_equal(encoded[1], original.data)
681
+ assert encoded[2] == original.value
682
+
683
+
684
+ def test_decode_nullable_ndarray_none_or_value_input() -> None:
685
+ """Test decoding a nullable NDArray with None or value inputs."""
686
+ src_type_dict = {
687
+ "kind": "Vector",
688
+ "element_type": {"kind": "Float32"},
689
+ "dimension": None,
690
+ }
691
+ dst_annotation = NDArrayFloat32Type | None
692
+ decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
693
+
694
+ none_engine_value = None
695
+ decoded_array = decoder(none_engine_value)
696
+ assert decoded_array is None
697
+
698
+ engine_value = [1.0, 2.0, 3.0]
699
+ decoded_array = decoder(engine_value)
700
+
701
+ assert isinstance(decoded_array, np.ndarray)
702
+ assert decoded_array.dtype == np.float32
703
+ np.testing.assert_array_equal(
704
+ decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
705
+ )
706
+
707
+
708
+ def test_decode_vector_string() -> None:
709
+ """Test decoding a vector of strings works for Python native list type."""
710
+ src_type_dict = {
711
+ "kind": "Vector",
712
+ "element_type": {"kind": "Str"},
713
+ "dimension": None,
714
+ }
715
+ decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
716
+ assert decoder(["hello", "world"]) == ["hello", "world"]
717
+
718
+
719
+ def test_decode_error_non_nullable_or_non_list_vector() -> None:
720
+ """Test decoding errors for non-nullable vectors or non-list inputs."""
721
+ src_type_dict = {
722
+ "kind": "Vector",
723
+ "element_type": {"kind": "Float32"},
724
+ "dimension": None,
725
+ }
726
+ decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
727
+ with pytest.raises(ValueError, match="Received null for non-nullable vector"):
728
+ decoder(None)
729
+ with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
730
+ decoder("not a list")
731
+
732
+
733
+ def test_dump_vector_type_annotation_with_dim() -> None:
734
+ """Test dumping a vector type annotation with a specified dimension."""
735
+ expected_dump = {
736
+ "type": {
737
+ "kind": "Vector",
738
+ "element_type": {"kind": "Float32"},
739
+ "dimension": 3,
740
+ }
741
+ }
742
+ assert dump_engine_object(Float32VectorType) == expected_dump
743
+
744
+
745
+ def test_dump_vector_type_annotation_no_dim() -> None:
746
+ """Test dumping a vector type annotation with no dimension."""
747
+ expected_dump_no_dim = {
748
+ "type": {
749
+ "kind": "Vector",
750
+ "element_type": {"kind": "Float64"},
751
+ "dimension": None,
752
+ }
753
+ }
754
+ assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
755
+
756
+
757
+ def test_full_roundtrip_vector_numeric_types() -> None:
758
+ """Test full roundtrip for numeric vector types using NDArray."""
759
+ value_f32: Vector[np.float32, Literal[3]] = np.array(
760
+ [1.0, 2.0, 3.0], dtype=np.float32
761
+ )
762
+ validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
763
+ value_f64: Vector[np.float64, Literal[3]] = np.array(
764
+ [1.0, 2.0, 3.0], dtype=np.float64
765
+ )
766
+ validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
767
+ value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
768
+ validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
769
+ value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
770
+ with pytest.raises(ValueError, match="type unsupported yet"):
771
+ validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
772
+ value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
773
+ with pytest.raises(ValueError, match="type unsupported yet"):
774
+ validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
775
+ value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
776
+ with pytest.raises(ValueError, match="type unsupported yet"):
777
+ validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
778
+ value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
779
+ with pytest.raises(ValueError, match="type unsupported yet"):
780
+ validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
781
+ value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
782
+ with pytest.raises(ValueError, match="type unsupported yet"):
783
+ validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
784
+
785
+
786
+ def test_roundtrip_vector_no_dimension() -> None:
787
+ """Test full roundtrip for vector types without dimension annotation."""
788
+ value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
789
+ validate_full_roundtrip(value_f64, Vector[np.float64])
790
+
791
+
792
+ def test_roundtrip_string_vector() -> None:
793
+ """Test full roundtrip for string vector using list."""
794
+ value_str: Vector[str] = ["hello", "world"]
795
+ validate_full_roundtrip(value_str, Vector[str])
796
+
797
+
798
+ def test_roundtrip_empty_vector() -> None:
799
+ """Test full roundtrip for empty numeric vector."""
800
+ value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
801
+ validate_full_roundtrip(value_empty, Vector[np.float32])
802
+
803
+
804
+ def test_roundtrip_dimension_mismatch() -> None:
805
+ """Test that dimension mismatch raises an error during roundtrip."""
806
+ value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
807
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
808
+ validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
809
+
810
+
811
+ def test_roundtrip_list_backward_compatibility() -> None:
812
+ """Test full roundtrip for list-based vectors for backward compatibility."""
813
+ value_list: list[int] = [1, 2, 3]
814
+ validate_full_roundtrip(value_list, list[int])