cocoindex 0.2.3__cp311-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +92 -0
- cocoindex/_engine.pyd +0 -0
- cocoindex/auth_registry.py +51 -0
- cocoindex/cli.py +697 -0
- cocoindex/convert.py +621 -0
- cocoindex/flow.py +1205 -0
- cocoindex/functions.py +357 -0
- cocoindex/index.py +29 -0
- cocoindex/lib.py +32 -0
- cocoindex/llm.py +46 -0
- cocoindex/op.py +628 -0
- cocoindex/py.typed +0 -0
- cocoindex/runtime.py +37 -0
- cocoindex/setting.py +181 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources.py +102 -0
- cocoindex/subprocess_exec.py +279 -0
- cocoindex/targets.py +135 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/conftest.py +38 -0
- cocoindex/tests/test_convert.py +1543 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +207 -0
- cocoindex/tests/test_typing.py +429 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +473 -0
- cocoindex/user_app_loader.py +51 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.2.3.dist-info/METADATA +262 -0
- cocoindex-0.2.3.dist-info/RECORD +34 -0
- cocoindex-0.2.3.dist-info/WHEEL +4 -0
- cocoindex-0.2.3.dist-info/entry_points.txt +2 -0
- cocoindex-0.2.3.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1543 @@
|
|
1
|
+
import datetime
|
2
|
+
import inspect
|
3
|
+
import uuid
|
4
|
+
from dataclasses import dataclass, make_dataclass
|
5
|
+
from typing import Annotated, Any, Callable, Literal, NamedTuple, Type
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pytest
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
import cocoindex
|
12
|
+
from cocoindex.convert import (
|
13
|
+
dump_engine_object,
|
14
|
+
make_engine_value_encoder,
|
15
|
+
make_engine_value_decoder,
|
16
|
+
)
|
17
|
+
from cocoindex.typing import (
|
18
|
+
Float32,
|
19
|
+
Float64,
|
20
|
+
TypeKind,
|
21
|
+
Vector,
|
22
|
+
analyze_type_info,
|
23
|
+
encode_enriched_type,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class Order:
|
29
|
+
order_id: str
|
30
|
+
name: str
|
31
|
+
price: float
|
32
|
+
extra_field: str = "default_extra"
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class Tag:
|
37
|
+
name: str
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class Basket:
|
42
|
+
items: list[str]
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class Customer:
|
47
|
+
name: str
|
48
|
+
order: Order
|
49
|
+
tags: list[Tag] | None = None
|
50
|
+
|
51
|
+
|
52
|
+
@dataclass
|
53
|
+
class NestedStruct:
|
54
|
+
customer: Customer
|
55
|
+
orders: list[Order]
|
56
|
+
count: int = 0
|
57
|
+
|
58
|
+
|
59
|
+
class OrderNamedTuple(NamedTuple):
|
60
|
+
order_id: str
|
61
|
+
name: str
|
62
|
+
price: float
|
63
|
+
extra_field: str = "default_extra"
|
64
|
+
|
65
|
+
|
66
|
+
class CustomerNamedTuple(NamedTuple):
|
67
|
+
name: str
|
68
|
+
order: OrderNamedTuple
|
69
|
+
tags: list[Tag] | None = None
|
70
|
+
|
71
|
+
|
72
|
+
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
|
73
|
+
"""
|
74
|
+
Encode a Python value to an engine value.
|
75
|
+
"""
|
76
|
+
encoder = make_engine_value_encoder(analyze_type_info(type_hint))
|
77
|
+
return encoder(value)
|
78
|
+
|
79
|
+
|
80
|
+
def build_engine_value_decoder(
|
81
|
+
engine_type_in_py: Any, python_type: Any | None = None
|
82
|
+
) -> Callable[[Any], Any]:
|
83
|
+
"""
|
84
|
+
Helper to build a converter for the given engine-side type (as represented in Python).
|
85
|
+
If python_type is not specified, uses engine_type_in_py as the target.
|
86
|
+
"""
|
87
|
+
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
88
|
+
return make_engine_value_decoder(
|
89
|
+
[], engine_type, analyze_type_info(python_type or engine_type_in_py)
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def validate_full_roundtrip_to(
|
94
|
+
value: Any,
|
95
|
+
value_type: Any,
|
96
|
+
*decoded_values: tuple[Any, Any],
|
97
|
+
) -> None:
|
98
|
+
"""
|
99
|
+
Validate the given value becomes specific values after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
|
100
|
+
|
101
|
+
`decoded_values` is a tuple of (value, type) pairs.
|
102
|
+
"""
|
103
|
+
from cocoindex import _engine # type: ignore
|
104
|
+
|
105
|
+
def eq(a: Any, b: Any) -> bool:
|
106
|
+
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
|
107
|
+
return np.array_equal(a, b)
|
108
|
+
return type(a) is type(b) and not not (a == b)
|
109
|
+
|
110
|
+
encoded_value = encode_engine_value(value, value_type)
|
111
|
+
value_type = value_type or type(value)
|
112
|
+
encoded_output_type = encode_enriched_type(value_type)["type"]
|
113
|
+
value_from_engine = _engine.testutil.seder_roundtrip(
|
114
|
+
encoded_value, encoded_output_type
|
115
|
+
)
|
116
|
+
|
117
|
+
for other_value, other_type in decoded_values:
|
118
|
+
decoder = make_engine_value_decoder(
|
119
|
+
[], encoded_output_type, analyze_type_info(other_type)
|
120
|
+
)
|
121
|
+
other_decoded_value = decoder(value_from_engine)
|
122
|
+
assert eq(other_decoded_value, other_value), (
|
123
|
+
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
|
124
|
+
)
|
125
|
+
|
126
|
+
|
127
|
+
def validate_full_roundtrip(
|
128
|
+
value: Any,
|
129
|
+
value_type: Any,
|
130
|
+
*other_decoded_values: tuple[Any, Any],
|
131
|
+
) -> None:
|
132
|
+
"""
|
133
|
+
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
|
134
|
+
|
135
|
+
`other_decoded_values` is a tuple of (value, type) pairs.
|
136
|
+
If provided, also validate the value can be decoded to the other types.
|
137
|
+
"""
|
138
|
+
validate_full_roundtrip_to(
|
139
|
+
value, value_type, (value, value_type), *other_decoded_values
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
def test_encode_engine_value_basic_types() -> None:
|
144
|
+
assert encode_engine_value(123, int) == 123
|
145
|
+
assert encode_engine_value(3.14, float) == 3.14
|
146
|
+
assert encode_engine_value("hello", str) == "hello"
|
147
|
+
assert encode_engine_value(True, bool) is True
|
148
|
+
|
149
|
+
|
150
|
+
def test_encode_engine_value_uuid() -> None:
|
151
|
+
u = uuid.uuid4()
|
152
|
+
assert encode_engine_value(u, uuid.UUID) == u
|
153
|
+
|
154
|
+
|
155
|
+
def test_encode_engine_value_date_time_types() -> None:
|
156
|
+
d = datetime.date(2024, 1, 1)
|
157
|
+
assert encode_engine_value(d, datetime.date) == d
|
158
|
+
t = datetime.time(12, 30)
|
159
|
+
assert encode_engine_value(t, datetime.time) == t
|
160
|
+
dt = datetime.datetime(2024, 1, 1, 12, 30)
|
161
|
+
assert encode_engine_value(dt, datetime.datetime) == dt
|
162
|
+
|
163
|
+
|
164
|
+
def test_encode_scalar_numpy_values() -> None:
|
165
|
+
"""Test encoding scalar NumPy values to engine-compatible values."""
|
166
|
+
test_cases = [
|
167
|
+
(np.int64(42), 42),
|
168
|
+
(np.float32(3.14), pytest.approx(3.14)),
|
169
|
+
(np.float64(2.718), pytest.approx(2.718)),
|
170
|
+
]
|
171
|
+
for np_value, expected in test_cases:
|
172
|
+
encoded = encode_engine_value(np_value, type(np_value))
|
173
|
+
assert encoded == expected
|
174
|
+
assert isinstance(encoded, (int, float))
|
175
|
+
|
176
|
+
|
177
|
+
def test_encode_engine_value_struct() -> None:
|
178
|
+
order = Order(order_id="O123", name="mixed nuts", price=25.0)
|
179
|
+
assert encode_engine_value(order, Order) == [
|
180
|
+
"O123",
|
181
|
+
"mixed nuts",
|
182
|
+
25.0,
|
183
|
+
"default_extra",
|
184
|
+
]
|
185
|
+
|
186
|
+
order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
|
187
|
+
assert encode_engine_value(order_nt, OrderNamedTuple) == [
|
188
|
+
"O123",
|
189
|
+
"mixed nuts",
|
190
|
+
25.0,
|
191
|
+
"default_extra",
|
192
|
+
]
|
193
|
+
|
194
|
+
|
195
|
+
def test_encode_engine_value_list_of_structs() -> None:
|
196
|
+
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
197
|
+
assert encode_engine_value(orders, list[Order]) == [
|
198
|
+
["O1", "item1", 10.0, "default_extra"],
|
199
|
+
["O2", "item2", 20.0, "default_extra"],
|
200
|
+
]
|
201
|
+
|
202
|
+
orders_nt = [
|
203
|
+
OrderNamedTuple("O1", "item1", 10.0),
|
204
|
+
OrderNamedTuple("O2", "item2", 20.0),
|
205
|
+
]
|
206
|
+
assert encode_engine_value(orders_nt, list[OrderNamedTuple]) == [
|
207
|
+
["O1", "item1", 10.0, "default_extra"],
|
208
|
+
["O2", "item2", 20.0, "default_extra"],
|
209
|
+
]
|
210
|
+
|
211
|
+
|
212
|
+
def test_encode_engine_value_struct_with_list() -> None:
|
213
|
+
basket = Basket(items=["apple", "banana"])
|
214
|
+
assert encode_engine_value(basket, Basket) == [["apple", "banana"]]
|
215
|
+
|
216
|
+
|
217
|
+
def test_encode_engine_value_nested_struct() -> None:
|
218
|
+
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
|
219
|
+
assert encode_engine_value(customer, Customer) == [
|
220
|
+
"Alice",
|
221
|
+
["O1", "item1", 10.0, "default_extra"],
|
222
|
+
None,
|
223
|
+
]
|
224
|
+
|
225
|
+
customer_nt = CustomerNamedTuple(
|
226
|
+
name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)
|
227
|
+
)
|
228
|
+
assert encode_engine_value(customer_nt, CustomerNamedTuple) == [
|
229
|
+
"Alice",
|
230
|
+
["O1", "item1", 10.0, "default_extra"],
|
231
|
+
None,
|
232
|
+
]
|
233
|
+
|
234
|
+
|
235
|
+
def test_encode_engine_value_empty_list() -> None:
|
236
|
+
assert encode_engine_value([], list) == []
|
237
|
+
assert encode_engine_value([[]], list[list[Any]]) == [[]]
|
238
|
+
|
239
|
+
|
240
|
+
def test_encode_engine_value_tuple() -> None:
|
241
|
+
assert encode_engine_value((), Any) == []
|
242
|
+
assert encode_engine_value((1, 2, 3), Any) == [1, 2, 3]
|
243
|
+
assert encode_engine_value(((1, 2), (3, 4)), Any) == [[1, 2], [3, 4]]
|
244
|
+
assert encode_engine_value(([],), Any) == [[]]
|
245
|
+
assert encode_engine_value(((),), Any) == [[]]
|
246
|
+
|
247
|
+
|
248
|
+
def test_encode_engine_value_none() -> None:
|
249
|
+
assert encode_engine_value(None, Any) is None
|
250
|
+
|
251
|
+
|
252
|
+
def test_roundtrip_basic_types() -> None:
|
253
|
+
validate_full_roundtrip(
|
254
|
+
b"hello world",
|
255
|
+
bytes,
|
256
|
+
(b"hello world", inspect.Parameter.empty),
|
257
|
+
(b"hello world", Any),
|
258
|
+
)
|
259
|
+
validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
|
260
|
+
validate_full_roundtrip("hello", str, ("hello", Any))
|
261
|
+
validate_full_roundtrip(True, bool, (True, Any))
|
262
|
+
validate_full_roundtrip(False, bool, (False, Any))
|
263
|
+
validate_full_roundtrip(
|
264
|
+
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
|
265
|
+
)
|
266
|
+
validate_full_roundtrip(42, int, (42, cocoindex.Int64))
|
267
|
+
validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
|
268
|
+
|
269
|
+
validate_full_roundtrip(
|
270
|
+
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
|
271
|
+
)
|
272
|
+
validate_full_roundtrip(3.25, float, (3.25, Float64))
|
273
|
+
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
|
274
|
+
|
275
|
+
validate_full_roundtrip(
|
276
|
+
3.25,
|
277
|
+
Float32,
|
278
|
+
(3.25, float),
|
279
|
+
(np.float32(3.25), np.float32),
|
280
|
+
(np.float64(3.25), np.float64),
|
281
|
+
(3.25, Float64),
|
282
|
+
(3.25, Any),
|
283
|
+
)
|
284
|
+
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
|
285
|
+
|
286
|
+
|
287
|
+
def test_roundtrip_uuid() -> None:
|
288
|
+
uuid_value = uuid.uuid4()
|
289
|
+
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
|
290
|
+
|
291
|
+
|
292
|
+
def test_roundtrip_range() -> None:
|
293
|
+
r1 = (0, 100)
|
294
|
+
validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
|
295
|
+
r2 = (50, 50)
|
296
|
+
validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
|
297
|
+
r3 = (0, 1_000_000_000)
|
298
|
+
validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
|
299
|
+
|
300
|
+
|
301
|
+
def test_roundtrip_time() -> None:
|
302
|
+
t1 = datetime.time(10, 30, 50, 123456)
|
303
|
+
validate_full_roundtrip(t1, datetime.time, (t1, Any))
|
304
|
+
t2 = datetime.time(23, 59, 59)
|
305
|
+
validate_full_roundtrip(t2, datetime.time, (t2, Any))
|
306
|
+
t3 = datetime.time(0, 0, 0)
|
307
|
+
validate_full_roundtrip(t3, datetime.time, (t3, Any))
|
308
|
+
|
309
|
+
validate_full_roundtrip(
|
310
|
+
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
|
311
|
+
)
|
312
|
+
|
313
|
+
validate_full_roundtrip(
|
314
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
|
315
|
+
cocoindex.LocalDateTime,
|
316
|
+
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456), datetime.datetime),
|
317
|
+
)
|
318
|
+
|
319
|
+
tz = datetime.timezone(datetime.timedelta(hours=5))
|
320
|
+
validate_full_roundtrip(
|
321
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
|
322
|
+
cocoindex.OffsetDateTime,
|
323
|
+
(
|
324
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
|
325
|
+
datetime.datetime,
|
326
|
+
),
|
327
|
+
)
|
328
|
+
validate_full_roundtrip(
|
329
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz),
|
330
|
+
datetime.datetime,
|
331
|
+
(datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, tz), cocoindex.OffsetDateTime),
|
332
|
+
)
|
333
|
+
validate_full_roundtrip_to(
|
334
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
|
335
|
+
cocoindex.OffsetDateTime,
|
336
|
+
(
|
337
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
|
338
|
+
datetime.datetime,
|
339
|
+
),
|
340
|
+
)
|
341
|
+
validate_full_roundtrip_to(
|
342
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456),
|
343
|
+
datetime.datetime,
|
344
|
+
(
|
345
|
+
datetime.datetime(2025, 1, 2, 3, 4, 5, 123456, datetime.UTC),
|
346
|
+
cocoindex.OffsetDateTime,
|
347
|
+
),
|
348
|
+
)
|
349
|
+
|
350
|
+
|
351
|
+
def test_roundtrip_timedelta() -> None:
|
352
|
+
td1 = datetime.timedelta(
|
353
|
+
days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
|
354
|
+
)
|
355
|
+
validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
|
356
|
+
td2 = datetime.timedelta(days=-5, hours=-2)
|
357
|
+
validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
|
358
|
+
td3 = datetime.timedelta(0)
|
359
|
+
validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
|
360
|
+
|
361
|
+
|
362
|
+
def test_roundtrip_json() -> None:
|
363
|
+
simple_dict = {"key": "value", "number": 123, "bool": True, "float": 1.23}
|
364
|
+
validate_full_roundtrip(simple_dict, cocoindex.Json)
|
365
|
+
|
366
|
+
simple_list = [1, "string", False, None, 4.56]
|
367
|
+
validate_full_roundtrip(simple_list, cocoindex.Json)
|
368
|
+
|
369
|
+
nested_structure = {
|
370
|
+
"name": "Test Json",
|
371
|
+
"version": 1.0,
|
372
|
+
"items": [
|
373
|
+
{"id": 1, "value": "item1"},
|
374
|
+
{"id": 2, "value": None, "props": {"active": True}},
|
375
|
+
],
|
376
|
+
"metadata": None,
|
377
|
+
}
|
378
|
+
validate_full_roundtrip(nested_structure, cocoindex.Json)
|
379
|
+
|
380
|
+
validate_full_roundtrip({}, cocoindex.Json)
|
381
|
+
validate_full_roundtrip([], cocoindex.Json)
|
382
|
+
|
383
|
+
|
384
|
+
def test_decode_scalar_numpy_values() -> None:
|
385
|
+
test_cases = [
|
386
|
+
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
|
387
|
+
({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
|
388
|
+
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
|
389
|
+
]
|
390
|
+
for src_type, dst_type, input_value, expected in test_cases:
|
391
|
+
decoder = make_engine_value_decoder(
|
392
|
+
["field"], src_type, analyze_type_info(dst_type)
|
393
|
+
)
|
394
|
+
result = decoder(input_value)
|
395
|
+
assert isinstance(result, dst_type)
|
396
|
+
assert result == expected
|
397
|
+
|
398
|
+
|
399
|
+
def test_non_ndarray_vector_decoding() -> None:
|
400
|
+
# Test list[np.float64]
|
401
|
+
src_type = {
|
402
|
+
"kind": "Vector",
|
403
|
+
"element_type": {"kind": "Float64"},
|
404
|
+
"dimension": None,
|
405
|
+
}
|
406
|
+
dst_type_float = list[np.float64]
|
407
|
+
decoder = make_engine_value_decoder(
|
408
|
+
["field"], src_type, analyze_type_info(dst_type_float)
|
409
|
+
)
|
410
|
+
input_numbers = [1.0, 2.0, 3.0]
|
411
|
+
result = decoder(input_numbers)
|
412
|
+
assert isinstance(result, list)
|
413
|
+
assert all(isinstance(x, np.float64) for x in result)
|
414
|
+
assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
|
415
|
+
|
416
|
+
# Test list[Uuid]
|
417
|
+
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
|
418
|
+
dst_type_uuid = list[uuid.UUID]
|
419
|
+
decoder = make_engine_value_decoder(
|
420
|
+
["field"], src_type, analyze_type_info(dst_type_uuid)
|
421
|
+
)
|
422
|
+
uuid1 = uuid.uuid4()
|
423
|
+
uuid2 = uuid.uuid4()
|
424
|
+
input_uuids = [uuid1, uuid2]
|
425
|
+
result = decoder(input_uuids)
|
426
|
+
assert isinstance(result, list)
|
427
|
+
assert all(isinstance(x, uuid.UUID) for x in result)
|
428
|
+
assert result == [uuid1, uuid2]
|
429
|
+
|
430
|
+
|
431
|
+
def test_roundtrip_struct() -> None:
|
432
|
+
validate_full_roundtrip(
|
433
|
+
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
434
|
+
Order,
|
435
|
+
)
|
436
|
+
validate_full_roundtrip(
|
437
|
+
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
438
|
+
OrderNamedTuple,
|
439
|
+
)
|
440
|
+
|
441
|
+
|
442
|
+
def test_make_engine_value_decoder_list_of_struct() -> None:
|
443
|
+
# List of structs (dataclass)
|
444
|
+
engine_val = [
|
445
|
+
["O1", "item1", 10.0, "default_extra"],
|
446
|
+
["O2", "item2", 20.0, "default_extra"],
|
447
|
+
]
|
448
|
+
decoder = build_engine_value_decoder(list[Order])
|
449
|
+
assert decoder(engine_val) == [
|
450
|
+
Order("O1", "item1", 10.0, "default_extra"),
|
451
|
+
Order("O2", "item2", 20.0, "default_extra"),
|
452
|
+
]
|
453
|
+
|
454
|
+
# List of structs (NamedTuple)
|
455
|
+
decoder = build_engine_value_decoder(list[OrderNamedTuple])
|
456
|
+
assert decoder(engine_val) == [
|
457
|
+
OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
|
458
|
+
OrderNamedTuple("O2", "item2", 20.0, "default_extra"),
|
459
|
+
]
|
460
|
+
|
461
|
+
|
462
|
+
def test_make_engine_value_decoder_struct_of_list() -> None:
|
463
|
+
# Struct with list field
|
464
|
+
engine_val = [
|
465
|
+
"Alice",
|
466
|
+
["O1", "item1", 10.0, "default_extra"],
|
467
|
+
[["vip"], ["premium"]],
|
468
|
+
]
|
469
|
+
decoder = build_engine_value_decoder(Customer)
|
470
|
+
assert decoder(engine_val) == Customer(
|
471
|
+
"Alice",
|
472
|
+
Order("O1", "item1", 10.0, "default_extra"),
|
473
|
+
[Tag("vip"), Tag("premium")],
|
474
|
+
)
|
475
|
+
|
476
|
+
# NamedTuple with list field
|
477
|
+
decoder = build_engine_value_decoder(CustomerNamedTuple)
|
478
|
+
assert decoder(engine_val) == CustomerNamedTuple(
|
479
|
+
"Alice",
|
480
|
+
OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
|
481
|
+
[Tag("vip"), Tag("premium")],
|
482
|
+
)
|
483
|
+
|
484
|
+
|
485
|
+
def test_make_engine_value_decoder_struct_of_struct() -> None:
|
486
|
+
# Struct with struct field
|
487
|
+
engine_val = [
|
488
|
+
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
|
489
|
+
[
|
490
|
+
["O1", "item1", 10.0, "default_extra"],
|
491
|
+
["O2", "item2", 20.0, "default_extra"],
|
492
|
+
],
|
493
|
+
2,
|
494
|
+
]
|
495
|
+
decoder = build_engine_value_decoder(NestedStruct)
|
496
|
+
assert decoder(engine_val) == NestedStruct(
|
497
|
+
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
|
498
|
+
[
|
499
|
+
Order("O1", "item1", 10.0, "default_extra"),
|
500
|
+
Order("O2", "item2", 20.0, "default_extra"),
|
501
|
+
],
|
502
|
+
2,
|
503
|
+
)
|
504
|
+
|
505
|
+
|
506
|
+
def make_engine_order(fields: list[tuple[str, type]]) -> type:
|
507
|
+
return make_dataclass("EngineOrder", fields)
|
508
|
+
|
509
|
+
|
510
|
+
def make_python_order(
|
511
|
+
fields: list[tuple[str, type]], defaults: dict[str, Any] | None = None
|
512
|
+
) -> type:
|
513
|
+
if defaults is None:
|
514
|
+
defaults = {}
|
515
|
+
# Move all fields with defaults to the end (Python dataclass requirement)
|
516
|
+
non_default_fields = [(n, t) for n, t in fields if n not in defaults]
|
517
|
+
default_fields = [(n, t) for n, t in fields if n in defaults]
|
518
|
+
ordered_fields = non_default_fields + default_fields
|
519
|
+
# Prepare the namespace for defaults (only for fields at the end)
|
520
|
+
namespace = {k: defaults[k] for k, _ in default_fields}
|
521
|
+
return make_dataclass("PythonOrder", ordered_fields, namespace=namespace)
|
522
|
+
|
523
|
+
|
524
|
+
@pytest.mark.parametrize(
|
525
|
+
"engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
|
526
|
+
[
|
527
|
+
# Extra field in Python (middle)
|
528
|
+
(
|
529
|
+
[("id", str), ("name", str)],
|
530
|
+
[("id", str), ("price", float), ("name", str)],
|
531
|
+
{"price": 0.0},
|
532
|
+
["O123", "mixed nuts"],
|
533
|
+
("O123", 0.0, "mixed nuts"),
|
534
|
+
),
|
535
|
+
# Missing field in Python (middle)
|
536
|
+
(
|
537
|
+
[("id", str), ("price", float), ("name", str)],
|
538
|
+
[("id", str), ("name", str)],
|
539
|
+
{},
|
540
|
+
["O123", 25.0, "mixed nuts"],
|
541
|
+
("O123", "mixed nuts"),
|
542
|
+
),
|
543
|
+
# Extra field in Python (start)
|
544
|
+
(
|
545
|
+
[("name", str), ("price", float)],
|
546
|
+
[("extra", str), ("name", str), ("price", float)],
|
547
|
+
{"extra": "default"},
|
548
|
+
["mixed nuts", 25.0],
|
549
|
+
("default", "mixed nuts", 25.0),
|
550
|
+
),
|
551
|
+
# Missing field in Python (start)
|
552
|
+
(
|
553
|
+
[("extra", str), ("name", str), ("price", float)],
|
554
|
+
[("name", str), ("price", float)],
|
555
|
+
{},
|
556
|
+
["unexpected", "mixed nuts", 25.0],
|
557
|
+
("mixed nuts", 25.0),
|
558
|
+
),
|
559
|
+
# Field order difference (should map by name)
|
560
|
+
(
|
561
|
+
[("id", str), ("name", str), ("price", float)],
|
562
|
+
[("name", str), ("id", str), ("price", float), ("extra", str)],
|
563
|
+
{"extra": "default"},
|
564
|
+
["O123", "mixed nuts", 25.0],
|
565
|
+
("mixed nuts", "O123", 25.0, "default"),
|
566
|
+
),
|
567
|
+
# Extra field (Python has extra field with default)
|
568
|
+
(
|
569
|
+
[("id", str), ("name", str)],
|
570
|
+
[("id", str), ("name", str), ("price", float)],
|
571
|
+
{"price": 0.0},
|
572
|
+
["O123", "mixed nuts"],
|
573
|
+
("O123", "mixed nuts", 0.0),
|
574
|
+
),
|
575
|
+
# Missing field (Engine has extra field)
|
576
|
+
(
|
577
|
+
[("id", str), ("name", str), ("price", float)],
|
578
|
+
[("id", str), ("name", str)],
|
579
|
+
{},
|
580
|
+
["O123", "mixed nuts", 25.0],
|
581
|
+
("O123", "mixed nuts"),
|
582
|
+
),
|
583
|
+
],
|
584
|
+
)
|
585
|
+
def test_field_position_cases(
|
586
|
+
engine_fields: list[tuple[str, type]],
|
587
|
+
python_fields: list[tuple[str, type]],
|
588
|
+
python_defaults: dict[str, Any],
|
589
|
+
engine_val: list[Any],
|
590
|
+
expected_python_val: tuple[Any, ...],
|
591
|
+
) -> None:
|
592
|
+
EngineOrder = make_engine_order(engine_fields)
|
593
|
+
PythonOrder = make_python_order(python_fields, python_defaults)
|
594
|
+
decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
|
595
|
+
# Map field names to expected values
|
596
|
+
expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
|
597
|
+
# Instantiate using keyword arguments (order doesn't matter)
|
598
|
+
assert decoder(engine_val) == PythonOrder(**expected_dict)
|
599
|
+
|
600
|
+
|
601
|
+
def test_roundtrip_union_simple() -> None:
|
602
|
+
t = int | str | float
|
603
|
+
value = 10.4
|
604
|
+
validate_full_roundtrip(value, t)
|
605
|
+
|
606
|
+
|
607
|
+
def test_roundtrip_union_with_active_uuid() -> None:
|
608
|
+
t = str | uuid.UUID | int
|
609
|
+
value = uuid.uuid4()
|
610
|
+
validate_full_roundtrip(value, t)
|
611
|
+
|
612
|
+
|
613
|
+
def test_roundtrip_union_with_inactive_uuid() -> None:
|
614
|
+
t = str | uuid.UUID | int
|
615
|
+
value = "5a9f8f6a-318f-4f1f-929d-566d7444a62d" # it's a string
|
616
|
+
validate_full_roundtrip(value, t)
|
617
|
+
|
618
|
+
|
619
|
+
def test_roundtrip_union_offset_datetime() -> None:
|
620
|
+
t = str | uuid.UUID | float | int | datetime.datetime
|
621
|
+
value = datetime.datetime.now(datetime.UTC)
|
622
|
+
validate_full_roundtrip(value, t)
|
623
|
+
|
624
|
+
|
625
|
+
def test_roundtrip_union_date() -> None:
|
626
|
+
t = str | uuid.UUID | float | int | datetime.date
|
627
|
+
value = datetime.date.today()
|
628
|
+
validate_full_roundtrip(value, t)
|
629
|
+
|
630
|
+
|
631
|
+
def test_roundtrip_union_time() -> None:
|
632
|
+
t = str | uuid.UUID | float | int | datetime.time
|
633
|
+
value = datetime.time()
|
634
|
+
validate_full_roundtrip(value, t)
|
635
|
+
|
636
|
+
|
637
|
+
def test_roundtrip_union_timedelta() -> None:
|
638
|
+
t = str | uuid.UUID | float | int | datetime.timedelta
|
639
|
+
value = datetime.timedelta(hours=39, minutes=10, seconds=1)
|
640
|
+
validate_full_roundtrip(value, t)
|
641
|
+
|
642
|
+
|
643
|
+
def test_roundtrip_vector_of_union() -> None:
|
644
|
+
t = list[str | int]
|
645
|
+
value = ["a", 1]
|
646
|
+
validate_full_roundtrip(value, t)
|
647
|
+
|
648
|
+
|
649
|
+
def test_roundtrip_union_with_vector() -> None:
|
650
|
+
t = NDArray[np.float32] | str
|
651
|
+
value = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
652
|
+
validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))
|
653
|
+
|
654
|
+
|
655
|
+
def test_roundtrip_union_with_misc_types() -> None:
|
656
|
+
t_bytes_union = int | bytes | str
|
657
|
+
validate_full_roundtrip(b"test_bytes", t_bytes_union)
|
658
|
+
validate_full_roundtrip(123, t_bytes_union)
|
659
|
+
|
660
|
+
t_range_union = cocoindex.Range | str | bool
|
661
|
+
validate_full_roundtrip((100, 200), t_range_union)
|
662
|
+
validate_full_roundtrip("test_string", t_range_union)
|
663
|
+
|
664
|
+
t_json_union = cocoindex.Json | int | bytes
|
665
|
+
json_dict = {"a": 1, "b": [2, 3]}
|
666
|
+
validate_full_roundtrip(json_dict, t_json_union)
|
667
|
+
validate_full_roundtrip(b"another_byte_string", t_json_union)
|
668
|
+
|
669
|
+
|
670
|
+
def test_roundtrip_ltable() -> None:
|
671
|
+
t = list[Order]
|
672
|
+
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
673
|
+
validate_full_roundtrip(value, t)
|
674
|
+
|
675
|
+
t_nt = list[OrderNamedTuple]
|
676
|
+
value_nt = [
|
677
|
+
OrderNamedTuple("O1", "item1", 10.0),
|
678
|
+
OrderNamedTuple("O2", "item2", 20.0),
|
679
|
+
]
|
680
|
+
validate_full_roundtrip(value_nt, t_nt)
|
681
|
+
|
682
|
+
|
683
|
+
def test_roundtrip_ktable_various_key_types() -> None:
|
684
|
+
@dataclass
|
685
|
+
class SimpleValue:
|
686
|
+
data: str
|
687
|
+
|
688
|
+
t_bytes_key = dict[bytes, SimpleValue]
|
689
|
+
value_bytes_key = {b"key1": SimpleValue("val1"), b"key2": SimpleValue("val2")}
|
690
|
+
validate_full_roundtrip(value_bytes_key, t_bytes_key)
|
691
|
+
|
692
|
+
t_int_key = dict[int, SimpleValue]
|
693
|
+
value_int_key = {1: SimpleValue("val1"), 2: SimpleValue("val2")}
|
694
|
+
validate_full_roundtrip(value_int_key, t_int_key)
|
695
|
+
|
696
|
+
t_bool_key = dict[bool, SimpleValue]
|
697
|
+
value_bool_key = {True: SimpleValue("val_true"), False: SimpleValue("val_false")}
|
698
|
+
validate_full_roundtrip(value_bool_key, t_bool_key)
|
699
|
+
|
700
|
+
t_str_key = dict[str, Order]
|
701
|
+
value_str_key = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
|
702
|
+
validate_full_roundtrip(value_str_key, t_str_key)
|
703
|
+
|
704
|
+
t_nt = dict[str, OrderNamedTuple]
|
705
|
+
value_nt = {
|
706
|
+
"K1": OrderNamedTuple("O1", "item1", 10.0),
|
707
|
+
"K2": OrderNamedTuple("O2", "item2", 20.0),
|
708
|
+
}
|
709
|
+
validate_full_roundtrip(value_nt, t_nt)
|
710
|
+
|
711
|
+
t_range_key = dict[cocoindex.Range, SimpleValue]
|
712
|
+
value_range_key = {
|
713
|
+
(1, 10): SimpleValue("val_range1"),
|
714
|
+
(20, 30): SimpleValue("val_range2"),
|
715
|
+
}
|
716
|
+
validate_full_roundtrip(value_range_key, t_range_key)
|
717
|
+
|
718
|
+
t_date_key = dict[datetime.date, SimpleValue]
|
719
|
+
value_date_key = {
|
720
|
+
datetime.date(2023, 1, 1): SimpleValue("val_date1"),
|
721
|
+
datetime.date(2024, 2, 2): SimpleValue("val_date2"),
|
722
|
+
}
|
723
|
+
validate_full_roundtrip(value_date_key, t_date_key)
|
724
|
+
|
725
|
+
t_uuid_key = dict[uuid.UUID, SimpleValue]
|
726
|
+
value_uuid_key = {
|
727
|
+
uuid.uuid4(): SimpleValue("val_uuid1"),
|
728
|
+
uuid.uuid4(): SimpleValue("val_uuid2"),
|
729
|
+
}
|
730
|
+
validate_full_roundtrip(value_uuid_key, t_uuid_key)
|
731
|
+
|
732
|
+
|
733
|
+
def test_roundtrip_ktable_struct_key() -> None:
|
734
|
+
@dataclass(frozen=True)
|
735
|
+
class OrderKey:
|
736
|
+
shop_id: str
|
737
|
+
version: int
|
738
|
+
|
739
|
+
t = dict[OrderKey, Order]
|
740
|
+
value = {
|
741
|
+
OrderKey("A", 3): Order("O1", "item1", 10.0),
|
742
|
+
OrderKey("B", 4): Order("O2", "item2", 20.0),
|
743
|
+
}
|
744
|
+
validate_full_roundtrip(value, t)
|
745
|
+
|
746
|
+
t_nt = dict[OrderKey, OrderNamedTuple]
|
747
|
+
value_nt = {
|
748
|
+
OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
|
749
|
+
OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
|
750
|
+
}
|
751
|
+
validate_full_roundtrip(value_nt, t_nt)
|
752
|
+
|
753
|
+
|
754
|
+
IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
|
755
|
+
|
756
|
+
|
757
|
+
def test_vector_as_vector() -> None:
|
758
|
+
value = np.array([1, 2, 3, 4, 5], dtype=np.int64)
|
759
|
+
encoded = encode_engine_value(value, IntVectorType)
|
760
|
+
assert np.array_equal(encoded, value)
|
761
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
762
|
+
assert np.array_equal(decoded, value)
|
763
|
+
|
764
|
+
|
765
|
+
ListIntType = list[int]
|
766
|
+
|
767
|
+
|
768
|
+
def test_vector_as_list() -> None:
|
769
|
+
value: ListIntType = [1, 2, 3, 4, 5]
|
770
|
+
encoded = encode_engine_value(value, ListIntType)
|
771
|
+
assert encoded == [1, 2, 3, 4, 5]
|
772
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
773
|
+
assert np.array_equal(decoded, value)
|
774
|
+
|
775
|
+
|
776
|
+
Float64VectorTypeNoDim = Vector[np.float64]
|
777
|
+
Float32VectorType = Vector[np.float32, Literal[3]]
|
778
|
+
Float64VectorType = Vector[np.float64, Literal[3]]
|
779
|
+
Int64VectorType = Vector[np.int64, Literal[3]]
|
780
|
+
NDArrayFloat32Type = NDArray[np.float32]
|
781
|
+
NDArrayFloat64Type = NDArray[np.float64]
|
782
|
+
NDArrayInt64Type = NDArray[np.int64]
|
783
|
+
|
784
|
+
|
785
|
+
def test_encode_engine_value_ndarray() -> None:
|
786
|
+
"""Test encoding NDArray vectors to lists for the Rust engine."""
|
787
|
+
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
788
|
+
assert np.array_equal(
|
789
|
+
encode_engine_value(vec_f32, Float32VectorType), [1.0, 2.0, 3.0]
|
790
|
+
)
|
791
|
+
vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
792
|
+
assert np.array_equal(
|
793
|
+
encode_engine_value(vec_f64, Float64VectorType), [1.0, 2.0, 3.0]
|
794
|
+
)
|
795
|
+
vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
|
796
|
+
assert np.array_equal(encode_engine_value(vec_i64, Int64VectorType), [1, 2, 3])
|
797
|
+
vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
798
|
+
assert np.array_equal(
|
799
|
+
encode_engine_value(vec_nd_f32, NDArrayFloat32Type), [1.0, 2.0, 3.0]
|
800
|
+
)
|
801
|
+
|
802
|
+
|
803
|
+
def test_make_engine_value_decoder_ndarray() -> None:
|
804
|
+
"""Test decoding engine lists to NDArray vectors."""
|
805
|
+
decoder_f32 = build_engine_value_decoder(Float32VectorType)
|
806
|
+
result_f32 = decoder_f32([1.0, 2.0, 3.0])
|
807
|
+
assert isinstance(result_f32, np.ndarray)
|
808
|
+
assert result_f32.dtype == np.float32
|
809
|
+
assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
810
|
+
decoder_f64 = build_engine_value_decoder(Float64VectorType)
|
811
|
+
result_f64 = decoder_f64([1.0, 2.0, 3.0])
|
812
|
+
assert isinstance(result_f64, np.ndarray)
|
813
|
+
assert result_f64.dtype == np.float64
|
814
|
+
assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
|
815
|
+
decoder_i64 = build_engine_value_decoder(Int64VectorType)
|
816
|
+
result_i64 = decoder_i64([1, 2, 3])
|
817
|
+
assert isinstance(result_i64, np.ndarray)
|
818
|
+
assert result_i64.dtype == np.int64
|
819
|
+
assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
|
820
|
+
decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
|
821
|
+
result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
|
822
|
+
assert isinstance(result_nd_f32, np.ndarray)
|
823
|
+
assert result_nd_f32.dtype == np.float32
|
824
|
+
assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
825
|
+
|
826
|
+
|
827
|
+
def test_roundtrip_ndarray_vector() -> None:
|
828
|
+
"""Test roundtrip encoding and decoding of NDArray vectors."""
|
829
|
+
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
830
|
+
encoded_f32 = encode_engine_value(value_f32, Float32VectorType)
|
831
|
+
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
|
832
|
+
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
|
833
|
+
assert isinstance(decoded_f32, np.ndarray)
|
834
|
+
assert decoded_f32.dtype == np.float32
|
835
|
+
assert np.array_equal(decoded_f32, value_f32)
|
836
|
+
value_i64 = np.array([1, 2, 3], dtype=np.int64)
|
837
|
+
encoded_i64 = encode_engine_value(value_i64, Int64VectorType)
|
838
|
+
assert np.array_equal(encoded_i64, [1, 2, 3])
|
839
|
+
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
|
840
|
+
assert isinstance(decoded_i64, np.ndarray)
|
841
|
+
assert decoded_i64.dtype == np.int64
|
842
|
+
assert np.array_equal(decoded_i64, value_i64)
|
843
|
+
value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
844
|
+
encoded_nd_f64 = encode_engine_value(value_nd_f64, NDArrayFloat64Type)
|
845
|
+
assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
|
846
|
+
decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
|
847
|
+
assert isinstance(decoded_nd_f64, np.ndarray)
|
848
|
+
assert decoded_nd_f64.dtype == np.float64
|
849
|
+
assert np.array_equal(decoded_nd_f64, value_nd_f64)
|
850
|
+
|
851
|
+
|
852
|
+
def test_ndarray_dimension_mismatch() -> None:
|
853
|
+
"""Test dimension enforcement for Vector with specified dimension."""
|
854
|
+
value = np.array([1.0, 2.0], dtype=np.float32)
|
855
|
+
encoded = encode_engine_value(value, NDArray[np.float32])
|
856
|
+
assert np.array_equal(encoded, [1.0, 2.0])
|
857
|
+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
|
858
|
+
build_engine_value_decoder(Float32VectorType)(encoded)
|
859
|
+
|
860
|
+
|
861
|
+
def test_list_vector_backward_compatibility() -> None:
|
862
|
+
"""Test that list-based vectors still work for backward compatibility."""
|
863
|
+
value = [1, 2, 3, 4, 5]
|
864
|
+
encoded = encode_engine_value(value, list[int])
|
865
|
+
assert encoded == [1, 2, 3, 4, 5]
|
866
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
867
|
+
assert isinstance(decoded, np.ndarray)
|
868
|
+
assert decoded.dtype == np.int64
|
869
|
+
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
|
870
|
+
value_list: ListIntType = [1, 2, 3, 4, 5]
|
871
|
+
encoded = encode_engine_value(value_list, ListIntType)
|
872
|
+
assert np.array_equal(encoded, [1, 2, 3, 4, 5])
|
873
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
874
|
+
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
|
875
|
+
|
876
|
+
|
877
|
+
def test_encode_complex_structure_with_ndarray() -> None:
|
878
|
+
"""Test encoding a complex structure that includes an NDArray."""
|
879
|
+
|
880
|
+
@dataclass
|
881
|
+
class MyStructWithNDArray:
|
882
|
+
name: str
|
883
|
+
data: NDArray[np.float32]
|
884
|
+
value: int
|
885
|
+
|
886
|
+
original = MyStructWithNDArray(
|
887
|
+
name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
|
888
|
+
)
|
889
|
+
encoded = encode_engine_value(original, MyStructWithNDArray)
|
890
|
+
|
891
|
+
assert encoded[0] == original.name
|
892
|
+
assert np.array_equal(encoded[1], original.data)
|
893
|
+
assert encoded[2] == original.value
|
894
|
+
|
895
|
+
|
896
|
+
def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
897
|
+
"""Test decoding a nullable NDArray with None or value inputs."""
|
898
|
+
src_type_dict = {
|
899
|
+
"kind": "Vector",
|
900
|
+
"element_type": {"kind": "Float32"},
|
901
|
+
"dimension": None,
|
902
|
+
}
|
903
|
+
dst_annotation = NDArrayFloat32Type | None
|
904
|
+
decoder = make_engine_value_decoder(
|
905
|
+
[], src_type_dict, analyze_type_info(dst_annotation)
|
906
|
+
)
|
907
|
+
|
908
|
+
none_engine_value = None
|
909
|
+
decoded_array = decoder(none_engine_value)
|
910
|
+
assert decoded_array is None
|
911
|
+
|
912
|
+
engine_value = [1.0, 2.0, 3.0]
|
913
|
+
decoded_array = decoder(engine_value)
|
914
|
+
|
915
|
+
assert isinstance(decoded_array, np.ndarray)
|
916
|
+
assert decoded_array.dtype == np.float32
|
917
|
+
np.testing.assert_array_equal(
|
918
|
+
decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
919
|
+
)
|
920
|
+
|
921
|
+
|
922
|
+
def test_decode_vector_string() -> None:
|
923
|
+
"""Test decoding a vector of strings works for Python native list type."""
|
924
|
+
src_type_dict = {
|
925
|
+
"kind": "Vector",
|
926
|
+
"element_type": {"kind": "Str"},
|
927
|
+
"dimension": None,
|
928
|
+
}
|
929
|
+
decoder = make_engine_value_decoder(
|
930
|
+
[], src_type_dict, analyze_type_info(Vector[str])
|
931
|
+
)
|
932
|
+
assert decoder(["hello", "world"]) == ["hello", "world"]
|
933
|
+
|
934
|
+
|
935
|
+
def test_decode_error_non_nullable_or_non_list_vector() -> None:
|
936
|
+
"""Test decoding errors for non-nullable vectors or non-list inputs."""
|
937
|
+
src_type_dict = {
|
938
|
+
"kind": "Vector",
|
939
|
+
"element_type": {"kind": "Float32"},
|
940
|
+
"dimension": None,
|
941
|
+
}
|
942
|
+
decoder = make_engine_value_decoder(
|
943
|
+
[], src_type_dict, analyze_type_info(NDArrayFloat32Type)
|
944
|
+
)
|
945
|
+
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
|
946
|
+
decoder(None)
|
947
|
+
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
|
948
|
+
decoder("not a list")
|
949
|
+
|
950
|
+
|
951
|
+
def test_dump_vector_type_annotation_with_dim() -> None:
|
952
|
+
"""Test dumping a vector type annotation with a specified dimension."""
|
953
|
+
expected_dump = {
|
954
|
+
"type": {
|
955
|
+
"kind": "Vector",
|
956
|
+
"element_type": {"kind": "Float32"},
|
957
|
+
"dimension": 3,
|
958
|
+
}
|
959
|
+
}
|
960
|
+
assert dump_engine_object(Float32VectorType) == expected_dump
|
961
|
+
|
962
|
+
|
963
|
+
def test_dump_vector_type_annotation_no_dim() -> None:
|
964
|
+
"""Test dumping a vector type annotation with no dimension."""
|
965
|
+
expected_dump_no_dim = {
|
966
|
+
"type": {
|
967
|
+
"kind": "Vector",
|
968
|
+
"element_type": {"kind": "Float64"},
|
969
|
+
"dimension": None,
|
970
|
+
}
|
971
|
+
}
|
972
|
+
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
|
973
|
+
|
974
|
+
|
975
|
+
def test_full_roundtrip_vector_numeric_types() -> None:
|
976
|
+
"""Test full roundtrip for numeric vector types using NDArray."""
|
977
|
+
value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
978
|
+
validate_full_roundtrip(
|
979
|
+
value_f32,
|
980
|
+
Vector[np.float32, Literal[3]],
|
981
|
+
([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
|
982
|
+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
|
983
|
+
([1.0, 2.0, 3.0], list[float]),
|
984
|
+
)
|
985
|
+
validate_full_roundtrip(
|
986
|
+
value_f32,
|
987
|
+
np.typing.NDArray[np.float32],
|
988
|
+
([np.float32(1.0), np.float32(2.0), np.float32(3.0)], list[np.float32]),
|
989
|
+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
|
990
|
+
([1.0, 2.0, 3.0], list[float]),
|
991
|
+
)
|
992
|
+
validate_full_roundtrip(
|
993
|
+
value_f32.tolist(),
|
994
|
+
list[np.float32],
|
995
|
+
(value_f32, Vector[np.float32, Literal[3]]),
|
996
|
+
([1.0, 2.0, 3.0], list[cocoindex.Float32]),
|
997
|
+
([1.0, 2.0, 3.0], list[float]),
|
998
|
+
)
|
999
|
+
|
1000
|
+
value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
1001
|
+
validate_full_roundtrip(
|
1002
|
+
value_f64,
|
1003
|
+
Vector[np.float64, Literal[3]],
|
1004
|
+
([np.float64(1.0), np.float64(2.0), np.float64(3.0)], list[np.float64]),
|
1005
|
+
([1.0, 2.0, 3.0], list[cocoindex.Float64]),
|
1006
|
+
([1.0, 2.0, 3.0], list[float]),
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
value_i64 = np.array([1, 2, 3], dtype=np.int64)
|
1010
|
+
validate_full_roundtrip(
|
1011
|
+
value_i64,
|
1012
|
+
Vector[np.int64, Literal[3]],
|
1013
|
+
([np.int64(1), np.int64(2), np.int64(3)], list[np.int64]),
|
1014
|
+
([1, 2, 3], list[int]),
|
1015
|
+
)
|
1016
|
+
|
1017
|
+
value_i32 = np.array([1, 2, 3], dtype=np.int32)
|
1018
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1019
|
+
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
|
1020
|
+
value_u8 = np.array([1, 2, 3], dtype=np.uint8)
|
1021
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1022
|
+
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
|
1023
|
+
value_u16 = np.array([1, 2, 3], dtype=np.uint16)
|
1024
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1025
|
+
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
|
1026
|
+
value_u32 = np.array([1, 2, 3], dtype=np.uint32)
|
1027
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1028
|
+
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
|
1029
|
+
value_u64 = np.array([1, 2, 3], dtype=np.uint64)
|
1030
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1031
|
+
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
|
1032
|
+
|
1033
|
+
|
1034
|
+
def test_full_roundtrip_vector_of_vector() -> None:
|
1035
|
+
"""Test full roundtrip for vector of vector."""
|
1036
|
+
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
1037
|
+
validate_full_roundtrip(
|
1038
|
+
value_f32,
|
1039
|
+
Vector[Vector[np.float32, Literal[3]], Literal[2]],
|
1040
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
|
1041
|
+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
|
1042
|
+
(
|
1043
|
+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
1044
|
+
list[Vector[cocoindex.Float32, Literal[3]]],
|
1045
|
+
),
|
1046
|
+
(
|
1047
|
+
value_f32,
|
1048
|
+
np.typing.NDArray[np.float32],
|
1049
|
+
),
|
1050
|
+
)
|
1051
|
+
|
1052
|
+
|
1053
|
+
def test_full_roundtrip_vector_other_types() -> None:
|
1054
|
+
"""Test full roundtrip for Vector with non-numeric basic types."""
|
1055
|
+
uuid_list = [uuid.uuid4(), uuid.uuid4()]
|
1056
|
+
validate_full_roundtrip(uuid_list, Vector[uuid.UUID], (uuid_list, list[uuid.UUID]))
|
1057
|
+
|
1058
|
+
date_list = [datetime.date(2023, 1, 1), datetime.date(2024, 10, 5)]
|
1059
|
+
validate_full_roundtrip(
|
1060
|
+
date_list, Vector[datetime.date], (date_list, list[datetime.date])
|
1061
|
+
)
|
1062
|
+
|
1063
|
+
bool_list = [True, False, True, False]
|
1064
|
+
validate_full_roundtrip(bool_list, Vector[bool], (bool_list, list[bool]))
|
1065
|
+
|
1066
|
+
validate_full_roundtrip([], Vector[uuid.UUID], ([], list[uuid.UUID]))
|
1067
|
+
validate_full_roundtrip([], Vector[datetime.date], ([], list[datetime.date]))
|
1068
|
+
validate_full_roundtrip([], Vector[bool], ([], list[bool]))
|
1069
|
+
|
1070
|
+
|
1071
|
+
def test_roundtrip_vector_no_dimension() -> None:
|
1072
|
+
"""Test full roundtrip for vector types without dimension annotation."""
|
1073
|
+
value_f64 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
1074
|
+
validate_full_roundtrip(
|
1075
|
+
value_f64,
|
1076
|
+
Vector[np.float64],
|
1077
|
+
([1.0, 2.0, 3.0], list[float]),
|
1078
|
+
(np.array([1.0, 2.0, 3.0], dtype=np.float64), np.typing.NDArray[np.float64]),
|
1079
|
+
)
|
1080
|
+
|
1081
|
+
|
1082
|
+
def test_roundtrip_string_vector() -> None:
|
1083
|
+
"""Test full roundtrip for string vector using list."""
|
1084
|
+
value_str: Vector[str] = ["hello", "world"]
|
1085
|
+
validate_full_roundtrip(value_str, Vector[str])
|
1086
|
+
|
1087
|
+
|
1088
|
+
def test_roundtrip_empty_vector() -> None:
|
1089
|
+
"""Test full roundtrip for empty numeric vector."""
|
1090
|
+
value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
|
1091
|
+
validate_full_roundtrip(value_empty, Vector[np.float32])
|
1092
|
+
|
1093
|
+
|
1094
|
+
def test_roundtrip_dimension_mismatch() -> None:
|
1095
|
+
"""Test that dimension mismatch raises an error during roundtrip."""
|
1096
|
+
value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
|
1097
|
+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
|
1098
|
+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
|
1099
|
+
|
1100
|
+
|
1101
|
+
def test_full_roundtrip_scalar_numeric_types() -> None:
|
1102
|
+
"""Test full roundtrip for scalar NumPy numeric types."""
|
1103
|
+
# Test supported scalar types
|
1104
|
+
validate_full_roundtrip(np.int64(42), np.int64, (42, int))
|
1105
|
+
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, cocoindex.Float32))
|
1106
|
+
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, cocoindex.Float64))
|
1107
|
+
|
1108
|
+
# Test unsupported scalar types
|
1109
|
+
for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
|
1110
|
+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
|
1111
|
+
validate_full_roundtrip(unsupported_type(1), unsupported_type)
|
1112
|
+
|
1113
|
+
|
1114
|
+
def test_full_roundtrip_nullable_scalar() -> None:
|
1115
|
+
"""Test full roundtrip for nullable scalar NumPy types."""
|
1116
|
+
# Test with non-null values
|
1117
|
+
validate_full_roundtrip(np.int64(42), np.int64 | None)
|
1118
|
+
validate_full_roundtrip(np.float32(3.14), np.float32 | None)
|
1119
|
+
validate_full_roundtrip(np.float64(2.718), np.float64 | None)
|
1120
|
+
|
1121
|
+
# Test with None
|
1122
|
+
validate_full_roundtrip(None, np.int64 | None)
|
1123
|
+
validate_full_roundtrip(None, np.float32 | None)
|
1124
|
+
validate_full_roundtrip(None, np.float64 | None)
|
1125
|
+
|
1126
|
+
|
1127
|
+
def test_full_roundtrip_scalar_in_struct() -> None:
|
1128
|
+
"""Test full roundtrip for scalar NumPy types in a dataclass."""
|
1129
|
+
|
1130
|
+
@dataclass
|
1131
|
+
class NumericStruct:
|
1132
|
+
int_field: np.int64
|
1133
|
+
float32_field: np.float32
|
1134
|
+
float64_field: np.float64
|
1135
|
+
|
1136
|
+
instance = NumericStruct(
|
1137
|
+
int_field=np.int64(42),
|
1138
|
+
float32_field=np.float32(3.14),
|
1139
|
+
float64_field=np.float64(2.718),
|
1140
|
+
)
|
1141
|
+
validate_full_roundtrip(instance, NumericStruct)
|
1142
|
+
|
1143
|
+
|
1144
|
+
def test_full_roundtrip_scalar_in_nested_struct() -> None:
|
1145
|
+
"""Test full roundtrip for scalar NumPy types in a nested struct."""
|
1146
|
+
|
1147
|
+
@dataclass
|
1148
|
+
class InnerStruct:
|
1149
|
+
value: np.float64
|
1150
|
+
|
1151
|
+
@dataclass
|
1152
|
+
class OuterStruct:
|
1153
|
+
inner: InnerStruct
|
1154
|
+
count: np.int64
|
1155
|
+
|
1156
|
+
instance = OuterStruct(
|
1157
|
+
inner=InnerStruct(value=np.float64(2.718)),
|
1158
|
+
count=np.int64(1),
|
1159
|
+
)
|
1160
|
+
validate_full_roundtrip(instance, OuterStruct)
|
1161
|
+
|
1162
|
+
|
1163
|
+
def test_full_roundtrip_scalar_with_python_types() -> None:
|
1164
|
+
"""Test full roundtrip for structs mixing NumPy and Python scalar types."""
|
1165
|
+
|
1166
|
+
@dataclass
|
1167
|
+
class MixedStruct:
|
1168
|
+
numpy_int: np.int64
|
1169
|
+
python_int: int
|
1170
|
+
numpy_float: np.float64
|
1171
|
+
python_float: float
|
1172
|
+
string: str
|
1173
|
+
annotated_int: Annotated[np.int64, TypeKind("Int64")]
|
1174
|
+
annotated_float: Float32
|
1175
|
+
|
1176
|
+
instance = MixedStruct(
|
1177
|
+
numpy_int=np.int64(42),
|
1178
|
+
python_int=43,
|
1179
|
+
numpy_float=np.float64(2.718),
|
1180
|
+
python_float=3.14,
|
1181
|
+
string="hello, world",
|
1182
|
+
annotated_int=np.int64(42),
|
1183
|
+
annotated_float=2.0,
|
1184
|
+
)
|
1185
|
+
validate_full_roundtrip(instance, MixedStruct)
|
1186
|
+
|
1187
|
+
|
1188
|
+
def test_roundtrip_simple_struct_to_dict_binding() -> None:
|
1189
|
+
"""Test struct -> dict binding with Any annotation."""
|
1190
|
+
|
1191
|
+
@dataclass
|
1192
|
+
class SimpleStruct:
|
1193
|
+
first_name: str
|
1194
|
+
last_name: str
|
1195
|
+
|
1196
|
+
instance = SimpleStruct("John", "Doe")
|
1197
|
+
expected_dict = {"first_name": "John", "last_name": "Doe"}
|
1198
|
+
|
1199
|
+
# Test Any annotation
|
1200
|
+
validate_full_roundtrip(
|
1201
|
+
instance,
|
1202
|
+
SimpleStruct,
|
1203
|
+
(expected_dict, Any),
|
1204
|
+
(expected_dict, dict),
|
1205
|
+
(expected_dict, dict[Any, Any]),
|
1206
|
+
(expected_dict, dict[str, Any]),
|
1207
|
+
# For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
|
1208
|
+
(expected_dict, dict[Any, str]),
|
1209
|
+
(expected_dict, dict[str, str]),
|
1210
|
+
)
|
1211
|
+
|
1212
|
+
with pytest.raises(ValueError):
|
1213
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
|
1214
|
+
|
1215
|
+
with pytest.raises(ValueError):
|
1216
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
|
1217
|
+
|
1218
|
+
|
1219
|
+
def test_roundtrip_struct_to_dict_binding() -> None:
|
1220
|
+
"""Test struct -> dict binding with Any annotation."""
|
1221
|
+
|
1222
|
+
@dataclass
|
1223
|
+
class SimpleStruct:
|
1224
|
+
name: str
|
1225
|
+
value: int
|
1226
|
+
price: float
|
1227
|
+
|
1228
|
+
instance = SimpleStruct("test", 42, 3.14)
|
1229
|
+
expected_dict = {"name": "test", "value": 42, "price": 3.14}
|
1230
|
+
|
1231
|
+
# Test Any annotation
|
1232
|
+
validate_full_roundtrip(
|
1233
|
+
instance,
|
1234
|
+
SimpleStruct,
|
1235
|
+
(expected_dict, Any),
|
1236
|
+
(expected_dict, dict),
|
1237
|
+
(expected_dict, dict[Any, Any]),
|
1238
|
+
(expected_dict, dict[str, Any]),
|
1239
|
+
)
|
1240
|
+
|
1241
|
+
with pytest.raises(ValueError):
|
1242
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
|
1243
|
+
|
1244
|
+
with pytest.raises(ValueError):
|
1245
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
|
1246
|
+
|
1247
|
+
|
1248
|
+
def test_roundtrip_struct_to_dict_explicit() -> None:
|
1249
|
+
"""Test struct -> dict binding with explicit dict annotations."""
|
1250
|
+
|
1251
|
+
@dataclass
|
1252
|
+
class Product:
|
1253
|
+
id: str
|
1254
|
+
name: str
|
1255
|
+
price: float
|
1256
|
+
active: bool
|
1257
|
+
|
1258
|
+
instance = Product("P1", "Widget", 29.99, True)
|
1259
|
+
expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}
|
1260
|
+
|
1261
|
+
# Test explicit dict annotations
|
1262
|
+
validate_full_roundtrip(
|
1263
|
+
instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
|
1267
|
+
def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
|
1268
|
+
"""Test struct -> dict binding with None annotation."""
|
1269
|
+
|
1270
|
+
@dataclass
|
1271
|
+
class Config:
|
1272
|
+
host: str
|
1273
|
+
port: int
|
1274
|
+
debug: bool
|
1275
|
+
|
1276
|
+
instance = Config("localhost", 8080, True)
|
1277
|
+
expected_dict = {"host": "localhost", "port": 8080, "debug": True}
|
1278
|
+
|
1279
|
+
# Test empty annotation (should be treated as Any)
|
1280
|
+
validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
|
1281
|
+
|
1282
|
+
|
1283
|
+
def test_roundtrip_struct_to_dict_nested() -> None:
|
1284
|
+
"""Test struct -> dict binding with nested structs."""
|
1285
|
+
|
1286
|
+
@dataclass
|
1287
|
+
class Address:
|
1288
|
+
street: str
|
1289
|
+
city: str
|
1290
|
+
|
1291
|
+
@dataclass
|
1292
|
+
class Person:
|
1293
|
+
name: str
|
1294
|
+
age: int
|
1295
|
+
address: Address
|
1296
|
+
|
1297
|
+
address = Address("123 Main St", "Anytown")
|
1298
|
+
person = Person("John", 30, address)
|
1299
|
+
expected_dict = {
|
1300
|
+
"name": "John",
|
1301
|
+
"age": 30,
|
1302
|
+
"address": {"street": "123 Main St", "city": "Anytown"},
|
1303
|
+
}
|
1304
|
+
|
1305
|
+
# Test nested struct conversion
|
1306
|
+
validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))
|
1307
|
+
|
1308
|
+
|
1309
|
+
def test_roundtrip_struct_to_dict_with_list() -> None:
|
1310
|
+
"""Test struct -> dict binding with list fields."""
|
1311
|
+
|
1312
|
+
@dataclass
|
1313
|
+
class Team:
|
1314
|
+
name: str
|
1315
|
+
members: list[str]
|
1316
|
+
active: bool
|
1317
|
+
|
1318
|
+
instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
|
1319
|
+
expected_dict = {
|
1320
|
+
"name": "Dev Team",
|
1321
|
+
"members": ["Alice", "Bob", "Charlie"],
|
1322
|
+
"active": True,
|
1323
|
+
}
|
1324
|
+
|
1325
|
+
validate_full_roundtrip(instance, Team, (expected_dict, dict))
|
1326
|
+
|
1327
|
+
|
1328
|
+
def test_roundtrip_namedtuple_to_dict_binding() -> None:
|
1329
|
+
"""Test NamedTuple -> dict binding."""
|
1330
|
+
|
1331
|
+
class Point(NamedTuple):
|
1332
|
+
x: float
|
1333
|
+
y: float
|
1334
|
+
z: float
|
1335
|
+
|
1336
|
+
instance = Point(1.0, 2.0, 3.0)
|
1337
|
+
expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}
|
1338
|
+
|
1339
|
+
validate_full_roundtrip(
|
1340
|
+
instance, Point, (expected_dict, dict), (expected_dict, Any)
|
1341
|
+
)
|
1342
|
+
|
1343
|
+
|
1344
|
+
def test_roundtrip_ltable_to_list_dict_binding() -> None:
|
1345
|
+
"""Test LTable -> list[dict] binding with Any annotation."""
|
1346
|
+
|
1347
|
+
@dataclass
|
1348
|
+
class User:
|
1349
|
+
id: str
|
1350
|
+
name: str
|
1351
|
+
age: int
|
1352
|
+
|
1353
|
+
users = [User("u1", "Alice", 25), User("u2", "Bob", 30), User("u3", "Charlie", 35)]
|
1354
|
+
expected_list_dict = [
|
1355
|
+
{"id": "u1", "name": "Alice", "age": 25},
|
1356
|
+
{"id": "u2", "name": "Bob", "age": 30},
|
1357
|
+
{"id": "u3", "name": "Charlie", "age": 35},
|
1358
|
+
]
|
1359
|
+
|
1360
|
+
# Test Any annotation
|
1361
|
+
validate_full_roundtrip(
|
1362
|
+
users,
|
1363
|
+
list[User],
|
1364
|
+
(expected_list_dict, Any),
|
1365
|
+
(expected_list_dict, list[Any]),
|
1366
|
+
(expected_list_dict, list[dict[str, Any]]),
|
1367
|
+
)
|
1368
|
+
|
1369
|
+
|
1370
|
+
def test_roundtrip_ktable_to_dict_dict_binding() -> None:
|
1371
|
+
"""Test KTable -> dict[K, dict] binding with Any annotation."""
|
1372
|
+
|
1373
|
+
@dataclass
|
1374
|
+
class Product:
|
1375
|
+
name: str
|
1376
|
+
price: float
|
1377
|
+
active: bool
|
1378
|
+
|
1379
|
+
products = {
|
1380
|
+
"p1": Product("Widget", 29.99, True),
|
1381
|
+
"p2": Product("Gadget", 49.99, False),
|
1382
|
+
"p3": Product("Tool", 19.99, True),
|
1383
|
+
}
|
1384
|
+
expected_dict_dict = {
|
1385
|
+
"p1": {"name": "Widget", "price": 29.99, "active": True},
|
1386
|
+
"p2": {"name": "Gadget", "price": 49.99, "active": False},
|
1387
|
+
"p3": {"name": "Tool", "price": 19.99, "active": True},
|
1388
|
+
}
|
1389
|
+
|
1390
|
+
# Test Any annotation
|
1391
|
+
validate_full_roundtrip(
|
1392
|
+
products,
|
1393
|
+
dict[str, Product],
|
1394
|
+
(expected_dict_dict, Any),
|
1395
|
+
(expected_dict_dict, dict),
|
1396
|
+
(expected_dict_dict, dict[Any, Any]),
|
1397
|
+
(expected_dict_dict, dict[str, Any]),
|
1398
|
+
(expected_dict_dict, dict[Any, dict[Any, Any]]),
|
1399
|
+
(expected_dict_dict, dict[str, dict[Any, Any]]),
|
1400
|
+
(expected_dict_dict, dict[str, dict[str, Any]]),
|
1401
|
+
)
|
1402
|
+
|
1403
|
+
|
1404
|
+
def test_roundtrip_ktable_with_complex_key() -> None:
|
1405
|
+
"""Test KTable with complex key types -> dict binding."""
|
1406
|
+
|
1407
|
+
@dataclass(frozen=True)
|
1408
|
+
class OrderKey:
|
1409
|
+
shop_id: str
|
1410
|
+
version: int
|
1411
|
+
|
1412
|
+
@dataclass
|
1413
|
+
class Order:
|
1414
|
+
customer: str
|
1415
|
+
total: float
|
1416
|
+
|
1417
|
+
orders = {
|
1418
|
+
OrderKey("shop1", 1): Order("Alice", 100.0),
|
1419
|
+
OrderKey("shop2", 2): Order("Bob", 200.0),
|
1420
|
+
}
|
1421
|
+
expected_dict_dict = {
|
1422
|
+
("shop1", 1): {"customer": "Alice", "total": 100.0},
|
1423
|
+
("shop2", 2): {"customer": "Bob", "total": 200.0},
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
# Test Any annotation
|
1427
|
+
validate_full_roundtrip(
|
1428
|
+
orders,
|
1429
|
+
dict[OrderKey, Order],
|
1430
|
+
(expected_dict_dict, Any),
|
1431
|
+
(expected_dict_dict, dict),
|
1432
|
+
(expected_dict_dict, dict[Any, Any]),
|
1433
|
+
(expected_dict_dict, dict[Any, dict[str, Any]]),
|
1434
|
+
(
|
1435
|
+
{
|
1436
|
+
("shop1", 1): Order("Alice", 100.0),
|
1437
|
+
("shop2", 2): Order("Bob", 200.0),
|
1438
|
+
},
|
1439
|
+
dict[Any, Order],
|
1440
|
+
),
|
1441
|
+
(
|
1442
|
+
{
|
1443
|
+
OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
|
1444
|
+
OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
|
1445
|
+
},
|
1446
|
+
dict[OrderKey, Any],
|
1447
|
+
),
|
1448
|
+
)
|
1449
|
+
|
1450
|
+
|
1451
|
+
def test_roundtrip_ltable_with_nested_structs() -> None:
|
1452
|
+
"""Test LTable with nested structs -> list[dict] binding."""
|
1453
|
+
|
1454
|
+
@dataclass
|
1455
|
+
class Address:
|
1456
|
+
street: str
|
1457
|
+
city: str
|
1458
|
+
|
1459
|
+
@dataclass
|
1460
|
+
class Person:
|
1461
|
+
name: str
|
1462
|
+
age: int
|
1463
|
+
address: Address
|
1464
|
+
|
1465
|
+
people = [
|
1466
|
+
Person("John", 30, Address("123 Main St", "Anytown")),
|
1467
|
+
Person("Jane", 25, Address("456 Oak Ave", "Somewhere")),
|
1468
|
+
]
|
1469
|
+
expected_list_dict = [
|
1470
|
+
{
|
1471
|
+
"name": "John",
|
1472
|
+
"age": 30,
|
1473
|
+
"address": {"street": "123 Main St", "city": "Anytown"},
|
1474
|
+
},
|
1475
|
+
{
|
1476
|
+
"name": "Jane",
|
1477
|
+
"age": 25,
|
1478
|
+
"address": {"street": "456 Oak Ave", "city": "Somewhere"},
|
1479
|
+
},
|
1480
|
+
]
|
1481
|
+
|
1482
|
+
# Test Any annotation
|
1483
|
+
validate_full_roundtrip(people, list[Person], (expected_list_dict, Any))
|
1484
|
+
|
1485
|
+
|
1486
|
+
def test_roundtrip_ktable_with_list_fields() -> None:
|
1487
|
+
"""Test KTable with list fields -> dict binding."""
|
1488
|
+
|
1489
|
+
@dataclass
|
1490
|
+
class Team:
|
1491
|
+
name: str
|
1492
|
+
members: list[str]
|
1493
|
+
active: bool
|
1494
|
+
|
1495
|
+
teams = {
|
1496
|
+
"team1": Team("Dev Team", ["Alice", "Bob"], True),
|
1497
|
+
"team2": Team("QA Team", ["Charlie", "David"], False),
|
1498
|
+
}
|
1499
|
+
expected_dict_dict = {
|
1500
|
+
"team1": {"name": "Dev Team", "members": ["Alice", "Bob"], "active": True},
|
1501
|
+
"team2": {"name": "QA Team", "members": ["Charlie", "David"], "active": False},
|
1502
|
+
}
|
1503
|
+
|
1504
|
+
# Test Any annotation
|
1505
|
+
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
|
1506
|
+
|
1507
|
+
|
1508
|
+
def test_auto_default_for_supported_and_unsupported_types() -> None:
|
1509
|
+
@dataclass
|
1510
|
+
class Base:
|
1511
|
+
a: int
|
1512
|
+
|
1513
|
+
@dataclass
|
1514
|
+
class NullableField:
|
1515
|
+
a: int
|
1516
|
+
b: int | None
|
1517
|
+
|
1518
|
+
@dataclass
|
1519
|
+
class LTableField:
|
1520
|
+
a: int
|
1521
|
+
b: list[Base]
|
1522
|
+
|
1523
|
+
@dataclass
|
1524
|
+
class KTableField:
|
1525
|
+
a: int
|
1526
|
+
b: dict[str, Base]
|
1527
|
+
|
1528
|
+
@dataclass
|
1529
|
+
class UnsupportedField:
|
1530
|
+
a: int
|
1531
|
+
b: int
|
1532
|
+
|
1533
|
+
validate_full_roundtrip(NullableField(1, None), NullableField)
|
1534
|
+
|
1535
|
+
validate_full_roundtrip(LTableField(1, []), LTableField)
|
1536
|
+
|
1537
|
+
validate_full_roundtrip(KTableField(1, {}), KTableField)
|
1538
|
+
|
1539
|
+
with pytest.raises(
|
1540
|
+
ValueError,
|
1541
|
+
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
|
1542
|
+
):
|
1543
|
+
build_engine_value_decoder(Base, UnsupportedField)
|