cocoindex 0.3.4__cp311-abi3-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +114 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +44 -0
- cocoindex/cli.py +830 -0
- cocoindex/engine_object.py +214 -0
- cocoindex/engine_value.py +550 -0
- cocoindex/flow.py +1281 -0
- cocoindex/functions/__init__.py +40 -0
- cocoindex/functions/_engine_builtin_specs.py +66 -0
- cocoindex/functions/colpali.py +247 -0
- cocoindex/functions/sbert.py +77 -0
- cocoindex/index.py +50 -0
- cocoindex/lib.py +75 -0
- cocoindex/llm.py +47 -0
- cocoindex/op.py +1047 -0
- cocoindex/py.typed +0 -0
- cocoindex/query_handler.py +57 -0
- cocoindex/runtime.py +78 -0
- cocoindex/setting.py +171 -0
- cocoindex/setup.py +92 -0
- cocoindex/sources/__init__.py +5 -0
- cocoindex/sources/_engine_builtin_specs.py +120 -0
- cocoindex/subprocess_exec.py +277 -0
- cocoindex/targets/__init__.py +5 -0
- cocoindex/targets/_engine_builtin_specs.py +153 -0
- cocoindex/targets/lancedb.py +466 -0
- cocoindex/tests/__init__.py +0 -0
- cocoindex/tests/test_engine_object.py +331 -0
- cocoindex/tests/test_engine_value.py +1724 -0
- cocoindex/tests/test_optional_database.py +249 -0
- cocoindex/tests/test_transform_flow.py +300 -0
- cocoindex/tests/test_typing.py +553 -0
- cocoindex/tests/test_validation.py +134 -0
- cocoindex/typing.py +834 -0
- cocoindex/user_app_loader.py +53 -0
- cocoindex/utils.py +20 -0
- cocoindex/validation.py +104 -0
- cocoindex-0.3.4.dist-info/METADATA +288 -0
- cocoindex-0.3.4.dist-info/RECORD +42 -0
- cocoindex-0.3.4.dist-info/WHEEL +4 -0
- cocoindex-0.3.4.dist-info/entry_points.txt +2 -0
- cocoindex-0.3.4.dist-info/licenses/THIRD_PARTY_NOTICES.html +13249 -0
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import logging
|
|
3
|
+
import threading
|
|
4
|
+
import uuid
|
|
5
|
+
import weakref
|
|
6
|
+
import datetime
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import lancedb # type: ignore
|
|
11
|
+
import pyarrow as pa # type: ignore
|
|
12
|
+
|
|
13
|
+
from .. import op
|
|
14
|
+
from ..typing import (
|
|
15
|
+
FieldSchema,
|
|
16
|
+
EnrichedValueType,
|
|
17
|
+
BasicValueType,
|
|
18
|
+
StructType,
|
|
19
|
+
ValueType,
|
|
20
|
+
VectorTypeSchema,
|
|
21
|
+
TableType,
|
|
22
|
+
)
|
|
23
|
+
from ..index import VectorIndexDef, IndexOptions, VectorSimilarityMetric
|
|
24
|
+
|
|
25
|
+
_logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
_LANCEDB_VECTOR_METRIC: dict[VectorSimilarityMetric, str] = {
|
|
28
|
+
VectorSimilarityMetric.COSINE_SIMILARITY: "cosine",
|
|
29
|
+
VectorSimilarityMetric.L2_DISTANCE: "l2",
|
|
30
|
+
VectorSimilarityMetric.INNER_PRODUCT: "dot",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DatabaseOptions:
|
|
35
|
+
storage_options: dict[str, Any] | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LanceDB(op.TargetSpec):
|
|
39
|
+
db_uri: str
|
|
40
|
+
table_name: str
|
|
41
|
+
db_options: DatabaseOptions | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclasses.dataclass
|
|
45
|
+
class _VectorIndex:
|
|
46
|
+
name: str
|
|
47
|
+
field_name: str
|
|
48
|
+
metric: VectorSimilarityMetric
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclasses.dataclass
|
|
52
|
+
class _State:
|
|
53
|
+
key_field_schema: FieldSchema
|
|
54
|
+
value_fields_schema: list[FieldSchema]
|
|
55
|
+
vector_indexes: list[_VectorIndex] | None = None
|
|
56
|
+
db_options: DatabaseOptions | None = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclasses.dataclass
|
|
60
|
+
class _TableKey:
|
|
61
|
+
db_uri: str
|
|
62
|
+
table_name: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
_DbConnectionsLock = threading.Lock()
|
|
66
|
+
_DbConnections: weakref.WeakValueDictionary[str, lancedb.AsyncConnection] = (
|
|
67
|
+
weakref.WeakValueDictionary()
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def connect_async(
|
|
72
|
+
db_uri: str,
|
|
73
|
+
*,
|
|
74
|
+
db_options: DatabaseOptions | None = None,
|
|
75
|
+
read_consistency_interval: datetime.timedelta | None = None,
|
|
76
|
+
) -> lancedb.AsyncConnection:
|
|
77
|
+
"""
|
|
78
|
+
Helper function to connect to a LanceDB database.
|
|
79
|
+
It will reuse the connection if it already exists.
|
|
80
|
+
The connection will be shared with the target used by cocoindex, so it achieves strong consistency.
|
|
81
|
+
"""
|
|
82
|
+
with _DbConnectionsLock:
|
|
83
|
+
conn = _DbConnections.get(db_uri)
|
|
84
|
+
if conn is None:
|
|
85
|
+
db_options = db_options or DatabaseOptions()
|
|
86
|
+
_DbConnections[db_uri] = conn = await lancedb.connect_async(
|
|
87
|
+
db_uri,
|
|
88
|
+
storage_options=db_options.storage_options,
|
|
89
|
+
read_consistency_interval=read_consistency_interval,
|
|
90
|
+
)
|
|
91
|
+
return conn
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_pa_schema(
|
|
95
|
+
key_field_schema: FieldSchema, value_fields_schema: list[FieldSchema]
|
|
96
|
+
) -> pa.Schema:
|
|
97
|
+
"""Convert FieldSchema list to PyArrow schema."""
|
|
98
|
+
fields = [
|
|
99
|
+
_convert_field_to_pa_field(field)
|
|
100
|
+
for field in [key_field_schema] + value_fields_schema
|
|
101
|
+
]
|
|
102
|
+
return pa.schema(fields)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _convert_field_to_pa_field(field_schema: FieldSchema) -> pa.Field:
|
|
106
|
+
"""Convert a FieldSchema to a PyArrow Field."""
|
|
107
|
+
pa_type = _convert_value_type_to_pa_type(field_schema.value_type)
|
|
108
|
+
|
|
109
|
+
# Handle nullable fields
|
|
110
|
+
nullable = field_schema.value_type.nullable
|
|
111
|
+
|
|
112
|
+
return pa.field(field_schema.name, pa_type, nullable=nullable)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _convert_value_type_to_pa_type(value_type: EnrichedValueType) -> pa.DataType:
|
|
116
|
+
"""Convert EnrichedValueType to PyArrow DataType."""
|
|
117
|
+
base_type: ValueType = value_type.type
|
|
118
|
+
|
|
119
|
+
if isinstance(base_type, StructType):
|
|
120
|
+
# Handle struct types
|
|
121
|
+
return _convert_struct_fields_to_pa_type(base_type.fields)
|
|
122
|
+
elif isinstance(base_type, BasicValueType):
|
|
123
|
+
# Handle basic types
|
|
124
|
+
return _convert_basic_type_to_pa_type(base_type)
|
|
125
|
+
elif isinstance(base_type, TableType):
|
|
126
|
+
return pa.list_(_convert_struct_fields_to_pa_type(base_type.row.fields))
|
|
127
|
+
|
|
128
|
+
assert False, f"Unhandled value type: {value_type}"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _convert_struct_fields_to_pa_type(
|
|
132
|
+
fields_schema: list[FieldSchema],
|
|
133
|
+
) -> pa.StructType:
|
|
134
|
+
"""Convert StructType to PyArrow StructType."""
|
|
135
|
+
return pa.struct([_convert_field_to_pa_field(field) for field in fields_schema])
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _convert_basic_type_to_pa_type(basic_type: BasicValueType) -> pa.DataType:
|
|
139
|
+
"""Convert BasicValueType to PyArrow DataType."""
|
|
140
|
+
kind: str = basic_type.kind
|
|
141
|
+
|
|
142
|
+
# Map basic types to PyArrow types
|
|
143
|
+
type_mapping = {
|
|
144
|
+
"Bytes": pa.binary(),
|
|
145
|
+
"Str": pa.string(),
|
|
146
|
+
"Bool": pa.bool_(),
|
|
147
|
+
"Int64": pa.int64(),
|
|
148
|
+
"Float32": pa.float32(),
|
|
149
|
+
"Float64": pa.float64(),
|
|
150
|
+
"Uuid": pa.uuid(),
|
|
151
|
+
"Date": pa.date32(),
|
|
152
|
+
"Time": pa.time64("us"),
|
|
153
|
+
"LocalDateTime": pa.timestamp("us"),
|
|
154
|
+
"OffsetDateTime": pa.timestamp("us", tz="UTC"),
|
|
155
|
+
"TimeDelta": pa.duration("us"),
|
|
156
|
+
"Json": pa.json_(),
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if kind in type_mapping:
|
|
160
|
+
return type_mapping[kind]
|
|
161
|
+
|
|
162
|
+
if kind == "Vector":
|
|
163
|
+
vector_schema: VectorTypeSchema | None = basic_type.vector
|
|
164
|
+
if vector_schema is None:
|
|
165
|
+
raise ValueError("Vector type missing vector schema")
|
|
166
|
+
element_type = _convert_basic_type_to_pa_type(vector_schema.element_type)
|
|
167
|
+
|
|
168
|
+
if vector_schema.dimension is not None:
|
|
169
|
+
return pa.list_(element_type, vector_schema.dimension)
|
|
170
|
+
else:
|
|
171
|
+
return pa.list_(element_type)
|
|
172
|
+
|
|
173
|
+
if kind == "Range":
|
|
174
|
+
# Range as a struct with start and end
|
|
175
|
+
return pa.struct([pa.field("start", pa.int64()), pa.field("end", pa.int64())])
|
|
176
|
+
|
|
177
|
+
assert False, f"Unsupported type kind for LanceDB: {kind}"
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _convert_key_value_to_sql(v: Any) -> str:
|
|
181
|
+
if isinstance(v, str):
|
|
182
|
+
escaped = v.replace("'", "''")
|
|
183
|
+
return f"'{escaped}'"
|
|
184
|
+
|
|
185
|
+
if isinstance(v, uuid.UUID):
|
|
186
|
+
return f"x'{v.hex}'"
|
|
187
|
+
|
|
188
|
+
return str(v)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _convert_fields_to_pyarrow(fields: list[FieldSchema], v: Any) -> Any:
|
|
192
|
+
if isinstance(v, dict):
|
|
193
|
+
return {
|
|
194
|
+
field.name: _convert_value_for_pyarrow(
|
|
195
|
+
field.value_type.type, v.get(field.name)
|
|
196
|
+
)
|
|
197
|
+
for field in fields
|
|
198
|
+
}
|
|
199
|
+
elif isinstance(v, tuple):
|
|
200
|
+
return {
|
|
201
|
+
field.name: _convert_value_for_pyarrow(field.value_type.type, value)
|
|
202
|
+
for field, value in zip(fields, v)
|
|
203
|
+
}
|
|
204
|
+
else:
|
|
205
|
+
field = fields[0]
|
|
206
|
+
return {field.name: _convert_value_for_pyarrow(field.value_type.type, v)}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _convert_value_for_pyarrow(t: ValueType, v: Any) -> Any:
|
|
210
|
+
if v is None:
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
if isinstance(t, BasicValueType):
|
|
214
|
+
if isinstance(v, uuid.UUID):
|
|
215
|
+
return v.bytes
|
|
216
|
+
|
|
217
|
+
if t.kind == "Range":
|
|
218
|
+
return {"start": v[0], "end": v[1]}
|
|
219
|
+
|
|
220
|
+
if t.vector is not None:
|
|
221
|
+
return [_convert_value_for_pyarrow(t.vector.element_type, e) for e in v]
|
|
222
|
+
|
|
223
|
+
return v
|
|
224
|
+
|
|
225
|
+
elif isinstance(t, StructType):
|
|
226
|
+
return _convert_fields_to_pyarrow(t.fields, v)
|
|
227
|
+
|
|
228
|
+
elif isinstance(t, TableType):
|
|
229
|
+
if isinstance(v, list):
|
|
230
|
+
return [_convert_fields_to_pyarrow(t.row.fields, value) for value in v]
|
|
231
|
+
else:
|
|
232
|
+
key_fields = t.row.fields[: t.num_key_parts]
|
|
233
|
+
value_fields = t.row.fields[t.num_key_parts :]
|
|
234
|
+
return [
|
|
235
|
+
_convert_fields_to_pyarrow(key_fields, value[0 : t.num_key_parts])
|
|
236
|
+
| _convert_fields_to_pyarrow(value_fields, value[t.num_key_parts :])
|
|
237
|
+
for value in v
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
assert False, f"Unsupported value type: {t}"
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclasses.dataclass
|
|
244
|
+
class _MutateContext:
|
|
245
|
+
table: lancedb.AsyncTable
|
|
246
|
+
key_field_schema: FieldSchema
|
|
247
|
+
value_fields_type: list[ValueType]
|
|
248
|
+
pa_schema: pa.Schema
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# Not used for now, because of https://github.com/lancedb/lance/issues/3443
|
|
252
|
+
#
|
|
253
|
+
# async def _update_table_schema(
|
|
254
|
+
# table: lancedb.AsyncTable,
|
|
255
|
+
# expected_schema: pa.Schema,
|
|
256
|
+
# ) -> None:
|
|
257
|
+
# existing_schema = await table.schema()
|
|
258
|
+
# unseen_existing_field_names = {field.name: field for field in existing_schema}
|
|
259
|
+
# new_columns = []
|
|
260
|
+
# updated_columns = []
|
|
261
|
+
# for field in expected_schema:
|
|
262
|
+
# existing_field = unseen_existing_field_names.pop(field.name, None)
|
|
263
|
+
# if existing_field is None:
|
|
264
|
+
# new_columns.append(field)
|
|
265
|
+
# else:
|
|
266
|
+
# if field.type != existing_field.type:
|
|
267
|
+
# updated_columns.append(
|
|
268
|
+
# {
|
|
269
|
+
# "path": field.name,
|
|
270
|
+
# "data_type": field.type,
|
|
271
|
+
# "nullable": field.nullable,
|
|
272
|
+
# }
|
|
273
|
+
# )
|
|
274
|
+
# if new_columns:
|
|
275
|
+
# table.add_columns(new_columns)
|
|
276
|
+
# if updated_columns:
|
|
277
|
+
# table.alter_columns(*updated_columns)
|
|
278
|
+
# if unseen_existing_field_names:
|
|
279
|
+
# table.drop_columns(unseen_existing_field_names.keys())
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@op.target_connector(
|
|
283
|
+
spec_cls=LanceDB, persistent_key_type=_TableKey, setup_state_cls=_State
|
|
284
|
+
)
|
|
285
|
+
class _Connector:
|
|
286
|
+
@staticmethod
|
|
287
|
+
def get_persistent_key(spec: LanceDB) -> _TableKey:
|
|
288
|
+
return _TableKey(db_uri=spec.db_uri, table_name=spec.table_name)
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def get_setup_state(
|
|
292
|
+
spec: LanceDB,
|
|
293
|
+
key_fields_schema: list[FieldSchema],
|
|
294
|
+
value_fields_schema: list[FieldSchema],
|
|
295
|
+
index_options: IndexOptions,
|
|
296
|
+
) -> _State:
|
|
297
|
+
if len(key_fields_schema) != 1:
|
|
298
|
+
raise ValueError("LanceDB only supports a single key field")
|
|
299
|
+
if index_options.vector_indexes is not None:
|
|
300
|
+
for vector_index in index_options.vector_indexes:
|
|
301
|
+
if vector_index.method is not None:
|
|
302
|
+
raise ValueError(
|
|
303
|
+
"Vector index method is not configurable for LanceDB yet"
|
|
304
|
+
)
|
|
305
|
+
return _State(
|
|
306
|
+
key_field_schema=key_fields_schema[0],
|
|
307
|
+
value_fields_schema=value_fields_schema,
|
|
308
|
+
db_options=spec.db_options,
|
|
309
|
+
vector_indexes=(
|
|
310
|
+
[
|
|
311
|
+
_VectorIndex(
|
|
312
|
+
name=f"__{index.field_name}__{_LANCEDB_VECTOR_METRIC[index.metric]}__idx",
|
|
313
|
+
field_name=index.field_name,
|
|
314
|
+
metric=index.metric,
|
|
315
|
+
)
|
|
316
|
+
for index in index_options.vector_indexes
|
|
317
|
+
]
|
|
318
|
+
if index_options.vector_indexes is not None
|
|
319
|
+
else None
|
|
320
|
+
),
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
def describe(key: _TableKey) -> str:
|
|
325
|
+
return f"LanceDB table {key.table_name}@{key.db_uri}"
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def check_state_compatibility(
|
|
329
|
+
previous: _State, current: _State
|
|
330
|
+
) -> op.TargetStateCompatibility:
|
|
331
|
+
if (
|
|
332
|
+
previous.key_field_schema != current.key_field_schema
|
|
333
|
+
or previous.value_fields_schema != current.value_fields_schema
|
|
334
|
+
):
|
|
335
|
+
return op.TargetStateCompatibility.NOT_COMPATIBLE
|
|
336
|
+
|
|
337
|
+
return op.TargetStateCompatibility.COMPATIBLE
|
|
338
|
+
|
|
339
|
+
@staticmethod
|
|
340
|
+
async def apply_setup_change(
|
|
341
|
+
key: _TableKey, previous: _State | None, current: _State | None
|
|
342
|
+
) -> None:
|
|
343
|
+
latest_state = current or previous
|
|
344
|
+
if not latest_state:
|
|
345
|
+
return
|
|
346
|
+
db_conn = await connect_async(key.db_uri, db_options=latest_state.db_options)
|
|
347
|
+
|
|
348
|
+
reuse_table = (
|
|
349
|
+
previous is not None
|
|
350
|
+
and current is not None
|
|
351
|
+
and previous.key_field_schema == current.key_field_schema
|
|
352
|
+
and previous.value_fields_schema == current.value_fields_schema
|
|
353
|
+
)
|
|
354
|
+
if previous is not None:
|
|
355
|
+
if not reuse_table:
|
|
356
|
+
await db_conn.drop_table(key.table_name, ignore_missing=True)
|
|
357
|
+
|
|
358
|
+
if current is None:
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
table: lancedb.AsyncTable | None = None
|
|
362
|
+
if reuse_table:
|
|
363
|
+
try:
|
|
364
|
+
table = await db_conn.open_table(key.table_name)
|
|
365
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
366
|
+
_logger.warning(
|
|
367
|
+
"Exception in opening table %s, creating it",
|
|
368
|
+
key.table_name,
|
|
369
|
+
exc_info=e,
|
|
370
|
+
)
|
|
371
|
+
table = None
|
|
372
|
+
|
|
373
|
+
if table is None:
|
|
374
|
+
table = await db_conn.create_table(
|
|
375
|
+
key.table_name,
|
|
376
|
+
schema=make_pa_schema(
|
|
377
|
+
current.key_field_schema, current.value_fields_schema
|
|
378
|
+
),
|
|
379
|
+
mode="overwrite",
|
|
380
|
+
)
|
|
381
|
+
await table.create_index(
|
|
382
|
+
current.key_field_schema.name, config=lancedb.index.BTree()
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
unseen_prev_vector_indexes = {
|
|
386
|
+
index.name for index in (previous and previous.vector_indexes) or []
|
|
387
|
+
}
|
|
388
|
+
existing_vector_indexes = {index.name for index in await table.list_indices()}
|
|
389
|
+
|
|
390
|
+
for index in current.vector_indexes or []:
|
|
391
|
+
if index.name in unseen_prev_vector_indexes:
|
|
392
|
+
unseen_prev_vector_indexes.remove(index.name)
|
|
393
|
+
else:
|
|
394
|
+
try:
|
|
395
|
+
await table.create_index(
|
|
396
|
+
index.field_name,
|
|
397
|
+
name=index.name,
|
|
398
|
+
config=lancedb.index.HnswPq(
|
|
399
|
+
distance_type=_LANCEDB_VECTOR_METRIC[index.metric]
|
|
400
|
+
),
|
|
401
|
+
)
|
|
402
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
403
|
+
raise RuntimeError(
|
|
404
|
+
f"Exception in creating index on field {index.field_name}. "
|
|
405
|
+
f"This may be caused by a limitation of LanceDB, "
|
|
406
|
+
f"which requires data existing in the table to train the index. "
|
|
407
|
+
f"See: https://github.com/lancedb/lance/issues/4034",
|
|
408
|
+
index.name,
|
|
409
|
+
) from e
|
|
410
|
+
|
|
411
|
+
for vector_index_name in unseen_prev_vector_indexes:
|
|
412
|
+
if vector_index_name in existing_vector_indexes:
|
|
413
|
+
await table.drop_index(vector_index_name)
|
|
414
|
+
|
|
415
|
+
@staticmethod
|
|
416
|
+
async def prepare(
|
|
417
|
+
spec: LanceDB,
|
|
418
|
+
setup_state: _State,
|
|
419
|
+
) -> _MutateContext:
|
|
420
|
+
db_conn = await connect_async(spec.db_uri, db_options=spec.db_options)
|
|
421
|
+
table = await db_conn.open_table(spec.table_name)
|
|
422
|
+
return _MutateContext(
|
|
423
|
+
table=table,
|
|
424
|
+
key_field_schema=setup_state.key_field_schema,
|
|
425
|
+
value_fields_type=[
|
|
426
|
+
field.value_type.type for field in setup_state.value_fields_schema
|
|
427
|
+
],
|
|
428
|
+
pa_schema=make_pa_schema(
|
|
429
|
+
setup_state.key_field_schema, setup_state.value_fields_schema
|
|
430
|
+
),
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
async def mutate(
|
|
435
|
+
*all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
|
|
436
|
+
) -> None:
|
|
437
|
+
for context, mutations in all_mutations:
|
|
438
|
+
key_name = context.key_field_schema.name
|
|
439
|
+
value_types = context.value_fields_type
|
|
440
|
+
|
|
441
|
+
rows_to_upserts = []
|
|
442
|
+
keys_sql_to_deletes = []
|
|
443
|
+
for key, value in mutations.items():
|
|
444
|
+
if value is None:
|
|
445
|
+
keys_sql_to_deletes.append(_convert_key_value_to_sql(key))
|
|
446
|
+
else:
|
|
447
|
+
fields = {
|
|
448
|
+
key_name: _convert_value_for_pyarrow(
|
|
449
|
+
context.key_field_schema.value_type.type, key
|
|
450
|
+
)
|
|
451
|
+
}
|
|
452
|
+
for (name, value), value_type in zip(value.items(), value_types):
|
|
453
|
+
fields[name] = _convert_value_for_pyarrow(value_type, value)
|
|
454
|
+
rows_to_upserts.append(fields)
|
|
455
|
+
record_batch = pa.RecordBatch.from_pylist(
|
|
456
|
+
rows_to_upserts, context.pa_schema
|
|
457
|
+
)
|
|
458
|
+
builder = (
|
|
459
|
+
context.table.merge_insert(key_name)
|
|
460
|
+
.when_matched_update_all()
|
|
461
|
+
.when_not_matched_insert_all()
|
|
462
|
+
)
|
|
463
|
+
if keys_sql_to_deletes:
|
|
464
|
+
delete_cond_sql = f"{key_name} IN ({','.join(keys_sql_to_deletes)})"
|
|
465
|
+
builder = builder.when_not_matched_by_source_delete(delete_cond_sql)
|
|
466
|
+
await builder.execute(record_batch)
|
|
File without changes
|