cocoindex 0.1.49__cp311-cp311-macosx_11_0_arm64.whl → 0.1.50__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Cocoindex is a framework for building and running indexing pipelines.
3
3
  """
4
4
 
5
- from . import functions, query, sources, storages, cli, utils
5
+ from . import functions, sources, storages, cli, utils
6
6
 
7
7
  from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
8
8
  from .flow import FlowBuilder, DataScope, DataSlice, Flow, transform_flow
Binary file
cocoindex/convert.py CHANGED
@@ -6,6 +6,7 @@ import dataclasses
6
6
  import datetime
7
7
  import inspect
8
8
  import uuid
9
+ import numpy as np
9
10
 
10
11
  from enum import Enum
11
12
  from typing import Any, Callable, get_origin, Mapping
@@ -15,6 +16,7 @@ from .typing import (
15
16
  is_namedtuple_type,
16
17
  TABLE_TYPES,
17
18
  KEY_FIELD_NAME,
19
+ DtypeRegistry,
18
20
  )
19
21
 
20
22
 
@@ -27,6 +29,8 @@ def encode_engine_value(value: Any) -> Any:
27
29
  ]
28
30
  if is_namedtuple_type(type(value)):
29
31
  return [encode_engine_value(getattr(value, name)) for name in value._fields]
32
+ if isinstance(value, np.ndarray):
33
+ return value
30
34
  if isinstance(value, (list, tuple)):
31
35
  return [encode_engine_value(v) for v in value]
32
36
  if isinstance(value, dict):
@@ -122,6 +126,38 @@ def make_engine_value_decoder(
122
126
  if src_type_kind == "Uuid":
123
127
  return lambda value: uuid.UUID(bytes=value)
124
128
 
129
+ if src_type_kind == "Vector":
130
+ elem_coco_type_info = analyze_type_info(dst_type_info.elem_type)
131
+ dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind)
132
+
133
+ def decode_vector(value: Any) -> Any | None:
134
+ if value is None:
135
+ if dst_type_info.nullable:
136
+ return None
137
+ raise ValueError(
138
+ f"Received null for non-nullable vector `{''.join(field_path)}`"
139
+ )
140
+
141
+ if not isinstance(value, (np.ndarray, list)):
142
+ raise TypeError(
143
+ f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
144
+ )
145
+ expected_dim = (
146
+ dst_type_info.vector_info.dim if dst_type_info.vector_info else None
147
+ )
148
+ if expected_dim is not None and len(value) != expected_dim:
149
+ raise ValueError(
150
+ f"Vector dimension mismatch for `{''.join(field_path)}`: "
151
+ f"expected {expected_dim}, got {len(value)}"
152
+ )
153
+
154
+ # Use NDArray for supported numeric dtypes, else return list
155
+ if dtype_info is not None:
156
+ return np.array(value, dtype=dtype_info.numpy_dtype)
157
+ return value
158
+
159
+ return decode_vector
160
+
125
161
  return lambda value: value
126
162
 
127
163
 
cocoindex/functions.py CHANGED
@@ -1,6 +1,8 @@
1
1
  """All builtin functions."""
2
2
 
3
- from typing import Annotated, Any, TYPE_CHECKING
3
+ from typing import Annotated, Any, TYPE_CHECKING, Literal
4
+ import numpy as np
5
+ from numpy.typing import NDArray
4
6
  import dataclasses
5
7
 
6
8
  from .typing import Float32, Vector, TypeAttr
@@ -66,11 +68,11 @@ class SentenceTransformerEmbedExecutor:
66
68
  self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
67
69
  dim = self._model.get_sentence_embedding_dimension()
68
70
  result: type = Annotated[
69
- Vector[Float32, dim], # type: ignore
71
+ Vector[np.float32, Literal[dim]], # type: ignore
70
72
  TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
71
73
  ]
72
74
  return result
73
75
 
74
- def __call__(self, text: str) -> list[Float32]:
75
- result: list[Float32] = self._model.encode(text).tolist()
76
+ def __call__(self, text: str) -> NDArray[np.float32]:
77
+ result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
76
78
  return result
cocoindex/lib.py CHANGED
@@ -6,7 +6,7 @@ import warnings
6
6
  from typing import Callable, Any
7
7
 
8
8
  from . import _engine # type: ignore
9
- from . import flow, query, setting
9
+ from . import flow, setting
10
10
  from .convert import dump_engine_object
11
11
 
12
12
 
@@ -24,7 +24,6 @@ def init(settings: setting.Settings | None = None) -> None:
24
24
  def start_server(settings: setting.ServerSettings) -> None:
25
25
  """Start the cocoindex server."""
26
26
  flow.ensure_all_flows_built()
27
- query.ensure_all_handlers_built()
28
27
  _engine.start_server(settings.__dict__)
29
28
 
30
29
 
@@ -1,11 +1,20 @@
1
1
  import uuid
2
2
  import datetime
3
3
  from dataclasses import dataclass, make_dataclass
4
- from typing import NamedTuple, Literal
4
+ from typing import NamedTuple, Literal, Any, Callable
5
5
  import pytest
6
6
  import cocoindex
7
- from cocoindex.typing import encode_enriched_type
8
- from cocoindex.convert import encode_engine_value, make_engine_value_decoder
7
+ from cocoindex.typing import (
8
+ encode_enriched_type,
9
+ Vector,
10
+ )
11
+ from cocoindex.convert import (
12
+ encode_engine_value,
13
+ make_engine_value_decoder,
14
+ dump_engine_object,
15
+ )
16
+ import numpy as np
17
+ from numpy.typing import NDArray
9
18
 
10
19
 
11
20
  @dataclass
@@ -53,7 +62,9 @@ class CustomerNamedTuple(NamedTuple):
53
62
  tags: list[Tag] | None = None
54
63
 
55
64
 
56
- def build_engine_value_decoder(engine_type_in_py, python_type=None):
65
+ def build_engine_value_decoder(
66
+ engine_type_in_py: Any, python_type: Any | None = None
67
+ ) -> Callable[[Any], Any]:
57
68
  """
58
69
  Helper to build a converter for the given engine-side type (as represented in Python).
59
70
  If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,6 +73,27 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
62
73
  return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
63
74
 
64
75
 
76
+ def validate_full_roundtrip(
77
+ value: Any, output_type: Any, input_type: Any | None = None
78
+ ) -> None:
79
+ """
80
+ Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
81
+
82
+ If `input_type` is not specified, uses `output_type` as the target.
83
+ """
84
+ from cocoindex import _engine
85
+
86
+ encoded_value = encode_engine_value(value)
87
+ encoded_output_type = encode_enriched_type(output_type)["type"]
88
+ value_from_engine = _engine.testutil.seder_roundtrip(
89
+ encoded_value, encoded_output_type
90
+ )
91
+ decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
92
+ value_from_engine
93
+ )
94
+ assert decoded_value == value
95
+
96
+
65
97
  def test_encode_engine_value_basic_types():
66
98
  assert encode_engine_value(123) == 123
67
99
  assert encode_engine_value(3.14) == 3.14
@@ -434,57 +466,33 @@ def test_field_position_cases(
434
466
  assert decoder(engine_val) == PythonOrder(**expected_dict)
435
467
 
436
468
 
437
- def test_roundtrip_ltable():
469
+ def test_roundtrip_ltable() -> None:
438
470
  t = list[Order]
439
471
  value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
440
- encoded = encode_engine_value(value)
441
- assert encoded == [
442
- ["O1", "item1", 10.0, "default_extra"],
443
- ["O2", "item2", 20.0, "default_extra"],
444
- ]
445
- decoded = build_engine_value_decoder(t)(encoded)
446
- assert decoded == value
472
+ validate_full_roundtrip(value, t)
447
473
 
448
474
  t_nt = list[OrderNamedTuple]
449
475
  value_nt = [
450
476
  OrderNamedTuple("O1", "item1", 10.0),
451
477
  OrderNamedTuple("O2", "item2", 20.0),
452
478
  ]
453
- encoded = encode_engine_value(value_nt)
454
- assert encoded == [
455
- ["O1", "item1", 10.0, "default_extra"],
456
- ["O2", "item2", 20.0, "default_extra"],
457
- ]
458
- decoded = build_engine_value_decoder(t_nt)(encoded)
459
- assert decoded == value_nt
479
+ validate_full_roundtrip(value_nt, t_nt)
460
480
 
461
481
 
462
- def test_roundtrip_ktable_str_key():
482
+ def test_roundtrip_ktable_str_key() -> None:
463
483
  t = dict[str, Order]
464
484
  value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
465
- encoded = encode_engine_value(value)
466
- assert encoded == [
467
- ["K1", "O1", "item1", 10.0, "default_extra"],
468
- ["K2", "O2", "item2", 20.0, "default_extra"],
469
- ]
470
- decoded = build_engine_value_decoder(t)(encoded)
471
- assert decoded == value
485
+ validate_full_roundtrip(value, t)
472
486
 
473
487
  t_nt = dict[str, OrderNamedTuple]
474
488
  value_nt = {
475
489
  "K1": OrderNamedTuple("O1", "item1", 10.0),
476
490
  "K2": OrderNamedTuple("O2", "item2", 20.0),
477
491
  }
478
- encoded = encode_engine_value(value_nt)
479
- assert encoded == [
480
- ["K1", "O1", "item1", 10.0, "default_extra"],
481
- ["K2", "O2", "item2", 20.0, "default_extra"],
482
- ]
483
- decoded = build_engine_value_decoder(t_nt)(encoded)
484
- assert decoded == value_nt
492
+ validate_full_roundtrip(value_nt, t_nt)
485
493
 
486
494
 
487
- def test_roundtrip_ktable_struct_key():
495
+ def test_roundtrip_ktable_struct_key() -> None:
488
496
  @dataclass(frozen=True)
489
497
  class OrderKey:
490
498
  shop_id: str
@@ -495,29 +503,17 @@ def test_roundtrip_ktable_struct_key():
495
503
  OrderKey("A", 3): Order("O1", "item1", 10.0),
496
504
  OrderKey("B", 4): Order("O2", "item2", 20.0),
497
505
  }
498
- encoded = encode_engine_value(value)
499
- assert encoded == [
500
- [["A", 3], "O1", "item1", 10.0, "default_extra"],
501
- [["B", 4], "O2", "item2", 20.0, "default_extra"],
502
- ]
503
- decoded = build_engine_value_decoder(t)(encoded)
504
- assert decoded == value
506
+ validate_full_roundtrip(value, t)
505
507
 
506
508
  t_nt = dict[OrderKey, OrderNamedTuple]
507
509
  value_nt = {
508
510
  OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
509
511
  OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
510
512
  }
511
- encoded = encode_engine_value(value_nt)
512
- assert encoded == [
513
- [["A", 3], "O1", "item1", 10.0, "default_extra"],
514
- [["B", 4], "O2", "item2", 20.0, "default_extra"],
515
- ]
516
- decoded = build_engine_value_decoder(t_nt)(encoded)
517
- assert decoded == value_nt
513
+ validate_full_roundtrip(value_nt, t_nt)
518
514
 
519
515
 
520
- IntVectorType = cocoindex.Vector[int, Literal[5]]
516
+ IntVectorType = cocoindex.Vector[np.int32, Literal[5]]
521
517
 
522
518
 
523
519
  def test_vector_as_vector() -> None:
@@ -525,7 +521,7 @@ def test_vector_as_vector() -> None:
525
521
  encoded = encode_engine_value(value)
526
522
  assert encoded == [1, 2, 3, 4, 5]
527
523
  decoded = build_engine_value_decoder(IntVectorType)(encoded)
528
- assert decoded == value
524
+ assert np.array_equal(decoded, value)
529
525
 
530
526
 
531
527
  ListIntType = list[int]
@@ -536,4 +532,236 @@ def test_vector_as_list() -> None:
536
532
  encoded = encode_engine_value(value)
537
533
  assert encoded == [1, 2, 3, 4, 5]
538
534
  decoded = build_engine_value_decoder(ListIntType)(encoded)
539
- assert decoded == value
535
+ assert np.array_equal(decoded, value)
536
+
537
+
538
+ Float64VectorTypeNoDim = Vector[np.float64]
539
+ Float32VectorType = Vector[np.float32, Literal[3]]
540
+ Float64VectorType = Vector[np.float64, Literal[3]]
541
+ Int64VectorType = Vector[np.int64, Literal[3]]
542
+ Int32VectorType = Vector[np.int32, Literal[3]]
543
+ NDArrayFloat32Type = NDArray[np.float32]
544
+ NDArrayFloat64Type = NDArray[np.float64]
545
+ NDArrayInt64Type = NDArray[np.int64]
546
+
547
+
548
+ def test_encode_engine_value_ndarray():
549
+ """Test encoding NDArray vectors to lists for the Rust engine."""
550
+ vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
551
+ assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
552
+ vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
553
+ assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
554
+ vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
555
+ assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
556
+ vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
557
+ assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
558
+
559
+
560
+ def test_make_engine_value_decoder_ndarray():
561
+ """Test decoding engine lists to NDArray vectors."""
562
+ decoder_f32 = build_engine_value_decoder(Float32VectorType)
563
+ result_f32 = decoder_f32([1.0, 2.0, 3.0])
564
+ assert isinstance(result_f32, np.ndarray)
565
+ assert result_f32.dtype == np.float32
566
+ assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
567
+ decoder_f64 = build_engine_value_decoder(Float64VectorType)
568
+ result_f64 = decoder_f64([1.0, 2.0, 3.0])
569
+ assert isinstance(result_f64, np.ndarray)
570
+ assert result_f64.dtype == np.float64
571
+ assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
572
+ decoder_i64 = build_engine_value_decoder(Int64VectorType)
573
+ result_i64 = decoder_i64([1, 2, 3])
574
+ assert isinstance(result_i64, np.ndarray)
575
+ assert result_i64.dtype == np.int64
576
+ assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
577
+ decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
578
+ result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
579
+ assert isinstance(result_nd_f32, np.ndarray)
580
+ assert result_nd_f32.dtype == np.float32
581
+ assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
582
+
583
+
584
+ def test_roundtrip_ndarray_vector():
585
+ """Test roundtrip encoding and decoding of NDArray vectors."""
586
+ value_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
587
+ encoded_f32 = encode_engine_value(value_f32)
588
+ np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
589
+ decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
590
+ assert isinstance(decoded_f32, np.ndarray)
591
+ assert decoded_f32.dtype == np.float32
592
+ assert np.array_equal(decoded_f32, value_f32)
593
+ value_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
594
+ encoded_i64 = encode_engine_value(value_i64)
595
+ assert np.array_equal(encoded_i64, [1, 2, 3])
596
+ decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
597
+ assert isinstance(decoded_i64, np.ndarray)
598
+ assert decoded_i64.dtype == np.int64
599
+ assert np.array_equal(decoded_i64, value_i64)
600
+ value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
601
+ encoded_nd_f64 = encode_engine_value(value_nd_f64)
602
+ assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
603
+ decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
604
+ assert isinstance(decoded_nd_f64, np.ndarray)
605
+ assert decoded_nd_f64.dtype == np.float64
606
+ assert np.array_equal(decoded_nd_f64, value_nd_f64)
607
+
608
+
609
+ def test_uint_support():
610
+ """Test encoding and decoding of unsigned integer vectors."""
611
+ value_uint8 = np.array([1, 2, 3, 4], dtype=np.uint8)
612
+ encoded = encode_engine_value(value_uint8)
613
+ assert np.array_equal(encoded, [1, 2, 3, 4])
614
+ decoder = make_engine_value_decoder(
615
+ [], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint8]
616
+ )
617
+ decoded = decoder(encoded)
618
+ assert np.array_equal(decoded, value_uint8)
619
+ assert decoded.dtype == np.uint8
620
+ value_uint16 = np.array([1, 2, 3, 4], dtype=np.uint16)
621
+ encoded = encode_engine_value(value_uint16)
622
+ assert np.array_equal(encoded, [1, 2, 3, 4])
623
+ decoder = make_engine_value_decoder(
624
+ [], {"kind": "Vector", "element_type": {"kind": "UInt16"}}, NDArray[np.uint16]
625
+ )
626
+ decoded = decoder(encoded)
627
+ assert np.array_equal(decoded, value_uint16)
628
+ assert decoded.dtype == np.uint16
629
+ value_uint32 = np.array([1, 2, 3], dtype=np.uint32)
630
+ encoded = encode_engine_value(value_uint32)
631
+ assert np.array_equal(encoded, [1, 2, 3])
632
+ decoder = make_engine_value_decoder(
633
+ [], {"kind": "Vector", "element_type": {"kind": "UInt32"}}, NDArray[np.uint32]
634
+ )
635
+ decoded = decoder(encoded)
636
+ assert np.array_equal(decoded, value_uint32)
637
+ assert decoded.dtype == np.uint32
638
+ value_uint64 = np.array([1, 2, 3], dtype=np.uint64)
639
+ encoded = encode_engine_value(value_uint64)
640
+ assert np.array_equal(encoded, [1, 2, 3])
641
+ decoder = make_engine_value_decoder(
642
+ [], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64]
643
+ )
644
+ decoded = decoder(encoded)
645
+ assert np.array_equal(decoded, value_uint64)
646
+ assert decoded.dtype == np.uint64
647
+
648
+
649
+ def test_ndarray_dimension_mismatch():
650
+ """Test dimension enforcement for Vector with specified dimension."""
651
+ value: Float32VectorType = np.array([1.0, 2.0], dtype=np.float32)
652
+ encoded = encode_engine_value(value)
653
+ assert np.array_equal(encoded, [1.0, 2.0])
654
+ with pytest.raises(ValueError, match="Vector dimension mismatch"):
655
+ build_engine_value_decoder(Float32VectorType)(encoded)
656
+
657
+
658
+ def test_list_vector_backward_compatibility():
659
+ """Test that list-based vectors still work for backward compatibility."""
660
+ value: IntVectorType = [1, 2, 3, 4, 5]
661
+ encoded = encode_engine_value(value)
662
+ assert encoded == [1, 2, 3, 4, 5]
663
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
664
+ assert isinstance(decoded, np.ndarray)
665
+ assert decoded.dtype == np.int32
666
+ assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
667
+ value_list: ListIntType = [1, 2, 3, 4, 5]
668
+ encoded = encode_engine_value(value_list)
669
+ assert np.array_equal(encoded, [1, 2, 3, 4, 5])
670
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
671
+ assert np.array_equal(decoded, [1, 2, 3, 4, 5])
672
+
673
+
674
+ def test_encode_complex_structure_with_ndarray():
675
+ """Test encoding a complex structure that includes an NDArray."""
676
+
677
+ @dataclass
678
+ class MyStructWithNDArray:
679
+ name: str
680
+ data: NDArray[np.float32]
681
+ value: int
682
+
683
+ original = MyStructWithNDArray(
684
+ name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
685
+ )
686
+ encoded = encode_engine_value(original)
687
+ expected = [
688
+ "test_np",
689
+ [1.0, 0.5],
690
+ 100,
691
+ ]
692
+ assert encoded[0] == expected[0]
693
+ assert np.array_equal(encoded[1], expected[1])
694
+ assert encoded[2] == expected[2]
695
+
696
+
697
+ def test_decode_nullable_ndarray_none_or_value_input():
698
+ """Test decoding a nullable NDArray with None or value inputs."""
699
+ src_type_dict = {
700
+ "kind": "Vector",
701
+ "element_type": {"kind": "Float32"},
702
+ "dimension": None,
703
+ }
704
+ dst_annotation = NDArrayFloat32Type | None
705
+ decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
706
+
707
+ none_engine_value = None
708
+ decoded_array = decoder(none_engine_value)
709
+ assert decoded_array is None
710
+
711
+ engine_value = [1.0, 2.0, 3.0]
712
+ decoded_array = decoder(engine_value)
713
+
714
+ assert isinstance(decoded_array, np.ndarray)
715
+ assert decoded_array.dtype == np.float32
716
+ np.testing.assert_array_equal(
717
+ decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
718
+ )
719
+
720
+
721
+ def test_decode_vector_string():
722
+ """Test decoding a vector of strings works for Python native list type."""
723
+ src_type_dict = {
724
+ "kind": "Vector",
725
+ "element_type": {"kind": "Str"},
726
+ "dimension": None,
727
+ }
728
+ decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
729
+ assert decoder(["hello", "world"]) == ["hello", "world"]
730
+
731
+
732
+ def test_decode_error_non_nullable_or_non_list_vector():
733
+ """Test decoding errors for non-nullable vectors or non-list inputs."""
734
+ src_type_dict = {
735
+ "kind": "Vector",
736
+ "element_type": {"kind": "Float32"},
737
+ "dimension": None,
738
+ }
739
+ decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
740
+ with pytest.raises(ValueError, match="Received null for non-nullable vector"):
741
+ decoder(None)
742
+ with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
743
+ decoder("not a list")
744
+
745
+
746
+ def test_dump_vector_type_annotation_with_dim():
747
+ """Test dumping a vector type annotation with a specified dimension."""
748
+ expected_dump = {
749
+ "type": {
750
+ "kind": "Vector",
751
+ "element_type": {"kind": "Float32"},
752
+ "dimension": 3,
753
+ }
754
+ }
755
+ assert dump_engine_object(Float32VectorType) == expected_dump
756
+
757
+
758
+ def test_dump_vector_type_annotation_no_dim():
759
+ """Test dumping a vector type annotation with no dimension."""
760
+ expected_dump_no_dim = {
761
+ "type": {
762
+ "kind": "Vector",
763
+ "element_type": {"kind": "Float64"},
764
+ "dimension": None,
765
+ }
766
+ }
767
+ assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
@@ -0,0 +1,499 @@
1
+ import dataclasses
2
+ import datetime
3
+ import uuid
4
+ from typing import (
5
+ Annotated,
6
+ List,
7
+ Dict,
8
+ Literal,
9
+ Any,
10
+ get_args,
11
+ NamedTuple,
12
+ )
13
+ from collections.abc import Sequence, Mapping
14
+ import pytest
15
+ import numpy as np
16
+ from numpy.typing import NDArray
17
+
18
+ from cocoindex.typing import (
19
+ analyze_type_info,
20
+ Vector,
21
+ VectorInfo,
22
+ TypeKind,
23
+ TypeAttr,
24
+ Float32,
25
+ Float64,
26
+ encode_enriched_type,
27
+ AnalyzedTypeInfo,
28
+ )
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class SimpleDataclass:
33
+ name: str
34
+ value: int
35
+
36
+
37
+ class SimpleNamedTuple(NamedTuple):
38
+ name: str
39
+ value: Any
40
+
41
+
42
+ def test_ndarray_float32_no_dim():
43
+ typ = NDArray[np.float32]
44
+ result = analyze_type_info(typ)
45
+ assert result == AnalyzedTypeInfo(
46
+ kind="Vector",
47
+ vector_info=VectorInfo(dim=None),
48
+ elem_type=Float32,
49
+ key_type=None,
50
+ struct_type=None,
51
+ attrs=None,
52
+ nullable=False,
53
+ )
54
+
55
+
56
+ def test_vector_float32_no_dim():
57
+ typ = Vector[np.float32]
58
+ result = analyze_type_info(typ)
59
+ assert result == AnalyzedTypeInfo(
60
+ kind="Vector",
61
+ vector_info=VectorInfo(dim=None),
62
+ elem_type=Float32,
63
+ key_type=None,
64
+ struct_type=None,
65
+ attrs=None,
66
+ nullable=False,
67
+ )
68
+
69
+
70
+ def test_ndarray_float64_with_dim():
71
+ typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
72
+ result = analyze_type_info(typ)
73
+ assert result == AnalyzedTypeInfo(
74
+ kind="Vector",
75
+ vector_info=VectorInfo(dim=128),
76
+ elem_type=Float64,
77
+ key_type=None,
78
+ struct_type=None,
79
+ attrs=None,
80
+ nullable=False,
81
+ )
82
+
83
+
84
+ def test_vector_float32_with_dim():
85
+ typ = Vector[np.float32, Literal[384]]
86
+ result = analyze_type_info(typ)
87
+ assert result == AnalyzedTypeInfo(
88
+ kind="Vector",
89
+ vector_info=VectorInfo(dim=384),
90
+ elem_type=Float32,
91
+ key_type=None,
92
+ struct_type=None,
93
+ attrs=None,
94
+ nullable=False,
95
+ )
96
+
97
+
98
+ def test_ndarray_int64_no_dim():
99
+ typ = NDArray[np.int64]
100
+ result = analyze_type_info(typ)
101
+ assert result.kind == "Vector"
102
+ assert result.vector_info == VectorInfo(dim=None)
103
+ assert get_args(result.elem_type) == (int, TypeKind("Int64"))
104
+ assert not result.nullable
105
+
106
+
107
+ def test_ndarray_int32_with_dim():
108
+ typ = Annotated[NDArray[np.int32], VectorInfo(dim=10)]
109
+ result = analyze_type_info(typ)
110
+ assert result.kind == "Vector"
111
+ assert result.vector_info == VectorInfo(dim=10)
112
+ assert get_args(result.elem_type) == (int, TypeKind("Int32"))
113
+ assert not result.nullable
114
+
115
+
116
+ def test_ndarray_uint8_no_dim():
117
+ typ = NDArray[np.uint8]
118
+ result = analyze_type_info(typ)
119
+ assert result.kind == "Vector"
120
+ assert result.vector_info == VectorInfo(dim=None)
121
+ assert get_args(result.elem_type) == (int, TypeKind("UInt8"))
122
+ assert not result.nullable
123
+
124
+
125
+ def test_nullable_ndarray():
126
+ typ = NDArray[np.float32] | None
127
+ result = analyze_type_info(typ)
128
+ assert result == AnalyzedTypeInfo(
129
+ kind="Vector",
130
+ vector_info=VectorInfo(dim=None),
131
+ elem_type=Float32,
132
+ key_type=None,
133
+ struct_type=None,
134
+ attrs=None,
135
+ nullable=True,
136
+ )
137
+
138
+
139
+ def test_vector_str():
140
+ typ = Vector[str]
141
+ result = analyze_type_info(typ)
142
+ assert result.kind == "Vector"
143
+ assert result.elem_type == str
144
+ assert result.vector_info == VectorInfo(dim=None)
145
+
146
+
147
+ def test_vector_complex64():
148
+ typ = Vector[np.complex64]
149
+ result = analyze_type_info(typ)
150
+ assert result.kind == "Vector"
151
+ assert result.elem_type == np.complex64
152
+ assert result.vector_info == VectorInfo(dim=None)
153
+
154
+
155
+ def test_non_numpy_vector():
156
+ typ = Vector[float, Literal[3]]
157
+ result = analyze_type_info(typ)
158
+ assert result.kind == "Vector"
159
+ assert result.elem_type == float
160
+ assert result.vector_info == VectorInfo(dim=3)
161
+
162
+
163
+ def test_ndarray_any_dtype():
164
+ typ = NDArray[Any]
165
+ with pytest.raises(
166
+ TypeError, match="NDArray for Vector must use a concrete numpy dtype"
167
+ ):
168
+ analyze_type_info(typ)
169
+
170
+
171
+ def test_list_of_primitives():
172
+ typ = List[str]
173
+ result = analyze_type_info(typ)
174
+ assert result == AnalyzedTypeInfo(
175
+ kind="Vector",
176
+ vector_info=VectorInfo(dim=None),
177
+ elem_type=str,
178
+ key_type=None,
179
+ struct_type=None,
180
+ attrs=None,
181
+ nullable=False,
182
+ )
183
+
184
+
185
+ def test_list_of_structs():
186
+ typ = List[SimpleDataclass]
187
+ result = analyze_type_info(typ)
188
+ assert result == AnalyzedTypeInfo(
189
+ kind="LTable",
190
+ vector_info=None,
191
+ elem_type=SimpleDataclass,
192
+ key_type=None,
193
+ struct_type=None,
194
+ attrs=None,
195
+ nullable=False,
196
+ )
197
+
198
+
199
+ def test_sequence_of_int():
200
+ typ = Sequence[int]
201
+ result = analyze_type_info(typ)
202
+ assert result == AnalyzedTypeInfo(
203
+ kind="Vector",
204
+ vector_info=VectorInfo(dim=None),
205
+ elem_type=int,
206
+ key_type=None,
207
+ struct_type=None,
208
+ attrs=None,
209
+ nullable=False,
210
+ )
211
+
212
+
213
+ def test_list_with_vector_info():
214
+ typ = Annotated[List[int], VectorInfo(dim=5)]
215
+ result = analyze_type_info(typ)
216
+ assert result == AnalyzedTypeInfo(
217
+ kind="Vector",
218
+ vector_info=VectorInfo(dim=5),
219
+ elem_type=int,
220
+ key_type=None,
221
+ struct_type=None,
222
+ attrs=None,
223
+ nullable=False,
224
+ )
225
+
226
+
227
+ def test_dict_str_int():
228
+ typ = Dict[str, int]
229
+ result = analyze_type_info(typ)
230
+ assert result == AnalyzedTypeInfo(
231
+ kind="KTable",
232
+ vector_info=None,
233
+ elem_type=(str, int),
234
+ key_type=None,
235
+ struct_type=None,
236
+ attrs=None,
237
+ nullable=False,
238
+ )
239
+
240
+
241
+ def test_mapping_str_dataclass():
242
+ typ = Mapping[str, SimpleDataclass]
243
+ result = analyze_type_info(typ)
244
+ assert result == AnalyzedTypeInfo(
245
+ kind="KTable",
246
+ vector_info=None,
247
+ elem_type=(str, SimpleDataclass),
248
+ key_type=None,
249
+ struct_type=None,
250
+ attrs=None,
251
+ nullable=False,
252
+ )
253
+
254
+
255
+ def test_dataclass():
256
+ typ = SimpleDataclass
257
+ result = analyze_type_info(typ)
258
+ assert result == AnalyzedTypeInfo(
259
+ kind="Struct",
260
+ vector_info=None,
261
+ elem_type=None,
262
+ key_type=None,
263
+ struct_type=SimpleDataclass,
264
+ attrs=None,
265
+ nullable=False,
266
+ )
267
+
268
+
269
+ def test_named_tuple():
270
+ typ = SimpleNamedTuple
271
+ result = analyze_type_info(typ)
272
+ assert result == AnalyzedTypeInfo(
273
+ kind="Struct",
274
+ vector_info=None,
275
+ elem_type=None,
276
+ key_type=None,
277
+ struct_type=SimpleNamedTuple,
278
+ attrs=None,
279
+ nullable=False,
280
+ )
281
+
282
+
283
+ def test_tuple_key_value():
284
+ typ = (str, int)
285
+ result = analyze_type_info(typ)
286
+ assert result == AnalyzedTypeInfo(
287
+ kind="Int64",
288
+ vector_info=None,
289
+ elem_type=None,
290
+ key_type=str,
291
+ struct_type=None,
292
+ attrs=None,
293
+ nullable=False,
294
+ )
295
+
296
+
297
+ def test_str():
298
+ typ = str
299
+ result = analyze_type_info(typ)
300
+ assert result == AnalyzedTypeInfo(
301
+ kind="Str",
302
+ vector_info=None,
303
+ elem_type=None,
304
+ key_type=None,
305
+ struct_type=None,
306
+ attrs=None,
307
+ nullable=False,
308
+ )
309
+
310
+
311
+ def test_bool():
312
+ typ = bool
313
+ result = analyze_type_info(typ)
314
+ assert result == AnalyzedTypeInfo(
315
+ kind="Bool",
316
+ vector_info=None,
317
+ elem_type=None,
318
+ key_type=None,
319
+ struct_type=None,
320
+ attrs=None,
321
+ nullable=False,
322
+ )
323
+
324
+
325
+ def test_bytes():
326
+ typ = bytes
327
+ result = analyze_type_info(typ)
328
+ assert result == AnalyzedTypeInfo(
329
+ kind="Bytes",
330
+ vector_info=None,
331
+ elem_type=None,
332
+ key_type=None,
333
+ struct_type=None,
334
+ attrs=None,
335
+ nullable=False,
336
+ )
337
+
338
+
339
+ def test_uuid():
340
+ typ = uuid.UUID
341
+ result = analyze_type_info(typ)
342
+ assert result == AnalyzedTypeInfo(
343
+ kind="Uuid",
344
+ vector_info=None,
345
+ elem_type=None,
346
+ key_type=None,
347
+ struct_type=None,
348
+ attrs=None,
349
+ nullable=False,
350
+ )
351
+
352
+
353
+ def test_date():
354
+ typ = datetime.date
355
+ result = analyze_type_info(typ)
356
+ assert result == AnalyzedTypeInfo(
357
+ kind="Date",
358
+ vector_info=None,
359
+ elem_type=None,
360
+ key_type=None,
361
+ struct_type=None,
362
+ attrs=None,
363
+ nullable=False,
364
+ )
365
+
366
+
367
+ def test_time():
368
+ typ = datetime.time
369
+ result = analyze_type_info(typ)
370
+ assert result == AnalyzedTypeInfo(
371
+ kind="Time",
372
+ vector_info=None,
373
+ elem_type=None,
374
+ key_type=None,
375
+ struct_type=None,
376
+ attrs=None,
377
+ nullable=False,
378
+ )
379
+
380
+
381
+ def test_timedelta():
382
+ typ = datetime.timedelta
383
+ result = analyze_type_info(typ)
384
+ assert result == AnalyzedTypeInfo(
385
+ kind="TimeDelta",
386
+ vector_info=None,
387
+ elem_type=None,
388
+ key_type=None,
389
+ struct_type=None,
390
+ attrs=None,
391
+ nullable=False,
392
+ )
393
+
394
+
395
+ def test_float():
396
+ typ = float
397
+ result = analyze_type_info(typ)
398
+ assert result == AnalyzedTypeInfo(
399
+ kind="Float64",
400
+ vector_info=None,
401
+ elem_type=None,
402
+ key_type=None,
403
+ struct_type=None,
404
+ attrs=None,
405
+ nullable=False,
406
+ )
407
+
408
+
409
+ def test_int():
410
+ typ = int
411
+ result = analyze_type_info(typ)
412
+ assert result == AnalyzedTypeInfo(
413
+ kind="Int64",
414
+ vector_info=None,
415
+ elem_type=None,
416
+ key_type=None,
417
+ struct_type=None,
418
+ attrs=None,
419
+ nullable=False,
420
+ )
421
+
422
+
423
+ def test_type_with_attributes():
424
+ typ = Annotated[str, TypeAttr("key", "value")]
425
+ result = analyze_type_info(typ)
426
+ assert result == AnalyzedTypeInfo(
427
+ kind="Str",
428
+ vector_info=None,
429
+ elem_type=None,
430
+ key_type=None,
431
+ struct_type=None,
432
+ attrs={"key": "value"},
433
+ nullable=False,
434
+ )
435
+
436
+
437
+ def test_encode_enriched_type_none():
438
+ typ = None
439
+ result = encode_enriched_type(typ)
440
+ assert result is None
441
+
442
+
443
+ def test_encode_enriched_type_struct():
444
+ typ = SimpleDataclass
445
+ result = encode_enriched_type(typ)
446
+ assert result["type"]["kind"] == "Struct"
447
+ assert len(result["type"]["fields"]) == 2
448
+ assert result["type"]["fields"][0]["name"] == "name"
449
+ assert result["type"]["fields"][0]["type"]["kind"] == "Str"
450
+ assert result["type"]["fields"][1]["name"] == "value"
451
+ assert result["type"]["fields"][1]["type"]["kind"] == "Int64"
452
+
453
+
454
+ def test_encode_enriched_type_vector():
455
+ typ = NDArray[np.float32]
456
+ result = encode_enriched_type(typ)
457
+ assert result["type"]["kind"] == "Vector"
458
+ assert result["type"]["element_type"]["kind"] == "Float32"
459
+ assert result["type"]["dimension"] is None
460
+
461
+
462
+ def test_encode_enriched_type_ltable():
463
+ typ = List[SimpleDataclass]
464
+ result = encode_enriched_type(typ)
465
+ assert result["type"]["kind"] == "LTable"
466
+ assert result["type"]["row"]["kind"] == "Struct"
467
+ assert len(result["type"]["row"]["fields"]) == 2
468
+
469
+
470
+ def test_encode_enriched_type_with_attrs():
471
+ typ = Annotated[str, TypeAttr("key", "value")]
472
+ result = encode_enriched_type(typ)
473
+ assert result["type"]["kind"] == "Str"
474
+ assert result["attrs"] == {"key": "value"}
475
+
476
+
477
+ def test_encode_enriched_type_nullable():
478
+ typ = str | None
479
+ result = encode_enriched_type(typ)
480
+ assert result["type"]["kind"] == "Str"
481
+ assert result["nullable"] is True
482
+
483
+
484
+ def test_invalid_struct_kind():
485
+ typ = Annotated[SimpleDataclass, TypeKind("Vector")]
486
+ with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
487
+ analyze_type_info(typ)
488
+
489
+
490
+ def test_invalid_list_kind():
491
+ typ = Annotated[List[int], TypeKind("Struct")]
492
+ with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"):
493
+ analyze_type_info(typ)
494
+
495
+
496
+ def test_unsupported_type():
497
+ typ = set
498
+ with pytest.raises(ValueError, match="type unsupported yet: <class 'set'>"):
499
+ analyze_type_info(typ)
cocoindex/typing.py CHANGED
@@ -9,14 +9,16 @@ from typing import (
9
9
  Annotated,
10
10
  NamedTuple,
11
11
  Any,
12
+ KeysView,
12
13
  TypeVar,
13
14
  TYPE_CHECKING,
14
15
  overload,
15
- Sequence,
16
16
  Generic,
17
17
  Literal,
18
18
  Protocol,
19
19
  )
20
+ import numpy as np
21
+ from numpy.typing import NDArray
20
22
 
21
23
 
22
24
  class VectorInfo(NamedTuple):
@@ -50,7 +52,7 @@ if TYPE_CHECKING:
50
52
  Dim_co = TypeVar("Dim_co", bound=int, covariant=True)
51
53
 
52
54
  class Vector(Protocol, Generic[T_co, Dim_co]):
53
- """Vector[T, Dim] is a special typing alias for a list[T] with optional dimension info"""
55
+ """Vector[T, Dim] is a special typing alias for an NDArray[T] with optional dimension info"""
54
56
 
55
57
  def __getitem__(self, index: int) -> T_co: ...
56
58
  def __len__(self) -> int: ...
@@ -58,25 +60,34 @@ if TYPE_CHECKING:
58
60
  else:
59
61
 
60
62
  class Vector: # type: ignore[unreachable]
61
- """A special typing alias for a list[T] with optional dimension info"""
63
+ """A special typing alias for an NDArray[T] with optional dimension info"""
62
64
 
63
65
  def __class_getitem__(self, params):
64
66
  if not isinstance(params, tuple):
65
- # Only element type provided
66
- elem_type = params
67
- return Annotated[list[elem_type], VectorInfo(dim=None)]
67
+ # No dimension provided, e.g., Vector[np.float32]
68
+ dtype = params
69
+ # Use NDArray for supported numeric dtypes, else list
70
+ if DtypeRegistry.get_by_dtype(dtype) is not None:
71
+ return Annotated[NDArray[dtype], VectorInfo(dim=None)]
72
+ return Annotated[list[dtype], VectorInfo(dim=None)]
68
73
  else:
69
- # Element type and dimension provided
70
- elem_type, dim = params
71
- if typing.get_origin(dim) is Literal:
72
- dim = typing.get_args(dim)[0] # Extract the literal value
73
- return Annotated[list[elem_type], VectorInfo(dim=dim)]
74
+ # Element type and dimension provided, e.g., Vector[np.float32, Literal[3]]
75
+ dtype, dim_literal = params
76
+ # Extract the literal value
77
+ dim_val = (
78
+ typing.get_args(dim_literal)[0]
79
+ if typing.get_origin(dim_literal) is Literal
80
+ else None
81
+ )
82
+ if DtypeRegistry.get_by_dtype(dtype) is not None:
83
+ return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
84
+ return Annotated[list[dtype], VectorInfo(dim=dim_val)]
74
85
 
75
86
 
76
87
  TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
77
88
  KEY_FIELD_NAME: str = "_key"
78
89
 
79
- ElementType = type | tuple[type, type]
90
+ ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
80
91
 
81
92
 
82
93
  def is_namedtuple_type(t: type) -> bool:
@@ -89,6 +100,48 @@ def _is_struct_type(t: ElementType | None) -> bool:
89
100
  )
90
101
 
91
102
 
103
+ class DtypeInfo:
104
+ """Metadata for a NumPy dtype."""
105
+
106
+ def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
107
+ self.numpy_dtype = numpy_dtype
108
+ self.kind = kind
109
+ self.python_type = python_type
110
+ self.annotated_type = Annotated[python_type, TypeKind(kind)]
111
+
112
+
113
+ class DtypeRegistry:
114
+ _mappings: dict[type, DtypeInfo] = {
115
+ np.float32: DtypeInfo(np.float32, "Float32", float),
116
+ np.float64: DtypeInfo(np.float64, "Float64", float),
117
+ np.int32: DtypeInfo(np.int32, "Int32", int),
118
+ np.int64: DtypeInfo(np.int64, "Int64", int),
119
+ np.uint8: DtypeInfo(np.uint8, "UInt8", int),
120
+ np.uint16: DtypeInfo(np.uint16, "UInt16", int),
121
+ np.uint32: DtypeInfo(np.uint32, "UInt32", int),
122
+ np.uint64: DtypeInfo(np.uint64, "UInt64", int),
123
+ }
124
+
125
+ @classmethod
126
+ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
127
+ if dtype is Any:
128
+ raise TypeError(
129
+ "NDArray for Vector must use a concrete numpy dtype, got `Any`."
130
+ )
131
+ return cls._mappings.get(dtype)
132
+
133
+ @staticmethod
134
+ def get_by_kind(kind: str) -> DtypeInfo | None:
135
+ return next(
136
+ (info for info in DtypeRegistry._mappings.values() if info.kind == kind),
137
+ None,
138
+ )
139
+
140
+ @staticmethod
141
+ def supported_dtypes() -> KeysView[type]:
142
+ return DtypeRegistry._mappings.keys()
143
+
144
+
92
145
  @dataclasses.dataclass
93
146
  class AnalyzedTypeInfo:
94
147
  """
@@ -179,12 +232,34 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
179
232
  vector_info = VectorInfo(dim=None)
180
233
  elif not (kind == "Vector" or kind in TABLE_TYPES):
181
234
  raise ValueError(f"Unexpected type kind for list: {kind}")
235
+ elif base_type is np.ndarray:
236
+ kind = "Vector"
237
+ args = typing.get_args(t)
238
+ _, dtype_spec = args
239
+
240
+ dtype_args = typing.get_args(dtype_spec)
241
+ if not dtype_args:
242
+ raise ValueError("Invalid dtype specification for NDArray")
243
+
244
+ numpy_dtype = dtype_args[0]
245
+ dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
246
+ if dtype_info is None:
247
+ raise ValueError(
248
+ f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
249
+ f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
250
+ )
251
+ elem_type = dtype_info.annotated_type
252
+ vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
253
+
182
254
  elif base_type is collections.abc.Mapping or base_type is dict:
183
255
  args = typing.get_args(t)
184
256
  elem_type = (args[0], args[1])
185
257
  kind = "KTable"
186
258
  elif kind is None:
187
- if t is bytes:
259
+ dtype_info = DtypeRegistry.get_by_dtype(t)
260
+ if dtype_info is not None:
261
+ kind = dtype_info.kind
262
+ elif t is bytes:
188
263
  kind = "Bytes"
189
264
  elif t is str:
190
265
  kind = "Str"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cocoindex
3
- Version: 0.1.49
3
+ Version: 0.1.50
4
4
  Requires-Dist: sentence-transformers>=3.3.1
5
5
  Requires-Dist: click>=8.1.8
6
6
  Requires-Dist: rich>=14.0.0
@@ -1,27 +1,27 @@
1
- cocoindex-0.1.49.dist-info/METADATA,sha256=8Fow__E5M04be_uy0-Gz-ojEaokr4scBR4lQdRMEVGo,9875
2
- cocoindex-0.1.49.dist-info/WHEEL,sha256=p_tvkyHH2UmMBrR2Gemb1ahXJMM2SXUIsCLrWZgJvB8,104
3
- cocoindex-0.1.49.dist-info/entry_points.txt,sha256=_NretjYVzBdNTn7dK-zgwr7YfG2afz1u1uSE-5bZXF8,46
4
- cocoindex-0.1.49.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
- cocoindex/__init__.py,sha256=TkfOR-7upXpzETEvBQUhcnYVuI_A3sqjOdXYeTPAwO0,818
6
- cocoindex/_engine.cpython-311-darwin.so,sha256=LdbQ9EL9P0L19Xq9Nh07NNZY8MYTcv4O4CiJIpI7p7s,57420944
1
+ cocoindex-0.1.50.dist-info/METADATA,sha256=PdFYx9yT4tBqCXVaP77iQjRZMB26D2I8BWXoA8i7ZqQ,9875
2
+ cocoindex-0.1.50.dist-info/WHEEL,sha256=8hL2oHqulIPF_TJ_OfB-8kD2X9O6HJL0emTmeguFbqc,104
3
+ cocoindex-0.1.50.dist-info/entry_points.txt,sha256=_NretjYVzBdNTn7dK-zgwr7YfG2afz1u1uSE-5bZXF8,46
4
+ cocoindex-0.1.50.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
5
+ cocoindex/__init__.py,sha256=BC6BvotY0LcKHCUNnZWjKZX7Afl18TDR2Bfv5R6Y4DM,811
6
+ cocoindex/_engine.cpython-311-darwin.so,sha256=hqncg8nnjBaQicxXkJbOf4Hl3-UoE9onm_Q7On3dvQk,57265296
7
7
  cocoindex/auth_registry.py,sha256=1XqO7ibjmBBd8i11XSJTvTgdz8p1ptW-ZpuSgo_5zzk,716
8
8
  cocoindex/cli.py,sha256=p92s-Ya6VpT4EY5xfbl-R0yyIJv8ySK-2eyxnwC8KqA,17853
9
- cocoindex/convert.py,sha256=CfsZDRs1kwMeNCeJ3-gifHQm_9GzyqySpUmZyCyWXMo,7316
9
+ cocoindex/convert.py,sha256=dxlMUd6Jweun1xESJv0w-WbIx-NifMXlUF9RqgeggWM,8735
10
10
  cocoindex/flow.py,sha256=sC7eg0L9MkZuo7MlvGA4eYe-mEReIyPjKIILFtkWm-Y,30017
11
- cocoindex/functions.py,sha256=RnrYJFJCfj263wow3rEx89UBtff44ldQ_torgwArPEw,2294
11
+ cocoindex/functions.py,sha256=9A61Jj5a3vQoI2MIAhjXvJrDxSzDhe6VncQWbiVtwcg,2393
12
12
  cocoindex/index.py,sha256=j93B9jEvvLXHtpzKWL88SY6wCGEoPgpsQhEGHlyYGFg,540
13
- cocoindex/lib.py,sha256=W3hPh0QQIjLRRe2tpyLCvL_6ajzXhGGSNJubikkS27s,2985
13
+ cocoindex/lib.py,sha256=BeRUn3RqE_wSsVtsgCzbFFKe1LXgRyRmMOcmwWBuEXo,2940
14
14
  cocoindex/llm.py,sha256=KO-R4mrAWtxXD82-Yv5ixpkKMVfkwpbdWwqPVZygLu4,352
15
15
  cocoindex/op.py,sha256=hOJoHC8elhLHMNVRTGTBWNSK5uYrQb-FfRV-qt08j8g,11815
16
16
  cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- cocoindex/query.py,sha256=UswtEp0wZwwNPeR9x9lRokFpNNZOHkyt5tYfCmGIWWU,3450
18
17
  cocoindex/runtime.py,sha256=bAdHYaXFWiiUWyAgzmKTeaAaRR0D_AmaqVCIdPO-v00,1056
19
18
  cocoindex/setting.py,sha256=ePqHw1i95rVtqYYRVqzVwtBifRO4SfH1rlMW_AG3Zek,3418
20
19
  cocoindex/setup.py,sha256=u5dYZFKfz4yZLiGHD0guNaR0s4zY9JAoZWrWHpAHw_0,773
21
20
  cocoindex/sources.py,sha256=JCnOhv1w4o28e03i7yvo4ESicWYAhckkBg5bQlxNH4U,1330
22
21
  cocoindex/storages.py,sha256=i9cPzj5LbLJVHjskEWyltUXF9kqCyc6ftMEqll6FDwI,2804
23
22
  cocoindex/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
24
- cocoindex/tests/test_convert.py,sha256=tTsDveMCgflhcYmDq-fQq3MRGAmN2xlbtFIEDYlRpGw,17095
25
- cocoindex/typing.py,sha256=0WQ4F04kFhXwHphv9AHJqc1cThYmLe27eqqGEuGWHAU,9462
23
+ cocoindex/tests/test_convert.py,sha256=EsR7NEGKAypQpr3FxRz8xja_z0sKmYpTjJuCMSUdyYw,26217
24
+ cocoindex/tests/test_typing.py,sha256=P6bnRjxMjKH2B5RQ0ncZQAMyOPSdviFVmML7fCS2iR4,11884
25
+ cocoindex/typing.py,sha256=9366o49R6-YbrG9rc9UOxJjgTYITC-7DjDkJSVIjZL0,12316
26
26
  cocoindex/utils.py,sha256=5a3ubVzDowJUJyUl8ecd75Th_OzON3G-MV2vXf0dQSk,503
27
- cocoindex-0.1.49.dist-info/RECORD,,
27
+ cocoindex-0.1.50.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: maturin (1.8.6)
2
+ Generator: maturin (1.8.7)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp311-cp311-macosx_11_0_arm64
cocoindex/query.py DELETED
@@ -1,115 +0,0 @@
1
- from typing import Callable, Any
2
- from dataclasses import dataclass
3
- from threading import Lock
4
-
5
- from . import flow as fl
6
- from . import index
7
- from . import _engine # type: ignore
8
-
9
- _handlers_lock = Lock()
10
- _handlers: dict[str, _engine.SimpleSemanticsQueryHandler] = {}
11
-
12
-
13
- @dataclass
14
- class SimpleSemanticsQueryInfo:
15
- """
16
- Additional information about the query.
17
- """
18
-
19
- similarity_metric: index.VectorSimilarityMetric
20
- query_vector: list[float]
21
- vector_field_name: str
22
-
23
-
24
- @dataclass
25
- class QueryResult:
26
- """
27
- A single result from the query.
28
- """
29
-
30
- data: dict[str, Any]
31
- score: float
32
-
33
-
34
- class SimpleSemanticsQueryHandler:
35
- """
36
- A query handler that uses simple semantics to query the index.
37
- """
38
-
39
- _lazy_query_handler: Callable[[], _engine.SimpleSemanticsQueryHandler]
40
-
41
- def __init__(
42
- self,
43
- name: str,
44
- flow: fl.Flow,
45
- target_name: str,
46
- query_transform_flow: Callable[..., fl.DataSlice[Any]],
47
- default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY,
48
- ) -> None:
49
- engine_handler = None
50
- lock = Lock()
51
-
52
- def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
53
- nonlocal engine_handler, lock
54
- if engine_handler is None:
55
- with lock:
56
- if engine_handler is None:
57
- engine_handler = _engine.SimpleSemanticsQueryHandler(
58
- flow.internal_flow(),
59
- target_name,
60
- fl.TransformFlow(
61
- query_transform_flow, [str]
62
- ).internal_flow(),
63
- default_similarity_metric.value,
64
- )
65
- engine_handler.register_query_handler(name)
66
- return engine_handler
67
-
68
- self._lazy_query_handler = _lazy_handler
69
-
70
- with _handlers_lock:
71
- _handlers[name] = self
72
-
73
- def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler:
74
- """
75
- Get the internal query handler.
76
- """
77
- return self._lazy_query_handler()
78
-
79
- def search(
80
- self,
81
- query: str,
82
- limit: int,
83
- vector_field_name: str | None = None,
84
- similarity_metric: index.VectorSimilarityMetric | None = None,
85
- ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
86
- """
87
- Search the index with the given query, limit, vector field name, and similarity metric.
88
- """
89
- internal_results, internal_info = self.internal_handler().search(
90
- query,
91
- limit,
92
- vector_field_name,
93
- similarity_metric.value if similarity_metric is not None else None,
94
- )
95
- results = [
96
- QueryResult(data=result["data"], score=result["score"])
97
- for result in internal_results
98
- ]
99
- info = SimpleSemanticsQueryInfo(
100
- similarity_metric=index.VectorSimilarityMetric(
101
- internal_info["similarity_metric"]
102
- ),
103
- query_vector=internal_info["query_vector"],
104
- vector_field_name=internal_info["vector_field_name"],
105
- )
106
- return results, info
107
-
108
-
109
- def ensure_all_handlers_built() -> None:
110
- """
111
- Ensure all handlers are built.
112
- """
113
- with _handlers_lock:
114
- for handler in _handlers.values():
115
- handler.internal_handler()