cocoindex 0.1.25__cp311-cp311-win_amd64.whl → 0.1.27__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +2 -1
- cocoindex/_engine.cp311-win_amd64.pyd +0 -0
- cocoindex/cli.py +9 -6
- cocoindex/convert.py +42 -21
- cocoindex/flow.py +46 -15
- cocoindex/functions.py +1 -1
- cocoindex/lib.py +17 -8
- cocoindex/llm.py +2 -0
- cocoindex/op.py +7 -6
- cocoindex/query.py +5 -4
- cocoindex/runtime.py +12 -4
- cocoindex/storages.py +10 -3
- cocoindex/tests/test_convert.py +245 -34
- cocoindex/typing.py +75 -47
- {cocoindex-0.1.25.dist-info → cocoindex-0.1.27.dist-info}/METADATA +2 -2
- cocoindex-0.1.27.dist-info/RECORD +24 -0
- cocoindex-0.1.25.dist-info/RECORD +0 -24
- {cocoindex-0.1.25.dist-info → cocoindex-0.1.27.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.25.dist-info → cocoindex-0.1.27.dist-info}/licenses/LICENSE +0 -0
cocoindex/__init__.py
CHANGED
@@ -9,4 +9,5 @@ from .llm import LlmSpec, LlmApiType
|
|
9
9
|
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
|
10
10
|
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
|
11
11
|
from .lib import *
|
12
|
-
from ._engine import OpArgSchema
|
12
|
+
from ._engine import OpArgSchema
|
13
|
+
from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
|
Binary file
|
cocoindex/cli.py
CHANGED
@@ -4,6 +4,7 @@ import datetime
|
|
4
4
|
|
5
5
|
from . import flow, lib
|
6
6
|
from .setup import sync_setup, drop_setup, flow_names_with_setup, apply_setup_changes
|
7
|
+
from .runtime import execution_context
|
7
8
|
|
8
9
|
@click.group()
|
9
10
|
def cli():
|
@@ -113,11 +114,13 @@ def update(flow_name: str | None, live: bool, quiet: bool):
|
|
113
114
|
Update the index to reflect the latest data from data sources.
|
114
115
|
"""
|
115
116
|
options = flow.FlowLiveUpdaterOptions(live_mode=live, print_stats=not quiet)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
117
|
+
async def _update():
|
118
|
+
if flow_name is None:
|
119
|
+
await flow.update_all_flows(options)
|
120
|
+
else:
|
121
|
+
updater = await flow.FlowLiveUpdater.create(_flow_by_name(flow_name), options)
|
122
|
+
await updater.wait()
|
123
|
+
execution_context.run(_update())
|
121
124
|
|
122
125
|
@cli.command()
|
123
126
|
@click.argument("flow_name", type=str, required=False)
|
@@ -167,7 +170,7 @@ def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None
|
|
167
170
|
lib.start_server(lib.ServerSettings(address=address, cors_origin=cors_origin))
|
168
171
|
if live_update:
|
169
172
|
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
|
170
|
-
|
173
|
+
execution_context.run(flow.update_all_flows(options))
|
171
174
|
input("Press Enter to stop...")
|
172
175
|
|
173
176
|
|
cocoindex/convert.py
CHANGED
@@ -8,25 +8,28 @@ import uuid
|
|
8
8
|
|
9
9
|
from enum import Enum
|
10
10
|
from typing import Any, Callable, get_origin
|
11
|
-
from .typing import analyze_type_info, encode_enriched_type,
|
11
|
+
from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME
|
12
12
|
|
13
|
-
|
14
|
-
|
13
|
+
|
14
|
+
def encode_engine_value(value: Any) -> Any:
|
15
|
+
"""Encode a Python value to an engine value."""
|
15
16
|
if dataclasses.is_dataclass(value):
|
16
|
-
return [
|
17
|
+
return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
|
17
18
|
if isinstance(value, (list, tuple)):
|
18
|
-
return [
|
19
|
+
return [encode_engine_value(v) for v in value]
|
20
|
+
if isinstance(value, dict):
|
21
|
+
return [[encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items()]
|
19
22
|
if isinstance(value, uuid.UUID):
|
20
23
|
return value.bytes
|
21
24
|
return value
|
22
25
|
|
23
|
-
def
|
26
|
+
def make_engine_value_decoder(
|
24
27
|
field_path: list[str],
|
25
28
|
src_type: dict[str, Any],
|
26
29
|
dst_annotation,
|
27
30
|
) -> Callable[[Any], Any]:
|
28
31
|
"""
|
29
|
-
Make a
|
32
|
+
Make a decoder from an engine value to a Python value.
|
30
33
|
|
31
34
|
Args:
|
32
35
|
field_path: The path to the field in the engine value. For error messages.
|
@@ -34,13 +37,13 @@ def make_engine_value_converter(
|
|
34
37
|
dst_annotation: The type annotation of the Python value.
|
35
38
|
|
36
39
|
Returns:
|
37
|
-
A
|
40
|
+
A decoder from an engine value to a Python value.
|
38
41
|
"""
|
39
42
|
|
40
43
|
src_type_kind = src_type['kind']
|
41
44
|
|
42
45
|
if dst_annotation is inspect.Parameter.empty:
|
43
|
-
if src_type_kind == 'Struct' or src_type_kind in
|
46
|
+
if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES:
|
44
47
|
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
|
45
48
|
f"It's required for {src_type_kind} type.")
|
46
49
|
return lambda value: value
|
@@ -53,41 +56,59 @@ def make_engine_value_converter(
|
|
53
56
|
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
|
54
57
|
|
55
58
|
if dst_type_info.dataclass_type is not None:
|
56
|
-
return
|
59
|
+
return _make_engine_struct_value_decoder(
|
57
60
|
field_path, src_type['fields'], dst_type_info.dataclass_type)
|
58
61
|
|
59
|
-
if src_type_kind in
|
62
|
+
if src_type_kind in TABLE_TYPES:
|
60
63
|
field_path.append('[*]')
|
61
64
|
elem_type_info = analyze_type_info(dst_type_info.elem_type)
|
62
65
|
if elem_type_info.dataclass_type is None:
|
63
66
|
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
+
f"declared `{dst_type_info.kind}`, a dataclass type expected")
|
68
|
+
engine_fields_schema = src_type['row']['fields']
|
69
|
+
if elem_type_info.key_type is not None:
|
70
|
+
key_field_schema = engine_fields_schema[0]
|
71
|
+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
|
72
|
+
key_decoder = make_engine_value_decoder(
|
73
|
+
field_path, key_field_schema['type'], elem_type_info.key_type)
|
74
|
+
field_path.pop()
|
75
|
+
value_decoder = _make_engine_struct_value_decoder(
|
76
|
+
field_path, engine_fields_schema[1:], elem_type_info.dataclass_type)
|
77
|
+
def decode(value):
|
78
|
+
if value is None:
|
79
|
+
return None
|
80
|
+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
|
81
|
+
else:
|
82
|
+
elem_decoder = _make_engine_struct_value_decoder(
|
83
|
+
field_path, engine_fields_schema, elem_type_info.dataclass_type)
|
84
|
+
def decode(value):
|
85
|
+
if value is None:
|
86
|
+
return None
|
87
|
+
return [elem_decoder(v) for v in value]
|
67
88
|
field_path.pop()
|
68
|
-
return
|
89
|
+
return decode
|
69
90
|
|
70
91
|
if src_type_kind == 'Uuid':
|
71
92
|
return lambda value: uuid.UUID(bytes=value)
|
72
93
|
|
73
94
|
return lambda value: value
|
74
95
|
|
75
|
-
def
|
96
|
+
def _make_engine_struct_value_decoder(
|
76
97
|
field_path: list[str],
|
77
98
|
src_fields: list[dict[str, Any]],
|
78
99
|
dst_dataclass_type: type,
|
79
100
|
) -> Callable[[list], Any]:
|
80
|
-
"""Make a
|
101
|
+
"""Make a decoder from an engine field values to a Python value."""
|
81
102
|
|
82
103
|
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
|
83
104
|
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
|
84
105
|
src_idx = src_name_to_idx.get(name)
|
85
106
|
if src_idx is not None:
|
86
107
|
field_path.append(f'.{name}')
|
87
|
-
|
108
|
+
field_decoder = make_engine_value_decoder(
|
88
109
|
field_path, src_fields[src_idx]['type'], param.annotation)
|
89
110
|
field_path.pop()
|
90
|
-
return lambda values:
|
111
|
+
return lambda values: field_decoder(values[src_idx])
|
91
112
|
|
92
113
|
default_value = param.default
|
93
114
|
if default_value is inspect.Parameter.empty:
|
@@ -96,12 +117,12 @@ def _make_engine_struct_value_converter(
|
|
96
117
|
|
97
118
|
return lambda _: default_value
|
98
119
|
|
99
|
-
|
120
|
+
field_value_decoder = [
|
100
121
|
make_closure_for_value(name, param)
|
101
122
|
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
|
102
123
|
|
103
124
|
return lambda values: dst_dataclass_type(
|
104
|
-
*(
|
125
|
+
*(decoder(values) for decoder in field_value_decoder))
|
105
126
|
|
106
127
|
def dump_engine_object(v: Any) -> Any:
|
107
128
|
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
|
cocoindex/flow.py
CHANGED
@@ -19,7 +19,7 @@ from . import index
|
|
19
19
|
from . import op
|
20
20
|
from .convert import dump_engine_object
|
21
21
|
from .typing import encode_enriched_type
|
22
|
-
from .runtime import
|
22
|
+
from .runtime import execution_context
|
23
23
|
|
24
24
|
class _NameBuilder:
|
25
25
|
_existing_names: set[str]
|
@@ -142,9 +142,9 @@ class DataSlice:
|
|
142
142
|
|
143
143
|
def row(self) -> DataScope:
|
144
144
|
"""
|
145
|
-
Return a scope representing each
|
145
|
+
Return a scope representing each row of the table.
|
146
146
|
"""
|
147
|
-
row_scope = self._state.engine_data_slice.
|
147
|
+
row_scope = self._state.engine_data_slice.table_row_scope()
|
148
148
|
return DataScope(self._state.flow_builder_state, row_scope)
|
149
149
|
|
150
150
|
def for_each(self, f: Callable[[DataScope], None]) -> None:
|
@@ -355,6 +355,12 @@ class FlowBuilder:
|
|
355
355
|
name
|
356
356
|
)
|
357
357
|
|
358
|
+
def declare(self, spec: op.DeclarationSpec):
|
359
|
+
"""
|
360
|
+
Add a declaration to the flow.
|
361
|
+
"""
|
362
|
+
self._state.engine_flow_builder.declare(dump_engine_object(spec))
|
363
|
+
|
358
364
|
@dataclass
|
359
365
|
class FlowLiveUpdaterOptions:
|
360
366
|
"""
|
@@ -369,16 +375,29 @@ class FlowLiveUpdater:
|
|
369
375
|
"""
|
370
376
|
_engine_live_updater: _engine.FlowLiveUpdater
|
371
377
|
|
372
|
-
def __init__(self,
|
373
|
-
|
374
|
-
|
378
|
+
def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdaterOptions | None = None):
|
379
|
+
if isinstance(arg, _engine.FlowLiveUpdater):
|
380
|
+
self._engine_live_updater = arg
|
381
|
+
else:
|
382
|
+
self._engine_live_updater = execution_context.run(_engine.FlowLiveUpdater(
|
383
|
+
arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions())))
|
384
|
+
|
385
|
+
@staticmethod
|
386
|
+
async def create(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
|
387
|
+
"""
|
388
|
+
Create a live updater for a flow.
|
389
|
+
"""
|
390
|
+
engine_live_updater = await _engine.FlowLiveUpdater.create(
|
391
|
+
await fl.ainternal_flow(),
|
392
|
+
dump_engine_object(options or FlowLiveUpdaterOptions()))
|
393
|
+
return FlowLiveUpdater(engine_live_updater)
|
375
394
|
|
376
395
|
def __enter__(self) -> FlowLiveUpdater:
|
377
396
|
return self
|
378
397
|
|
379
398
|
def __exit__(self, exc_type, exc_value, traceback):
|
380
399
|
self.abort()
|
381
|
-
|
400
|
+
execution_context.run(self.wait())
|
382
401
|
|
383
402
|
async def __aenter__(self) -> FlowLiveUpdater:
|
384
403
|
return self
|
@@ -450,7 +469,7 @@ class Flow:
|
|
450
469
|
Update the index defined by the flow.
|
451
470
|
Once the function returns, the indice is fresh up to the moment when the function is called.
|
452
471
|
"""
|
453
|
-
updater = FlowLiveUpdater(self, FlowLiveUpdaterOptions(live_mode=False))
|
472
|
+
updater = await FlowLiveUpdater.create(self, FlowLiveUpdaterOptions(live_mode=False))
|
454
473
|
await updater.wait()
|
455
474
|
return updater.update_stats()
|
456
475
|
|
@@ -466,6 +485,12 @@ class Flow:
|
|
466
485
|
"""
|
467
486
|
return self._lazy_engine_flow()
|
468
487
|
|
488
|
+
async def ainternal_flow(self) -> _engine.Flow:
|
489
|
+
"""
|
490
|
+
Get the engine flow. The async version.
|
491
|
+
"""
|
492
|
+
return await asyncio.to_thread(self.internal_flow)
|
493
|
+
|
469
494
|
def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
|
470
495
|
"""
|
471
496
|
Create a flow without really building it yet.
|
@@ -476,7 +501,7 @@ def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope
|
|
476
501
|
root_scope = DataScope(
|
477
502
|
flow_builder_state, flow_builder_state.engine_flow_builder.root_scope())
|
478
503
|
fl_def(FlowBuilder(flow_builder_state), root_scope)
|
479
|
-
return flow_builder_state.engine_flow_builder.build_flow(
|
504
|
+
return flow_builder_state.engine_flow_builder.build_flow(execution_context.event_loop)
|
480
505
|
|
481
506
|
return Flow(_create_engine_flow)
|
482
507
|
|
@@ -523,17 +548,23 @@ def ensure_all_flows_built() -> None:
|
|
523
548
|
"""
|
524
549
|
Ensure all flows are built.
|
525
550
|
"""
|
526
|
-
|
527
|
-
|
528
|
-
|
551
|
+
for fl in flows():
|
552
|
+
fl.internal_flow()
|
553
|
+
|
554
|
+
async def aensure_all_flows_built() -> None:
|
555
|
+
"""
|
556
|
+
Ensure all flows are built.
|
557
|
+
"""
|
558
|
+
for fl in flows():
|
559
|
+
await fl.ainternal_flow()
|
529
560
|
|
530
561
|
async def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
|
531
562
|
"""
|
532
563
|
Update all flows.
|
533
564
|
"""
|
534
|
-
|
565
|
+
await aensure_all_flows_built()
|
535
566
|
async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo:
|
536
|
-
updater = FlowLiveUpdater(fl, options)
|
567
|
+
updater = await FlowLiveUpdater.create(fl, options)
|
537
568
|
await updater.wait()
|
538
569
|
return updater.update_stats()
|
539
570
|
fls = flows()
|
@@ -572,7 +603,7 @@ class TransientFlow:
|
|
572
603
|
flow_builder_state.engine_flow_builder.set_direct_output(
|
573
604
|
_data_slice_state(output).engine_data_slice)
|
574
605
|
self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
|
575
|
-
|
606
|
+
execution_context.event_loop)
|
576
607
|
|
577
608
|
def __str__(self):
|
578
609
|
return str(self._engine_flow)
|
cocoindex/functions.py
CHANGED
@@ -41,7 +41,7 @@ class SentenceTransformerEmbedExecutor:
|
|
41
41
|
args = self.spec.args or {}
|
42
42
|
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
|
43
43
|
dim = self._model.get_sentence_embedding_dimension()
|
44
|
-
return Annotated[
|
44
|
+
return Annotated[Vector[Float32, dim], TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)]
|
45
45
|
|
46
46
|
def __call__(self, text: str) -> list[Float32]:
|
47
47
|
return self._model.encode(text).tolist()
|
cocoindex/lib.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""
|
2
2
|
Library level functions and states.
|
3
3
|
"""
|
4
|
-
import asyncio
|
5
4
|
import os
|
6
5
|
import sys
|
7
6
|
import functools
|
@@ -12,6 +11,7 @@ from dataclasses import dataclass
|
|
12
11
|
|
13
12
|
from . import _engine
|
14
13
|
from . import flow, query, cli
|
14
|
+
from .convert import dump_engine_object
|
15
15
|
|
16
16
|
|
17
17
|
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
|
@@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
|
|
22
22
|
else:
|
23
23
|
target[name] = value
|
24
24
|
|
25
|
+
@dataclass
|
26
|
+
class DatabaseConnectionSpec:
|
27
|
+
url: str
|
28
|
+
user: str | None = None
|
29
|
+
password: str | None = None
|
30
|
+
|
25
31
|
@dataclass
|
26
32
|
class Settings:
|
27
33
|
"""Settings for the cocoindex library."""
|
28
|
-
|
34
|
+
database: DatabaseConnectionSpec
|
29
35
|
|
30
36
|
@classmethod
|
31
37
|
def from_env(cls) -> Self:
|
32
38
|
"""Load settings from environment variables."""
|
33
39
|
|
34
|
-
|
35
|
-
_load_field(
|
36
|
-
|
37
|
-
|
40
|
+
db_kwargs: dict[str, str] = dict()
|
41
|
+
_load_field(db_kwargs, "url", "COCOINDEX_DATABASE_URL", required=True)
|
42
|
+
_load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
|
43
|
+
_load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
|
44
|
+
database = DatabaseConnectionSpec(**db_kwargs)
|
45
|
+
return cls(database=database)
|
38
46
|
|
39
47
|
|
40
48
|
def init(settings: Settings):
|
41
49
|
"""Initialize the cocoindex library."""
|
42
|
-
_engine.init(settings
|
50
|
+
_engine.init(dump_engine_object(settings))
|
43
51
|
|
44
52
|
@dataclass
|
45
53
|
class ServerSettings:
|
@@ -101,7 +109,8 @@ def main_fn(
|
|
101
109
|
try:
|
102
110
|
if _should_run_cli():
|
103
111
|
# Schedule to a separate thread as it invokes nested event loop.
|
104
|
-
return await asyncio.to_thread(_run_cli)
|
112
|
+
# return await asyncio.to_thread(_run_cli)
|
113
|
+
return _run_cli()
|
105
114
|
return await fn(*args, **kwargs)
|
106
115
|
finally:
|
107
116
|
stop()
|
cocoindex/llm.py
CHANGED
cocoindex/op.py
CHANGED
@@ -7,10 +7,9 @@ import inspect
|
|
7
7
|
|
8
8
|
from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform
|
9
9
|
from enum import Enum
|
10
|
-
from functools import partial
|
11
10
|
|
12
11
|
from .typing import encode_enriched_type
|
13
|
-
from .convert import
|
12
|
+
from .convert import encode_engine_value, make_engine_value_decoder
|
14
13
|
from . import _engine
|
15
14
|
|
16
15
|
class OpCategory(Enum):
|
@@ -18,7 +17,7 @@ class OpCategory(Enum):
|
|
18
17
|
FUNCTION = "function"
|
19
18
|
SOURCE = "source"
|
20
19
|
STORAGE = "storage"
|
21
|
-
|
20
|
+
DECLARATION = "declaration"
|
22
21
|
@dataclass_transform()
|
23
22
|
class SpecMeta(type):
|
24
23
|
"""Meta class for spec classes."""
|
@@ -41,6 +40,8 @@ class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint:
|
|
41
40
|
class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods
|
42
41
|
"""A storage spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
43
42
|
|
43
|
+
class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
|
44
|
+
"""A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
44
45
|
class Executor(Protocol):
|
45
46
|
"""An executor for an operation."""
|
46
47
|
op_category: OpCategory
|
@@ -128,7 +129,7 @@ def _register_op_factory(
|
|
128
129
|
raise ValueError(
|
129
130
|
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
|
130
131
|
self._args_converters.append(
|
131
|
-
|
132
|
+
make_engine_value_decoder(
|
132
133
|
[arg_name], arg.value_type['type'], arg_param.annotation))
|
133
134
|
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
|
134
135
|
next_param_idx += 1
|
@@ -145,7 +146,7 @@ def _register_op_factory(
|
|
145
146
|
if expected_arg is None:
|
146
147
|
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
|
147
148
|
arg_param = expected_arg[1]
|
148
|
-
self._kwargs_converters[kwarg_name] =
|
149
|
+
self._kwargs_converters[kwarg_name] = make_engine_value_decoder(
|
149
150
|
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
|
150
151
|
|
151
152
|
missing_args = [name for (name, arg) in expected_kwargs
|
@@ -187,7 +188,7 @@ def _register_op_factory(
|
|
187
188
|
output = await self._acall(*converted_args, **converted_kwargs)
|
188
189
|
else:
|
189
190
|
output = await self._acall(*converted_args, **converted_kwargs)
|
190
|
-
return
|
191
|
+
return encode_engine_value(output)
|
191
192
|
|
192
193
|
_WrappedClass.__name__ = executor_cls.__name__
|
193
194
|
_WrappedClass.__doc__ = executor_cls.__doc__
|
cocoindex/query.py
CHANGED
@@ -66,15 +66,16 @@ class SimpleSemanticsQueryHandler:
|
|
66
66
|
return self._lazy_query_handler()
|
67
67
|
|
68
68
|
def search(self, query: str, limit: int, vector_field_name: str | None = None,
|
69
|
-
|
69
|
+
similarity_metric: index.VectorSimilarityMetric | None = None
|
70
|
+
) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
|
70
71
|
"""
|
71
72
|
Search the index with the given query, limit, vector field name, and similarity metric.
|
72
73
|
"""
|
73
74
|
internal_results, internal_info = self.internal_handler().search(
|
74
75
|
query, limit, vector_field_name,
|
75
|
-
|
76
|
-
|
77
|
-
|
76
|
+
similarity_metric.value if similarity_metric is not None else None)
|
77
|
+
results = [QueryResult(data=result['data'], score=result['score'])
|
78
|
+
for result in internal_results]
|
78
79
|
info = SimpleSemanticsQueryInfo(
|
79
80
|
similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
|
80
81
|
query_vector=internal_info['query_vector'],
|
cocoindex/runtime.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
|
+
"""
|
2
|
+
This module provides a standalone execution runtime for executing coroutines in a thread-safe
|
3
|
+
manner.
|
4
|
+
"""
|
5
|
+
|
1
6
|
import threading
|
2
7
|
import asyncio
|
3
|
-
|
4
|
-
class
|
8
|
+
from typing import Coroutine
|
9
|
+
class _ExecutionContext:
|
5
10
|
_lock: threading.Lock
|
6
11
|
_event_loop: asyncio.AbstractEventLoop | None = None
|
7
12
|
|
@@ -14,8 +19,11 @@ class _OpExecutionContext:
|
|
14
19
|
with self._lock:
|
15
20
|
if self._event_loop is None:
|
16
21
|
self._event_loop = asyncio.new_event_loop()
|
17
|
-
asyncio.set_event_loop(self._event_loop)
|
18
22
|
threading.Thread(target=self._event_loop.run_forever, daemon=True).start()
|
19
23
|
return self._event_loop
|
20
24
|
|
21
|
-
|
25
|
+
def run(self, coro: Coroutine):
|
26
|
+
"""Run a coroutine in the event loop, blocking until it finishes. Return its result."""
|
27
|
+
return asyncio.run_coroutine_threadsafe(coro, self.event_loop).result()
|
28
|
+
|
29
|
+
execution_context = _ExecutionContext()
|
cocoindex/storages.py
CHANGED
@@ -9,7 +9,7 @@ from .auth_registry import AuthEntryReference
|
|
9
9
|
class Postgres(op.StorageSpec):
|
10
10
|
"""Storage powered by Postgres and pgvector."""
|
11
11
|
|
12
|
-
|
12
|
+
database: AuthEntryReference | None = None
|
13
13
|
table_name: str | None = None
|
14
14
|
|
15
15
|
@dataclass
|
@@ -43,8 +43,9 @@ class NodeReferenceMapping:
|
|
43
43
|
fields: list[TargetFieldMapping]
|
44
44
|
|
45
45
|
@dataclass
|
46
|
-
class
|
46
|
+
class ReferencedNode:
|
47
47
|
"""Storage spec for a graph node."""
|
48
|
+
label: str
|
48
49
|
primary_key_fields: Sequence[str]
|
49
50
|
vector_indexes: Sequence[index.VectorIndexDef] = ()
|
50
51
|
|
@@ -63,10 +64,16 @@ class RelationshipMapping:
|
|
63
64
|
rel_type: str
|
64
65
|
source: NodeReferenceMapping
|
65
66
|
target: NodeReferenceMapping
|
66
|
-
nodes_storage_spec: dict[str, NodeStorageSpec] | None = None
|
67
67
|
|
68
68
|
class Neo4j(op.StorageSpec):
|
69
69
|
"""Graph storage powered by Neo4j."""
|
70
70
|
|
71
71
|
connection: AuthEntryReference
|
72
72
|
mapping: NodeMapping | RelationshipMapping
|
73
|
+
|
74
|
+
class Neo4jDeclarations(op.DeclarationSpec):
|
75
|
+
"""Declarations for Neo4j."""
|
76
|
+
|
77
|
+
kind = "Neo4j"
|
78
|
+
connection: AuthEntryReference
|
79
|
+
referenced_nodes: Sequence[ReferencedNode] = ()
|
cocoindex/tests/test_convert.py
CHANGED
@@ -1,15 +1,21 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import uuid
|
3
2
|
import datetime
|
4
|
-
from dataclasses import dataclass
|
3
|
+
from dataclasses import dataclass, make_dataclass
|
5
4
|
import pytest
|
6
|
-
|
7
|
-
|
5
|
+
import cocoindex
|
6
|
+
from cocoindex.typing import encode_enriched_type
|
7
|
+
from cocoindex.convert import encode_engine_value, make_engine_value_decoder
|
8
|
+
from typing import Literal
|
8
9
|
@dataclass
|
9
10
|
class Order:
|
10
11
|
order_id: str
|
11
12
|
name: str
|
12
13
|
price: float
|
14
|
+
extra_field: str = "default_extra"
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class Tag:
|
18
|
+
name: str
|
13
19
|
|
14
20
|
@dataclass
|
15
21
|
class Basket:
|
@@ -19,51 +25,256 @@ class Basket:
|
|
19
25
|
class Customer:
|
20
26
|
name: str
|
21
27
|
order: Order
|
28
|
+
tags: list[Tag] | None = None
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class NestedStruct:
|
32
|
+
customer: Customer
|
33
|
+
orders: list[Order]
|
34
|
+
count: int = 0
|
35
|
+
|
36
|
+
def build_engine_value_decoder(engine_type_in_py, python_type=None):
|
37
|
+
"""
|
38
|
+
Helper to build a converter for the given engine-side type (as represented in Python).
|
39
|
+
If python_type is not specified, uses engine_type_in_py as the target.
|
40
|
+
"""
|
41
|
+
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
42
|
+
return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
|
22
43
|
|
23
|
-
def
|
24
|
-
assert
|
25
|
-
assert
|
26
|
-
assert
|
27
|
-
assert
|
44
|
+
def test_encode_engine_value_basic_types():
|
45
|
+
assert encode_engine_value(123) == 123
|
46
|
+
assert encode_engine_value(3.14) == 3.14
|
47
|
+
assert encode_engine_value("hello") == "hello"
|
48
|
+
assert encode_engine_value(True) is True
|
28
49
|
|
29
|
-
def
|
50
|
+
def test_encode_engine_value_uuid():
|
30
51
|
u = uuid.uuid4()
|
31
|
-
assert
|
52
|
+
assert encode_engine_value(u) == u.bytes
|
32
53
|
|
33
|
-
def
|
54
|
+
def test_encode_engine_value_date_time_types():
|
34
55
|
d = datetime.date(2024, 1, 1)
|
35
|
-
assert
|
56
|
+
assert encode_engine_value(d) == d
|
36
57
|
t = datetime.time(12, 30)
|
37
|
-
assert
|
58
|
+
assert encode_engine_value(t) == t
|
38
59
|
dt = datetime.datetime(2024, 1, 1, 12, 30)
|
39
|
-
assert
|
60
|
+
assert encode_engine_value(dt) == dt
|
40
61
|
|
41
|
-
def
|
62
|
+
def test_encode_engine_value_struct():
|
42
63
|
order = Order(order_id="O123", name="mixed nuts", price=25.0)
|
43
|
-
assert
|
64
|
+
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
|
44
65
|
|
45
|
-
def
|
66
|
+
def test_encode_engine_value_list_of_structs():
|
46
67
|
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
47
|
-
assert
|
68
|
+
assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
|
48
69
|
|
49
|
-
def
|
70
|
+
def test_encode_engine_value_struct_with_list():
|
50
71
|
basket = Basket(items=["apple", "banana"])
|
51
|
-
assert
|
72
|
+
assert encode_engine_value(basket) == [["apple", "banana"]]
|
52
73
|
|
53
|
-
def
|
74
|
+
def test_encode_engine_value_nested_struct():
|
54
75
|
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
|
55
|
-
assert
|
76
|
+
assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
|
77
|
+
|
78
|
+
def test_encode_engine_value_empty_list():
|
79
|
+
assert encode_engine_value([]) == []
|
80
|
+
assert encode_engine_value([[]]) == [[]]
|
81
|
+
|
82
|
+
def test_encode_engine_value_tuple():
|
83
|
+
assert encode_engine_value(()) == []
|
84
|
+
assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
|
85
|
+
assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
|
86
|
+
assert encode_engine_value(([],)) == [[]]
|
87
|
+
assert encode_engine_value(((),)) == [[]]
|
88
|
+
|
89
|
+
def test_encode_engine_value_none():
|
90
|
+
assert encode_engine_value(None) is None
|
91
|
+
|
92
|
+
def test_make_engine_value_decoder_basic_types():
|
93
|
+
for engine_type_in_py, value in [
|
94
|
+
(int, 42),
|
95
|
+
(float, 3.14),
|
96
|
+
(str, "hello"),
|
97
|
+
(bool, True),
|
98
|
+
# (type(None), None), # Removed unsupported NoneType
|
99
|
+
]:
|
100
|
+
decoder = build_engine_value_decoder(engine_type_in_py)
|
101
|
+
assert decoder(value) == value
|
102
|
+
|
103
|
+
@pytest.mark.parametrize(
|
104
|
+
"data_type, engine_val, expected",
|
105
|
+
[
|
106
|
+
# All fields match
|
107
|
+
(Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
|
108
|
+
# Extra field in engine value (should ignore extra)
|
109
|
+
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")),
|
110
|
+
# Fewer fields in engine value (should fill with default)
|
111
|
+
(Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")),
|
112
|
+
# More fields in engine value (should ignore extra)
|
113
|
+
(Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")),
|
114
|
+
# Truly extra field (should ignore the fifth field)
|
115
|
+
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")),
|
116
|
+
# Missing optional field in engine value (tags=None)
|
117
|
+
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
|
118
|
+
# Extra field in engine value for Customer (should ignore)
|
119
|
+
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
|
120
|
+
]
|
121
|
+
)
|
122
|
+
def test_struct_decoder_cases(data_type, engine_val, expected):
|
123
|
+
decoder = build_engine_value_decoder(data_type)
|
124
|
+
assert decoder(engine_val) == expected
|
125
|
+
|
126
|
+
def test_make_engine_value_decoder_collections():
|
127
|
+
# List of structs
|
128
|
+
decoder = build_engine_value_decoder(list[Order])
|
129
|
+
engine_val = [
|
130
|
+
["O1", "item1", 10.0, "default_extra"],
|
131
|
+
["O2", "item2", 20.0, "default_extra"]
|
132
|
+
]
|
133
|
+
assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
|
134
|
+
# Struct with list field
|
135
|
+
decoder = build_engine_value_decoder(Customer)
|
136
|
+
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
|
137
|
+
assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
|
138
|
+
# Struct with struct field
|
139
|
+
decoder = build_engine_value_decoder(NestedStruct)
|
140
|
+
engine_val = [
|
141
|
+
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
|
142
|
+
[["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]],
|
143
|
+
2
|
144
|
+
]
|
145
|
+
assert decoder(engine_val) == NestedStruct(
|
146
|
+
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
|
147
|
+
[Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")],
|
148
|
+
2
|
149
|
+
)
|
150
|
+
|
151
|
+
def make_engine_order(fields):
|
152
|
+
return make_dataclass('EngineOrder', fields)
|
153
|
+
|
154
|
+
def make_python_order(fields, defaults=None):
|
155
|
+
if defaults is None:
|
156
|
+
defaults = {}
|
157
|
+
# Move all fields with defaults to the end (Python dataclass requirement)
|
158
|
+
non_default_fields = [(n, t) for n, t in fields if n not in defaults]
|
159
|
+
default_fields = [(n, t) for n, t in fields if n in defaults]
|
160
|
+
ordered_fields = non_default_fields + default_fields
|
161
|
+
# Prepare the namespace for defaults (only for fields at the end)
|
162
|
+
namespace = {k: defaults[k] for k, _ in default_fields}
|
163
|
+
return make_dataclass('PythonOrder', ordered_fields, namespace=namespace)
|
164
|
+
|
165
|
+
@pytest.mark.parametrize(
|
166
|
+
"engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
|
167
|
+
[
|
168
|
+
# Extra field in Python (middle)
|
169
|
+
(
|
170
|
+
[("id", str), ("name", str)],
|
171
|
+
[("id", str), ("price", float), ("name", str)],
|
172
|
+
{"price": 0.0},
|
173
|
+
["O123", "mixed nuts"],
|
174
|
+
("O123", 0.0, "mixed nuts"),
|
175
|
+
),
|
176
|
+
# Missing field in Python (middle)
|
177
|
+
(
|
178
|
+
[("id", str), ("price", float), ("name", str)],
|
179
|
+
[("id", str), ("name", str)],
|
180
|
+
{},
|
181
|
+
["O123", 25.0, "mixed nuts"],
|
182
|
+
("O123", "mixed nuts"),
|
183
|
+
),
|
184
|
+
# Extra field in Python (start)
|
185
|
+
(
|
186
|
+
[("name", str), ("price", float)],
|
187
|
+
[("extra", str), ("name", str), ("price", float)],
|
188
|
+
{"extra": "default"},
|
189
|
+
["mixed nuts", 25.0],
|
190
|
+
("default", "mixed nuts", 25.0),
|
191
|
+
),
|
192
|
+
# Missing field in Python (start)
|
193
|
+
(
|
194
|
+
[("extra", str), ("name", str), ("price", float)],
|
195
|
+
[("name", str), ("price", float)],
|
196
|
+
{},
|
197
|
+
["unexpected", "mixed nuts", 25.0],
|
198
|
+
("mixed nuts", 25.0),
|
199
|
+
),
|
200
|
+
# Field order difference (should map by name)
|
201
|
+
(
|
202
|
+
[("id", str), ("name", str), ("price", float)],
|
203
|
+
[("name", str), ("id", str), ("price", float), ("extra", str)],
|
204
|
+
{"extra": "default"},
|
205
|
+
["O123", "mixed nuts", 25.0],
|
206
|
+
("mixed nuts", "O123", 25.0, "default"),
|
207
|
+
),
|
208
|
+
# Extra field (Python has extra field with default)
|
209
|
+
(
|
210
|
+
[("id", str), ("name", str)],
|
211
|
+
[("id", str), ("name", str), ("price", float)],
|
212
|
+
{"price": 0.0},
|
213
|
+
["O123", "mixed nuts"],
|
214
|
+
("O123", "mixed nuts", 0.0),
|
215
|
+
),
|
216
|
+
# Missing field (Engine has extra field)
|
217
|
+
(
|
218
|
+
[("id", str), ("name", str), ("price", float)],
|
219
|
+
[("id", str), ("name", str)],
|
220
|
+
{},
|
221
|
+
["O123", "mixed nuts", 25.0],
|
222
|
+
("O123", "mixed nuts"),
|
223
|
+
),
|
224
|
+
]
|
225
|
+
)
|
226
|
+
def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val):
|
227
|
+
EngineOrder = make_engine_order(engine_fields)
|
228
|
+
PythonOrder = make_python_order(python_fields, python_defaults)
|
229
|
+
decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
|
230
|
+
# Map field names to expected values
|
231
|
+
expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
|
232
|
+
# Instantiate using keyword arguments (order doesn't matter)
|
233
|
+
assert decoder(engine_val) == PythonOrder(**expected_dict)
|
234
|
+
|
235
|
+
def test_roundtrip_ltable():
|
236
|
+
t = list[Order]
|
237
|
+
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
|
238
|
+
encoded = encode_engine_value(value)
|
239
|
+
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
|
240
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
241
|
+
assert decoded == value
|
242
|
+
|
243
|
+
def test_roundtrip_ktable_str_key():
|
244
|
+
t = dict[str, Order]
|
245
|
+
value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
|
246
|
+
encoded = encode_engine_value(value)
|
247
|
+
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
|
248
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
249
|
+
assert decoded == value
|
250
|
+
|
251
|
+
def test_roundtrip_ktable_struct_key():
|
252
|
+
@dataclass(frozen=True)
|
253
|
+
class OrderKey:
|
254
|
+
shop_id: str
|
255
|
+
version: int
|
256
|
+
|
257
|
+
t = dict[OrderKey, Order]
|
258
|
+
value = {OrderKey("A", 3): Order("O1", "item1", 10.0), OrderKey("B", 4): Order("O2", "item2", 20.0)}
|
259
|
+
encoded = encode_engine_value(value)
|
260
|
+
assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"],
|
261
|
+
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
|
262
|
+
decoded = build_engine_value_decoder(t)(encoded)
|
263
|
+
assert decoded == value
|
56
264
|
|
57
|
-
|
58
|
-
|
59
|
-
|
265
|
+
IntVectorType = cocoindex.Vector[int, Literal[5]]
|
266
|
+
def test_vector_as_vector() -> None:
|
267
|
+
value: IntVectorType = [1, 2, 3, 4, 5]
|
268
|
+
encoded = encode_engine_value(value)
|
269
|
+
assert encoded == [1, 2, 3, 4, 5]
|
270
|
+
decoded = build_engine_value_decoder(IntVectorType)(encoded)
|
271
|
+
assert decoded == value
|
60
272
|
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
assert
|
66
|
-
|
273
|
+
ListIntType = list[int]
|
274
|
+
def test_vector_as_list() -> None:
|
275
|
+
value: ListIntType = [1, 2, 3, 4, 5]
|
276
|
+
encoded = encode_engine_value(value)
|
277
|
+
assert encoded == [1, 2, 3, 4, 5]
|
278
|
+
decoded = build_engine_value_decoder(ListIntType)(encoded)
|
279
|
+
assert decoded == value
|
67
280
|
|
68
|
-
def test_to_engine_value_none():
|
69
|
-
assert to_engine_value(None) is None
|
cocoindex/typing.py
CHANGED
@@ -5,9 +5,9 @@ import datetime
|
|
5
5
|
import types
|
6
6
|
import inspect
|
7
7
|
import uuid
|
8
|
-
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload
|
8
|
+
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload, Sequence, Protocol, Generic, Literal
|
9
9
|
|
10
|
-
class
|
10
|
+
class VectorInfo(NamedTuple):
|
11
11
|
dim: int | None
|
12
12
|
|
13
13
|
class TypeKind(NamedTuple):
|
@@ -21,7 +21,7 @@ class TypeAttr:
|
|
21
21
|
self.key = key
|
22
22
|
self.value = value
|
23
23
|
|
24
|
-
Annotation =
|
24
|
+
Annotation = TypeKind | TypeAttr | VectorInfo
|
25
25
|
|
26
26
|
Float32 = Annotated[float, TypeKind('Float32')]
|
27
27
|
Float64 = Annotated[float, TypeKind('Float64')]
|
@@ -30,29 +30,34 @@ Json = Annotated[Any, TypeKind('Json')]
|
|
30
30
|
LocalDateTime = Annotated[datetime.datetime, TypeKind('LocalDateTime')]
|
31
31
|
OffsetDateTime = Annotated[datetime.datetime, TypeKind('OffsetDateTime')]
|
32
32
|
|
33
|
-
COLLECTION_TYPES = ('Table', 'List')
|
34
|
-
|
35
|
-
R = TypeVar("R")
|
36
|
-
|
37
33
|
if TYPE_CHECKING:
|
38
|
-
|
39
|
-
|
34
|
+
T_co = TypeVar('T_co', covariant=True)
|
35
|
+
Dim_co = TypeVar('Dim_co', bound=int, covariant=True)
|
36
|
+
|
37
|
+
class Vector(Sequence[T_co], Generic[T_co, Dim_co], Protocol):
|
38
|
+
"""Vector[T, Dim] is a special typing alias for a list[T] with optional dimension info"""
|
40
39
|
else:
|
41
|
-
#
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
40
|
+
class Vector: # type: ignore[unreachable]
|
41
|
+
""" A special typing alias for a list[T] with optional dimension info """
|
42
|
+
def __class_getitem__(self, params):
|
43
|
+
if not isinstance(params, tuple):
|
44
|
+
# Only element type provided
|
45
|
+
elem_type = params
|
46
|
+
return Annotated[list[elem_type], VectorInfo(dim=None)]
|
47
|
+
else:
|
48
|
+
# Element type and dimension provided
|
49
|
+
elem_type, dim = params
|
50
|
+
if typing.get_origin(dim) is Literal:
|
51
|
+
dim = typing.get_args(dim)[0] # Extract the literal value
|
52
|
+
return Annotated[list[elem_type], VectorInfo(dim=dim)]
|
53
|
+
|
54
|
+
TABLE_TYPES = ('KTable', 'LTable')
|
55
|
+
KEY_FIELD_NAME = '_key'
|
56
|
+
|
57
|
+
ElementType = type | tuple[type, type]
|
58
|
+
|
59
|
+
def _is_struct_type(t) -> bool:
|
60
|
+
return isinstance(t, type) and dataclasses.is_dataclass(t)
|
56
61
|
|
57
62
|
@dataclasses.dataclass
|
58
63
|
class AnalyzedTypeInfo:
|
@@ -60,9 +65,12 @@ class AnalyzedTypeInfo:
|
|
60
65
|
Analyzed info of a Python type.
|
61
66
|
"""
|
62
67
|
kind: str
|
63
|
-
vector_info:
|
64
|
-
elem_type:
|
65
|
-
|
68
|
+
vector_info: VectorInfo | None # For Vector
|
69
|
+
elem_type: ElementType | None # For Vector and Table
|
70
|
+
|
71
|
+
key_type: type | None # For element of KTable
|
72
|
+
dataclass_type: type | None # For Struct
|
73
|
+
|
66
74
|
attrs: dict[str, Any] | None
|
67
75
|
nullable: bool = False
|
68
76
|
|
@@ -70,6 +78,12 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
70
78
|
"""
|
71
79
|
Analyze a Python type and return the analyzed info.
|
72
80
|
"""
|
81
|
+
if isinstance(t, tuple) and len(t) == 2:
|
82
|
+
key_type, value_type = t
|
83
|
+
result = analyze_type_info(value_type)
|
84
|
+
result.key_type = key_type
|
85
|
+
return result
|
86
|
+
|
73
87
|
annotations: tuple[Annotation, ...] = ()
|
74
88
|
base_type = None
|
75
89
|
nullable = False
|
@@ -98,33 +112,41 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
98
112
|
if attrs is None:
|
99
113
|
attrs = dict()
|
100
114
|
attrs[attr.key] = attr.value
|
101
|
-
elif isinstance(attr,
|
115
|
+
elif isinstance(attr, VectorInfo):
|
102
116
|
vector_info = attr
|
103
117
|
elif isinstance(attr, TypeKind):
|
104
118
|
kind = attr.kind
|
105
119
|
|
106
120
|
dataclass_type = None
|
107
121
|
elem_type = None
|
108
|
-
|
122
|
+
key_type = None
|
123
|
+
if _is_struct_type(t):
|
109
124
|
if kind is None:
|
110
125
|
kind = 'Struct'
|
111
126
|
elif kind != 'Struct':
|
112
127
|
raise ValueError(f"Unexpected type kind for struct: {kind}")
|
113
128
|
dataclass_type = t
|
114
129
|
elif base_type is collections.abc.Sequence or base_type is list:
|
130
|
+
args = typing.get_args(t)
|
131
|
+
elem_type = args[0]
|
132
|
+
|
115
133
|
if kind is None:
|
116
|
-
|
117
|
-
|
134
|
+
if _is_struct_type(elem_type):
|
135
|
+
kind = 'LTable'
|
136
|
+
if vector_info is not None:
|
137
|
+
raise ValueError("Vector element must be a simple type, not a struct")
|
138
|
+
else:
|
139
|
+
kind = 'Vector'
|
140
|
+
if vector_info is None:
|
141
|
+
vector_info = VectorInfo(dim=None)
|
142
|
+
elif not (kind == 'Vector' or kind in TABLE_TYPES):
|
118
143
|
raise ValueError(f"Unexpected type kind for list: {kind}")
|
119
|
-
|
144
|
+
elif base_type is collections.abc.Mapping or base_type is dict:
|
120
145
|
args = typing.get_args(t)
|
121
|
-
|
122
|
-
|
123
|
-
elem_type = args[0]
|
146
|
+
elem_type = (args[0], args[1])
|
147
|
+
kind = 'KTable'
|
124
148
|
elif kind is None:
|
125
|
-
if
|
126
|
-
kind = 'Vector' if vector_info is not None else 'List'
|
127
|
-
elif t is bytes:
|
149
|
+
if t is bytes:
|
128
150
|
kind = 'Bytes'
|
129
151
|
elif t is str:
|
130
152
|
kind = 'Str'
|
@@ -145,20 +167,26 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
|
|
145
167
|
else:
|
146
168
|
raise ValueError(f"type unsupported yet: {t}")
|
147
169
|
|
148
|
-
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info,
|
149
|
-
|
170
|
+
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info,
|
171
|
+
elem_type=elem_type, key_type=key_type, dataclass_type=dataclass_type,
|
172
|
+
attrs=attrs, nullable=nullable)
|
150
173
|
|
151
|
-
def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]:
|
174
|
+
def _encode_fields_schema(dataclass_type: type, key_type: type | None = None) -> list[dict[str, Any]]:
|
152
175
|
result = []
|
153
|
-
|
176
|
+
def add_field(name: str, t) -> None:
|
154
177
|
try:
|
155
|
-
type_info = encode_enriched_type_info(analyze_type_info(
|
178
|
+
type_info = encode_enriched_type_info(analyze_type_info(t))
|
156
179
|
except ValueError as e:
|
157
180
|
e.add_note(f"Failed to encode annotation for field - "
|
158
|
-
f"{dataclass_type.__name__}.{
|
181
|
+
f"{dataclass_type.__name__}.{name}: {t}")
|
159
182
|
raise
|
160
|
-
type_info['name'] =
|
183
|
+
type_info['name'] = name
|
161
184
|
result.append(type_info)
|
185
|
+
|
186
|
+
if key_type is not None:
|
187
|
+
add_field(KEY_FIELD_NAME, key_type)
|
188
|
+
for field in dataclasses.fields(dataclass_type):
|
189
|
+
add_field(field.name, field.type)
|
162
190
|
return result
|
163
191
|
|
164
192
|
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
@@ -167,7 +195,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
167
195
|
if type_info.kind == 'Struct':
|
168
196
|
if type_info.dataclass_type is None:
|
169
197
|
raise ValueError("Struct type must have a dataclass type")
|
170
|
-
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type)
|
198
|
+
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type, type_info.key_type)
|
171
199
|
if doc := inspect.getdoc(type_info.dataclass_type):
|
172
200
|
encoded_type['description'] = doc
|
173
201
|
|
@@ -179,7 +207,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
|
|
179
207
|
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
|
180
208
|
encoded_type['dimension'] = type_info.vector_info.dim
|
181
209
|
|
182
|
-
elif type_info.kind in
|
210
|
+
elif type_info.kind in TABLE_TYPES:
|
183
211
|
if type_info.elem_type is None:
|
184
212
|
raise ValueError(f"{type_info.kind} type must have an element type")
|
185
213
|
row_type_info = analyze_type_info(type_info.elem_type)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: cocoindex
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.27
|
4
4
|
Requires-Dist: sentence-transformers>=3.3.1
|
5
5
|
Requires-Dist: click>=8.1.8
|
6
6
|
Requires-Dist: pytest ; extra == 'test'
|
@@ -117,7 +117,7 @@ Go to the [examples directory](examples) to try out with any of the examples, fo
|
|
117
117
|
| [PDF Embedding](examples/pdf_embedding) | Parse PDF and index text embeddings for semantic search |
|
118
118
|
| [Manuals LLM Extraction](examples/manuals_llm_extraction) | Extract structured information from a manual using LLM |
|
119
119
|
| [Google Drive Text Embedding](examples/gdrive_text_embedding) | Index text documents from Google Drive |
|
120
|
-
| [Docs to Knowledge Graph](examples/
|
120
|
+
| [Docs to Knowledge Graph](examples/docs_to_knowledge_graph) | Extract relationships from Markdown documents and build a knowledge graph |
|
121
121
|
| [Embeddings to Qdrant](examples/text_embedding_qdrant) | Index documents in a Qdrant collection for semantic search |
|
122
122
|
|
123
123
|
More coming and stay tuned! If there's any specific examples you would like to see, please let us know in our [Discord community](https://discord.com/invite/zpA9S2DR7s) 🌱.
|
@@ -0,0 +1,24 @@
|
|
1
|
+
cocoindex-0.1.27.dist-info/METADATA,sha256=O6XL0i2OUPY0KehlPYXxywyfX6uMdNk1xuTYxdJOcnA,8209
|
2
|
+
cocoindex-0.1.27.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
|
3
|
+
cocoindex-0.1.27.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
4
|
+
cocoindex/auth_registry.py,sha256=QQI8P7NMvA_NFIHXzV-yS94wVaZuXaIjY992Vrcii1E,641
|
5
|
+
cocoindex/cli.py,sha256=kX2PQx7MJYX9UBGvyzIMqpzZSJfZA-NQJl9QI3m1ICU,7252
|
6
|
+
cocoindex/convert.py,sha256=M8wRr38AgPpd43eklyWpNMaAvKHsIEyy_L3IlU7Q0oA,6095
|
7
|
+
cocoindex/flow.py,sha256=Gu9uwDtoNz7IWnQFX6GtaSgee1NpODT1_4qS90OxqaE,21672
|
8
|
+
cocoindex/functions.py,sha256=2yPL_s908AFxb6vVleFOVs8-6SCyT5cO-M7JLDqVFzA,1694
|
9
|
+
cocoindex/index.py,sha256=32iiQI60VKhRRHle17rpoEVa_tsAHnyXXXnrLaX68uQ,557
|
10
|
+
cocoindex/lib.py,sha256=V1bNKH6lpb1WT4f4F_4qD7e4Uarq-sS-donKXH5PjFo,4001
|
11
|
+
cocoindex/llm.py,sha256=POMdB-huMvPkRDpvcaBeOgXfbm0YZa0swNOdD2s4TRc,364
|
12
|
+
cocoindex/op.py,sha256=SYjXRlMxi304OcuOLPGlovv3CYa_DeQUunx4YJjmRc0,10871
|
13
|
+
cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
cocoindex/query.py,sha256=GU3V8AoSK--xOU2yX0Ao-Rz1t3mlrsJ4JIhzsLVwFH8,3297
|
15
|
+
cocoindex/runtime.py,sha256=WwyDSpJhvaHjnxH4q6r2MVftJB-QkNVH_TxZLxuDtFw,1009
|
16
|
+
cocoindex/setup.py,sha256=euHPZ9-7a68DV4PbdK2A0cZ9ECa6vT4qBraY3WKZk4s,512
|
17
|
+
cocoindex/sources.py,sha256=4JWsW7sjFuiizaxOkxXJpZgVmVgF9ys9ATvvRY-SBjE,966
|
18
|
+
cocoindex/storages.py,sha256=AYYua_rzDUyvYb2ou3YFjO3jvRUcJKUTQlIYZzoTDnE,2036
|
19
|
+
cocoindex/tests/test_convert.py,sha256=l5VpBgDw8F-0eMT_kEitUQTyI5newGnxqsm3gb6c2XA,11516
|
20
|
+
cocoindex/tests/__init__.py,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
|
21
|
+
cocoindex/typing.py,sha256=RP322ATJqGkHkC7ht0s0tuRMg39CeNKUhk3cR_f78EY,8705
|
22
|
+
cocoindex/__init__.py,sha256=FZo8bM7rmJ3Ppg_ysAjjlWxiHPQPrNNwyHb4V-onXT0,673
|
23
|
+
cocoindex/_engine.cp311-win_amd64.pyd,sha256=79tgqK7SCVNjQldNdOWBdZRNynaQviO03U8C5FFwJpM,58859008
|
24
|
+
cocoindex-0.1.27.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
cocoindex-0.1.25.dist-info/METADATA,sha256=-ufLCvjdS1aeIfab5HeYCq0rwMWGp5L98HNWnjijtL8,8196
|
2
|
-
cocoindex-0.1.25.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
|
3
|
-
cocoindex-0.1.25.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
4
|
-
cocoindex/auth_registry.py,sha256=QQI8P7NMvA_NFIHXzV-yS94wVaZuXaIjY992Vrcii1E,641
|
5
|
-
cocoindex/cli.py,sha256=-7lJjjXQbgYZTgsbrH9yl1DvWz-8vwPCmaWHKX_IdCM,7119
|
6
|
-
cocoindex/convert.py,sha256=BIIc2YGr3bc5zVgn4M_kcJeAyVkbl7YJ1ZVNgYIAmpQ,5144
|
7
|
-
cocoindex/flow.py,sha256=izi6qxDAB35_rYuyc62-J2Petq88iIs6qBEHLBdm_3Q,20560
|
8
|
-
cocoindex/functions.py,sha256=kF4GUK1uvCu4IE5MLlCidoGyONp3b6lm47qcjMaFCtg,1704
|
9
|
-
cocoindex/index.py,sha256=32iiQI60VKhRRHle17rpoEVa_tsAHnyXXXnrLaX68uQ,557
|
10
|
-
cocoindex/lib.py,sha256=FTzT07U1dxjP1QsMDLndD4pBem7Baerf1ZbSaVq-OVM,3582
|
11
|
-
cocoindex/llm.py,sha256=t6EcAYM17tl7UgMSAizSRbXF8M96g3tQdynASLsqx6I,312
|
12
|
-
cocoindex/op.py,sha256=ee9Zu0pNygALCb9iPHPGqjQhD2PnNFcdVNPIezMEcPw,10609
|
13
|
-
cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
cocoindex/query.py,sha256=FO_w1lsh4-lu7wPqkEDcZyvqKeis3e3Bq1L1h--91fg,3364
|
15
|
-
cocoindex/runtime.py,sha256=R7FBysOH-4TL75cZNeDvAeEZ5jNmJ0m0zfkSUE7l0Lg,713
|
16
|
-
cocoindex/setup.py,sha256=euHPZ9-7a68DV4PbdK2A0cZ9ECa6vT4qBraY3WKZk4s,512
|
17
|
-
cocoindex/sources.py,sha256=4JWsW7sjFuiizaxOkxXJpZgVmVgF9ys9ATvvRY-SBjE,966
|
18
|
-
cocoindex/storages.py,sha256=zICYAP3qHSFbGiVuNJNqn73-P6waNoBnU_SKOPfeygc,1882
|
19
|
-
cocoindex/tests/test_convert.py,sha256=usmtQ7O91FYCfw131dxtmOrpFHo5I1zS-CFjPs8eHAs,2127
|
20
|
-
cocoindex/tests/__init__.py,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
|
21
|
-
cocoindex/typing.py,sha256=SlO_QNa8lLJIdqJgESxISNsP7LSThC5DoWvFDcSgHK8,7501
|
22
|
-
cocoindex/__init__.py,sha256=hdpaL27JWb9Jr1I0gVaZa9QffBLxb_MABkEcNTfU524,583
|
23
|
-
cocoindex/_engine.cp311-win_amd64.pyd,sha256=aJ4bqXJtDQXCcfYWTCg0D7Aeymd96EDa2FqMbL9-KRA,58040832
|
24
|
-
cocoindex-0.1.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|