cocoindex 0.1.43__cp311-cp311-macosx_11_0_arm64.whl → 0.1.45__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/functions.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """All builtin functions."""
2
+
2
3
  from typing import Annotated, Any, TYPE_CHECKING
3
4
 
4
5
  from .typing import Float32, Vector, TypeAttr
@@ -8,12 +9,15 @@ from . import op, llm
8
9
  if TYPE_CHECKING:
9
10
  import sentence_transformers
10
11
 
12
+
11
13
  class ParseJson(op.FunctionSpec):
12
14
  """Parse a text into a JSON object."""
13
15
 
16
+
14
17
  class SplitRecursively(op.FunctionSpec):
15
18
  """Split a document (in string) recursively."""
16
19
 
20
+
17
21
  class ExtractByLlm(op.FunctionSpec):
18
22
  """Extract information from a text using a LLM."""
19
23
 
@@ -21,6 +25,7 @@ class ExtractByLlm(op.FunctionSpec):
21
25
  output_type: type
22
26
  instruction: str | None = None
23
27
 
28
+
24
29
  class SentenceTransformerEmbed(op.FunctionSpec):
25
30
  """
26
31
  `SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library.
@@ -30,9 +35,11 @@ class SentenceTransformerEmbed(op.FunctionSpec):
30
35
  model: The name of the SentenceTransformer model to use.
31
36
  args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True}
32
37
  """
38
+
33
39
  model: str
34
40
  args: dict[str, Any] | None = None
35
41
 
42
+
36
43
  @op.executor_class(gpu=True, cache=True, behavior_version=1)
37
44
  class SentenceTransformerEmbedExecutor:
38
45
  """Executor for SentenceTransformerEmbed."""
@@ -40,12 +47,18 @@ class SentenceTransformerEmbedExecutor:
40
47
  spec: SentenceTransformerEmbed
41
48
  _model: "sentence_transformers.SentenceTransformer"
42
49
 
43
- def analyze(self, text):
44
- import sentence_transformers # pylint: disable=import-outside-toplevel
50
+ def analyze(self, text: Any) -> type:
51
+ import sentence_transformers # pylint: disable=import-outside-toplevel
52
+
45
53
  args = self.spec.args or {}
46
54
  self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
47
55
  dim = self._model.get_sentence_embedding_dimension()
48
- return Annotated[Vector[Float32, dim], TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)]
56
+ result: type = Annotated[
57
+ Vector[Float32, dim], # type: ignore
58
+ TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value),
59
+ ]
60
+ return result
49
61
 
50
62
  def __call__(self, text: str) -> list[Float32]:
51
- return self._model.encode(text).tolist()
63
+ result: list[Float32] = self._model.encode(text).tolist()
64
+ return result
cocoindex/index.py CHANGED
@@ -1,23 +1,29 @@
1
1
  from enum import Enum
2
2
  from dataclasses import dataclass
3
3
  from typing import Sequence
4
+
5
+
4
6
  class VectorSimilarityMetric(Enum):
5
7
  COSINE_SIMILARITY = "CosineSimilarity"
6
8
  L2_DISTANCE = "L2Distance"
7
9
  INNER_PRODUCT = "InnerProduct"
8
10
 
11
+
9
12
  @dataclass
10
13
  class VectorIndexDef:
11
14
  """
12
15
  Define a vector index on a field.
13
16
  """
17
+
14
18
  field_name: str
15
19
  metric: VectorSimilarityMetric
16
20
 
21
+
17
22
  @dataclass
18
23
  class IndexOptions:
19
24
  """
20
25
  Options for an index.
21
26
  """
27
+
22
28
  primary_key_fields: Sequence[str]
23
29
  vector_indexes: Sequence[VectorIndexDef] = ()
cocoindex/lib.py CHANGED
@@ -1,14 +1,16 @@
1
1
  """
2
2
  Library level functions and states.
3
3
  """
4
+
4
5
  import warnings
5
6
  from typing import Callable, Any
6
7
 
7
- from . import _engine, flow, query, setting
8
+ from . import _engine # type: ignore
9
+ from . import flow, query, setting
8
10
  from .convert import dump_engine_object
9
11
 
10
12
 
11
- def init(settings: setting.Settings | None = None):
13
+ def init(settings: setting.Settings | None = None) -> None:
12
14
  """
13
15
  Initialize the cocoindex library.
14
16
 
@@ -19,20 +21,22 @@ def init(settings: setting.Settings | None = None):
19
21
  setting.set_app_namespace(settings.app_namespace)
20
22
 
21
23
 
22
- def start_server(settings: setting.ServerSettings):
24
+ def start_server(settings: setting.ServerSettings) -> None:
23
25
  """Start the cocoindex server."""
24
26
  flow.ensure_all_flows_built()
25
27
  query.ensure_all_handlers_built()
26
28
  _engine.start_server(settings.__dict__)
27
29
 
28
- def stop():
30
+
31
+ def stop() -> None:
29
32
  """Stop the cocoindex library."""
30
33
  _engine.stop()
31
34
 
35
+
32
36
  def main_fn(
33
- settings: Any | None = None,
34
- cocoindex_cmd: str | None = None,
35
- ) -> Callable[[Callable], Callable]:
37
+ settings: Any | None = None,
38
+ cocoindex_cmd: str | None = None,
39
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
36
40
  """
37
41
  DEPRECATED: The @cocoindex.main_fn() decorator is obsolete and has no effect.
38
42
  It will be removed in a future version, which will cause an AttributeError.
@@ -63,9 +67,10 @@ def main_fn(
63
67
  "See cocoindex <command> --help for more details.\n"
64
68
  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n",
65
69
  DeprecationWarning,
66
- stacklevel=2
70
+ stacklevel=2,
67
71
  )
68
72
 
69
- def _main_wrapper(fn: Callable) -> Callable:
73
+ def _main_wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
70
74
  return fn
75
+
71
76
  return _main_wrapper
cocoindex/llm.py CHANGED
@@ -1,16 +1,20 @@
1
1
  from dataclasses import dataclass
2
2
  from enum import Enum
3
3
 
4
+
4
5
  class LlmApiType(Enum):
5
6
  """The type of LLM API to use."""
7
+
6
8
  OPENAI = "OpenAi"
7
9
  OLLAMA = "Ollama"
8
10
  GEMINI = "Gemini"
9
11
  ANTHROPIC = "Anthropic"
10
12
 
13
+
11
14
  @dataclass
12
15
  class LlmSpec:
13
16
  """A specification for a LLM."""
17
+
14
18
  api_type: LlmApiType
15
19
  model: str
16
20
  address: str | None = None
cocoindex/op.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Facilities for defining cocoindex operations.
3
3
  """
4
+
4
5
  import asyncio
5
6
  import dataclasses
6
7
  import inspect
@@ -10,40 +11,58 @@ from enum import Enum
10
11
 
11
12
  from .typing import encode_enriched_type, resolve_forward_ref
12
13
  from .convert import encode_engine_value, make_engine_value_decoder
13
- from . import _engine
14
+ from . import _engine # type: ignore
15
+
14
16
 
15
17
  class OpCategory(Enum):
16
18
  """The category of the operation."""
19
+
17
20
  FUNCTION = "function"
18
21
  SOURCE = "source"
19
22
  STORAGE = "storage"
20
23
  DECLARATION = "declaration"
24
+
25
+
21
26
  @dataclass_transform()
22
27
  class SpecMeta(type):
23
28
  """Meta class for spec classes."""
24
- def __new__(mcs, name, bases, attrs, category: OpCategory | None = None):
29
+
30
+ def __new__(
31
+ mcs,
32
+ name: str,
33
+ bases: tuple[type, ...],
34
+ attrs: dict[str, Any],
35
+ category: OpCategory | None = None,
36
+ ) -> type:
25
37
  cls: type = super().__new__(mcs, name, bases, attrs)
26
38
  if category is not None:
27
39
  # It's the base class.
28
- setattr(cls, '_op_category', category)
40
+ setattr(cls, "_op_category", category)
29
41
  else:
30
42
  # It's the specific class providing specific fields.
31
43
  cls = dataclasses.dataclass(cls)
32
44
  return cls
33
45
 
34
- class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
46
+
47
+ class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
35
48
  """A source spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
36
49
 
37
- class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
50
+
51
+ class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
38
52
  """A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
39
53
 
40
- class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods
54
+
55
+ class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods
41
56
  """A storage spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
42
57
 
43
- class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
58
+
59
+ class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
44
60
  """A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
61
+
62
+
45
63
  class Executor(Protocol):
46
64
  """An executor for an operation."""
65
+
47
66
  op_category: OpCategory
48
67
 
49
68
 
@@ -55,7 +74,9 @@ class _FunctionExecutorFactory:
55
74
  self._spec_cls = spec_cls
56
75
  self._executor_cls = executor_cls
57
76
 
58
- def __call__(self, spec: dict[str, Any], *args, **kwargs):
77
+ def __call__(
78
+ self, spec: dict[str, Any], *args: Any, **kwargs: Any
79
+ ) -> tuple[dict[str, Any], Executor]:
59
80
  spec = self._spec_cls(**spec)
60
81
  executor = self._executor_cls(spec)
61
82
  result_type = executor.analyze(*args, **kwargs)
@@ -64,6 +85,7 @@ class _FunctionExecutorFactory:
64
85
 
65
86
  _gpu_dispatch_lock = asyncio.Lock()
66
87
 
88
+
67
89
  @dataclasses.dataclass
68
90
  class OpArgs:
69
91
  """
@@ -72,44 +94,50 @@ class OpArgs:
72
94
  - behavior_version: The behavior version of the executor. Cache will be invalidated if it
73
95
  changes. Must be provided if `cache` is True.
74
96
  """
97
+
75
98
  gpu: bool = False
76
99
  cache: bool = False
77
100
  behavior_version: int | None = None
78
101
 
79
- def _to_async_call(call: Callable) -> Callable[..., Awaitable[Any]]:
102
+
103
+ def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
80
104
  if inspect.iscoroutinefunction(call):
81
105
  return call
82
106
  return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
83
107
 
108
+
84
109
  def _register_op_factory(
85
- category: OpCategory,
86
- expected_args: list[tuple[str, inspect.Parameter]],
87
- expected_return,
88
- executor_cls: type,
89
- spec_cls: type,
90
- op_args: OpArgs,
91
- ):
110
+ category: OpCategory,
111
+ expected_args: list[tuple[str, inspect.Parameter]],
112
+ expected_return: Any,
113
+ executor_cls: type,
114
+ spec_cls: type,
115
+ op_args: OpArgs,
116
+ ) -> type:
92
117
  """
93
118
  Register an op factory.
94
119
  """
120
+
95
121
  class _Fallback:
96
- def enable_cache(self):
122
+ def enable_cache(self) -> bool:
97
123
  return op_args.cache
98
124
 
99
- def behavior_version(self):
125
+ def behavior_version(self) -> int | None:
100
126
  return op_args.behavior_version
101
127
 
102
- class _WrappedClass(executor_cls, _Fallback):
128
+ class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
103
129
  _args_decoders: list[Callable[[Any], Any]]
104
- _kwargs_decoders: dict[str, Callable[[str, Any], Any]]
105
- _acall: Callable
130
+ _kwargs_decoders: dict[str, Callable[[Any], Any]]
131
+ _acall: Callable[..., Awaitable[Any]]
106
132
 
107
- def __init__(self, spec):
133
+ def __init__(self, spec: Any) -> None:
108
134
  super().__init__()
109
135
  self.spec = spec
110
136
  self._acall = _to_async_call(super().__call__)
111
137
 
112
- def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema):
138
+ def analyze(
139
+ self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema
140
+ ) -> Any:
113
141
  """
114
142
  Analyze the spec and arguments. In this phase, argument types should be validated.
115
143
  It should return the expected result type for the current op.
@@ -122,15 +150,21 @@ def _register_op_factory(
122
150
  for arg in args:
123
151
  if next_param_idx >= len(expected_args):
124
152
  raise ValueError(
125
- f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
153
+ f"Too many arguments passed in: {len(args)} > {len(expected_args)}"
154
+ )
126
155
  arg_name, arg_param = expected_args[next_param_idx]
127
156
  if arg_param.kind in (
128
- inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD):
157
+ inspect.Parameter.KEYWORD_ONLY,
158
+ inspect.Parameter.VAR_KEYWORD,
159
+ ):
129
160
  raise ValueError(
130
- f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
161
+ f"Too many positional arguments passed in: {len(args)} > {next_param_idx}"
162
+ )
131
163
  self._args_decoders.append(
132
164
  make_engine_value_decoder(
133
- [arg_name], arg.value_type['type'], arg_param.annotation))
165
+ [arg_name], arg.value_type["type"], arg_param.annotation
166
+ )
167
+ )
134
168
  if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
135
169
  next_param_idx += 1
136
170
 
@@ -138,45 +172,72 @@ def _register_op_factory(
138
172
 
139
173
  for kwarg_name, kwarg in kwargs.items():
140
174
  expected_arg = next(
141
- (arg for arg in expected_kwargs
142
- if (arg[0] == kwarg_name and arg[1].kind in (
143
- inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
144
- or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
145
- None)
175
+ (
176
+ arg
177
+ for arg in expected_kwargs
178
+ if (
179
+ arg[0] == kwarg_name
180
+ and arg[1].kind
181
+ in (
182
+ inspect.Parameter.KEYWORD_ONLY,
183
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
184
+ )
185
+ )
186
+ or arg[1].kind == inspect.Parameter.VAR_KEYWORD
187
+ ),
188
+ None,
189
+ )
146
190
  if expected_arg is None:
147
- raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
191
+ raise ValueError(
192
+ f"Unexpected keyword argument passed in: {kwarg_name}"
193
+ )
148
194
  arg_param = expected_arg[1]
149
195
  self._kwargs_decoders[kwarg_name] = make_engine_value_decoder(
150
- [kwarg_name], kwarg.value_type['type'], arg_param.annotation)
151
-
152
- missing_args = [name for (name, arg) in expected_kwargs
153
- if arg.default is inspect.Parameter.empty
154
- and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or
155
- (arg.kind in (inspect.Parameter.KEYWORD_ONLY,
156
- inspect.Parameter.POSITIONAL_OR_KEYWORD)
157
- and name not in kwargs))]
196
+ [kwarg_name], kwarg.value_type["type"], arg_param.annotation
197
+ )
198
+
199
+ missing_args = [
200
+ name
201
+ for (name, arg) in expected_kwargs
202
+ if arg.default is inspect.Parameter.empty
203
+ and (
204
+ arg.kind == inspect.Parameter.POSITIONAL_ONLY
205
+ or (
206
+ arg.kind
207
+ in (
208
+ inspect.Parameter.KEYWORD_ONLY,
209
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
210
+ )
211
+ and name not in kwargs
212
+ )
213
+ )
214
+ ]
158
215
  if len(missing_args) > 0:
159
216
  raise ValueError(f"Missing arguments: {', '.join(missing_args)}")
160
217
 
161
- prepare_method = getattr(executor_cls, 'analyze', None)
218
+ prepare_method = getattr(executor_cls, "analyze", None)
162
219
  if prepare_method is not None:
163
220
  return prepare_method(self, *args, **kwargs)
164
221
  else:
165
222
  return expected_return
166
223
 
167
- async def prepare(self):
224
+ async def prepare(self) -> None:
168
225
  """
169
226
  Prepare for execution.
170
227
  It's executed after `analyze` and before any `__call__` execution.
171
228
  """
172
- setup_method = getattr(super(), 'prepare', None)
229
+ setup_method = getattr(super(), "prepare", None)
173
230
  if setup_method is not None:
174
231
  await _to_async_call(setup_method)()
175
232
 
176
- async def __call__(self, *args, **kwargs):
177
- decoded_args = (decoder(arg) for decoder, arg in zip(self._args_decoders, args))
178
- decoded_kwargs = {arg_name: self._kwargs_decoders[arg_name](arg)
179
- for arg_name, arg in kwargs.items()}
233
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
234
+ decoded_args = (
235
+ decoder(arg) for decoder, arg in zip(self._args_decoders, args)
236
+ )
237
+ decoded_kwargs = {
238
+ arg_name: self._kwargs_decoders[arg_name](arg)
239
+ for arg_name, arg in kwargs.items()
240
+ }
180
241
 
181
242
  if op_args.gpu:
182
243
  # For GPU executions, data-level parallelism is applied, so we don't want to
@@ -198,13 +259,15 @@ def _register_op_factory(
198
259
 
199
260
  if category == OpCategory.FUNCTION:
200
261
  _engine.register_function_factory(
201
- spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass))
262
+ spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)
263
+ )
202
264
  else:
203
265
  raise ValueError(f"Unsupported executor type {category}")
204
266
 
205
267
  return _WrappedClass
206
268
 
207
- def executor_class(**args) -> Callable[[type], type]:
269
+
270
+ def executor_class(**args: Any) -> Callable[[type], type]:
208
271
  """
209
272
  Decorate a class to provide an executor for an op.
210
273
  """
@@ -216,9 +279,9 @@ def executor_class(**args) -> Callable[[type], type]:
216
279
  """
217
280
  # Use `__annotations__` instead of `get_type_hints`, to avoid resolving forward references.
218
281
  type_hints = cls.__annotations__
219
- if 'spec' not in type_hints:
282
+ if "spec" not in type_hints:
220
283
  raise TypeError("Expect a `spec` field with type hint")
221
- spec_cls = resolve_forward_ref(type_hints['spec'])
284
+ spec_cls = resolve_forward_ref(type_hints["spec"])
222
285
  sig = inspect.signature(cls.__call__)
223
286
  return _register_op_factory(
224
287
  category=spec_cls._op_category,
@@ -226,34 +289,35 @@ def executor_class(**args) -> Callable[[type], type]:
226
289
  expected_return=sig.return_annotation,
227
290
  executor_cls=cls,
228
291
  spec_cls=spec_cls,
229
- op_args=op_args)
292
+ op_args=op_args,
293
+ )
230
294
 
231
295
  return _inner
232
296
 
233
- def function(**args) -> Callable[[Callable], FunctionSpec]:
297
+
298
+ def function(**args: Any) -> Callable[[Callable[..., Any]], FunctionSpec]:
234
299
  """
235
300
  Decorate a function to provide a function for an op.
236
301
  """
237
302
  op_args = OpArgs(**args)
238
303
 
239
- def _inner(fn: Callable) -> FunctionSpec:
240
-
304
+ def _inner(fn: Callable[..., Any]) -> FunctionSpec:
241
305
  # Convert snake case to camel case.
242
- op_name = ''.join(word.capitalize() for word in fn.__name__.split('_'))
306
+ op_name = "".join(word.capitalize() for word in fn.__name__.split("_"))
243
307
  sig = inspect.signature(fn)
244
308
 
245
309
  class _Executor:
246
- def __call__(self, *args, **kwargs):
310
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
247
311
  return fn(*args, **kwargs)
248
312
 
249
313
  class _Spec(FunctionSpec):
250
- pass
314
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
315
+ return fn(*args, **kwargs)
251
316
 
252
317
  _Spec.__name__ = op_name
253
318
  _Spec.__doc__ = fn.__doc__
254
319
  _Spec.__module__ = fn.__module__
255
320
  _Spec.__qualname__ = fn.__qualname__
256
- _Spec.__wrapped__ = fn
257
321
 
258
322
  _register_op_factory(
259
323
  category=OpCategory.FUNCTION,
@@ -261,7 +325,8 @@ def function(**args) -> Callable[[Callable], FunctionSpec]:
261
325
  expected_return=sig.return_annotation,
262
326
  executor_cls=_Executor,
263
327
  spec_cls=_Spec,
264
- op_args=op_args)
328
+ op_args=op_args,
329
+ )
265
330
 
266
331
  return _Spec()
267
332
 
cocoindex/query.py CHANGED
@@ -4,25 +4,29 @@ from threading import Lock
4
4
 
5
5
  from . import flow as fl
6
6
  from . import index
7
- from . import _engine
7
+ from . import _engine # type: ignore
8
8
 
9
9
  _handlers_lock = Lock()
10
10
  _handlers: dict[str, _engine.SimpleSemanticsQueryHandler] = {}
11
11
 
12
+
12
13
  @dataclass
13
14
  class SimpleSemanticsQueryInfo:
14
15
  """
15
16
  Additional information about the query.
16
17
  """
18
+
17
19
  similarity_metric: index.VectorSimilarityMetric
18
20
  query_vector: list[float]
19
21
  vector_field_name: str
20
22
 
23
+
21
24
  @dataclass
22
25
  class QueryResult:
23
26
  """
24
27
  A single result from the query.
25
28
  """
29
+
26
30
  data: dict[str, Any]
27
31
  score: float
28
32
 
@@ -31,6 +35,7 @@ class SimpleSemanticsQueryHandler:
31
35
  """
32
36
  A query handler that uses simple semantics to query the index.
33
37
  """
38
+
34
39
  _lazy_query_handler: Callable[[], _engine.SimpleSemanticsQueryHandler]
35
40
 
36
41
  def __init__(
@@ -38,22 +43,28 @@ class SimpleSemanticsQueryHandler:
38
43
  name: str,
39
44
  flow: fl.Flow,
40
45
  target_name: str,
41
- query_transform_flow: Callable[..., fl.DataSlice],
42
- default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY) -> None:
43
-
46
+ query_transform_flow: Callable[..., fl.DataSlice[Any]],
47
+ default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY,
48
+ ) -> None:
44
49
  engine_handler = None
45
50
  lock = Lock()
51
+
46
52
  def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler:
47
53
  nonlocal engine_handler, lock
48
54
  if engine_handler is None:
49
55
  with lock:
50
56
  if engine_handler is None:
51
57
  engine_handler = _engine.SimpleSemanticsQueryHandler(
52
- flow.internal_flow(), target_name,
53
- fl.TransformFlow(query_transform_flow, [str]).internal_flow(),
54
- default_similarity_metric.value)
58
+ flow.internal_flow(),
59
+ target_name,
60
+ fl.TransformFlow(
61
+ query_transform_flow, [str]
62
+ ).internal_flow(),
63
+ default_similarity_metric.value,
64
+ )
55
65
  engine_handler.register_query_handler(name)
56
66
  return engine_handler
67
+
57
68
  self._lazy_query_handler = _lazy_handler
58
69
 
59
70
  with _handlers_lock:
@@ -65,24 +76,36 @@ class SimpleSemanticsQueryHandler:
65
76
  """
66
77
  return self._lazy_query_handler()
67
78
 
68
- def search(self, query: str, limit: int, vector_field_name: str | None = None,
69
- similarity_metric: index.VectorSimilarityMetric | None = None
70
- ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
79
+ def search(
80
+ self,
81
+ query: str,
82
+ limit: int,
83
+ vector_field_name: str | None = None,
84
+ similarity_metric: index.VectorSimilarityMetric | None = None,
85
+ ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
71
86
  """
72
87
  Search the index with the given query, limit, vector field name, and similarity metric.
73
88
  """
74
89
  internal_results, internal_info = self.internal_handler().search(
75
- query, limit, vector_field_name,
76
- similarity_metric.value if similarity_metric is not None else None)
77
- results = [QueryResult(data=result['data'], score=result['score'])
78
- for result in internal_results]
90
+ query,
91
+ limit,
92
+ vector_field_name,
93
+ similarity_metric.value if similarity_metric is not None else None,
94
+ )
95
+ results = [
96
+ QueryResult(data=result["data"], score=result["score"])
97
+ for result in internal_results
98
+ ]
79
99
  info = SimpleSemanticsQueryInfo(
80
- similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
81
- query_vector=internal_info['query_vector'],
82
- vector_field_name=internal_info['vector_field_name']
100
+ similarity_metric=index.VectorSimilarityMetric(
101
+ internal_info["similarity_metric"]
102
+ ),
103
+ query_vector=internal_info["query_vector"],
104
+ vector_field_name=internal_info["vector_field_name"],
83
105
  )
84
106
  return results, info
85
107
 
108
+
86
109
  def ensure_all_handlers_built() -> None:
87
110
  """
88
111
  Ensure all handlers are built.