cocoindex 0.2.10__cp311-abi3-manylinux_2_28_aarch64.whl → 0.2.12__cp311-abi3-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,460 @@
1
+ import dataclasses
2
+ import logging
3
+ import threading
4
+ import uuid
5
+ import weakref
6
+ import datetime
7
+
8
+ from typing import Any
9
+
10
+ import lancedb # type: ignore
11
+ import pyarrow as pa # type: ignore
12
+
13
+ from .. import op
14
+ from ..typing import (
15
+ FieldSchema,
16
+ EnrichedValueType,
17
+ BasicValueType,
18
+ StructType,
19
+ ValueType,
20
+ VectorTypeSchema,
21
+ TableType,
22
+ )
23
+ from ..index import VectorIndexDef, IndexOptions, VectorSimilarityMetric
24
+
25
+ _logger = logging.getLogger(__name__)
26
+
27
+ _LANCEDB_VECTOR_METRIC: dict[VectorSimilarityMetric, str] = {
28
+ VectorSimilarityMetric.COSINE_SIMILARITY: "cosine",
29
+ VectorSimilarityMetric.L2_DISTANCE: "l2",
30
+ VectorSimilarityMetric.INNER_PRODUCT: "dot",
31
+ }
32
+
33
+
34
+ class DatabaseOptions:
35
+ storage_options: dict[str, Any] | None = None
36
+
37
+
38
+ class LanceDB(op.TargetSpec):
39
+ db_uri: str
40
+ table_name: str
41
+ db_options: DatabaseOptions | None = None
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class _VectorIndex:
46
+ name: str
47
+ field_name: str
48
+ metric: VectorSimilarityMetric
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class _State:
53
+ key_field_schema: FieldSchema
54
+ value_fields_schema: list[FieldSchema]
55
+ vector_indexes: list[_VectorIndex] | None = None
56
+ db_options: DatabaseOptions | None = None
57
+
58
+
59
+ @dataclasses.dataclass
60
+ class _TableKey:
61
+ db_uri: str
62
+ table_name: str
63
+
64
+
65
+ _DbConnectionsLock = threading.Lock()
66
+ _DbConnections: weakref.WeakValueDictionary[str, lancedb.AsyncConnection] = (
67
+ weakref.WeakValueDictionary()
68
+ )
69
+
70
+
71
+ async def connect_async(
72
+ db_uri: str,
73
+ *,
74
+ db_options: DatabaseOptions | None = None,
75
+ read_consistency_interval: datetime.timedelta | None = None,
76
+ ) -> lancedb.AsyncConnection:
77
+ """
78
+ Helper function to connect to a LanceDB database.
79
+ It will reuse the connection if it already exists.
80
+ The connection will be shared with the target used by cocoindex, so it achieves strong consistency.
81
+ """
82
+ with _DbConnectionsLock:
83
+ conn = _DbConnections.get(db_uri)
84
+ if conn is None:
85
+ db_options = db_options or DatabaseOptions()
86
+ _DbConnections[db_uri] = conn = await lancedb.connect_async(
87
+ db_uri,
88
+ storage_options=db_options.storage_options,
89
+ read_consistency_interval=read_consistency_interval,
90
+ )
91
+ return conn
92
+
93
+
94
+ def make_pa_schema(
95
+ key_field_schema: FieldSchema, value_fields_schema: list[FieldSchema]
96
+ ) -> pa.Schema:
97
+ """Convert FieldSchema list to PyArrow schema."""
98
+ fields = [
99
+ _convert_field_to_pa_field(field)
100
+ for field in [key_field_schema] + value_fields_schema
101
+ ]
102
+ return pa.schema(fields)
103
+
104
+
105
+ def _convert_field_to_pa_field(field_schema: FieldSchema) -> pa.Field:
106
+ """Convert a FieldSchema to a PyArrow Field."""
107
+ pa_type = _convert_value_type_to_pa_type(field_schema.value_type)
108
+
109
+ # Handle nullable fields
110
+ nullable = field_schema.value_type.nullable
111
+
112
+ return pa.field(field_schema.name, pa_type, nullable=nullable)
113
+
114
+
115
+ def _convert_value_type_to_pa_type(value_type: EnrichedValueType) -> pa.DataType:
116
+ """Convert EnrichedValueType to PyArrow DataType."""
117
+ base_type: ValueType = value_type.type
118
+
119
+ if isinstance(base_type, StructType):
120
+ # Handle struct types
121
+ return _convert_struct_fields_to_pa_type(base_type.fields)
122
+ elif isinstance(base_type, BasicValueType):
123
+ # Handle basic types
124
+ return _convert_basic_type_to_pa_type(base_type)
125
+ elif isinstance(base_type, TableType):
126
+ return pa.list_(_convert_struct_fields_to_pa_type(base_type.row.fields))
127
+
128
+ assert False, f"Unhandled value type: {value_type}"
129
+
130
+
131
+ def _convert_struct_fields_to_pa_type(
132
+ fields_schema: list[FieldSchema],
133
+ ) -> pa.StructType:
134
+ """Convert StructType to PyArrow StructType."""
135
+ return pa.struct([_convert_field_to_pa_field(field) for field in fields_schema])
136
+
137
+
138
+ def _convert_basic_type_to_pa_type(basic_type: BasicValueType) -> pa.DataType:
139
+ """Convert BasicValueType to PyArrow DataType."""
140
+ kind: str = basic_type.kind
141
+
142
+ # Map basic types to PyArrow types
143
+ type_mapping = {
144
+ "Bytes": pa.binary(),
145
+ "Str": pa.string(),
146
+ "Bool": pa.bool_(),
147
+ "Int64": pa.int64(),
148
+ "Float32": pa.float32(),
149
+ "Float64": pa.float64(),
150
+ "Uuid": pa.uuid(),
151
+ "Date": pa.date32(),
152
+ "Time": pa.time64("us"),
153
+ "LocalDateTime": pa.timestamp("us"),
154
+ "OffsetDateTime": pa.timestamp("us", tz="UTC"),
155
+ "TimeDelta": pa.duration("us"),
156
+ "Json": pa.json_(),
157
+ }
158
+
159
+ if kind in type_mapping:
160
+ return type_mapping[kind]
161
+
162
+ if kind == "Vector":
163
+ vector_schema: VectorTypeSchema | None = basic_type.vector
164
+ if vector_schema is None:
165
+ raise ValueError("Vector type missing vector schema")
166
+ element_type = _convert_basic_type_to_pa_type(vector_schema.element_type)
167
+
168
+ if vector_schema.dimension is not None:
169
+ return pa.list_(element_type, vector_schema.dimension)
170
+ else:
171
+ return pa.list_(element_type)
172
+
173
+ if kind == "Range":
174
+ # Range as a struct with start and end
175
+ return pa.struct([pa.field("start", pa.int64()), pa.field("end", pa.int64())])
176
+
177
+ assert False, f"Unsupported type kind for LanceDB: {kind}"
178
+
179
+
180
+ def _convert_key_value_to_sql(v: Any) -> str:
181
+ if isinstance(v, str):
182
+ escaped = v.replace("'", "''")
183
+ return f"'{escaped}'"
184
+
185
+ if isinstance(v, uuid.UUID):
186
+ return f"x'{v.hex}'"
187
+
188
+ return str(v)
189
+
190
+
191
+ def _convert_fields_to_pyarrow(fields: list[FieldSchema], v: Any) -> Any:
192
+ if isinstance(v, dict):
193
+ return {
194
+ field.name: _convert_value_for_pyarrow(
195
+ field.value_type.type, v.get(field.name)
196
+ )
197
+ for field in fields
198
+ }
199
+ elif isinstance(v, tuple):
200
+ return {
201
+ field.name: _convert_value_for_pyarrow(field.value_type.type, value)
202
+ for field, value in zip(fields, v)
203
+ }
204
+ else:
205
+ field = fields[0]
206
+ return {field.name: _convert_value_for_pyarrow(field.value_type.type, v)}
207
+
208
+
209
+ def _convert_value_for_pyarrow(t: ValueType, v: Any) -> Any:
210
+ if v is None:
211
+ return None
212
+
213
+ if isinstance(t, BasicValueType):
214
+ if isinstance(v, uuid.UUID):
215
+ return v.bytes
216
+
217
+ if t.kind == "Range":
218
+ return {"start": v[0], "end": v[1]}
219
+
220
+ if t.vector is not None:
221
+ return [_convert_value_for_pyarrow(t.vector.element_type, e) for e in v]
222
+
223
+ return v
224
+
225
+ elif isinstance(t, StructType):
226
+ return _convert_fields_to_pyarrow(t.fields, v)
227
+
228
+ elif isinstance(t, TableType):
229
+ if isinstance(v, list):
230
+ return [_convert_fields_to_pyarrow(t.row.fields, value) for value in v]
231
+ else:
232
+ key_fields = t.row.fields[: t.num_key_parts]
233
+ value_fields = t.row.fields[t.num_key_parts :]
234
+ return [
235
+ _convert_fields_to_pyarrow(key_fields, value[0 : t.num_key_parts])
236
+ | _convert_fields_to_pyarrow(value_fields, value[t.num_key_parts :])
237
+ for value in v
238
+ ]
239
+
240
+ assert False, f"Unsupported value type: {t}"
241
+
242
+
243
+ @dataclasses.dataclass
244
+ class _MutateContext:
245
+ table: lancedb.AsyncTable
246
+ key_field_schema: FieldSchema
247
+ value_fields_type: list[ValueType]
248
+ pa_schema: pa.Schema
249
+
250
+
251
+ # Not used for now, because of https://github.com/lancedb/lance/issues/3443
252
+ #
253
+ # async def _update_table_schema(
254
+ # table: lancedb.AsyncTable,
255
+ # expected_schema: pa.Schema,
256
+ # ) -> None:
257
+ # existing_schema = await table.schema()
258
+ # unseen_existing_field_names = {field.name: field for field in existing_schema}
259
+ # new_columns = []
260
+ # updated_columns = []
261
+ # for field in expected_schema:
262
+ # existing_field = unseen_existing_field_names.pop(field.name, None)
263
+ # if existing_field is None:
264
+ # new_columns.append(field)
265
+ # else:
266
+ # if field.type != existing_field.type:
267
+ # updated_columns.append(
268
+ # {
269
+ # "path": field.name,
270
+ # "data_type": field.type,
271
+ # "nullable": field.nullable,
272
+ # }
273
+ # )
274
+ # if new_columns:
275
+ # table.add_columns(new_columns)
276
+ # if updated_columns:
277
+ # table.alter_columns(*updated_columns)
278
+ # if unseen_existing_field_names:
279
+ # table.drop_columns(unseen_existing_field_names.keys())
280
+
281
+
282
+ @op.target_connector(
283
+ spec_cls=LanceDB, persistent_key_type=_TableKey, setup_state_cls=_State
284
+ )
285
+ class _Connector:
286
+ @staticmethod
287
+ def get_persistent_key(spec: LanceDB) -> _TableKey:
288
+ return _TableKey(db_uri=spec.db_uri, table_name=spec.table_name)
289
+
290
+ @staticmethod
291
+ def get_setup_state(
292
+ spec: LanceDB,
293
+ key_fields_schema: list[FieldSchema],
294
+ value_fields_schema: list[FieldSchema],
295
+ index_options: IndexOptions,
296
+ ) -> _State:
297
+ if len(key_fields_schema) != 1:
298
+ raise ValueError("LanceDB only supports a single key field")
299
+ return _State(
300
+ key_field_schema=key_fields_schema[0],
301
+ value_fields_schema=value_fields_schema,
302
+ db_options=spec.db_options,
303
+ vector_indexes=(
304
+ [
305
+ _VectorIndex(
306
+ name=f"__{index.field_name}__{_LANCEDB_VECTOR_METRIC[index.metric]}__idx",
307
+ field_name=index.field_name,
308
+ metric=index.metric,
309
+ )
310
+ for index in index_options.vector_indexes
311
+ ]
312
+ if index_options.vector_indexes is not None
313
+ else None
314
+ ),
315
+ )
316
+
317
+ @staticmethod
318
+ def describe(key: _TableKey) -> str:
319
+ return f"LanceDB table {key.table_name}@{key.db_uri}"
320
+
321
+ @staticmethod
322
+ def check_state_compatibility(
323
+ previous: _State, current: _State
324
+ ) -> op.TargetStateCompatibility:
325
+ if (
326
+ previous.key_field_schema != current.key_field_schema
327
+ or previous.value_fields_schema != current.value_fields_schema
328
+ ):
329
+ return op.TargetStateCompatibility.NOT_COMPATIBLE
330
+
331
+ return op.TargetStateCompatibility.COMPATIBLE
332
+
333
+ @staticmethod
334
+ async def apply_setup_change(
335
+ key: _TableKey, previous: _State | None, current: _State | None
336
+ ) -> None:
337
+ latest_state = current or previous
338
+ if not latest_state:
339
+ return
340
+ db_conn = await connect_async(key.db_uri, db_options=latest_state.db_options)
341
+
342
+ reuse_table = (
343
+ previous is not None
344
+ and current is not None
345
+ and previous.key_field_schema == current.key_field_schema
346
+ and previous.value_fields_schema == current.value_fields_schema
347
+ )
348
+ if previous is not None:
349
+ if not reuse_table:
350
+ await db_conn.drop_table(key.table_name, ignore_missing=True)
351
+
352
+ if current is None:
353
+ return
354
+
355
+ table: lancedb.AsyncTable | None = None
356
+ if reuse_table:
357
+ try:
358
+ table = await db_conn.open_table(key.table_name)
359
+ except Exception as e: # pylint: disable=broad-exception-caught
360
+ _logger.warning(
361
+ "Exception in opening table %s, creating it",
362
+ key.table_name,
363
+ exc_info=e,
364
+ )
365
+ table = None
366
+
367
+ if table is None:
368
+ table = await db_conn.create_table(
369
+ key.table_name,
370
+ schema=make_pa_schema(
371
+ current.key_field_schema, current.value_fields_schema
372
+ ),
373
+ mode="overwrite",
374
+ )
375
+ await table.create_index(
376
+ current.key_field_schema.name, config=lancedb.index.BTree()
377
+ )
378
+
379
+ unseen_prev_vector_indexes = {
380
+ index.name for index in (previous and previous.vector_indexes) or []
381
+ }
382
+ existing_vector_indexes = {index.name for index in await table.list_indices()}
383
+
384
+ for index in current.vector_indexes or []:
385
+ if index.name in unseen_prev_vector_indexes:
386
+ unseen_prev_vector_indexes.remove(index.name)
387
+ else:
388
+ try:
389
+ await table.create_index(
390
+ index.field_name,
391
+ name=index.name,
392
+ config=lancedb.index.HnswPq(
393
+ distance_type=_LANCEDB_VECTOR_METRIC[index.metric]
394
+ ),
395
+ )
396
+ except Exception as e: # pylint: disable=broad-exception-caught
397
+ raise RuntimeError(
398
+ f"Exception in creating index on field {index.field_name}. "
399
+ f"This may be caused by a limitation of LanceDB, "
400
+ f"which requires data existing in the table to train the index. "
401
+ f"See: https://github.com/lancedb/lance/issues/4034",
402
+ index.name,
403
+ ) from e
404
+
405
+ for vector_index_name in unseen_prev_vector_indexes:
406
+ if vector_index_name in existing_vector_indexes:
407
+ await table.drop_index(vector_index_name)
408
+
409
+ @staticmethod
410
+ async def prepare(
411
+ spec: LanceDB,
412
+ setup_state: _State,
413
+ ) -> _MutateContext:
414
+ db_conn = await connect_async(spec.db_uri, db_options=spec.db_options)
415
+ table = await db_conn.open_table(spec.table_name)
416
+ return _MutateContext(
417
+ table=table,
418
+ key_field_schema=setup_state.key_field_schema,
419
+ value_fields_type=[
420
+ field.value_type.type for field in setup_state.value_fields_schema
421
+ ],
422
+ pa_schema=make_pa_schema(
423
+ setup_state.key_field_schema, setup_state.value_fields_schema
424
+ ),
425
+ )
426
+
427
+ @staticmethod
428
+ async def mutate(
429
+ *all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
430
+ ) -> None:
431
+ for context, mutations in all_mutations:
432
+ key_name = context.key_field_schema.name
433
+ value_types = context.value_fields_type
434
+
435
+ rows_to_upserts = []
436
+ keys_sql_to_deletes = []
437
+ for key, value in mutations.items():
438
+ if value is None:
439
+ keys_sql_to_deletes.append(_convert_key_value_to_sql(key))
440
+ else:
441
+ fields = {
442
+ key_name: _convert_value_for_pyarrow(
443
+ context.key_field_schema.value_type.type, key
444
+ )
445
+ }
446
+ for (name, value), value_type in zip(value.items(), value_types):
447
+ fields[name] = _convert_value_for_pyarrow(value_type, value)
448
+ rows_to_upserts.append(fields)
449
+ record_batch = pa.RecordBatch.from_pylist(
450
+ rows_to_upserts, context.pa_schema
451
+ )
452
+ builder = (
453
+ context.table.merge_insert(key_name)
454
+ .when_matched_update_all()
455
+ .when_not_matched_insert_all()
456
+ )
457
+ if keys_sql_to_deletes:
458
+ delete_cond_sql = f"{key_name} IN ({','.join(keys_sql_to_deletes)})"
459
+ builder = builder.when_not_matched_by_source_delete(delete_cond_sql)
460
+ await builder.execute(record_batch)
@@ -21,6 +21,7 @@ from cocoindex.typing import (
21
21
  Vector,
22
22
  analyze_type_info,
23
23
  encode_enriched_type,
24
+ decode_engine_value_type,
24
25
  )
25
26
 
26
27
 
@@ -86,7 +87,9 @@ def build_engine_value_decoder(
86
87
  """
87
88
  engine_type = encode_enriched_type(engine_type_in_py)["type"]
88
89
  return make_engine_value_decoder(
89
- [], engine_type, analyze_type_info(python_type or engine_type_in_py)
90
+ [],
91
+ decode_engine_value_type(engine_type),
92
+ analyze_type_info(python_type or engine_type_in_py),
90
93
  )
91
94
 
92
95
 
@@ -116,7 +119,9 @@ def validate_full_roundtrip_to(
116
119
 
117
120
  for other_value, other_type in decoded_values:
118
121
  decoder = make_engine_value_decoder(
119
- [], encoded_output_type, analyze_type_info(other_type)
122
+ [],
123
+ decode_engine_value_type(encoded_output_type),
124
+ analyze_type_info(other_type),
120
125
  )
121
126
  other_decoded_value = decoder(value_from_engine)
122
127
  assert eq(other_decoded_value, other_value), (
@@ -383,9 +388,19 @@ def test_roundtrip_json() -> None:
383
388
 
384
389
  def test_decode_scalar_numpy_values() -> None:
385
390
  test_cases = [
386
- ({"kind": "Int64"}, np.int64, 42, np.int64(42)),
387
- ({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
388
- ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
391
+ (decode_engine_value_type({"kind": "Int64"}), np.int64, 42, np.int64(42)),
392
+ (
393
+ decode_engine_value_type({"kind": "Float32"}),
394
+ np.float32,
395
+ 3.14,
396
+ np.float32(3.14),
397
+ ),
398
+ (
399
+ decode_engine_value_type({"kind": "Float64"}),
400
+ np.float64,
401
+ 2.718,
402
+ np.float64(2.718),
403
+ ),
389
404
  ]
390
405
  for src_type, dst_type, input_value, expected in test_cases:
391
406
  decoder = make_engine_value_decoder(
@@ -398,11 +413,13 @@ def test_decode_scalar_numpy_values() -> None:
398
413
 
399
414
  def test_non_ndarray_vector_decoding() -> None:
400
415
  # Test list[np.float64]
401
- src_type = {
402
- "kind": "Vector",
403
- "element_type": {"kind": "Float64"},
404
- "dimension": None,
405
- }
416
+ src_type = decode_engine_value_type(
417
+ {
418
+ "kind": "Vector",
419
+ "element_type": {"kind": "Float64"},
420
+ "dimension": None,
421
+ }
422
+ )
406
423
  dst_type_float = list[np.float64]
407
424
  decoder = make_engine_value_decoder(
408
425
  ["field"], src_type, analyze_type_info(dst_type_float)
@@ -414,7 +431,9 @@ def test_non_ndarray_vector_decoding() -> None:
414
431
  assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
415
432
 
416
433
  # Test list[Uuid]
417
- src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
434
+ src_type = decode_engine_value_type(
435
+ {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
436
+ )
418
437
  dst_type_uuid = list[uuid.UUID]
419
438
  decoder = make_engine_value_decoder(
420
439
  ["field"], src_type, analyze_type_info(dst_type_uuid)
@@ -895,11 +914,13 @@ def test_encode_complex_structure_with_ndarray() -> None:
895
914
 
896
915
  def test_decode_nullable_ndarray_none_or_value_input() -> None:
897
916
  """Test decoding a nullable NDArray with None or value inputs."""
898
- src_type_dict = {
899
- "kind": "Vector",
900
- "element_type": {"kind": "Float32"},
901
- "dimension": None,
902
- }
917
+ src_type_dict = decode_engine_value_type(
918
+ {
919
+ "kind": "Vector",
920
+ "element_type": {"kind": "Float32"},
921
+ "dimension": None,
922
+ }
923
+ )
903
924
  dst_annotation = NDArrayFloat32Type | None
904
925
  decoder = make_engine_value_decoder(
905
926
  [], src_type_dict, analyze_type_info(dst_annotation)
@@ -921,11 +942,13 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
921
942
 
922
943
  def test_decode_vector_string() -> None:
923
944
  """Test decoding a vector of strings works for Python native list type."""
924
- src_type_dict = {
925
- "kind": "Vector",
926
- "element_type": {"kind": "Str"},
927
- "dimension": None,
928
- }
945
+ src_type_dict = decode_engine_value_type(
946
+ {
947
+ "kind": "Vector",
948
+ "element_type": {"kind": "Str"},
949
+ "dimension": None,
950
+ }
951
+ )
929
952
  decoder = make_engine_value_decoder(
930
953
  [], src_type_dict, analyze_type_info(Vector[str])
931
954
  )
@@ -934,11 +957,13 @@ def test_decode_vector_string() -> None:
934
957
 
935
958
  def test_decode_error_non_nullable_or_non_list_vector() -> None:
936
959
  """Test decoding errors for non-nullable vectors or non-list inputs."""
937
- src_type_dict = {
938
- "kind": "Vector",
939
- "element_type": {"kind": "Float32"},
940
- "dimension": None,
941
- }
960
+ src_type_dict = decode_engine_value_type(
961
+ {
962
+ "kind": "Vector",
963
+ "element_type": {"kind": "Float32"},
964
+ "dimension": None,
965
+ }
966
+ )
942
967
  decoder = make_engine_value_decoder(
943
968
  [], src_type_dict, analyze_type_info(NDArrayFloat32Type)
944
969
  )