cocoindex 0.1.25__cp311-cp311-win_amd64.whl → 0.1.27__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cocoindex/__init__.py CHANGED
@@ -9,4 +9,5 @@ from .llm import LlmSpec, LlmApiType
9
9
  from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
10
10
  from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
11
11
  from .lib import *
12
- from ._engine import OpArgSchema
12
+ from ._engine import OpArgSchema
13
+ from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json
Binary file
cocoindex/cli.py CHANGED
@@ -4,6 +4,7 @@ import datetime
4
4
 
5
5
  from . import flow, lib
6
6
  from .setup import sync_setup, drop_setup, flow_names_with_setup, apply_setup_changes
7
+ from .runtime import execution_context
7
8
 
8
9
  @click.group()
9
10
  def cli():
@@ -113,11 +114,13 @@ def update(flow_name: str | None, live: bool, quiet: bool):
113
114
  Update the index to reflect the latest data from data sources.
114
115
  """
115
116
  options = flow.FlowLiveUpdaterOptions(live_mode=live, print_stats=not quiet)
116
- if flow_name is None:
117
- asyncio.run(flow.update_all_flows(options))
118
- else:
119
- updater = flow.FlowLiveUpdater(_flow_by_name(flow_name), options)
120
- asyncio.run(updater.wait())
117
+ async def _update():
118
+ if flow_name is None:
119
+ await flow.update_all_flows(options)
120
+ else:
121
+ updater = await flow.FlowLiveUpdater.create(_flow_by_name(flow_name), options)
122
+ await updater.wait()
123
+ execution_context.run(_update())
121
124
 
122
125
  @cli.command()
123
126
  @click.argument("flow_name", type=str, required=False)
@@ -167,7 +170,7 @@ def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None
167
170
  lib.start_server(lib.ServerSettings(address=address, cors_origin=cors_origin))
168
171
  if live_update:
169
172
  options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
170
- asyncio.run(flow.update_all_flows(options))
173
+ execution_context.run(flow.update_all_flows(options))
171
174
  input("Press Enter to stop...")
172
175
 
173
176
 
cocoindex/convert.py CHANGED
@@ -8,25 +8,28 @@ import uuid
8
8
 
9
9
  from enum import Enum
10
10
  from typing import Any, Callable, get_origin
11
- from .typing import analyze_type_info, encode_enriched_type, COLLECTION_TYPES
11
+ from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME
12
12
 
13
- def to_engine_value(value: Any) -> Any:
14
- """Convert a Python value to an engine value."""
13
+
14
+ def encode_engine_value(value: Any) -> Any:
15
+ """Encode a Python value to an engine value."""
15
16
  if dataclasses.is_dataclass(value):
16
- return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
17
+ return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
17
18
  if isinstance(value, (list, tuple)):
18
- return [to_engine_value(v) for v in value]
19
+ return [encode_engine_value(v) for v in value]
20
+ if isinstance(value, dict):
21
+ return [[encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items()]
19
22
  if isinstance(value, uuid.UUID):
20
23
  return value.bytes
21
24
  return value
22
25
 
23
- def make_engine_value_converter(
26
+ def make_engine_value_decoder(
24
27
  field_path: list[str],
25
28
  src_type: dict[str, Any],
26
29
  dst_annotation,
27
30
  ) -> Callable[[Any], Any]:
28
31
  """
29
- Make a converter from an engine value to a Python value.
32
+ Make a decoder from an engine value to a Python value.
30
33
 
31
34
  Args:
32
35
  field_path: The path to the field in the engine value. For error messages.
@@ -34,13 +37,13 @@ def make_engine_value_converter(
34
37
  dst_annotation: The type annotation of the Python value.
35
38
 
36
39
  Returns:
37
- A converter from an engine value to a Python value.
40
+ A decoder from an engine value to a Python value.
38
41
  """
39
42
 
40
43
  src_type_kind = src_type['kind']
41
44
 
42
45
  if dst_annotation is inspect.Parameter.empty:
43
- if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES:
46
+ if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES:
44
47
  raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
45
48
  f"It's required for {src_type_kind} type.")
46
49
  return lambda value: value
@@ -53,41 +56,59 @@ def make_engine_value_converter(
53
56
  f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
54
57
 
55
58
  if dst_type_info.dataclass_type is not None:
56
- return _make_engine_struct_value_converter(
59
+ return _make_engine_struct_value_decoder(
57
60
  field_path, src_type['fields'], dst_type_info.dataclass_type)
58
61
 
59
- if src_type_kind in COLLECTION_TYPES:
62
+ if src_type_kind in TABLE_TYPES:
60
63
  field_path.append('[*]')
61
64
  elem_type_info = analyze_type_info(dst_type_info.elem_type)
62
65
  if elem_type_info.dataclass_type is None:
63
66
  raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
64
- f"declared `{dst_type_info.kind}`, a dataclass type expected")
65
- elem_converter = _make_engine_struct_value_converter(
66
- field_path, src_type['row']['fields'], elem_type_info.dataclass_type)
67
+ f"declared `{dst_type_info.kind}`, a dataclass type expected")
68
+ engine_fields_schema = src_type['row']['fields']
69
+ if elem_type_info.key_type is not None:
70
+ key_field_schema = engine_fields_schema[0]
71
+ field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
72
+ key_decoder = make_engine_value_decoder(
73
+ field_path, key_field_schema['type'], elem_type_info.key_type)
74
+ field_path.pop()
75
+ value_decoder = _make_engine_struct_value_decoder(
76
+ field_path, engine_fields_schema[1:], elem_type_info.dataclass_type)
77
+ def decode(value):
78
+ if value is None:
79
+ return None
80
+ return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
81
+ else:
82
+ elem_decoder = _make_engine_struct_value_decoder(
83
+ field_path, engine_fields_schema, elem_type_info.dataclass_type)
84
+ def decode(value):
85
+ if value is None:
86
+ return None
87
+ return [elem_decoder(v) for v in value]
67
88
  field_path.pop()
68
- return lambda value: [elem_converter(v) for v in value] if value is not None else None
89
+ return decode
69
90
 
70
91
  if src_type_kind == 'Uuid':
71
92
  return lambda value: uuid.UUID(bytes=value)
72
93
 
73
94
  return lambda value: value
74
95
 
75
- def _make_engine_struct_value_converter(
96
+ def _make_engine_struct_value_decoder(
76
97
  field_path: list[str],
77
98
  src_fields: list[dict[str, Any]],
78
99
  dst_dataclass_type: type,
79
100
  ) -> Callable[[list], Any]:
80
- """Make a converter from an engine field values to a Python value."""
101
+ """Make a decoder from an engine field values to a Python value."""
81
102
 
82
103
  src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
83
104
  def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
84
105
  src_idx = src_name_to_idx.get(name)
85
106
  if src_idx is not None:
86
107
  field_path.append(f'.{name}')
87
- field_converter = make_engine_value_converter(
108
+ field_decoder = make_engine_value_decoder(
88
109
  field_path, src_fields[src_idx]['type'], param.annotation)
89
110
  field_path.pop()
90
- return lambda values: field_converter(values[src_idx])
111
+ return lambda values: field_decoder(values[src_idx])
91
112
 
92
113
  default_value = param.default
93
114
  if default_value is inspect.Parameter.empty:
@@ -96,12 +117,12 @@ def _make_engine_struct_value_converter(
96
117
 
97
118
  return lambda _: default_value
98
119
 
99
- field_value_converters = [
120
+ field_value_decoder = [
100
121
  make_closure_for_value(name, param)
101
122
  for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
102
123
 
103
124
  return lambda values: dst_dataclass_type(
104
- *(converter(values) for converter in field_value_converters))
125
+ *(decoder(values) for decoder in field_value_decoder))
105
126
 
106
127
  def dump_engine_object(v: Any) -> Any:
107
128
  """Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
cocoindex/flow.py CHANGED
@@ -19,7 +19,7 @@ from . import index
19
19
  from . import op
20
20
  from .convert import dump_engine_object
21
21
  from .typing import encode_enriched_type
22
- from .runtime import op_execution_context
22
+ from .runtime import execution_context
23
23
 
24
24
  class _NameBuilder:
25
25
  _existing_names: set[str]
@@ -142,9 +142,9 @@ class DataSlice:
142
142
 
143
143
  def row(self) -> DataScope:
144
144
  """
145
- Return a scope representing each entry of the collection.
145
+ Return a scope representing each row of the table.
146
146
  """
147
- row_scope = self._state.engine_data_slice.collection_entry_scope()
147
+ row_scope = self._state.engine_data_slice.table_row_scope()
148
148
  return DataScope(self._state.flow_builder_state, row_scope)
149
149
 
150
150
  def for_each(self, f: Callable[[DataScope], None]) -> None:
@@ -355,6 +355,12 @@ class FlowBuilder:
355
355
  name
356
356
  )
357
357
 
358
+ def declare(self, spec: op.DeclarationSpec):
359
+ """
360
+ Add a declaration to the flow.
361
+ """
362
+ self._state.engine_flow_builder.declare(dump_engine_object(spec))
363
+
358
364
  @dataclass
359
365
  class FlowLiveUpdaterOptions:
360
366
  """
@@ -369,16 +375,29 @@ class FlowLiveUpdater:
369
375
  """
370
376
  _engine_live_updater: _engine.FlowLiveUpdater
371
377
 
372
- def __init__(self, fl: Flow, options: FlowLiveUpdaterOptions | None = None):
373
- self._engine_live_updater = _engine.FlowLiveUpdater(
374
- fl._lazy_engine_flow(), dump_engine_object(options or FlowLiveUpdaterOptions()))
378
+ def __init__(self, arg: Flow | _engine.FlowLiveUpdater, options: FlowLiveUpdaterOptions | None = None):
379
+ if isinstance(arg, _engine.FlowLiveUpdater):
380
+ self._engine_live_updater = arg
381
+ else:
382
+ self._engine_live_updater = execution_context.run(_engine.FlowLiveUpdater(
383
+ arg.internal_flow(), dump_engine_object(options or FlowLiveUpdaterOptions())))
384
+
385
+ @staticmethod
386
+ async def create(fl: Flow, options: FlowLiveUpdaterOptions | None = None) -> FlowLiveUpdater:
387
+ """
388
+ Create a live updater for a flow.
389
+ """
390
+ engine_live_updater = await _engine.FlowLiveUpdater.create(
391
+ await fl.ainternal_flow(),
392
+ dump_engine_object(options or FlowLiveUpdaterOptions()))
393
+ return FlowLiveUpdater(engine_live_updater)
375
394
 
376
395
  def __enter__(self) -> FlowLiveUpdater:
377
396
  return self
378
397
 
379
398
  def __exit__(self, exc_type, exc_value, traceback):
380
399
  self.abort()
381
- asyncio.run(self.wait())
400
+ execution_context.run(self.wait())
382
401
 
383
402
  async def __aenter__(self) -> FlowLiveUpdater:
384
403
  return self
@@ -450,7 +469,7 @@ class Flow:
450
469
  Update the index defined by the flow.
451
470
  Once the function returns, the indice is fresh up to the moment when the function is called.
452
471
  """
453
- updater = FlowLiveUpdater(self, FlowLiveUpdaterOptions(live_mode=False))
472
+ updater = await FlowLiveUpdater.create(self, FlowLiveUpdaterOptions(live_mode=False))
454
473
  await updater.wait()
455
474
  return updater.update_stats()
456
475
 
@@ -466,6 +485,12 @@ class Flow:
466
485
  """
467
486
  return self._lazy_engine_flow()
468
487
 
488
+ async def ainternal_flow(self) -> _engine.Flow:
489
+ """
490
+ Get the engine flow. The async version.
491
+ """
492
+ return await asyncio.to_thread(self.internal_flow)
493
+
469
494
  def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow:
470
495
  """
471
496
  Create a flow without really building it yet.
@@ -476,7 +501,7 @@ def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope
476
501
  root_scope = DataScope(
477
502
  flow_builder_state, flow_builder_state.engine_flow_builder.root_scope())
478
503
  fl_def(FlowBuilder(flow_builder_state), root_scope)
479
- return flow_builder_state.engine_flow_builder.build_flow(op_execution_context.event_loop)
504
+ return flow_builder_state.engine_flow_builder.build_flow(execution_context.event_loop)
480
505
 
481
506
  return Flow(_create_engine_flow)
482
507
 
@@ -523,17 +548,23 @@ def ensure_all_flows_built() -> None:
523
548
  """
524
549
  Ensure all flows are built.
525
550
  """
526
- with _flows_lock:
527
- for fl in _flows.values():
528
- fl.internal_flow()
551
+ for fl in flows():
552
+ fl.internal_flow()
553
+
554
+ async def aensure_all_flows_built() -> None:
555
+ """
556
+ Ensure all flows are built.
557
+ """
558
+ for fl in flows():
559
+ await fl.ainternal_flow()
529
560
 
530
561
  async def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]:
531
562
  """
532
563
  Update all flows.
533
564
  """
534
- ensure_all_flows_built()
565
+ await aensure_all_flows_built()
535
566
  async def _update_flow(fl: Flow) -> _engine.IndexUpdateInfo:
536
- updater = FlowLiveUpdater(fl, options)
567
+ updater = await FlowLiveUpdater.create(fl, options)
537
568
  await updater.wait()
538
569
  return updater.update_stats()
539
570
  fls = flows()
@@ -572,7 +603,7 @@ class TransientFlow:
572
603
  flow_builder_state.engine_flow_builder.set_direct_output(
573
604
  _data_slice_state(output).engine_data_slice)
574
605
  self._engine_flow = flow_builder_state.engine_flow_builder.build_transient_flow(
575
- op_execution_context.event_loop)
606
+ execution_context.event_loop)
576
607
 
577
608
  def __str__(self):
578
609
  return str(self._engine_flow)
cocoindex/functions.py CHANGED
@@ -41,7 +41,7 @@ class SentenceTransformerEmbedExecutor:
41
41
  args = self.spec.args or {}
42
42
  self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
43
43
  dim = self._model.get_sentence_embedding_dimension()
44
- return Annotated[list[Float32], Vector(dim=dim), TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)]
44
+ return Annotated[Vector[Float32, dim], TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)]
45
45
 
46
46
  def __call__(self, text: str) -> list[Float32]:
47
47
  return self._model.encode(text).tolist()
cocoindex/lib.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """
2
2
  Library level functions and states.
3
3
  """
4
- import asyncio
5
4
  import os
6
5
  import sys
7
6
  import functools
@@ -12,6 +11,7 @@ from dataclasses import dataclass
12
11
 
13
12
  from . import _engine
14
13
  from . import flow, query, cli
14
+ from .convert import dump_engine_object
15
15
 
16
16
 
17
17
  def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
@@ -22,24 +22,32 @@ def _load_field(target: dict[str, str], name: str, env_name: str, required: bool
22
22
  else:
23
23
  target[name] = value
24
24
 
25
+ @dataclass
26
+ class DatabaseConnectionSpec:
27
+ url: str
28
+ user: str | None = None
29
+ password: str | None = None
30
+
25
31
  @dataclass
26
32
  class Settings:
27
33
  """Settings for the cocoindex library."""
28
- database_url: str
34
+ database: DatabaseConnectionSpec
29
35
 
30
36
  @classmethod
31
37
  def from_env(cls) -> Self:
32
38
  """Load settings from environment variables."""
33
39
 
34
- kwargs: dict[str, str] = dict()
35
- _load_field(kwargs, "database_url", "COCOINDEX_DATABASE_URL", required=True)
36
-
37
- return cls(**kwargs)
40
+ db_kwargs: dict[str, str] = dict()
41
+ _load_field(db_kwargs, "url", "COCOINDEX_DATABASE_URL", required=True)
42
+ _load_field(db_kwargs, "user", "COCOINDEX_DATABASE_USER")
43
+ _load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD")
44
+ database = DatabaseConnectionSpec(**db_kwargs)
45
+ return cls(database=database)
38
46
 
39
47
 
40
48
  def init(settings: Settings):
41
49
  """Initialize the cocoindex library."""
42
- _engine.init(settings.__dict__)
50
+ _engine.init(dump_engine_object(settings))
43
51
 
44
52
  @dataclass
45
53
  class ServerSettings:
@@ -101,7 +109,8 @@ def main_fn(
101
109
  try:
102
110
  if _should_run_cli():
103
111
  # Schedule to a separate thread as it invokes nested event loop.
104
- return await asyncio.to_thread(_run_cli)
112
+ # return await asyncio.to_thread(_run_cli)
113
+ return _run_cli()
105
114
  return await fn(*args, **kwargs)
106
115
  finally:
107
116
  stop()
cocoindex/llm.py CHANGED
@@ -5,6 +5,8 @@ class LlmApiType(Enum):
5
5
  """The type of LLM API to use."""
6
6
  OPENAI = "OpenAi"
7
7
  OLLAMA = "Ollama"
8
+ GEMINI = "Gemini"
9
+ ANTHROPIC = "Anthropic"
8
10
 
9
11
  @dataclass
10
12
  class LlmSpec:
cocoindex/op.py CHANGED
@@ -7,10 +7,9 @@ import inspect
7
7
 
8
8
  from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform
9
9
  from enum import Enum
10
- from functools import partial
11
10
 
12
11
  from .typing import encode_enriched_type
13
- from .convert import to_engine_value, make_engine_value_converter
12
+ from .convert import encode_engine_value, make_engine_value_decoder
14
13
  from . import _engine
15
14
 
16
15
  class OpCategory(Enum):
@@ -18,7 +17,7 @@ class OpCategory(Enum):
18
17
  FUNCTION = "function"
19
18
  SOURCE = "source"
20
19
  STORAGE = "storage"
21
-
20
+ DECLARATION = "declaration"
22
21
  @dataclass_transform()
23
22
  class SpecMeta(type):
24
23
  """Meta class for spec classes."""
@@ -41,6 +40,8 @@ class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint:
41
40
  class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods
42
41
  """A storage spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
43
42
 
43
+ class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods
44
+ """A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
44
45
  class Executor(Protocol):
45
46
  """An executor for an operation."""
46
47
  op_category: OpCategory
@@ -128,7 +129,7 @@ def _register_op_factory(
128
129
  raise ValueError(
129
130
  f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
130
131
  self._args_converters.append(
131
- make_engine_value_converter(
132
+ make_engine_value_decoder(
132
133
  [arg_name], arg.value_type['type'], arg_param.annotation))
133
134
  if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
134
135
  next_param_idx += 1
@@ -145,7 +146,7 @@ def _register_op_factory(
145
146
  if expected_arg is None:
146
147
  raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
147
148
  arg_param = expected_arg[1]
148
- self._kwargs_converters[kwarg_name] = make_engine_value_converter(
149
+ self._kwargs_converters[kwarg_name] = make_engine_value_decoder(
149
150
  [kwarg_name], kwarg.value_type['type'], arg_param.annotation)
150
151
 
151
152
  missing_args = [name for (name, arg) in expected_kwargs
@@ -187,7 +188,7 @@ def _register_op_factory(
187
188
  output = await self._acall(*converted_args, **converted_kwargs)
188
189
  else:
189
190
  output = await self._acall(*converted_args, **converted_kwargs)
190
- return to_engine_value(output)
191
+ return encode_engine_value(output)
191
192
 
192
193
  _WrappedClass.__name__ = executor_cls.__name__
193
194
  _WrappedClass.__doc__ = executor_cls.__doc__
cocoindex/query.py CHANGED
@@ -66,15 +66,16 @@ class SimpleSemanticsQueryHandler:
66
66
  return self._lazy_query_handler()
67
67
 
68
68
  def search(self, query: str, limit: int, vector_field_name: str | None = None,
69
- similarity_matric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
69
+ similarity_metric: index.VectorSimilarityMetric | None = None
70
+ ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
70
71
  """
71
72
  Search the index with the given query, limit, vector field name, and similarity metric.
72
73
  """
73
74
  internal_results, internal_info = self.internal_handler().search(
74
75
  query, limit, vector_field_name,
75
- similarity_matric.value if similarity_matric is not None else None)
76
- fields = [field['name'] for field in internal_results['fields']]
77
- results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']]
76
+ similarity_metric.value if similarity_metric is not None else None)
77
+ results = [QueryResult(data=result['data'], score=result['score'])
78
+ for result in internal_results]
78
79
  info = SimpleSemanticsQueryInfo(
79
80
  similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
80
81
  query_vector=internal_info['query_vector'],
cocoindex/runtime.py CHANGED
@@ -1,7 +1,12 @@
1
+ """
2
+ This module provides a standalone execution runtime for executing coroutines in a thread-safe
3
+ manner.
4
+ """
5
+
1
6
  import threading
2
7
  import asyncio
3
-
4
- class _OpExecutionContext:
8
+ from typing import Coroutine
9
+ class _ExecutionContext:
5
10
  _lock: threading.Lock
6
11
  _event_loop: asyncio.AbstractEventLoop | None = None
7
12
 
@@ -14,8 +19,11 @@ class _OpExecutionContext:
14
19
  with self._lock:
15
20
  if self._event_loop is None:
16
21
  self._event_loop = asyncio.new_event_loop()
17
- asyncio.set_event_loop(self._event_loop)
18
22
  threading.Thread(target=self._event_loop.run_forever, daemon=True).start()
19
23
  return self._event_loop
20
24
 
21
- op_execution_context = _OpExecutionContext()
25
+ def run(self, coro: Coroutine):
26
+ """Run a coroutine in the event loop, blocking until it finishes. Return its result."""
27
+ return asyncio.run_coroutine_threadsafe(coro, self.event_loop).result()
28
+
29
+ execution_context = _ExecutionContext()
cocoindex/storages.py CHANGED
@@ -9,7 +9,7 @@ from .auth_registry import AuthEntryReference
9
9
  class Postgres(op.StorageSpec):
10
10
  """Storage powered by Postgres and pgvector."""
11
11
 
12
- database_url: str | None = None
12
+ database: AuthEntryReference | None = None
13
13
  table_name: str | None = None
14
14
 
15
15
  @dataclass
@@ -43,8 +43,9 @@ class NodeReferenceMapping:
43
43
  fields: list[TargetFieldMapping]
44
44
 
45
45
  @dataclass
46
- class NodeStorageSpec:
46
+ class ReferencedNode:
47
47
  """Storage spec for a graph node."""
48
+ label: str
48
49
  primary_key_fields: Sequence[str]
49
50
  vector_indexes: Sequence[index.VectorIndexDef] = ()
50
51
 
@@ -63,10 +64,16 @@ class RelationshipMapping:
63
64
  rel_type: str
64
65
  source: NodeReferenceMapping
65
66
  target: NodeReferenceMapping
66
- nodes_storage_spec: dict[str, NodeStorageSpec] | None = None
67
67
 
68
68
  class Neo4j(op.StorageSpec):
69
69
  """Graph storage powered by Neo4j."""
70
70
 
71
71
  connection: AuthEntryReference
72
72
  mapping: NodeMapping | RelationshipMapping
73
+
74
+ class Neo4jDeclarations(op.DeclarationSpec):
75
+ """Declarations for Neo4j."""
76
+
77
+ kind = "Neo4j"
78
+ connection: AuthEntryReference
79
+ referenced_nodes: Sequence[ReferencedNode] = ()
@@ -1,15 +1,21 @@
1
- import dataclasses
2
1
  import uuid
3
2
  import datetime
4
- from dataclasses import dataclass
3
+ from dataclasses import dataclass, make_dataclass
5
4
  import pytest
6
- from cocoindex.convert import to_engine_value
7
-
5
+ import cocoindex
6
+ from cocoindex.typing import encode_enriched_type
7
+ from cocoindex.convert import encode_engine_value, make_engine_value_decoder
8
+ from typing import Literal
8
9
  @dataclass
9
10
  class Order:
10
11
  order_id: str
11
12
  name: str
12
13
  price: float
14
+ extra_field: str = "default_extra"
15
+
16
+ @dataclass
17
+ class Tag:
18
+ name: str
13
19
 
14
20
  @dataclass
15
21
  class Basket:
@@ -19,51 +25,256 @@ class Basket:
19
25
  class Customer:
20
26
  name: str
21
27
  order: Order
28
+ tags: list[Tag] | None = None
29
+
30
+ @dataclass
31
+ class NestedStruct:
32
+ customer: Customer
33
+ orders: list[Order]
34
+ count: int = 0
35
+
36
+ def build_engine_value_decoder(engine_type_in_py, python_type=None):
37
+ """
38
+ Helper to build a converter for the given engine-side type (as represented in Python).
39
+ If python_type is not specified, uses engine_type_in_py as the target.
40
+ """
41
+ engine_type = encode_enriched_type(engine_type_in_py)["type"]
42
+ return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py)
22
43
 
23
- def test_to_engine_value_basic_types():
24
- assert to_engine_value(123) == 123
25
- assert to_engine_value(3.14) == 3.14
26
- assert to_engine_value("hello") == "hello"
27
- assert to_engine_value(True) is True
44
+ def test_encode_engine_value_basic_types():
45
+ assert encode_engine_value(123) == 123
46
+ assert encode_engine_value(3.14) == 3.14
47
+ assert encode_engine_value("hello") == "hello"
48
+ assert encode_engine_value(True) is True
28
49
 
29
- def test_to_engine_value_uuid():
50
+ def test_encode_engine_value_uuid():
30
51
  u = uuid.uuid4()
31
- assert to_engine_value(u) == u.bytes
52
+ assert encode_engine_value(u) == u.bytes
32
53
 
33
- def test_to_engine_value_date_time_types():
54
+ def test_encode_engine_value_date_time_types():
34
55
  d = datetime.date(2024, 1, 1)
35
- assert to_engine_value(d) == d
56
+ assert encode_engine_value(d) == d
36
57
  t = datetime.time(12, 30)
37
- assert to_engine_value(t) == t
58
+ assert encode_engine_value(t) == t
38
59
  dt = datetime.datetime(2024, 1, 1, 12, 30)
39
- assert to_engine_value(dt) == dt
60
+ assert encode_engine_value(dt) == dt
40
61
 
41
- def test_to_engine_value_struct():
62
+ def test_encode_engine_value_struct():
42
63
  order = Order(order_id="O123", name="mixed nuts", price=25.0)
43
- assert to_engine_value(order) == ["O123", "mixed nuts", 25.0]
64
+ assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
44
65
 
45
- def test_to_engine_value_list_of_structs():
66
+ def test_encode_engine_value_list_of_structs():
46
67
  orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
47
- assert to_engine_value(orders) == [["O1", "item1", 10.0], ["O2", "item2", 20.0]]
68
+ assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
48
69
 
49
- def test_to_engine_value_struct_with_list():
70
+ def test_encode_engine_value_struct_with_list():
50
71
  basket = Basket(items=["apple", "banana"])
51
- assert to_engine_value(basket) == [["apple", "banana"]]
72
+ assert encode_engine_value(basket) == [["apple", "banana"]]
52
73
 
53
- def test_to_engine_value_nested_struct():
74
+ def test_encode_engine_value_nested_struct():
54
75
  customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
55
- assert to_engine_value(customer) == ["Alice", ["O1", "item1", 10.0]]
76
+ assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
77
+
78
+ def test_encode_engine_value_empty_list():
79
+ assert encode_engine_value([]) == []
80
+ assert encode_engine_value([[]]) == [[]]
81
+
82
+ def test_encode_engine_value_tuple():
83
+ assert encode_engine_value(()) == []
84
+ assert encode_engine_value((1, 2, 3)) == [1, 2, 3]
85
+ assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
86
+ assert encode_engine_value(([],)) == [[]]
87
+ assert encode_engine_value(((),)) == [[]]
88
+
89
+ def test_encode_engine_value_none():
90
+ assert encode_engine_value(None) is None
91
+
92
+ def test_make_engine_value_decoder_basic_types():
93
+ for engine_type_in_py, value in [
94
+ (int, 42),
95
+ (float, 3.14),
96
+ (str, "hello"),
97
+ (bool, True),
98
+ # (type(None), None), # Removed unsupported NoneType
99
+ ]:
100
+ decoder = build_engine_value_decoder(engine_type_in_py)
101
+ assert decoder(value) == value
102
+
103
+ @pytest.mark.parametrize(
104
+ "data_type, engine_val, expected",
105
+ [
106
+ # All fields match
107
+ (Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
108
+ # Extra field in engine value (should ignore extra)
109
+ (Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")),
110
+ # Fewer fields in engine value (should fill with default)
111
+ (Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")),
112
+ # More fields in engine value (should ignore extra)
113
+ (Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")),
114
+ # Truly extra field (should ignore the fifth field)
115
+ (Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")),
116
+ # Missing optional field in engine value (tags=None)
117
+ (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
118
+ # Extra field in engine value for Customer (should ignore)
119
+ (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
120
+ ]
121
+ )
122
+ def test_struct_decoder_cases(data_type, engine_val, expected):
123
+ decoder = build_engine_value_decoder(data_type)
124
+ assert decoder(engine_val) == expected
125
+
126
+ def test_make_engine_value_decoder_collections():
127
+ # List of structs
128
+ decoder = build_engine_value_decoder(list[Order])
129
+ engine_val = [
130
+ ["O1", "item1", 10.0, "default_extra"],
131
+ ["O2", "item2", 20.0, "default_extra"]
132
+ ]
133
+ assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
134
+ # Struct with list field
135
+ decoder = build_engine_value_decoder(Customer)
136
+ engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
137
+ assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
138
+ # Struct with struct field
139
+ decoder = build_engine_value_decoder(NestedStruct)
140
+ engine_val = [
141
+ ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
142
+ [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]],
143
+ 2
144
+ ]
145
+ assert decoder(engine_val) == NestedStruct(
146
+ Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
147
+ [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")],
148
+ 2
149
+ )
150
+
151
+ def make_engine_order(fields):
152
+ return make_dataclass('EngineOrder', fields)
153
+
154
+ def make_python_order(fields, defaults=None):
155
+ if defaults is None:
156
+ defaults = {}
157
+ # Move all fields with defaults to the end (Python dataclass requirement)
158
+ non_default_fields = [(n, t) for n, t in fields if n not in defaults]
159
+ default_fields = [(n, t) for n, t in fields if n in defaults]
160
+ ordered_fields = non_default_fields + default_fields
161
+ # Prepare the namespace for defaults (only for fields at the end)
162
+ namespace = {k: defaults[k] for k, _ in default_fields}
163
+ return make_dataclass('PythonOrder', ordered_fields, namespace=namespace)
164
+
165
+ @pytest.mark.parametrize(
166
+ "engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
167
+ [
168
+ # Extra field in Python (middle)
169
+ (
170
+ [("id", str), ("name", str)],
171
+ [("id", str), ("price", float), ("name", str)],
172
+ {"price": 0.0},
173
+ ["O123", "mixed nuts"],
174
+ ("O123", 0.0, "mixed nuts"),
175
+ ),
176
+ # Missing field in Python (middle)
177
+ (
178
+ [("id", str), ("price", float), ("name", str)],
179
+ [("id", str), ("name", str)],
180
+ {},
181
+ ["O123", 25.0, "mixed nuts"],
182
+ ("O123", "mixed nuts"),
183
+ ),
184
+ # Extra field in Python (start)
185
+ (
186
+ [("name", str), ("price", float)],
187
+ [("extra", str), ("name", str), ("price", float)],
188
+ {"extra": "default"},
189
+ ["mixed nuts", 25.0],
190
+ ("default", "mixed nuts", 25.0),
191
+ ),
192
+ # Missing field in Python (start)
193
+ (
194
+ [("extra", str), ("name", str), ("price", float)],
195
+ [("name", str), ("price", float)],
196
+ {},
197
+ ["unexpected", "mixed nuts", 25.0],
198
+ ("mixed nuts", 25.0),
199
+ ),
200
+ # Field order difference (should map by name)
201
+ (
202
+ [("id", str), ("name", str), ("price", float)],
203
+ [("name", str), ("id", str), ("price", float), ("extra", str)],
204
+ {"extra": "default"},
205
+ ["O123", "mixed nuts", 25.0],
206
+ ("mixed nuts", "O123", 25.0, "default"),
207
+ ),
208
+ # Extra field (Python has extra field with default)
209
+ (
210
+ [("id", str), ("name", str)],
211
+ [("id", str), ("name", str), ("price", float)],
212
+ {"price": 0.0},
213
+ ["O123", "mixed nuts"],
214
+ ("O123", "mixed nuts", 0.0),
215
+ ),
216
+ # Missing field (Engine has extra field)
217
+ (
218
+ [("id", str), ("name", str), ("price", float)],
219
+ [("id", str), ("name", str)],
220
+ {},
221
+ ["O123", "mixed nuts", 25.0],
222
+ ("O123", "mixed nuts"),
223
+ ),
224
+ ]
225
+ )
226
+ def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val):
227
+ EngineOrder = make_engine_order(engine_fields)
228
+ PythonOrder = make_python_order(python_fields, python_defaults)
229
+ decoder = build_engine_value_decoder(EngineOrder, PythonOrder)
230
+ # Map field names to expected values
231
+ expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
232
+ # Instantiate using keyword arguments (order doesn't matter)
233
+ assert decoder(engine_val) == PythonOrder(**expected_dict)
234
+
235
+ def test_roundtrip_ltable():
236
+ t = list[Order]
237
+ value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
238
+ encoded = encode_engine_value(value)
239
+ assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
240
+ decoded = build_engine_value_decoder(t)(encoded)
241
+ assert decoded == value
242
+
243
+ def test_roundtrip_ktable_str_key():
244
+ t = dict[str, Order]
245
+ value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)}
246
+ encoded = encode_engine_value(value)
247
+ assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
248
+ decoded = build_engine_value_decoder(t)(encoded)
249
+ assert decoded == value
250
+
251
+ def test_roundtrip_ktable_struct_key():
252
+ @dataclass(frozen=True)
253
+ class OrderKey:
254
+ shop_id: str
255
+ version: int
256
+
257
+ t = dict[OrderKey, Order]
258
+ value = {OrderKey("A", 3): Order("O1", "item1", 10.0), OrderKey("B", 4): Order("O2", "item2", 20.0)}
259
+ encoded = encode_engine_value(value)
260
+ assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"],
261
+ [["B", 4], "O2", "item2", 20.0, "default_extra"]]
262
+ decoded = build_engine_value_decoder(t)(encoded)
263
+ assert decoded == value
56
264
 
57
- def test_to_engine_value_empty_list():
58
- assert to_engine_value([]) == []
59
- assert to_engine_value([[]]) == [[]]
265
+ IntVectorType = cocoindex.Vector[int, Literal[5]]
266
+ def test_vector_as_vector() -> None:
267
+ value: IntVectorType = [1, 2, 3, 4, 5]
268
+ encoded = encode_engine_value(value)
269
+ assert encoded == [1, 2, 3, 4, 5]
270
+ decoded = build_engine_value_decoder(IntVectorType)(encoded)
271
+ assert decoded == value
60
272
 
61
- def test_to_engine_value_tuple():
62
- assert to_engine_value(()) == []
63
- assert to_engine_value((1, 2, 3)) == [1, 2, 3]
64
- assert to_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]]
65
- assert to_engine_value(([],)) == [[]]
66
- assert to_engine_value(((),)) == [[]]
273
+ ListIntType = list[int]
274
+ def test_vector_as_list() -> None:
275
+ value: ListIntType = [1, 2, 3, 4, 5]
276
+ encoded = encode_engine_value(value)
277
+ assert encoded == [1, 2, 3, 4, 5]
278
+ decoded = build_engine_value_decoder(ListIntType)(encoded)
279
+ assert decoded == value
67
280
 
68
- def test_to_engine_value_none():
69
- assert to_engine_value(None) is None
cocoindex/typing.py CHANGED
@@ -5,9 +5,9 @@ import datetime
5
5
  import types
6
6
  import inspect
7
7
  import uuid
8
- from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload
8
+ from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload, Sequence, Protocol, Generic, Literal
9
9
 
10
- class Vector(NamedTuple):
10
+ class VectorInfo(NamedTuple):
11
11
  dim: int | None
12
12
 
13
13
  class TypeKind(NamedTuple):
@@ -21,7 +21,7 @@ class TypeAttr:
21
21
  self.key = key
22
22
  self.value = value
23
23
 
24
- Annotation = Vector | TypeKind | TypeAttr
24
+ Annotation = TypeKind | TypeAttr | VectorInfo
25
25
 
26
26
  Float32 = Annotated[float, TypeKind('Float32')]
27
27
  Float64 = Annotated[float, TypeKind('Float64')]
@@ -30,29 +30,34 @@ Json = Annotated[Any, TypeKind('Json')]
30
30
  LocalDateTime = Annotated[datetime.datetime, TypeKind('LocalDateTime')]
31
31
  OffsetDateTime = Annotated[datetime.datetime, TypeKind('OffsetDateTime')]
32
32
 
33
- COLLECTION_TYPES = ('Table', 'List')
34
-
35
- R = TypeVar("R")
36
-
37
33
  if TYPE_CHECKING:
38
- Table = Annotated[list[R], TypeKind('Table')]
39
- List = Annotated[list[R], TypeKind('List')]
34
+ T_co = TypeVar('T_co', covariant=True)
35
+ Dim_co = TypeVar('Dim_co', bound=int, covariant=True)
36
+
37
+ class Vector(Sequence[T_co], Generic[T_co, Dim_co], Protocol):
38
+ """Vector[T, Dim] is a special typing alias for a list[T] with optional dimension info"""
40
39
  else:
41
- # pylint: disable=too-few-public-methods
42
- class Table: # type: ignore[unreachable]
43
- """
44
- A Table type, which has a list of rows. The first field of each row is the key.
45
- """
46
- def __class_getitem__(cls, item: type[R]):
47
- return Annotated[list[item], TypeKind('Table')]
48
-
49
- # pylint: disable=too-few-public-methods
50
- class List: # type: ignore[unreachable]
51
- """
52
- A List type, which has a list of ordered rows.
53
- """
54
- def __class_getitem__(cls, item: type[R]):
55
- return Annotated[list[item], TypeKind('List')]
40
+ class Vector: # type: ignore[unreachable]
41
+ """ A special typing alias for a list[T] with optional dimension info """
42
+ def __class_getitem__(self, params):
43
+ if not isinstance(params, tuple):
44
+ # Only element type provided
45
+ elem_type = params
46
+ return Annotated[list[elem_type], VectorInfo(dim=None)]
47
+ else:
48
+ # Element type and dimension provided
49
+ elem_type, dim = params
50
+ if typing.get_origin(dim) is Literal:
51
+ dim = typing.get_args(dim)[0] # Extract the literal value
52
+ return Annotated[list[elem_type], VectorInfo(dim=dim)]
53
+
54
+ TABLE_TYPES = ('KTable', 'LTable')
55
+ KEY_FIELD_NAME = '_key'
56
+
57
+ ElementType = type | tuple[type, type]
58
+
59
+ def _is_struct_type(t) -> bool:
60
+ return isinstance(t, type) and dataclasses.is_dataclass(t)
56
61
 
57
62
  @dataclasses.dataclass
58
63
  class AnalyzedTypeInfo:
@@ -60,9 +65,12 @@ class AnalyzedTypeInfo:
60
65
  Analyzed info of a Python type.
61
66
  """
62
67
  kind: str
63
- vector_info: Vector | None
64
- elem_type: type | None
65
- dataclass_type: type | None
68
+ vector_info: VectorInfo | None # For Vector
69
+ elem_type: ElementType | None # For Vector and Table
70
+
71
+ key_type: type | None # For element of KTable
72
+ dataclass_type: type | None # For Struct
73
+
66
74
  attrs: dict[str, Any] | None
67
75
  nullable: bool = False
68
76
 
@@ -70,6 +78,12 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
70
78
  """
71
79
  Analyze a Python type and return the analyzed info.
72
80
  """
81
+ if isinstance(t, tuple) and len(t) == 2:
82
+ key_type, value_type = t
83
+ result = analyze_type_info(value_type)
84
+ result.key_type = key_type
85
+ return result
86
+
73
87
  annotations: tuple[Annotation, ...] = ()
74
88
  base_type = None
75
89
  nullable = False
@@ -98,33 +112,41 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
98
112
  if attrs is None:
99
113
  attrs = dict()
100
114
  attrs[attr.key] = attr.value
101
- elif isinstance(attr, Vector):
115
+ elif isinstance(attr, VectorInfo):
102
116
  vector_info = attr
103
117
  elif isinstance(attr, TypeKind):
104
118
  kind = attr.kind
105
119
 
106
120
  dataclass_type = None
107
121
  elem_type = None
108
- if isinstance(t, type) and dataclasses.is_dataclass(t):
122
+ key_type = None
123
+ if _is_struct_type(t):
109
124
  if kind is None:
110
125
  kind = 'Struct'
111
126
  elif kind != 'Struct':
112
127
  raise ValueError(f"Unexpected type kind for struct: {kind}")
113
128
  dataclass_type = t
114
129
  elif base_type is collections.abc.Sequence or base_type is list:
130
+ args = typing.get_args(t)
131
+ elem_type = args[0]
132
+
115
133
  if kind is None:
116
- kind = 'Vector' if vector_info is not None else 'List'
117
- elif not (kind == 'Vector' or kind in COLLECTION_TYPES):
134
+ if _is_struct_type(elem_type):
135
+ kind = 'LTable'
136
+ if vector_info is not None:
137
+ raise ValueError("Vector element must be a simple type, not a struct")
138
+ else:
139
+ kind = 'Vector'
140
+ if vector_info is None:
141
+ vector_info = VectorInfo(dim=None)
142
+ elif not (kind == 'Vector' or kind in TABLE_TYPES):
118
143
  raise ValueError(f"Unexpected type kind for list: {kind}")
119
-
144
+ elif base_type is collections.abc.Mapping or base_type is dict:
120
145
  args = typing.get_args(t)
121
- if len(args) != 1:
122
- raise ValueError(f"{kind} must have exactly one type argument")
123
- elem_type = args[0]
146
+ elem_type = (args[0], args[1])
147
+ kind = 'KTable'
124
148
  elif kind is None:
125
- if base_type is collections.abc.Sequence or base_type is list:
126
- kind = 'Vector' if vector_info is not None else 'List'
127
- elif t is bytes:
149
+ if t is bytes:
128
150
  kind = 'Bytes'
129
151
  elif t is str:
130
152
  kind = 'Str'
@@ -145,20 +167,26 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
145
167
  else:
146
168
  raise ValueError(f"type unsupported yet: {t}")
147
169
 
148
- return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
149
- dataclass_type=dataclass_type, attrs=attrs, nullable=nullable)
170
+ return AnalyzedTypeInfo(kind=kind, vector_info=vector_info,
171
+ elem_type=elem_type, key_type=key_type, dataclass_type=dataclass_type,
172
+ attrs=attrs, nullable=nullable)
150
173
 
151
- def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]:
174
+ def _encode_fields_schema(dataclass_type: type, key_type: type | None = None) -> list[dict[str, Any]]:
152
175
  result = []
153
- for field in dataclasses.fields(dataclass_type):
176
+ def add_field(name: str, t) -> None:
154
177
  try:
155
- type_info = encode_enriched_type_info(analyze_type_info(field.type))
178
+ type_info = encode_enriched_type_info(analyze_type_info(t))
156
179
  except ValueError as e:
157
180
  e.add_note(f"Failed to encode annotation for field - "
158
- f"{dataclass_type.__name__}.{field.name}: {field.type}")
181
+ f"{dataclass_type.__name__}.{name}: {t}")
159
182
  raise
160
- type_info['name'] = field.name
183
+ type_info['name'] = name
161
184
  result.append(type_info)
185
+
186
+ if key_type is not None:
187
+ add_field(KEY_FIELD_NAME, key_type)
188
+ for field in dataclasses.fields(dataclass_type):
189
+ add_field(field.name, field.type)
162
190
  return result
163
191
 
164
192
  def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
@@ -167,7 +195,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
167
195
  if type_info.kind == 'Struct':
168
196
  if type_info.dataclass_type is None:
169
197
  raise ValueError("Struct type must have a dataclass type")
170
- encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type)
198
+ encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type, type_info.key_type)
171
199
  if doc := inspect.getdoc(type_info.dataclass_type):
172
200
  encoded_type['description'] = doc
173
201
 
@@ -179,7 +207,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
179
207
  encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
180
208
  encoded_type['dimension'] = type_info.vector_info.dim
181
209
 
182
- elif type_info.kind in COLLECTION_TYPES:
210
+ elif type_info.kind in TABLE_TYPES:
183
211
  if type_info.elem_type is None:
184
212
  raise ValueError(f"{type_info.kind} type must have an element type")
185
213
  row_type_info = analyze_type_info(type_info.elem_type)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cocoindex
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Requires-Dist: sentence-transformers>=3.3.1
5
5
  Requires-Dist: click>=8.1.8
6
6
  Requires-Dist: pytest ; extra == 'test'
@@ -117,7 +117,7 @@ Go to the [examples directory](examples) to try out with any of the examples, fo
117
117
  | [PDF Embedding](examples/pdf_embedding) | Parse PDF and index text embeddings for semantic search |
118
118
  | [Manuals LLM Extraction](examples/manuals_llm_extraction) | Extract structured information from a manual using LLM |
119
119
  | [Google Drive Text Embedding](examples/gdrive_text_embedding) | Index text documents from Google Drive |
120
- | [Docs to Knowledge Graph](examples/docs_to_kg) | Extract relationships from Markdown documents and build a knowledge graph |
120
+ | [Docs to Knowledge Graph](examples/docs_to_knowledge_graph) | Extract relationships from Markdown documents and build a knowledge graph |
121
121
  | [Embeddings to Qdrant](examples/text_embedding_qdrant) | Index documents in a Qdrant collection for semantic search |
122
122
 
123
123
  More coming and stay tuned! If there's any specific examples you would like to see, please let us know in our [Discord community](https://discord.com/invite/zpA9S2DR7s) 🌱.
@@ -0,0 +1,24 @@
1
+ cocoindex-0.1.27.dist-info/METADATA,sha256=O6XL0i2OUPY0KehlPYXxywyfX6uMdNk1xuTYxdJOcnA,8209
2
+ cocoindex-0.1.27.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
+ cocoindex-0.1.27.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
+ cocoindex/auth_registry.py,sha256=QQI8P7NMvA_NFIHXzV-yS94wVaZuXaIjY992Vrcii1E,641
5
+ cocoindex/cli.py,sha256=kX2PQx7MJYX9UBGvyzIMqpzZSJfZA-NQJl9QI3m1ICU,7252
6
+ cocoindex/convert.py,sha256=M8wRr38AgPpd43eklyWpNMaAvKHsIEyy_L3IlU7Q0oA,6095
7
+ cocoindex/flow.py,sha256=Gu9uwDtoNz7IWnQFX6GtaSgee1NpODT1_4qS90OxqaE,21672
8
+ cocoindex/functions.py,sha256=2yPL_s908AFxb6vVleFOVs8-6SCyT5cO-M7JLDqVFzA,1694
9
+ cocoindex/index.py,sha256=32iiQI60VKhRRHle17rpoEVa_tsAHnyXXXnrLaX68uQ,557
10
+ cocoindex/lib.py,sha256=V1bNKH6lpb1WT4f4F_4qD7e4Uarq-sS-donKXH5PjFo,4001
11
+ cocoindex/llm.py,sha256=POMdB-huMvPkRDpvcaBeOgXfbm0YZa0swNOdD2s4TRc,364
12
+ cocoindex/op.py,sha256=SYjXRlMxi304OcuOLPGlovv3CYa_DeQUunx4YJjmRc0,10871
13
+ cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ cocoindex/query.py,sha256=GU3V8AoSK--xOU2yX0Ao-Rz1t3mlrsJ4JIhzsLVwFH8,3297
15
+ cocoindex/runtime.py,sha256=WwyDSpJhvaHjnxH4q6r2MVftJB-QkNVH_TxZLxuDtFw,1009
16
+ cocoindex/setup.py,sha256=euHPZ9-7a68DV4PbdK2A0cZ9ECa6vT4qBraY3WKZk4s,512
17
+ cocoindex/sources.py,sha256=4JWsW7sjFuiizaxOkxXJpZgVmVgF9ys9ATvvRY-SBjE,966
18
+ cocoindex/storages.py,sha256=AYYua_rzDUyvYb2ou3YFjO3jvRUcJKUTQlIYZzoTDnE,2036
19
+ cocoindex/tests/test_convert.py,sha256=l5VpBgDw8F-0eMT_kEitUQTyI5newGnxqsm3gb6c2XA,11516
20
+ cocoindex/tests/__init__.py,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
21
+ cocoindex/typing.py,sha256=RP322ATJqGkHkC7ht0s0tuRMg39CeNKUhk3cR_f78EY,8705
22
+ cocoindex/__init__.py,sha256=FZo8bM7rmJ3Ppg_ysAjjlWxiHPQPrNNwyHb4V-onXT0,673
23
+ cocoindex/_engine.cp311-win_amd64.pyd,sha256=79tgqK7SCVNjQldNdOWBdZRNynaQviO03U8C5FFwJpM,58859008
24
+ cocoindex-0.1.27.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- cocoindex-0.1.25.dist-info/METADATA,sha256=-ufLCvjdS1aeIfab5HeYCq0rwMWGp5L98HNWnjijtL8,8196
2
- cocoindex-0.1.25.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
- cocoindex-0.1.25.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
4
- cocoindex/auth_registry.py,sha256=QQI8P7NMvA_NFIHXzV-yS94wVaZuXaIjY992Vrcii1E,641
5
- cocoindex/cli.py,sha256=-7lJjjXQbgYZTgsbrH9yl1DvWz-8vwPCmaWHKX_IdCM,7119
6
- cocoindex/convert.py,sha256=BIIc2YGr3bc5zVgn4M_kcJeAyVkbl7YJ1ZVNgYIAmpQ,5144
7
- cocoindex/flow.py,sha256=izi6qxDAB35_rYuyc62-J2Petq88iIs6qBEHLBdm_3Q,20560
8
- cocoindex/functions.py,sha256=kF4GUK1uvCu4IE5MLlCidoGyONp3b6lm47qcjMaFCtg,1704
9
- cocoindex/index.py,sha256=32iiQI60VKhRRHle17rpoEVa_tsAHnyXXXnrLaX68uQ,557
10
- cocoindex/lib.py,sha256=FTzT07U1dxjP1QsMDLndD4pBem7Baerf1ZbSaVq-OVM,3582
11
- cocoindex/llm.py,sha256=t6EcAYM17tl7UgMSAizSRbXF8M96g3tQdynASLsqx6I,312
12
- cocoindex/op.py,sha256=ee9Zu0pNygALCb9iPHPGqjQhD2PnNFcdVNPIezMEcPw,10609
13
- cocoindex/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- cocoindex/query.py,sha256=FO_w1lsh4-lu7wPqkEDcZyvqKeis3e3Bq1L1h--91fg,3364
15
- cocoindex/runtime.py,sha256=R7FBysOH-4TL75cZNeDvAeEZ5jNmJ0m0zfkSUE7l0Lg,713
16
- cocoindex/setup.py,sha256=euHPZ9-7a68DV4PbdK2A0cZ9ECa6vT4qBraY3WKZk4s,512
17
- cocoindex/sources.py,sha256=4JWsW7sjFuiizaxOkxXJpZgVmVgF9ys9ATvvRY-SBjE,966
18
- cocoindex/storages.py,sha256=zICYAP3qHSFbGiVuNJNqn73-P6waNoBnU_SKOPfeygc,1882
19
- cocoindex/tests/test_convert.py,sha256=usmtQ7O91FYCfw131dxtmOrpFHo5I1zS-CFjPs8eHAs,2127
20
- cocoindex/tests/__init__.py,sha256=frcCV1k9oG9oKj3dpUqdJg1PxRT2RSN_XKdLCPjaYaY,2
21
- cocoindex/typing.py,sha256=SlO_QNa8lLJIdqJgESxISNsP7LSThC5DoWvFDcSgHK8,7501
22
- cocoindex/__init__.py,sha256=hdpaL27JWb9Jr1I0gVaZa9QffBLxb_MABkEcNTfU524,583
23
- cocoindex/_engine.cp311-win_amd64.pyd,sha256=aJ4bqXJtDQXCcfYWTCg0D7Aeymd96EDa2FqMbL9-KRA,58040832
24
- cocoindex-0.1.25.dist-info/RECORD,,