cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cocoindex/__init__.py +114 -0
  2. cocoindex/_engine.abi3.so +0 -0
  3. cocoindex/auth_registry.py +44 -0
  4. cocoindex/cli.py +830 -0
  5. cocoindex/engine_object.py +214 -0
  6. cocoindex/engine_value.py +550 -0
  7. cocoindex/flow.py +1281 -0
  8. cocoindex/functions/__init__.py +40 -0
  9. cocoindex/functions/_engine_builtin_specs.py +66 -0
  10. cocoindex/functions/colpali.py +247 -0
  11. cocoindex/functions/sbert.py +77 -0
  12. cocoindex/index.py +50 -0
  13. cocoindex/lib.py +75 -0
  14. cocoindex/llm.py +47 -0
  15. cocoindex/op.py +1047 -0
  16. cocoindex/py.typed +0 -0
  17. cocoindex/query_handler.py +57 -0
  18. cocoindex/runtime.py +78 -0
  19. cocoindex/setting.py +171 -0
  20. cocoindex/setup.py +92 -0
  21. cocoindex/sources/__init__.py +5 -0
  22. cocoindex/sources/_engine_builtin_specs.py +120 -0
  23. cocoindex/subprocess_exec.py +277 -0
  24. cocoindex/targets/__init__.py +5 -0
  25. cocoindex/targets/_engine_builtin_specs.py +153 -0
  26. cocoindex/targets/lancedb.py +466 -0
  27. cocoindex/tests/__init__.py +0 -0
  28. cocoindex/tests/test_engine_object.py +331 -0
  29. cocoindex/tests/test_engine_value.py +1724 -0
  30. cocoindex/tests/test_optional_database.py +249 -0
  31. cocoindex/tests/test_transform_flow.py +300 -0
  32. cocoindex/tests/test_typing.py +553 -0
  33. cocoindex/tests/test_validation.py +134 -0
  34. cocoindex/typing.py +834 -0
  35. cocoindex/user_app_loader.py +53 -0
  36. cocoindex/utils.py +20 -0
  37. cocoindex/validation.py +104 -0
  38. cocoindex-0.3.4.dist-info/METADATA +288 -0
  39. cocoindex-0.3.4.dist-info/RECORD +42 -0
  40. cocoindex-0.3.4.dist-info/WHEEL +4 -0
  41. cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
  42. cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
@@ -0,0 +1,466 @@
1
+ import dataclasses
2
+ import logging
3
+ import threading
4
+ import uuid
5
+ import weakref
6
+ import datetime
7
+
8
+ from typing import Any
9
+
10
+ import lancedb # type: ignore
11
+ import pyarrow as pa # type: ignore
12
+
13
+ from .. import op
14
+ from ..typing import (
15
+ FieldSchema,
16
+ EnrichedValueType,
17
+ BasicValueType,
18
+ StructType,
19
+ ValueType,
20
+ VectorTypeSchema,
21
+ TableType,
22
+ )
23
+ from ..index import VectorIndexDef, IndexOptions, VectorSimilarityMetric
24
+
25
+ _logger = logging.getLogger(__name__)
26
+
27
+ _LANCEDB_VECTOR_METRIC: dict[VectorSimilarityMetric, str] = {
28
+ VectorSimilarityMetric.COSINE_SIMILARITY: "cosine",
29
+ VectorSimilarityMetric.L2_DISTANCE: "l2",
30
+ VectorSimilarityMetric.INNER_PRODUCT: "dot",
31
+ }
32
+
33
+
34
+ class DatabaseOptions:
35
+ storage_options: dict[str, Any] | None = None
36
+
37
+
38
+ class LanceDB(op.TargetSpec):
39
+ db_uri: str
40
+ table_name: str
41
+ db_options: DatabaseOptions | None = None
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class _VectorIndex:
46
+ name: str
47
+ field_name: str
48
+ metric: VectorSimilarityMetric
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class _State:
53
+ key_field_schema: FieldSchema
54
+ value_fields_schema: list[FieldSchema]
55
+ vector_indexes: list[_VectorIndex] | None = None
56
+ db_options: DatabaseOptions | None = None
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class _TableKey:
61
+ db_uri: str
62
+ table_name: str
63
+
64
+
65
+ _DbConnectionsLock = threading.Lock()
66
+ _DbConnections: weakref.WeakValueDictionary[str, lancedb.AsyncConnection] = (
67
+ weakref.WeakValueDictionary()
68
+ )
69
+
70
+
71
+ async def connect_async(
72
+ db_uri: str,
73
+ *,
74
+ db_options: DatabaseOptions | None = None,
75
+ read_consistency_interval: datetime.timedelta | None = None,
76
+ ) -> lancedb.AsyncConnection:
77
+ """
78
+ Helper function to connect to a LanceDB database.
79
+ It will reuse the connection if it already exists.
80
+ The connection will be shared with the target used by cocoindex, so it achieves strong consistency.
81
+ """
82
+ with _DbConnectionsLock:
83
+ conn = _DbConnections.get(db_uri)
84
+ if conn is None:
85
+ db_options = db_options or DatabaseOptions()
86
+ _DbConnections[db_uri] = conn = await lancedb.connect_async(
87
+ db_uri,
88
+ storage_options=db_options.storage_options,
89
+ read_consistency_interval=read_consistency_interval,
90
+ )
91
+ return conn
92
+
93
+
94
+ def make_pa_schema(
95
+ key_field_schema: FieldSchema, value_fields_schema: list[FieldSchema]
96
+ ) -> pa.Schema:
97
+ """Convert FieldSchema list to PyArrow schema."""
98
+ fields = [
99
+ _convert_field_to_pa_field(field)
100
+ for field in [key_field_schema] + value_fields_schema
101
+ ]
102
+ return pa.schema(fields)
103
+
104
+
105
+ def _convert_field_to_pa_field(field_schema: FieldSchema) -> pa.Field:
106
+ """Convert a FieldSchema to a PyArrow Field."""
107
+ pa_type = _convert_value_type_to_pa_type(field_schema.value_type)
108
+
109
+ # Handle nullable fields
110
+ nullable = field_schema.value_type.nullable
111
+
112
+ return pa.field(field_schema.name, pa_type, nullable=nullable)
113
+
114
+
115
+ def _convert_value_type_to_pa_type(value_type: EnrichedValueType) -> pa.DataType:
116
+ """Convert EnrichedValueType to PyArrow DataType."""
117
+ base_type: ValueType = value_type.type
118
+
119
+ if isinstance(base_type, StructType):
120
+ # Handle struct types
121
+ return _convert_struct_fields_to_pa_type(base_type.fields)
122
+ elif isinstance(base_type, BasicValueType):
123
+ # Handle basic types
124
+ return _convert_basic_type_to_pa_type(base_type)
125
+ elif isinstance(base_type, TableType):
126
+ return pa.list_(_convert_struct_fields_to_pa_type(base_type.row.fields))
127
+
128
+ assert False, f"Unhandled value type: {value_type}"
129
+
130
+
131
+ def _convert_struct_fields_to_pa_type(
132
+ fields_schema: list[FieldSchema],
133
+ ) -> pa.StructType:
134
+ """Convert StructType to PyArrow StructType."""
135
+ return pa.struct([_convert_field_to_pa_field(field) for field in fields_schema])
136
+
137
+
138
+ def _convert_basic_type_to_pa_type(basic_type: BasicValueType) -> pa.DataType:
139
+ """Convert BasicValueType to PyArrow DataType."""
140
+ kind: str = basic_type.kind
141
+
142
+ # Map basic types to PyArrow types
143
+ type_mapping = {
144
+ "Bytes": pa.binary(),
145
+ "Str": pa.string(),
146
+ "Bool": pa.bool_(),
147
+ "Int64": pa.int64(),
148
+ "Float32": pa.float32(),
149
+ "Float64": pa.float64(),
150
+ "Uuid": pa.uuid(),
151
+ "Date": pa.date32(),
152
+ "Time": pa.time64("us"),
153
+ "LocalDateTime": pa.timestamp("us"),
154
+ "OffsetDateTime": pa.timestamp("us", tz="UTC"),
155
+ "TimeDelta": pa.duration("us"),
156
+ "Json": pa.json_(),
157
+ }
158
+
159
+ if kind in type_mapping:
160
+ return type_mapping[kind]
161
+
162
+ if kind == "Vector":
163
+ vector_schema: VectorTypeSchema | None = basic_type.vector
164
+ if vector_schema is None:
165
+ raise ValueError("Vector type missing vector schema")
166
+ element_type = _convert_basic_type_to_pa_type(vector_schema.element_type)
167
+
168
+ if vector_schema.dimension is not None:
169
+ return pa.list_(element_type, vector_schema.dimension)
170
+ else:
171
+ return pa.list_(element_type)
172
+
173
+ if kind == "Range":
174
+ # Range as a struct with start and end
175
+ return pa.struct([pa.field("start", pa.int64()), pa.field("end", pa.int64())])
176
+
177
+ assert False, f"Unsupported type kind for LanceDB: {kind}"
178
+
179
+
180
+ def _convert_key_value_to_sql(v: Any) -> str:
181
+ if isinstance(v, str):
182
+ escaped = v.replace("'", "''")
183
+ return f"'{escaped}'"
184
+
185
+ if isinstance(v, uuid.UUID):
186
+ return f"x'{v.hex}'"
187
+
188
+ return str(v)
189
+
190
+
191
+ def _convert_fields_to_pyarrow(fields: list[FieldSchema], v: Any) -> Any:
192
+ if isinstance(v, dict):
193
+ return {
194
+ field.name: _convert_value_for_pyarrow(
195
+ field.value_type.type, v.get(field.name)
196
+ )
197
+ for field in fields
198
+ }
199
+ elif isinstance(v, tuple):
200
+ return {
201
+ field.name: _convert_value_for_pyarrow(field.value_type.type, value)
202
+ for field, value in zip(fields, v)
203
+ }
204
+ else:
205
+ field = fields[0]
206
+ return {field.name: _convert_value_for_pyarrow(field.value_type.type, v)}
207
+
208
+
209
+ def _convert_value_for_pyarrow(t: ValueType, v: Any) -> Any:
210
+ if v is None:
211
+ return None
212
+
213
+ if isinstance(t, BasicValueType):
214
+ if isinstance(v, uuid.UUID):
215
+ return v.bytes
216
+
217
+ if t.kind == "Range":
218
+ return {"start": v[0], "end": v[1]}
219
+
220
+ if t.vector is not None:
221
+ return [_convert_value_for_pyarrow(t.vector.element_type, e) for e in v]
222
+
223
+ return v
224
+
225
+ elif isinstance(t, StructType):
226
+ return _convert_fields_to_pyarrow(t.fields, v)
227
+
228
+ elif isinstance(t, TableType):
229
+ if isinstance(v, list):
230
+ return [_convert_fields_to_pyarrow(t.row.fields, value) for value in v]
231
+ else:
232
+ key_fields = t.row.fields[: t.num_key_parts]
233
+ value_fields = t.row.fields[t.num_key_parts :]
234
+ return [
235
+ _convert_fields_to_pyarrow(key_fields, value[0 : t.num_key_parts])
236
+ | _convert_fields_to_pyarrow(value_fields, value[t.num_key_parts :])
237
+ for value in v
238
+ ]
239
+
240
+ assert False, f"Unsupported value type: {t}"
241
+
242
+
243
+ @dataclasses.dataclass
244
+ class _MutateContext:
245
+ table: lancedb.AsyncTable
246
+ key_field_schema: FieldSchema
247
+ value_fields_type: list[ValueType]
248
+ pa_schema: pa.Schema
249
+
250
+
251
+ # Not used for now, because of https://github.com/lancedb/lance/issues/3443
252
+ #
253
+ # async def _update_table_schema(
254
+ # table: lancedb.AsyncTable,
255
+ # expected_schema: pa.Schema,
256
+ # ) -> None:
257
+ # existing_schema = await table.schema()
258
+ # unseen_existing_field_names = {field.name: field for field in existing_schema}
259
+ # new_columns = []
260
+ # updated_columns = []
261
+ # for field in expected_schema:
262
+ # existing_field = unseen_existing_field_names.pop(field.name, None)
263
+ # if existing_field is None:
264
+ # new_columns.append(field)
265
+ # else:
266
+ # if field.type != existing_field.type:
267
+ # updated_columns.append(
268
+ # {
269
+ # "path": field.name,
270
+ # "data_type": field.type,
271
+ # "nullable": field.nullable,
272
+ # }
273
+ # )
274
+ # if new_columns:
275
+ # table.add_columns(new_columns)
276
+ # if updated_columns:
277
+ # table.alter_columns(*updated_columns)
278
+ # if unseen_existing_field_names:
279
+ # table.drop_columns(unseen_existing_field_names.keys())
280
+
281
+
282
+ @op.target_connector(
283
+ spec_cls=LanceDB, persistent_key_type=_TableKey, setup_state_cls=_State
284
+ )
285
+ class _Connector:
286
+ @staticmethod
287
+ def get_persistent_key(spec: LanceDB) -> _TableKey:
288
+ return _TableKey(db_uri=spec.db_uri, table_name=spec.table_name)
289
+
290
+ @staticmethod
291
+ def get_setup_state(
292
+ spec: LanceDB,
293
+ key_fields_schema: list[FieldSchema],
294
+ value_fields_schema: list[FieldSchema],
295
+ index_options: IndexOptions,
296
+ ) -> _State:
297
+ if len(key_fields_schema) != 1:
298
+ raise ValueError("LanceDB only supports a single key field")
299
+ if index_options.vector_indexes is not None:
300
+ for vector_index in index_options.vector_indexes:
301
+ if vector_index.method is not None:
302
+ raise ValueError(
303
+ "Vector index method is not configurable for LanceDB yet"
304
+ )
305
+ return _State(
306
+ key_field_schema=key_fields_schema[0],
307
+ value_fields_schema=value_fields_schema,
308
+ db_options=spec.db_options,
309
+ vector_indexes=(
310
+ [
311
+ _VectorIndex(
312
+ name=f"__{index.field_name}__{_LANCEDB_VECTOR_METRIC[index.metric]}__idx",
313
+ field_name=index.field_name,
314
+ metric=index.metric,
315
+ )
316
+ for index in index_options.vector_indexes
317
+ ]
318
+ if index_options.vector_indexes is not None
319
+ else None
320
+ ),
321
+ )
322
+
323
+ @staticmethod
324
+ def describe(key: _TableKey) -> str:
325
+ return f"LanceDB table {key.table_name}@{key.db_uri}"
326
+
327
+ @staticmethod
328
+ def check_state_compatibility(
329
+ previous: _State, current: _State
330
+ ) -> op.TargetStateCompatibility:
331
+ if (
332
+ previous.key_field_schema != current.key_field_schema
333
+ or previous.value_fields_schema != current.value_fields_schema
334
+ ):
335
+ return op.TargetStateCompatibility.NOT_COMPATIBLE
336
+
337
+ return op.TargetStateCompatibility.COMPATIBLE
338
+
339
+ @staticmethod
340
+ async def apply_setup_change(
341
+ key: _TableKey, previous: _State | None, current: _State | None
342
+ ) -> None:
343
+ latest_state = current or previous
344
+ if not latest_state:
345
+ return
346
+ db_conn = await connect_async(key.db_uri, db_options=latest_state.db_options)
347
+
348
+ reuse_table = (
349
+ previous is not None
350
+ and current is not None
351
+ and previous.key_field_schema == current.key_field_schema
352
+ and previous.value_fields_schema == current.value_fields_schema
353
+ )
354
+ if previous is not None:
355
+ if not reuse_table:
356
+ await db_conn.drop_table(key.table_name, ignore_missing=True)
357
+
358
+ if current is None:
359
+ return
360
+
361
+ table: lancedb.AsyncTable | None = None
362
+ if reuse_table:
363
+ try:
364
+ table = await db_conn.open_table(key.table_name)
365
+ except Exception as e: # pylint: disable=broad-exception-caught
366
+ _logger.warning(
367
+ "Exception in opening table %s, creating it",
368
+ key.table_name,
369
+ exc_info=e,
370
+ )
371
+ table = None
372
+
373
+ if table is None:
374
+ table = await db_conn.create_table(
375
+ key.table_name,
376
+ schema=make_pa_schema(
377
+ current.key_field_schema, current.value_fields_schema
378
+ ),
379
+ mode="overwrite",
380
+ )
381
+ await table.create_index(
382
+ current.key_field_schema.name, config=lancedb.index.BTree()
383
+ )
384
+
385
+ unseen_prev_vector_indexes = {
386
+ index.name for index in (previous and previous.vector_indexes) or []
387
+ }
388
+ existing_vector_indexes = {index.name for index in await table.list_indices()}
389
+
390
+ for index in current.vector_indexes or []:
391
+ if index.name in unseen_prev_vector_indexes:
392
+ unseen_prev_vector_indexes.remove(index.name)
393
+ else:
394
+ try:
395
+ await table.create_index(
396
+ index.field_name,
397
+ name=index.name,
398
+ config=lancedb.index.HnswPq(
399
+ distance_type=_LANCEDB_VECTOR_METRIC[index.metric]
400
+ ),
401
+ )
402
+ except Exception as e: # pylint: disable=broad-exception-caught
403
+ raise RuntimeError(
404
+ f"Exception in creating index on field {index.field_name}. "
405
+ f"This may be caused by a limitation of LanceDB, "
406
+ f"which requires data existing in the table to train the index. "
407
+ f"See: https://github.com/lancedb/lance/issues/4034",
408
+ index.name,
409
+ ) from e
410
+
411
+ for vector_index_name in unseen_prev_vector_indexes:
412
+ if vector_index_name in existing_vector_indexes:
413
+ await table.drop_index(vector_index_name)
414
+
415
+ @staticmethod
416
+ async def prepare(
417
+ spec: LanceDB,
418
+ setup_state: _State,
419
+ ) -> _MutateContext:
420
+ db_conn = await connect_async(spec.db_uri, db_options=spec.db_options)
421
+ table = await db_conn.open_table(spec.table_name)
422
+ return _MutateContext(
423
+ table=table,
424
+ key_field_schema=setup_state.key_field_schema,
425
+ value_fields_type=[
426
+ field.value_type.type for field in setup_state.value_fields_schema
427
+ ],
428
+ pa_schema=make_pa_schema(
429
+ setup_state.key_field_schema, setup_state.value_fields_schema
430
+ ),
431
+ )
432
+
433
+ @staticmethod
434
+ async def mutate(
435
+ *all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
436
+ ) -> None:
437
+ for context, mutations in all_mutations:
438
+ key_name = context.key_field_schema.name
439
+ value_types = context.value_fields_type
440
+
441
+ rows_to_upserts = []
442
+ keys_sql_to_deletes = []
443
+ for key, value in mutations.items():
444
+ if value is None:
445
+ keys_sql_to_deletes.append(_convert_key_value_to_sql(key))
446
+ else:
447
+ fields = {
448
+ key_name: _convert_value_for_pyarrow(
449
+ context.key_field_schema.value_type.type, key
450
+ )
451
+ }
452
+ for (name, value), value_type in zip(value.items(), value_types):
453
+ fields[name] = _convert_value_for_pyarrow(value_type, value)
454
+ rows_to_upserts.append(fields)
455
+ record_batch = pa.RecordBatch.from_pylist(
456
+ rows_to_upserts, context.pa_schema
457
+ )
458
+ builder = (
459
+ context.table.merge_insert(key_name)
460
+ .when_matched_update_all()
461
+ .when_not_matched_insert_all()
462
+ )
463
+ if keys_sql_to_deletes:
464
+ delete_cond_sql = f"{key_name} IN ({','.join(keys_sql_to_deletes)})"
465
+ builder = builder.when_not_matched_by_source_delete(delete_cond_sql)
466
+ await builder.execute(record_batch)
File without changes