cocoindex 0.1.26__cp312-cp312-macosx_11_0_arm64.whl → 0.1.27__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +2 -1
- cocoindex/_engine.cpython-312-darwin.so +0 -0
- cocoindex/convert.py +42 -21
- cocoindex/flow.py +2 -2
- cocoindex/functions.py +1 -1
- cocoindex/lib.py +15 -7
- cocoindex/llm.py +1 -0
- cocoindex/op.py +4 -4
- cocoindex/query.py +5 -4
- cocoindex/storages.py +1 -1
- cocoindex/tests/test_convert.py +99 -53
- cocoindex/typing.py +75 -47
- {cocoindex-0.1.26.dist-info → cocoindex-0.1.27.dist-info}/METADATA +1 -1
- cocoindex-0.1.27.dist-info/RECORD +24 -0
- cocoindex-0.1.26.dist-info/RECORD +0 -24
- {cocoindex-0.1.26.dist-info → cocoindex-0.1.27.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.26.dist-info → cocoindex-0.1.27.dist-info}/licenses/LICENSE +0 -0
cocoindex/__init__.py
CHANGED
@@ -9,4 +9,5 @@ from .llm import LlmSpec, LlmApiType
|
|
9
9
|
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
|
10
10
|
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
|
11
11
|
from .lib import *
|
12
|
-
from ._engine import OpArgSchema
|
12
|
+
from ._engine import OpArgSchema
|
13
|
+
from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
|
Binary file
|
cocoindex/convert.py
CHANGED
@@ -8,25 +8,28 @@ import uuid
|
|
8
8
|
|
9
9
|
from enum import Enum
|
10
10
|
from typing import Any, Callable, get_origin
|
11
|
-
from .typing import analyze_type_info, encode_enriched_type,
|
11
|
+
from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME
|
12
12
|
|
13
|
-
|
14
|
-
|
13
|
+
|
14
|
+
def encode_engine_value(value: Any) -> Any:
|
15
|
+
"""Encode a Python value to an engine value."""
|
15
16
|
if dataclasses.is_dataclass(value):
|
16
|
-
return [
|
17
|
+
return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
|
17
18
|
if isinstance(value, (list, tuple)):
|
18
|
-
return [
|
19
|
+
return [encode_engine_value(v) for v in value]
|
20
|
+
if isinstance(value, dict):
|
21
|
+
return [[encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items()]
|
19
22
|
if isinstance(value, uuid.UUID):
|
20
23
|
return value.bytes
|
21
24
|
return value
|
22
25
|
|
23
|
-
def
|
26
|
+
def make_engine_value_decoder(
|
24
27
|
field_path: list[str],
|
25
28
|
src_type: dict[str, Any],
|
26
29
|
dst_annotation,
|
27
30
|
) -> Callable[[Any], Any]:
|
28
31
|
"""
|
29
|
-
Make a
|
32
|
+
Make a decoder from an engine value to a Python value.
|
30
33
|
|
31
34
|
Args:
|
32
35
|
field_path: The path to the field in the engine value. For error messages.
|
@@ -34,13 +37,13 @@ def make_engine_value_converter(
|
|
34
37
|
dst_annotation: The type annotation of the Python value.
|
35
38
|
|
36
39
|
Returns:
|
37
|
-
A
|
40
|
+
A decoder from an engine value to a Python value.
|
38
41
|
"""
|
39
42
|
|
40
43
|
src_type_kind = src_type['kind']
|
41
44
|
|
42
45
|
if dst_annotation is inspect.Parameter.empty:
|
43
|
-
if src_type_kind == 'Struct' or src_type_kind in
|
46
|
+
if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES:
|
44
47
|
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
|
45
48
|
f"It's required for {src_type_kind} type.")
|
46
49
|
return lambda value: value
|
@@ -53,41 +56,59 @@ def make_engine_value_converter(
|
|
53
56
|
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
|
54
57
|
|
55
58
|
if dst_type_info.dataclass_type is not None:
|
56
|
-
return
|
59
|
+
return _make_engine_struct_value_decoder(
|
57
60
|
field_path, src_type['fields'], dst_type_info.dataclass_type)
|
58
61
|
|
59
|
-
if src_type_kind in
|
62
|
+
if src_type_kind in TABLE_TYPES:
|
60
63
|
field_path.append('[*]')
|
61
64
|
elem_type_info = analyze_type_info(dst_type_info.elem_type)
|
62
65
|
if elem_type_info.dataclass_type is None:
|
63
66
|
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
+
f"declared `{dst_type_info.kind}`, a dataclass type expected")
|
68
|
+
engine_fields_schema = src_type['row']['fields']
|
69
|
+
if elem_type_info.key_type is not None:
|
70
|
+
key_field_schema = engine_fields_schema[0]
|
71
|
+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
|
72
|
+
key_decoder = make_engine_value_decoder(
|
73
|
+
field_path, key_field_schema['type'], elem_type_info.key_type)
|
74
|
+
field_path.pop()
|
75
|
+
value_decoder = _make_engine_struct_value_decoder(
|
76
|
+
field_path, engine_fields_schema[1:], elem_type_info.dataclass_type)
|
77
|
+
def decode(value):
|
78
|
+
if value is None:
|
79
|
+
return None
|
80
|
+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
|
81
|
+
else:
|
82
|
+
elem_decoder = _make_engine_struct_value_decoder(
|
83
|
+
field_path, engine_fields_schema, elem_type_info.dataclass_type)
|
84
|
+
def decode(value):
|
85
|
+
if value is None:
|
86
|
+
return None
|
87
|
+
return [elem_decoder(v) for v in value]
|
67
88
|
field_path.pop()
|
68
|
-
return
|
89
|
+
return decode
|
69
90
|
|
70
91
|
if src_type_kind == 'Uuid':
|
71
92
|
return lambda value: uuid.UUID(bytes=value)
|
72
93
|
|
73
94
|
return lambda value: value
|
74
95
|
|
75
|
-
def
|
96
|
+
def _make_engine_struct_value_decoder(
|
76
97
|
field_path: list[str],
|
77
98
|
src_fields: list[dict[str, Any]],
|
78
99
|
dst_dataclass_type: type,
|
79
100
|
) -> Callable[[list], Any]:
|
80
|
-
"""Make a
|
101
|
+
"""Make a decoder from an engine field values to a Python value."""
|
81
102
|
|
82
103
|
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
|
83
104
|
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
|
84
105
|
src_idx = src_name_to_idx.get(name)
|
85
106
|
if src_idx is not None:
|
86
107
|
field_path.append(f'.{name}')
|
87
|
-
|
108
|
+
field_decoder = make_engine_value_decoder(
|
88
109
|
field_path, src_fields[src_idx]['type'], param.annotation)
|
89
110
|
field_path.pop()
|
90
|
-
return lambda values:
|
111
|
+
return lambda values: field_decoder(values[src_idx])
|
91
112
|
|
92
113
|
default_value = param.default
|
93
114
|
if default_value is inspect.Parameter.empty:
|
@@ -96,12 +117,12 @@ def _make_engine_struct_value_converter(
|
|
96
117
|
|
97
118
|
return lambda _: default_value
|
98
119
|
|
99
|
-
|
120
|
+
field_value_decoder = [
|
100
121
|
make_closure_for_value(name, param)
|
101
122
|
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
|
102
123
|
|
103
124
|
return lambda values: dst_dataclass_type(
|
104
|
-
*(
|
125
|
+
*(decoder(values) for decoder in field_value_decoder))
|
105
126
|
|
106
127
|
def dump_engine_object(v: Any) -> Any:
|
107
128
|
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
|
cocoindex/flow.py
CHANGED
@@ -142,9 +142,9 @@ class DataSlice:
|
|
142
142
|
|
143
143
|
def row(self) -> DataScope:
|
144
144
|
"""
|
145
|
-
Return a scope representing each
|
145
|
+
Return a scope representing each row of the table.
|
146
146
|
"""
|
147
|
-
row_scope = self._state.engine_data_slice.
|
147
|
+
row_scope = self._state.engine_data_slice.table_row_scope()
|
148
148
|
return DataScope(self._state.flow_builder_state, row_scope)
|
149
149
|
|
150
150
|
def for_each(self, f: Callable[[DataScope], None]) -> None:
|
cocoindex/functions.py
CHANGED
@@ -41,7 +41,7 @@ class SentenceTransformerEmbedExecutor:
|
|
41
41
|
args = self.spec.args or {}
|
42
42
|
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
|
43
43
|
dim = self._model.get_sentence_embedding_dimension()
|
44
|
-
return Annotated[
|
44
|
+
return Annotated[Vector[Float32, dim], TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)]
|
45
45
|
|
46
46
|
def __call__(self, text: str) -> list[Float32]:
|
47
47
|
return self._model.encode(text).tolist()
|
cocoindex/lib.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""
|
2
2
|
Library level functions and states.
|
3
3
|
"""
|
4
|
-
import asyncio
|
5
4
|
import os
|
6
5
|
import sys
|
7
6
|
import functools
|
@@ -12,6 +11,7 @@ from dataclasses import dataclass
|
|
12
11
|
|
13
12
|
from . import _engine
|
14
13
|
from . import flow, query, cli
|
14
|
+
from .convert import dump_engine_object
|
15
15
|
|
16
16
|
|
17
17
|
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
|
@@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
|
|
22
22
|
else:
|
23
23
|
target[name] = value
|
24
24
|
|
25
|
+
@dataclass
|
26
|
+
class DatabaseConnectionSpec:
|
27
|
+
url: str
|
28
|
+
user: str | None = None
|
29
|
+
password: str | None = None
|
30
|
+
|
25
31
|
@dataclass
|
26
32
|
class Settings:
|
27
33
|
"""Settings for the cocoindex library."""
|
28
|
-
|
34
|
+
database: DatabaseConnectionSpec
|
29
35
|
|
30
36
|
@classmethod
|
31
37
|
def from_env(cls) -> Self:
|
32
38
|
"""Load settings from environment variables."""
|
33
39
|
|
34
|
-
|
35
|
-
_load_field(
|
36
|
-
|
37
|
-
|
40
|
+
db_kwargs: dict[str, str] = dict()
|
41
|
+
_load_field(db_kwargs, "url", "COCOINDEX_DATABASE_URL", required=True)
|
42
|
+
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
|
43
|
+
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
|
44
|
+
database = DatabaseConnectionSpec(**db_kwargs)
|
45
|
+
return cls(database=database)
|
38
46
|
|
39
47
|
|
40
48
|
def init(settings: Settings):
|
41
49
|
"""Initialize the cocoindex library."""
|
42
|
-
_engine.init(settings
|
50
|
+
_engine.init(dump_engine_object(settings))
|
43
51
|
|
44
52
|
@dataclass
|
45
53
|
class ServerSettings:
|
cocoindex/llm.py
CHANGED
cocoindex/op.py
CHANGED
@@ -9,7 +9,7 @@ from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass
|
|
9
9
|
from enum import Enum
|
10
10
|
|
11
11
|
from .typing import encode_enriched_type
|
12
|
-
from .convert import
|
12
|
+
from .convert import encode_engine_value, make_engine_value_decoder
|
13
13
|
from . import _engine
|
14
14
|
|
15
15
|
class OpCategory(Enum):
|
@@ -129,7 +129,7 @@ def _register_op_factory(
|
|
129
129
|
raise ValueError(
|
130
130
|
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
|
131
131
|
self._args_converters.append(
|
132
|
-
|
132
|
+
make_engine_value_decoder(
|
133
133
|
[arg_name], arg.value_type['type'], arg_param.annotation))
|
134
134
|
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
|
135
135
|
next_param_idx += 1
|
@@ -146,7 +146,7 @@ def _register_op_factory(
|
|
146
146
|
if expected_arg is None:
|
147
147
|
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
|
148
148
|
arg_param = expected_arg[1]
|
149
|
-
self._kwargs_converters[kwarg_name] =
|
149
|
+
self._kwargs_converters[kwarg_name] = make_engine_value_decoder(
|
150
150
|
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
|
151
151
|
|
152
152
|
missing_args = [name for (name, arg) in expected_kwargs
|
@@ -188,7 +188,7 @@ def _register_op_factory(
|
|
188
188
|
output = await self._acall(*converted_args, **converted_kwargs)
|
189
189
|
else:
|
190
190
|
output = await self._acall(*converted_args, **converted_kwargs)
|
191
|
-
return
|
191
|
+
return encode_engine_value(output)
|
192
192
|
|
193
193
|
_WrappedClass.__name__ = executor_cls.__name__
|
194
194
|
_WrappedClass.__doc__ = executor_cls.__doc__
|
cocoindex/query.py
CHANGED
@@ -66,15 +66,16 @@ class SimpleSemanticsQueryHandler:
|
|
66
66
|
return self._lazy_query_handler()
|
67
67
|
|
68
68
|
def search(self, query: str, limit: int, vector_field_name: str | None = None,
|
69
|
-
|
69
|
+
similarity_metric: index.VectorSimilarityMetric | None = None
|
70
|
+
) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
|
70
71
|
"""
|
71
72
|
Search the index with the given query, limit, vector field name, and similarity metric.
|
72
73
|
"""
|
73
74
|
internal_results, internal_info = self.internal_handler().search(
|
74
75
|
query, limit, vector_field_name,
|
75
|
-
|
76
|
-
|
77
|
-
|
76
|
+
similarity_metric.value if similarity_metric is not None else None)
|
77
|
+
results = [QueryResult(data=result['data'], score=result['score'])
|
78
|
+
for result in internal_results]
|
78
79
|
info = SimpleSemanticsQueryInfo(
|
79
80
|
similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
|
80
81
|
query_vector=internal_info['query_vector'],
|
cocoindex/storages.py
CHANGED
@@ -9,7 +9,7 @@ from .auth_registry import AuthEntryReference
|
|
9
9
|
class Postgres(op.StorageSpec):
|
10
10
|
"""Storage powered by Postgres and pgvector."""
|
11
11
|
|
12
|
-
|
12
|
+
database: AuthEntryReference | None = None
|
13
13
|
table_name: str | None = None
|
14
14
|
|
15
15
|
@dataclass
|
cocoindex/tests/test_convert.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import uuid
|
3
2
|
import datetime
|
4
3
|
from dataclasses import dataclass, make_dataclass
|
5
4
|
import pytest
|
5
|
+
import cocoindex
|
6
6
|
from cocoindex.typing import encode_enriched_type
|
7
|
-
from cocoindex.convert import
|
8
|
-
from
|
9
|
-
|
7
|
+
from cocoindex.convert import encode_engine_value, make_engine_value_decoder
|
8
|
+
from typing import Literal
|
10
9
|
@dataclass
|
11
10
|
class Order:
|
12
11
|
order_id: str
|
@@ -26,7 +25,7 @@ class Basket:
|
|
26
25
|
class Customer:
|
27
26
|
name: str
|
28
27
|
order: Order
|
29
|
-
tags: list[Tag] = None
|
28
|
+
tags: list[Tag] | None = None
|
30
29
|
|
31
30
|
@dataclass
|
32
31
|
class NestedStruct:
|
@@ -34,63 +33,63 @@ class NestedStruct:
|
|
34
33
|
orders: list[Order]
|
35
34
|
count: int = 0
|
36
35
|
|
37
|
-
def
|
36
|
+
def build_engine_value_decoder(engine_type_in_py, python_type=None):
|
38
37
|
"""
|
39
38
|
Helper to build a converter for the given engine-side type (as represented in Python).
|
40
39
|
If python_type is not specified, uses engine_type_in_py as the target.
|
41
40
|
"""
|
42
41
|
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
43
|
-
return
|
42
|
+
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
|
44
43
|
|
45
|
-
def
|
46
|
-
assert
|
47
|
-
assert
|
48
|
-
assert
|
49
|
-
assert
|
44
|
+
def test_encode_engine_value_basic_types():
|
45
|
+
assert encode_engine_value(123) == 123
|
46
|
+
assert encode_engine_value(3.14) == 3.14
|
47
|
+
assert encode_engine_value("hello") == "hello"
|
48
|
+
assert encode_engine_value(True) is True
|
50
49
|
|
51
|
-
def
|
50
|
+
def test_encode_engine_value_uuid():
|
52
51
|
u = uuid.uuid4()
|
53
|
-
assert
|
52
|
+
assert encode_engine_value(u) == u.bytes
|
54
53
|
|
55
|
-
def
|
54
|
+
def test_encode_engine_value_date_time_types():
|
56
55
|
d = datetime.date(2024, 1, 1)
|
57
|
-
assert
|
56
|
+
assert encode_engine_value(d) == d
|
58
57
|
t = datetime.time(12, 30)
|
59
|
-
assert
|
58
|
+
assert encode_engine_value(t) == t
|
60
59
|
dt = datetime.datetime(2024, 1, 1, 12, 30)
|
61
|
-
assert
|
60
|
+
assert encode_engine_value(dt) == dt
|
62
61
|
|
63
|
-
def
|
62
|
+
def test_encode_engine_value_struct():
|
64
63
|
order = Order(order_id="O123", name="mixed nuts", price=25.0)
|
65
|
-
assert
|
64
|
+
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
|
66
65
|
|
67
|
-
def
|
66
|
+
def test_encode_engine_value_list_of_structs():
|
68
67
|
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
69
|
-
assert
|
68
|
+
assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
|
70
69
|
|
71
|
-
def
|
70
|
+
def test_encode_engine_value_struct_with_list():
|
72
71
|
basket = Basket(items=["apple", "banana"])
|
73
|
-
assert
|
72
|
+
assert encode_engine_value(basket) == [["apple", "banana"]]
|
74
73
|
|
75
|
-
def
|
74
|
+
def test_encode_engine_value_nested_struct():
|
76
75
|
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
|
77
|
-
assert
|
76
|
+
assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
|
78
77
|
|
79
|
-
def
|
80
|
-
assert
|
81
|
-
assert
|
78
|
+
def test_encode_engine_value_empty_list():
|
79
|
+
assert encode_engine_value([]) == []
|
80
|
+
assert encode_engine_value([[]]) == [[]]
|
82
81
|
|
83
|
-
def
|
84
|
-
assert
|
85
|
-
assert
|
86
|
-
assert
|
87
|
-
assert
|
88
|
-
assert
|
82
|
+
def test_encode_engine_value_tuple():
|
83
|
+
assert encode_engine_value(()) == []
|
84
|
+
assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
|
85
|
+
assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
|
86
|
+
assert encode_engine_value(([],)) == [[]]
|
87
|
+
assert encode_engine_value(((),)) == [[]]
|
89
88
|
|
90
|
-
def
|
91
|
-
assert
|
89
|
+
def test_encode_engine_value_none():
|
90
|
+
assert encode_engine_value(None) is None
|
92
91
|
|
93
|
-
def
|
92
|
+
def test_make_engine_value_decoder_basic_types():
|
94
93
|
for engine_type_in_py, value in [
|
95
94
|
(int, 42),
|
96
95
|
(float, 3.14),
|
@@ -98,11 +97,11 @@ def test_make_engine_value_converter_basic_types():
|
|
98
97
|
(bool, True),
|
99
98
|
# (type(None), None), # Removed unsupported NoneType
|
100
99
|
]:
|
101
|
-
|
102
|
-
assert
|
100
|
+
decoder = build_engine_value_decoder(engine_type_in_py)
|
101
|
+
assert decoder(value) == value
|
103
102
|
|
104
103
|
@pytest.mark.parametrize(
|
105
|
-
"
|
104
|
+
"data_type, engine_val, expected",
|
106
105
|
[
|
107
106
|
# All fields match
|
108
107
|
(Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
|
@@ -120,30 +119,30 @@ def test_make_engine_value_converter_basic_types():
|
|
120
119
|
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
|
121
120
|
]
|
122
121
|
)
|
123
|
-
def
|
124
|
-
|
125
|
-
assert
|
122
|
+
def test_struct_decoder_cases(data_type, engine_val, expected):
|
123
|
+
decoder = build_engine_value_decoder(data_type)
|
124
|
+
assert decoder(engine_val) == expected
|
126
125
|
|
127
|
-
def
|
126
|
+
def test_make_engine_value_decoder_collections():
|
128
127
|
# List of structs
|
129
|
-
|
128
|
+
decoder = build_engine_value_decoder(list[Order])
|
130
129
|
engine_val = [
|
131
130
|
["O1", "item1", 10.0, "default_extra"],
|
132
131
|
["O2", "item2", 20.0, "default_extra"]
|
133
132
|
]
|
134
|
-
assert
|
133
|
+
assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
|
135
134
|
# Struct with list field
|
136
|
-
|
135
|
+
decoder = build_engine_value_decoder(Customer)
|
137
136
|
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
|
138
|
-
assert
|
137
|
+
assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
|
139
138
|
# Struct with struct field
|
140
|
-
|
139
|
+
decoder = build_engine_value_decoder(NestedStruct)
|
141
140
|
engine_val = [
|
142
141
|
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
|
143
142
|
[["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]],
|
144
143
|
2
|
145
144
|
]
|
146
|
-
assert
|
145
|
+
assert decoder(engine_val) == NestedStruct(
|
147
146
|
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
|
148
147
|
[Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")],
|
149
148
|
2
|
@@ -227,8 +226,55 @@ def make_python_order(fields, defaults=None):
|
|
227
226
|
def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val):
|
228
227
|
EngineOrder = make_engine_order(engine_fields)
|
229
228
|
PythonOrder = make_python_order(python_fields, python_defaults)
|
230
|
-
|
229
|
+
decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
|
231
230
|
# Map field names to expected values
|
232
231
|
expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
|
233
232
|
# Instantiate using keyword arguments (order doesn't matter)
|
234
|
-
assert
|
233
|
+
assert decoder(engine_val) == PythonOrder(**expected_dict)
|
234
|
+
|
235
|
+
def test_roundtrip_ltable():
|
236
|
+
t = list[Order]
|
237
|
+
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
238
|
+
encoded = encode_engine_value(value)
|
239
|
+
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
|
240
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
241
|
+
assert decoded == value
|
242
|
+
|
243
|
+
def test_roundtrip_ktable_str_key():
|
244
|
+
t = dict[str, Order]
|
245
|
+
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
|
246
|
+
encoded = encode_engine_value(value)
|
247
|
+
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
|
248
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
249
|
+
assert decoded == value
|
250
|
+
|
251
|
+
def test_roundtrip_ktable_struct_key():
|
252
|
+
@dataclass(frozen=True)
|
253
|
+
class OrderKey:
|
254
|
+
shop_id: str
|
255
|
+
version: int
|
256
|
+
|
257
|
+
t = dict[OrderKey, Order]
|
258
|
+
value = {OrderKey("A", 3): Order("O1", "item1", 10.0), OrderKey("B", 4): Order("O2", "item2", 20.0)}
|
259
|
+
encoded = encode_engine_value(value)
|
260
|
+
assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
261
|
+
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
|
262
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
263
|
+
assert decoded == value
|
264
|
+
|
265
|
+
IntVectorType = cocoindex.Vector[int, Literal[5]]
|
266
|
+
def test_vector_as_vector() -> None:
|
267
|
+
value: IntVectorType = [1, 2, 3, 4, 5]
|
268
|
+
encoded = encode_engine_value(value)
|
269
|
+
assert encoded == [1, 2, 3, 4, 5]
|
270
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
271
|
+
assert decoded == value
|
272
|
+
|
273
|
+
ListIntType = list[int]
|
274
|
+
def test_vector_as_list() -> None:
|
275
|
+
value: ListIntType = [1, 2, 3, 4, 5]
|
276
|
+
encoded = encode_engine_value(value)
|
277
|
+
assert encoded == [1, 2, 3, 4, 5]
|
278
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
279
|
+
assert decoded == value
|
280
|
+
|
cocoindex/typing.py
CHANGED
@@ -5,9 +5,9 @@ import datetime
|
|
5
5
|
import types
|
6
6
|
import inspect
|
7
7
|
import uuid
|
8
|
-
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload
|
8
|
+
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload, Sequence, Protocol, Generic, Literal
|
9
9
|
|
10
|
-
class
|
10
|
+
class VectorInfo(NamedTuple):
|
11
11
|
dim: int | None
|
12
12
|
|
13
13
|
class TypeKind(NamedTuple):
|
@@ -21,7 +21,7 @@ class TypeAttr:
|
|
21
21
|
self.key = key
|
22
22
|
self.value = value
|
23
23
|
|
24
|
-
Annotation =
|
24
|
+
Annotation = TypeKind | TypeAttr | VectorInfo
|
25
25
|
|
26
26
|
Float32 = Annotated[float, TypeKind('Float32')]
|
27
27
|
Float64 = Annotated[float, TypeKind('Float64')]
|
@@ -30,29 +30,34 @@ Json = Annotated[Any, TypeKind('Json')]
|
|
30
30
|
LocalDateTime = Annotated[datetime.datetime, TypeKind('LocalDateTime')]
|
31
31
|
OffsetDateTime = Annotated[datetime.datetime, TypeKind('OffsetDateTime')]
|
32
32
|
|
33
|
-
COLLECTION_TYPES = ('Table', 'List')
|
34
|
-
|
35
|
-
R = TypeVar("R")
|
36
|
-
|
37
33
|
if TYPE_CHECKING:
|
38
|
-
|
39
|
-
|
34
|
+
T_co = TypeVar('T_co', covariant=True)
|
35
|
+
Dim_co = TypeVar('Dim_co', bound=int, covariant=True)
|
36
|
+
|
37
|
+
class Vector(Sequence[T_co], Generic[T_co, Dim_co], Protocol):
|
38
|
+
"""Vector[T, Dim] is a special typing alias for a list[T] with optional dimension info"""
|
40
39
|
else:
|
41
|
-
#
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
40
|
+
class Vector: # type: ignore[unreachable]
|
41
|
+
""" A special typing alias for a list[T] with optional dimension info """
|
42
|
+
def __class_getitem__(self, params):
|
43
|
+
if not isinstance(params, tuple):
|
44
|
+
# Only element type provided
|
45
|
+
elem_type = params
|
46
|
+
return Annotated[list[elem_type], VectorInfo(dim=None)]
|
47
|
+
else:
|
48
|
+
# Element type and dimension provided
|
49
|
+
elem_type, dim = params
|
50
|
+
if typing.get_origin(dim) is Literal:
|
51
|
+
dim = typing.get_args(dim)[0] # Extract the literal value
|
52
|
+
return Annotated[list[elem_type], VectorInfo(dim=dim)]
|
53
|
+
|
54
|
+
TABLE_TYPES = ('KTable', 'LTable')
|
55
|
+
KEY_FIELD_NAME = '_key'
|
56
|
+
|
57
|
+
ElementType = type | tuple[type, type]
|
58
|
+
|
59
|
+
def _is_struct_type(t) -> bool:
|
60
|
+
return isinstance(t, type) and dataclasses.is_dataclass(t)
|
56
61
|
|
57
62
|
@dataclasses.dataclass
|
58
63
|
class AnalyzedTypeInfo:
|
@@ -60,9 +65,12 @@ class AnalyzedTypeInfo:
|
|
60
65
|
Analyzed info of a Python type.
|
61
66
|
"""
|
62
67
|
kind: str
|
63
|
-
vector_info:
|
64
|
-
elem_type:
|
65
|
-
|
68
|
+
vector_info: VectorInfo | None # For Vector
|
69
|
+
elem_type: ElementType | None # For Vector and Table
|
70
|
+
|
71
|
+
key_type: type | None # For element of KTable
|
72
|
+
dataclass_type: type | None # For Struct
|
73
|
+
|
66
74
|
attrs: dict[str, Any] | None
|
67
75
|
nullable: bool = False
|
68
76
|
|
@@ -70,6 +78,12 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
70
78
|
"""
|
71
79
|
Analyze a Python type and return the analyzed info.
|
72
80
|
"""
|
81
|
+
if isinstance(t, tuple) and len(t) == 2:
|
82
|
+
key_type, value_type = t
|
83
|
+
result = analyze_type_info(value_type)
|
84
|
+
result.key_type = key_type
|
85
|
+
return result
|
86
|
+
|
73
87
|
annotations: tuple[Annotation, ...] = ()
|
74
88
|
base_type = None
|
75
89
|
nullable = False
|
@@ -98,33 +112,41 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
98
112
|
if attrs is None:
|
99
113
|
attrs = dict()
|
100
114
|
attrs[attr.key] = attr.value
|
101
|
-
elif isinstance(attr,
|
115
|
+
elif isinstance(attr, VectorInfo):
|
102
116
|
vector_info = attr
|
103
117
|
elif isinstance(attr, TypeKind):
|
104
118
|
kind = attr.kind
|
105
119
|
|
106
120
|
dataclass_type = None
|
107
121
|
elem_type = None
|
108
|
-
|
122
|
+
key_type = None
|
123
|
+
if _is_struct_type(t):
|
109
124
|
if kind is None:
|
110
125
|
kind = 'Struct'
|
111
126
|
elif kind != 'Struct':
|
112
127
|
raise ValueError(f"Unexpected type kind for struct: {kind}")
|
113
128
|
dataclass_type = t
|
114
129
|
elif base_type is collections.abc.Sequence or base_type is list:
|
130
|
+
args = typing.get_args(t)
|
131
|
+
elem_type = args[0]
|
132
|
+
|
115
133
|
if kind is None:
|
116
|
-
|
117
|
-
|
134
|
+
if _is_struct_type(elem_type):
|
135
|
+
kind = 'LTable'
|
136
|
+
if vector_info is not None:
|
137
|
+
raise ValueError("Vector element must be a simple type, not a struct")
|
138
|
+
else:
|
139
|
+
kind = 'Vector'
|
140
|
+
if vector_info is None:
|
141
|
+
vector_info = VectorInfo(dim=None)
|
142
|
+
elif not (kind == 'Vector' or kind in TABLE_TYPES):
|
118
143
|
raise ValueError(f"Unexpected type kind for list: {kind}")
|
119
|
-
|
144
|
+
elif base_type is collections.abc.Mapping or base_type is dict:
|
120
145
|
args = typing.get_args(t)
|
121
|
-
|
122
|
-
|
123
|
-
elem_type = args[0]
|
146
|
+
elem_type = (args[0], args[1])
|
147
|
+
kind = 'KTable'
|
124
148
|
elif kind is None:
|
125
|
-
if
|
126
|
-
kind = 'Vector' if vector_info is not None else 'List'
|
127
|
-
elif t is bytes:
|
149
|
+
if t is bytes:
|
128
150
|
kind = 'Bytes'
|
129
151
|
elif t is str:
|
130
152
|
kind = 'Str'
|
@@ -145,20 +167,26 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
145
167
|
else:
|
146
168
|
raise ValueError(f"type unsupported yet: {t}")
|
147
169
|
|
148
|
-
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info,
|
149
|
-
|
170
|
+
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info,
|
171
|
+
elem_type=elem_type, key_type=key_type, dataclass_type=dataclass_type,
|
172
|
+
attrs=attrs, nullable=nullable)
|
150
173
|
|
151
|
-
def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]:
|
174
|
+
def _encode_fields_schema(dataclass_type: type, key_type: type | None = None) -> list[dict[str, Any]]:
|
152
175
|
result = []
|
153
|
-
|
176
|
+
def add_field(name: str, t) -> None:
|
154
177
|
try:
|
155
|
-
type_info = encode_enriched_type_info(analyze_type_info(
|
178
|
+
type_info = encode_enriched_type_info(analyze_type_info(t))
|
156
179
|
except ValueError as e:
|
157
180
|
e.add_note(f"Failed to encode annotation for field - "
|
158
|
-
f"{dataclass_type.__name__}.{
|
181
|
+
f"{dataclass_type.__name__}.{name}: {t}")
|
159
182
|
raise
|
160
|
-
type_info['name'] =
|
183
|
+
type_info['name'] = name
|
161
184
|
result.append(type_info)
|
185
|
+
|
186
|
+
if key_type is not None:
|
187
|
+
add_field(KEY_FIELD_NAME, key_type)
|
188
|
+
for field in dataclasses.fields(dataclass_type):
|
189
|
+
add_field(field.name, field.type)
|
162
190
|
return result
|
163
191
|
|
164
192
|
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
@@ -167,7 +195,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
167
195
|
if type_info.kind == 'Struct':
|
168
196
|
if type_info.dataclass_type is None:
|
169
197
|
raise ValueError("Struct type must have a dataclass type")
|
170
|
-
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type)
|
198
|
+
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type, type_info.key_type)
|
171
199
|
if doc := inspect.getdoc(type_info.dataclass_type):
|
172
200
|
encoded_type['description'] = doc
|
173
201
|
|
@@ -179,7 +207,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
179
207
|
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
|
180
208
|
encoded_type['dimension'] = type_info.vector_info.dim
|
181
209
|
|
182
|
-
elif type_info.kind in
|
210
|
+
elif type_info.kind in TABLE_TYPES:
|
183
211
|
if type_info.elem_type is None:
|
184
212
|
raise ValueError(f"{type_info.kind} type must have an element type")
|
185
213
|
row_type_info = analyze_type_info(type_info.elem_type)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
cocoindex-0.1.27.dist-info/METADATA,sha256=upAIi8VNoLnqN1xUoKPpw1Hqspc3Rpb4cxKB3ahbFTs,8079
|
2
|
+
cocoindex-0.1.27.dist-info/WHEEL,sha256=e_pkmRfvNwfhA3ReKBun4dlOSPa_df-kVPas-U1KpNY,104
|
3
|
+
cocoindex-0.1.27.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
+
cocoindex/functions.py,sha256=clnpHCYSsjUnc8Spbc1-5sQedG-60fmibodv9LpHgqo,1647
|
5
|
+
cocoindex/query.py,sha256=8_3Lb_EVjZtl2ZyJNZGX16LoKXEd-PL8OjY-zs9GQeA,3205
|
6
|
+
cocoindex/index.py,sha256=LssEOuZi6AqhwKtZM3QFeQpa9T-0ELi8G5DsrYKECvc,534
|
7
|
+
cocoindex/lib.py,sha256=c6D7NuuTJj20WgVhnp0QGyK18lKMUvoDCiFr3PFs71s,3871
|
8
|
+
cocoindex/auth_registry.py,sha256=lZ2rD5_9aC_UpGk7t4TmSYal_rjN7eHgO4_sU7FR0Zw,620
|
9
|
+
cocoindex/convert.py,sha256=mBUTa_Ag39_ut-yE_jc1wqS3zLjtOm6QKet-bqJ-RWc,5947
|
10
|
+
cocoindex/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
11
|
+
cocoindex/tests/test_convert.py,sha256=WPRKp0jv_uSEM81RGWEAmsax-J-FtXt90mZ0yEnvGLs,11236
|
12
|
+
cocoindex/__init__.py,sha256=8atBT1HjclUOeiXd7TSzZWaqOR4x_qr5epvCKB7Z7oY,661
|
13
|
+
cocoindex/flow.py,sha256=zuwvByhQxg0fMYshWCcq0YYe6X2TvAJJie4xetnNIPE,21054
|
14
|
+
cocoindex/llm.py,sha256=_3rtahuKcqcEHPkFSwhXOSrekZyGxVApPoYtlU_chcA,348
|
15
|
+
cocoindex/runtime.py,sha256=jqRnWkkIlAhE04gi4y0Y5bzuq9FX4j0aVNU-nengLJk,980
|
16
|
+
cocoindex/op.py,sha256=ICCKZw6peCFu-CtMeIEaz6vlBxrf5dZwgUs9R4ALYNU,10604
|
17
|
+
cocoindex/sources.py,sha256=wZFU8lwSXjyofJR-syySH9fTyPnBlAPJ6-1hQNX8fGA,936
|
18
|
+
cocoindex/setup.py,sha256=W1HshwYk_K2aeLOVn_e62ZOXBO9yWsoUboRiH4SjF48,496
|
19
|
+
cocoindex/cli.py,sha256=MvEUbQVrJy-sYbGQNsqIaMJvcQXQn1OQVNues22Hph0,7061
|
20
|
+
cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
|
+
cocoindex/typing.py,sha256=p3FEUnoQc6zmiG8YwO4T155sgZtyc_1AufiJe3bNol8,8458
|
22
|
+
cocoindex/storages.py,sha256=bWsV4rRs4PJp6GemOb_PwoRu5XRgOsG-4sRMziztliU,1957
|
23
|
+
cocoindex/_engine.cpython-312-darwin.so,sha256=JBHj9qaowOY_yuYszvJ7-DNf7zK88xRFRcz2Dt4jRds,59514816
|
24
|
+
cocoindex-0.1.27.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
cocoindex-0.1.26.dist-info/METADATA,sha256=BwpB45xAUpYFZugav3l9u1uJ_0s4Kdrnp-Y2FnSJZ0A,8079
|
2
|
-
cocoindex-0.1.26.dist-info/WHEEL,sha256=e_pkmRfvNwfhA3ReKBun4dlOSPa_df-kVPas-U1KpNY,104
|
3
|
-
cocoindex-0.1.26.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
4
|
-
cocoindex/functions.py,sha256=xcAeRQTy9JObfxpjyMn-dPY2y7XhVWjB7759xVyup6o,1657
|
5
|
-
cocoindex/query.py,sha256=XsVY5cBJJ3a70qazkcCHjWZLE1zBqzMQ4HVSulicGMA,3273
|
6
|
-
cocoindex/index.py,sha256=LssEOuZi6AqhwKtZM3QFeQpa9T-0ELi8G5DsrYKECvc,534
|
7
|
-
cocoindex/lib.py,sha256=48nfWSg5IMzTSkVxdrWF8d9Hi-Bw8in_2rs7rQRwAs8,3505
|
8
|
-
cocoindex/auth_registry.py,sha256=lZ2rD5_9aC_UpGk7t4TmSYal_rjN7eHgO4_sU7FR0Zw,620
|
9
|
-
cocoindex/convert.py,sha256=tzlHadc-SaZCRBWxZEp08T4clJQPab_eXt6mUub0iQQ,5017
|
10
|
-
cocoindex/tests/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
11
|
-
cocoindex/tests/test_convert.py,sha256=mT_7HhXRu1PzGFbdmPmrq3O5-bPyDDG7mbr-mHaa2i0,9339
|
12
|
-
cocoindex/__init__.py,sha256=f4LTPg4db7Wm3QO9HirvhsT11OVykiFxGbt1JK6taFA,572
|
13
|
-
cocoindex/flow.py,sha256=1Mx-rYBzPIlFLNsiNVGhPtJKy2u6stZT1xjyPzERbFI,21068
|
14
|
-
cocoindex/llm.py,sha256=4b20wpSHcgfDM7tdxRm1KIo_7C30nT7h0gCsWvs686I,320
|
15
|
-
cocoindex/runtime.py,sha256=jqRnWkkIlAhE04gi4y0Y5bzuq9FX4j0aVNU-nengLJk,980
|
16
|
-
cocoindex/op.py,sha256=zOQzgnVDvET8LtUt7TW8PfHMhaA0eAXoZ3c_EsL6jtU,10602
|
17
|
-
cocoindex/sources.py,sha256=wZFU8lwSXjyofJR-syySH9fTyPnBlAPJ6-1hQNX8fGA,936
|
18
|
-
cocoindex/setup.py,sha256=W1HshwYk_K2aeLOVn_e62ZOXBO9yWsoUboRiH4SjF48,496
|
19
|
-
cocoindex/cli.py,sha256=MvEUbQVrJy-sYbGQNsqIaMJvcQXQn1OQVNues22Hph0,7061
|
20
|
-
cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
|
-
cocoindex/typing.py,sha256=4mP9VXS75s3VMfF1LDc1LXsBng7uGdR5aD73N8iaeSM,7282
|
22
|
-
cocoindex/storages.py,sha256=GRHkmwuSAU7neF3H0pjAPfeEkmtXv-DJc7CKYjcATvE,1946
|
23
|
-
cocoindex/_engine.cpython-312-darwin.so,sha256=nVnQv2Ok4mUXYFOvKEdd-YlWdzF-YGwDU-DACxkjASQ,59217776
|
24
|
-
cocoindex-0.1.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|