cocoindex 0.1.49__cp311-cp311-macosx_11_0_arm64.whl → 0.1.50__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +1 -1
- cocoindex/_engine.cpython-311-darwin.so +0 -0
- cocoindex/convert.py +36 -0
- cocoindex/functions.py +6 -4
- cocoindex/lib.py +1 -2
- cocoindex/tests/test_convert.py +280 -52
- cocoindex/tests/test_typing.py +499 -0
- cocoindex/typing.py +88 -13
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.50.dist-info}/METADATA +1 -1
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.50.dist-info}/RECORD +13 -13
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.50.dist-info}/WHEEL +1 -1
- cocoindex/query.py +0 -115
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.50.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.49.dist-info → cocoindex-0.1.50.dist-info}/licenses/LICENSE +0 -0
cocoindex/__init__.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
Cocoindex is a framework for building and running indexing pipelines.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from . import functions,
|
5
|
+
from . import functions, sources, storages, cli, utils
|
6
6
|
|
7
7
|
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
|
8
8
|
from .flow import FlowBuilder, DataScope, DataSlice, Flow, transform_flow
|
Binary file
|
cocoindex/convert.py
CHANGED
@@ -6,6 +6,7 @@ import dataclasses
|
|
6
6
|
import datetime
|
7
7
|
import inspect
|
8
8
|
import uuid
|
9
|
+
import numpy as np
|
9
10
|
|
10
11
|
from enum import Enum
|
11
12
|
from typing import Any, Callable, get_origin, Mapping
|
@@ -15,6 +16,7 @@ from .typing import (
|
|
15
16
|
is_namedtuple_type,
|
16
17
|
TABLE_TYPES,
|
17
18
|
KEY_FIELD_NAME,
|
19
|
+
DtypeRegistry,
|
18
20
|
)
|
19
21
|
|
20
22
|
|
@@ -27,6 +29,8 @@ def encode_engine_value(value: Any) -> Any:
|
|
27
29
|
]
|
28
30
|
if is_namedtuple_type(type(value)):
|
29
31
|
return [encode_engine_value(getattr(value, name)) for name in value._fields]
|
32
|
+
if isinstance(value, np.ndarray):
|
33
|
+
return value
|
30
34
|
if isinstance(value, (list, tuple)):
|
31
35
|
return [encode_engine_value(v) for v in value]
|
32
36
|
if isinstance(value, dict):
|
@@ -122,6 +126,38 @@ def make_engine_value_decoder(
|
|
122
126
|
if src_type_kind == "Uuid":
|
123
127
|
return lambda value: uuid.UUID(bytes=value)
|
124
128
|
|
129
|
+
if src_type_kind == "Vector":
|
130
|
+
elem_coco_type_info = analyze_type_info(dst_type_info.elem_type)
|
131
|
+
dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind)
|
132
|
+
|
133
|
+
def decode_vector(value: Any) -> Any | None:
|
134
|
+
if value is None:
|
135
|
+
if dst_type_info.nullable:
|
136
|
+
return None
|
137
|
+
raise ValueError(
|
138
|
+
f"Received null for non-nullable vector `{''.join(field_path)}`"
|
139
|
+
)
|
140
|
+
|
141
|
+
if not isinstance(value, (np.ndarray, list)):
|
142
|
+
raise TypeError(
|
143
|
+
f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
|
144
|
+
)
|
145
|
+
expected_dim = (
|
146
|
+
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
|
147
|
+
)
|
148
|
+
if expected_dim is not None and len(value) != expected_dim:
|
149
|
+
raise ValueError(
|
150
|
+
f"Vector dimension mismatch for `{''.join(field_path)}`: "
|
151
|
+
f"expected {expected_dim}, got {len(value)}"
|
152
|
+
)
|
153
|
+
|
154
|
+
# Use NDArray for supported numeric dtypes, else return list
|
155
|
+
if dtype_info is not None:
|
156
|
+
return np.array(value, dtype=dtype_info.numpy_dtype)
|
157
|
+
return value
|
158
|
+
|
159
|
+
return decode_vector
|
160
|
+
|
125
161
|
return lambda value: value
|
126
162
|
|
127
163
|
|
cocoindex/functions.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""All builtin functions."""
|
2
2
|
|
3
|
-
from typing import Annotated, Any, TYPE_CHECKING
|
3
|
+
from typing import Annotated, Any, TYPE_CHECKING, Literal
|
4
|
+
import numpy as np
|
5
|
+
from numpy.typing import NDArray
|
4
6
|
import dataclasses
|
5
7
|
|
6
8
|
from .typing import Float32, Vector, TypeAttr
|
@@ -66,11 +68,11 @@ class SentenceTransformerEmbedExecutor:
|
|
66
68
|
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
|
67
69
|
dim = self._model.get_sentence_embedding_dimension()
|
68
70
|
result: type = Annotated[
|
69
|
-
Vector[
|
71
|
+
Vector[np.float32, Literal[dim]], # type: ignore
|
70
72
|
TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
|
71
73
|
]
|
72
74
|
return result
|
73
75
|
|
74
|
-
def __call__(self, text: str) ->
|
75
|
-
result:
|
76
|
+
def __call__(self, text: str) -> NDArray[np.float32]:
|
77
|
+
result: NDArray[np.float32] = self._model.encode(text, convert_to_numpy=True)
|
76
78
|
return result
|
cocoindex/lib.py
CHANGED
@@ -6,7 +6,7 @@ import warnings
|
|
6
6
|
from typing import Callable, Any
|
7
7
|
|
8
8
|
from . import _engine # type: ignore
|
9
|
-
from . import flow,
|
9
|
+
from . import flow, setting
|
10
10
|
from .convert import dump_engine_object
|
11
11
|
|
12
12
|
|
@@ -24,7 +24,6 @@ def init(settings: setting.Settings | None = None) -> None:
|
|
24
24
|
def start_server(settings: setting.ServerSettings) -> None:
|
25
25
|
"""Start the cocoindex server."""
|
26
26
|
flow.ensure_all_flows_built()
|
27
|
-
query.ensure_all_handlers_built()
|
28
27
|
_engine.start_server(settings.__dict__)
|
29
28
|
|
30
29
|
|
cocoindex/tests/test_convert.py
CHANGED
@@ -1,11 +1,20 @@
|
|
1
1
|
import uuid
|
2
2
|
import datetime
|
3
3
|
from dataclasses import dataclass, make_dataclass
|
4
|
-
from typing import NamedTuple, Literal
|
4
|
+
from typing import NamedTuple, Literal, Any, Callable
|
5
5
|
import pytest
|
6
6
|
import cocoindex
|
7
|
-
from cocoindex.typing import
|
8
|
-
|
7
|
+
from cocoindex.typing import (
|
8
|
+
encode_enriched_type,
|
9
|
+
Vector,
|
10
|
+
)
|
11
|
+
from cocoindex.convert import (
|
12
|
+
encode_engine_value,
|
13
|
+
make_engine_value_decoder,
|
14
|
+
dump_engine_object,
|
15
|
+
)
|
16
|
+
import numpy as np
|
17
|
+
from numpy.typing import NDArray
|
9
18
|
|
10
19
|
|
11
20
|
@dataclass
|
@@ -53,7 +62,9 @@ class CustomerNamedTuple(NamedTuple):
|
|
53
62
|
tags: list[Tag] | None = None
|
54
63
|
|
55
64
|
|
56
|
-
def build_engine_value_decoder(
|
65
|
+
def build_engine_value_decoder(
|
66
|
+
engine_type_in_py: Any, python_type: Any | None = None
|
67
|
+
) -> Callable[[Any], Any]:
|
57
68
|
"""
|
58
69
|
Helper to build a converter for the given engine-side type (as represented in Python).
|
59
70
|
If python_type is not specified, uses engine_type_in_py as the target.
|
@@ -62,6 +73,27 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
|
|
62
73
|
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
|
63
74
|
|
64
75
|
|
76
|
+
def validate_full_roundtrip(
|
77
|
+
value: Any, output_type: Any, input_type: Any | None = None
|
78
|
+
) -> None:
|
79
|
+
"""
|
80
|
+
Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
|
81
|
+
|
82
|
+
If `input_type` is not specified, uses `output_type` as the target.
|
83
|
+
"""
|
84
|
+
from cocoindex import _engine
|
85
|
+
|
86
|
+
encoded_value = encode_engine_value(value)
|
87
|
+
encoded_output_type = encode_enriched_type(output_type)["type"]
|
88
|
+
value_from_engine = _engine.testutil.seder_roundtrip(
|
89
|
+
encoded_value, encoded_output_type
|
90
|
+
)
|
91
|
+
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
|
92
|
+
value_from_engine
|
93
|
+
)
|
94
|
+
assert decoded_value == value
|
95
|
+
|
96
|
+
|
65
97
|
def test_encode_engine_value_basic_types():
|
66
98
|
assert encode_engine_value(123) == 123
|
67
99
|
assert encode_engine_value(3.14) == 3.14
|
@@ -434,57 +466,33 @@ def test_field_position_cases(
|
|
434
466
|
assert decoder(engine_val) == PythonOrder(**expected_dict)
|
435
467
|
|
436
468
|
|
437
|
-
def test_roundtrip_ltable():
|
469
|
+
def test_roundtrip_ltable() -> None:
|
438
470
|
t = list[Order]
|
439
471
|
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
440
|
-
|
441
|
-
assert encoded == [
|
442
|
-
["O1", "item1", 10.0, "default_extra"],
|
443
|
-
["O2", "item2", 20.0, "default_extra"],
|
444
|
-
]
|
445
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
446
|
-
assert decoded == value
|
472
|
+
validate_full_roundtrip(value, t)
|
447
473
|
|
448
474
|
t_nt = list[OrderNamedTuple]
|
449
475
|
value_nt = [
|
450
476
|
OrderNamedTuple("O1", "item1", 10.0),
|
451
477
|
OrderNamedTuple("O2", "item2", 20.0),
|
452
478
|
]
|
453
|
-
|
454
|
-
assert encoded == [
|
455
|
-
["O1", "item1", 10.0, "default_extra"],
|
456
|
-
["O2", "item2", 20.0, "default_extra"],
|
457
|
-
]
|
458
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
459
|
-
assert decoded == value_nt
|
479
|
+
validate_full_roundtrip(value_nt, t_nt)
|
460
480
|
|
461
481
|
|
462
|
-
def test_roundtrip_ktable_str_key():
|
482
|
+
def test_roundtrip_ktable_str_key() -> None:
|
463
483
|
t = dict[str, Order]
|
464
484
|
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
|
465
|
-
|
466
|
-
assert encoded == [
|
467
|
-
["K1", "O1", "item1", 10.0, "default_extra"],
|
468
|
-
["K2", "O2", "item2", 20.0, "default_extra"],
|
469
|
-
]
|
470
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
471
|
-
assert decoded == value
|
485
|
+
validate_full_roundtrip(value, t)
|
472
486
|
|
473
487
|
t_nt = dict[str, OrderNamedTuple]
|
474
488
|
value_nt = {
|
475
489
|
"K1": OrderNamedTuple("O1", "item1", 10.0),
|
476
490
|
"K2": OrderNamedTuple("O2", "item2", 20.0),
|
477
491
|
}
|
478
|
-
|
479
|
-
assert encoded == [
|
480
|
-
["K1", "O1", "item1", 10.0, "default_extra"],
|
481
|
-
["K2", "O2", "item2", 20.0, "default_extra"],
|
482
|
-
]
|
483
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
484
|
-
assert decoded == value_nt
|
492
|
+
validate_full_roundtrip(value_nt, t_nt)
|
485
493
|
|
486
494
|
|
487
|
-
def test_roundtrip_ktable_struct_key():
|
495
|
+
def test_roundtrip_ktable_struct_key() -> None:
|
488
496
|
@dataclass(frozen=True)
|
489
497
|
class OrderKey:
|
490
498
|
shop_id: str
|
@@ -495,29 +503,17 @@ def test_roundtrip_ktable_struct_key():
|
|
495
503
|
OrderKey("A", 3): Order("O1", "item1", 10.0),
|
496
504
|
OrderKey("B", 4): Order("O2", "item2", 20.0),
|
497
505
|
}
|
498
|
-
|
499
|
-
assert encoded == [
|
500
|
-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
501
|
-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
|
502
|
-
]
|
503
|
-
decoded = build_engine_value_decoder(t)(encoded)
|
504
|
-
assert decoded == value
|
506
|
+
validate_full_roundtrip(value, t)
|
505
507
|
|
506
508
|
t_nt = dict[OrderKey, OrderNamedTuple]
|
507
509
|
value_nt = {
|
508
510
|
OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0),
|
509
511
|
OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0),
|
510
512
|
}
|
511
|
-
|
512
|
-
assert encoded == [
|
513
|
-
[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
514
|
-
[["B", 4], "O2", "item2", 20.0, "default_extra"],
|
515
|
-
]
|
516
|
-
decoded = build_engine_value_decoder(t_nt)(encoded)
|
517
|
-
assert decoded == value_nt
|
513
|
+
validate_full_roundtrip(value_nt, t_nt)
|
518
514
|
|
519
515
|
|
520
|
-
IntVectorType = cocoindex.Vector[
|
516
|
+
IntVectorType = cocoindex.Vector[np.int32, Literal[5]]
|
521
517
|
|
522
518
|
|
523
519
|
def test_vector_as_vector() -> None:
|
@@ -525,7 +521,7 @@ def test_vector_as_vector() -> None:
|
|
525
521
|
encoded = encode_engine_value(value)
|
526
522
|
assert encoded == [1, 2, 3, 4, 5]
|
527
523
|
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
528
|
-
assert decoded
|
524
|
+
assert np.array_equal(decoded, value)
|
529
525
|
|
530
526
|
|
531
527
|
ListIntType = list[int]
|
@@ -536,4 +532,236 @@ def test_vector_as_list() -> None:
|
|
536
532
|
encoded = encode_engine_value(value)
|
537
533
|
assert encoded == [1, 2, 3, 4, 5]
|
538
534
|
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
539
|
-
assert decoded
|
535
|
+
assert np.array_equal(decoded, value)
|
536
|
+
|
537
|
+
|
538
|
+
Float64VectorTypeNoDim = Vector[np.float64]
|
539
|
+
Float32VectorType = Vector[np.float32, Literal[3]]
|
540
|
+
Float64VectorType = Vector[np.float64, Literal[3]]
|
541
|
+
Int64VectorType = Vector[np.int64, Literal[3]]
|
542
|
+
Int32VectorType = Vector[np.int32, Literal[3]]
|
543
|
+
NDArrayFloat32Type = NDArray[np.float32]
|
544
|
+
NDArrayFloat64Type = NDArray[np.float64]
|
545
|
+
NDArrayInt64Type = NDArray[np.int64]
|
546
|
+
|
547
|
+
|
548
|
+
def test_encode_engine_value_ndarray():
|
549
|
+
"""Test encoding NDArray vectors to lists for the Rust engine."""
|
550
|
+
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
551
|
+
assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
|
552
|
+
vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
553
|
+
assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
|
554
|
+
vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
|
555
|
+
assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
|
556
|
+
vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
557
|
+
assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
|
558
|
+
|
559
|
+
|
560
|
+
def test_make_engine_value_decoder_ndarray():
|
561
|
+
"""Test decoding engine lists to NDArray vectors."""
|
562
|
+
decoder_f32 = build_engine_value_decoder(Float32VectorType)
|
563
|
+
result_f32 = decoder_f32([1.0, 2.0, 3.0])
|
564
|
+
assert isinstance(result_f32, np.ndarray)
|
565
|
+
assert result_f32.dtype == np.float32
|
566
|
+
assert np.array_equal(result_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
567
|
+
decoder_f64 = build_engine_value_decoder(Float64VectorType)
|
568
|
+
result_f64 = decoder_f64([1.0, 2.0, 3.0])
|
569
|
+
assert isinstance(result_f64, np.ndarray)
|
570
|
+
assert result_f64.dtype == np.float64
|
571
|
+
assert np.array_equal(result_f64, np.array([1.0, 2.0, 3.0], dtype=np.float64))
|
572
|
+
decoder_i64 = build_engine_value_decoder(Int64VectorType)
|
573
|
+
result_i64 = decoder_i64([1, 2, 3])
|
574
|
+
assert isinstance(result_i64, np.ndarray)
|
575
|
+
assert result_i64.dtype == np.int64
|
576
|
+
assert np.array_equal(result_i64, np.array([1, 2, 3], dtype=np.int64))
|
577
|
+
decoder_nd_f32 = build_engine_value_decoder(NDArrayFloat32Type)
|
578
|
+
result_nd_f32 = decoder_nd_f32([1.0, 2.0, 3.0])
|
579
|
+
assert isinstance(result_nd_f32, np.ndarray)
|
580
|
+
assert result_nd_f32.dtype == np.float32
|
581
|
+
assert np.array_equal(result_nd_f32, np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
582
|
+
|
583
|
+
|
584
|
+
def test_roundtrip_ndarray_vector():
|
585
|
+
"""Test roundtrip encoding and decoding of NDArray vectors."""
|
586
|
+
value_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
587
|
+
encoded_f32 = encode_engine_value(value_f32)
|
588
|
+
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
|
589
|
+
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
|
590
|
+
assert isinstance(decoded_f32, np.ndarray)
|
591
|
+
assert decoded_f32.dtype == np.float32
|
592
|
+
assert np.array_equal(decoded_f32, value_f32)
|
593
|
+
value_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
|
594
|
+
encoded_i64 = encode_engine_value(value_i64)
|
595
|
+
assert np.array_equal(encoded_i64, [1, 2, 3])
|
596
|
+
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
|
597
|
+
assert isinstance(decoded_i64, np.ndarray)
|
598
|
+
assert decoded_i64.dtype == np.int64
|
599
|
+
assert np.array_equal(decoded_i64, value_i64)
|
600
|
+
value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
601
|
+
encoded_nd_f64 = encode_engine_value(value_nd_f64)
|
602
|
+
assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
|
603
|
+
decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
|
604
|
+
assert isinstance(decoded_nd_f64, np.ndarray)
|
605
|
+
assert decoded_nd_f64.dtype == np.float64
|
606
|
+
assert np.array_equal(decoded_nd_f64, value_nd_f64)
|
607
|
+
|
608
|
+
|
609
|
+
def test_uint_support():
|
610
|
+
"""Test encoding and decoding of unsigned integer vectors."""
|
611
|
+
value_uint8 = np.array([1, 2, 3, 4], dtype=np.uint8)
|
612
|
+
encoded = encode_engine_value(value_uint8)
|
613
|
+
assert np.array_equal(encoded, [1, 2, 3, 4])
|
614
|
+
decoder = make_engine_value_decoder(
|
615
|
+
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint8]
|
616
|
+
)
|
617
|
+
decoded = decoder(encoded)
|
618
|
+
assert np.array_equal(decoded, value_uint8)
|
619
|
+
assert decoded.dtype == np.uint8
|
620
|
+
value_uint16 = np.array([1, 2, 3, 4], dtype=np.uint16)
|
621
|
+
encoded = encode_engine_value(value_uint16)
|
622
|
+
assert np.array_equal(encoded, [1, 2, 3, 4])
|
623
|
+
decoder = make_engine_value_decoder(
|
624
|
+
[], {"kind": "Vector", "element_type": {"kind": "UInt16"}}, NDArray[np.uint16]
|
625
|
+
)
|
626
|
+
decoded = decoder(encoded)
|
627
|
+
assert np.array_equal(decoded, value_uint16)
|
628
|
+
assert decoded.dtype == np.uint16
|
629
|
+
value_uint32 = np.array([1, 2, 3], dtype=np.uint32)
|
630
|
+
encoded = encode_engine_value(value_uint32)
|
631
|
+
assert np.array_equal(encoded, [1, 2, 3])
|
632
|
+
decoder = make_engine_value_decoder(
|
633
|
+
[], {"kind": "Vector", "element_type": {"kind": "UInt32"}}, NDArray[np.uint32]
|
634
|
+
)
|
635
|
+
decoded = decoder(encoded)
|
636
|
+
assert np.array_equal(decoded, value_uint32)
|
637
|
+
assert decoded.dtype == np.uint32
|
638
|
+
value_uint64 = np.array([1, 2, 3], dtype=np.uint64)
|
639
|
+
encoded = encode_engine_value(value_uint64)
|
640
|
+
assert np.array_equal(encoded, [1, 2, 3])
|
641
|
+
decoder = make_engine_value_decoder(
|
642
|
+
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64]
|
643
|
+
)
|
644
|
+
decoded = decoder(encoded)
|
645
|
+
assert np.array_equal(decoded, value_uint64)
|
646
|
+
assert decoded.dtype == np.uint64
|
647
|
+
|
648
|
+
|
649
|
+
def test_ndarray_dimension_mismatch():
|
650
|
+
"""Test dimension enforcement for Vector with specified dimension."""
|
651
|
+
value: Float32VectorType = np.array([1.0, 2.0], dtype=np.float32)
|
652
|
+
encoded = encode_engine_value(value)
|
653
|
+
assert np.array_equal(encoded, [1.0, 2.0])
|
654
|
+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
|
655
|
+
build_engine_value_decoder(Float32VectorType)(encoded)
|
656
|
+
|
657
|
+
|
658
|
+
def test_list_vector_backward_compatibility():
|
659
|
+
"""Test that list-based vectors still work for backward compatibility."""
|
660
|
+
value: IntVectorType = [1, 2, 3, 4, 5]
|
661
|
+
encoded = encode_engine_value(value)
|
662
|
+
assert encoded == [1, 2, 3, 4, 5]
|
663
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
664
|
+
assert isinstance(decoded, np.ndarray)
|
665
|
+
assert decoded.dtype == np.int32
|
666
|
+
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
|
667
|
+
value_list: ListIntType = [1, 2, 3, 4, 5]
|
668
|
+
encoded = encode_engine_value(value_list)
|
669
|
+
assert np.array_equal(encoded, [1, 2, 3, 4, 5])
|
670
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
671
|
+
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
|
672
|
+
|
673
|
+
|
674
|
+
def test_encode_complex_structure_with_ndarray():
|
675
|
+
"""Test encoding a complex structure that includes an NDArray."""
|
676
|
+
|
677
|
+
@dataclass
|
678
|
+
class MyStructWithNDArray:
|
679
|
+
name: str
|
680
|
+
data: NDArray[np.float32]
|
681
|
+
value: int
|
682
|
+
|
683
|
+
original = MyStructWithNDArray(
|
684
|
+
name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100
|
685
|
+
)
|
686
|
+
encoded = encode_engine_value(original)
|
687
|
+
expected = [
|
688
|
+
"test_np",
|
689
|
+
[1.0, 0.5],
|
690
|
+
100,
|
691
|
+
]
|
692
|
+
assert encoded[0] == expected[0]
|
693
|
+
assert np.array_equal(encoded[1], expected[1])
|
694
|
+
assert encoded[2] == expected[2]
|
695
|
+
|
696
|
+
|
697
|
+
def test_decode_nullable_ndarray_none_or_value_input():
|
698
|
+
"""Test decoding a nullable NDArray with None or value inputs."""
|
699
|
+
src_type_dict = {
|
700
|
+
"kind": "Vector",
|
701
|
+
"element_type": {"kind": "Float32"},
|
702
|
+
"dimension": None,
|
703
|
+
}
|
704
|
+
dst_annotation = NDArrayFloat32Type | None
|
705
|
+
decoder = make_engine_value_decoder([], src_type_dict, dst_annotation)
|
706
|
+
|
707
|
+
none_engine_value = None
|
708
|
+
decoded_array = decoder(none_engine_value)
|
709
|
+
assert decoded_array is None
|
710
|
+
|
711
|
+
engine_value = [1.0, 2.0, 3.0]
|
712
|
+
decoded_array = decoder(engine_value)
|
713
|
+
|
714
|
+
assert isinstance(decoded_array, np.ndarray)
|
715
|
+
assert decoded_array.dtype == np.float32
|
716
|
+
np.testing.assert_array_equal(
|
717
|
+
decoded_array, np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
718
|
+
)
|
719
|
+
|
720
|
+
|
721
|
+
def test_decode_vector_string():
|
722
|
+
"""Test decoding a vector of strings works for Python native list type."""
|
723
|
+
src_type_dict = {
|
724
|
+
"kind": "Vector",
|
725
|
+
"element_type": {"kind": "Str"},
|
726
|
+
"dimension": None,
|
727
|
+
}
|
728
|
+
decoder = make_engine_value_decoder([], src_type_dict, Vector[str])
|
729
|
+
assert decoder(["hello", "world"]) == ["hello", "world"]
|
730
|
+
|
731
|
+
|
732
|
+
def test_decode_error_non_nullable_or_non_list_vector():
|
733
|
+
"""Test decoding errors for non-nullable vectors or non-list inputs."""
|
734
|
+
src_type_dict = {
|
735
|
+
"kind": "Vector",
|
736
|
+
"element_type": {"kind": "Float32"},
|
737
|
+
"dimension": None,
|
738
|
+
}
|
739
|
+
decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
|
740
|
+
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
|
741
|
+
decoder(None)
|
742
|
+
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
|
743
|
+
decoder("not a list")
|
744
|
+
|
745
|
+
|
746
|
+
def test_dump_vector_type_annotation_with_dim():
|
747
|
+
"""Test dumping a vector type annotation with a specified dimension."""
|
748
|
+
expected_dump = {
|
749
|
+
"type": {
|
750
|
+
"kind": "Vector",
|
751
|
+
"element_type": {"kind": "Float32"},
|
752
|
+
"dimension": 3,
|
753
|
+
}
|
754
|
+
}
|
755
|
+
assert dump_engine_object(Float32VectorType) == expected_dump
|
756
|
+
|
757
|
+
|
758
|
+
def test_dump_vector_type_annotation_no_dim():
|
759
|
+
"""Test dumping a vector type annotation with no dimension."""
|
760
|
+
expected_dump_no_dim = {
|
761
|
+
"type": {
|
762
|
+
"kind": "Vector",
|
763
|
+
"element_type": {"kind": "Float64"},
|
764
|
+
"dimension": None,
|
765
|
+
}
|
766
|
+
}
|
767
|
+
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
|
@@ -0,0 +1,499 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import datetime
|
3
|
+
import uuid
|
4
|
+
from typing import (
|
5
|
+
Annotated,
|
6
|
+
List,
|
7
|
+
Dict,
|
8
|
+
Literal,
|
9
|
+
Any,
|
10
|
+
get_args,
|
11
|
+
NamedTuple,
|
12
|
+
)
|
13
|
+
from collections.abc import Sequence, Mapping
|
14
|
+
import pytest
|
15
|
+
import numpy as np
|
16
|
+
from numpy.typing import NDArray
|
17
|
+
|
18
|
+
from cocoindex.typing import (
|
19
|
+
analyze_type_info,
|
20
|
+
Vector,
|
21
|
+
VectorInfo,
|
22
|
+
TypeKind,
|
23
|
+
TypeAttr,
|
24
|
+
Float32,
|
25
|
+
Float64,
|
26
|
+
encode_enriched_type,
|
27
|
+
AnalyzedTypeInfo,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@dataclasses.dataclass
|
32
|
+
class SimpleDataclass:
|
33
|
+
name: str
|
34
|
+
value: int
|
35
|
+
|
36
|
+
|
37
|
+
class SimpleNamedTuple(NamedTuple):
|
38
|
+
name: str
|
39
|
+
value: Any
|
40
|
+
|
41
|
+
|
42
|
+
def test_ndarray_float32_no_dim():
|
43
|
+
typ = NDArray[np.float32]
|
44
|
+
result = analyze_type_info(typ)
|
45
|
+
assert result == AnalyzedTypeInfo(
|
46
|
+
kind="Vector",
|
47
|
+
vector_info=VectorInfo(dim=None),
|
48
|
+
elem_type=Float32,
|
49
|
+
key_type=None,
|
50
|
+
struct_type=None,
|
51
|
+
attrs=None,
|
52
|
+
nullable=False,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def test_vector_float32_no_dim():
|
57
|
+
typ = Vector[np.float32]
|
58
|
+
result = analyze_type_info(typ)
|
59
|
+
assert result == AnalyzedTypeInfo(
|
60
|
+
kind="Vector",
|
61
|
+
vector_info=VectorInfo(dim=None),
|
62
|
+
elem_type=Float32,
|
63
|
+
key_type=None,
|
64
|
+
struct_type=None,
|
65
|
+
attrs=None,
|
66
|
+
nullable=False,
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def test_ndarray_float64_with_dim():
|
71
|
+
typ = Annotated[NDArray[np.float64], VectorInfo(dim=128)]
|
72
|
+
result = analyze_type_info(typ)
|
73
|
+
assert result == AnalyzedTypeInfo(
|
74
|
+
kind="Vector",
|
75
|
+
vector_info=VectorInfo(dim=128),
|
76
|
+
elem_type=Float64,
|
77
|
+
key_type=None,
|
78
|
+
struct_type=None,
|
79
|
+
attrs=None,
|
80
|
+
nullable=False,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
def test_vector_float32_with_dim():
|
85
|
+
typ = Vector[np.float32, Literal[384]]
|
86
|
+
result = analyze_type_info(typ)
|
87
|
+
assert result == AnalyzedTypeInfo(
|
88
|
+
kind="Vector",
|
89
|
+
vector_info=VectorInfo(dim=384),
|
90
|
+
elem_type=Float32,
|
91
|
+
key_type=None,
|
92
|
+
struct_type=None,
|
93
|
+
attrs=None,
|
94
|
+
nullable=False,
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
def test_ndarray_int64_no_dim():
|
99
|
+
typ = NDArray[np.int64]
|
100
|
+
result = analyze_type_info(typ)
|
101
|
+
assert result.kind == "Vector"
|
102
|
+
assert result.vector_info == VectorInfo(dim=None)
|
103
|
+
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
|
104
|
+
assert not result.nullable
|
105
|
+
|
106
|
+
|
107
|
+
def test_ndarray_int32_with_dim():
|
108
|
+
typ = Annotated[NDArray[np.int32], VectorInfo(dim=10)]
|
109
|
+
result = analyze_type_info(typ)
|
110
|
+
assert result.kind == "Vector"
|
111
|
+
assert result.vector_info == VectorInfo(dim=10)
|
112
|
+
assert get_args(result.elem_type) == (int, TypeKind("Int32"))
|
113
|
+
assert not result.nullable
|
114
|
+
|
115
|
+
|
116
|
+
def test_ndarray_uint8_no_dim():
|
117
|
+
typ = NDArray[np.uint8]
|
118
|
+
result = analyze_type_info(typ)
|
119
|
+
assert result.kind == "Vector"
|
120
|
+
assert result.vector_info == VectorInfo(dim=None)
|
121
|
+
assert get_args(result.elem_type) == (int, TypeKind("UInt8"))
|
122
|
+
assert not result.nullable
|
123
|
+
|
124
|
+
|
125
|
+
def test_nullable_ndarray():
|
126
|
+
typ = NDArray[np.float32] | None
|
127
|
+
result = analyze_type_info(typ)
|
128
|
+
assert result == AnalyzedTypeInfo(
|
129
|
+
kind="Vector",
|
130
|
+
vector_info=VectorInfo(dim=None),
|
131
|
+
elem_type=Float32,
|
132
|
+
key_type=None,
|
133
|
+
struct_type=None,
|
134
|
+
attrs=None,
|
135
|
+
nullable=True,
|
136
|
+
)
|
137
|
+
|
138
|
+
|
139
|
+
def test_vector_str():
|
140
|
+
typ = Vector[str]
|
141
|
+
result = analyze_type_info(typ)
|
142
|
+
assert result.kind == "Vector"
|
143
|
+
assert result.elem_type == str
|
144
|
+
assert result.vector_info == VectorInfo(dim=None)
|
145
|
+
|
146
|
+
|
147
|
+
def test_vector_complex64():
|
148
|
+
typ = Vector[np.complex64]
|
149
|
+
result = analyze_type_info(typ)
|
150
|
+
assert result.kind == "Vector"
|
151
|
+
assert result.elem_type == np.complex64
|
152
|
+
assert result.vector_info == VectorInfo(dim=None)
|
153
|
+
|
154
|
+
|
155
|
+
def test_non_numpy_vector():
|
156
|
+
typ = Vector[float, Literal[3]]
|
157
|
+
result = analyze_type_info(typ)
|
158
|
+
assert result.kind == "Vector"
|
159
|
+
assert result.elem_type == float
|
160
|
+
assert result.vector_info == VectorInfo(dim=3)
|
161
|
+
|
162
|
+
|
163
|
+
def test_ndarray_any_dtype():
|
164
|
+
typ = NDArray[Any]
|
165
|
+
with pytest.raises(
|
166
|
+
TypeError, match="NDArray for Vector must use a concrete numpy dtype"
|
167
|
+
):
|
168
|
+
analyze_type_info(typ)
|
169
|
+
|
170
|
+
|
171
|
+
def test_list_of_primitives():
|
172
|
+
typ = List[str]
|
173
|
+
result = analyze_type_info(typ)
|
174
|
+
assert result == AnalyzedTypeInfo(
|
175
|
+
kind="Vector",
|
176
|
+
vector_info=VectorInfo(dim=None),
|
177
|
+
elem_type=str,
|
178
|
+
key_type=None,
|
179
|
+
struct_type=None,
|
180
|
+
attrs=None,
|
181
|
+
nullable=False,
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
def test_list_of_structs():
|
186
|
+
typ = List[SimpleDataclass]
|
187
|
+
result = analyze_type_info(typ)
|
188
|
+
assert result == AnalyzedTypeInfo(
|
189
|
+
kind="LTable",
|
190
|
+
vector_info=None,
|
191
|
+
elem_type=SimpleDataclass,
|
192
|
+
key_type=None,
|
193
|
+
struct_type=None,
|
194
|
+
attrs=None,
|
195
|
+
nullable=False,
|
196
|
+
)
|
197
|
+
|
198
|
+
|
199
|
+
def test_sequence_of_int():
|
200
|
+
typ = Sequence[int]
|
201
|
+
result = analyze_type_info(typ)
|
202
|
+
assert result == AnalyzedTypeInfo(
|
203
|
+
kind="Vector",
|
204
|
+
vector_info=VectorInfo(dim=None),
|
205
|
+
elem_type=int,
|
206
|
+
key_type=None,
|
207
|
+
struct_type=None,
|
208
|
+
attrs=None,
|
209
|
+
nullable=False,
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
def test_list_with_vector_info():
|
214
|
+
typ = Annotated[List[int], VectorInfo(dim=5)]
|
215
|
+
result = analyze_type_info(typ)
|
216
|
+
assert result == AnalyzedTypeInfo(
|
217
|
+
kind="Vector",
|
218
|
+
vector_info=VectorInfo(dim=5),
|
219
|
+
elem_type=int,
|
220
|
+
key_type=None,
|
221
|
+
struct_type=None,
|
222
|
+
attrs=None,
|
223
|
+
nullable=False,
|
224
|
+
)
|
225
|
+
|
226
|
+
|
227
|
+
def test_dict_str_int():
|
228
|
+
typ = Dict[str, int]
|
229
|
+
result = analyze_type_info(typ)
|
230
|
+
assert result == AnalyzedTypeInfo(
|
231
|
+
kind="KTable",
|
232
|
+
vector_info=None,
|
233
|
+
elem_type=(str, int),
|
234
|
+
key_type=None,
|
235
|
+
struct_type=None,
|
236
|
+
attrs=None,
|
237
|
+
nullable=False,
|
238
|
+
)
|
239
|
+
|
240
|
+
|
241
|
+
def test_mapping_str_dataclass():
|
242
|
+
typ = Mapping[str, SimpleDataclass]
|
243
|
+
result = analyze_type_info(typ)
|
244
|
+
assert result == AnalyzedTypeInfo(
|
245
|
+
kind="KTable",
|
246
|
+
vector_info=None,
|
247
|
+
elem_type=(str, SimpleDataclass),
|
248
|
+
key_type=None,
|
249
|
+
struct_type=None,
|
250
|
+
attrs=None,
|
251
|
+
nullable=False,
|
252
|
+
)
|
253
|
+
|
254
|
+
|
255
|
+
def test_dataclass():
|
256
|
+
typ = SimpleDataclass
|
257
|
+
result = analyze_type_info(typ)
|
258
|
+
assert result == AnalyzedTypeInfo(
|
259
|
+
kind="Struct",
|
260
|
+
vector_info=None,
|
261
|
+
elem_type=None,
|
262
|
+
key_type=None,
|
263
|
+
struct_type=SimpleDataclass,
|
264
|
+
attrs=None,
|
265
|
+
nullable=False,
|
266
|
+
)
|
267
|
+
|
268
|
+
|
269
|
+
def test_named_tuple():
|
270
|
+
typ = SimpleNamedTuple
|
271
|
+
result = analyze_type_info(typ)
|
272
|
+
assert result == AnalyzedTypeInfo(
|
273
|
+
kind="Struct",
|
274
|
+
vector_info=None,
|
275
|
+
elem_type=None,
|
276
|
+
key_type=None,
|
277
|
+
struct_type=SimpleNamedTuple,
|
278
|
+
attrs=None,
|
279
|
+
nullable=False,
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def test_tuple_key_value():
|
284
|
+
typ = (str, int)
|
285
|
+
result = analyze_type_info(typ)
|
286
|
+
assert result == AnalyzedTypeInfo(
|
287
|
+
kind="Int64",
|
288
|
+
vector_info=None,
|
289
|
+
elem_type=None,
|
290
|
+
key_type=str,
|
291
|
+
struct_type=None,
|
292
|
+
attrs=None,
|
293
|
+
nullable=False,
|
294
|
+
)
|
295
|
+
|
296
|
+
|
297
|
+
def test_str():
|
298
|
+
typ = str
|
299
|
+
result = analyze_type_info(typ)
|
300
|
+
assert result == AnalyzedTypeInfo(
|
301
|
+
kind="Str",
|
302
|
+
vector_info=None,
|
303
|
+
elem_type=None,
|
304
|
+
key_type=None,
|
305
|
+
struct_type=None,
|
306
|
+
attrs=None,
|
307
|
+
nullable=False,
|
308
|
+
)
|
309
|
+
|
310
|
+
|
311
|
+
def test_bool():
|
312
|
+
typ = bool
|
313
|
+
result = analyze_type_info(typ)
|
314
|
+
assert result == AnalyzedTypeInfo(
|
315
|
+
kind="Bool",
|
316
|
+
vector_info=None,
|
317
|
+
elem_type=None,
|
318
|
+
key_type=None,
|
319
|
+
struct_type=None,
|
320
|
+
attrs=None,
|
321
|
+
nullable=False,
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def test_bytes():
|
326
|
+
typ = bytes
|
327
|
+
result = analyze_type_info(typ)
|
328
|
+
assert result == AnalyzedTypeInfo(
|
329
|
+
kind="Bytes",
|
330
|
+
vector_info=None,
|
331
|
+
elem_type=None,
|
332
|
+
key_type=None,
|
333
|
+
struct_type=None,
|
334
|
+
attrs=None,
|
335
|
+
nullable=False,
|
336
|
+
)
|
337
|
+
|
338
|
+
|
339
|
+
def test_uuid():
|
340
|
+
typ = uuid.UUID
|
341
|
+
result = analyze_type_info(typ)
|
342
|
+
assert result == AnalyzedTypeInfo(
|
343
|
+
kind="Uuid",
|
344
|
+
vector_info=None,
|
345
|
+
elem_type=None,
|
346
|
+
key_type=None,
|
347
|
+
struct_type=None,
|
348
|
+
attrs=None,
|
349
|
+
nullable=False,
|
350
|
+
)
|
351
|
+
|
352
|
+
|
353
|
+
def test_date():
|
354
|
+
typ = datetime.date
|
355
|
+
result = analyze_type_info(typ)
|
356
|
+
assert result == AnalyzedTypeInfo(
|
357
|
+
kind="Date",
|
358
|
+
vector_info=None,
|
359
|
+
elem_type=None,
|
360
|
+
key_type=None,
|
361
|
+
struct_type=None,
|
362
|
+
attrs=None,
|
363
|
+
nullable=False,
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
def test_time():
|
368
|
+
typ = datetime.time
|
369
|
+
result = analyze_type_info(typ)
|
370
|
+
assert result == AnalyzedTypeInfo(
|
371
|
+
kind="Time",
|
372
|
+
vector_info=None,
|
373
|
+
elem_type=None,
|
374
|
+
key_type=None,
|
375
|
+
struct_type=None,
|
376
|
+
attrs=None,
|
377
|
+
nullable=False,
|
378
|
+
)
|
379
|
+
|
380
|
+
|
381
|
+
def test_timedelta():
|
382
|
+
typ = datetime.timedelta
|
383
|
+
result = analyze_type_info(typ)
|
384
|
+
assert result == AnalyzedTypeInfo(
|
385
|
+
kind="TimeDelta",
|
386
|
+
vector_info=None,
|
387
|
+
elem_type=None,
|
388
|
+
key_type=None,
|
389
|
+
struct_type=None,
|
390
|
+
attrs=None,
|
391
|
+
nullable=False,
|
392
|
+
)
|
393
|
+
|
394
|
+
|
395
|
+
def test_float():
|
396
|
+
typ = float
|
397
|
+
result = analyze_type_info(typ)
|
398
|
+
assert result == AnalyzedTypeInfo(
|
399
|
+
kind="Float64",
|
400
|
+
vector_info=None,
|
401
|
+
elem_type=None,
|
402
|
+
key_type=None,
|
403
|
+
struct_type=None,
|
404
|
+
attrs=None,
|
405
|
+
nullable=False,
|
406
|
+
)
|
407
|
+
|
408
|
+
|
409
|
+
def test_int():
|
410
|
+
typ = int
|
411
|
+
result = analyze_type_info(typ)
|
412
|
+
assert result == AnalyzedTypeInfo(
|
413
|
+
kind="Int64",
|
414
|
+
vector_info=None,
|
415
|
+
elem_type=None,
|
416
|
+
key_type=None,
|
417
|
+
struct_type=None,
|
418
|
+
attrs=None,
|
419
|
+
nullable=False,
|
420
|
+
)
|
421
|
+
|
422
|
+
|
423
|
+
def test_type_with_attributes():
|
424
|
+
typ = Annotated[str, TypeAttr("key", "value")]
|
425
|
+
result = analyze_type_info(typ)
|
426
|
+
assert result == AnalyzedTypeInfo(
|
427
|
+
kind="Str",
|
428
|
+
vector_info=None,
|
429
|
+
elem_type=None,
|
430
|
+
key_type=None,
|
431
|
+
struct_type=None,
|
432
|
+
attrs={"key": "value"},
|
433
|
+
nullable=False,
|
434
|
+
)
|
435
|
+
|
436
|
+
|
437
|
+
def test_encode_enriched_type_none():
|
438
|
+
typ = None
|
439
|
+
result = encode_enriched_type(typ)
|
440
|
+
assert result is None
|
441
|
+
|
442
|
+
|
443
|
+
def test_encode_enriched_type_struct():
|
444
|
+
typ = SimpleDataclass
|
445
|
+
result = encode_enriched_type(typ)
|
446
|
+
assert result["type"]["kind"] == "Struct"
|
447
|
+
assert len(result["type"]["fields"]) == 2
|
448
|
+
assert result["type"]["fields"][0]["name"] == "name"
|
449
|
+
assert result["type"]["fields"][0]["type"]["kind"] == "Str"
|
450
|
+
assert result["type"]["fields"][1]["name"] == "value"
|
451
|
+
assert result["type"]["fields"][1]["type"]["kind"] == "Int64"
|
452
|
+
|
453
|
+
|
454
|
+
def test_encode_enriched_type_vector():
|
455
|
+
typ = NDArray[np.float32]
|
456
|
+
result = encode_enriched_type(typ)
|
457
|
+
assert result["type"]["kind"] == "Vector"
|
458
|
+
assert result["type"]["element_type"]["kind"] == "Float32"
|
459
|
+
assert result["type"]["dimension"] is None
|
460
|
+
|
461
|
+
|
462
|
+
def test_encode_enriched_type_ltable():
|
463
|
+
typ = List[SimpleDataclass]
|
464
|
+
result = encode_enriched_type(typ)
|
465
|
+
assert result["type"]["kind"] == "LTable"
|
466
|
+
assert result["type"]["row"]["kind"] == "Struct"
|
467
|
+
assert len(result["type"]["row"]["fields"]) == 2
|
468
|
+
|
469
|
+
|
470
|
+
def test_encode_enriched_type_with_attrs():
|
471
|
+
typ = Annotated[str, TypeAttr("key", "value")]
|
472
|
+
result = encode_enriched_type(typ)
|
473
|
+
assert result["type"]["kind"] == "Str"
|
474
|
+
assert result["attrs"] == {"key": "value"}
|
475
|
+
|
476
|
+
|
477
|
+
def test_encode_enriched_type_nullable():
|
478
|
+
typ = str | None
|
479
|
+
result = encode_enriched_type(typ)
|
480
|
+
assert result["type"]["kind"] == "Str"
|
481
|
+
assert result["nullable"] is True
|
482
|
+
|
483
|
+
|
484
|
+
def test_invalid_struct_kind():
|
485
|
+
typ = Annotated[SimpleDataclass, TypeKind("Vector")]
|
486
|
+
with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
|
487
|
+
analyze_type_info(typ)
|
488
|
+
|
489
|
+
|
490
|
+
def test_invalid_list_kind():
|
491
|
+
typ = Annotated[List[int], TypeKind("Struct")]
|
492
|
+
with pytest.raises(ValueError, match="Unexpected type kind for list: Struct"):
|
493
|
+
analyze_type_info(typ)
|
494
|
+
|
495
|
+
|
496
|
+
def test_unsupported_type():
|
497
|
+
typ = set
|
498
|
+
with pytest.raises(ValueError, match="type unsupported yet: <class 'set'>"):
|
499
|
+
analyze_type_info(typ)
|
cocoindex/typing.py
CHANGED
@@ -9,14 +9,16 @@ from typing import (
|
|
9
9
|
Annotated,
|
10
10
|
NamedTuple,
|
11
11
|
Any,
|
12
|
+
KeysView,
|
12
13
|
TypeVar,
|
13
14
|
TYPE_CHECKING,
|
14
15
|
overload,
|
15
|
-
Sequence,
|
16
16
|
Generic,
|
17
17
|
Literal,
|
18
18
|
Protocol,
|
19
19
|
)
|
20
|
+
import numpy as np
|
21
|
+
from numpy.typing import NDArray
|
20
22
|
|
21
23
|
|
22
24
|
class VectorInfo(NamedTuple):
|
@@ -50,7 +52,7 @@ if TYPE_CHECKING:
|
|
50
52
|
Dim_co = TypeVar("Dim_co", bound=int, covariant=True)
|
51
53
|
|
52
54
|
class Vector(Protocol, Generic[T_co, Dim_co]):
|
53
|
-
"""Vector[T, Dim] is a special typing alias for
|
55
|
+
"""Vector[T, Dim] is a special typing alias for an NDArray[T] with optional dimension info"""
|
54
56
|
|
55
57
|
def __getitem__(self, index: int) -> T_co: ...
|
56
58
|
def __len__(self) -> int: ...
|
@@ -58,25 +60,34 @@ if TYPE_CHECKING:
|
|
58
60
|
else:
|
59
61
|
|
60
62
|
class Vector: # type: ignore[unreachable]
|
61
|
-
"""A special typing alias for
|
63
|
+
"""A special typing alias for an NDArray[T] with optional dimension info"""
|
62
64
|
|
63
65
|
def __class_getitem__(self, params):
|
64
66
|
if not isinstance(params, tuple):
|
65
|
-
#
|
66
|
-
|
67
|
-
|
67
|
+
# No dimension provided, e.g., Vector[np.float32]
|
68
|
+
dtype = params
|
69
|
+
# Use NDArray for supported numeric dtypes, else list
|
70
|
+
if DtypeRegistry.get_by_dtype(dtype) is not None:
|
71
|
+
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
|
72
|
+
return Annotated[list[dtype], VectorInfo(dim=None)]
|
68
73
|
else:
|
69
|
-
# Element type and dimension provided
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
+
# Element type and dimension provided, e.g., Vector[np.float32, Literal[3]]
|
75
|
+
dtype, dim_literal = params
|
76
|
+
# Extract the literal value
|
77
|
+
dim_val = (
|
78
|
+
typing.get_args(dim_literal)[0]
|
79
|
+
if typing.get_origin(dim_literal) is Literal
|
80
|
+
else None
|
81
|
+
)
|
82
|
+
if DtypeRegistry.get_by_dtype(dtype) is not None:
|
83
|
+
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
|
84
|
+
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
|
74
85
|
|
75
86
|
|
76
87
|
TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
|
77
88
|
KEY_FIELD_NAME: str = "_key"
|
78
89
|
|
79
|
-
ElementType = type | tuple[type, type]
|
90
|
+
ElementType = type | tuple[type, type] | Annotated[Any, TypeKind]
|
80
91
|
|
81
92
|
|
82
93
|
def is_namedtuple_type(t: type) -> bool:
|
@@ -89,6 +100,48 @@ def _is_struct_type(t: ElementType | None) -> bool:
|
|
89
100
|
)
|
90
101
|
|
91
102
|
|
103
|
+
class DtypeInfo:
|
104
|
+
"""Metadata for a NumPy dtype."""
|
105
|
+
|
106
|
+
def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
|
107
|
+
self.numpy_dtype = numpy_dtype
|
108
|
+
self.kind = kind
|
109
|
+
self.python_type = python_type
|
110
|
+
self.annotated_type = Annotated[python_type, TypeKind(kind)]
|
111
|
+
|
112
|
+
|
113
|
+
class DtypeRegistry:
|
114
|
+
_mappings: dict[type, DtypeInfo] = {
|
115
|
+
np.float32: DtypeInfo(np.float32, "Float32", float),
|
116
|
+
np.float64: DtypeInfo(np.float64, "Float64", float),
|
117
|
+
np.int32: DtypeInfo(np.int32, "Int32", int),
|
118
|
+
np.int64: DtypeInfo(np.int64, "Int64", int),
|
119
|
+
np.uint8: DtypeInfo(np.uint8, "UInt8", int),
|
120
|
+
np.uint16: DtypeInfo(np.uint16, "UInt16", int),
|
121
|
+
np.uint32: DtypeInfo(np.uint32, "UInt32", int),
|
122
|
+
np.uint64: DtypeInfo(np.uint64, "UInt64", int),
|
123
|
+
}
|
124
|
+
|
125
|
+
@classmethod
|
126
|
+
def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
|
127
|
+
if dtype is Any:
|
128
|
+
raise TypeError(
|
129
|
+
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
|
130
|
+
)
|
131
|
+
return cls._mappings.get(dtype)
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def get_by_kind(kind: str) -> DtypeInfo | None:
|
135
|
+
return next(
|
136
|
+
(info for info in DtypeRegistry._mappings.values() if info.kind == kind),
|
137
|
+
None,
|
138
|
+
)
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def supported_dtypes() -> KeysView[type]:
|
142
|
+
return DtypeRegistry._mappings.keys()
|
143
|
+
|
144
|
+
|
92
145
|
@dataclasses.dataclass
|
93
146
|
class AnalyzedTypeInfo:
|
94
147
|
"""
|
@@ -179,12 +232,34 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
|
|
179
232
|
vector_info = VectorInfo(dim=None)
|
180
233
|
elif not (kind == "Vector" or kind in TABLE_TYPES):
|
181
234
|
raise ValueError(f"Unexpected type kind for list: {kind}")
|
235
|
+
elif base_type is np.ndarray:
|
236
|
+
kind = "Vector"
|
237
|
+
args = typing.get_args(t)
|
238
|
+
_, dtype_spec = args
|
239
|
+
|
240
|
+
dtype_args = typing.get_args(dtype_spec)
|
241
|
+
if not dtype_args:
|
242
|
+
raise ValueError("Invalid dtype specification for NDArray")
|
243
|
+
|
244
|
+
numpy_dtype = dtype_args[0]
|
245
|
+
dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
|
246
|
+
if dtype_info is None:
|
247
|
+
raise ValueError(
|
248
|
+
f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
|
249
|
+
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
|
250
|
+
)
|
251
|
+
elem_type = dtype_info.annotated_type
|
252
|
+
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
|
253
|
+
|
182
254
|
elif base_type is collections.abc.Mapping or base_type is dict:
|
183
255
|
args = typing.get_args(t)
|
184
256
|
elem_type = (args[0], args[1])
|
185
257
|
kind = "KTable"
|
186
258
|
elif kind is None:
|
187
|
-
|
259
|
+
dtype_info = DtypeRegistry.get_by_dtype(t)
|
260
|
+
if dtype_info is not None:
|
261
|
+
kind = dtype_info.kind
|
262
|
+
elif t is bytes:
|
188
263
|
kind = "Bytes"
|
189
264
|
elif t is str:
|
190
265
|
kind = "Str"
|
@@ -1,27 +1,27 @@
|
|
1
|
-
cocoindex-0.1.
|
2
|
-
cocoindex-0.1.
|
3
|
-
cocoindex-0.1.
|
4
|
-
cocoindex-0.1.
|
5
|
-
cocoindex/__init__.py,sha256=
|
6
|
-
cocoindex/_engine.cpython-311-darwin.so,sha256=
|
1
|
+
cocoindex-0.1.50.dist-info/METADATA,sha256=PdFYx9yT4tBqCXVaP77iQjRZMB26D2I8BWXoA8i7ZqQ,9875
|
2
|
+
cocoindex-0.1.50.dist-info/WHEEL,sha256=8hL2oHqulIPF_TJ_OfB-8kD2X9O6HJL0emTmeguFbqc,104
|
3
|
+
cocoindex-0.1.50.dist-info/entry_points.txt,sha256=_NretjYVzBdNTn7dK-zgwr7YfG2afz1u1uSE-5bZXF8,46
|
4
|
+
cocoindex-0.1.50.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
5
|
+
cocoindex/__init__.py,sha256=BC6BvotY0LcKHCUNnZWjKZX7Afl18TDR2Bfv5R6Y4DM,811
|
6
|
+
cocoindex/_engine.cpython-311-darwin.so,sha256=hqncg8nnjBaQicxXkJbOf4Hl3-UoE9onm_Q7On3dvQk,57265296
|
7
7
|
cocoindex/auth_registry.py,sha256=1XqO7ibjmBBd8i11XSJTvTgdz8p1ptW-ZpuSgo_5zzk,716
|
8
8
|
cocoindex/cli.py,sha256=p92s-Ya6VpT4EY5xfbl-R0yyIJv8ySK-2eyxnwC8KqA,17853
|
9
|
-
cocoindex/convert.py,sha256=
|
9
|
+
cocoindex/convert.py,sha256=dxlMUd6Jweun1xESJv0w-WbIx-NifMXlUF9RqgeggWM,8735
|
10
10
|
cocoindex/flow.py,sha256=sC7eg0L9MkZuo7MlvGA4eYe-mEReIyPjKIILFtkWm-Y,30017
|
11
|
-
cocoindex/functions.py,sha256=
|
11
|
+
cocoindex/functions.py,sha256=9A61Jj5a3vQoI2MIAhjXvJrDxSzDhe6VncQWbiVtwcg,2393
|
12
12
|
cocoindex/index.py,sha256=j93B9jEvvLXHtpzKWL88SY6wCGEoPgpsQhEGHlyYGFg,540
|
13
|
-
cocoindex/lib.py,sha256=
|
13
|
+
cocoindex/lib.py,sha256=BeRUn3RqE_wSsVtsgCzbFFKe1LXgRyRmMOcmwWBuEXo,2940
|
14
14
|
cocoindex/llm.py,sha256=KO-R4mrAWtxXD82-Yv5ixpkKMVfkwpbdWwqPVZygLu4,352
|
15
15
|
cocoindex/op.py,sha256=hOJoHC8elhLHMNVRTGTBWNSK5uYrQb-FfRV-qt08j8g,11815
|
16
16
|
cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
cocoindex/query.py,sha256=UswtEp0wZwwNPeR9x9lRokFpNNZOHkyt5tYfCmGIWWU,3450
|
18
17
|
cocoindex/runtime.py,sha256=bAdHYaXFWiiUWyAgzmKTeaAaRR0D_AmaqVCIdPO-v00,1056
|
19
18
|
cocoindex/setting.py,sha256=ePqHw1i95rVtqYYRVqzVwtBifRO4SfH1rlMW_AG3Zek,3418
|
20
19
|
cocoindex/setup.py,sha256=u5dYZFKfz4yZLiGHD0guNaR0s4zY9JAoZWrWHpAHw_0,773
|
21
20
|
cocoindex/sources.py,sha256=JCnOhv1w4o28e03i7yvo4ESicWYAhckkBg5bQlxNH4U,1330
|
22
21
|
cocoindex/storages.py,sha256=i9cPzj5LbLJVHjskEWyltUXF9kqCyc6ftMEqll6FDwI,2804
|
23
22
|
cocoindex/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
24
|
-
cocoindex/tests/test_convert.py,sha256=
|
25
|
-
cocoindex/
|
23
|
+
cocoindex/tests/test_convert.py,sha256=EsR7NEGKAypQpr3FxRz8xja_z0sKmYpTjJuCMSUdyYw,26217
|
24
|
+
cocoindex/tests/test_typing.py,sha256=P6bnRjxMjKH2B5RQ0ncZQAMyOPSdviFVmML7fCS2iR4,11884
|
25
|
+
cocoindex/typing.py,sha256=9366o49R6-YbrG9rc9UOxJjgTYITC-7DjDkJSVIjZL0,12316
|
26
26
|
cocoindex/utils.py,sha256=5a3ubVzDowJUJyUl8ecd75Th_OzON3G-MV2vXf0dQSk,503
|
27
|
-
cocoindex-0.1.
|
27
|
+
cocoindex-0.1.50.dist-info/RECORD,,
|
cocoindex/query.py
DELETED
@@ -1,115 +0,0 @@
|
|
1
|
-
from typing import Callable, Any
|
2
|
-
from dataclasses import dataclass
|
3
|
-
from threading import Lock
|
4
|
-
|
5
|
-
from . import flow as fl
|
6
|
-
from . import index
|
7
|
-
from . import _engine # type: ignore
|
8
|
-
|
9
|
-
_handlers_lock = Lock()
|
10
|
-
_handlers: dict[str, _engine.SimpleSemanticsQueryHandler] = {}
|
11
|
-
|
12
|
-
|
13
|
-
@dataclass
|
14
|
-
class SimpleSemanticsQueryInfo:
|
15
|
-
"""
|
16
|
-
Additional information about the query.
|
17
|
-
"""
|
18
|
-
|
19
|
-
similarity_metric: index.VectorSimilarityMetric
|
20
|
-
query_vector: list[float]
|
21
|
-
vector_field_name: str
|
22
|
-
|
23
|
-
|
24
|
-
@dataclass
|
25
|
-
class QueryResult:
|
26
|
-
"""
|
27
|
-
A single result from the query.
|
28
|
-
"""
|
29
|
-
|
30
|
-
data: dict[str, Any]
|
31
|
-
score: float
|
32
|
-
|
33
|
-
|
34
|
-
class SimpleSemanticsQueryHandler:
|
35
|
-
"""
|
36
|
-
A query handler that uses simple semantics to query the index.
|
37
|
-
"""
|
38
|
-
|
39
|
-
_lazy_query_handler: Callable[[], _engine.SimpleSemanticsQueryHandler]
|
40
|
-
|
41
|
-
def __init__(
|
42
|
-
self,
|
43
|
-
name: str,
|
44
|
-
flow: fl.Flow,
|
45
|
-
target_name: str,
|
46
|
-
query_transform_flow: Callable[..., fl.DataSlice[Any]],
|
47
|
-
default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY,
|
48
|
-
) -> None:
|
49
|
-
engine_handler = None
|
50
|
-
lock = Lock()
|
51
|
-
|
52
|
-
def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
|
53
|
-
nonlocal engine_handler, lock
|
54
|
-
if engine_handler is None:
|
55
|
-
with lock:
|
56
|
-
if engine_handler is None:
|
57
|
-
engine_handler = _engine.SimpleSemanticsQueryHandler(
|
58
|
-
flow.internal_flow(),
|
59
|
-
target_name,
|
60
|
-
fl.TransformFlow(
|
61
|
-
query_transform_flow, [str]
|
62
|
-
).internal_flow(),
|
63
|
-
default_similarity_metric.value,
|
64
|
-
)
|
65
|
-
engine_handler.register_query_handler(name)
|
66
|
-
return engine_handler
|
67
|
-
|
68
|
-
self._lazy_query_handler = _lazy_handler
|
69
|
-
|
70
|
-
with _handlers_lock:
|
71
|
-
_handlers[name] = self
|
72
|
-
|
73
|
-
def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler:
|
74
|
-
"""
|
75
|
-
Get the internal query handler.
|
76
|
-
"""
|
77
|
-
return self._lazy_query_handler()
|
78
|
-
|
79
|
-
def search(
|
80
|
-
self,
|
81
|
-
query: str,
|
82
|
-
limit: int,
|
83
|
-
vector_field_name: str | None = None,
|
84
|
-
similarity_metric: index.VectorSimilarityMetric | None = None,
|
85
|
-
) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
|
86
|
-
"""
|
87
|
-
Search the index with the given query, limit, vector field name, and similarity metric.
|
88
|
-
"""
|
89
|
-
internal_results, internal_info = self.internal_handler().search(
|
90
|
-
query,
|
91
|
-
limit,
|
92
|
-
vector_field_name,
|
93
|
-
similarity_metric.value if similarity_metric is not None else None,
|
94
|
-
)
|
95
|
-
results = [
|
96
|
-
QueryResult(data=result["data"], score=result["score"])
|
97
|
-
for result in internal_results
|
98
|
-
]
|
99
|
-
info = SimpleSemanticsQueryInfo(
|
100
|
-
similarity_metric=index.VectorSimilarityMetric(
|
101
|
-
internal_info["similarity_metric"]
|
102
|
-
),
|
103
|
-
query_vector=internal_info["query_vector"],
|
104
|
-
vector_field_name=internal_info["vector_field_name"],
|
105
|
-
)
|
106
|
-
return results, info
|
107
|
-
|
108
|
-
|
109
|
-
def ensure_all_handlers_built() -> None:
|
110
|
-
"""
|
111
|
-
Ensure all handlers are built.
|
112
|
-
"""
|
113
|
-
with _handlers_lock:
|
114
|
-
for handler in _handlers.values():
|
115
|
-
handler.internal_handler()
|
File without changes
|
File without changes
|