cocoindex 0.1.73__cp313-cp313t-manylinux_2_28_aarch64.whl → 0.1.75__cp313-cp313t-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  import datetime
2
+ import inspect
2
3
  import uuid
3
- from dataclasses import dataclass, make_dataclass
4
+ from dataclasses import dataclass, make_dataclass, field
4
5
  from typing import Annotated, Any, Callable, Literal, NamedTuple
5
6
 
6
7
  import numpy as np
@@ -19,6 +20,7 @@ from cocoindex.typing import (
19
20
  TypeKind,
20
21
  Vector,
21
22
  encode_enriched_type,
23
+ analyze_type_info,
22
24
  )
23
25
 
24
26
 
@@ -75,7 +77,9 @@ def build_engine_value_decoder(
75
77
  If python_type is not specified, uses engine_type_in_py as the target.
76
78
  """
77
79
  engine_type = encode_enriched_type(engine_type_in_py)["type"]
78
- return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
80
+ return make_engine_value_decoder(
81
+ [], engine_type, analyze_type_info(python_type or engine_type_in_py)
82
+ )
79
83
 
80
84
 
81
85
  def validate_full_roundtrip_to(
@@ -103,7 +107,9 @@ def validate_full_roundtrip_to(
103
107
  )
104
108
 
105
109
  for other_value, other_type in decoded_values:
106
- decoder = make_engine_value_decoder([], encoded_output_type, other_type)
110
+ decoder = make_engine_value_decoder(
111
+ [], encoded_output_type, analyze_type_info(other_type)
112
+ )
107
113
  other_decoded_value = decoder(value_from_engine)
108
114
  assert eq(other_decoded_value, other_value), (
109
115
  f"Expected {other_value} but got {other_decoded_value} for {other_type}"
@@ -231,19 +237,24 @@ def test_encode_engine_value_none() -> None:
231
237
 
232
238
 
233
239
  def test_roundtrip_basic_types() -> None:
234
- validate_full_roundtrip(b"hello world", bytes, (b"hello world", None))
240
+ validate_full_roundtrip(
241
+ b"hello world",
242
+ bytes,
243
+ (b"hello world", inspect.Parameter.empty),
244
+ (b"hello world", Any),
245
+ )
235
246
  validate_full_roundtrip(b"\x00\x01\x02\xff\xfe", bytes)
236
- validate_full_roundtrip("hello", str, ("hello", None))
237
- validate_full_roundtrip(True, bool, (True, None))
238
- validate_full_roundtrip(False, bool, (False, None))
247
+ validate_full_roundtrip("hello", str, ("hello", Any))
248
+ validate_full_roundtrip(True, bool, (True, Any))
249
+ validate_full_roundtrip(False, bool, (False, Any))
239
250
  validate_full_roundtrip(
240
- 42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, None)
251
+ 42, cocoindex.Int64, (42, int), (np.int64(42), np.int64), (42, Any)
241
252
  )
242
253
  validate_full_roundtrip(42, int, (42, cocoindex.Int64))
243
254
  validate_full_roundtrip(np.int64(42), np.int64, (42, cocoindex.Int64))
244
255
 
245
256
  validate_full_roundtrip(
246
- 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, None)
257
+ 3.25, Float64, (3.25, float), (np.float64(3.25), np.float64), (3.25, Any)
247
258
  )
248
259
  validate_full_roundtrip(3.25, float, (3.25, Float64))
249
260
  validate_full_roundtrip(np.float64(3.25), np.float64, (3.25, Float64))
@@ -255,35 +266,35 @@ def test_roundtrip_basic_types() -> None:
255
266
  (np.float32(3.25), np.float32),
256
267
  (np.float64(3.25), np.float64),
257
268
  (3.25, Float64),
258
- (3.25, None),
269
+ (3.25, Any),
259
270
  )
260
271
  validate_full_roundtrip(np.float32(3.25), np.float32, (3.25, Float32))
261
272
 
262
273
 
263
274
  def test_roundtrip_uuid() -> None:
264
275
  uuid_value = uuid.uuid4()
265
- validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, None))
276
+ validate_full_roundtrip(uuid_value, uuid.UUID, (uuid_value, Any))
266
277
 
267
278
 
268
279
  def test_roundtrip_range() -> None:
269
280
  r1 = (0, 100)
270
- validate_full_roundtrip(r1, cocoindex.Range, (r1, None))
281
+ validate_full_roundtrip(r1, cocoindex.Range, (r1, Any))
271
282
  r2 = (50, 50)
272
- validate_full_roundtrip(r2, cocoindex.Range, (r2, None))
283
+ validate_full_roundtrip(r2, cocoindex.Range, (r2, Any))
273
284
  r3 = (0, 1_000_000_000)
274
- validate_full_roundtrip(r3, cocoindex.Range, (r3, None))
285
+ validate_full_roundtrip(r3, cocoindex.Range, (r3, Any))
275
286
 
276
287
 
277
288
  def test_roundtrip_time() -> None:
278
289
  t1 = datetime.time(10, 30, 50, 123456)
279
- validate_full_roundtrip(t1, datetime.time, (t1, None))
290
+ validate_full_roundtrip(t1, datetime.time, (t1, Any))
280
291
  t2 = datetime.time(23, 59, 59)
281
- validate_full_roundtrip(t2, datetime.time, (t2, None))
292
+ validate_full_roundtrip(t2, datetime.time, (t2, Any))
282
293
  t3 = datetime.time(0, 0, 0)
283
- validate_full_roundtrip(t3, datetime.time, (t3, None))
294
+ validate_full_roundtrip(t3, datetime.time, (t3, Any))
284
295
 
285
296
  validate_full_roundtrip(
286
- datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), None)
297
+ datetime.date(2025, 1, 1), datetime.date, (datetime.date(2025, 1, 1), Any)
287
298
  )
288
299
 
289
300
  validate_full_roundtrip(
@@ -328,11 +339,11 @@ def test_roundtrip_timedelta() -> None:
328
339
  td1 = datetime.timedelta(
329
340
  days=5, seconds=10, microseconds=123, milliseconds=456, minutes=30, hours=2
330
341
  )
331
- validate_full_roundtrip(td1, datetime.timedelta, (td1, None))
342
+ validate_full_roundtrip(td1, datetime.timedelta, (td1, Any))
332
343
  td2 = datetime.timedelta(days=-5, hours=-2)
333
- validate_full_roundtrip(td2, datetime.timedelta, (td2, None))
344
+ validate_full_roundtrip(td2, datetime.timedelta, (td2, Any))
334
345
  td3 = datetime.timedelta(0)
335
- validate_full_roundtrip(td3, datetime.timedelta, (td3, None))
346
+ validate_full_roundtrip(td3, datetime.timedelta, (td3, Any))
336
347
 
337
348
 
338
349
  def test_roundtrip_json() -> None:
@@ -364,7 +375,9 @@ def test_decode_scalar_numpy_values() -> None:
364
375
  ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
365
376
  ]
366
377
  for src_type, dst_type, input_value, expected in test_cases:
367
- decoder = make_engine_value_decoder(["field"], src_type, dst_type)
378
+ decoder = make_engine_value_decoder(
379
+ ["field"], src_type, analyze_type_info(dst_type)
380
+ )
368
381
  result = decoder(input_value)
369
382
  assert isinstance(result, dst_type)
370
383
  assert result == expected
@@ -378,7 +391,9 @@ def test_non_ndarray_vector_decoding() -> None:
378
391
  "dimension": None,
379
392
  }
380
393
  dst_type_float = list[np.float64]
381
- decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
394
+ decoder = make_engine_value_decoder(
395
+ ["field"], src_type, analyze_type_info(dst_type_float)
396
+ )
382
397
  input_numbers = [1.0, 2.0, 3.0]
383
398
  result = decoder(input_numbers)
384
399
  assert isinstance(result, list)
@@ -388,7 +403,9 @@ def test_non_ndarray_vector_decoding() -> None:
388
403
  # Test list[Uuid]
389
404
  src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
390
405
  dst_type_uuid = list[uuid.UUID]
391
- decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
406
+ decoder = make_engine_value_decoder(
407
+ ["field"], src_type, analyze_type_info(dst_type_uuid)
408
+ )
392
409
  uuid1 = uuid.uuid4()
393
410
  uuid2 = uuid.uuid4()
394
411
  input_uuids = [uuid1, uuid2]
@@ -398,124 +415,15 @@ def test_non_ndarray_vector_decoding() -> None:
398
415
  assert result == [uuid1, uuid2]
399
416
 
400
417
 
401
- @pytest.mark.parametrize(
402
- "data_type, engine_val, expected",
403
- [
404
- # All fields match (dataclass)
405
- (
406
- Order,
407
- ["O123", "mixed nuts", 25.0, "default_extra"],
408
- Order("O123", "mixed nuts", 25.0, "default_extra"),
409
- ),
410
- # All fields match (NamedTuple)
411
- (
412
- OrderNamedTuple,
413
- ["O123", "mixed nuts", 25.0, "default_extra"],
414
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
415
- ),
416
- # Extra field in engine value (should ignore extra)
417
- (
418
- Order,
419
- ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
420
- Order("O123", "mixed nuts", 25.0, "default_extra"),
421
- ),
422
- (
423
- OrderNamedTuple,
424
- ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"],
425
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
426
- ),
427
- # Fewer fields in engine value (should fill with default)
428
- (
429
- Order,
430
- ["O123", "mixed nuts", 0.0, "default_extra"],
431
- Order("O123", "mixed nuts", 0.0, "default_extra"),
432
- ),
433
- (
434
- OrderNamedTuple,
435
- ["O123", "mixed nuts", 0.0, "default_extra"],
436
- OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra"),
437
- ),
438
- # More fields in engine value (should ignore extra)
439
- (
440
- Order,
441
- ["O123", "mixed nuts", 25.0, "unexpected"],
442
- Order("O123", "mixed nuts", 25.0, "unexpected"),
443
- ),
444
- (
445
- OrderNamedTuple,
446
- ["O123", "mixed nuts", 25.0, "unexpected"],
447
- OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected"),
448
- ),
449
- # Truly extra field (should ignore the fifth field)
450
- (
451
- Order,
452
- ["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
453
- Order("O123", "mixed nuts", 25.0, "default_extra"),
454
- ),
455
- (
456
- OrderNamedTuple,
457
- ["O123", "mixed nuts", 25.0, "default_extra", "ignored"],
458
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
459
- ),
460
- # Missing optional field in engine value (tags=None)
461
- (
462
- Customer,
463
- ["Alice", ["O1", "item1", 10.0, "default_extra"], None],
464
- Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
465
- ),
466
- (
467
- CustomerNamedTuple,
468
- ["Alice", ["O1", "item1", 10.0, "default_extra"], None],
469
- CustomerNamedTuple(
470
- "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
471
- ),
472
- ),
473
- # Extra field in engine value for Customer (should ignore)
474
- (
475
- Customer,
476
- ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
477
- Customer(
478
- "Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]
479
- ),
480
- ),
481
- (
482
- CustomerNamedTuple,
483
- ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"],
484
- CustomerNamedTuple(
485
- "Alice",
486
- OrderNamedTuple("O1", "item1", 10.0, "default_extra"),
487
- [Tag("vip")],
488
- ),
489
- ),
490
- # Missing optional field with default
491
- (
492
- Order,
493
- ["O123", "mixed nuts", 25.0],
494
- Order("O123", "mixed nuts", 25.0, "default_extra"),
495
- ),
496
- (
497
- OrderNamedTuple,
498
- ["O123", "mixed nuts", 25.0],
499
- OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
500
- ),
501
- # Partial optional fields
502
- (
503
- Customer,
504
- ["Alice", ["O1", "item1", 10.0]],
505
- Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None),
506
- ),
507
- (
508
- CustomerNamedTuple,
509
- ["Alice", ["O1", "item1", 10.0]],
510
- CustomerNamedTuple(
511
- "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None
512
- ),
513
- ),
514
- ],
515
- )
516
- def test_struct_decoder_cases(data_type: Any, engine_val: Any, expected: Any) -> None:
517
- decoder = build_engine_value_decoder(data_type)
518
- assert decoder(engine_val) == expected
418
+ def test_roundtrip_struct() -> None:
419
+ validate_full_roundtrip(
420
+ Order("O123", "mixed nuts", 25.0, "default_extra"),
421
+ Order,
422
+ )
423
+ validate_full_roundtrip(
424
+ OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"),
425
+ OrderNamedTuple,
426
+ )
519
427
 
520
428
 
521
429
  def test_make_engine_value_decoder_list_of_struct() -> None:
@@ -974,7 +882,9 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
974
882
  "dimension": None,
975
883
  }
976
884
  dst_annotation = NDArrayFloat32Type | None
977
- decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
885
+ decoder = make_engine_value_decoder(
886
+ [], src_type_dict, analyze_type_info(dst_annotation)
887
+ )
978
888
 
979
889
  none_engine_value = None
980
890
  decoded_array = decoder(none_engine_value)
@@ -997,7 +907,9 @@ def test_decode_vector_string() -> None:
997
907
  "element_type": {"kind": "Str"},
998
908
  "dimension": None,
999
909
  }
1000
- decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
910
+ decoder = make_engine_value_decoder(
911
+ [], src_type_dict, analyze_type_info(Vector[str])
912
+ )
1001
913
  assert decoder(["hello", "world"]) == ["hello", "world"]
1002
914
 
1003
915
 
@@ -1008,7 +920,9 @@ def test_decode_error_non_nullable_or_non_list_vector() -> None:
1008
920
  "element_type": {"kind": "Float32"},
1009
921
  "dimension": None,
1010
922
  }
1011
- decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
923
+ decoder = make_engine_value_decoder(
924
+ [], src_type_dict, analyze_type_info(NDArrayFloat32Type)
925
+ )
1012
926
  with pytest.raises(ValueError, match="Received null for non-nullable vector"):
1013
927
  decoder(None)
1014
928
  with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
@@ -1252,6 +1166,37 @@ def test_full_roundtrip_scalar_with_python_types() -> None:
1252
1166
  validate_full_roundtrip(instance, MixedStruct)
1253
1167
 
1254
1168
 
1169
+ def test_roundtrip_simple_struct_to_dict_binding() -> None:
1170
+ """Test struct -> dict binding with Any annotation."""
1171
+
1172
+ @dataclass
1173
+ class SimpleStruct:
1174
+ first_name: str
1175
+ last_name: str
1176
+
1177
+ instance = SimpleStruct("John", "Doe")
1178
+ expected_dict = {"first_name": "John", "last_name": "Doe"}
1179
+
1180
+ # Test Any annotation
1181
+ validate_full_roundtrip(
1182
+ instance,
1183
+ SimpleStruct,
1184
+ (expected_dict, Any),
1185
+ (expected_dict, dict),
1186
+ (expected_dict, dict[Any, Any]),
1187
+ (expected_dict, dict[str, Any]),
1188
+ # For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
1189
+ (expected_dict, dict[Any, str]),
1190
+ (expected_dict, dict[str, str]),
1191
+ )
1192
+
1193
+ with pytest.raises(ValueError):
1194
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
1195
+
1196
+ with pytest.raises(ValueError):
1197
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1198
+
1199
+
1255
1200
  def test_roundtrip_struct_to_dict_binding() -> None:
1256
1201
  """Test struct -> dict binding with Any annotation."""
1257
1202
 
@@ -1265,7 +1210,20 @@ def test_roundtrip_struct_to_dict_binding() -> None:
1265
1210
  expected_dict = {"name": "test", "value": 42, "price": 3.14}
1266
1211
 
1267
1212
  # Test Any annotation
1268
- validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any))
1213
+ validate_full_roundtrip(
1214
+ instance,
1215
+ SimpleStruct,
1216
+ (expected_dict, Any),
1217
+ (expected_dict, dict),
1218
+ (expected_dict, dict[Any, Any]),
1219
+ (expected_dict, dict[str, Any]),
1220
+ )
1221
+
1222
+ with pytest.raises(ValueError):
1223
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
1224
+
1225
+ with pytest.raises(ValueError):
1226
+ validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1269
1227
 
1270
1228
 
1271
1229
  def test_roundtrip_struct_to_dict_explicit() -> None:
@@ -1299,8 +1257,8 @@ def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
1299
1257
  instance = Config("localhost", 8080, True)
1300
1258
  expected_dict = {"host": "localhost", "port": 8080, "debug": True}
1301
1259
 
1302
- # Test None annotation (should be treated as Any)
1303
- validate_full_roundtrip(instance, Config, (expected_dict, None))
1260
+ # Test empty annotation (should be treated as Any)
1261
+ validate_full_roundtrip(instance, Config, (expected_dict, inspect.Parameter.empty))
1304
1262
 
1305
1263
 
1306
1264
  def test_roundtrip_struct_to_dict_nested() -> None:
@@ -1381,7 +1339,13 @@ def test_roundtrip_ltable_to_list_dict_binding() -> None:
1381
1339
  ]
1382
1340
 
1383
1341
  # Test Any annotation
1384
- validate_full_roundtrip(users, list[User], (expected_list_dict, Any))
1342
+ validate_full_roundtrip(
1343
+ users,
1344
+ list[User],
1345
+ (expected_list_dict, Any),
1346
+ (expected_list_dict, list[Any]),
1347
+ (expected_list_dict, list[dict[str, Any]]),
1348
+ )
1385
1349
 
1386
1350
 
1387
1351
  def test_roundtrip_ktable_to_dict_dict_binding() -> None:
@@ -1405,7 +1369,17 @@ def test_roundtrip_ktable_to_dict_dict_binding() -> None:
1405
1369
  }
1406
1370
 
1407
1371
  # Test Any annotation
1408
- validate_full_roundtrip(products, dict[str, Product], (expected_dict_dict, Any))
1372
+ validate_full_roundtrip(
1373
+ products,
1374
+ dict[str, Product],
1375
+ (expected_dict_dict, Any),
1376
+ (expected_dict_dict, dict),
1377
+ (expected_dict_dict, dict[Any, Any]),
1378
+ (expected_dict_dict, dict[str, Any]),
1379
+ (expected_dict_dict, dict[Any, dict[Any, Any]]),
1380
+ (expected_dict_dict, dict[str, dict[Any, Any]]),
1381
+ (expected_dict_dict, dict[str, dict[str, Any]]),
1382
+ )
1409
1383
 
1410
1384
 
1411
1385
  def test_roundtrip_ktable_with_complex_key() -> None:
@@ -1431,7 +1405,28 @@ def test_roundtrip_ktable_with_complex_key() -> None:
1431
1405
  }
1432
1406
 
1433
1407
  # Test Any annotation
1434
- validate_full_roundtrip(orders, dict[OrderKey, Order], (expected_dict_dict, Any))
1408
+ validate_full_roundtrip(
1409
+ orders,
1410
+ dict[OrderKey, Order],
1411
+ (expected_dict_dict, Any),
1412
+ (expected_dict_dict, dict),
1413
+ (expected_dict_dict, dict[Any, Any]),
1414
+ (expected_dict_dict, dict[Any, dict[str, Any]]),
1415
+ (
1416
+ {
1417
+ ("shop1", 1): Order("Alice", 100.0),
1418
+ ("shop2", 2): Order("Bob", 200.0),
1419
+ },
1420
+ dict[Any, Order],
1421
+ ),
1422
+ (
1423
+ {
1424
+ OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
1425
+ OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
1426
+ },
1427
+ dict[OrderKey, Any],
1428
+ ),
1429
+ )
1435
1430
 
1436
1431
 
1437
1432
  def test_roundtrip_ltable_with_nested_structs() -> None:
@@ -1489,3 +1484,41 @@ def test_roundtrip_ktable_with_list_fields() -> None:
1489
1484
 
1490
1485
  # Test Any annotation
1491
1486
  validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
1487
+
1488
+
1489
+ def test_auto_default_for_supported_and_unsupported_types() -> None:
1490
+ @dataclass
1491
+ class Base:
1492
+ a: int
1493
+
1494
+ @dataclass
1495
+ class NullableField:
1496
+ a: int
1497
+ b: int | None
1498
+
1499
+ @dataclass
1500
+ class LTableField:
1501
+ a: int
1502
+ b: list[Base]
1503
+
1504
+ @dataclass
1505
+ class KTableField:
1506
+ a: int
1507
+ b: dict[str, Base]
1508
+
1509
+ @dataclass
1510
+ class UnsupportedField:
1511
+ a: int
1512
+ b: int
1513
+
1514
+ validate_full_roundtrip(NullableField(1, None), NullableField)
1515
+
1516
+ validate_full_roundtrip(LTableField(1, []), LTableField)
1517
+
1518
+ # validate_full_roundtrip(KTableField(1, {}), KTableField)
1519
+
1520
+ with pytest.raises(
1521
+ ValueError,
1522
+ match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
1523
+ ):
1524
+ build_engine_value_decoder(Base, UnsupportedField)
@@ -0,0 +1,103 @@
1
+ import typing
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ import pytest
6
+
7
+ import cocoindex
8
+
9
+
10
+ @dataclass
11
+ class Child:
12
+ value: int
13
+
14
+
15
+ @dataclass
16
+ class Parent:
17
+ children: list[Child]
18
+
19
+
20
+ # Fixture to initialize CocoIndex library
21
+ @pytest.fixture(scope="session", autouse=True)
22
+ def init_cocoindex() -> typing.Generator[None, None, None]:
23
+ cocoindex.init()
24
+ yield
25
+
26
+
27
+ @cocoindex.op.function()
28
+ def add_suffix(text: str) -> str:
29
+ """Append ' world' to the input text."""
30
+ return f"{text} world"
31
+
32
+
33
+ @cocoindex.transform_flow()
34
+ def simple_transform(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
35
+ """Transform flow that applies add_suffix to input text."""
36
+ return text.transform(add_suffix)
37
+
38
+
39
+ @cocoindex.op.function()
40
+ def extract_value(value: int) -> int:
41
+ """Extracts the value."""
42
+ return value
43
+
44
+
45
+ @cocoindex.transform_flow()
46
+ def for_each_transform(
47
+ data: cocoindex.DataSlice[Parent],
48
+ ) -> cocoindex.DataSlice[Any]:
49
+ """Transform flow that processes child rows to extract values."""
50
+ with data["children"].row() as child:
51
+ child["new_field"] = child["value"].transform(extract_value)
52
+ return data
53
+
54
+
55
+ def test_simple_transform_flow() -> None:
56
+ """Test the simple transform flow."""
57
+ input_text = "hello"
58
+ result = simple_transform.eval(input_text)
59
+ assert result == "hello world", f"Expected 'hello world', got {result}"
60
+
61
+ result = simple_transform.eval("")
62
+ assert result == " world", f"Expected ' world', got {result}"
63
+
64
+
65
+ @pytest.mark.asyncio
66
+ async def test_simple_transform_flow_async() -> None:
67
+ """Test the simple transform flow asynchronously."""
68
+ input_text = "async"
69
+ result = await simple_transform.eval_async(input_text)
70
+ assert result == "async world", f"Expected 'async world', got {result}"
71
+
72
+
73
+ def test_for_each_transform_flow() -> None:
74
+ """Test the complex transform flow with child rows."""
75
+ input_data = Parent(children=[Child(1), Child(2), Child(3)])
76
+ result = for_each_transform.eval(input_data)
77
+ expected = {
78
+ "children": [
79
+ {"value": 1, "new_field": 1},
80
+ {"value": 2, "new_field": 2},
81
+ {"value": 3, "new_field": 3},
82
+ ]
83
+ }
84
+ assert result == expected, f"Expected {expected}, got {result}"
85
+
86
+ input_data = Parent(children=[])
87
+ result = for_each_transform.eval(input_data)
88
+ assert result == {"children": []}, f"Expected {{'children': []}}, got {result}"
89
+
90
+
91
+ @pytest.mark.asyncio
92
+ async def test_for_each_transform_flow_async() -> None:
93
+ """Test the complex transform flow asynchronously."""
94
+ input_data = Parent(children=[Child(4), Child(5)])
95
+ result = await for_each_transform.eval_async(input_data)
96
+ expected = {
97
+ "children": [
98
+ {"value": 4, "new_field": 4},
99
+ {"value": 5, "new_field": 5},
100
+ ]
101
+ }
102
+
103
+ assert result == expected, f"Expected {expected}, got {result}"
@@ -13,6 +13,7 @@ from cocoindex.typing import (
13
13
  AnalyzedDictType,
14
14
  AnalyzedListType,
15
15
  AnalyzedStructType,
16
+ AnalyzedUnknownType,
16
17
  AnalyzedTypeInfo,
17
18
  TypeAttr,
18
19
  TypeKind,
@@ -422,15 +423,7 @@ def test_annotated_list_with_type_kind() -> None:
422
423
  assert result.variant.kind == "Struct"
423
424
 
424
425
 
425
- def test_unsupported_type() -> None:
426
- with pytest.raises(
427
- ValueError,
428
- match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'set'>",
429
- ):
430
- analyze_type_info(set)
431
-
432
- with pytest.raises(
433
- ValueError,
434
- match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'numpy.complex64'>",
435
- ):
436
- Vector[np.complex64]
426
+ def test_unknown_type() -> None:
427
+ typ = set
428
+ result = analyze_type_info(typ)
429
+ assert isinstance(result.variant, AnalyzedUnknownType)
cocoindex/typing.py CHANGED
@@ -262,7 +262,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
262
262
 
263
263
  if kind is not None:
264
264
  variant = AnalyzedBasicType(kind=kind)
265
- elif base_type is None or base_type is Any or base_type is inspect.Parameter.empty:
265
+ elif base_type is Any or base_type is inspect.Parameter.empty:
266
266
  variant = AnalyzedAnyType()
267
267
  elif is_struct_type(base_type):
268
268
  variant = AnalyzedStructType(struct_type=t)
@@ -270,15 +270,15 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
270
270
  kind = DtypeRegistry.validate_dtype_and_get_kind(t)
271
271
  variant = AnalyzedBasicType(kind=kind)
272
272
  elif base_type is collections.abc.Sequence or base_type is list:
273
- elem_type = type_args[0] if len(type_args) > 0 else None
273
+ elem_type = type_args[0] if len(type_args) > 0 else Any
274
274
  variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
275
275
  elif base_type is np.ndarray:
276
276
  np_number_type = t
277
277
  elem_type = extract_ndarray_elem_dtype(np_number_type)
278
278
  variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
279
279
  elif base_type is collections.abc.Mapping or base_type is dict or t is dict:
280
- key_type = type_args[0] if len(type_args) > 0 else None
281
- elem_type = type_args[1] if len(type_args) > 1 else None
280
+ key_type = type_args[0] if len(type_args) > 0 else Any
281
+ elem_type = type_args[1] if len(type_args) > 1 else Any
282
282
  variant = AnalyzedDictType(key_type=key_type, value_type=elem_type)
283
283
  elif base_type in (types.UnionType, typing.Union):
284
284
  non_none_types = [arg for arg in type_args if arg not in (None, types.NoneType)]
@@ -313,11 +313,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
313
313
  kind = "OffsetDateTime"
314
314
  elif t is datetime.timedelta:
315
315
  kind = "TimeDelta"
316
+
317
+ if kind is None:
318
+ variant = AnalyzedUnknownType()
316
319
  else:
317
- raise ValueError(
318
- f"Unsupported as a specific type annotation for CocoIndex data type (https://cocoindex.io/docs/core/data_types): {t}"
319
- )
320
- variant = AnalyzedBasicType(kind=kind)
320
+ variant = AnalyzedBasicType(kind=kind)
321
321
 
322
322
  return AnalyzedTypeInfo(
323
323
  core_type=core_type,