cocoindex 0.1.49__cp312-cp312-win_amd64.whl → 0.1.51__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +52 -1
- cocoindex/_engine.cp312-win_amd64.pyd +0 -0
- cocoindex/cli.py +22 -4
- cocoindex/convert.py +41 -1
- cocoindex/functions.py +6 -4
- cocoindex/lib.py +1 -2
- cocoindex/setting.py +10 -6
- cocoindex/tests/test_convert.py +359 -84
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_typing.py +505 -0
- cocoindex/typing.py +92 -17
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/METADATA +1 -1
- cocoindex-0.1.51.dist-info/RECORD +28 -0
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/WHEEL +1 -1
- cocoindex/query.py +0 -115
- cocoindex-0.1.49.dist-info/RECORD +0 -27
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.51.dist-info}/licenses/LICENSE +0 -0
cocoindex/tests/test_convert.py
CHANGED
@@ -1,11 +1,22 @@
|
|
1
1
|
import uuid
|
2
2
|
import datetime
|
3
3
|
from dataclasses import dataclass, make_dataclass
|
4
|
-
from typing import NamedTuple, Literal
|
4
|
+
from typing import NamedTuple, Literal, Any, Callable, Union
|
5
5
|
import pytest
|
6
6
|
import cocoindex
|
7
|
-
from cocoindex.typing import
|
8
|
-
|
7
|
+
from cocoindex.typing import (
|
8
|
+
encode_enriched_type,
|
9
|
+
Vector,
|
10
|
+
Float32,
|
11
|
+
Float64,
|
12
|
+
)
|
13
|
+
from cocoindex.convert import (
|
14
|
+
encode_engine_value,
|
15
|
+
make_engine_value_decoder,
|
16
|
+
dump_engine_object,
|
17
|
+
)
|
18
|
+
import numpy as np
|
19
|
+
from numpy.typing import NDArray
|
9
20
|
|
10
21
|
|
11
22
|
@dataclass
|
@@ -23,7 +34,7 @@ class Tag:
|
|
23
34
|
|
24
35
|
@dataclass
|
25
36
|
class Basket:
|
26
|
-
items: list
|
37
|
+
items: list[str]
|
27
38
|
|
28
39
|
|
29
40
|
@dataclass
|
@@ -53,7 +64,9 @@ class CustomerNamedTuple(NamedTuple):
|
|
53
64
|
tags: list[Tag] | None = None
|
54
65
|
|
55
66
|
|
56
|
-
def build_engine_value_decoder(
|
67
|
+
def build_engine_value_decoder(
|
68
|
+
engine_type_in_py: Any, python_type: Any | None = None
|
69
|
+
) -> Callable[[Any], Any]:
|
57
70
|
"""
|
58
71
|
Helper to build a converter for the given engine-side type (as represented in Python).
|
59
72
|
If python_type is not specified, uses engine_type_in_py as the target.
|
@@ -62,19 +75,51 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
|
|
62
75
|
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
|
63
76
|
|
64
77
|
|
65
|
-
def
|
78
|
+
def validate_full_roundtrip(
|
79
|
+
value: Any,
|
80
|
+
value_type: Any = None,
|
81
|
+
*other_decoded_values: tuple[Any, Any],
|
82
|
+
) -> None:
|
83
|
+
"""
|
84
|
+
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
|
85
|
+
|
86
|
+
`other_decoded_values` is a tuple of (value, type) pairs.
|
87
|
+
If provided, also validate the value can be decoded to the other types.
|
88
|
+
"""
|
89
|
+
from cocoindex import _engine # type: ignore
|
90
|
+
|
91
|
+
encoded_value = encode_engine_value(value)
|
92
|
+
value_type = value_type or type(value)
|
93
|
+
encoded_output_type = encode_enriched_type(value_type)["type"]
|
94
|
+
value_from_engine = _engine.testutil.seder_roundtrip(
|
95
|
+
encoded_value, encoded_output_type
|
96
|
+
)
|
97
|
+
decoded_value = build_engine_value_decoder(value_type, value_type)(
|
98
|
+
value_from_engine
|
99
|
+
)
|
100
|
+
np.testing.assert_array_equal(decoded_value, value)
|
101
|
+
|
102
|
+
if other_decoded_values is not None:
|
103
|
+
for other_value, other_type in other_decoded_values:
|
104
|
+
other_decoded_value = build_engine_value_decoder(other_type, other_type)(
|
105
|
+
value_from_engine
|
106
|
+
)
|
107
|
+
np.testing.assert_array_equal(other_decoded_value, other_value)
|
108
|
+
|
109
|
+
|
110
|
+
def test_encode_engine_value_basic_types() -> None:
|
66
111
|
assert encode_engine_value(123) == 123
|
67
112
|
assert encode_engine_value(3.14) == 3.14
|
68
113
|
assert encode_engine_value("hello") == "hello"
|
69
114
|
assert encode_engine_value(True) is True
|
70
115
|
|
71
116
|
|
72
|
-
def test_encode_engine_value_uuid():
|
117
|
+
def test_encode_engine_value_uuid() -> None:
|
73
118
|
u = uuid.uuid4()
|
74
119
|
assert encode_engine_value(u) == u.bytes
|
75
120
|
|
76
121
|
|
77
|
-
def test_encode_engine_value_date_time_types():
|
122
|
+
def test_encode_engine_value_date_time_types() -> None:
|
78
123
|
d = datetime.date(2024, 1, 1)
|
79
124
|
assert encode_engine_value(d) == d
|
80
125
|
t = datetime.time(12, 30)
|
@@ -83,7 +128,7 @@ def test_encode_engine_value_date_time_types():
|
|
83
128
|
assert encode_engine_value(dt) == dt
|
84
129
|
|
85
130
|
|
86
|
-
def test_encode_engine_value_struct():
|
131
|
+
def test_encode_engine_value_struct() -> None:
|
87
132
|
order = Order(order_id="O123", name="mixed nuts", price=25.0)
|
88
133
|
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
|
89
134
|
|
@@ -96,7 +141,7 @@ def test_encode_engine_value_struct():
|
|
96
141
|
]
|
97
142
|
|
98
143
|
|
99
|
-
def test_encode_engine_value_list_of_structs():
|
144
|
+
def test_encode_engine_value_list_of_structs() -> None:
|
100
145
|
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
101
146
|
assert encode_engine_value(orders) == [
|
102
147
|
["O1", "item1", 10.0, "default_extra"],
|
@@ -113,12 +158,12 @@ def test_encode_engine_value_list_of_structs():
|
|
113
158
|
]
|
114
159
|
|
115
160
|
|
116
|
-
def test_encode_engine_value_struct_with_list():
|
161
|
+
def test_encode_engine_value_struct_with_list() -> None:
|
117
162
|
basket = Basket(items=["apple", "banana"])
|
118
163
|
assert encode_engine_value(basket) == [["apple", "banana"]]
|
119
164
|
|
120
165
|
|
121
|
-
def test_encode_engine_value_nested_struct():
|
166
|
+
def test_encode_engine_value_nested_struct() -> None:
|
122
167
|
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
|
123
168
|
assert encode_engine_value(customer) == [
|
124
169
|
"Alice",
|
@@ -136,12 +181,12 @@ def test_encode_engine_value_nested_struct():
|
|
136
181
|
]
|
137
182
|
|
138
183
|
|
139
|
-
def test_encode_engine_value_empty_list():
|
184
|
+
def test_encode_engine_value_empty_list() -> None:
|
140
185
|
assert encode_engine_value([]) == []
|
141
186
|
assert encode_engine_value([[]]) == [[]]
|
142
187
|
|
143
188
|
|
144
|
-
def test_encode_engine_value_tuple():
|
189
|
+
def test_encode_engine_value_tuple() -> None:
|
145
190
|
assert encode_engine_value(()) == []
|
146
191
|
assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
|
147
192
|
assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
|
@@ -149,20 +194,23 @@ def test_encode_engine_value_tuple():
|
|
149
194
|
assert encode_engine_value(((),)) == [[]]
|
150
195
|
|
151
196
|
|
152
|
-
def test_encode_engine_value_none():
|
197
|
+
def test_encode_engine_value_none() -> None:
|
153
198
|
assert encode_engine_value(None) is None
|
154
199
|
|
155
200
|
|
156
|
-
def
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
201
|
+
def test_roundtrip_basic_types() -> None:
|
202
|
+
validate_full_roundtrip(42, int)
|
203
|
+
validate_full_roundtrip(3.25, float, (3.25, Float64))
|
204
|
+
validate_full_roundtrip(3.25, Float64, (3.25, float))
|
205
|
+
validate_full_roundtrip(3.25, Float32)
|
206
|
+
validate_full_roundtrip("hello", str)
|
207
|
+
validate_full_roundtrip(True, bool)
|
208
|
+
validate_full_roundtrip(False, bool)
|
209
|
+
validate_full_roundtrip(datetime.date(2025, 1, 1), datetime.date)
|
210
|
+
validate_full_roundtrip(datetime.datetime.now(), cocoindex.LocalDateTime)
|
211
|
+
validate_full_roundtrip(
|
212
|
+
datetime.datetime.now(datetime.UTC), cocoindex.OffsetDateTime
|
213
|
+
)
|
166
214
|
|
167
215
|
|
168
216
|
@pytest.mark.parametrize(
|
@@ -280,18 +328,18 @@ def test_make_engine_value_decoder_basic_types():
|
|
280
328
|
),
|
281
329
|
],
|
282
330
|
)
|
283
|
-
def test_struct_decoder_cases(data_type, engine_val, expected):
|
331
|
+
def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
|
284
332
|
decoder = build_engine_value_decoder(data_type)
|
285
333
|
assert decoder(engine_val) == expected
|
286
334
|
|
287
335
|
|
288
|
-
def
|
336
|
+
def test_make_engine_value_decoder_list_of_struct() -> None:
|
289
337
|
# List of structs (dataclass)
|
290
|
-
decoder = build_engine_value_decoder(list[Order])
|
291
338
|
engine_val = [
|
292
339
|
["O1", "item1", 10.0, "default_extra"],
|
293
340
|
["O2", "item2", 20.0, "default_extra"],
|
294
341
|
]
|
342
|
+
decoder = build_engine_value_decoder(list[Order])
|
295
343
|
assert decoder(engine_val) == [
|
296
344
|
Order("O1", "item1", 10.0, "default_extra"),
|
297
345
|
Order("O2", "item2", 20.0, "default_extra"),
|
@@ -304,13 +352,15 @@ def test_make_engine_value_decoder_collections():
|
|
304
352
|
OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
|
305
353
|
]
|
306
354
|
|
355
|
+
|
356
|
+
def test_make_engine_value_decoder_struct_of_list() -> None:
|
307
357
|
# Struct with list field
|
308
|
-
decoder = build_engine_value_decoder(Customer)
|
309
358
|
engine_val = [
|
310
359
|
"Alice",
|
311
360
|
["O1", "item1", 10.0, "default_extra"],
|
312
361
|
[["vip"], ["premium"]],
|
313
362
|
]
|
363
|
+
decoder = build_engine_value_decoder(Customer)
|
314
364
|
assert decoder(engine_val) == Customer(
|
315
365
|
"Alice",
|
316
366
|
Order("O1", "item1", 10.0, "default_extra"),
|
@@ -325,8 +375,9 @@ def test_make_engine_value_decoder_collections():
|
|
325
375
|
[Tag("vip"), Tag("premium")],
|
326
376
|
)
|
327
377
|
|
378
|
+
|
379
|
+
def test_make_engine_value_decoder_struct_of_struct() -> None:
|
328
380
|
# Struct with struct field
|
329
|
-
decoder = build_engine_value_decoder(NestedStruct)
|
330
381
|
engine_val = [
|
331
382
|
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
|
332
383
|
[
|
@@ -335,6 +386,7 @@ def test_make_engine_value_decoder_collections():
|
|
335
386
|
],
|
336
387
|
2,
|
337
388
|
]
|
389
|
+
decoder = build_engine_value_decoder(NestedStruct)
|
338
390
|
assert decoder(engine_val) == NestedStruct(
|
339
391
|
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
|
340
392
|
[
|
@@ -345,11 +397,13 @@ def test_make_engine_value_decoder_collections():
|
|
345
397
|
)
|
346
398
|
|
347
399
|
|
348
|
-
def make_engine_order(fields):
|
400
|
+
def make_engine_order(fields: list[tuple[str, type]]) -> type:
|
349
401
|
return make_dataclass("EngineOrder", fields)
|
350
402
|
|
351
403
|
|
352
|
-
def make_python_order(
|
404
|
+
def make_python_order(
|
405
|
+
fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
|
406
|
+
) -> type:
|
353
407
|
if defaults is None:
|
354
408
|
defaults = {}
|
355
409
|
# Move all fields with defaults to the end (Python dataclass requirement)
|
@@ -423,8 +477,12 @@ def make_python_order(fields, defaults=None):
|
|
423
477
|
],
|
424
478
|
)
|
425
479
|
def test_field_position_cases(
|
426
|
-
engine_fields
|
427
|
-
|
480
|
+
engine_fields: list[tuple[str, type]],
|
481
|
+
python_fields: list[tuple[str, type]],
|
482
|
+
python_defaults: dict[str, Any],
|
483
|
+
engine_val: list[Any],
|
484
|
+
expected_python_val: tuple[Any, ...],
|
485
|
+
) -> None:
|
428
486
|
EngineOrder = make_engine_order(engine_fields)
|
429
487
|
PythonOrder = make_python_order(python_fields, python_defaults)
|
430
488
|
decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
|
@@ -434,57 +492,33 @@ def test_field_position_cases(
|
|
434
492
|
assert decoder(engine_val) == PythonOrder(**expected_dict)
|
435
493
|
|
436
494
|
|
437
|
-
def test_roundtrip_ltable():
|
495
|
+
def test_roundtrip_ltable() -> None:
|
438
496
|
t = list[Order]
|
439
497
|
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
440
|
-
|
441
|
-
assert encoded == [
|
442
|
-
["O1", "item1", 10.0, "default_extra"],
|
443
|
-
["O2", "item2", 20.0, "default_extra"],
|
444
|
-
]
|
445
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
446
|
-
assert decoded == value
|
498
|
+
validate_full_roundtrip(value, t)
|
447
499
|
|
448
500
|
t_nt = list[OrderNamedTuple]
|
449
501
|
value_nt = [
|
450
502
|
OrderNamedTuple("O1", "item1", 10.0),
|
451
503
|
OrderNamedTuple("O2", "item2", 20.0),
|
452
504
|
]
|
453
|
-
|
454
|
-
assert encoded == [
|
455
|
-
["O1", "item1", 10.0, "default_extra"],
|
456
|
-
["O2", "item2", 20.0, "default_extra"],
|
457
|
-
]
|
458
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
459
|
-
assert decoded == value_nt
|
505
|
+
validate_full_roundtrip(value_nt, t_nt)
|
460
506
|
|
461
507
|
|
462
|
-
def test_roundtrip_ktable_str_key():
|
508
|
+
def test_roundtrip_ktable_str_key() -> None:
|
463
509
|
t = dict[str, Order]
|
464
510
|
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
|
465
|
-
|
466
|
-
assert encoded == [
|
467
|
-
["K1", "O1", "item1", 10.0, "default_extra"],
|
468
|
-
["K2", "O2", "item2", 20.0, "default_extra"],
|
469
|
-
]
|
470
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
471
|
-
assert decoded == value
|
511
|
+
validate_full_roundtrip(value, t)
|
472
512
|
|
473
513
|
t_nt = dict[str, OrderNamedTuple]
|
474
514
|
value_nt = {
|
475
515
|
"K1": OrderNamedTuple("O1", "item1", 10.0),
|
476
516
|
"K2": OrderNamedTuple("O2", "item2", 20.0),
|
477
517
|
}
|
478
|
-
|
479
|
-
assert encoded == [
|
480
|
-
["K1", "O1", "item1", 10.0, "default_extra"],
|
481
|
-
["K2", "O2", "item2", 20.0, "default_extra"],
|
482
|
-
]
|
483
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
484
|
-
assert decoded == value_nt
|
518
|
+
validate_full_roundtrip(value_nt, t_nt)
|
485
519
|
|
486
520
|
|
487
|
-
def test_roundtrip_ktable_struct_key():
|
521
|
+
def test_roundtrip_ktable_struct_key() -> None:
|
488
522
|
@dataclass(frozen=True)
|
489
523
|
class OrderKey:
|
490
524
|
shop_id: str
|
@@ -495,37 +529,25 @@ def test_roundtrip_ktable_struct_key():
|
|
495
529
|
OrderKey("A", 3): Order("O1", "item1", 10.0),
|
496
530
|
OrderKey("B", 4): Order("O2", "item2", 20.0),
|
497
531
|
}
|
498
|
-
|
499
|
-
assert encoded == [
|
500
|
-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
501
|
-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
|
502
|
-
]
|
503
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
504
|
-
assert decoded == value
|
532
|
+
validate_full_roundtrip(value, t)
|
505
533
|
|
506
534
|
t_nt = dict[OrderKey, OrderNamedTuple]
|
507
535
|
value_nt = {
|
508
536
|
OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
|
509
537
|
OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
|
510
538
|
}
|
511
|
-
|
512
|
-
assert encoded == [
|
513
|
-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
514
|
-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
|
515
|
-
]
|
516
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
517
|
-
assert decoded == value_nt
|
539
|
+
validate_full_roundtrip(value_nt, t_nt)
|
518
540
|
|
519
541
|
|
520
|
-
IntVectorType = cocoindex.Vector[
|
542
|
+
IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
|
521
543
|
|
522
544
|
|
523
545
|
def test_vector_as_vector() -> None:
|
524
|
-
value
|
546
|
+
value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
|
525
547
|
encoded = encode_engine_value(value)
|
526
|
-
assert encoded
|
548
|
+
assert np.array_equal(encoded, value)
|
527
549
|
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
528
|
-
assert decoded
|
550
|
+
assert np.array_equal(decoded, value)
|
529
551
|
|
530
552
|
|
531
553
|
ListIntType = list[int]
|
@@ -536,4 +558,257 @@ def test_vector_as_list() -> None:
|
|
536
558
|
encoded = encode_engine_value(value)
|
537
559
|
assert encoded == [1, 2, 3, 4, 5]
|
538
560
|
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
539
|
-
assert decoded
|
561
|
+
assert np.array_equal(decoded, value)
|
562
|
+
|
563
|
+
|
564
|
+
Float64VectorTypeNoDim = Vector[np.float64]
|
565
|
+
Float32VectorType = Vector[np.float32, Literal[3]]
|
566
|
+
Float64VectorType = Vector[np.float64, Literal[3]]
|
567
|
+
Int64VectorType = Vector[np.int64, Literal[3]]
|
568
|
+
Int32VectorType = Vector[np.int32, Literal[3]]
|
569
|
+
UInt8VectorType = Vector[np.uint8, Literal[3]]
|
570
|
+
UInt16VectorType = Vector[np.uint16, Literal[3]]
|
571
|
+
UInt32VectorType = Vector[np.uint32, Literal[3]]
|
572
|
+
UInt64VectorType = Vector[np.uint64, Literal[3]]
|
573
|
+
StrVectorType = Vector[str]
|
574
|
+
NDArrayFloat32Type = NDArray[np.float32]
|
575
|
+
NDArrayFloat64Type = NDArray[np.float64]
|
576
|
+
NDArrayInt64Type = NDArray[np.int64]
|
577
|
+
|
578
|
+
|
579
|
+
def test_encode_engine_value_ndarray() -> None:
|
580
|
+
"""Test encoding NDArray vectors to lists for the Rust engine."""
|
581
|
+
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
582
|
+
assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
|
583
|
+
vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
584
|
+
assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
|
585
|
+
vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
|
586
|
+
assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
|
587
|
+
vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
588
|
+
assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
|
589
|
+
|
590
|
+
|
591
|
+
def test_make_engine_value_decoder_ndarray() -> None:
|
592
|
+
"""Test decoding engine lists to NDArray vectors."""
|
593
|
+
decoder_f32 = build_engine_value_decoder(Float32VectorType)
|
594
|
+
result_f32 = decoder_f32([1.0, 2.0, 3.0])
|
595
|
+
assert isinstance(result_f32, np.ndarray)
|
596
|
+
assert result_f32.dtype == np.float32
|
597
|
+
assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
598
|
+
decoder_f64 = build_engine_value_decoder(Float64VectorType)
|
599
|
+
result_f64 = decoder_f64([1.0, 2.0, 3.0])
|
600
|
+
assert isinstance(result_f64, np.ndarray)
|
601
|
+
assert result_f64.dtype == np.float64
|
602
|
+
assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
|
603
|
+
decoder_i64 = build_engine_value_decoder(Int64VectorType)
|
604
|
+
result_i64 = decoder_i64([1, 2, 3])
|
605
|
+
assert isinstance(result_i64, np.ndarray)
|
606
|
+
assert result_i64.dtype == np.int64
|
607
|
+
assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
|
608
|
+
decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
|
609
|
+
result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
|
610
|
+
assert isinstance(result_nd_f32, np.ndarray)
|
611
|
+
assert result_nd_f32.dtype == np.float32
|
612
|
+
assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
613
|
+
|
614
|
+
|
615
|
+
def test_roundtrip_ndarray_vector() -> None:
|
616
|
+
"""Test roundtrip encoding and decoding of NDArray vectors."""
|
617
|
+
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
618
|
+
encoded_f32 = encode_engine_value(value_f32)
|
619
|
+
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
|
620
|
+
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
|
621
|
+
assert isinstance(decoded_f32, np.ndarray)
|
622
|
+
assert decoded_f32.dtype == np.float32
|
623
|
+
assert np.array_equal(decoded_f32, value_f32)
|
624
|
+
value_i64 = np.array([1, 2, 3], dtype=np.int64)
|
625
|
+
encoded_i64 = encode_engine_value(value_i64)
|
626
|
+
assert np.array_equal(encoded_i64, [1, 2, 3])
|
627
|
+
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
|
628
|
+
assert isinstance(decoded_i64, np.ndarray)
|
629
|
+
assert decoded_i64.dtype == np.int64
|
630
|
+
assert np.array_equal(decoded_i64, value_i64)
|
631
|
+
value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
632
|
+
encoded_nd_f64 = encode_engine_value(value_nd_f64)
|
633
|
+
assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
|
634
|
+
decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
|
635
|
+
assert isinstance(decoded_nd_f64, np.ndarray)
|
636
|
+
assert decoded_nd_f64.dtype == np.float64
|
637
|
+
assert np.array_equal(decoded_nd_f64, value_nd_f64)
|
638
|
+
|
639
|
+
|
640
|
+
def test_ndarray_dimension_mismatch() -> None:
|
641
|
+
"""Test dimension enforcement for Vector with specified dimension."""
|
642
|
+
value = np.array([1.0, 2.0], dtype=np.float32)
|
643
|
+
encoded = encode_engine_value(value)
|
644
|
+
assert np.array_equal(encoded, [1.0, 2.0])
|
645
|
+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
|
646
|
+
build_engine_value_decoder(Float32VectorType)(encoded)
|
647
|
+
|
648
|
+
|
649
|
+
def test_list_vector_backward_compatibility() -> None:
|
650
|
+
"""Test that list-based vectors still work for backward compatibility."""
|
651
|
+
value = [1, 2, 3, 4, 5]
|
652
|
+
encoded = encode_engine_value(value)
|
653
|
+
assert encoded == [1, 2, 3, 4, 5]
|
654
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
655
|
+
assert isinstance(decoded, np.ndarray)
|
656
|
+
assert decoded.dtype == np.int64
|
657
|
+
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
|
658
|
+
value_list: ListIntType = [1, 2, 3, 4, 5]
|
659
|
+
encoded = encode_engine_value(value_list)
|
660
|
+
assert np.array_equal(encoded, [1, 2, 3, 4, 5])
|
661
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
662
|
+
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
|
663
|
+
|
664
|
+
|
665
|
+
def test_encode_complex_structure_with_ndarray() -> None:
|
666
|
+
"""Test encoding a complex structure that includes an NDArray."""
|
667
|
+
|
668
|
+
@dataclass
|
669
|
+
class MyStructWithNDArray:
|
670
|
+
name: str
|
671
|
+
data: NDArray[np.float32]
|
672
|
+
value: int
|
673
|
+
|
674
|
+
original = MyStructWithNDArray(
|
675
|
+
name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
|
676
|
+
)
|
677
|
+
encoded = encode_engine_value(original)
|
678
|
+
|
679
|
+
assert encoded[0] == original.name
|
680
|
+
assert np.array_equal(encoded[1], original.data)
|
681
|
+
assert encoded[2] == original.value
|
682
|
+
|
683
|
+
|
684
|
+
def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
685
|
+
"""Test decoding a nullable NDArray with None or value inputs."""
|
686
|
+
src_type_dict = {
|
687
|
+
"kind": "Vector",
|
688
|
+
"element_type": {"kind": "Float32"},
|
689
|
+
"dimension": None,
|
690
|
+
}
|
691
|
+
dst_annotation = NDArrayFloat32Type | None
|
692
|
+
decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
|
693
|
+
|
694
|
+
none_engine_value = None
|
695
|
+
decoded_array = decoder(none_engine_value)
|
696
|
+
assert decoded_array is None
|
697
|
+
|
698
|
+
engine_value = [1.0, 2.0, 3.0]
|
699
|
+
decoded_array = decoder(engine_value)
|
700
|
+
|
701
|
+
assert isinstance(decoded_array, np.ndarray)
|
702
|
+
assert decoded_array.dtype == np.float32
|
703
|
+
np.testing.assert_array_equal(
|
704
|
+
decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
705
|
+
)
|
706
|
+
|
707
|
+
|
708
|
+
def test_decode_vector_string() -> None:
|
709
|
+
"""Test decoding a vector of strings works for Python native list type."""
|
710
|
+
src_type_dict = {
|
711
|
+
"kind": "Vector",
|
712
|
+
"element_type": {"kind": "Str"},
|
713
|
+
"dimension": None,
|
714
|
+
}
|
715
|
+
decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
|
716
|
+
assert decoder(["hello", "world"]) == ["hello", "world"]
|
717
|
+
|
718
|
+
|
719
|
+
def test_decode_error_non_nullable_or_non_list_vector() -> None:
|
720
|
+
"""Test decoding errors for non-nullable vectors or non-list inputs."""
|
721
|
+
src_type_dict = {
|
722
|
+
"kind": "Vector",
|
723
|
+
"element_type": {"kind": "Float32"},
|
724
|
+
"dimension": None,
|
725
|
+
}
|
726
|
+
decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
|
727
|
+
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
|
728
|
+
decoder(None)
|
729
|
+
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
|
730
|
+
decoder("not a list")
|
731
|
+
|
732
|
+
|
733
|
+
def test_dump_vector_type_annotation_with_dim() -> None:
|
734
|
+
"""Test dumping a vector type annotation with a specified dimension."""
|
735
|
+
expected_dump = {
|
736
|
+
"type": {
|
737
|
+
"kind": "Vector",
|
738
|
+
"element_type": {"kind": "Float32"},
|
739
|
+
"dimension": 3,
|
740
|
+
}
|
741
|
+
}
|
742
|
+
assert dump_engine_object(Float32VectorType) == expected_dump
|
743
|
+
|
744
|
+
|
745
|
+
def test_dump_vector_type_annotation_no_dim() -> None:
|
746
|
+
"""Test dumping a vector type annotation with no dimension."""
|
747
|
+
expected_dump_no_dim = {
|
748
|
+
"type": {
|
749
|
+
"kind": "Vector",
|
750
|
+
"element_type": {"kind": "Float64"},
|
751
|
+
"dimension": None,
|
752
|
+
}
|
753
|
+
}
|
754
|
+
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
|
755
|
+
|
756
|
+
|
757
|
+
def test_full_roundtrip_vector_numeric_types() -> None:
|
758
|
+
"""Test full roundtrip for numeric vector types using NDArray."""
|
759
|
+
value_f32: Vector[np.float32, Literal[3]] = np.array(
|
760
|
+
[1.0, 2.0, 3.0], dtype=np.float32
|
761
|
+
)
|
762
|
+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
|
763
|
+
value_f64: Vector[np.float64, Literal[3]] = np.array(
|
764
|
+
[1.0, 2.0, 3.0], dtype=np.float64
|
765
|
+
)
|
766
|
+
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
|
767
|
+
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
|
768
|
+
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
|
769
|
+
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
|
770
|
+
with pytest.raises(ValueError, match="type unsupported yet"):
|
771
|
+
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
|
772
|
+
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
|
773
|
+
with pytest.raises(ValueError, match="type unsupported yet"):
|
774
|
+
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
|
775
|
+
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
|
776
|
+
with pytest.raises(ValueError, match="type unsupported yet"):
|
777
|
+
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
|
778
|
+
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
|
779
|
+
with pytest.raises(ValueError, match="type unsupported yet"):
|
780
|
+
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
|
781
|
+
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
|
782
|
+
with pytest.raises(ValueError, match="type unsupported yet"):
|
783
|
+
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
|
784
|
+
|
785
|
+
|
786
|
+
def test_roundtrip_vector_no_dimension() -> None:
|
787
|
+
"""Test full roundtrip for vector types without dimension annotation."""
|
788
|
+
value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
789
|
+
validate_full_roundtrip(value_f64, Vector[np.float64])
|
790
|
+
|
791
|
+
|
792
|
+
def test_roundtrip_string_vector() -> None:
|
793
|
+
"""Test full roundtrip for string vector using list."""
|
794
|
+
value_str: Vector[str] = ["hello", "world"]
|
795
|
+
validate_full_roundtrip(value_str, Vector[str])
|
796
|
+
|
797
|
+
|
798
|
+
def test_roundtrip_empty_vector() -> None:
|
799
|
+
"""Test full roundtrip for empty numeric vector."""
|
800
|
+
value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
|
801
|
+
validate_full_roundtrip(value_empty, Vector[np.float32])
|
802
|
+
|
803
|
+
|
804
|
+
def test_roundtrip_dimension_mismatch() -> None:
|
805
|
+
"""Test that dimension mismatch raises an error during roundtrip."""
|
806
|
+
value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
|
807
|
+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
|
808
|
+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
|
809
|
+
|
810
|
+
|
811
|
+
def test_roundtrip_list_backward_compatibility() -> None:
|
812
|
+
"""Test full roundtrip for list-based vectors for backward compatibility."""
|
813
|
+
value_list: list[int] = [1, 2, 3]
|
814
|
+
validate_full_roundtrip(value_list, list[int])
|