cocoindex 0.1.43__cp311-cp311-macosx_11_0_arm64.whl → 0.1.45__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +2 -1
- cocoindex/_engine.cpython-311-darwin.so +0 -0
- cocoindex/auth_registry.py +7 -3
- cocoindex/cli.py +186 -67
- cocoindex/convert.py +93 -52
- cocoindex/flow.py +303 -132
- cocoindex/functions.py +17 -4
- cocoindex/index.py +6 -0
- cocoindex/lib.py +14 -9
- cocoindex/llm.py +4 -0
- cocoindex/op.py +126 -61
- cocoindex/query.py +40 -17
- cocoindex/runtime.py +9 -4
- cocoindex/setting.py +35 -12
- cocoindex/setup.py +7 -3
- cocoindex/sources.py +3 -1
- cocoindex/storages.py +50 -7
- cocoindex/tests/test_convert.py +255 -63
- cocoindex/typing.py +116 -70
- cocoindex/utils.py +10 -2
- {cocoindex-0.1.43.dist-info → cocoindex-0.1.45.dist-info}/METADATA +3 -1
- cocoindex-0.1.45.dist-info/RECORD +27 -0
- cocoindex-0.1.43.dist-info/RECORD +0 -27
- {cocoindex-0.1.43.dist-info → cocoindex-0.1.45.dist-info}/WHEEL +0 -0
- {cocoindex-0.1.43.dist-info → cocoindex-0.1.45.dist-info}/entry_points.txt +0 -0
- {cocoindex-0.1.43.dist-info → cocoindex-0.1.45.dist-info}/licenses/LICENSE +0 -0
cocoindex/functions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""All builtin functions."""
|
2
|
+
|
2
3
|
from typing import Annotated, Any, TYPE_CHECKING
|
3
4
|
|
4
5
|
from .typing import Float32, Vector, TypeAttr
|
@@ -8,12 +9,15 @@ from . import op, llm
|
|
8
9
|
if TYPE_CHECKING:
|
9
10
|
import sentence_transformers
|
10
11
|
|
12
|
+
|
11
13
|
class ParseJson(op.FunctionSpec):
|
12
14
|
"""Parse a text into a JSON object."""
|
13
15
|
|
16
|
+
|
14
17
|
class SplitRecursively(op.FunctionSpec):
|
15
18
|
"""Split a document (in string) recursively."""
|
16
19
|
|
20
|
+
|
17
21
|
class ExtractByLlm(op.FunctionSpec):
|
18
22
|
"""Extract information from a text using a LLM."""
|
19
23
|
|
@@ -21,6 +25,7 @@ class ExtractByLlm(op.FunctionSpec):
|
|
21
25
|
output_type: type
|
22
26
|
instruction: str | None = None
|
23
27
|
|
28
|
+
|
24
29
|
class SentenceTransformerEmbed(op.FunctionSpec):
|
25
30
|
"""
|
26
31
|
`SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library.
|
@@ -30,9 +35,11 @@ class SentenceTransformerEmbed(op.FunctionSpec):
|
|
30
35
|
model: The name of the SentenceTransformer model to use.
|
31
36
|
args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True}
|
32
37
|
"""
|
38
|
+
|
33
39
|
model: str
|
34
40
|
args: dict[str, Any] | None = None
|
35
41
|
|
42
|
+
|
36
43
|
@op.executor_class(gpu=True, cache=True, behavior_version=1)
|
37
44
|
class SentenceTransformerEmbedExecutor:
|
38
45
|
"""Executor for SentenceTransformerEmbed."""
|
@@ -40,12 +47,18 @@ class SentenceTransformerEmbedExecutor:
|
|
40
47
|
spec: SentenceTransformerEmbed
|
41
48
|
_model: "sentence_transformers.SentenceTransformer"
|
42
49
|
|
43
|
-
def analyze(self, text):
|
44
|
-
import sentence_transformers
|
50
|
+
def analyze(self, text: Any) -> type:
|
51
|
+
import sentence_transformers # pylint: disable=import-outside-toplevel
|
52
|
+
|
45
53
|
args = self.spec.args or {}
|
46
54
|
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
|
47
55
|
dim = self._model.get_sentence_embedding_dimension()
|
48
|
-
|
56
|
+
result: type = Annotated[
|
57
|
+
Vector[Float32, dim], # type: ignore
|
58
|
+
TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
|
59
|
+
]
|
60
|
+
return result
|
49
61
|
|
50
62
|
def __call__(self, text: str) -> list[Float32]:
|
51
|
-
|
63
|
+
result: list[Float32] = self._model.encode(text).tolist()
|
64
|
+
return result
|
cocoindex/index.py
CHANGED
@@ -1,23 +1,29 @@
|
|
1
1
|
from enum import Enum
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from typing import Sequence
|
4
|
+
|
5
|
+
|
4
6
|
class VectorSimilarityMetric(Enum):
|
5
7
|
COSINE_SIMILARITY = "CosineSimilarity"
|
6
8
|
L2_DISTANCE = "L2Distance"
|
7
9
|
INNER_PRODUCT = "InnerProduct"
|
8
10
|
|
11
|
+
|
9
12
|
@dataclass
|
10
13
|
class VectorIndexDef:
|
11
14
|
"""
|
12
15
|
Define a vector index on a field.
|
13
16
|
"""
|
17
|
+
|
14
18
|
field_name: str
|
15
19
|
metric: VectorSimilarityMetric
|
16
20
|
|
21
|
+
|
17
22
|
@dataclass
|
18
23
|
class IndexOptions:
|
19
24
|
"""
|
20
25
|
Options for an index.
|
21
26
|
"""
|
27
|
+
|
22
28
|
primary_key_fields: Sequence[str]
|
23
29
|
vector_indexes: Sequence[VectorIndexDef] = ()
|
cocoindex/lib.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
"""
|
2
2
|
Library level functions and states.
|
3
3
|
"""
|
4
|
+
|
4
5
|
import warnings
|
5
6
|
from typing import Callable, Any
|
6
7
|
|
7
|
-
from . import _engine
|
8
|
+
from . import _engine # type: ignore
|
9
|
+
from . import flow, query, setting
|
8
10
|
from .convert import dump_engine_object
|
9
11
|
|
10
12
|
|
11
|
-
def init(settings: setting.Settings | None = None):
|
13
|
+
def init(settings: setting.Settings | None = None) -> None:
|
12
14
|
"""
|
13
15
|
Initialize the cocoindex library.
|
14
16
|
|
@@ -19,20 +21,22 @@ def init(settings: setting.Settings | None = None):
|
|
19
21
|
setting.set_app_namespace(settings.app_namespace)
|
20
22
|
|
21
23
|
|
22
|
-
def start_server(settings: setting.ServerSettings):
|
24
|
+
def start_server(settings: setting.ServerSettings) -> None:
|
23
25
|
"""Start the cocoindex server."""
|
24
26
|
flow.ensure_all_flows_built()
|
25
27
|
query.ensure_all_handlers_built()
|
26
28
|
_engine.start_server(settings.__dict__)
|
27
29
|
|
28
|
-
|
30
|
+
|
31
|
+
def stop() -> None:
|
29
32
|
"""Stop the cocoindex library."""
|
30
33
|
_engine.stop()
|
31
34
|
|
35
|
+
|
32
36
|
def main_fn(
|
33
|
-
|
34
|
-
|
35
|
-
|
37
|
+
settings: Any | None = None,
|
38
|
+
cocoindex_cmd: str | None = None,
|
39
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
36
40
|
"""
|
37
41
|
DEPRECATED: The @cocoindex.main_fn() decorator is obsolete and has no effect.
|
38
42
|
It will be removed in a future version, which will cause an AttributeError.
|
@@ -63,9 +67,10 @@ def main_fn(
|
|
63
67
|
"See cocoindex <command> --help for more details.\n"
|
64
68
|
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n",
|
65
69
|
DeprecationWarning,
|
66
|
-
stacklevel=2
|
70
|
+
stacklevel=2,
|
67
71
|
)
|
68
72
|
|
69
|
-
def _main_wrapper(fn: Callable) -> Callable:
|
73
|
+
def _main_wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
|
70
74
|
return fn
|
75
|
+
|
71
76
|
return _main_wrapper
|
cocoindex/llm.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
from enum import Enum
|
3
3
|
|
4
|
+
|
4
5
|
class LlmApiType(Enum):
|
5
6
|
"""The type of LLM API to use."""
|
7
|
+
|
6
8
|
OPENAI = "OpenAi"
|
7
9
|
OLLAMA = "Ollama"
|
8
10
|
GEMINI = "Gemini"
|
9
11
|
ANTHROPIC = "Anthropic"
|
10
12
|
|
13
|
+
|
11
14
|
@dataclass
|
12
15
|
class LlmSpec:
|
13
16
|
"""A specification for a LLM."""
|
17
|
+
|
14
18
|
api_type: LlmApiType
|
15
19
|
model: str
|
16
20
|
address: str | None = None
|
cocoindex/op.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
Facilities for defining cocoindex operations.
|
3
3
|
"""
|
4
|
+
|
4
5
|
import asyncio
|
5
6
|
import dataclasses
|
6
7
|
import inspect
|
@@ -10,40 +11,58 @@ from enum import Enum
|
|
10
11
|
|
11
12
|
from .typing import encode_enriched_type, resolve_forward_ref
|
12
13
|
from .convert import encode_engine_value, make_engine_value_decoder
|
13
|
-
from . import _engine
|
14
|
+
from . import _engine # type: ignore
|
15
|
+
|
14
16
|
|
15
17
|
class OpCategory(Enum):
|
16
18
|
"""The category of the operation."""
|
19
|
+
|
17
20
|
FUNCTION = "function"
|
18
21
|
SOURCE = "source"
|
19
22
|
STORAGE = "storage"
|
20
23
|
DECLARATION = "declaration"
|
24
|
+
|
25
|
+
|
21
26
|
@dataclass_transform()
|
22
27
|
class SpecMeta(type):
|
23
28
|
"""Meta class for spec classes."""
|
24
|
-
|
29
|
+
|
30
|
+
def __new__(
|
31
|
+
mcs,
|
32
|
+
name: str,
|
33
|
+
bases: tuple[type, ...],
|
34
|
+
attrs: dict[str, Any],
|
35
|
+
category: OpCategory | None = None,
|
36
|
+
) -> type:
|
25
37
|
cls: type = super().__new__(mcs, name, bases, attrs)
|
26
38
|
if category is not None:
|
27
39
|
# It's the base class.
|
28
|
-
setattr(cls,
|
40
|
+
setattr(cls, "_op_category", category)
|
29
41
|
else:
|
30
42
|
# It's the specific class providing specific fields.
|
31
43
|
cls = dataclasses.dataclass(cls)
|
32
44
|
return cls
|
33
45
|
|
34
|
-
|
46
|
+
|
47
|
+
class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
|
35
48
|
"""A source spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
36
49
|
|
37
|
-
|
50
|
+
|
51
|
+
class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
|
38
52
|
"""A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
39
53
|
|
40
|
-
|
54
|
+
|
55
|
+
class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods
|
41
56
|
"""A storage spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
42
57
|
|
43
|
-
|
58
|
+
|
59
|
+
class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
|
44
60
|
"""A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
|
61
|
+
|
62
|
+
|
45
63
|
class Executor(Protocol):
|
46
64
|
"""An executor for an operation."""
|
65
|
+
|
47
66
|
op_category: OpCategory
|
48
67
|
|
49
68
|
|
@@ -55,7 +74,9 @@ class _FunctionExecutorFactory:
|
|
55
74
|
self._spec_cls = spec_cls
|
56
75
|
self._executor_cls = executor_cls
|
57
76
|
|
58
|
-
def __call__(
|
77
|
+
def __call__(
|
78
|
+
self, spec: dict[str, Any], *args: Any, **kwargs: Any
|
79
|
+
) -> tuple[dict[str, Any], Executor]:
|
59
80
|
spec = self._spec_cls(**spec)
|
60
81
|
executor = self._executor_cls(spec)
|
61
82
|
result_type = executor.analyze(*args, **kwargs)
|
@@ -64,6 +85,7 @@ class _FunctionExecutorFactory:
|
|
64
85
|
|
65
86
|
_gpu_dispatch_lock = asyncio.Lock()
|
66
87
|
|
88
|
+
|
67
89
|
@dataclasses.dataclass
|
68
90
|
class OpArgs:
|
69
91
|
"""
|
@@ -72,44 +94,50 @@ class OpArgs:
|
|
72
94
|
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
|
73
95
|
changes. Must be provided if `cache` is True.
|
74
96
|
"""
|
97
|
+
|
75
98
|
gpu: bool = False
|
76
99
|
cache: bool = False
|
77
100
|
behavior_version: int | None = None
|
78
101
|
|
79
|
-
|
102
|
+
|
103
|
+
def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
80
104
|
if inspect.iscoroutinefunction(call):
|
81
105
|
return call
|
82
106
|
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
|
83
107
|
|
108
|
+
|
84
109
|
def _register_op_factory(
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
110
|
+
category: OpCategory,
|
111
|
+
expected_args: list[tuple[str, inspect.Parameter]],
|
112
|
+
expected_return: Any,
|
113
|
+
executor_cls: type,
|
114
|
+
spec_cls: type,
|
115
|
+
op_args: OpArgs,
|
116
|
+
) -> type:
|
92
117
|
"""
|
93
118
|
Register an op factory.
|
94
119
|
"""
|
120
|
+
|
95
121
|
class _Fallback:
|
96
|
-
def enable_cache(self):
|
122
|
+
def enable_cache(self) -> bool:
|
97
123
|
return op_args.cache
|
98
124
|
|
99
|
-
def behavior_version(self):
|
125
|
+
def behavior_version(self) -> int | None:
|
100
126
|
return op_args.behavior_version
|
101
127
|
|
102
|
-
class _WrappedClass(executor_cls, _Fallback):
|
128
|
+
class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
|
103
129
|
_args_decoders: list[Callable[[Any], Any]]
|
104
|
-
_kwargs_decoders: dict[str, Callable[[
|
105
|
-
_acall: Callable
|
130
|
+
_kwargs_decoders: dict[str, Callable[[Any], Any]]
|
131
|
+
_acall: Callable[..., Awaitable[Any]]
|
106
132
|
|
107
|
-
def __init__(self, spec):
|
133
|
+
def __init__(self, spec: Any) -> None:
|
108
134
|
super().__init__()
|
109
135
|
self.spec = spec
|
110
136
|
self._acall = _to_async_call(super().__call__)
|
111
137
|
|
112
|
-
def analyze(
|
138
|
+
def analyze(
|
139
|
+
self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
|
140
|
+
) -> Any:
|
113
141
|
"""
|
114
142
|
Analyze the spec and arguments. In this phase, argument types should be validated.
|
115
143
|
It should return the expected result type for the current op.
|
@@ -122,15 +150,21 @@ def _register_op_factory(
|
|
122
150
|
for arg in args:
|
123
151
|
if next_param_idx >= len(expected_args):
|
124
152
|
raise ValueError(
|
125
|
-
f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
|
153
|
+
f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
|
154
|
+
)
|
126
155
|
arg_name, arg_param = expected_args[next_param_idx]
|
127
156
|
if arg_param.kind in (
|
128
|
-
inspect.Parameter.KEYWORD_ONLY,
|
157
|
+
inspect.Parameter.KEYWORD_ONLY,
|
158
|
+
inspect.Parameter.VAR_KEYWORD,
|
159
|
+
):
|
129
160
|
raise ValueError(
|
130
|
-
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
|
161
|
+
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
|
162
|
+
)
|
131
163
|
self._args_decoders.append(
|
132
164
|
make_engine_value_decoder(
|
133
|
-
[arg_name], arg.value_type[
|
165
|
+
[arg_name], arg.value_type["type"], arg_param.annotation
|
166
|
+
)
|
167
|
+
)
|
134
168
|
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
|
135
169
|
next_param_idx += 1
|
136
170
|
|
@@ -138,45 +172,72 @@ def _register_op_factory(
|
|
138
172
|
|
139
173
|
for kwarg_name, kwarg in kwargs.items():
|
140
174
|
expected_arg = next(
|
141
|
-
(
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
175
|
+
(
|
176
|
+
arg
|
177
|
+
for arg in expected_kwargs
|
178
|
+
if (
|
179
|
+
arg[0] == kwarg_name
|
180
|
+
and arg[1].kind
|
181
|
+
in (
|
182
|
+
inspect.Parameter.KEYWORD_ONLY,
|
183
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
184
|
+
)
|
185
|
+
)
|
186
|
+
or arg[1].kind == inspect.Parameter.VAR_KEYWORD
|
187
|
+
),
|
188
|
+
None,
|
189
|
+
)
|
146
190
|
if expected_arg is None:
|
147
|
-
raise ValueError(
|
191
|
+
raise ValueError(
|
192
|
+
f"Unexpected keyword argument passed in: {kwarg_name}"
|
193
|
+
)
|
148
194
|
arg_param = expected_arg[1]
|
149
195
|
self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
|
150
|
-
[kwarg_name], kwarg.value_type[
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
196
|
+
[kwarg_name], kwarg.value_type["type"], arg_param.annotation
|
197
|
+
)
|
198
|
+
|
199
|
+
missing_args = [
|
200
|
+
name
|
201
|
+
for (name, arg) in expected_kwargs
|
202
|
+
if arg.default is inspect.Parameter.empty
|
203
|
+
and (
|
204
|
+
arg.kind == inspect.Parameter.POSITIONAL_ONLY
|
205
|
+
or (
|
206
|
+
arg.kind
|
207
|
+
in (
|
208
|
+
inspect.Parameter.KEYWORD_ONLY,
|
209
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
210
|
+
)
|
211
|
+
and name not in kwargs
|
212
|
+
)
|
213
|
+
)
|
214
|
+
]
|
158
215
|
if len(missing_args) > 0:
|
159
216
|
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
|
160
217
|
|
161
|
-
prepare_method = getattr(executor_cls,
|
218
|
+
prepare_method = getattr(executor_cls, "analyze", None)
|
162
219
|
if prepare_method is not None:
|
163
220
|
return prepare_method(self, *args, **kwargs)
|
164
221
|
else:
|
165
222
|
return expected_return
|
166
223
|
|
167
|
-
async def prepare(self):
|
224
|
+
async def prepare(self) -> None:
|
168
225
|
"""
|
169
226
|
Prepare for execution.
|
170
227
|
It's executed after `analyze` and before any `__call__` execution.
|
171
228
|
"""
|
172
|
-
setup_method = getattr(super(),
|
229
|
+
setup_method = getattr(super(), "prepare", None)
|
173
230
|
if setup_method is not None:
|
174
231
|
await _to_async_call(setup_method)()
|
175
232
|
|
176
|
-
async def __call__(self, *args, **kwargs):
|
177
|
-
decoded_args = (
|
178
|
-
|
179
|
-
|
233
|
+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
234
|
+
decoded_args = (
|
235
|
+
decoder(arg) for decoder, arg in zip(self._args_decoders, args)
|
236
|
+
)
|
237
|
+
decoded_kwargs = {
|
238
|
+
arg_name: self._kwargs_decoders[arg_name](arg)
|
239
|
+
for arg_name, arg in kwargs.items()
|
240
|
+
}
|
180
241
|
|
181
242
|
if op_args.gpu:
|
182
243
|
# For GPU executions, data-level parallelism is applied, so we don't want to
|
@@ -198,13 +259,15 @@ def _register_op_factory(
|
|
198
259
|
|
199
260
|
if category == OpCategory.FUNCTION:
|
200
261
|
_engine.register_function_factory(
|
201
|
-
spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)
|
262
|
+
spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)
|
263
|
+
)
|
202
264
|
else:
|
203
265
|
raise ValueError(f"Unsupported executor type {category}")
|
204
266
|
|
205
267
|
return _WrappedClass
|
206
268
|
|
207
|
-
|
269
|
+
|
270
|
+
def executor_class(**args: Any) -> Callable[[type], type]:
|
208
271
|
"""
|
209
272
|
Decorate a class to provide an executor for an op.
|
210
273
|
"""
|
@@ -216,9 +279,9 @@ def executor_class(**args) -> Callable[[type], type]:
|
|
216
279
|
"""
|
217
280
|
# Use `__annotations__` instead of `get_type_hints`, to avoid resolving forward references.
|
218
281
|
type_hints = cls.__annotations__
|
219
|
-
if
|
282
|
+
if "spec" not in type_hints:
|
220
283
|
raise TypeError("Expect a `spec` field with type hint")
|
221
|
-
spec_cls = resolve_forward_ref(type_hints[
|
284
|
+
spec_cls = resolve_forward_ref(type_hints["spec"])
|
222
285
|
sig = inspect.signature(cls.__call__)
|
223
286
|
return _register_op_factory(
|
224
287
|
category=spec_cls._op_category,
|
@@ -226,34 +289,35 @@ def executor_class(**args) -> Callable[[type], type]:
|
|
226
289
|
expected_return=sig.return_annotation,
|
227
290
|
executor_cls=cls,
|
228
291
|
spec_cls=spec_cls,
|
229
|
-
op_args=op_args
|
292
|
+
op_args=op_args,
|
293
|
+
)
|
230
294
|
|
231
295
|
return _inner
|
232
296
|
|
233
|
-
|
297
|
+
|
298
|
+
def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
|
234
299
|
"""
|
235
300
|
Decorate a function to provide a function for an op.
|
236
301
|
"""
|
237
302
|
op_args = OpArgs(**args)
|
238
303
|
|
239
|
-
def _inner(fn: Callable) -> FunctionSpec:
|
240
|
-
|
304
|
+
def _inner(fn: Callable[..., Any]) -> FunctionSpec:
|
241
305
|
# Convert snake case to camel case.
|
242
|
-
op_name =
|
306
|
+
op_name = "".join(word.capitalize() for word in fn.__name__.split("_"))
|
243
307
|
sig = inspect.signature(fn)
|
244
308
|
|
245
309
|
class _Executor:
|
246
|
-
def __call__(self, *args, **kwargs):
|
310
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
247
311
|
return fn(*args, **kwargs)
|
248
312
|
|
249
313
|
class _Spec(FunctionSpec):
|
250
|
-
|
314
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
315
|
+
return fn(*args, **kwargs)
|
251
316
|
|
252
317
|
_Spec.__name__ = op_name
|
253
318
|
_Spec.__doc__ = fn.__doc__
|
254
319
|
_Spec.__module__ = fn.__module__
|
255
320
|
_Spec.__qualname__ = fn.__qualname__
|
256
|
-
_Spec.__wrapped__ = fn
|
257
321
|
|
258
322
|
_register_op_factory(
|
259
323
|
category=OpCategory.FUNCTION,
|
@@ -261,7 +325,8 @@ def function(**args) -> Callable[[Callable], FunctionSpec]:
|
|
261
325
|
expected_return=sig.return_annotation,
|
262
326
|
executor_cls=_Executor,
|
263
327
|
spec_cls=_Spec,
|
264
|
-
op_args=op_args
|
328
|
+
op_args=op_args,
|
329
|
+
)
|
265
330
|
|
266
331
|
return _Spec()
|
267
332
|
|
cocoindex/query.py
CHANGED
@@ -4,25 +4,29 @@ from threading import Lock
|
|
4
4
|
|
5
5
|
from . import flow as fl
|
6
6
|
from . import index
|
7
|
-
from . import _engine
|
7
|
+
from . import _engine # type: ignore
|
8
8
|
|
9
9
|
_handlers_lock = Lock()
|
10
10
|
_handlers: dict[str, _engine.SimpleSemanticsQueryHandler] = {}
|
11
11
|
|
12
|
+
|
12
13
|
@dataclass
|
13
14
|
class SimpleSemanticsQueryInfo:
|
14
15
|
"""
|
15
16
|
Additional information about the query.
|
16
17
|
"""
|
18
|
+
|
17
19
|
similarity_metric: index.VectorSimilarityMetric
|
18
20
|
query_vector: list[float]
|
19
21
|
vector_field_name: str
|
20
22
|
|
23
|
+
|
21
24
|
@dataclass
|
22
25
|
class QueryResult:
|
23
26
|
"""
|
24
27
|
A single result from the query.
|
25
28
|
"""
|
29
|
+
|
26
30
|
data: dict[str, Any]
|
27
31
|
score: float
|
28
32
|
|
@@ -31,6 +35,7 @@ class SimpleSemanticsQueryHandler:
|
|
31
35
|
"""
|
32
36
|
A query handler that uses simple semantics to query the index.
|
33
37
|
"""
|
38
|
+
|
34
39
|
_lazy_query_handler: Callable[[], _engine.SimpleSemanticsQueryHandler]
|
35
40
|
|
36
41
|
def __init__(
|
@@ -38,22 +43,28 @@ class SimpleSemanticsQueryHandler:
|
|
38
43
|
name: str,
|
39
44
|
flow: fl.Flow,
|
40
45
|
target_name: str,
|
41
|
-
query_transform_flow: Callable[..., fl.DataSlice],
|
42
|
-
default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY
|
43
|
-
|
46
|
+
query_transform_flow: Callable[..., fl.DataSlice[Any]],
|
47
|
+
default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY,
|
48
|
+
) -> None:
|
44
49
|
engine_handler = None
|
45
50
|
lock = Lock()
|
51
|
+
|
46
52
|
def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
|
47
53
|
nonlocal engine_handler, lock
|
48
54
|
if engine_handler is None:
|
49
55
|
with lock:
|
50
56
|
if engine_handler is None:
|
51
57
|
engine_handler = _engine.SimpleSemanticsQueryHandler(
|
52
|
-
flow.internal_flow(),
|
53
|
-
|
54
|
-
|
58
|
+
flow.internal_flow(),
|
59
|
+
target_name,
|
60
|
+
fl.TransformFlow(
|
61
|
+
query_transform_flow, [str]
|
62
|
+
).internal_flow(),
|
63
|
+
default_similarity_metric.value,
|
64
|
+
)
|
55
65
|
engine_handler.register_query_handler(name)
|
56
66
|
return engine_handler
|
67
|
+
|
57
68
|
self._lazy_query_handler = _lazy_handler
|
58
69
|
|
59
70
|
with _handlers_lock:
|
@@ -65,24 +76,36 @@ class SimpleSemanticsQueryHandler:
|
|
65
76
|
"""
|
66
77
|
return self._lazy_query_handler()
|
67
78
|
|
68
|
-
def search(
|
69
|
-
|
70
|
-
|
79
|
+
def search(
|
80
|
+
self,
|
81
|
+
query: str,
|
82
|
+
limit: int,
|
83
|
+
vector_field_name: str | None = None,
|
84
|
+
similarity_metric: index.VectorSimilarityMetric | None = None,
|
85
|
+
) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
|
71
86
|
"""
|
72
87
|
Search the index with the given query, limit, vector field name, and similarity metric.
|
73
88
|
"""
|
74
89
|
internal_results, internal_info = self.internal_handler().search(
|
75
|
-
query,
|
76
|
-
|
77
|
-
|
78
|
-
|
90
|
+
query,
|
91
|
+
limit,
|
92
|
+
vector_field_name,
|
93
|
+
similarity_metric.value if similarity_metric is not None else None,
|
94
|
+
)
|
95
|
+
results = [
|
96
|
+
QueryResult(data=result["data"], score=result["score"])
|
97
|
+
for result in internal_results
|
98
|
+
]
|
79
99
|
info = SimpleSemanticsQueryInfo(
|
80
|
-
similarity_metric=index.VectorSimilarityMetric(
|
81
|
-
|
82
|
-
|
100
|
+
similarity_metric=index.VectorSimilarityMetric(
|
101
|
+
internal_info["similarity_metric"]
|
102
|
+
),
|
103
|
+
query_vector=internal_info["query_vector"],
|
104
|
+
vector_field_name=internal_info["vector_field_name"],
|
83
105
|
)
|
84
106
|
return results, info
|
85
107
|
|
108
|
+
|
86
109
|
def ensure_all_handlers_built() -> None:
|
87
110
|
"""
|
88
111
|
Ensure all handlers are built.
|