cocoindex 0.2.10__cp311-abi3-macosx_11_0_arm64.whl → 0.2.12__cp311-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cocoindex/__init__.py +5 -0
- cocoindex/_engine.abi3.so +0 -0
- cocoindex/auth_registry.py +8 -15
- cocoindex/convert.py +185 -27
- cocoindex/flow.py +83 -20
- cocoindex/op.py +168 -52
- cocoindex/query_handler.py +51 -0
- cocoindex/runtime.py +8 -1
- cocoindex/targets/__init__.py +5 -0
- cocoindex/{targets.py → targets/_engine_builtin_specs.py} +4 -4
- cocoindex/targets/lancedb.py +460 -0
- cocoindex/tests/test_convert.py +51 -26
- cocoindex/tests/test_load_convert.py +118 -0
- cocoindex/tests/test_typing.py +126 -2
- cocoindex/typing.py +207 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/METADATA +4 -1
- cocoindex-0.2.12.dist-info/RECORD +37 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/licenses/THIRD_PARTY_NOTICES.html +1 -1
- cocoindex-0.2.10.dist-info/RECORD +0 -33
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/WHEEL +0 -0
- {cocoindex-0.2.10.dist-info → cocoindex-0.2.12.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,460 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import logging
|
3
|
+
import threading
|
4
|
+
import uuid
|
5
|
+
import weakref
|
6
|
+
import datetime
|
7
|
+
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import lancedb # type: ignore
|
11
|
+
import pyarrow as pa # type: ignore
|
12
|
+
|
13
|
+
from .. import op
|
14
|
+
from ..typing import (
|
15
|
+
FieldSchema,
|
16
|
+
EnrichedValueType,
|
17
|
+
BasicValueType,
|
18
|
+
StructType,
|
19
|
+
ValueType,
|
20
|
+
VectorTypeSchema,
|
21
|
+
TableType,
|
22
|
+
)
|
23
|
+
from ..index import VectorIndexDef, IndexOptions, VectorSimilarityMetric
|
24
|
+
|
25
|
+
_logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
_LANCEDB_VECTOR_METRIC: dict[VectorSimilarityMetric, str] = {
|
28
|
+
VectorSimilarityMetric.COSINE_SIMILARITY: "cosine",
|
29
|
+
VectorSimilarityMetric.L2_DISTANCE: "l2",
|
30
|
+
VectorSimilarityMetric.INNER_PRODUCT: "dot",
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
class DatabaseOptions:
|
35
|
+
storage_options: dict[str, Any] | None = None
|
36
|
+
|
37
|
+
|
38
|
+
class LanceDB(op.TargetSpec):
|
39
|
+
db_uri: str
|
40
|
+
table_name: str
|
41
|
+
db_options: DatabaseOptions | None = None
|
42
|
+
|
43
|
+
|
44
|
+
@dataclasses.dataclass
|
45
|
+
class _VectorIndex:
|
46
|
+
name: str
|
47
|
+
field_name: str
|
48
|
+
metric: VectorSimilarityMetric
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class _State:
|
53
|
+
key_field_schema: FieldSchema
|
54
|
+
value_fields_schema: list[FieldSchema]
|
55
|
+
vector_indexes: list[_VectorIndex] | None = None
|
56
|
+
db_options: DatabaseOptions | None = None
|
57
|
+
|
58
|
+
|
59
|
+
@dataclasses.dataclass
|
60
|
+
class _TableKey:
|
61
|
+
db_uri: str
|
62
|
+
table_name: str
|
63
|
+
|
64
|
+
|
65
|
+
_DbConnectionsLock = threading.Lock()
|
66
|
+
_DbConnections: weakref.WeakValueDictionary[str, lancedb.AsyncConnection] = (
|
67
|
+
weakref.WeakValueDictionary()
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
async def connect_async(
|
72
|
+
db_uri: str,
|
73
|
+
*,
|
74
|
+
db_options: DatabaseOptions | None = None,
|
75
|
+
read_consistency_interval: datetime.timedelta | None = None,
|
76
|
+
) -> lancedb.AsyncConnection:
|
77
|
+
"""
|
78
|
+
Helper function to connect to a LanceDB database.
|
79
|
+
It will reuse the connection if it already exists.
|
80
|
+
The connection will be shared with the target used by cocoindex, so it achieves strong consistency.
|
81
|
+
"""
|
82
|
+
with _DbConnectionsLock:
|
83
|
+
conn = _DbConnections.get(db_uri)
|
84
|
+
if conn is None:
|
85
|
+
db_options = db_options or DatabaseOptions()
|
86
|
+
_DbConnections[db_uri] = conn = await lancedb.connect_async(
|
87
|
+
db_uri,
|
88
|
+
storage_options=db_options.storage_options,
|
89
|
+
read_consistency_interval=read_consistency_interval,
|
90
|
+
)
|
91
|
+
return conn
|
92
|
+
|
93
|
+
|
94
|
+
def make_pa_schema(
|
95
|
+
key_field_schema: FieldSchema, value_fields_schema: list[FieldSchema]
|
96
|
+
) -> pa.Schema:
|
97
|
+
"""Convert FieldSchema list to PyArrow schema."""
|
98
|
+
fields = [
|
99
|
+
_convert_field_to_pa_field(field)
|
100
|
+
for field in [key_field_schema] + value_fields_schema
|
101
|
+
]
|
102
|
+
return pa.schema(fields)
|
103
|
+
|
104
|
+
|
105
|
+
def _convert_field_to_pa_field(field_schema: FieldSchema) -> pa.Field:
|
106
|
+
"""Convert a FieldSchema to a PyArrow Field."""
|
107
|
+
pa_type = _convert_value_type_to_pa_type(field_schema.value_type)
|
108
|
+
|
109
|
+
# Handle nullable fields
|
110
|
+
nullable = field_schema.value_type.nullable
|
111
|
+
|
112
|
+
return pa.field(field_schema.name, pa_type, nullable=nullable)
|
113
|
+
|
114
|
+
|
115
|
+
def _convert_value_type_to_pa_type(value_type: EnrichedValueType) -> pa.DataType:
|
116
|
+
"""Convert EnrichedValueType to PyArrow DataType."""
|
117
|
+
base_type: ValueType = value_type.type
|
118
|
+
|
119
|
+
if isinstance(base_type, StructType):
|
120
|
+
# Handle struct types
|
121
|
+
return _convert_struct_fields_to_pa_type(base_type.fields)
|
122
|
+
elif isinstance(base_type, BasicValueType):
|
123
|
+
# Handle basic types
|
124
|
+
return _convert_basic_type_to_pa_type(base_type)
|
125
|
+
elif isinstance(base_type, TableType):
|
126
|
+
return pa.list_(_convert_struct_fields_to_pa_type(base_type.row.fields))
|
127
|
+
|
128
|
+
assert False, f"Unhandled value type: {value_type}"
|
129
|
+
|
130
|
+
|
131
|
+
def _convert_struct_fields_to_pa_type(
|
132
|
+
fields_schema: list[FieldSchema],
|
133
|
+
) -> pa.StructType:
|
134
|
+
"""Convert StructType to PyArrow StructType."""
|
135
|
+
return pa.struct([_convert_field_to_pa_field(field) for field in fields_schema])
|
136
|
+
|
137
|
+
|
138
|
+
def _convert_basic_type_to_pa_type(basic_type: BasicValueType) -> pa.DataType:
|
139
|
+
"""Convert BasicValueType to PyArrow DataType."""
|
140
|
+
kind: str = basic_type.kind
|
141
|
+
|
142
|
+
# Map basic types to PyArrow types
|
143
|
+
type_mapping = {
|
144
|
+
"Bytes": pa.binary(),
|
145
|
+
"Str": pa.string(),
|
146
|
+
"Bool": pa.bool_(),
|
147
|
+
"Int64": pa.int64(),
|
148
|
+
"Float32": pa.float32(),
|
149
|
+
"Float64": pa.float64(),
|
150
|
+
"Uuid": pa.uuid(),
|
151
|
+
"Date": pa.date32(),
|
152
|
+
"Time": pa.time64("us"),
|
153
|
+
"LocalDateTime": pa.timestamp("us"),
|
154
|
+
"OffsetDateTime": pa.timestamp("us", tz="UTC"),
|
155
|
+
"TimeDelta": pa.duration("us"),
|
156
|
+
"Json": pa.json_(),
|
157
|
+
}
|
158
|
+
|
159
|
+
if kind in type_mapping:
|
160
|
+
return type_mapping[kind]
|
161
|
+
|
162
|
+
if kind == "Vector":
|
163
|
+
vector_schema: VectorTypeSchema | None = basic_type.vector
|
164
|
+
if vector_schema is None:
|
165
|
+
raise ValueError("Vector type missing vector schema")
|
166
|
+
element_type = _convert_basic_type_to_pa_type(vector_schema.element_type)
|
167
|
+
|
168
|
+
if vector_schema.dimension is not None:
|
169
|
+
return pa.list_(element_type, vector_schema.dimension)
|
170
|
+
else:
|
171
|
+
return pa.list_(element_type)
|
172
|
+
|
173
|
+
if kind == "Range":
|
174
|
+
# Range as a struct with start and end
|
175
|
+
return pa.struct([pa.field("start", pa.int64()), pa.field("end", pa.int64())])
|
176
|
+
|
177
|
+
assert False, f"Unsupported type kind for LanceDB: {kind}"
|
178
|
+
|
179
|
+
|
180
|
+
def _convert_key_value_to_sql(v: Any) -> str:
|
181
|
+
if isinstance(v, str):
|
182
|
+
escaped = v.replace("'", "''")
|
183
|
+
return f"'{escaped}'"
|
184
|
+
|
185
|
+
if isinstance(v, uuid.UUID):
|
186
|
+
return f"x'{v.hex}'"
|
187
|
+
|
188
|
+
return str(v)
|
189
|
+
|
190
|
+
|
191
|
+
def _convert_fields_to_pyarrow(fields: list[FieldSchema], v: Any) -> Any:
|
192
|
+
if isinstance(v, dict):
|
193
|
+
return {
|
194
|
+
field.name: _convert_value_for_pyarrow(
|
195
|
+
field.value_type.type, v.get(field.name)
|
196
|
+
)
|
197
|
+
for field in fields
|
198
|
+
}
|
199
|
+
elif isinstance(v, tuple):
|
200
|
+
return {
|
201
|
+
field.name: _convert_value_for_pyarrow(field.value_type.type, value)
|
202
|
+
for field, value in zip(fields, v)
|
203
|
+
}
|
204
|
+
else:
|
205
|
+
field = fields[0]
|
206
|
+
return {field.name: _convert_value_for_pyarrow(field.value_type.type, v)}
|
207
|
+
|
208
|
+
|
209
|
+
def _convert_value_for_pyarrow(t: ValueType, v: Any) -> Any:
|
210
|
+
if v is None:
|
211
|
+
return None
|
212
|
+
|
213
|
+
if isinstance(t, BasicValueType):
|
214
|
+
if isinstance(v, uuid.UUID):
|
215
|
+
return v.bytes
|
216
|
+
|
217
|
+
if t.kind == "Range":
|
218
|
+
return {"start": v[0], "end": v[1]}
|
219
|
+
|
220
|
+
if t.vector is not None:
|
221
|
+
return [_convert_value_for_pyarrow(t.vector.element_type, e) for e in v]
|
222
|
+
|
223
|
+
return v
|
224
|
+
|
225
|
+
elif isinstance(t, StructType):
|
226
|
+
return _convert_fields_to_pyarrow(t.fields, v)
|
227
|
+
|
228
|
+
elif isinstance(t, TableType):
|
229
|
+
if isinstance(v, list):
|
230
|
+
return [_convert_fields_to_pyarrow(t.row.fields, value) for value in v]
|
231
|
+
else:
|
232
|
+
key_fields = t.row.fields[: t.num_key_parts]
|
233
|
+
value_fields = t.row.fields[t.num_key_parts :]
|
234
|
+
return [
|
235
|
+
_convert_fields_to_pyarrow(key_fields, value[0 : t.num_key_parts])
|
236
|
+
| _convert_fields_to_pyarrow(value_fields, value[t.num_key_parts :])
|
237
|
+
for value in v
|
238
|
+
]
|
239
|
+
|
240
|
+
assert False, f"Unsupported value type: {t}"
|
241
|
+
|
242
|
+
|
243
|
+
@dataclasses.dataclass
|
244
|
+
class _MutateContext:
|
245
|
+
table: lancedb.AsyncTable
|
246
|
+
key_field_schema: FieldSchema
|
247
|
+
value_fields_type: list[ValueType]
|
248
|
+
pa_schema: pa.Schema
|
249
|
+
|
250
|
+
|
251
|
+
# Not used for now, because of https://github.com/lancedb/lance/issues/3443
|
252
|
+
#
|
253
|
+
# async def _update_table_schema(
|
254
|
+
# table: lancedb.AsyncTable,
|
255
|
+
# expected_schema: pa.Schema,
|
256
|
+
# ) -> None:
|
257
|
+
# existing_schema = await table.schema()
|
258
|
+
# unseen_existing_field_names = {field.name: field for field in existing_schema}
|
259
|
+
# new_columns = []
|
260
|
+
# updated_columns = []
|
261
|
+
# for field in expected_schema:
|
262
|
+
# existing_field = unseen_existing_field_names.pop(field.name, None)
|
263
|
+
# if existing_field is None:
|
264
|
+
# new_columns.append(field)
|
265
|
+
# else:
|
266
|
+
# if field.type != existing_field.type:
|
267
|
+
# updated_columns.append(
|
268
|
+
# {
|
269
|
+
# "path": field.name,
|
270
|
+
# "data_type": field.type,
|
271
|
+
# "nullable": field.nullable,
|
272
|
+
# }
|
273
|
+
# )
|
274
|
+
# if new_columns:
|
275
|
+
# table.add_columns(new_columns)
|
276
|
+
# if updated_columns:
|
277
|
+
# table.alter_columns(*updated_columns)
|
278
|
+
# if unseen_existing_field_names:
|
279
|
+
# table.drop_columns(unseen_existing_field_names.keys())
|
280
|
+
|
281
|
+
|
282
|
+
@op.target_connector(
|
283
|
+
spec_cls=LanceDB, persistent_key_type=_TableKey, setup_state_cls=_State
|
284
|
+
)
|
285
|
+
class _Connector:
|
286
|
+
@staticmethod
|
287
|
+
def get_persistent_key(spec: LanceDB) -> _TableKey:
|
288
|
+
return _TableKey(db_uri=spec.db_uri, table_name=spec.table_name)
|
289
|
+
|
290
|
+
@staticmethod
|
291
|
+
def get_setup_state(
|
292
|
+
spec: LanceDB,
|
293
|
+
key_fields_schema: list[FieldSchema],
|
294
|
+
value_fields_schema: list[FieldSchema],
|
295
|
+
index_options: IndexOptions,
|
296
|
+
) -> _State:
|
297
|
+
if len(key_fields_schema) != 1:
|
298
|
+
raise ValueError("LanceDB only supports a single key field")
|
299
|
+
return _State(
|
300
|
+
key_field_schema=key_fields_schema[0],
|
301
|
+
value_fields_schema=value_fields_schema,
|
302
|
+
db_options=spec.db_options,
|
303
|
+
vector_indexes=(
|
304
|
+
[
|
305
|
+
_VectorIndex(
|
306
|
+
name=f"__{index.field_name}__{_LANCEDB_VECTOR_METRIC[index.metric]}__idx",
|
307
|
+
field_name=index.field_name,
|
308
|
+
metric=index.metric,
|
309
|
+
)
|
310
|
+
for index in index_options.vector_indexes
|
311
|
+
]
|
312
|
+
if index_options.vector_indexes is not None
|
313
|
+
else None
|
314
|
+
),
|
315
|
+
)
|
316
|
+
|
317
|
+
@staticmethod
|
318
|
+
def describe(key: _TableKey) -> str:
|
319
|
+
return f"LanceDB table {key.table_name}@{key.db_uri}"
|
320
|
+
|
321
|
+
@staticmethod
|
322
|
+
def check_state_compatibility(
|
323
|
+
previous: _State, current: _State
|
324
|
+
) -> op.TargetStateCompatibility:
|
325
|
+
if (
|
326
|
+
previous.key_field_schema != current.key_field_schema
|
327
|
+
or previous.value_fields_schema != current.value_fields_schema
|
328
|
+
):
|
329
|
+
return op.TargetStateCompatibility.NOT_COMPATIBLE
|
330
|
+
|
331
|
+
return op.TargetStateCompatibility.COMPATIBLE
|
332
|
+
|
333
|
+
@staticmethod
|
334
|
+
async def apply_setup_change(
|
335
|
+
key: _TableKey, previous: _State | None, current: _State | None
|
336
|
+
) -> None:
|
337
|
+
latest_state = current or previous
|
338
|
+
if not latest_state:
|
339
|
+
return
|
340
|
+
db_conn = await connect_async(key.db_uri, db_options=latest_state.db_options)
|
341
|
+
|
342
|
+
reuse_table = (
|
343
|
+
previous is not None
|
344
|
+
and current is not None
|
345
|
+
and previous.key_field_schema == current.key_field_schema
|
346
|
+
and previous.value_fields_schema == current.value_fields_schema
|
347
|
+
)
|
348
|
+
if previous is not None:
|
349
|
+
if not reuse_table:
|
350
|
+
await db_conn.drop_table(key.table_name, ignore_missing=True)
|
351
|
+
|
352
|
+
if current is None:
|
353
|
+
return
|
354
|
+
|
355
|
+
table: lancedb.AsyncTable | None = None
|
356
|
+
if reuse_table:
|
357
|
+
try:
|
358
|
+
table = await db_conn.open_table(key.table_name)
|
359
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
360
|
+
_logger.warning(
|
361
|
+
"Exception in opening table %s, creating it",
|
362
|
+
key.table_name,
|
363
|
+
exc_info=e,
|
364
|
+
)
|
365
|
+
table = None
|
366
|
+
|
367
|
+
if table is None:
|
368
|
+
table = await db_conn.create_table(
|
369
|
+
key.table_name,
|
370
|
+
schema=make_pa_schema(
|
371
|
+
current.key_field_schema, current.value_fields_schema
|
372
|
+
),
|
373
|
+
mode="overwrite",
|
374
|
+
)
|
375
|
+
await table.create_index(
|
376
|
+
current.key_field_schema.name, config=lancedb.index.BTree()
|
377
|
+
)
|
378
|
+
|
379
|
+
unseen_prev_vector_indexes = {
|
380
|
+
index.name for index in (previous and previous.vector_indexes) or []
|
381
|
+
}
|
382
|
+
existing_vector_indexes = {index.name for index in await table.list_indices()}
|
383
|
+
|
384
|
+
for index in current.vector_indexes or []:
|
385
|
+
if index.name in unseen_prev_vector_indexes:
|
386
|
+
unseen_prev_vector_indexes.remove(index.name)
|
387
|
+
else:
|
388
|
+
try:
|
389
|
+
await table.create_index(
|
390
|
+
index.field_name,
|
391
|
+
name=index.name,
|
392
|
+
config=lancedb.index.HnswPq(
|
393
|
+
distance_type=_LANCEDB_VECTOR_METRIC[index.metric]
|
394
|
+
),
|
395
|
+
)
|
396
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
397
|
+
raise RuntimeError(
|
398
|
+
f"Exception in creating index on field {index.field_name}. "
|
399
|
+
f"This may be caused by a limitation of LanceDB, "
|
400
|
+
f"which requires data existing in the table to train the index. "
|
401
|
+
f"See: https://github.com/lancedb/lance/issues/4034",
|
402
|
+
index.name,
|
403
|
+
) from e
|
404
|
+
|
405
|
+
for vector_index_name in unseen_prev_vector_indexes:
|
406
|
+
if vector_index_name in existing_vector_indexes:
|
407
|
+
await table.drop_index(vector_index_name)
|
408
|
+
|
409
|
+
@staticmethod
|
410
|
+
async def prepare(
|
411
|
+
spec: LanceDB,
|
412
|
+
setup_state: _State,
|
413
|
+
) -> _MutateContext:
|
414
|
+
db_conn = await connect_async(spec.db_uri, db_options=spec.db_options)
|
415
|
+
table = await db_conn.open_table(spec.table_name)
|
416
|
+
return _MutateContext(
|
417
|
+
table=table,
|
418
|
+
key_field_schema=setup_state.key_field_schema,
|
419
|
+
value_fields_type=[
|
420
|
+
field.value_type.type for field in setup_state.value_fields_schema
|
421
|
+
],
|
422
|
+
pa_schema=make_pa_schema(
|
423
|
+
setup_state.key_field_schema, setup_state.value_fields_schema
|
424
|
+
),
|
425
|
+
)
|
426
|
+
|
427
|
+
@staticmethod
|
428
|
+
async def mutate(
|
429
|
+
*all_mutations: tuple[_MutateContext, dict[Any, dict[str, Any] | None]],
|
430
|
+
) -> None:
|
431
|
+
for context, mutations in all_mutations:
|
432
|
+
key_name = context.key_field_schema.name
|
433
|
+
value_types = context.value_fields_type
|
434
|
+
|
435
|
+
rows_to_upserts = []
|
436
|
+
keys_sql_to_deletes = []
|
437
|
+
for key, value in mutations.items():
|
438
|
+
if value is None:
|
439
|
+
keys_sql_to_deletes.append(_convert_key_value_to_sql(key))
|
440
|
+
else:
|
441
|
+
fields = {
|
442
|
+
key_name: _convert_value_for_pyarrow(
|
443
|
+
context.key_field_schema.value_type.type, key
|
444
|
+
)
|
445
|
+
}
|
446
|
+
for (name, value), value_type in zip(value.items(), value_types):
|
447
|
+
fields[name] = _convert_value_for_pyarrow(value_type, value)
|
448
|
+
rows_to_upserts.append(fields)
|
449
|
+
record_batch = pa.RecordBatch.from_pylist(
|
450
|
+
rows_to_upserts, context.pa_schema
|
451
|
+
)
|
452
|
+
builder = (
|
453
|
+
context.table.merge_insert(key_name)
|
454
|
+
.when_matched_update_all()
|
455
|
+
.when_not_matched_insert_all()
|
456
|
+
)
|
457
|
+
if keys_sql_to_deletes:
|
458
|
+
delete_cond_sql = f"{key_name} IN ({','.join(keys_sql_to_deletes)})"
|
459
|
+
builder = builder.when_not_matched_by_source_delete(delete_cond_sql)
|
460
|
+
await builder.execute(record_batch)
|
cocoindex/tests/test_convert.py
CHANGED
@@ -21,6 +21,7 @@ from cocoindex.typing import (
|
|
21
21
|
Vector,
|
22
22
|
analyze_type_info,
|
23
23
|
encode_enriched_type,
|
24
|
+
decode_engine_value_type,
|
24
25
|
)
|
25
26
|
|
26
27
|
|
@@ -86,7 +87,9 @@ def build_engine_value_decoder(
|
|
86
87
|
"""
|
87
88
|
engine_type = encode_enriched_type(engine_type_in_py)["type"]
|
88
89
|
return make_engine_value_decoder(
|
89
|
-
[],
|
90
|
+
[],
|
91
|
+
decode_engine_value_type(engine_type),
|
92
|
+
analyze_type_info(python_type or engine_type_in_py),
|
90
93
|
)
|
91
94
|
|
92
95
|
|
@@ -116,7 +119,9 @@ def validate_full_roundtrip_to(
|
|
116
119
|
|
117
120
|
for other_value, other_type in decoded_values:
|
118
121
|
decoder = make_engine_value_decoder(
|
119
|
-
[],
|
122
|
+
[],
|
123
|
+
decode_engine_value_type(encoded_output_type),
|
124
|
+
analyze_type_info(other_type),
|
120
125
|
)
|
121
126
|
other_decoded_value = decoder(value_from_engine)
|
122
127
|
assert eq(other_decoded_value, other_value), (
|
@@ -383,9 +388,19 @@ def test_roundtrip_json() -> None:
|
|
383
388
|
|
384
389
|
def test_decode_scalar_numpy_values() -> None:
|
385
390
|
test_cases = [
|
386
|
-
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
|
387
|
-
(
|
388
|
-
|
391
|
+
(decode_engine_value_type({"kind": "Int64"}), np.int64, 42, np.int64(42)),
|
392
|
+
(
|
393
|
+
decode_engine_value_type({"kind": "Float32"}),
|
394
|
+
np.float32,
|
395
|
+
3.14,
|
396
|
+
np.float32(3.14),
|
397
|
+
),
|
398
|
+
(
|
399
|
+
decode_engine_value_type({"kind": "Float64"}),
|
400
|
+
np.float64,
|
401
|
+
2.718,
|
402
|
+
np.float64(2.718),
|
403
|
+
),
|
389
404
|
]
|
390
405
|
for src_type, dst_type, input_value, expected in test_cases:
|
391
406
|
decoder = make_engine_value_decoder(
|
@@ -398,11 +413,13 @@ def test_decode_scalar_numpy_values() -> None:
|
|
398
413
|
|
399
414
|
def test_non_ndarray_vector_decoding() -> None:
|
400
415
|
# Test list[np.float64]
|
401
|
-
src_type =
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
416
|
+
src_type = decode_engine_value_type(
|
417
|
+
{
|
418
|
+
"kind": "Vector",
|
419
|
+
"element_type": {"kind": "Float64"},
|
420
|
+
"dimension": None,
|
421
|
+
}
|
422
|
+
)
|
406
423
|
dst_type_float = list[np.float64]
|
407
424
|
decoder = make_engine_value_decoder(
|
408
425
|
["field"], src_type, analyze_type_info(dst_type_float)
|
@@ -414,7 +431,9 @@ def test_non_ndarray_vector_decoding() -> None:
|
|
414
431
|
assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
|
415
432
|
|
416
433
|
# Test list[Uuid]
|
417
|
-
src_type =
|
434
|
+
src_type = decode_engine_value_type(
|
435
|
+
{"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
|
436
|
+
)
|
418
437
|
dst_type_uuid = list[uuid.UUID]
|
419
438
|
decoder = make_engine_value_decoder(
|
420
439
|
["field"], src_type, analyze_type_info(dst_type_uuid)
|
@@ -895,11 +914,13 @@ def test_encode_complex_structure_with_ndarray() -> None:
|
|
895
914
|
|
896
915
|
def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
897
916
|
"""Test decoding a nullable NDArray with None or value inputs."""
|
898
|
-
src_type_dict =
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
917
|
+
src_type_dict = decode_engine_value_type(
|
918
|
+
{
|
919
|
+
"kind": "Vector",
|
920
|
+
"element_type": {"kind": "Float32"},
|
921
|
+
"dimension": None,
|
922
|
+
}
|
923
|
+
)
|
903
924
|
dst_annotation = NDArrayFloat32Type | None
|
904
925
|
decoder = make_engine_value_decoder(
|
905
926
|
[], src_type_dict, analyze_type_info(dst_annotation)
|
@@ -921,11 +942,13 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None:
|
|
921
942
|
|
922
943
|
def test_decode_vector_string() -> None:
|
923
944
|
"""Test decoding a vector of strings works for Python native list type."""
|
924
|
-
src_type_dict =
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
945
|
+
src_type_dict = decode_engine_value_type(
|
946
|
+
{
|
947
|
+
"kind": "Vector",
|
948
|
+
"element_type": {"kind": "Str"},
|
949
|
+
"dimension": None,
|
950
|
+
}
|
951
|
+
)
|
929
952
|
decoder = make_engine_value_decoder(
|
930
953
|
[], src_type_dict, analyze_type_info(Vector[str])
|
931
954
|
)
|
@@ -934,11 +957,13 @@ def test_decode_vector_string() -> None:
|
|
934
957
|
|
935
958
|
def test_decode_error_non_nullable_or_non_list_vector() -> None:
|
936
959
|
"""Test decoding errors for non-nullable vectors or non-list inputs."""
|
937
|
-
src_type_dict =
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
960
|
+
src_type_dict = decode_engine_value_type(
|
961
|
+
{
|
962
|
+
"kind": "Vector",
|
963
|
+
"element_type": {"kind": "Float32"},
|
964
|
+
"dimension": None,
|
965
|
+
}
|
966
|
+
)
|
942
967
|
decoder = make_engine_value_decoder(
|
943
968
|
[], src_type_dict, analyze_type_info(NDArrayFloat32Type)
|
944
969
|
)
|