pyvastbase 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyvastbase/__init__.py +364 -0
- pyvastbase/async_impl/__init__.py +159 -0
- pyvastbase/async_impl/collection.py +1367 -0
- pyvastbase/async_impl/connections.py +561 -0
- pyvastbase/async_impl/index_builder.py +299 -0
- pyvastbase/async_impl/search_builder.py +338 -0
- pyvastbase/async_impl/utility.py +814 -0
- pyvastbase/client.py +611 -0
- pyvastbase/collection.py +1053 -0
- pyvastbase/core/__init__.py +1 -0
- pyvastbase/core/collection_core.py +902 -0
- pyvastbase/core/connections.py +443 -0
- pyvastbase/core/constants.py +202 -0
- pyvastbase/core/exceptions.py +166 -0
- pyvastbase/core/executor.py +115 -0
- pyvastbase/core/mutation_result.py +143 -0
- pyvastbase/core/schema.py +610 -0
- pyvastbase/core/search_result.py +138 -0
- pyvastbase/core/vector_types.py +224 -0
- pyvastbase/executor/__init__.py +1 -0
- pyvastbase/executor/async_impl.py +144 -0
- pyvastbase/executor/sync.py +119 -0
- pyvastbase/index/__init__.py +1 -0
- pyvastbase/index/index_builder.py +485 -0
- pyvastbase/operations/__init__.py +130 -0
- pyvastbase/operations/admin.py +256 -0
- pyvastbase/operations/ddl.py +75 -0
- pyvastbase/operations/delete.py +44 -0
- pyvastbase/operations/index.py +170 -0
- pyvastbase/operations/insert.py +131 -0
- pyvastbase/operations/query.py +94 -0
- pyvastbase/operations/result.py +35 -0
- pyvastbase/operations/schema_reflect.py +106 -0
- pyvastbase/operations/search.py +40 -0
- pyvastbase/operations/search_result.py +54 -0
- pyvastbase/operations/upsert.py +34 -0
- pyvastbase/orm/__init__.py +66 -0
- pyvastbase/orm/base.py +711 -0
- pyvastbase/orm/fields.py +280 -0
- pyvastbase/search/__init__.py +1 -0
- pyvastbase/search/search_builder.py +691 -0
- pyvastbase/test_conn.py +17 -0
- pyvastbase/utility.py +2677 -0
- pyvastbase/utils/__init__.py +163 -0
- pyvastbase/utils/distance_utils.py +531 -0
- pyvastbase/utils/filter_utils.py +333 -0
- pyvastbase/utils/index_utils.py +387 -0
- pyvastbase/utils/result_utils.py +435 -0
- pyvastbase/utils/sql_utils.py +230 -0
- pyvastbase/utils/vector_utils.py +708 -0
- pyvastbase-0.2.2.dist-info/METADATA +584 -0
- pyvastbase-0.2.2.dist-info/RECORD +54 -0
- pyvastbase-0.2.2.dist-info/WHEEL +5 -0
- pyvastbase-0.2.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1367 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vastbase Vector Database SDK - Async Collection Module
|
|
3
|
+
|
|
4
|
+
Async implementation of Collection class.
|
|
5
|
+
Fully compatible with the sync API - same method names, parameters, and return types.
|
|
6
|
+
|
|
7
|
+
All methods are async and return awaitables. Use 'await' to call them.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
from pyvastbase.async import AsyncCollection, async_connect
|
|
11
|
+
|
|
12
|
+
await async_connect(host="localhost", database="mydb")
|
|
13
|
+
collection = AsyncCollection("my_vectors")
|
|
14
|
+
await collection.insert([{"id": 1, "embedding": [0.1] * 128}])
|
|
15
|
+
|
|
16
|
+
results = await collection.search(data=[[0.15] * 128], limit=10)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import functools
|
|
21
|
+
from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING, Iterator
|
|
22
|
+
|
|
23
|
+
from ..core.constants import DistanceType, IndexType
|
|
24
|
+
from ..core.collection_core import CollectionCore
|
|
25
|
+
from ..core.exceptions import (
|
|
26
|
+
CollectionError, CollectionNotExistsError,
|
|
27
|
+
SchemaError, ParamError, DataError, VectorDimensionError
|
|
28
|
+
)
|
|
29
|
+
from ..core.mutation_result import MutationResult
|
|
30
|
+
from ..operations import (
|
|
31
|
+
build_delete_sql,
|
|
32
|
+
build_upsert_clauses,
|
|
33
|
+
classify_insert_columns,
|
|
34
|
+
build_columns_query_sql,
|
|
35
|
+
build_existence_check_sql,
|
|
36
|
+
build_pk_query_sql,
|
|
37
|
+
parse_column_row,
|
|
38
|
+
validate_and_normalize_queries,
|
|
39
|
+
)
|
|
40
|
+
from ..core.schema import CollectionSchema, FieldSchema, DataType
|
|
41
|
+
from ..index.index_builder import IndexParams, IndexBuilder
|
|
42
|
+
from ..search.search_builder import SearchParams, SearchQueryBuilder, SearchResult
|
|
43
|
+
from ..utils import (
|
|
44
|
+
normalize_vector, validate_vector, ensure_vector_dim,
|
|
45
|
+
batch_iterator,
|
|
46
|
+
format_entity, extract_ids, to_list_dict,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if TYPE_CHECKING:
|
|
50
|
+
from .connections import AsyncVastbaseConnection
|
|
51
|
+
|
|
52
|
+
from ..executor.async_impl import AsyncExecutor, AsyncTransaction
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class QueryIterator:
|
|
56
|
+
"""
|
|
57
|
+
Async iterator for query results in batches.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
# Iterate over results in batches
|
|
61
|
+
async for batch in collection.query(batch_size=100):
|
|
62
|
+
await process(batch)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
collection: 'AsyncCollection',
|
|
68
|
+
filter_expr: str = "",
|
|
69
|
+
batch_size: int = 100,
|
|
70
|
+
output_fields: Optional[List[str]] = None,
|
|
71
|
+
limit: Optional[int] = None,
|
|
72
|
+
offset: Optional[int] = None
|
|
73
|
+
):
|
|
74
|
+
self._collection = collection
|
|
75
|
+
self._filter_expr = filter_expr
|
|
76
|
+
self._batch_size = batch_size
|
|
77
|
+
self._output_fields = output_fields
|
|
78
|
+
self._limit = limit
|
|
79
|
+
self._offset = offset or 0
|
|
80
|
+
self._buffer: List[Dict[str, Any]] = []
|
|
81
|
+
self._exhausted = False
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def batch_size(self) -> int:
|
|
85
|
+
"""Get batch size (for test compatibility)"""
|
|
86
|
+
return self._batch_size
|
|
87
|
+
|
|
88
|
+
def __aiter__(self) -> 'QueryIterator':
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
async def __anext__(self) -> List[Dict[str, Any]]:
|
|
92
|
+
if self._exhausted:
|
|
93
|
+
raise StopAsyncIteration
|
|
94
|
+
|
|
95
|
+
# If buffer is empty, fetch next batch
|
|
96
|
+
if not self._buffer:
|
|
97
|
+
remaining = self._limit - self._offset if self._limit else None
|
|
98
|
+
fetch_limit = min(self._batch_size, remaining) if remaining else self._batch_size
|
|
99
|
+
|
|
100
|
+
rows = await self._collection._fetch_batch(
|
|
101
|
+
filter_expr=self._filter_expr,
|
|
102
|
+
output_fields=self._output_fields,
|
|
103
|
+
limit=fetch_limit,
|
|
104
|
+
offset=self._offset
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if not rows:
|
|
108
|
+
self._exhausted = True
|
|
109
|
+
raise StopAsyncIteration
|
|
110
|
+
|
|
111
|
+
self._buffer = rows
|
|
112
|
+
|
|
113
|
+
# Return batch from buffer
|
|
114
|
+
batch = self._buffer[:self._batch_size]
|
|
115
|
+
self._buffer = self._buffer[self._batch_size:]
|
|
116
|
+
self._offset += len(batch)
|
|
117
|
+
|
|
118
|
+
return batch
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ---------------------------------------------------------------------------
|
|
122
|
+
# Decorator — eliminates connection boilerplate from every public method
|
|
123
|
+
# ---------------------------------------------------------------------------
|
|
124
|
+
|
|
125
|
+
def _with_connection(method=None, *, schema: bool = False):
|
|
126
|
+
"""Ensure a live connection (and optionally a loaded schema) before the method runs.
|
|
127
|
+
|
|
128
|
+
Supports both ``@_with_connection`` and ``@_with_connection(schema=True)``.
|
|
129
|
+
|
|
130
|
+
The async version calls sync ``_get_connection()`` and awaits
|
|
131
|
+
``_ensure_connection()`` and (when *schema* is True) ``_ensure_schema_loaded()``.
|
|
132
|
+
"""
|
|
133
|
+
def deco(m):
|
|
134
|
+
@functools.wraps(m)
|
|
135
|
+
async def wrapper(self, *args, **kwargs):
|
|
136
|
+
self._get_connection()
|
|
137
|
+
await self._ensure_connection()
|
|
138
|
+
if schema:
|
|
139
|
+
await self._ensure_schema_loaded()
|
|
140
|
+
return await m(self, *args, **kwargs)
|
|
141
|
+
return wrapper
|
|
142
|
+
|
|
143
|
+
if method is not None:
|
|
144
|
+
return deco(method)
|
|
145
|
+
return deco
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class AsyncCollection:
|
|
149
|
+
"""
|
|
150
|
+
Async version of Collection class.
|
|
151
|
+
|
|
152
|
+
Provides the same API as Collection, but all database operations are async.
|
|
153
|
+
Use 'await' to call methods.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
from pyvastbase.async import AsyncCollection, async_connect
|
|
157
|
+
|
|
158
|
+
await async_connect(host="localhost", database="mydb")
|
|
159
|
+
collection = AsyncCollection("my_vectors")
|
|
160
|
+
|
|
161
|
+
# Insert data
|
|
162
|
+
await collection.insert([
|
|
163
|
+
{"id": 1, "embedding": [0.1] * 128, "text": "hello"},
|
|
164
|
+
{"id": 2, "embedding": [0.2] * 128, "text": "world"}
|
|
165
|
+
])
|
|
166
|
+
|
|
167
|
+
# Search
|
|
168
|
+
results = await collection.search(
|
|
169
|
+
data=[[0.15] * 128],
|
|
170
|
+
anns_field="embedding",
|
|
171
|
+
limit=10
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Query with async iterator
|
|
175
|
+
async for batch in collection.query(batch_size=100):
|
|
176
|
+
print(batch)
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
name: str,
|
|
182
|
+
schema: Optional[CollectionSchema] = None,
|
|
183
|
+
using: str = "default",
|
|
184
|
+
timeout: Optional[float] = None,
|
|
185
|
+
_connection_provider=None,
|
|
186
|
+
**kwargs
|
|
187
|
+
):
|
|
188
|
+
self._name = name
|
|
189
|
+
self._schema = schema
|
|
190
|
+
self._using = using
|
|
191
|
+
self._timeout = timeout
|
|
192
|
+
self._connection_provider = _connection_provider
|
|
193
|
+
self._extra_params = kwargs
|
|
194
|
+
|
|
195
|
+
self._conn: Optional['AsyncVastbaseConnection'] = None
|
|
196
|
+
self._core_obj: Optional[CollectionCore] = None
|
|
197
|
+
self._executor: Optional[AsyncExecutor] = None
|
|
198
|
+
|
|
199
|
+
# =========================================================================
|
|
200
|
+
# Properties
|
|
201
|
+
# =========================================================================
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def schema(self) -> CollectionSchema:
|
|
205
|
+
"""Get collection schema (synchronous, cached)"""
|
|
206
|
+
if self._schema is None:
|
|
207
|
+
raise CollectionError(
|
|
208
|
+
"Schema not loaded. Use 'await collection.schema_async' or "
|
|
209
|
+
"call 'await collection.load_schema()' first."
|
|
210
|
+
)
|
|
211
|
+
return self._schema
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
async def schema_async(self) -> CollectionSchema:
|
|
215
|
+
"""Get collection schema (async, auto-loads if needed)"""
|
|
216
|
+
if self._schema is None:
|
|
217
|
+
await self._load_schema_async()
|
|
218
|
+
return self._schema
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def name(self) -> str:
|
|
222
|
+
"""Get collection name"""
|
|
223
|
+
return self._name
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def description(self) -> str:
|
|
227
|
+
"""Get collection description"""
|
|
228
|
+
return self.schema.description
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
async def num_entities(self) -> int:
|
|
232
|
+
"""Get number of entities in collection (async property)"""
|
|
233
|
+
try:
|
|
234
|
+
rows = await self._executor.execute(
|
|
235
|
+
f'SELECT COUNT(*) AS cnt FROM "{self._name}"', []
|
|
236
|
+
)
|
|
237
|
+
if rows:
|
|
238
|
+
return rows[0].get("cnt", next(iter(rows[0].values())))
|
|
239
|
+
return 0
|
|
240
|
+
except Exception as e:
|
|
241
|
+
raise CollectionError(f"Failed to get entity count: {e}")
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
async def is_empty(self) -> bool:
|
|
245
|
+
"""Check if collection is empty"""
|
|
246
|
+
count = await self.num_entities
|
|
247
|
+
return count == 0
|
|
248
|
+
|
|
249
|
+
async def index_params_async(self) -> Dict[str, Any]:
|
|
250
|
+
"""Get current index parameters (async version)"""
|
|
251
|
+
try:
|
|
252
|
+
rows = await self._executor.execute(
|
|
253
|
+
"""
|
|
254
|
+
SELECT indexname, indexdef
|
|
255
|
+
FROM pg_indexes
|
|
256
|
+
WHERE tablename = %s
|
|
257
|
+
""",
|
|
258
|
+
[self._name]
|
|
259
|
+
)
|
|
260
|
+
return {row["indexname"]: row["indexdef"] for row in rows}
|
|
261
|
+
except Exception as e:
|
|
262
|
+
raise CollectionError(f"Failed to get index params: {e}")
|
|
263
|
+
|
|
264
|
+
# =========================================================================
|
|
265
|
+
# Connection Management
|
|
266
|
+
# =========================================================================
|
|
267
|
+
|
|
268
|
+
def _get_connection(self) -> 'AsyncVastbaseConnection':
|
|
269
|
+
"""Get database connection"""
|
|
270
|
+
if self._conn is None:
|
|
271
|
+
if self._connection_provider:
|
|
272
|
+
self._conn = self._connection_provider.get_connection(self._using)
|
|
273
|
+
else:
|
|
274
|
+
from .connections import AsyncConnections
|
|
275
|
+
self._conn = AsyncConnections.get_connection(self._using)
|
|
276
|
+
return self._conn
|
|
277
|
+
|
|
278
|
+
async def _ensure_connection(self) -> None:
|
|
279
|
+
"""Ensure connection is established"""
|
|
280
|
+
self._conn = self._get_connection()
|
|
281
|
+
if not self._conn.is_connected:
|
|
282
|
+
await self._conn.connect()
|
|
283
|
+
if self._executor is None:
|
|
284
|
+
self._executor = AsyncExecutor(self._conn)
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def _core(self) -> CollectionCore:
|
|
288
|
+
"""Lazy-initialized CollectionCore with the AsyncExecutor.
|
|
289
|
+
|
|
290
|
+
Used for delegating business logic to the shared CollectionCore layer.
|
|
291
|
+
Schema must be loaded before this property is accessed.
|
|
292
|
+
"""
|
|
293
|
+
if self._core_obj is None:
|
|
294
|
+
self._core_obj = CollectionCore(
|
|
295
|
+
self._schema, self._executor, name=self._name
|
|
296
|
+
)
|
|
297
|
+
return self._core_obj
|
|
298
|
+
|
|
299
|
+
async def _ensure_schema_loaded(self) -> None:
|
|
300
|
+
"""Ensure schema is loaded before delegating to core."""
|
|
301
|
+
if self._schema is None:
|
|
302
|
+
await self._load_schema_async()
|
|
303
|
+
|
|
304
|
+
def _get_core(self) -> CollectionCore:
|
|
305
|
+
"""Get CollectionCore with FakeExecutor for validation-only operations.
|
|
306
|
+
|
|
307
|
+
Unlike ``self._core`` (which uses the real AsyncExecutor), this method
|
|
308
|
+
returns a core backed by a FakeExecutor, safe for calling
|
|
309
|
+
``_process_vectors_for_insert`` / ``_validate_insert_data`` without
|
|
310
|
+
executing real SQL.
|
|
311
|
+
"""
|
|
312
|
+
if self._schema is None:
|
|
313
|
+
raise CollectionError(
|
|
314
|
+
"Schema not loaded. Use 'await collection.schema_async' or "
|
|
315
|
+
"call 'await collection.load_schema()' first."
|
|
316
|
+
)
|
|
317
|
+
from ..core.executor import FakeExecutor
|
|
318
|
+
|
|
319
|
+
return CollectionCore(self._schema, FakeExecutor(), name=self._name)
|
|
320
|
+
|
|
321
|
+
# =========================================================================
|
|
322
|
+
# Schema Operations
|
|
323
|
+
# =========================================================================
|
|
324
|
+
|
|
325
|
+
@_with_connection
|
|
326
|
+
async def load_schema(self) -> CollectionSchema:
|
|
327
|
+
"""Load schema from database"""
|
|
328
|
+
await self._load_schema_async()
|
|
329
|
+
return self._schema
|
|
330
|
+
|
|
331
|
+
async def _load_schema_async(self) -> None:
|
|
332
|
+
"""Load schema from database (async)"""
|
|
333
|
+
try:
|
|
334
|
+
# Check existence first
|
|
335
|
+
check_sql = build_existence_check_sql()
|
|
336
|
+
rows = await self._executor.execute(check_sql, [self._name])
|
|
337
|
+
if not rows:
|
|
338
|
+
raise CollectionNotExistsError(self._name)
|
|
339
|
+
|
|
340
|
+
# Primary keys
|
|
341
|
+
pk_rows = await self._executor.execute(
|
|
342
|
+
build_pk_query_sql(), [self._name]
|
|
343
|
+
)
|
|
344
|
+
pk_columns = {row["column_name"] for row in pk_rows}
|
|
345
|
+
|
|
346
|
+
# Columns with atttypmod dimensions
|
|
347
|
+
field_rows = await self._executor.execute(
|
|
348
|
+
build_columns_query_sql(), [self._name]
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
fields = [parse_column_row(row, pk_columns) for row in field_rows]
|
|
352
|
+
self._schema = CollectionSchema(name=self._name, fields=fields)
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
if isinstance(e, CollectionNotExistsError):
|
|
356
|
+
raise
|
|
357
|
+
raise CollectionError(f"Failed to load schema: {e}")
|
|
358
|
+
|
|
359
|
+
async def _check_schema_version(self, conn) -> None:
|
|
360
|
+
"""Verify server version supports all vector types in the schema.
|
|
361
|
+
|
|
362
|
+
Uses VectorTypeRegistry to look up the minimum Vastbase version
|
|
363
|
+
required for each vector type in the schema.
|
|
364
|
+
|
|
365
|
+
The ``conn`` argument is the AsyncVastbaseConnection, which provides
|
|
366
|
+
``require_version``.
|
|
367
|
+
"""
|
|
368
|
+
from ..core.vector_types import get_registry
|
|
369
|
+
|
|
370
|
+
registry = get_registry()
|
|
371
|
+
checked: set = set()
|
|
372
|
+
for f in self._schema.fields:
|
|
373
|
+
if f.is_vector_field():
|
|
374
|
+
min_ver = registry.get_min_version(f.dtype)
|
|
375
|
+
if min_ver is not None and min_ver not in checked:
|
|
376
|
+
checked.add(min_ver)
|
|
377
|
+
ver_str = ".".join(str(v) for v in min_ver)
|
|
378
|
+
await conn.require_version(
|
|
379
|
+
(min_ver[0], min_ver[1]),
|
|
380
|
+
min_ver[2],
|
|
381
|
+
min_ver[3],
|
|
382
|
+
f"Vector type '{f.dtype}' requires Vastbase {ver_str}",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# =========================================================================
|
|
386
|
+
# DDL Operations
|
|
387
|
+
# =========================================================================
|
|
388
|
+
|
|
389
|
+
@_with_connection
|
|
390
|
+
async def create(
|
|
391
|
+
self,
|
|
392
|
+
timeout: Optional[float] = None,
|
|
393
|
+
**kwargs
|
|
394
|
+
) -> None:
|
|
395
|
+
"""
|
|
396
|
+
Create the collection in the database.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
timeout: Operation timeout in seconds
|
|
400
|
+
"""
|
|
401
|
+
if self._schema is None:
|
|
402
|
+
raise SchemaError("Schema is required to create collection")
|
|
403
|
+
|
|
404
|
+
# Version-gating: check if schema uses newer vector types
|
|
405
|
+
await self._check_schema_version(self._conn)
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
sql = self._schema.to_sql(if_not_exists=True)
|
|
409
|
+
await self._executor.execute(sql, [])
|
|
410
|
+
except Exception as e:
|
|
411
|
+
raise CollectionError(f"Failed to create collection: {e}")
|
|
412
|
+
|
|
413
|
+
@_with_connection
|
|
414
|
+
async def drop(self, timeout: Optional[float] = None, **kwargs) -> None:
|
|
415
|
+
"""Drop the collection from the database."""
|
|
416
|
+
from ..operations.ddl import build_drop_table_sql
|
|
417
|
+
try:
|
|
418
|
+
sql, params = build_drop_table_sql(self._name)
|
|
419
|
+
await self._executor.execute(sql, params)
|
|
420
|
+
except Exception as e:
|
|
421
|
+
raise CollectionError(f"Failed to drop collection: {e}")
|
|
422
|
+
|
|
423
|
+
@_with_connection
|
|
424
|
+
async def exists(self) -> bool:
|
|
425
|
+
"""Check if collection exists"""
|
|
426
|
+
try:
|
|
427
|
+
from ..operations.schema_reflect import build_existence_check_sql
|
|
428
|
+
sql = build_existence_check_sql()
|
|
429
|
+
rows = await self._executor.execute(sql, {"table_name": self._name})
|
|
430
|
+
return len(rows) > 0
|
|
431
|
+
except Exception:
|
|
432
|
+
return False
|
|
433
|
+
|
|
434
|
+
@_with_connection
|
|
435
|
+
async def truncate(self, timeout: Optional[float] = None, **kwargs) -> None:
|
|
436
|
+
"""Truncate (empty) the collection."""
|
|
437
|
+
from ..operations.ddl import build_truncate_table_sql
|
|
438
|
+
try:
|
|
439
|
+
sql, params = build_truncate_table_sql(self._name)
|
|
440
|
+
await self._executor.execute(sql, params)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
raise CollectionError(f"Failed to truncate collection: {e}")
|
|
443
|
+
|
|
444
|
+
@_with_connection
|
|
445
|
+
async def refresh_collection(self, timeout: Optional[float] = None) -> None:
|
|
446
|
+
"""Refresh collection statistics (ANALYZE)."""
|
|
447
|
+
from ..operations.ddl import build_refresh_collection_sql
|
|
448
|
+
try:
|
|
449
|
+
sql, params = build_refresh_collection_sql(self._name)
|
|
450
|
+
await self._executor.execute(sql, params)
|
|
451
|
+
except Exception as e:
|
|
452
|
+
raise CollectionError(f"Failed to refresh collection: {e}")
|
|
453
|
+
|
|
454
|
+
@_with_connection
|
|
455
|
+
async def has_collection_revision(
|
|
456
|
+
self,
|
|
457
|
+
revision: Optional[int] = None
|
|
458
|
+
) -> Union[bool, Dict[str, Any]]:
|
|
459
|
+
"""Check collection revision or get revision info."""
|
|
460
|
+
from ..operations.ddl import build_has_revision_sql, parse_has_revision_result
|
|
461
|
+
try:
|
|
462
|
+
sql, params = build_has_revision_sql(self._name)
|
|
463
|
+
rows = await self._executor.execute(sql, params)
|
|
464
|
+
result = parse_has_revision_result(rows, revision)
|
|
465
|
+
if isinstance(result, dict):
|
|
466
|
+
result["collection_name"] = self._name
|
|
467
|
+
return result
|
|
468
|
+
|
|
469
|
+
except Exception as e:
|
|
470
|
+
if revision is not None:
|
|
471
|
+
return False
|
|
472
|
+
raise CollectionError(f"Failed to get collection revision: {e}")
|
|
473
|
+
|
|
474
|
+
# =========================================================================
|
|
475
|
+
# Index Operations
|
|
476
|
+
# =========================================================================
|
|
477
|
+
|
|
478
|
+
@_with_connection
|
|
479
|
+
async def create_index(
|
|
480
|
+
self,
|
|
481
|
+
field_name: str,
|
|
482
|
+
index_params: Optional[IndexParams] = None,
|
|
483
|
+
timeout: Optional[float] = None,
|
|
484
|
+
**kwargs
|
|
485
|
+
) -> None:
|
|
486
|
+
"""
|
|
487
|
+
Create index on a field.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
field_name: Field name to index
|
|
491
|
+
index_params: Index parameters
|
|
492
|
+
timeout: Operation timeout in seconds
|
|
493
|
+
"""
|
|
494
|
+
if index_params is None:
|
|
495
|
+
index_params = IndexParams()
|
|
496
|
+
|
|
497
|
+
# Get field info (load schema if needed)
|
|
498
|
+
if self._schema is None:
|
|
499
|
+
await self._load_schema_async()
|
|
500
|
+
|
|
501
|
+
field = self._schema.get_field(field_name)
|
|
502
|
+
if field is None:
|
|
503
|
+
raise CollectionError(f"Field '{field_name}' not found")
|
|
504
|
+
|
|
505
|
+
# Generate index name
|
|
506
|
+
index_name = f"idx_{self._name}_{field_name}"
|
|
507
|
+
|
|
508
|
+
# Build index using builder
|
|
509
|
+
builder = IndexBuilder(self._name)
|
|
510
|
+
builder.set_index_name(index_name)
|
|
511
|
+
builder.set_column(field_name)
|
|
512
|
+
builder.set_params(index_params)
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
sql = builder.build()
|
|
516
|
+
await self._executor.execute(sql, [])
|
|
517
|
+
return index_name
|
|
518
|
+
except Exception as e:
|
|
519
|
+
raise CollectionError(f"Failed to create index: {e}")
|
|
520
|
+
|
|
521
|
+
@_with_connection
|
|
522
|
+
async def drop_index(
|
|
523
|
+
self,
|
|
524
|
+
field_name: str,
|
|
525
|
+
timeout: Optional[float] = None,
|
|
526
|
+
**kwargs
|
|
527
|
+
) -> None:
|
|
528
|
+
"""Drop index on a field."""
|
|
529
|
+
from ..operations.index import build_drop_index_sql
|
|
530
|
+
index_name = f"idx_{self._name}_{field_name}"
|
|
531
|
+
try:
|
|
532
|
+
sql, params = build_drop_index_sql(index_name)
|
|
533
|
+
await self._executor.execute(sql, params)
|
|
534
|
+
except Exception as e:
|
|
535
|
+
raise CollectionError(f"Failed to drop index: {e}")
|
|
536
|
+
|
|
537
|
+
@_with_connection
|
|
538
|
+
async def has_index(self, field_name: Optional[str] = None) -> bool:
|
|
539
|
+
"""Check if index exists on field."""
|
|
540
|
+
from ..operations.index import build_has_index_sql
|
|
541
|
+
try:
|
|
542
|
+
index_name = f"idx_{self._name}_{field_name}" if field_name else None
|
|
543
|
+
sql, params = build_has_index_sql(self._name, index_name)
|
|
544
|
+
rows = await self._executor.execute(sql, params)
|
|
545
|
+
return len(rows) > 0
|
|
546
|
+
except Exception:
|
|
547
|
+
return False
|
|
548
|
+
|
|
549
|
+
@_with_connection
|
|
550
|
+
async def list_indexes(self) -> List[str]:
|
|
551
|
+
"""List all indexes of the collection."""
|
|
552
|
+
from ..operations.index import build_list_indexes_sql, parse_list_indexes_result
|
|
553
|
+
try:
|
|
554
|
+
sql, params = build_list_indexes_sql(self._name)
|
|
555
|
+
rows = await self._executor.execute(sql, params)
|
|
556
|
+
return parse_list_indexes_result(rows)
|
|
557
|
+
except Exception as e:
|
|
558
|
+
raise CollectionError(f"Failed to list indexes: {e}")
|
|
559
|
+
|
|
560
|
+
@_with_connection
|
|
561
|
+
async def get_index(self, field_name: Optional[str] = None) -> Dict[str, Any]:
|
|
562
|
+
"""Get index information."""
|
|
563
|
+
from ..operations.index import build_get_index_sql, parse_get_index_result
|
|
564
|
+
try:
|
|
565
|
+
sql, params = build_get_index_sql(self._name, field_name)
|
|
566
|
+
rows = await self._executor.execute(sql, params)
|
|
567
|
+
return parse_get_index_result(rows, self._name)
|
|
568
|
+
except Exception as e:
|
|
569
|
+
raise CollectionError(f"Failed to get index: {e}")
|
|
570
|
+
|
|
571
|
+
# =========================================================================
|
|
572
|
+
# Data Operations - Helpers
|
|
573
|
+
# =========================================================================
|
|
574
|
+
# Data Operations - Insert
|
|
575
|
+
# =========================================================================
|
|
576
|
+
|
|
577
|
+
@_with_connection(schema=True)
|
|
578
|
+
async def insert(
|
|
579
|
+
self,
|
|
580
|
+
data: Union[List[Dict[str, Any]], Dict[str, Any]],
|
|
581
|
+
partition_name: Optional[str] = None,
|
|
582
|
+
timeout: Optional[float] = None,
|
|
583
|
+
normalize_vectors: bool = False,
|
|
584
|
+
**kwargs
|
|
585
|
+
) -> MutationResult:
|
|
586
|
+
"""
|
|
587
|
+
Insert data into collection.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
data: List of records or single record
|
|
591
|
+
timeout: Operation timeout in seconds
|
|
592
|
+
normalize_vectors: Whether to L2-normalize vector fields before insert
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
MutationResult with inserted IDs (backward compatible with List[int])
|
|
596
|
+
|
|
597
|
+
Example:
|
|
598
|
+
>>> result = await collection.insert([{"id": 1, "embedding": [0.1] * 128}])
|
|
599
|
+
>>> # Access as list (backward compatible)
|
|
600
|
+
>>> ids = list(result)
|
|
601
|
+
>>> # Access as MutationResult
|
|
602
|
+
>>> print(result.insert_count)
|
|
603
|
+
"""
|
|
604
|
+
# Normalize to list
|
|
605
|
+
if isinstance(data, dict):
|
|
606
|
+
data = [data]
|
|
607
|
+
|
|
608
|
+
if not data:
|
|
609
|
+
return MutationResult(primary_keys=[])
|
|
610
|
+
|
|
611
|
+
# Deep copy to avoid mutating original data
|
|
612
|
+
data = json.loads(json.dumps(data))
|
|
613
|
+
|
|
614
|
+
# Process vector fields: validate and optionally normalize
|
|
615
|
+
core = self._get_core()
|
|
616
|
+
core._process_vectors_for_insert(data, normalize=normalize_vectors)
|
|
617
|
+
|
|
618
|
+
# Validate against schema
|
|
619
|
+
core._validate_insert_data(data)
|
|
620
|
+
|
|
621
|
+
# Build insert statement
|
|
622
|
+
columns = list(data[0].keys())
|
|
623
|
+
from ..operations.insert import build_insert_col_names, build_insert_row_sql
|
|
624
|
+
from ..operations.result import parse_insert_result
|
|
625
|
+
from ..operations.insert import build_value_parts
|
|
626
|
+
col_names = build_insert_col_names(columns)
|
|
627
|
+
pk_field = self._schema.get_primary_field()
|
|
628
|
+
field_map = {f.name: f for f in self._schema.fields}
|
|
629
|
+
|
|
630
|
+
try:
|
|
631
|
+
ids = []
|
|
632
|
+
for row in data:
|
|
633
|
+
value_parts, params = build_value_parts(row, columns, field_map)
|
|
634
|
+
sql = build_insert_row_sql(self._name, col_names, pk_field.name, value_parts)
|
|
635
|
+
rows = await self._executor.execute(sql, params)
|
|
636
|
+
pk_col = pk_field.name
|
|
637
|
+
row_ids = [r[pk_col] for r in rows]
|
|
638
|
+
ids.extend(row_ids)
|
|
639
|
+
|
|
640
|
+
return parse_insert_result(ids)
|
|
641
|
+
except Exception as e:
|
|
642
|
+
await self._conn.rollback()
|
|
643
|
+
raise DataError(f"Failed to insert data: {e}")
|
|
644
|
+
|
|
645
|
+
@_with_connection(schema=True)
|
|
646
|
+
async def batch_insert(
|
|
647
|
+
self,
|
|
648
|
+
data: Union[List[Dict[str, Any]], Dict[str, Any]],
|
|
649
|
+
batch_size: int = 100,
|
|
650
|
+
timeout: Optional[float] = None,
|
|
651
|
+
normalize_vectors: bool = False,
|
|
652
|
+
**kwargs
|
|
653
|
+
) -> MutationResult:
|
|
654
|
+
"""
|
|
655
|
+
Insert data in batches for better performance with large datasets.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
data: List of records or single record
|
|
659
|
+
batch_size: Number of records per batch
|
|
660
|
+
timeout: Operation timeout in seconds
|
|
661
|
+
normalize_vectors: Whether to L2-normalize vector fields before insert
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
MutationResult with all inserted IDs (backward compatible with List[int])
|
|
665
|
+
|
|
666
|
+
Example:
|
|
667
|
+
>>> result = await collection.batch_insert(large_data, batch_size=1000)
|
|
668
|
+
>>> print(result.insert_count)
|
|
669
|
+
>>> print(result.primary_keys[:5])
|
|
670
|
+
"""
|
|
671
|
+
# Normalize to list
|
|
672
|
+
if isinstance(data, dict):
|
|
673
|
+
data = [data]
|
|
674
|
+
|
|
675
|
+
if not data:
|
|
676
|
+
return MutationResult(primary_keys=[])
|
|
677
|
+
|
|
678
|
+
all_ids = []
|
|
679
|
+
|
|
680
|
+
# Use batch_iterator utility to process data in batches
|
|
681
|
+
for batch in batch_iterator(data, batch_size):
|
|
682
|
+
# Process vectors (validate, normalize)
|
|
683
|
+
batch_copy = json.loads(json.dumps(list(batch)))
|
|
684
|
+
core = self._get_core()
|
|
685
|
+
core._process_vectors_for_insert(batch_copy, normalize=normalize_vectors)
|
|
686
|
+
|
|
687
|
+
# Validate batch
|
|
688
|
+
core._validate_insert_data(batch_copy)
|
|
689
|
+
|
|
690
|
+
# Build batch insert SQL
|
|
691
|
+
columns = list(batch_copy[0].keys())
|
|
692
|
+
from ..operations.insert import build_insert_col_names, build_insert_row_sql
|
|
693
|
+
from ..operations.insert import build_value_parts
|
|
694
|
+
col_names = build_insert_col_names(columns)
|
|
695
|
+
field_map = {f.name: f for f in self._schema.fields}
|
|
696
|
+
pk_field = self._schema.get_primary_field()
|
|
697
|
+
|
|
698
|
+
try:
|
|
699
|
+
batch_ids = []
|
|
700
|
+
for row in batch_copy:
|
|
701
|
+
value_parts, params = build_value_parts(row, columns, field_map)
|
|
702
|
+
sql = build_insert_row_sql(self._name, col_names, pk_field.name, value_parts)
|
|
703
|
+
rows = await self._executor.execute(sql, params)
|
|
704
|
+
pk_col = pk_field.name
|
|
705
|
+
batch_ids.extend(r[pk_col] for r in rows)
|
|
706
|
+
all_ids.extend(batch_ids)
|
|
707
|
+
except Exception as e:
|
|
708
|
+
await self._conn.rollback()
|
|
709
|
+
raise DataError(f"Batch insert failed: {e}")
|
|
710
|
+
|
|
711
|
+
from ..operations.result import parse_insert_result
|
|
712
|
+
return parse_insert_result(all_ids)
|
|
713
|
+
|
|
714
|
+
# =========================================================================
|
|
715
|
+
# Data Operations - Delete/Upsert
|
|
716
|
+
# =========================================================================
|
|
717
|
+
|
|
718
|
+
@_with_connection(schema=True)
|
|
719
|
+
async def delete(
|
|
720
|
+
self,
|
|
721
|
+
expr: Optional[str] = None,
|
|
722
|
+
pks: Optional[List[Any]] = None,
|
|
723
|
+
partition_name: Optional[str] = None,
|
|
724
|
+
timeout: Optional[float] = None,
|
|
725
|
+
**kwargs
|
|
726
|
+
) -> MutationResult:
|
|
727
|
+
"""
|
|
728
|
+
Delete entities from collection.
|
|
729
|
+
|
|
730
|
+
Args:
|
|
731
|
+
expr: Filter expression (SQL WHERE clause)
|
|
732
|
+
pks: List of primary key values to delete
|
|
733
|
+
timeout: Operation timeout in seconds
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
MutationResult with delete count (backward compatible with int)
|
|
737
|
+
|
|
738
|
+
Example:
|
|
739
|
+
>>> result = await collection.delete(pks=[1, 2, 3])
|
|
740
|
+
>>> # Access as int (backward compatible)
|
|
741
|
+
>>> count = len(result)
|
|
742
|
+
>>> # Access as MutationResult
|
|
743
|
+
>>> print(result.delete_count)
|
|
744
|
+
"""
|
|
745
|
+
conn = self._get_connection()
|
|
746
|
+
await self._ensure_connection()
|
|
747
|
+
|
|
748
|
+
if self._schema is None:
|
|
749
|
+
await self._load_schema_async()
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
select_sql, delete_sql, params = build_delete_sql(
|
|
753
|
+
self._name,
|
|
754
|
+
self._schema.get_primary_field().name,
|
|
755
|
+
expr=expr if expr else None,
|
|
756
|
+
pks=pks if pks else None,
|
|
757
|
+
)
|
|
758
|
+
except ValueError as e:
|
|
759
|
+
raise ParamError(str(e))
|
|
760
|
+
|
|
761
|
+
try:
|
|
762
|
+
from ..operations.result import parse_delete_result
|
|
763
|
+
if select_sql:
|
|
764
|
+
pk_name = self._schema.get_primary_field().name
|
|
765
|
+
rows = await self._executor.execute(select_sql, [])
|
|
766
|
+
deleted_pks = [row[pk_name] for row in rows]
|
|
767
|
+
else:
|
|
768
|
+
deleted_pks = list(pks)
|
|
769
|
+
|
|
770
|
+
await self._executor.execute(delete_sql, params if params else [])
|
|
771
|
+
return parse_delete_result(deleted_pks)
|
|
772
|
+
except Exception as e:
|
|
773
|
+
await self._conn.rollback()
|
|
774
|
+
raise DataError(f"Failed to delete data: {e}")
|
|
775
|
+
|
|
776
|
+
@_with_connection(schema=True)
|
|
777
|
+
async def upsert(
|
|
778
|
+
self,
|
|
779
|
+
data: Union[List[Dict[str, Any]], Dict[str, Any]],
|
|
780
|
+
partition_name: Optional[str] = None,
|
|
781
|
+
timeout: Optional[float] = None,
|
|
782
|
+
**kwargs
|
|
783
|
+
) -> MutationResult:
|
|
784
|
+
"""
|
|
785
|
+
Upsert data (insert or update on conflict).
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
data: Data to upsert
|
|
789
|
+
timeout: Operation timeout in seconds
|
|
790
|
+
|
|
791
|
+
Returns:
|
|
792
|
+
MutationResult with upserted IDs (backward compatible with List[int])
|
|
793
|
+
|
|
794
|
+
Example:
|
|
795
|
+
>>> result = await collection.upsert(data)
|
|
796
|
+
>>> print(result.insert_count)
|
|
797
|
+
>>> print(result.primary_keys[:5])
|
|
798
|
+
"""
|
|
799
|
+
if isinstance(data, dict):
|
|
800
|
+
data = [data]
|
|
801
|
+
|
|
802
|
+
if not data:
|
|
803
|
+
return MutationResult(primary_keys=[])
|
|
804
|
+
|
|
805
|
+
pk_field = self._schema.get_primary_field()
|
|
806
|
+
if not pk_field:
|
|
807
|
+
raise CollectionError("Upsert requires a primary key")
|
|
808
|
+
|
|
809
|
+
# Validate data
|
|
810
|
+
self._get_core()._validate_insert_data(data)
|
|
811
|
+
|
|
812
|
+
columns = list(data[0].keys())
|
|
813
|
+
from ..operations.upsert import build_upsert_clauses
|
|
814
|
+
from ..operations.insert import build_insert_col_names, build_upsert_row_sql
|
|
815
|
+
from ..operations.insert import build_value_parts
|
|
816
|
+
from ..operations.result import parse_upsert_result
|
|
817
|
+
col_names, on_conflict = build_upsert_clauses(columns, pk_field.name)
|
|
818
|
+
field_map = {f.name: f for f in self._schema.fields}
|
|
819
|
+
|
|
820
|
+
try:
|
|
821
|
+
ids = []
|
|
822
|
+
for row in data:
|
|
823
|
+
value_parts, params = build_value_parts(row, columns, field_map)
|
|
824
|
+
sql = build_upsert_row_sql(self._name, col_names, value_parts, on_conflict)
|
|
825
|
+
rows = await self._executor.execute(sql, params)
|
|
826
|
+
pk_col = pk_field.name
|
|
827
|
+
ids.extend(r[pk_col] for r in rows)
|
|
828
|
+
|
|
829
|
+
return parse_upsert_result(ids)
|
|
830
|
+
except Exception as e:
|
|
831
|
+
await self._conn.rollback()
|
|
832
|
+
raise DataError(f"Failed to upsert data: {e}")
|
|
833
|
+
|
|
834
|
+
# =========================================================================
|
|
835
|
+
# Search Operations
|
|
836
|
+
# =========================================================================
|
|
837
|
+
|
|
838
|
+
@_with_connection(schema=True)
|
|
839
|
+
async def search(
|
|
840
|
+
self,
|
|
841
|
+
data: Union[List[float], List[List[float]]],
|
|
842
|
+
anns_field: Optional[str] = None,
|
|
843
|
+
param: Optional[Dict[str, Any]] = None,
|
|
844
|
+
limit: int = 10,
|
|
845
|
+
expr: Optional[str] = None,
|
|
846
|
+
partition_names: Optional[List[str]] = None,
|
|
847
|
+
output_fields: Optional[List[str]] = None,
|
|
848
|
+
timeout: Optional[float] = None,
|
|
849
|
+
round_decimal: int = -1,
|
|
850
|
+
ranker: Optional[Any] = None,
|
|
851
|
+
highlighter: Optional[Any] = None,
|
|
852
|
+
normalize_vectors: bool = False,
|
|
853
|
+
**kwargs
|
|
854
|
+
) -> SearchResult:
|
|
855
|
+
"""
|
|
856
|
+
Search for similar vectors.
|
|
857
|
+
|
|
858
|
+
Args:
|
|
859
|
+
data: Query vector(s)
|
|
860
|
+
anns_field: Name of vector field to search
|
|
861
|
+
param: Search parameters
|
|
862
|
+
limit: Number of results to return
|
|
863
|
+
expr: Filter expression
|
|
864
|
+
output_fields: Fields to return in results
|
|
865
|
+
timeout: Operation timeout in seconds
|
|
866
|
+
normalize_vectors: Whether to normalize query vectors before search
|
|
867
|
+
|
|
868
|
+
Returns:
|
|
869
|
+
SearchResult object
|
|
870
|
+
"""
|
|
871
|
+
if ranker is not None:
|
|
872
|
+
raise ParamError("ranker is not supported by Vastbase (PostgreSQL-based). "
|
|
873
|
+
"Vastbase does not support custom ranking strategies.")
|
|
874
|
+
if highlighter is not None:
|
|
875
|
+
raise ParamError("highlighter is not supported by Vastbase. "
|
|
876
|
+
"Use bm25_search_highlight utility for BM25 highlighting.")
|
|
877
|
+
|
|
878
|
+
# Get vector field
|
|
879
|
+
if anns_field is None:
|
|
880
|
+
vector_fields = self._schema.vector_field_names
|
|
881
|
+
if not vector_fields:
|
|
882
|
+
raise CollectionError("No vector field found in schema")
|
|
883
|
+
anns_field = vector_fields[0]
|
|
884
|
+
|
|
885
|
+
# Get field info for dimension validation
|
|
886
|
+
field = self._schema.get_field(anns_field)
|
|
887
|
+
if field is None:
|
|
888
|
+
raise CollectionError(f"Vector field '{anns_field}' not found")
|
|
889
|
+
|
|
890
|
+
expected_dim = field.dim
|
|
891
|
+
|
|
892
|
+
# Get metric type from schema
|
|
893
|
+
metric_type = field.metric_type if field and field.metric_type else DistanceType.L2
|
|
894
|
+
|
|
895
|
+
# Normalize queries (ensure correct dimension)
|
|
896
|
+
queries = data if isinstance(data[0], list) else [data]
|
|
897
|
+
|
|
898
|
+
processed_queries = validate_and_normalize_queries(
|
|
899
|
+
queries,
|
|
900
|
+
expected_dim,
|
|
901
|
+
normalize=normalize_vectors,
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
# Build search parameters
|
|
905
|
+
search_params = SearchParams(
|
|
906
|
+
anns_field=anns_field,
|
|
907
|
+
top_k=limit,
|
|
908
|
+
metric_type=metric_type
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
if param:
|
|
912
|
+
if "ef" in param:
|
|
913
|
+
search_params.hnsw_ef_search = param["ef"]
|
|
914
|
+
if "metric_type" in param:
|
|
915
|
+
search_params.metric_type = DistanceType(param["metric_type"])
|
|
916
|
+
|
|
917
|
+
# Build query using SearchQueryBuilder
|
|
918
|
+
builder = SearchQueryBuilder(self._name)
|
|
919
|
+
builder.set_vector_field(anns_field)
|
|
920
|
+
builder.set_queries(processed_queries)
|
|
921
|
+
builder.set_params(search_params)
|
|
922
|
+
builder.set_limit(limit)
|
|
923
|
+
|
|
924
|
+
if expr:
|
|
925
|
+
builder.set_filter(expr)
|
|
926
|
+
|
|
927
|
+
if output_fields:
|
|
928
|
+
builder.set_output_fields(output_fields)
|
|
929
|
+
|
|
930
|
+
# Build and execute query
|
|
931
|
+
build_result = builder.build()
|
|
932
|
+
|
|
933
|
+
try:
|
|
934
|
+
if isinstance(build_result, tuple) and len(build_result) == 3:
|
|
935
|
+
set_clauses, main_sql, params = build_result
|
|
936
|
+
async with self._executor.transaction() as tx:
|
|
937
|
+
for set_sql in set_clauses:
|
|
938
|
+
await tx.execute(set_sql, [])
|
|
939
|
+
rows = await tx.execute(
|
|
940
|
+
main_sql, list(params[0]) if params else []
|
|
941
|
+
)
|
|
942
|
+
else:
|
|
943
|
+
sql, params = build_result
|
|
944
|
+
if params:
|
|
945
|
+
rows = await self._executor.execute(sql, list(params[0]))
|
|
946
|
+
else:
|
|
947
|
+
rows = await self._executor.execute(sql, [])
|
|
948
|
+
|
|
949
|
+
# Parse results
|
|
950
|
+
from ..operations.search_result import parse_search_rows
|
|
951
|
+
pk_field = self._schema.get_primary_field()
|
|
952
|
+
query_results = parse_search_rows(rows, pk_field.name if pk_field else None, output_fields)
|
|
953
|
+
|
|
954
|
+
results = SearchResult()
|
|
955
|
+
results.num_queries = len(processed_queries)
|
|
956
|
+
results.queries = processed_queries
|
|
957
|
+
results.add_result(0, query_results)
|
|
958
|
+
return results
|
|
959
|
+
|
|
960
|
+
except Exception as e:
|
|
961
|
+
raise CollectionError(f"Search failed: {e}")
|
|
962
|
+
|
|
963
|
+
async def hybrid_search(
|
|
964
|
+
self,
|
|
965
|
+
reqs: Optional[List[Any]] = None,
|
|
966
|
+
rerank: Optional[Any] = None,
|
|
967
|
+
limit: int = 10,
|
|
968
|
+
partition_names: Optional[List[str]] = None,
|
|
969
|
+
output_fields: Optional[List[str]] = None,
|
|
970
|
+
timeout: Optional[float] = None,
|
|
971
|
+
round_decimal: int = -1,
|
|
972
|
+
ranker: Optional[Any] = None,
|
|
973
|
+
parallel: bool = True,
|
|
974
|
+
**kwargs
|
|
975
|
+
) -> SearchResult:
|
|
976
|
+
"""
|
|
977
|
+
Multi-way vector recall with reranking (pymilvus-compatible, async).
|
|
978
|
+
|
|
979
|
+
Issues independent ANN search SQL per AnnSearchRequest, merges
|
|
980
|
+
results using the specified reranking strategy.
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
reqs: List of AnnSearchRequest objects
|
|
984
|
+
rerank: BaseRanker instance (RRFRanker, WeightedRRFRanker, WeightedRanker)
|
|
985
|
+
limit: Max merged results
|
|
986
|
+
output_fields: Fields to include
|
|
987
|
+
parallel: Execute searches concurrently via asyncio.gather (default True)
|
|
988
|
+
"""
|
|
989
|
+
from ..core.exceptions import ParamError
|
|
990
|
+
from ..search.search_builder import BaseRanker
|
|
991
|
+
|
|
992
|
+
if not reqs:
|
|
993
|
+
raise ParamError("reqs must not be empty")
|
|
994
|
+
if not isinstance(rerank, BaseRanker):
|
|
995
|
+
raise ParamError(f"rerank must be a BaseRanker instance, got {type(rerank).__name__}")
|
|
996
|
+
|
|
997
|
+
pk_field = self._schema.get_primary_field()
|
|
998
|
+
all_results: list[tuple[list, DistanceType]] = []
|
|
999
|
+
|
|
1000
|
+
async def _execute_single_req(req):
|
|
1001
|
+
field = self._schema.get_field(req.anns_field)
|
|
1002
|
+
if field is None:
|
|
1003
|
+
raise ParamError(f"Field '{req.anns_field}' not found in schema")
|
|
1004
|
+
|
|
1005
|
+
metric_str = req.param.get("metric_type", "l2")
|
|
1006
|
+
try:
|
|
1007
|
+
metric_type = DistanceType(metric_str.lower())
|
|
1008
|
+
except ValueError:
|
|
1009
|
+
raise ParamError(f"Unknown metric_type '{metric_str}'")
|
|
1010
|
+
|
|
1011
|
+
params = SearchParams(
|
|
1012
|
+
anns_field=req.anns_field,
|
|
1013
|
+
top_k=req.limit,
|
|
1014
|
+
metric_type=metric_type,
|
|
1015
|
+
)
|
|
1016
|
+
if "ef" in req.param:
|
|
1017
|
+
params.hnsw_ef_search = req.param["ef"]
|
|
1018
|
+
|
|
1019
|
+
query_vec = req.data if req.data and isinstance(req.data[0], list) else [req.data]
|
|
1020
|
+
|
|
1021
|
+
builder = SearchQueryBuilder(table_name=self._name)
|
|
1022
|
+
builder.set_vector_field(req.anns_field)
|
|
1023
|
+
builder.set_queries(query_vec)
|
|
1024
|
+
builder.set_params(params)
|
|
1025
|
+
builder.set_limit(req.limit)
|
|
1026
|
+
if output_fields:
|
|
1027
|
+
builder.set_output_fields(output_fields)
|
|
1028
|
+
|
|
1029
|
+
build_result = builder.build()
|
|
1030
|
+
if isinstance(build_result, tuple) and len(build_result) == 3:
|
|
1031
|
+
set_clauses, main_sql, sql_params = build_result
|
|
1032
|
+
async with self._executor.transaction() as tx:
|
|
1033
|
+
for set_sql in set_clauses:
|
|
1034
|
+
await tx.execute(set_sql, [])
|
|
1035
|
+
rows = await tx.execute(
|
|
1036
|
+
main_sql, list(sql_params[0]) if sql_params else []
|
|
1037
|
+
)
|
|
1038
|
+
else:
|
|
1039
|
+
sql, sql_params = build_result
|
|
1040
|
+
rows = await self._executor.execute(sql, list(sql_params[0]) if sql_params else [])
|
|
1041
|
+
items = parse_search_rows(rows, pk_field.name if pk_field else None, output_fields)
|
|
1042
|
+
return items, metric_type
|
|
1043
|
+
|
|
1044
|
+
if parallel and len(reqs) > 1:
|
|
1045
|
+
import asyncio
|
|
1046
|
+
pairs = await asyncio.gather(*[_execute_single_req(req) for req in reqs])
|
|
1047
|
+
all_results = list(pairs)
|
|
1048
|
+
else:
|
|
1049
|
+
for req in reqs:
|
|
1050
|
+
items, metric_type = await _execute_single_req(req)
|
|
1051
|
+
all_results.append((items, metric_type))
|
|
1052
|
+
|
|
1053
|
+
merged = rerank.merge(all_results, limit)
|
|
1054
|
+
|
|
1055
|
+
results = SearchResult()
|
|
1056
|
+
results.num_queries = 1
|
|
1057
|
+
results.add_result(0, merged)
|
|
1058
|
+
return results
|
|
1059
|
+
|
|
1060
|
+
async def hybrid_ann_search(
|
|
1061
|
+
self,
|
|
1062
|
+
data: Union[List[float], List[List[float]]],
|
|
1063
|
+
anns_field: Optional[str] = None,
|
|
1064
|
+
param: Optional[Dict[str, Any]] = None,
|
|
1065
|
+
limit: int = 10,
|
|
1066
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
|
1067
|
+
output_fields: Optional[List[str]] = None,
|
|
1068
|
+
timeout: Optional[float] = None,
|
|
1069
|
+
normalize_vectors: bool = False,
|
|
1070
|
+
**kwargs
|
|
1071
|
+
) -> SearchResult:
|
|
1072
|
+
"""
|
|
1073
|
+
Hybrid ANN search: vector similarity search with scalar field filtering.
|
|
1074
|
+
|
|
1075
|
+
Vastbase-native vector + scalar hybrid search using HybridANN index.
|
|
1076
|
+
|
|
1077
|
+
Args:
|
|
1078
|
+
data: Query vector(s) - single vector or list of vectors
|
|
1079
|
+
anns_field: Name of vector field to search (auto-detected if None)
|
|
1080
|
+
param: Search parameters (ef, metric_type, etc.)
|
|
1081
|
+
limit: Maximum number of results
|
|
1082
|
+
filter: Filter expression - can be:
|
|
1083
|
+
- dict: MongoDB-style filter {"field": {"$op": value}}
|
|
1084
|
+
- str: Raw SQL WHERE clause
|
|
1085
|
+
output_fields: Fields to return in results
|
|
1086
|
+
timeout: Operation timeout in seconds
|
|
1087
|
+
normalize_vectors: Whether to L2-normalize query vectors
|
|
1088
|
+
|
|
1089
|
+
Returns:
|
|
1090
|
+
SearchResult object
|
|
1091
|
+
|
|
1092
|
+
Supported Filter Operators:
|
|
1093
|
+
- Comparison: $eq, $ne, $gt, $gte, $lt, $lte
|
|
1094
|
+
- Set: $in, $nin
|
|
1095
|
+
- Range: $between
|
|
1096
|
+
- String: $like, $ilike
|
|
1097
|
+
- Logical: $and, $or
|
|
1098
|
+
|
|
1099
|
+
Example:
|
|
1100
|
+
# Simple filter
|
|
1101
|
+
results = await collection.hybrid_search(
|
|
1102
|
+
data=[0.1] * 128,
|
|
1103
|
+
filter={"category": {"$eq": "electronics"}},
|
|
1104
|
+
limit=10
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
# Complex filter with AND/OR
|
|
1108
|
+
results = await collection.hybrid_search(
|
|
1109
|
+
data=[0.1] * 128,
|
|
1110
|
+
filter={
|
|
1111
|
+
"$and": [
|
|
1112
|
+
{"price": {"$lte": 1000}},
|
|
1113
|
+
{"category": {"$in": ["electronics", "books"]}}
|
|
1114
|
+
]
|
|
1115
|
+
},
|
|
1116
|
+
limit=20
|
|
1117
|
+
)
|
|
1118
|
+
"""
|
|
1119
|
+
from ..utils.filter_utils import parse_filter_expression
|
|
1120
|
+
|
|
1121
|
+
# Get vector field
|
|
1122
|
+
if anns_field is None:
|
|
1123
|
+
vector_fields = self._schema.vector_field_names
|
|
1124
|
+
if not vector_fields:
|
|
1125
|
+
raise CollectionError("No vector field found in schema")
|
|
1126
|
+
anns_field = vector_fields[0]
|
|
1127
|
+
|
|
1128
|
+
# Get field info for dimension validation
|
|
1129
|
+
field = self._schema.get_field(anns_field)
|
|
1130
|
+
if field is None:
|
|
1131
|
+
raise CollectionError(f"Vector field '{anns_field}' not found")
|
|
1132
|
+
|
|
1133
|
+
expected_dim = field.dim
|
|
1134
|
+
|
|
1135
|
+
# Get metric type from schema
|
|
1136
|
+
metric_type = field.metric_type if field and field.metric_type else DistanceType.L2
|
|
1137
|
+
|
|
1138
|
+
# Normalize queries (ensure correct dimension)
|
|
1139
|
+
queries = data if isinstance(data[0], list) else [data]
|
|
1140
|
+
|
|
1141
|
+
processed_queries = validate_and_normalize_queries(
|
|
1142
|
+
queries,
|
|
1143
|
+
expected_dim,
|
|
1144
|
+
normalize=normalize_vectors,
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
# Parse filter expression to SQL
|
|
1148
|
+
where_clause = ""
|
|
1149
|
+
filter_params = []
|
|
1150
|
+
|
|
1151
|
+
if filter is not None:
|
|
1152
|
+
where_clause, filter_params = parse_filter_expression(
|
|
1153
|
+
filter, param_style="dollar"
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
# Build search parameters
|
|
1157
|
+
search_params = SearchParams(
|
|
1158
|
+
anns_field=anns_field,
|
|
1159
|
+
top_k=limit,
|
|
1160
|
+
metric_type=metric_type
|
|
1161
|
+
)
|
|
1162
|
+
|
|
1163
|
+
if param:
|
|
1164
|
+
if "ef" in param:
|
|
1165
|
+
search_params.hnsw_ef_search = param["ef"]
|
|
1166
|
+
if "metric_type" in param:
|
|
1167
|
+
search_params.metric_type = DistanceType(param["metric_type"])
|
|
1168
|
+
|
|
1169
|
+
# Build query using SearchQueryBuilder
|
|
1170
|
+
builder = SearchQueryBuilder(self._name)
|
|
1171
|
+
builder.set_vector_field(anns_field)
|
|
1172
|
+
builder.set_queries(processed_queries)
|
|
1173
|
+
builder.set_limit(limit)
|
|
1174
|
+
builder.set_params(search_params)
|
|
1175
|
+
|
|
1176
|
+
if where_clause:
|
|
1177
|
+
builder.set_filter(where_clause)
|
|
1178
|
+
|
|
1179
|
+
if output_fields:
|
|
1180
|
+
builder.set_output_fields(output_fields)
|
|
1181
|
+
|
|
1182
|
+
# Build and execute query
|
|
1183
|
+
build_result = builder.build()
|
|
1184
|
+
|
|
1185
|
+
try:
|
|
1186
|
+
if isinstance(build_result, tuple) and len(build_result) == 3:
|
|
1187
|
+
set_clauses, main_sql, params = build_result
|
|
1188
|
+
final_params = filter_params + (params[0] if params else [])
|
|
1189
|
+
async with self._executor.transaction() as tx:
|
|
1190
|
+
for set_sql in set_clauses:
|
|
1191
|
+
await tx.execute(set_sql, [])
|
|
1192
|
+
rows = await tx.execute(
|
|
1193
|
+
main_sql, list(final_params) if final_params else []
|
|
1194
|
+
)
|
|
1195
|
+
else:
|
|
1196
|
+
sql, params = build_result
|
|
1197
|
+
final_params = filter_params + (params[0] if params else [])
|
|
1198
|
+
if final_params:
|
|
1199
|
+
rows = await self._executor.execute(sql, list(final_params))
|
|
1200
|
+
else:
|
|
1201
|
+
rows = await self._executor.execute(sql, [])
|
|
1202
|
+
|
|
1203
|
+
# Parse results
|
|
1204
|
+
from ..operations.search_result import parse_search_rows
|
|
1205
|
+
pk_field = self._schema.get_primary_field()
|
|
1206
|
+
query_results = parse_search_rows(rows, pk_field.name if pk_field else None, output_fields)
|
|
1207
|
+
|
|
1208
|
+
results = SearchResult()
|
|
1209
|
+
results.num_queries = len(processed_queries)
|
|
1210
|
+
results.queries = processed_queries
|
|
1211
|
+
results.add_result(0, query_results)
|
|
1212
|
+
return results
|
|
1213
|
+
|
|
1214
|
+
except Exception as e:
|
|
1215
|
+
raise CollectionError(f"Hybrid search failed: {e}")
|
|
1216
|
+
|
|
1217
|
+
async def _fetch_batch(
|
|
1218
|
+
self,
|
|
1219
|
+
filter_expr: str,
|
|
1220
|
+
output_fields: Optional[List[str]] = None,
|
|
1221
|
+
limit: Optional[int] = None,
|
|
1222
|
+
offset: int = 0
|
|
1223
|
+
) -> List[Dict[str, Any]]:
|
|
1224
|
+
"""Internal method to fetch a batch of results"""
|
|
1225
|
+
await self._ensure_connection()
|
|
1226
|
+
from ..operations.query import build_query_sql, parse_query_result
|
|
1227
|
+
sql, params = build_query_sql(self._name, filter_expr, output_fields, limit, offset if offset > 0 else None)
|
|
1228
|
+
try:
|
|
1229
|
+
rows = await self._executor.execute(sql, params)
|
|
1230
|
+
return parse_query_result(rows, output_fields)
|
|
1231
|
+
except Exception as e:
|
|
1232
|
+
raise CollectionError(f"Query failed: {e}")
|
|
1233
|
+
|
|
1234
|
+
@_with_connection(schema=True)
|
|
1235
|
+
async def query(
|
|
1236
|
+
self,
|
|
1237
|
+
expr: Optional[str] = None,
|
|
1238
|
+
output_fields: Optional[List[str]] = None,
|
|
1239
|
+
limit: Optional[int] = None,
|
|
1240
|
+
offset: Optional[int] = None,
|
|
1241
|
+
partition_names: Optional[List[str]] = None,
|
|
1242
|
+
timeout: Optional[float] = None,
|
|
1243
|
+
batch_size: Optional[int] = None,
|
|
1244
|
+
**kwargs
|
|
1245
|
+
) -> Union[List[Dict[str, Any]], QueryIterator]:
|
|
1246
|
+
"""
|
|
1247
|
+
Query entities with filters.
|
|
1248
|
+
|
|
1249
|
+
Args:
|
|
1250
|
+
expr: Filter expression (SQL WHERE clause)
|
|
1251
|
+
output_fields: Fields to return
|
|
1252
|
+
limit: Maximum results
|
|
1253
|
+
offset: Offset for pagination
|
|
1254
|
+
timeout: Operation timeout in seconds
|
|
1255
|
+
batch_size: If specified, returns an async iterator that yields batches
|
|
1256
|
+
|
|
1257
|
+
Returns:
|
|
1258
|
+
List of matching records, or QueryIterator if batch_size specified
|
|
1259
|
+
"""
|
|
1260
|
+
# If batch_size specified, return async iterator
|
|
1261
|
+
if batch_size is not None and batch_size > 0:
|
|
1262
|
+
return QueryIterator(
|
|
1263
|
+
collection=self,
|
|
1264
|
+
filter_expr=expr or "",
|
|
1265
|
+
batch_size=batch_size,
|
|
1266
|
+
output_fields=output_fields,
|
|
1267
|
+
limit=limit,
|
|
1268
|
+
offset=offset
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
# Build and execute query
|
|
1272
|
+
from ..operations.query import build_query_sql, parse_query_result
|
|
1273
|
+
sql, params = build_query_sql(self._name, expr, output_fields, limit, offset)
|
|
1274
|
+
try:
|
|
1275
|
+
rows = await self._executor.execute(sql, params)
|
|
1276
|
+
return parse_query_result(rows, output_fields)
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
raise CollectionError(f"Query failed: {e}")
|
|
1279
|
+
|
|
1280
|
+
@_with_connection(schema=True)
|
|
1281
|
+
async def query_iterator(
|
|
1282
|
+
self,
|
|
1283
|
+
expr: Optional[str] = None,
|
|
1284
|
+
output_fields: Optional[List[str]] = None,
|
|
1285
|
+
limit: Optional[int] = None,
|
|
1286
|
+
offset: Optional[int] = None,
|
|
1287
|
+
batch_size: int = 1000,
|
|
1288
|
+
partition_names: Optional[List[str]] = None,
|
|
1289
|
+
timeout: Optional[float] = None,
|
|
1290
|
+
**kwargs
|
|
1291
|
+
) -> QueryIterator:
|
|
1292
|
+
"""Return a QueryIterator for paginated result access."""
|
|
1293
|
+
results = await self._fetch_batch(
|
|
1294
|
+
filter_expr=expr or "",
|
|
1295
|
+
output_fields=output_fields,
|
|
1296
|
+
limit=limit,
|
|
1297
|
+
offset=offset or 0,
|
|
1298
|
+
)
|
|
1299
|
+
return QueryIterator(
|
|
1300
|
+
data=results,
|
|
1301
|
+
batch_size=batch_size,
|
|
1302
|
+
output_fields=output_fields,
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
@_with_connection(schema=True)
|
|
1306
|
+
async def get(
|
|
1307
|
+
self,
|
|
1308
|
+
ids: Union[List[Any], Any],
|
|
1309
|
+
output_fields: Optional[List[str]] = None,
|
|
1310
|
+
partition_names: Optional[List[str]] = None,
|
|
1311
|
+
timeout: Optional[float] = None,
|
|
1312
|
+
**kwargs
|
|
1313
|
+
) -> List[Dict[str, Any]]:
|
|
1314
|
+
"""
|
|
1315
|
+
Get entities by primary key.
|
|
1316
|
+
|
|
1317
|
+
Args:
|
|
1318
|
+
ids: Primary key value(s)
|
|
1319
|
+
output_fields: Fields to return
|
|
1320
|
+
timeout: Operation timeout in seconds
|
|
1321
|
+
|
|
1322
|
+
Returns:
|
|
1323
|
+
List of matching records
|
|
1324
|
+
"""
|
|
1325
|
+
if not isinstance(ids, list):
|
|
1326
|
+
ids = [ids]
|
|
1327
|
+
|
|
1328
|
+
pk_field = self._schema.get_primary_field()
|
|
1329
|
+
if not pk_field:
|
|
1330
|
+
raise CollectionError("Get requires a primary key")
|
|
1331
|
+
|
|
1332
|
+
from ..operations.query import build_get_sql, parse_query_result
|
|
1333
|
+
sql, params = build_get_sql(self._name, pk_field.name, ids, output_fields)
|
|
1334
|
+
try:
|
|
1335
|
+
rows = await self._executor.execute(sql, params)
|
|
1336
|
+
return parse_query_result(rows, output_fields)
|
|
1337
|
+
except Exception as e:
|
|
1338
|
+
raise CollectionError(f"Failed to get entities: {e}")
|
|
1339
|
+
|
|
1340
|
+
async def get_entity_by_id(
|
|
1341
|
+
self,
|
|
1342
|
+
ids: Union[List[Any], Any],
|
|
1343
|
+
output_fields: Optional[List[str]] = None,
|
|
1344
|
+
timeout: Optional[float] = None,
|
|
1345
|
+
**kwargs
|
|
1346
|
+
) -> List[Dict[str, Any]]:
|
|
1347
|
+
"""
|
|
1348
|
+
Get entities by primary key (PyMilvus API compatibility).
|
|
1349
|
+
|
|
1350
|
+
Deprecated: use get() instead.
|
|
1351
|
+
|
|
1352
|
+
Args:
|
|
1353
|
+
ids: Primary key value(s)
|
|
1354
|
+
output_fields: Fields to return
|
|
1355
|
+
timeout: Operation timeout in seconds
|
|
1356
|
+
|
|
1357
|
+
Returns:
|
|
1358
|
+
List of matching records
|
|
1359
|
+
"""
|
|
1360
|
+
import warnings
|
|
1361
|
+
warnings.warn(
|
|
1362
|
+
"get_entity_by_id() is deprecated, use get() instead",
|
|
1363
|
+
DeprecationWarning,
|
|
1364
|
+
stacklevel=2,
|
|
1365
|
+
)
|
|
1366
|
+
return await self.get(ids=ids, output_fields=output_fields, timeout=timeout, **kwargs)
|
|
1367
|
+
|