cocoindex 0.1.73__cp312-cp312-macosx_11_0_arm64.whl → 0.1.75__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +5 -3
- cocoindex/_engine.cpython-312-darwin.so +0 -0
- cocoindex/convert.py +188 -167
- cocoindex/flow.py +29 -12
- cocoindex/op.py +264 -6
- cocoindex/tests/test_convert.py +184 -151
- cocoindex/tests/test_transform_flow.py +103 -0
- cocoindex/tests/test_typing.py +5 -12
- cocoindex/typing.py +8 -8
- {cocoindex-0.1.73.dist-info → cocoindex-0.1.75.dist-info}/METADATA +5 -3
- {cocoindex-0.1.73.dist-info → cocoindex-0.1.75.dist-info}/RECORD +14 -13
- {cocoindex-0.1.73.dist-info → cocoindex-0.1.75.dist-info}/WHEEL +1 -1
- {cocoindex-0.1.73.dist-info → cocoindex-0.1.75.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.73.dist-info → cocoindex-0.1.75.dist-info}/licenses/LICENSE +0 -0
cocoindex/tests/test_convert.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import datetime
|
2
|
+
import inspect
|
2
3
|
import uuid
|
3
|
-
from dataclasses import dataclass, make_dataclass
|
4
|
+
from dataclasses import dataclass, make_dataclass, field
|
4
5
|
from typing import Annotated, Any, Callable, Literal, NamedTuple
|
5
6
|
|
6
7
|
import numpy as np
|
@@ -19,6 +20,7 @@ from cocoindex.typing import (
|
|
19
20
|
TypeKind,
|
20
21
|
Vector,
|
21
22
|
encode_enriched_type,
|
23
|
+
analyze_type_info,
|
22
24
|
)
|
23
25
|
|
24
26
|
|
@@ -75,7 +77,9 @@ def build_engine_value_decoder(
|
|
75
77
|
If python_type is not specified, uses engine_type_in_py as the target.
|
76
78
|
"""
|
77
79
|
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
78
|
-
return make_engine_value_decoder(
|
80
|
+
return make_engine_value_decoder(
|
81
|
+
[], engine_type, analyze_type_info(python_type or engine_type_in_py)
|
82
|
+
)
|
79
83
|
|
80
84
|
|
81
85
|
def validate_full_roundtrip_to(
|
@@ -103,7 +107,9 @@ def validate_full_roundtrip_to(
|
|
103
107
|
)
|
104
108
|
|
105
109
|
for other_value, other_type in decoded_values:
|
106
|
-
decoder = make_engine_value_decoder(
|
110
|
+
decoder = make_engine_value_decoder(
|
111
|
+
[], encoded_output_type, analyze_type_info(other_type)
|
112
|
+
)
|
107
113
|
other_decoded_value = decoder(value_from_engine)
|
108
114
|
assert eq(other_decoded_value, other_value), (
|
109
115
|
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
|
@@ -231,19 +237,24 @@ def test_encode_engine_value_none() -> None:
|
|
231
237
|
|
232
238
|
|
233
239
|
def test_roundtrip_basic_types() -> None:
|
234
|
-
validate_full_roundtrip(
|
240
|
+
validate_full_roundtrip(
|
241
|
+
b"hello world",
|
242
|
+
bytes,
|
243
|
+
(b"hello world", inspect.Parameter.empty),
|
244
|
+
(b"hello world", Any),
|
245
|
+
)
|
235
246
|
validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
|
236
|
-
validate_full_roundtrip("hello", str, ("hello",
|
237
|
-
validate_full_roundtrip(True, bool, (True,
|
238
|
-
validate_full_roundtrip(False, bool, (False,
|
247
|
+
validate_full_roundtrip("hello", str, ("hello", Any))
|
248
|
+
validate_full_roundtrip(True, bool, (True, Any))
|
249
|
+
validate_full_roundtrip(False, bool, (False, Any))
|
239
250
|
validate_full_roundtrip(
|
240
|
-
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42,
|
251
|
+
42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
|
241
252
|
)
|
242
253
|
validate_full_roundtrip(42, int, (42, cocoindex.Int64))
|
243
254
|
validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
|
244
255
|
|
245
256
|
validate_full_roundtrip(
|
246
|
-
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25,
|
257
|
+
3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
|
247
258
|
)
|
248
259
|
validate_full_roundtrip(3.25, float, (3.25, Float64))
|
249
260
|
validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
|
@@ -255,35 +266,35 @@ def test_roundtrip_basic_types() -> None:
|
|
255
266
|
(np.float32(3.25), np.float32),
|
256
267
|
(np.float64(3.25), np.float64),
|
257
268
|
(3.25, Float64),
|
258
|
-
(3.25,
|
269
|
+
(3.25, Any),
|
259
270
|
)
|
260
271
|
validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
|
261
272
|
|
262
273
|
|
263
274
|
def test_roundtrip_uuid() -> None:
|
264
275
|
uuid_value = uuid.uuid4()
|
265
|
-
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value,
|
276
|
+
validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
|
266
277
|
|
267
278
|
|
268
279
|
def test_roundtrip_range() -> None:
|
269
280
|
r1 = (0, 100)
|
270
|
-
validate_full_roundtrip(r1, cocoindex.Range, (r1,
|
281
|
+
validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
|
271
282
|
r2 = (50, 50)
|
272
|
-
validate_full_roundtrip(r2, cocoindex.Range, (r2,
|
283
|
+
validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
|
273
284
|
r3 = (0, 1_000_000_000)
|
274
|
-
validate_full_roundtrip(r3, cocoindex.Range, (r3,
|
285
|
+
validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
|
275
286
|
|
276
287
|
|
277
288
|
def test_roundtrip_time() -> None:
|
278
289
|
t1 = datetime.time(10, 30, 50, 123456)
|
279
|
-
validate_full_roundtrip(t1, datetime.time, (t1,
|
290
|
+
validate_full_roundtrip(t1, datetime.time, (t1, Any))
|
280
291
|
t2 = datetime.time(23, 59, 59)
|
281
|
-
validate_full_roundtrip(t2, datetime.time, (t2,
|
292
|
+
validate_full_roundtrip(t2, datetime.time, (t2, Any))
|
282
293
|
t3 = datetime.time(0, 0, 0)
|
283
|
-
validate_full_roundtrip(t3, datetime.time, (t3,
|
294
|
+
validate_full_roundtrip(t3, datetime.time, (t3, Any))
|
284
295
|
|
285
296
|
validate_full_roundtrip(
|
286
|
-
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1),
|
297
|
+
datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
|
287
298
|
)
|
288
299
|
|
289
300
|
validate_full_roundtrip(
|
@@ -328,11 +339,11 @@ def test_roundtrip_timedelta() -> None:
|
|
328
339
|
td1 = datetime.timedelta(
|
329
340
|
days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
|
330
341
|
)
|
331
|
-
validate_full_roundtrip(td1, datetime.timedelta, (td1,
|
342
|
+
validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
|
332
343
|
td2 = datetime.timedelta(days=-5, hours=-2)
|
333
|
-
validate_full_roundtrip(td2, datetime.timedelta, (td2,
|
344
|
+
validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
|
334
345
|
td3 = datetime.timedelta(0)
|
335
|
-
validate_full_roundtrip(td3, datetime.timedelta, (td3,
|
346
|
+
validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
|
336
347
|
|
337
348
|
|
338
349
|
def test_roundtrip_json() -> None:
|
@@ -364,7 +375,9 @@ def test_decode_scalar_numpy_values() -> None:
|
|
364
375
|
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
|
365
376
|
]
|
366
377
|
for src_type, dst_type, input_value, expected in test_cases:
|
367
|
-
decoder = make_engine_value_decoder(
|
378
|
+
decoder = make_engine_value_decoder(
|
379
|
+
["field"], src_type, analyze_type_info(dst_type)
|
380
|
+
)
|
368
381
|
result = decoder(input_value)
|
369
382
|
assert isinstance(result, dst_type)
|
370
383
|
assert result == expected
|
@@ -378,7 +391,9 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
378
391
|
"dimension": None,
|
379
392
|
}
|
380
393
|
dst_type_float = list[np.float64]
|
381
|
-
decoder = make_engine_value_decoder(
|
394
|
+
decoder = make_engine_value_decoder(
|
395
|
+
["field"], src_type, analyze_type_info(dst_type_float)
|
396
|
+
)
|
382
397
|
input_numbers = [1.0, 2.0, 3.0]
|
383
398
|
result = decoder(input_numbers)
|
384
399
|
assert isinstance(result, list)
|
@@ -388,7 +403,9 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
388
403
|
# Test list[Uuid]
|
389
404
|
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
|
390
405
|
dst_type_uuid = list[uuid.UUID]
|
391
|
-
decoder = make_engine_value_decoder(
|
406
|
+
decoder = make_engine_value_decoder(
|
407
|
+
["field"], src_type, analyze_type_info(dst_type_uuid)
|
408
|
+
)
|
392
409
|
uuid1 = uuid.uuid4()
|
393
410
|
uuid2 = uuid.uuid4()
|
394
411
|
input_uuids = [uuid1, uuid2]
|
@@ -398,124 +415,15 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
398
415
|
assert result == [uuid1, uuid2]
|
399
416
|
|
400
417
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
# All fields match (NamedTuple)
|
411
|
-
(
|
412
|
-
OrderNamedTuple,
|
413
|
-
["O123", "mixed nuts", 25.0, "default_extra"],
|
414
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
415
|
-
),
|
416
|
-
# Extra field in engine value (should ignore extra)
|
417
|
-
(
|
418
|
-
Order,
|
419
|
-
["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
|
420
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
421
|
-
),
|
422
|
-
(
|
423
|
-
OrderNamedTuple,
|
424
|
-
["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
|
425
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
426
|
-
),
|
427
|
-
# Fewer fields in engine value (should fill with default)
|
428
|
-
(
|
429
|
-
Order,
|
430
|
-
["O123", "mixed nuts", 0.0, "default_extra"],
|
431
|
-
Order("O123", "mixed nuts", 0.0, "default_extra"),
|
432
|
-
),
|
433
|
-
(
|
434
|
-
OrderNamedTuple,
|
435
|
-
["O123", "mixed nuts", 0.0, "default_extra"],
|
436
|
-
OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra"),
|
437
|
-
),
|
438
|
-
# More fields in engine value (should ignore extra)
|
439
|
-
(
|
440
|
-
Order,
|
441
|
-
["O123", "mixed nuts", 25.0, "unexpected"],
|
442
|
-
Order("O123", "mixed nuts", 25.0, "unexpected"),
|
443
|
-
),
|
444
|
-
(
|
445
|
-
OrderNamedTuple,
|
446
|
-
["O123", "mixed nuts", 25.0, "unexpected"],
|
447
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected"),
|
448
|
-
),
|
449
|
-
# Truly extra field (should ignore the fifth field)
|
450
|
-
(
|
451
|
-
Order,
|
452
|
-
["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
|
453
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
454
|
-
),
|
455
|
-
(
|
456
|
-
OrderNamedTuple,
|
457
|
-
["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
|
458
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
459
|
-
),
|
460
|
-
# Missing optional field in engine value (tags=None)
|
461
|
-
(
|
462
|
-
Customer,
|
463
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], None],
|
464
|
-
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
|
465
|
-
),
|
466
|
-
(
|
467
|
-
CustomerNamedTuple,
|
468
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], None],
|
469
|
-
CustomerNamedTuple(
|
470
|
-
"Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
|
471
|
-
),
|
472
|
-
),
|
473
|
-
# Extra field in engine value for Customer (should ignore)
|
474
|
-
(
|
475
|
-
Customer,
|
476
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
|
477
|
-
Customer(
|
478
|
-
"Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]
|
479
|
-
),
|
480
|
-
),
|
481
|
-
(
|
482
|
-
CustomerNamedTuple,
|
483
|
-
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
|
484
|
-
CustomerNamedTuple(
|
485
|
-
"Alice",
|
486
|
-
OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
|
487
|
-
[Tag("vip")],
|
488
|
-
),
|
489
|
-
),
|
490
|
-
# Missing optional field with default
|
491
|
-
(
|
492
|
-
Order,
|
493
|
-
["O123", "mixed nuts", 25.0],
|
494
|
-
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
495
|
-
),
|
496
|
-
(
|
497
|
-
OrderNamedTuple,
|
498
|
-
["O123", "mixed nuts", 25.0],
|
499
|
-
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
500
|
-
),
|
501
|
-
# Partial optional fields
|
502
|
-
(
|
503
|
-
Customer,
|
504
|
-
["Alice", ["O1", "item1", 10.0]],
|
505
|
-
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
|
506
|
-
),
|
507
|
-
(
|
508
|
-
CustomerNamedTuple,
|
509
|
-
["Alice", ["O1", "item1", 10.0]],
|
510
|
-
CustomerNamedTuple(
|
511
|
-
"Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
|
512
|
-
),
|
513
|
-
),
|
514
|
-
],
|
515
|
-
)
|
516
|
-
def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
|
517
|
-
decoder = build_engine_value_decoder(data_type)
|
518
|
-
assert decoder(engine_val) == expected
|
418
|
+
def test_roundtrip_struct() -> None:
|
419
|
+
validate_full_roundtrip(
|
420
|
+
Order("O123", "mixed nuts", 25.0, "default_extra"),
|
421
|
+
Order,
|
422
|
+
)
|
423
|
+
validate_full_roundtrip(
|
424
|
+
OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
|
425
|
+
OrderNamedTuple,
|
426
|
+
)
|
519
427
|
|
520
428
|
|
521
429
|
def test_make_engine_value_decoder_list_of_struct() -> None:
|
@@ -974,7 +882,9 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
|
974
882
|
"dimension": None,
|
975
883
|
}
|
976
884
|
dst_annotation = NDArrayFloat32Type | None
|
977
|
-
decoder = make_engine_value_decoder(
|
885
|
+
decoder = make_engine_value_decoder(
|
886
|
+
[], src_type_dict, analyze_type_info(dst_annotation)
|
887
|
+
)
|
978
888
|
|
979
889
|
none_engine_value = None
|
980
890
|
decoded_array = decoder(none_engine_value)
|
@@ -997,7 +907,9 @@ def test_decode_vector_string() -> None:
|
|
997
907
|
"element_type": {"kind": "Str"},
|
998
908
|
"dimension": None,
|
999
909
|
}
|
1000
|
-
decoder = make_engine_value_decoder(
|
910
|
+
decoder = make_engine_value_decoder(
|
911
|
+
[], src_type_dict, analyze_type_info(Vector[str])
|
912
|
+
)
|
1001
913
|
assert decoder(["hello", "world"]) == ["hello", "world"]
|
1002
914
|
|
1003
915
|
|
@@ -1008,7 +920,9 @@ def test_decode_error_non_nullable_or_non_list_vector() -> None:
|
|
1008
920
|
"element_type": {"kind": "Float32"},
|
1009
921
|
"dimension": None,
|
1010
922
|
}
|
1011
|
-
decoder = make_engine_value_decoder(
|
923
|
+
decoder = make_engine_value_decoder(
|
924
|
+
[], src_type_dict, analyze_type_info(NDArrayFloat32Type)
|
925
|
+
)
|
1012
926
|
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
|
1013
927
|
decoder(None)
|
1014
928
|
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
|
@@ -1252,6 +1166,37 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
|
|
1252
1166
|
validate_full_roundtrip(instance, MixedStruct)
|
1253
1167
|
|
1254
1168
|
|
1169
|
+
def test_roundtrip_simple_struct_to_dict_binding() -> None:
|
1170
|
+
"""Test struct -> dict binding with Any annotation."""
|
1171
|
+
|
1172
|
+
@dataclass
|
1173
|
+
class SimpleStruct:
|
1174
|
+
first_name: str
|
1175
|
+
last_name: str
|
1176
|
+
|
1177
|
+
instance = SimpleStruct("John", "Doe")
|
1178
|
+
expected_dict = {"first_name": "John", "last_name": "Doe"}
|
1179
|
+
|
1180
|
+
# Test Any annotation
|
1181
|
+
validate_full_roundtrip(
|
1182
|
+
instance,
|
1183
|
+
SimpleStruct,
|
1184
|
+
(expected_dict, Any),
|
1185
|
+
(expected_dict, dict),
|
1186
|
+
(expected_dict, dict[Any, Any]),
|
1187
|
+
(expected_dict, dict[str, Any]),
|
1188
|
+
# For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
|
1189
|
+
(expected_dict, dict[Any, str]),
|
1190
|
+
(expected_dict, dict[str, str]),
|
1191
|
+
)
|
1192
|
+
|
1193
|
+
with pytest.raises(ValueError):
|
1194
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
|
1195
|
+
|
1196
|
+
with pytest.raises(ValueError):
|
1197
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
|
1198
|
+
|
1199
|
+
|
1255
1200
|
def test_roundtrip_struct_to_dict_binding() -> None:
|
1256
1201
|
"""Test struct -> dict binding with Any annotation."""
|
1257
1202
|
|
@@ -1265,7 +1210,20 @@ def test_roundtrip_struct_to_dict_binding() -> None:
|
|
1265
1210
|
expected_dict = {"name": "test", "value": 42, "price": 3.14}
|
1266
1211
|
|
1267
1212
|
# Test Any annotation
|
1268
|
-
validate_full_roundtrip(
|
1213
|
+
validate_full_roundtrip(
|
1214
|
+
instance,
|
1215
|
+
SimpleStruct,
|
1216
|
+
(expected_dict, Any),
|
1217
|
+
(expected_dict, dict),
|
1218
|
+
(expected_dict, dict[Any, Any]),
|
1219
|
+
(expected_dict, dict[str, Any]),
|
1220
|
+
)
|
1221
|
+
|
1222
|
+
with pytest.raises(ValueError):
|
1223
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
|
1224
|
+
|
1225
|
+
with pytest.raises(ValueError):
|
1226
|
+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
|
1269
1227
|
|
1270
1228
|
|
1271
1229
|
def test_roundtrip_struct_to_dict_explicit() -> None:
|
@@ -1299,8 +1257,8 @@ def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
|
|
1299
1257
|
instance = Config("localhost", 8080, True)
|
1300
1258
|
expected_dict = {"host": "localhost", "port": 8080, "debug": True}
|
1301
1259
|
|
1302
|
-
# Test
|
1303
|
-
validate_full_roundtrip(instance, Config, (expected_dict,
|
1260
|
+
# Test empty annotation (should be treated as Any)
|
1261
|
+
validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
|
1304
1262
|
|
1305
1263
|
|
1306
1264
|
def test_roundtrip_struct_to_dict_nested() -> None:
|
@@ -1381,7 +1339,13 @@ def test_roundtrip_ltable_to_list_dict_binding() -> None:
|
|
1381
1339
|
]
|
1382
1340
|
|
1383
1341
|
# Test Any annotation
|
1384
|
-
validate_full_roundtrip(
|
1342
|
+
validate_full_roundtrip(
|
1343
|
+
users,
|
1344
|
+
list[User],
|
1345
|
+
(expected_list_dict, Any),
|
1346
|
+
(expected_list_dict, list[Any]),
|
1347
|
+
(expected_list_dict, list[dict[str, Any]]),
|
1348
|
+
)
|
1385
1349
|
|
1386
1350
|
|
1387
1351
|
def test_roundtrip_ktable_to_dict_dict_binding() -> None:
|
@@ -1405,7 +1369,17 @@ def test_roundtrip_ktable_to_dict_dict_binding() -> None:
|
|
1405
1369
|
}
|
1406
1370
|
|
1407
1371
|
# Test Any annotation
|
1408
|
-
validate_full_roundtrip(
|
1372
|
+
validate_full_roundtrip(
|
1373
|
+
products,
|
1374
|
+
dict[str, Product],
|
1375
|
+
(expected_dict_dict, Any),
|
1376
|
+
(expected_dict_dict, dict),
|
1377
|
+
(expected_dict_dict, dict[Any, Any]),
|
1378
|
+
(expected_dict_dict, dict[str, Any]),
|
1379
|
+
(expected_dict_dict, dict[Any, dict[Any, Any]]),
|
1380
|
+
(expected_dict_dict, dict[str, dict[Any, Any]]),
|
1381
|
+
(expected_dict_dict, dict[str, dict[str, Any]]),
|
1382
|
+
)
|
1409
1383
|
|
1410
1384
|
|
1411
1385
|
def test_roundtrip_ktable_with_complex_key() -> None:
|
@@ -1431,7 +1405,28 @@ def test_roundtrip_ktable_with_complex_key() -> None:
|
|
1431
1405
|
}
|
1432
1406
|
|
1433
1407
|
# Test Any annotation
|
1434
|
-
validate_full_roundtrip(
|
1408
|
+
validate_full_roundtrip(
|
1409
|
+
orders,
|
1410
|
+
dict[OrderKey, Order],
|
1411
|
+
(expected_dict_dict, Any),
|
1412
|
+
(expected_dict_dict, dict),
|
1413
|
+
(expected_dict_dict, dict[Any, Any]),
|
1414
|
+
(expected_dict_dict, dict[Any, dict[str, Any]]),
|
1415
|
+
(
|
1416
|
+
{
|
1417
|
+
("shop1", 1): Order("Alice", 100.0),
|
1418
|
+
("shop2", 2): Order("Bob", 200.0),
|
1419
|
+
},
|
1420
|
+
dict[Any, Order],
|
1421
|
+
),
|
1422
|
+
(
|
1423
|
+
{
|
1424
|
+
OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
|
1425
|
+
OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
|
1426
|
+
},
|
1427
|
+
dict[OrderKey, Any],
|
1428
|
+
),
|
1429
|
+
)
|
1435
1430
|
|
1436
1431
|
|
1437
1432
|
def test_roundtrip_ltable_with_nested_structs() -> None:
|
@@ -1489,3 +1484,41 @@ def test_roundtrip_ktable_with_list_fields() -> None:
|
|
1489
1484
|
|
1490
1485
|
# Test Any annotation
|
1491
1486
|
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
|
1487
|
+
|
1488
|
+
|
1489
|
+
def test_auto_default_for_supported_and_unsupported_types() -> None:
|
1490
|
+
@dataclass
|
1491
|
+
class Base:
|
1492
|
+
a: int
|
1493
|
+
|
1494
|
+
@dataclass
|
1495
|
+
class NullableField:
|
1496
|
+
a: int
|
1497
|
+
b: int | None
|
1498
|
+
|
1499
|
+
@dataclass
|
1500
|
+
class LTableField:
|
1501
|
+
a: int
|
1502
|
+
b: list[Base]
|
1503
|
+
|
1504
|
+
@dataclass
|
1505
|
+
class KTableField:
|
1506
|
+
a: int
|
1507
|
+
b: dict[str, Base]
|
1508
|
+
|
1509
|
+
@dataclass
|
1510
|
+
class UnsupportedField:
|
1511
|
+
a: int
|
1512
|
+
b: int
|
1513
|
+
|
1514
|
+
validate_full_roundtrip(NullableField(1, None), NullableField)
|
1515
|
+
|
1516
|
+
validate_full_roundtrip(LTableField(1, []), LTableField)
|
1517
|
+
|
1518
|
+
# validate_full_roundtrip(KTableField(1, {}), KTableField)
|
1519
|
+
|
1520
|
+
with pytest.raises(
|
1521
|
+
ValueError,
|
1522
|
+
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
|
1523
|
+
):
|
1524
|
+
build_engine_value_decoder(Base, UnsupportedField)
|
@@ -0,0 +1,103 @@
|
|
1
|
+
import typing
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
|
7
|
+
import cocoindex
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class Child:
|
12
|
+
value: int
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Parent:
|
17
|
+
children: list[Child]
|
18
|
+
|
19
|
+
|
20
|
+
# Fixture to initialize CocoIndex library
|
21
|
+
@pytest.fixture(scope="session", autouse=True)
|
22
|
+
def init_cocoindex() -> typing.Generator[None, None, None]:
|
23
|
+
cocoindex.init()
|
24
|
+
yield
|
25
|
+
|
26
|
+
|
27
|
+
@cocoindex.op.function()
|
28
|
+
def add_suffix(text: str) -> str:
|
29
|
+
"""Append ' world' to the input text."""
|
30
|
+
return f"{text} world"
|
31
|
+
|
32
|
+
|
33
|
+
@cocoindex.transform_flow()
|
34
|
+
def simple_transform(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
|
35
|
+
"""Transform flow that applies add_suffix to input text."""
|
36
|
+
return text.transform(add_suffix)
|
37
|
+
|
38
|
+
|
39
|
+
@cocoindex.op.function()
|
40
|
+
def extract_value(value: int) -> int:
|
41
|
+
"""Extracts the value."""
|
42
|
+
return value
|
43
|
+
|
44
|
+
|
45
|
+
@cocoindex.transform_flow()
|
46
|
+
def for_each_transform(
|
47
|
+
data: cocoindex.DataSlice[Parent],
|
48
|
+
) -> cocoindex.DataSlice[Any]:
|
49
|
+
"""Transform flow that processes child rows to extract values."""
|
50
|
+
with data["children"].row() as child:
|
51
|
+
child["new_field"] = child["value"].transform(extract_value)
|
52
|
+
return data
|
53
|
+
|
54
|
+
|
55
|
+
def test_simple_transform_flow() -> None:
|
56
|
+
"""Test the simple transform flow."""
|
57
|
+
input_text = "hello"
|
58
|
+
result = simple_transform.eval(input_text)
|
59
|
+
assert result == "hello world", f"Expected 'hello world', got {result}"
|
60
|
+
|
61
|
+
result = simple_transform.eval("")
|
62
|
+
assert result == " world", f"Expected ' world', got {result}"
|
63
|
+
|
64
|
+
|
65
|
+
@pytest.mark.asyncio
|
66
|
+
async def test_simple_transform_flow_async() -> None:
|
67
|
+
"""Test the simple transform flow asynchronously."""
|
68
|
+
input_text = "async"
|
69
|
+
result = await simple_transform.eval_async(input_text)
|
70
|
+
assert result == "async world", f"Expected 'async world', got {result}"
|
71
|
+
|
72
|
+
|
73
|
+
def test_for_each_transform_flow() -> None:
|
74
|
+
"""Test the complex transform flow with child rows."""
|
75
|
+
input_data = Parent(children=[Child(1), Child(2), Child(3)])
|
76
|
+
result = for_each_transform.eval(input_data)
|
77
|
+
expected = {
|
78
|
+
"children": [
|
79
|
+
{"value": 1, "new_field": 1},
|
80
|
+
{"value": 2, "new_field": 2},
|
81
|
+
{"value": 3, "new_field": 3},
|
82
|
+
]
|
83
|
+
}
|
84
|
+
assert result == expected, f"Expected {expected}, got {result}"
|
85
|
+
|
86
|
+
input_data = Parent(children=[])
|
87
|
+
result = for_each_transform.eval(input_data)
|
88
|
+
assert result == {"children": []}, f"Expected {{'children': []}}, got {result}"
|
89
|
+
|
90
|
+
|
91
|
+
@pytest.mark.asyncio
|
92
|
+
async def test_for_each_transform_flow_async() -> None:
|
93
|
+
"""Test the complex transform flow asynchronously."""
|
94
|
+
input_data = Parent(children=[Child(4), Child(5)])
|
95
|
+
result = await for_each_transform.eval_async(input_data)
|
96
|
+
expected = {
|
97
|
+
"children": [
|
98
|
+
{"value": 4, "new_field": 4},
|
99
|
+
{"value": 5, "new_field": 5},
|
100
|
+
]
|
101
|
+
}
|
102
|
+
|
103
|
+
assert result == expected, f"Expected {expected}, got {result}"
|
cocoindex/tests/test_typing.py
CHANGED
@@ -13,6 +13,7 @@ from cocoindex.typing import (
|
|
13
13
|
AnalyzedDictType,
|
14
14
|
AnalyzedListType,
|
15
15
|
AnalyzedStructType,
|
16
|
+
AnalyzedUnknownType,
|
16
17
|
AnalyzedTypeInfo,
|
17
18
|
TypeAttr,
|
18
19
|
TypeKind,
|
@@ -422,15 +423,7 @@ def test_annotated_list_with_type_kind() -> None:
|
|
422
423
|
assert result.variant.kind == "Struct"
|
423
424
|
|
424
425
|
|
425
|
-
def
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
):
|
430
|
-
analyze_type_info(set)
|
431
|
-
|
432
|
-
with pytest.raises(
|
433
|
-
ValueError,
|
434
|
-
match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'numpy.complex64'>",
|
435
|
-
):
|
436
|
-
Vector[np.complex64]
|
426
|
+
def test_unknown_type() -> None:
|
427
|
+
typ = set
|
428
|
+
result = analyze_type_info(typ)
|
429
|
+
assert isinstance(result.variant, AnalyzedUnknownType)
|
cocoindex/typing.py
CHANGED
@@ -262,7 +262,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
262
262
|
|
263
263
|
if kind is not None:
|
264
264
|
variant = AnalyzedBasicType(kind=kind)
|
265
|
-
elif base_type is
|
265
|
+
elif base_type is Any or base_type is inspect.Parameter.empty:
|
266
266
|
variant = AnalyzedAnyType()
|
267
267
|
elif is_struct_type(base_type):
|
268
268
|
variant = AnalyzedStructType(struct_type=t)
|
@@ -270,15 +270,15 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
270
270
|
kind = DtypeRegistry.validate_dtype_and_get_kind(t)
|
271
271
|
variant = AnalyzedBasicType(kind=kind)
|
272
272
|
elif base_type is collections.abc.Sequence or base_type is list:
|
273
|
-
elem_type = type_args[0] if len(type_args) > 0 else
|
273
|
+
elem_type = type_args[0] if len(type_args) > 0 else Any
|
274
274
|
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
|
275
275
|
elif base_type is np.ndarray:
|
276
276
|
np_number_type = t
|
277
277
|
elem_type = extract_ndarray_elem_dtype(np_number_type)
|
278
278
|
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
|
279
279
|
elif base_type is collections.abc.Mapping or base_type is dict or t is dict:
|
280
|
-
key_type = type_args[0] if len(type_args) > 0 else
|
281
|
-
elem_type = type_args[1] if len(type_args) > 1 else
|
280
|
+
key_type = type_args[0] if len(type_args) > 0 else Any
|
281
|
+
elem_type = type_args[1] if len(type_args) > 1 else Any
|
282
282
|
variant = AnalyzedDictType(key_type=key_type, value_type=elem_type)
|
283
283
|
elif base_type in (types.UnionType, typing.Union):
|
284
284
|
non_none_types = [arg for arg in type_args if arg not in (None, types.NoneType)]
|
@@ -313,11 +313,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
313
313
|
kind = "OffsetDateTime"
|
314
314
|
elif t is datetime.timedelta:
|
315
315
|
kind = "TimeDelta"
|
316
|
+
|
317
|
+
if kind is None:
|
318
|
+
variant = AnalyzedUnknownType()
|
316
319
|
else:
|
317
|
-
|
318
|
-
f"Unsupported as a specific type annotation for CocoIndex data type (https://cocoindex.io/docs/core/data_types): {t}"
|
319
|
-
)
|
320
|
-
variant = AnalyzedBasicType(kind=kind)
|
320
|
+
variant = AnalyzedBasicType(kind=kind)
|
321
321
|
|
322
322
|
return AnalyzedTypeInfo(
|
323
323
|
core_type=core_type,
|