pyseekdb 0.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyseekdb/__init__.py +90 -0
- pyseekdb/client/__init__.py +324 -0
- pyseekdb/client/admin_client.py +202 -0
- pyseekdb/client/base_connection.py +82 -0
- pyseekdb/client/client_base.py +1921 -0
- pyseekdb/client/client_oceanbase_server.py +258 -0
- pyseekdb/client/client_seekdb_embedded.py +324 -0
- pyseekdb/client/client_seekdb_server.py +226 -0
- pyseekdb/client/collection.py +485 -0
- pyseekdb/client/database.py +55 -0
- pyseekdb/client/filters.py +357 -0
- pyseekdb/client/meta_info.py +15 -0
- pyseekdb/client/query_result.py +122 -0
- pyseekdb/client/sql_utils.py +48 -0
- pyseekdb/examples/comprehensive_example.py +412 -0
- pyseekdb/examples/simple_example.py +113 -0
- pyseekdb/tests/__init__.py +0 -0
- pyseekdb/tests/test_admin_database_management.py +307 -0
- pyseekdb/tests/test_client_creation.py +425 -0
- pyseekdb/tests/test_collection_dml.py +652 -0
- pyseekdb/tests/test_collection_get.py +550 -0
- pyseekdb/tests/test_collection_hybrid_search.py +1126 -0
- pyseekdb/tests/test_collection_query.py +428 -0
- pyseekdb-0.1.0.dev3.dist-info/LICENSE +202 -0
- pyseekdb-0.1.0.dev3.dist-info/METADATA +856 -0
- pyseekdb-0.1.0.dev3.dist-info/RECORD +27 -0
- pyseekdb-0.1.0.dev3.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1921 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base client interface definition
|
|
3
|
+
"""
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import List, Optional, Sequence, Dict, Any, Union, TYPE_CHECKING, Tuple, Callable
|
|
9
|
+
|
|
10
|
+
from .base_connection import BaseConnection
|
|
11
|
+
from .admin_client import AdminAPI, DEFAULT_TENANT
|
|
12
|
+
from .meta_info import CollectionNames, CollectionFieldNames
|
|
13
|
+
from .query_result import QueryResult
|
|
14
|
+
from .filters import FilterBuilder
|
|
15
|
+
|
|
16
|
+
from .collection import Collection
|
|
17
|
+
|
|
18
|
+
from .database import Database
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
class ClientAPI(ABC):
|
|
23
|
+
"""
|
|
24
|
+
Client API interface for collection operations only.
|
|
25
|
+
This is what end users interact with through the Client proxy.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def create_collection(
|
|
30
|
+
self,
|
|
31
|
+
name: str,
|
|
32
|
+
dimension: Optional[int] = None,
|
|
33
|
+
**kwargs
|
|
34
|
+
) -> "Collection":
|
|
35
|
+
"""Create collection"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def get_collection(self, name: str) -> "Collection":
|
|
40
|
+
"""Get collection object"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def delete_collection(self, name: str) -> None:
|
|
45
|
+
"""Delete collection"""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def list_collections(self) -> List["Collection"]:
|
|
50
|
+
"""List all collections"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def has_collection(self, name: str) -> bool:
|
|
55
|
+
"""Check if collection exists"""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BaseClient(BaseConnection, AdminAPI):
|
|
60
|
+
"""
|
|
61
|
+
Abstract base class for all clients.
|
|
62
|
+
|
|
63
|
+
Design Pattern:
|
|
64
|
+
1. Provides public collection management methods (create_collection, get_collection, etc.)
|
|
65
|
+
2. Defines internal operation interfaces (_collection_* methods) called by Collection objects
|
|
66
|
+
3. Subclasses implement all abstract methods to provide specific business logic
|
|
67
|
+
|
|
68
|
+
Benefits of this design:
|
|
69
|
+
- Collection object interface is unified regardless of which client created it
|
|
70
|
+
- Different clients can have completely different underlying implementations (SQL/gRPC/REST)
|
|
71
|
+
- Easy to extend with new client types
|
|
72
|
+
|
|
73
|
+
Inherits connection management from BaseConnection and database operations from AdminAPI.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# ==================== Collection Management (User-facing) ====================
|
|
77
|
+
|
|
78
|
+
def create_collection(
|
|
79
|
+
self,
|
|
80
|
+
name: str,
|
|
81
|
+
dimension: Optional[int] = None,
|
|
82
|
+
**kwargs
|
|
83
|
+
) -> "Collection":
|
|
84
|
+
"""
|
|
85
|
+
Create a collection (user-facing API)
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
name: Collection name
|
|
89
|
+
dimension: Vector dimension
|
|
90
|
+
**kwargs: Additional parameters
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Collection object
|
|
94
|
+
"""
|
|
95
|
+
if dimension is None:
|
|
96
|
+
raise ValueError("dimension parameter is required for creating a collection")
|
|
97
|
+
|
|
98
|
+
# Construct table name: c$v1${name}
|
|
99
|
+
table_name = CollectionNames.table_name(name)
|
|
100
|
+
|
|
101
|
+
# Construct CREATE TABLE SQL statement with HEAP organization
|
|
102
|
+
sql = f"""CREATE TABLE `{table_name}` (
|
|
103
|
+
_id varbinary(512) PRIMARY KEY NOT NULL,
|
|
104
|
+
document string,
|
|
105
|
+
embedding vector({dimension}),
|
|
106
|
+
metadata json,
|
|
107
|
+
FULLTEXT INDEX idx1(document),
|
|
108
|
+
VECTOR INDEX idx2 (embedding) with(distance=l2, type=hnsw, lib=vsag)
|
|
109
|
+
) ORGANIZATION = HEAP;"""
|
|
110
|
+
|
|
111
|
+
# Execute SQL to create table
|
|
112
|
+
self.execute(sql)
|
|
113
|
+
|
|
114
|
+
# Create and return Collection object
|
|
115
|
+
return Collection(client=self, name=name, dimension=dimension, **kwargs)
|
|
116
|
+
|
|
117
|
+
def get_collection(self, name: str) -> "Collection":
|
|
118
|
+
"""
|
|
119
|
+
Get a collection object (user-facing API)
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: Collection name
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Collection object
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
ValueError: If collection does not exist
|
|
129
|
+
"""
|
|
130
|
+
# Construct table name: c$v1${name}
|
|
131
|
+
table_name = CollectionNames.table_name(name)
|
|
132
|
+
|
|
133
|
+
# Check if table exists by describing it
|
|
134
|
+
try:
|
|
135
|
+
table_info = self.execute(f"DESCRIBE `{table_name}`")
|
|
136
|
+
if not table_info or len(table_info) == 0:
|
|
137
|
+
raise ValueError(f"Collection '{name}' does not exist (table '{table_name}' not found)")
|
|
138
|
+
except Exception as e:
|
|
139
|
+
# If DESCRIBE fails, check if it's because table doesn't exist
|
|
140
|
+
error_msg = str(e).lower()
|
|
141
|
+
if "doesn't exist" in error_msg or "not found" in error_msg or "table" in error_msg:
|
|
142
|
+
raise ValueError(f"Collection '{name}' does not exist (table '{table_name}' not found)") from e
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
# Extract dimension from embedding column
|
|
146
|
+
dimension = None
|
|
147
|
+
for row in table_info:
|
|
148
|
+
# Handle both dict and tuple formats
|
|
149
|
+
if isinstance(row, dict):
|
|
150
|
+
field_name = row.get('Field', row.get('field', ''))
|
|
151
|
+
field_type = row.get('Type', row.get('type', ''))
|
|
152
|
+
elif isinstance(row, (tuple, list)):
|
|
153
|
+
field_name = row[0] if len(row) > 0 else ''
|
|
154
|
+
field_type = row[1] if len(row) > 1 else ''
|
|
155
|
+
else:
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
if field_name == 'embedding' and 'vector' in str(field_type).lower():
|
|
159
|
+
# Extract dimension from vector(dimension) format
|
|
160
|
+
match = re.search(r'vector\s*\(\s*(\d+)\s*\)', str(field_type), re.IGNORECASE)
|
|
161
|
+
if match:
|
|
162
|
+
dimension = int(match.group(1))
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
# Create and return Collection object
|
|
166
|
+
return Collection(client=self, name=name, dimension=dimension)
|
|
167
|
+
|
|
168
|
+
def delete_collection(self, name: str) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Delete a collection (user-facing API)
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
name: Collection name
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
ValueError: If collection does not exist
|
|
177
|
+
"""
|
|
178
|
+
# Construct table name: c$v1${name}
|
|
179
|
+
table_name = CollectionNames.table_name(name)
|
|
180
|
+
|
|
181
|
+
# Check if table exists first
|
|
182
|
+
if not self.has_collection(name):
|
|
183
|
+
raise ValueError(f"Collection '{name}' does not exist (table '{table_name}' not found)")
|
|
184
|
+
|
|
185
|
+
# Execute DROP TABLE SQL
|
|
186
|
+
self.execute(f"DROP TABLE IF EXISTS `{table_name}`")
|
|
187
|
+
|
|
188
|
+
def list_collections(self) -> List["Collection"]:
|
|
189
|
+
"""
|
|
190
|
+
List all collections (user-facing API)
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
List of Collection objects
|
|
194
|
+
"""
|
|
195
|
+
# List all tables that start with 'c$v1'
|
|
196
|
+
# Use SHOW TABLES LIKE 'c$v1%' to filter collection tables
|
|
197
|
+
try:
|
|
198
|
+
tables = self.execute("SHOW TABLES LIKE 'c$v1$%'")
|
|
199
|
+
except Exception:
|
|
200
|
+
# Fallback: try to query information_schema
|
|
201
|
+
try:
|
|
202
|
+
# Get current database name
|
|
203
|
+
db_result = self.execute("SELECT DATABASE()")
|
|
204
|
+
if db_result and len(db_result) > 0:
|
|
205
|
+
db_name = db_result[0][0] if isinstance(db_result[0], (tuple, list)) else db_result[0].get('DATABASE()', '')
|
|
206
|
+
tables = self.execute(
|
|
207
|
+
f"SELECT TABLE_NAME FROM information_schema.TABLES "
|
|
208
|
+
f"WHERE TABLE_SCHEMA = '{db_name}' AND TABLE_NAME LIKE 'c$v1$%'"
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
return []
|
|
212
|
+
except Exception:
|
|
213
|
+
return []
|
|
214
|
+
|
|
215
|
+
collections = []
|
|
216
|
+
for row in tables:
|
|
217
|
+
# Extract table name
|
|
218
|
+
if isinstance(row, dict):
|
|
219
|
+
# Server client returns dict, get the first value
|
|
220
|
+
table_name = list(row.values())[0] if row else ''
|
|
221
|
+
elif isinstance(row, (tuple, list)):
|
|
222
|
+
# Embedded client returns tuple, first element is table name
|
|
223
|
+
table_name = row[0] if len(row) > 0 else ''
|
|
224
|
+
else:
|
|
225
|
+
table_name = str(row)
|
|
226
|
+
|
|
227
|
+
# Extract collection name from table name (remove 'c$v1$' prefix)
|
|
228
|
+
if table_name.startswith('c$v1$'):
|
|
229
|
+
collection_name = table_name[5:] # Remove 'c$v1$' prefix
|
|
230
|
+
|
|
231
|
+
# Get collection with dimension
|
|
232
|
+
try:
|
|
233
|
+
collection = self.get_collection(collection_name)
|
|
234
|
+
collections.append(collection)
|
|
235
|
+
except Exception:
|
|
236
|
+
# Skip if we can't get collection info
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
return collections
|
|
240
|
+
|
|
241
|
+
def count_collection(self) -> int:
|
|
242
|
+
"""
|
|
243
|
+
Count the number of collections in the current database
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Number of collections
|
|
247
|
+
|
|
248
|
+
Examples:
|
|
249
|
+
count = client.count_collection()
|
|
250
|
+
print(f"Database has {count} collections")
|
|
251
|
+
"""
|
|
252
|
+
collections = self.list_collections()
|
|
253
|
+
return len(collections)
|
|
254
|
+
|
|
255
|
+
def has_collection(self, name: str) -> bool:
|
|
256
|
+
"""
|
|
257
|
+
Check if a collection exists (user-facing API)
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
name: Collection name
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
True if exists, False otherwise
|
|
264
|
+
"""
|
|
265
|
+
# Construct table name: c$v1${name}
|
|
266
|
+
table_name = CollectionNames.table_name(name)
|
|
267
|
+
|
|
268
|
+
# Check if table exists
|
|
269
|
+
try:
|
|
270
|
+
# Try to describe the table
|
|
271
|
+
table_info = self.execute(f"DESCRIBE `{table_name}`")
|
|
272
|
+
return table_info is not None and len(table_info) > 0
|
|
273
|
+
except Exception:
|
|
274
|
+
# If DESCRIBE fails, table doesn't exist
|
|
275
|
+
return False
|
|
276
|
+
|
|
277
|
+
def get_or_create_collection(
|
|
278
|
+
self,
|
|
279
|
+
name: str,
|
|
280
|
+
dimension: Optional[int] = None,
|
|
281
|
+
**kwargs
|
|
282
|
+
) -> "Collection":
|
|
283
|
+
"""
|
|
284
|
+
Get an existing collection or create it if it doesn't exist (user-facing API)
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
name: Collection name
|
|
288
|
+
dimension: Vector dimension (required if creating new collection)
|
|
289
|
+
**kwargs: Additional parameters for create_collection
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Collection object
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
ValueError: If collection doesn't exist and dimension is not provided
|
|
296
|
+
"""
|
|
297
|
+
# First, try to get the collection
|
|
298
|
+
if self.has_collection(name):
|
|
299
|
+
return self.get_collection(name)
|
|
300
|
+
|
|
301
|
+
# Collection doesn't exist, create it
|
|
302
|
+
if dimension is None:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Collection '{name}' does not exist and dimension parameter is required "
|
|
305
|
+
f"for creating a new collection"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return self.create_collection(name=name, dimension=dimension, **kwargs)
|
|
309
|
+
|
|
310
|
+
# ==================== Collection Internal Operations (Called by Collection) ====================
|
|
311
|
+
# These methods are called by Collection objects, different clients implement different logic
|
|
312
|
+
|
|
313
|
+
# -------------------- DML Operations --------------------
|
|
314
|
+
|
|
315
|
+
def _collection_add(
|
|
316
|
+
self,
|
|
317
|
+
collection_id: Optional[str],
|
|
318
|
+
collection_name: str,
|
|
319
|
+
ids: Union[str, List[str]],
|
|
320
|
+
vectors: Optional[Union[List[float], List[List[float]]]] = None,
|
|
321
|
+
metadatas: Optional[Union[Dict, List[Dict]]] = None,
|
|
322
|
+
documents: Optional[Union[str, List[str]]] = None,
|
|
323
|
+
**kwargs
|
|
324
|
+
) -> None:
|
|
325
|
+
"""
|
|
326
|
+
[Internal] Add data to collection - Common SQL-based implementation
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
collection_id: Collection ID
|
|
330
|
+
collection_name: Collection name
|
|
331
|
+
ids: Single ID or list of IDs
|
|
332
|
+
vectors: Single vector or list of vectors (optional)
|
|
333
|
+
metadatas: Single metadata dict or list of metadata dicts (optional)
|
|
334
|
+
documents: Single document or list of documents (optional)
|
|
335
|
+
**kwargs: Additional parameters
|
|
336
|
+
"""
|
|
337
|
+
logger.info(f"Adding data to collection '{collection_name}'")
|
|
338
|
+
|
|
339
|
+
# Normalize inputs to lists
|
|
340
|
+
if isinstance(ids, str):
|
|
341
|
+
ids = [ids]
|
|
342
|
+
if isinstance(documents, str):
|
|
343
|
+
documents = [documents]
|
|
344
|
+
if metadatas is not None and isinstance(metadatas, dict):
|
|
345
|
+
metadatas = [metadatas]
|
|
346
|
+
if vectors is not None:
|
|
347
|
+
if isinstance(vectors, list) and len(vectors) > 0 and not isinstance(vectors[0], list):
|
|
348
|
+
vectors = [vectors]
|
|
349
|
+
|
|
350
|
+
# Validate inputs
|
|
351
|
+
if not documents and not vectors and not metadatas:
|
|
352
|
+
raise ValueError("At least one of documents, vectors, or metadatas must be provided")
|
|
353
|
+
|
|
354
|
+
# Determine number of items
|
|
355
|
+
num_items = 0
|
|
356
|
+
if ids:
|
|
357
|
+
num_items = len(ids)
|
|
358
|
+
elif documents:
|
|
359
|
+
num_items = len(documents)
|
|
360
|
+
elif vectors:
|
|
361
|
+
num_items = len(vectors)
|
|
362
|
+
elif metadatas:
|
|
363
|
+
num_items = len(metadatas)
|
|
364
|
+
|
|
365
|
+
if num_items == 0:
|
|
366
|
+
raise ValueError("No items to add")
|
|
367
|
+
|
|
368
|
+
# Validate lengths match
|
|
369
|
+
if ids and len(ids) != num_items:
|
|
370
|
+
raise ValueError(f"Number of ids ({len(ids)}) does not match number of items ({num_items})")
|
|
371
|
+
if documents and len(documents) != num_items:
|
|
372
|
+
raise ValueError(f"Number of documents ({len(documents)}) does not match number of items ({num_items})")
|
|
373
|
+
if metadatas and len(metadatas) != num_items:
|
|
374
|
+
raise ValueError(f"Number of metadatas ({len(metadatas)}) does not match number of items ({num_items})")
|
|
375
|
+
if vectors and len(vectors) != num_items:
|
|
376
|
+
raise ValueError(f"Number of vectors ({len(vectors)}) does not match number of items ({num_items})")
|
|
377
|
+
|
|
378
|
+
# Get table name
|
|
379
|
+
table_name = CollectionNames.table_name(collection_name)
|
|
380
|
+
|
|
381
|
+
# Build INSERT SQL
|
|
382
|
+
values_list = []
|
|
383
|
+
for i in range(num_items):
|
|
384
|
+
# Process ID - support any string format
|
|
385
|
+
id_val = ids[i] if ids else None
|
|
386
|
+
if id_val:
|
|
387
|
+
if not isinstance(id_val, str):
|
|
388
|
+
id_val = str(id_val)
|
|
389
|
+
id_sql = self._convert_id_to_sql(id_val)
|
|
390
|
+
else:
|
|
391
|
+
raise ValueError("ids must be provided for add operation")
|
|
392
|
+
|
|
393
|
+
# Process document
|
|
394
|
+
doc_val = documents[i] if documents else None
|
|
395
|
+
if doc_val is not None:
|
|
396
|
+
# Escape single quotes
|
|
397
|
+
doc_val_escaped = doc_val.replace("'", "''")
|
|
398
|
+
doc_sql = f"'{doc_val_escaped}'"
|
|
399
|
+
else:
|
|
400
|
+
doc_sql = "NULL"
|
|
401
|
+
|
|
402
|
+
# Process metadata
|
|
403
|
+
meta_val = metadatas[i] if metadatas else None
|
|
404
|
+
if meta_val is not None:
|
|
405
|
+
# Convert to JSON string and escape
|
|
406
|
+
meta_json = json.dumps(meta_val, ensure_ascii=False)
|
|
407
|
+
meta_json_escaped = meta_json.replace("'", "''")
|
|
408
|
+
meta_sql = f"'{meta_json_escaped}'"
|
|
409
|
+
else:
|
|
410
|
+
meta_sql = "NULL"
|
|
411
|
+
|
|
412
|
+
# Process vector
|
|
413
|
+
vec_val = vectors[i] if vectors else None
|
|
414
|
+
if vec_val is not None:
|
|
415
|
+
# Convert vector to string format: [1.0,2.0,3.0]
|
|
416
|
+
vec_str = "[" + ",".join(map(str, vec_val)) + "]"
|
|
417
|
+
vec_sql = f"'{vec_str}'"
|
|
418
|
+
else:
|
|
419
|
+
vec_sql = "NULL"
|
|
420
|
+
|
|
421
|
+
values_list.append(f"({id_sql}, {doc_sql}, {meta_sql}, {vec_sql})")
|
|
422
|
+
|
|
423
|
+
# Build final SQL
|
|
424
|
+
sql = f"""INSERT INTO `{table_name}` ({CollectionFieldNames.ID}, {CollectionFieldNames.DOCUMENT}, {CollectionFieldNames.METADATA}, {CollectionFieldNames.EMBEDDING})
|
|
425
|
+
VALUES {','.join(values_list)}"""
|
|
426
|
+
|
|
427
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
428
|
+
self.execute(sql)
|
|
429
|
+
logger.info(f"✅ Successfully added {num_items} item(s) to collection '{collection_name}'")
|
|
430
|
+
|
|
431
|
+
def _collection_update(
|
|
432
|
+
self,
|
|
433
|
+
collection_id: Optional[str],
|
|
434
|
+
collection_name: str,
|
|
435
|
+
ids: Union[str, List[str]],
|
|
436
|
+
vectors: Optional[Union[List[float], List[List[float]]]] = None,
|
|
437
|
+
metadatas: Optional[Union[Dict, List[Dict]]] = None,
|
|
438
|
+
documents: Optional[Union[str, List[str]]] = None,
|
|
439
|
+
**kwargs
|
|
440
|
+
) -> None:
|
|
441
|
+
"""
|
|
442
|
+
[Internal] Update data in collection - Common SQL-based implementation
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
collection_id: Collection ID
|
|
446
|
+
collection_name: Collection name
|
|
447
|
+
ids: Single ID or list of IDs to update
|
|
448
|
+
vectors: New vectors (optional)
|
|
449
|
+
metadatas: New metadata (optional)
|
|
450
|
+
documents: New documents (optional)
|
|
451
|
+
**kwargs: Additional parameters
|
|
452
|
+
"""
|
|
453
|
+
logger.info(f"Updating data in collection '{collection_name}'")
|
|
454
|
+
|
|
455
|
+
# Normalize inputs to lists
|
|
456
|
+
if isinstance(ids, str):
|
|
457
|
+
ids = [ids]
|
|
458
|
+
if isinstance(documents, str):
|
|
459
|
+
documents = [documents]
|
|
460
|
+
if metadatas is not None and isinstance(metadatas, dict):
|
|
461
|
+
metadatas = [metadatas]
|
|
462
|
+
if vectors is not None:
|
|
463
|
+
if isinstance(vectors, list) and len(vectors) > 0 and not isinstance(vectors[0], list):
|
|
464
|
+
vectors = [vectors]
|
|
465
|
+
|
|
466
|
+
# Validate inputs
|
|
467
|
+
if not ids:
|
|
468
|
+
raise ValueError("ids must not be empty")
|
|
469
|
+
if not documents and not metadatas and not vectors:
|
|
470
|
+
raise ValueError("You must specify at least one column to update")
|
|
471
|
+
|
|
472
|
+
# Validate lengths match
|
|
473
|
+
if documents and len(documents) != len(ids):
|
|
474
|
+
raise ValueError(f"Number of documents ({len(documents)}) does not match number of ids ({len(ids)})")
|
|
475
|
+
if metadatas and len(metadatas) != len(ids):
|
|
476
|
+
raise ValueError(f"Number of metadatas ({len(metadatas)}) does not match number of ids ({len(ids)})")
|
|
477
|
+
if vectors and len(vectors) != len(ids):
|
|
478
|
+
raise ValueError(f"Number of vectors ({len(vectors)}) does not match number of ids ({len(ids)})")
|
|
479
|
+
|
|
480
|
+
# Get table name
|
|
481
|
+
table_name = CollectionNames.table_name(collection_name)
|
|
482
|
+
|
|
483
|
+
# Update each item
|
|
484
|
+
for i in range(len(ids)):
|
|
485
|
+
# Process ID - support any string format
|
|
486
|
+
id_val = ids[i]
|
|
487
|
+
if not isinstance(id_val, str):
|
|
488
|
+
id_val = str(id_val)
|
|
489
|
+
id_sql = self._convert_id_to_sql(id_val)
|
|
490
|
+
|
|
491
|
+
# Build SET clause
|
|
492
|
+
set_clauses = []
|
|
493
|
+
|
|
494
|
+
if documents:
|
|
495
|
+
doc_val = documents[i]
|
|
496
|
+
if doc_val is not None:
|
|
497
|
+
doc_val_escaped = doc_val.replace("'", "''")
|
|
498
|
+
set_clauses.append(f"{CollectionFieldNames.DOCUMENT} = '{doc_val_escaped}'")
|
|
499
|
+
|
|
500
|
+
if metadatas:
|
|
501
|
+
meta_val = metadatas[i]
|
|
502
|
+
if meta_val is not None:
|
|
503
|
+
meta_json = json.dumps(meta_val, ensure_ascii=False)
|
|
504
|
+
meta_json_escaped = meta_json.replace("'", "''")
|
|
505
|
+
set_clauses.append(f"{CollectionFieldNames.METADATA} = '{meta_json_escaped}'")
|
|
506
|
+
|
|
507
|
+
if vectors:
|
|
508
|
+
vec_val = vectors[i]
|
|
509
|
+
if vec_val is not None:
|
|
510
|
+
vec_str = "[" + ",".join(map(str, vec_val)) + "]"
|
|
511
|
+
set_clauses.append(f"{CollectionFieldNames.EMBEDDING} = '{vec_str}'")
|
|
512
|
+
|
|
513
|
+
if not set_clauses:
|
|
514
|
+
continue
|
|
515
|
+
|
|
516
|
+
# Build UPDATE SQL
|
|
517
|
+
sql = f"UPDATE `{table_name}` SET {', '.join(set_clauses)} WHERE {CollectionFieldNames.ID} = {id_sql}"
|
|
518
|
+
|
|
519
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
520
|
+
self.execute(sql)
|
|
521
|
+
|
|
522
|
+
logger.info(f"✅ Successfully updated {len(ids)} item(s) in collection '{collection_name}'")
|
|
523
|
+
|
|
524
|
+
def _collection_upsert(
|
|
525
|
+
self,
|
|
526
|
+
collection_id: Optional[str],
|
|
527
|
+
collection_name: str,
|
|
528
|
+
ids: Union[str, List[str]],
|
|
529
|
+
vectors: Optional[Union[List[float], List[List[float]]]] = None,
|
|
530
|
+
metadatas: Optional[Union[Dict, List[Dict]]] = None,
|
|
531
|
+
documents: Optional[Union[str, List[str]]] = None,
|
|
532
|
+
**kwargs
|
|
533
|
+
) -> None:
|
|
534
|
+
"""
|
|
535
|
+
[Internal] Insert or update data in collection - Common SQL-based implementation
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
collection_id: Collection ID
|
|
539
|
+
collection_name: Collection name
|
|
540
|
+
ids: Single ID or list of IDs
|
|
541
|
+
vectors: Vectors (optional)
|
|
542
|
+
metadatas: Metadata (optional)
|
|
543
|
+
documents: Documents (optional)
|
|
544
|
+
**kwargs: Additional parameters
|
|
545
|
+
"""
|
|
546
|
+
logger.info(f"Upserting data in collection '{collection_name}'")
|
|
547
|
+
|
|
548
|
+
# Normalize inputs to lists
|
|
549
|
+
if isinstance(ids, str):
|
|
550
|
+
ids = [ids]
|
|
551
|
+
if isinstance(documents, str):
|
|
552
|
+
documents = [documents]
|
|
553
|
+
if metadatas is not None and isinstance(metadatas, dict):
|
|
554
|
+
metadatas = [metadatas]
|
|
555
|
+
if vectors is not None:
|
|
556
|
+
if isinstance(vectors, list) and len(vectors) > 0 and not isinstance(vectors[0], list):
|
|
557
|
+
vectors = [vectors]
|
|
558
|
+
|
|
559
|
+
# Validate inputs
|
|
560
|
+
if not ids:
|
|
561
|
+
raise ValueError("ids must not be empty")
|
|
562
|
+
if not documents and not metadatas and not vectors:
|
|
563
|
+
raise ValueError("You must specify at least one column to upsert")
|
|
564
|
+
|
|
565
|
+
# Validate lengths match
|
|
566
|
+
if documents and len(documents) != len(ids):
|
|
567
|
+
raise ValueError(f"Number of documents ({len(documents)}) does not match number of ids ({len(ids)})")
|
|
568
|
+
if metadatas and len(metadatas) != len(ids):
|
|
569
|
+
raise ValueError(f"Number of metadatas ({len(metadatas)}) does not match number of ids ({len(ids)})")
|
|
570
|
+
if vectors and len(vectors) != len(ids):
|
|
571
|
+
raise ValueError(f"Number of vectors ({len(vectors)}) does not match number of ids ({len(ids)})")
|
|
572
|
+
|
|
573
|
+
# Get table name
|
|
574
|
+
table_name = CollectionNames.table_name(collection_name)
|
|
575
|
+
|
|
576
|
+
# Upsert each item
|
|
577
|
+
for i in range(len(ids)):
|
|
578
|
+
# Process ID - support any string format
|
|
579
|
+
id_val = ids[i]
|
|
580
|
+
if not isinstance(id_val, str):
|
|
581
|
+
id_val = str(id_val)
|
|
582
|
+
id_sql = self._convert_id_to_sql(id_val)
|
|
583
|
+
|
|
584
|
+
# Check if record exists
|
|
585
|
+
existing = self._collection_get(
|
|
586
|
+
collection_id=collection_id,
|
|
587
|
+
collection_name=collection_name,
|
|
588
|
+
ids=[ids[i]], # Use original string ID for query
|
|
589
|
+
include=["documents", "metadatas", "embeddings"]
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Get values for this item
|
|
593
|
+
doc_val = documents[i] if documents else None
|
|
594
|
+
meta_val = metadatas[i] if metadatas else None
|
|
595
|
+
vec_val = vectors[i] if vectors else None
|
|
596
|
+
|
|
597
|
+
if existing and len(existing) > 0:
|
|
598
|
+
# Update existing record - only update provided fields
|
|
599
|
+
existing_item = existing[0]
|
|
600
|
+
existing_doc = existing_item.document if hasattr(existing_item, 'document') else None
|
|
601
|
+
existing_meta = existing_item.metadata if hasattr(existing_item, 'metadata') else None
|
|
602
|
+
existing_vec = existing_item.embedding if hasattr(existing_item, 'embedding') else None
|
|
603
|
+
|
|
604
|
+
# Use provided values or keep existing values
|
|
605
|
+
final_document = doc_val if doc_val is not None else existing_doc
|
|
606
|
+
final_metadata = meta_val if meta_val is not None else existing_meta
|
|
607
|
+
final_vector = vec_val if vec_val is not None else existing_vec
|
|
608
|
+
|
|
609
|
+
# Build SET clause
|
|
610
|
+
set_clauses = []
|
|
611
|
+
|
|
612
|
+
if doc_val is not None:
|
|
613
|
+
doc_val_escaped = final_document.replace("'", "''") if final_document else "NULL"
|
|
614
|
+
set_clauses.append(f"{CollectionFieldNames.DOCUMENT} = '{doc_val_escaped}'")
|
|
615
|
+
|
|
616
|
+
if meta_val is not None:
|
|
617
|
+
meta_json = json.dumps(final_metadata, ensure_ascii=False) if final_metadata else "{}"
|
|
618
|
+
meta_json_escaped = meta_json.replace("'", "''")
|
|
619
|
+
set_clauses.append(f"{CollectionFieldNames.METADATA} = '{meta_json_escaped}'")
|
|
620
|
+
|
|
621
|
+
if vec_val is not None:
|
|
622
|
+
vec_str = "[" + ",".join(map(str, final_vector)) + "]" if final_vector else "NULL"
|
|
623
|
+
set_clauses.append(f"{CollectionFieldNames.EMBEDDING} = '{vec_str}'")
|
|
624
|
+
|
|
625
|
+
if set_clauses:
|
|
626
|
+
sql = f"UPDATE `{table_name}` SET {', '.join(set_clauses)} WHERE {CollectionFieldNames.ID} = {id_sql}"
|
|
627
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
628
|
+
self.execute(sql)
|
|
629
|
+
else:
|
|
630
|
+
# Insert new record
|
|
631
|
+
if doc_val:
|
|
632
|
+
doc_val_escaped = doc_val.replace("'", "''")
|
|
633
|
+
doc_sql = f"'{doc_val_escaped}'"
|
|
634
|
+
else:
|
|
635
|
+
doc_sql = "NULL"
|
|
636
|
+
|
|
637
|
+
if meta_val is not None:
|
|
638
|
+
meta_json = json.dumps(meta_val, ensure_ascii=False)
|
|
639
|
+
meta_json_escaped = meta_json.replace("'", "''")
|
|
640
|
+
meta_sql = f"'{meta_json_escaped}'"
|
|
641
|
+
else:
|
|
642
|
+
meta_sql = "NULL"
|
|
643
|
+
|
|
644
|
+
if vec_val is not None:
|
|
645
|
+
vec_str = "[" + ",".join(map(str, vec_val)) + "]"
|
|
646
|
+
vec_sql = f"'{vec_str}'"
|
|
647
|
+
else:
|
|
648
|
+
vec_sql = "NULL"
|
|
649
|
+
|
|
650
|
+
sql = f"""INSERT INTO `{table_name}` ({CollectionFieldNames.ID}, {CollectionFieldNames.DOCUMENT}, {CollectionFieldNames.METADATA}, {CollectionFieldNames.EMBEDDING})
|
|
651
|
+
VALUES ({id_sql}, {doc_sql}, {meta_sql}, {vec_sql})"""
|
|
652
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
653
|
+
self.execute(sql)
|
|
654
|
+
|
|
655
|
+
logger.info(f"✅ Successfully upserted {len(ids)} item(s) in collection '{collection_name}'")
|
|
656
|
+
|
|
657
|
+
def _collection_delete(
|
|
658
|
+
self,
|
|
659
|
+
collection_id: Optional[str],
|
|
660
|
+
collection_name: str,
|
|
661
|
+
ids: Optional[Union[str, List[str]]] = None,
|
|
662
|
+
where: Optional[Dict[str, Any]] = None,
|
|
663
|
+
where_document: Optional[Dict[str, Any]] = None,
|
|
664
|
+
**kwargs
|
|
665
|
+
) -> None:
|
|
666
|
+
"""
|
|
667
|
+
[Internal] Delete data from collection - Common SQL-based implementation
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
collection_id: Collection ID
|
|
671
|
+
collection_name: Collection name
|
|
672
|
+
ids: Single ID or list of IDs to delete (optional)
|
|
673
|
+
where: Filter condition on metadata (optional)
|
|
674
|
+
where_document: Filter condition on documents (optional)
|
|
675
|
+
**kwargs: Additional parameters
|
|
676
|
+
"""
|
|
677
|
+
logger.info(f"Deleting data from collection '{collection_name}'")
|
|
678
|
+
|
|
679
|
+
# Validate that at least one filter is provided
|
|
680
|
+
if not ids and not where and not where_document:
|
|
681
|
+
raise ValueError("At least one of ids, where, or where_document must be provided")
|
|
682
|
+
|
|
683
|
+
# Normalize ids to list
|
|
684
|
+
id_list = None
|
|
685
|
+
if ids is not None:
|
|
686
|
+
if isinstance(ids, str):
|
|
687
|
+
id_list = [ids]
|
|
688
|
+
else:
|
|
689
|
+
id_list = ids
|
|
690
|
+
|
|
691
|
+
# Get table name
|
|
692
|
+
table_name = CollectionNames.table_name(collection_name)
|
|
693
|
+
|
|
694
|
+
# Build WHERE clause
|
|
695
|
+
where_clause, params = self._build_where_clause(where, where_document, id_list)
|
|
696
|
+
|
|
697
|
+
# Build DELETE SQL
|
|
698
|
+
sql = f"DELETE FROM `{table_name}` {where_clause}"
|
|
699
|
+
|
|
700
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
701
|
+
logger.debug(f"Parameters: {params}")
|
|
702
|
+
|
|
703
|
+
# Execute DELETE using parameterized query
|
|
704
|
+
conn = self._ensure_connection()
|
|
705
|
+
use_context_manager = self._use_context_manager_for_cursor()
|
|
706
|
+
self._execute_query_with_cursor(conn, sql, params, use_context_manager)
|
|
707
|
+
|
|
708
|
+
logger.info(f"✅ Successfully deleted data from collection '{collection_name}'")
|
|
709
|
+
|
|
710
|
+
# -------------------- DQL Operations --------------------
|
|
711
|
+
# Note: _collection_query() and _collection_get() are implemented below with common SQL-based logic
|
|
712
|
+
|
|
713
|
+
def _normalize_query_vectors(
|
|
714
|
+
self,
|
|
715
|
+
query_embeddings: Optional[Union[List[float], List[List[float]]]]
|
|
716
|
+
) -> List[List[float]]:
|
|
717
|
+
"""
|
|
718
|
+
Normalize query vectors to list of lists format
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
query_embeddings: Single vector or list of vectors
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
List of vectors (each vector is a list of floats)
|
|
725
|
+
"""
|
|
726
|
+
if query_embeddings is None:
|
|
727
|
+
return []
|
|
728
|
+
|
|
729
|
+
# Check if it's a single vector (list of numbers)
|
|
730
|
+
if query_embeddings and isinstance(query_embeddings[0], (int, float)):
|
|
731
|
+
return [query_embeddings]
|
|
732
|
+
|
|
733
|
+
return query_embeddings
|
|
734
|
+
|
|
735
|
+
def _normalize_include_fields(
|
|
736
|
+
self,
|
|
737
|
+
include: Optional[List[str]]
|
|
738
|
+
) -> Dict[str, bool]:
|
|
739
|
+
"""
|
|
740
|
+
Normalize include parameter to a dictionary
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
include: List of fields to include (e.g., ["documents", "metadatas", "embeddings"])
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
Dictionary with field names as keys and True as values
|
|
747
|
+
Default includes: documents, metadatas (but not embeddings)
|
|
748
|
+
"""
|
|
749
|
+
# Default includes documents and metadatas
|
|
750
|
+
default_fields = {"documents": True, "metadatas": True}
|
|
751
|
+
|
|
752
|
+
if include is None:
|
|
753
|
+
return default_fields
|
|
754
|
+
|
|
755
|
+
# Build include dict from list
|
|
756
|
+
include_dict = {}
|
|
757
|
+
for field in include:
|
|
758
|
+
include_dict[field] = True
|
|
759
|
+
|
|
760
|
+
return include_dict
|
|
761
|
+
|
|
762
|
+
def _embed_texts(
|
|
763
|
+
self,
|
|
764
|
+
texts: Union[str, List[str]],
|
|
765
|
+
**kwargs
|
|
766
|
+
) -> List[List[float]]:
|
|
767
|
+
"""
|
|
768
|
+
Embed text(s) to vector(s)
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
texts: Single text or list of texts
|
|
772
|
+
**kwargs: Additional parameters for embedding
|
|
773
|
+
|
|
774
|
+
Returns:
|
|
775
|
+
List of vectors
|
|
776
|
+
|
|
777
|
+
Note:
|
|
778
|
+
This is a placeholder method. Subclasses should override this
|
|
779
|
+
to provide actual embedding functionality, or users should
|
|
780
|
+
provide query_embeddings directly.
|
|
781
|
+
"""
|
|
782
|
+
raise NotImplementedError(
|
|
783
|
+
"Text embedding is not implemented yet. "
|
|
784
|
+
"Please provide query_embeddings directly instead of query_texts."
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
def _normalize_row(self, row: Any, cursor_description: Optional[Any] = None) -> Dict[str, Any]:
|
|
788
|
+
"""
|
|
789
|
+
Normalize database row to dictionary format
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
row: Database row (can be dict or tuple)
|
|
793
|
+
cursor_description: Cursor description for tuple rows
|
|
794
|
+
|
|
795
|
+
Returns:
|
|
796
|
+
Dictionary with column names as keys
|
|
797
|
+
"""
|
|
798
|
+
if isinstance(row, dict):
|
|
799
|
+
return row
|
|
800
|
+
|
|
801
|
+
# Convert tuple to dict using cursor description
|
|
802
|
+
if cursor_description is not None:
|
|
803
|
+
row_dict = {}
|
|
804
|
+
for idx, col_desc in enumerate(cursor_description):
|
|
805
|
+
row_dict[col_desc[0]] = row[idx]
|
|
806
|
+
return row_dict
|
|
807
|
+
|
|
808
|
+
# Fallback: assume it's already a dict or try to convert
|
|
809
|
+
return dict(row) if hasattr(row, '_asdict') else row
|
|
810
|
+
|
|
811
|
+
def _execute_query_with_cursor(
|
|
812
|
+
self,
|
|
813
|
+
conn: Any,
|
|
814
|
+
sql: str,
|
|
815
|
+
params: List[Any],
|
|
816
|
+
use_context_manager: bool = True
|
|
817
|
+
) -> List[Dict[str, Any]]:
|
|
818
|
+
"""
|
|
819
|
+
Execute SQL query and return normalized rows
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
conn: Database connection
|
|
823
|
+
sql: SQL query string
|
|
824
|
+
params: Query parameters
|
|
825
|
+
use_context_manager: Whether to use context manager for cursor (default: True)
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
List of normalized row dictionaries
|
|
829
|
+
"""
|
|
830
|
+
if use_context_manager:
|
|
831
|
+
with conn.cursor() as cursor:
|
|
832
|
+
cursor.execute(sql, params)
|
|
833
|
+
rows = cursor.fetchall()
|
|
834
|
+
# Normalize rows
|
|
835
|
+
normalized_rows = []
|
|
836
|
+
for row in rows:
|
|
837
|
+
normalized_rows.append(self._normalize_row(row, cursor.description))
|
|
838
|
+
return normalized_rows
|
|
839
|
+
else:
|
|
840
|
+
cursor = conn.cursor()
|
|
841
|
+
try:
|
|
842
|
+
cursor.execute(sql, params)
|
|
843
|
+
rows = cursor.fetchall()
|
|
844
|
+
# Normalize rows
|
|
845
|
+
normalized_rows = []
|
|
846
|
+
for row in rows:
|
|
847
|
+
normalized_rows.append(self._normalize_row(row, cursor.description))
|
|
848
|
+
return normalized_rows
|
|
849
|
+
finally:
|
|
850
|
+
cursor.close()
|
|
851
|
+
|
|
852
|
+
def _build_select_clause(self, include_fields: Dict[str, bool]) -> str:
|
|
853
|
+
"""
|
|
854
|
+
Build SELECT clause based on include fields
|
|
855
|
+
|
|
856
|
+
Args:
|
|
857
|
+
include_fields: Dictionary of fields to include
|
|
858
|
+
|
|
859
|
+
Returns:
|
|
860
|
+
SELECT clause string
|
|
861
|
+
"""
|
|
862
|
+
select_fields = ["_id"]
|
|
863
|
+
if include_fields.get("embeddings") or include_fields.get("embedding"):
|
|
864
|
+
select_fields.append("embedding")
|
|
865
|
+
if include_fields.get("documents") or include_fields.get("document"):
|
|
866
|
+
select_fields.append("document")
|
|
867
|
+
if include_fields.get("metadatas") or include_fields.get("metadata"):
|
|
868
|
+
select_fields.append("metadata")
|
|
869
|
+
|
|
870
|
+
return ", ".join(select_fields)
|
|
871
|
+
|
|
872
|
+
def _build_where_clause(
|
|
873
|
+
self,
|
|
874
|
+
where: Optional[Dict[str, Any]] = None,
|
|
875
|
+
where_document: Optional[Dict[str, Any]] = None,
|
|
876
|
+
id_list: Optional[List[str]] = None
|
|
877
|
+
) -> Tuple[str, List[Any]]:
|
|
878
|
+
"""
|
|
879
|
+
Build WHERE clause from filters
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
where: Metadata filter
|
|
883
|
+
where_document: Document filter
|
|
884
|
+
id_list: List of IDs to filter
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
Tuple of (where_clause, params)
|
|
888
|
+
"""
|
|
889
|
+
where_clauses = []
|
|
890
|
+
params = []
|
|
891
|
+
|
|
892
|
+
# Add ids filter if provided
|
|
893
|
+
if id_list:
|
|
894
|
+
# Process IDs for varbinary(512) _id field - support any string format
|
|
895
|
+
processed_ids = []
|
|
896
|
+
for id_val in id_list:
|
|
897
|
+
if not isinstance(id_val, str):
|
|
898
|
+
id_val = str(id_val)
|
|
899
|
+
processed_ids.append(self._convert_id_to_sql(id_val))
|
|
900
|
+
|
|
901
|
+
where_clauses.append(f"_id IN ({','.join(processed_ids)})")
|
|
902
|
+
|
|
903
|
+
# Add metadata filter
|
|
904
|
+
if where:
|
|
905
|
+
meta_clause, meta_params = FilterBuilder.build_metadata_filter(where, "metadata")
|
|
906
|
+
if meta_clause:
|
|
907
|
+
where_clauses.append(meta_clause)
|
|
908
|
+
params.extend(meta_params)
|
|
909
|
+
|
|
910
|
+
# Add document filter
|
|
911
|
+
if where_document:
|
|
912
|
+
doc_clause, doc_params = FilterBuilder.build_document_filter(where_document, "document")
|
|
913
|
+
if doc_clause:
|
|
914
|
+
where_clauses.append(doc_clause)
|
|
915
|
+
params.extend(doc_params)
|
|
916
|
+
|
|
917
|
+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
918
|
+
return where_clause, params
|
|
919
|
+
|
|
920
|
+
def _parse_row_value(self, value: Any) -> Any:
|
|
921
|
+
"""
|
|
922
|
+
Parse row value (handle JSON strings)
|
|
923
|
+
|
|
924
|
+
Args:
|
|
925
|
+
value: Raw value from database
|
|
926
|
+
|
|
927
|
+
Returns:
|
|
928
|
+
Parsed value
|
|
929
|
+
"""
|
|
930
|
+
if value is None:
|
|
931
|
+
return None
|
|
932
|
+
|
|
933
|
+
if isinstance(value, str):
|
|
934
|
+
try:
|
|
935
|
+
return json.loads(value)
|
|
936
|
+
except (json.JSONDecodeError, ValueError):
|
|
937
|
+
return value
|
|
938
|
+
|
|
939
|
+
return value
|
|
940
|
+
|
|
941
|
+
def _convert_id_to_sql(self, id_val: str) -> str:
|
|
942
|
+
"""
|
|
943
|
+
Convert string ID to SQL format for varbinary(512) _id field
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
id_val: String ID (can be any string like "id1", "item-123", etc.)
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
SQL expression to convert string to binary (e.g., "CAST('id1' AS BINARY)")
|
|
950
|
+
"""
|
|
951
|
+
if not isinstance(id_val, str):
|
|
952
|
+
id_val = str(id_val)
|
|
953
|
+
|
|
954
|
+
# Escape single quotes in the ID
|
|
955
|
+
id_val_escaped = id_val.replace("'", "''")
|
|
956
|
+
# Use CAST to convert string to binary for varbinary(512) field
|
|
957
|
+
return f"CAST('{id_val_escaped}' AS BINARY)"
|
|
958
|
+
|
|
959
|
+
def _convert_id_from_bytes(self, record_id: Any) -> str:
|
|
960
|
+
"""
|
|
961
|
+
Convert _id from bytes to string format
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
record_id: Record ID from database (can be bytes, str, or other format)
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
String ID
|
|
968
|
+
"""
|
|
969
|
+
# If it's already a string, return as is
|
|
970
|
+
if isinstance(record_id, str):
|
|
971
|
+
return record_id
|
|
972
|
+
|
|
973
|
+
# Convert bytes to string (UTF-8 decode)
|
|
974
|
+
if isinstance(record_id, bytes):
|
|
975
|
+
try:
|
|
976
|
+
return record_id.decode('utf-8')
|
|
977
|
+
except UnicodeDecodeError:
|
|
978
|
+
# If UTF-8 decode fails, return hex representation as fallback
|
|
979
|
+
return record_id.hex()
|
|
980
|
+
|
|
981
|
+
# For other formats, convert to string
|
|
982
|
+
return str(record_id)
|
|
983
|
+
|
|
984
|
+
def _process_query_row(
|
|
985
|
+
self,
|
|
986
|
+
row: Dict[str, Any],
|
|
987
|
+
include_fields: Dict[str, bool]
|
|
988
|
+
) -> Dict[str, Any]:
|
|
989
|
+
"""
|
|
990
|
+
Process a row from query results
|
|
991
|
+
|
|
992
|
+
Args:
|
|
993
|
+
row: Normalized row dictionary
|
|
994
|
+
include_fields: Fields to include
|
|
995
|
+
|
|
996
|
+
Returns:
|
|
997
|
+
Result item dictionary
|
|
998
|
+
"""
|
|
999
|
+
# Convert _id from bytes to string format
|
|
1000
|
+
record_id = self._convert_id_from_bytes(row["_id"])
|
|
1001
|
+
result_item = {"_id": record_id}
|
|
1002
|
+
|
|
1003
|
+
if "document" in row and row["document"] is not None:
|
|
1004
|
+
result_item["document"] = row["document"]
|
|
1005
|
+
|
|
1006
|
+
if "embedding" in row and row["embedding"] is not None:
|
|
1007
|
+
result_item["embedding"] = self._parse_row_value(row["embedding"])
|
|
1008
|
+
|
|
1009
|
+
if "metadata" in row and row["metadata"] is not None:
|
|
1010
|
+
result_item["metadata"] = self._parse_row_value(row["metadata"])
|
|
1011
|
+
|
|
1012
|
+
if "distance" in row:
|
|
1013
|
+
result_item["distance"] = float(row["distance"])
|
|
1014
|
+
|
|
1015
|
+
return result_item
|
|
1016
|
+
|
|
1017
|
+
def _process_get_row(
|
|
1018
|
+
self,
|
|
1019
|
+
row: Dict[str, Any],
|
|
1020
|
+
include_fields: Dict[str, bool]
|
|
1021
|
+
) -> Dict[str, Any]:
|
|
1022
|
+
"""
|
|
1023
|
+
Process a row from get results
|
|
1024
|
+
|
|
1025
|
+
Args:
|
|
1026
|
+
row: Normalized row dictionary
|
|
1027
|
+
include_fields: Fields to include
|
|
1028
|
+
|
|
1029
|
+
Returns:
|
|
1030
|
+
Result item dictionary with id, document, embedding, metadata
|
|
1031
|
+
"""
|
|
1032
|
+
# Convert _id from bytes to string format
|
|
1033
|
+
record_id = self._convert_id_from_bytes(row["_id"])
|
|
1034
|
+
|
|
1035
|
+
document = None
|
|
1036
|
+
embedding = None
|
|
1037
|
+
metadata = None
|
|
1038
|
+
|
|
1039
|
+
# Include document if requested
|
|
1040
|
+
if include_fields.get("documents") or include_fields.get("document"):
|
|
1041
|
+
if "document" in row:
|
|
1042
|
+
document = row["document"]
|
|
1043
|
+
|
|
1044
|
+
# Include metadata if requested
|
|
1045
|
+
if include_fields.get("metadatas") or include_fields.get("metadata"):
|
|
1046
|
+
if "metadata" in row and row["metadata"] is not None:
|
|
1047
|
+
metadata = self._parse_row_value(row["metadata"])
|
|
1048
|
+
|
|
1049
|
+
# Include embedding if requested
|
|
1050
|
+
if include_fields.get("embeddings") or include_fields.get("embedding"):
|
|
1051
|
+
if "embedding" in row and row["embedding"] is not None:
|
|
1052
|
+
embedding = self._parse_row_value(row["embedding"])
|
|
1053
|
+
|
|
1054
|
+
return {
|
|
1055
|
+
"id": record_id,
|
|
1056
|
+
"document": document,
|
|
1057
|
+
"embedding": embedding,
|
|
1058
|
+
"metadata": metadata
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
def _use_context_manager_for_cursor(self) -> bool:
|
|
1062
|
+
"""
|
|
1063
|
+
Whether to use context manager for cursor
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
True if context manager should be used, False otherwise
|
|
1067
|
+
"""
|
|
1068
|
+
# Default implementation: use context manager
|
|
1069
|
+
# Subclasses can override this if they need different behavior
|
|
1070
|
+
return True
|
|
1071
|
+
|
|
1072
|
+
# -------------------- DQL Operations (Common Implementation) --------------------
|
|
1073
|
+
|
|
1074
|
+
def _collection_query(
|
|
1075
|
+
self,
|
|
1076
|
+
collection_id: Optional[str],
|
|
1077
|
+
collection_name: str,
|
|
1078
|
+
query_embeddings: Optional[Union[List[float], List[List[float]]]] = None,
|
|
1079
|
+
query_texts: Optional[Union[str, List[str]]] = None,
|
|
1080
|
+
n_results: int = 10,
|
|
1081
|
+
where: Optional[Dict[str, Any]] = None,
|
|
1082
|
+
where_document: Optional[Dict[str, Any]] = None,
|
|
1083
|
+
include: Optional[List[str]] = None,
|
|
1084
|
+
**kwargs
|
|
1085
|
+
) -> Union[QueryResult, List[QueryResult]]:
|
|
1086
|
+
"""
|
|
1087
|
+
[Internal] Query collection by vector similarity - Common SQL-based implementation
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
collection_id: Collection ID
|
|
1091
|
+
collection_name: Collection name
|
|
1092
|
+
query_embeddings: Query vector(s) (preferred)
|
|
1093
|
+
query_texts: Query text(s) - will be embedded if provided (preferred)
|
|
1094
|
+
n_results: Number of results (default: 10)
|
|
1095
|
+
where: Metadata filter
|
|
1096
|
+
where_document: Document filter
|
|
1097
|
+
include: Fields to include
|
|
1098
|
+
**kwargs: Additional parameters
|
|
1099
|
+
|
|
1100
|
+
Returns:
|
|
1101
|
+
- If single vector/text provided: QueryResult object containing query results
|
|
1102
|
+
- If multiple vectors/texts provided: List of QueryResult objects, one for each query vector
|
|
1103
|
+
"""
|
|
1104
|
+
logger.info(f"Querying collection '{collection_name}' with n_results={n_results}")
|
|
1105
|
+
conn = self._ensure_connection()
|
|
1106
|
+
|
|
1107
|
+
# Convert collection name to table name
|
|
1108
|
+
table_name = f"c$v1${collection_name}"
|
|
1109
|
+
|
|
1110
|
+
# Handle text embedding if query_texts provided
|
|
1111
|
+
if query_texts is not None and query_embeddings is None:
|
|
1112
|
+
logger.info("Embedding query texts...")
|
|
1113
|
+
query_embeddings = self._embed_texts(query_texts, **kwargs)
|
|
1114
|
+
|
|
1115
|
+
# Normalize query vectors to list of lists
|
|
1116
|
+
query_vectors = self._normalize_query_vectors(query_embeddings)
|
|
1117
|
+
|
|
1118
|
+
if not query_vectors:
|
|
1119
|
+
raise ValueError("Either query_embeddings or query_texts must be provided")
|
|
1120
|
+
|
|
1121
|
+
# Check if multiple vectors provided
|
|
1122
|
+
is_multiple_vectors = len(query_vectors) > 1
|
|
1123
|
+
|
|
1124
|
+
# Normalize include fields
|
|
1125
|
+
include_fields = self._normalize_include_fields(include)
|
|
1126
|
+
|
|
1127
|
+
# Build SELECT clause
|
|
1128
|
+
select_clause = self._build_select_clause(include_fields)
|
|
1129
|
+
|
|
1130
|
+
# Build WHERE clause from filters
|
|
1131
|
+
where_clause, params = self._build_where_clause(where, where_document)
|
|
1132
|
+
|
|
1133
|
+
use_context_manager = self._use_context_manager_for_cursor()
|
|
1134
|
+
|
|
1135
|
+
# Collect results for each query vector separately
|
|
1136
|
+
query_results = []
|
|
1137
|
+
|
|
1138
|
+
for query_vector in query_vectors:
|
|
1139
|
+
# Convert vector to string format for SQL
|
|
1140
|
+
vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
|
1141
|
+
|
|
1142
|
+
# Build SQL query with vector distance calculation
|
|
1143
|
+
# Reference: SELECT id, vec FROM t2 ORDER BY l2_distance(vec, '[0.1, 0.2, 0.3]') APPROXIMATE LIMIT 5;
|
|
1144
|
+
# Need to include distance in SELECT for result processing
|
|
1145
|
+
sql = f"""
|
|
1146
|
+
SELECT {select_clause},
|
|
1147
|
+
l2_distance(embedding, '{vector_str}') AS distance
|
|
1148
|
+
FROM `{table_name}`
|
|
1149
|
+
{where_clause}
|
|
1150
|
+
ORDER BY l2_distance(embedding, '{vector_str}')
|
|
1151
|
+
APPROXIMATE
|
|
1152
|
+
LIMIT %s
|
|
1153
|
+
"""
|
|
1154
|
+
|
|
1155
|
+
# Execute query
|
|
1156
|
+
query_params = params + [n_results]
|
|
1157
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
1158
|
+
logger.debug(f"Parameters: {query_params}")
|
|
1159
|
+
|
|
1160
|
+
rows = self._execute_query_with_cursor(conn, sql, query_params, use_context_manager)
|
|
1161
|
+
|
|
1162
|
+
# Create QueryResult for this vector
|
|
1163
|
+
query_result = QueryResult()
|
|
1164
|
+
for row in rows:
|
|
1165
|
+
result_item = self._process_query_row(row, include_fields)
|
|
1166
|
+
query_result.add_item(
|
|
1167
|
+
id=result_item.get("_id"),
|
|
1168
|
+
document=result_item.get("document"),
|
|
1169
|
+
embedding=result_item.get("embedding"),
|
|
1170
|
+
metadata=result_item.get("metadata"),
|
|
1171
|
+
distance=result_item.get("distance")
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
query_results.append(query_result)
|
|
1175
|
+
|
|
1176
|
+
# Return single QueryResult if only one vector, otherwise return list
|
|
1177
|
+
if is_multiple_vectors:
|
|
1178
|
+
logger.info(f"✅ Query completed for '{collection_name}' with {len(query_vectors)} vectors, returning {len(query_results)} QueryResult objects")
|
|
1179
|
+
return query_results
|
|
1180
|
+
else:
|
|
1181
|
+
logger.info(f"✅ Query completed for '{collection_name}', found {len(query_results[0])} results")
|
|
1182
|
+
return query_results[0]
|
|
1183
|
+
|
|
1184
|
+
def _collection_get(
|
|
1185
|
+
self,
|
|
1186
|
+
collection_id: Optional[str],
|
|
1187
|
+
collection_name: str,
|
|
1188
|
+
ids: Optional[Union[str, List[str]]] = None,
|
|
1189
|
+
where: Optional[Dict[str, Any]] = None,
|
|
1190
|
+
where_document: Optional[Dict[str, Any]] = None,
|
|
1191
|
+
limit: Optional[int] = None,
|
|
1192
|
+
offset: Optional[int] = None,
|
|
1193
|
+
include: Optional[List[str]] = None,
|
|
1194
|
+
**kwargs
|
|
1195
|
+
) -> Union[QueryResult, List[QueryResult]]:
|
|
1196
|
+
"""
|
|
1197
|
+
[Internal] Get data from collection by IDs or filters - Common SQL-based implementation
|
|
1198
|
+
|
|
1199
|
+
Args:
|
|
1200
|
+
collection_id: Collection ID
|
|
1201
|
+
collection_name: Collection name
|
|
1202
|
+
ids: Single ID or list of IDs (optional)
|
|
1203
|
+
where: Filter condition on metadata (optional)
|
|
1204
|
+
where_document: Filter condition on documents (optional)
|
|
1205
|
+
limit: Maximum number of results (optional)
|
|
1206
|
+
offset: Number of results to skip (optional)
|
|
1207
|
+
include: Fields to include in results (optional)
|
|
1208
|
+
**kwargs: Additional parameters
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
- If single ID provided: QueryResult object containing get results for that ID
|
|
1212
|
+
- If multiple IDs provided (and no filters): List of QueryResult objects, one for each ID
|
|
1213
|
+
- If filters provided (no IDs or multiple IDs with filters): QueryResult object containing all matching results
|
|
1214
|
+
"""
|
|
1215
|
+
logger.info(f"Getting data from collection '{collection_name}'")
|
|
1216
|
+
conn = self._ensure_connection()
|
|
1217
|
+
|
|
1218
|
+
# Convert collection name to table name
|
|
1219
|
+
table_name = f"c$v1${collection_name}"
|
|
1220
|
+
|
|
1221
|
+
# Set defaults
|
|
1222
|
+
if limit is None:
|
|
1223
|
+
limit = 100
|
|
1224
|
+
if offset is None:
|
|
1225
|
+
offset = 0
|
|
1226
|
+
|
|
1227
|
+
# Normalize ids to list
|
|
1228
|
+
id_list = None
|
|
1229
|
+
is_single_id = False
|
|
1230
|
+
if ids is not None:
|
|
1231
|
+
if isinstance(ids, str):
|
|
1232
|
+
id_list = [ids]
|
|
1233
|
+
is_single_id = True
|
|
1234
|
+
else:
|
|
1235
|
+
id_list = ids
|
|
1236
|
+
is_single_id = len(id_list) == 1
|
|
1237
|
+
|
|
1238
|
+
# Check if we should return multiple QueryResults (multiple IDs and no filters)
|
|
1239
|
+
has_filters = where is not None or where_document is not None
|
|
1240
|
+
is_multiple_ids = id_list is not None and len(id_list) > 1
|
|
1241
|
+
should_return_multiple = is_multiple_ids and not has_filters
|
|
1242
|
+
|
|
1243
|
+
# Normalize include fields (default includes documents and metadatas)
|
|
1244
|
+
include_fields = self._normalize_include_fields(include)
|
|
1245
|
+
|
|
1246
|
+
# Build SELECT clause - always include _id
|
|
1247
|
+
select_clause = self._build_select_clause(include_fields)
|
|
1248
|
+
|
|
1249
|
+
use_context_manager = self._use_context_manager_for_cursor()
|
|
1250
|
+
|
|
1251
|
+
# If multiple IDs and no filters, get each ID separately
|
|
1252
|
+
if should_return_multiple:
|
|
1253
|
+
query_results = []
|
|
1254
|
+
for single_id in id_list:
|
|
1255
|
+
# Build WHERE clause for this single ID
|
|
1256
|
+
where_clause, params = self._build_where_clause(where, where_document, [single_id])
|
|
1257
|
+
|
|
1258
|
+
# Build SQL query
|
|
1259
|
+
sql = f"""
|
|
1260
|
+
SELECT {select_clause}
|
|
1261
|
+
FROM `{table_name}`
|
|
1262
|
+
{where_clause}
|
|
1263
|
+
LIMIT %s OFFSET %s
|
|
1264
|
+
"""
|
|
1265
|
+
|
|
1266
|
+
# Execute query
|
|
1267
|
+
query_params = params + [limit, offset]
|
|
1268
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
1269
|
+
logger.debug(f"Parameters: {query_params}")
|
|
1270
|
+
|
|
1271
|
+
rows = self._execute_query_with_cursor(conn, sql, query_params, use_context_manager)
|
|
1272
|
+
|
|
1273
|
+
# Build QueryResult for this ID
|
|
1274
|
+
query_result = QueryResult()
|
|
1275
|
+
for row in rows:
|
|
1276
|
+
processed_row = self._process_get_row(row, include_fields)
|
|
1277
|
+
query_result.add_item(
|
|
1278
|
+
id=processed_row["id"],
|
|
1279
|
+
document=processed_row["document"],
|
|
1280
|
+
embedding=processed_row["embedding"],
|
|
1281
|
+
metadata=processed_row["metadata"]
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
query_results.append(query_result)
|
|
1285
|
+
|
|
1286
|
+
logger.info(f"✅ Get completed for '{collection_name}' with {len(id_list)} IDs, returning {len(query_results)} QueryResult objects")
|
|
1287
|
+
return query_results
|
|
1288
|
+
else:
|
|
1289
|
+
# Single ID or filters: return single QueryResult
|
|
1290
|
+
# Build WHERE clause from filters
|
|
1291
|
+
where_clause, params = self._build_where_clause(where, where_document, id_list)
|
|
1292
|
+
|
|
1293
|
+
# Build SQL query
|
|
1294
|
+
sql = f"""
|
|
1295
|
+
SELECT {select_clause}
|
|
1296
|
+
FROM `{table_name}`
|
|
1297
|
+
{where_clause}
|
|
1298
|
+
LIMIT %s OFFSET %s
|
|
1299
|
+
"""
|
|
1300
|
+
|
|
1301
|
+
# Execute query
|
|
1302
|
+
query_params = params + [limit, offset]
|
|
1303
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
1304
|
+
logger.debug(f"Parameters: {query_params}")
|
|
1305
|
+
|
|
1306
|
+
rows = self._execute_query_with_cursor(conn, sql, query_params, use_context_manager)
|
|
1307
|
+
|
|
1308
|
+
# Build QueryResult
|
|
1309
|
+
query_result = QueryResult()
|
|
1310
|
+
|
|
1311
|
+
for row in rows:
|
|
1312
|
+
# Process row
|
|
1313
|
+
processed_row = self._process_get_row(row, include_fields)
|
|
1314
|
+
|
|
1315
|
+
query_result.add_item(
|
|
1316
|
+
id=processed_row["id"],
|
|
1317
|
+
document=processed_row["document"],
|
|
1318
|
+
embedding=processed_row["embedding"],
|
|
1319
|
+
metadata=processed_row["metadata"]
|
|
1320
|
+
)
|
|
1321
|
+
|
|
1322
|
+
logger.info(f"✅ Get completed for '{collection_name}', found {len(query_result)} results")
|
|
1323
|
+
return query_result
|
|
1324
|
+
|
|
1325
|
+
def _collection_hybrid_search(
|
|
1326
|
+
self,
|
|
1327
|
+
collection_id: Optional[str],
|
|
1328
|
+
collection_name: str,
|
|
1329
|
+
query: Optional[Dict[str, Any]] = None,
|
|
1330
|
+
knn: Optional[Dict[str, Any]] = None,
|
|
1331
|
+
rank: Optional[Dict[str, Any]] = None,
|
|
1332
|
+
n_results: int = 10,
|
|
1333
|
+
include: Optional[List[str]] = None,
|
|
1334
|
+
**kwargs
|
|
1335
|
+
) -> Dict[str, Any]:
|
|
1336
|
+
"""
|
|
1337
|
+
[Internal] Hybrid search combining full-text search and vector similarity search - Common SQL-based implementation
|
|
1338
|
+
|
|
1339
|
+
Supports:
|
|
1340
|
+
1. Scalar query (metadata filtering only)
|
|
1341
|
+
2. Full-text search (with optional metadata filtering)
|
|
1342
|
+
3. Vector search (with optional metadata filtering)
|
|
1343
|
+
4. Scalar + vector search (with optional metadata filtering)
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
collection_id: Collection ID
|
|
1347
|
+
collection_name: Collection name
|
|
1348
|
+
query: Full-text search configuration dict with:
|
|
1349
|
+
- where_document: Document filter conditions (e.g., {"$contains": "text"})
|
|
1350
|
+
- where: Metadata filter conditions (e.g., {"page": {"$gte": 5}})
|
|
1351
|
+
knn: Vector search configuration dict with:
|
|
1352
|
+
- query_texts: Query text(s) to be embedded (optional if query_embeddings provided)
|
|
1353
|
+
- query_embeddings: Query vector(s) (optional if query_texts provided)
|
|
1354
|
+
- where: Metadata filter conditions (optional)
|
|
1355
|
+
- n_results: Number of results for vector search (optional)
|
|
1356
|
+
rank: Ranking configuration dict (e.g., {"rrf": {"rank_window_size": 60, "rank_constant": 60}})
|
|
1357
|
+
n_results: Final number of results to return after ranking (default: 10)
|
|
1358
|
+
include: Fields to include in results (optional)
|
|
1359
|
+
**kwargs: Additional parameters
|
|
1360
|
+
|
|
1361
|
+
Returns:
|
|
1362
|
+
Search results dictionary containing ids, distances, metadatas, documents, embeddings, etc.
|
|
1363
|
+
"""
|
|
1364
|
+
logger.info(f"Hybrid search in collection '{collection_name}' with n_results={n_results}")
|
|
1365
|
+
conn = self._ensure_connection()
|
|
1366
|
+
|
|
1367
|
+
# Build table name
|
|
1368
|
+
table_name = f"c$v1${collection_name}"
|
|
1369
|
+
|
|
1370
|
+
# Build search_parm JSON
|
|
1371
|
+
search_parm = self._build_search_parm(query, knn, rank, n_results)
|
|
1372
|
+
|
|
1373
|
+
# Convert search_parm to JSON string
|
|
1374
|
+
search_parm_json = json.dumps(search_parm, ensure_ascii=False)
|
|
1375
|
+
|
|
1376
|
+
# Use variable binding to avoid datatype issues
|
|
1377
|
+
use_context_manager = self._use_context_manager_for_cursor()
|
|
1378
|
+
|
|
1379
|
+
# Set the search_parm variable first
|
|
1380
|
+
escaped_params = search_parm_json.replace("'", "''")
|
|
1381
|
+
set_sql = f"SET @search_parm = '{escaped_params}'"
|
|
1382
|
+
logger.debug(f"Setting search_parm: {set_sql}")
|
|
1383
|
+
logger.debug(f"Search parm JSON: {search_parm_json}")
|
|
1384
|
+
|
|
1385
|
+
# Execute SET statement
|
|
1386
|
+
self._execute_query_with_cursor(conn, set_sql, [], use_context_manager)
|
|
1387
|
+
|
|
1388
|
+
# Get SQL query from DBMS_HYBRID_SEARCH.GET_SQL
|
|
1389
|
+
get_sql_query = f"SELECT DBMS_HYBRID_SEARCH.GET_SQL('{table_name}', @search_parm) as query_sql FROM dual"
|
|
1390
|
+
logger.debug(f"Getting SQL query: {get_sql_query}")
|
|
1391
|
+
|
|
1392
|
+
rows = self._execute_query_with_cursor(conn, get_sql_query, [], use_context_manager)
|
|
1393
|
+
|
|
1394
|
+
if not rows or not rows[0].get("query_sql"):
|
|
1395
|
+
logger.warning(f"No SQL query returned from GET_SQL")
|
|
1396
|
+
return {
|
|
1397
|
+
"ids": [],
|
|
1398
|
+
"distances": [],
|
|
1399
|
+
"metadatas": [],
|
|
1400
|
+
"documents": [],
|
|
1401
|
+
"embeddings": []
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
# Get the SQL query string
|
|
1405
|
+
query_sql = rows[0]["query_sql"]
|
|
1406
|
+
if isinstance(query_sql, str):
|
|
1407
|
+
# Remove any surrounding quotes if present
|
|
1408
|
+
query_sql = query_sql.strip().strip("'\"")
|
|
1409
|
+
|
|
1410
|
+
logger.debug(f"Executing query SQL: {query_sql}")
|
|
1411
|
+
|
|
1412
|
+
# Execute the returned SQL query
|
|
1413
|
+
result_rows = self._execute_query_with_cursor(conn, query_sql, [], use_context_manager)
|
|
1414
|
+
|
|
1415
|
+
# Transform SQL query results to standard format
|
|
1416
|
+
return self._transform_sql_result(result_rows, include)
|
|
1417
|
+
|
|
1418
|
+
def _build_search_parm(
|
|
1419
|
+
self,
|
|
1420
|
+
query: Optional[Dict[str, Any]],
|
|
1421
|
+
knn: Optional[Dict[str, Any]],
|
|
1422
|
+
rank: Optional[Dict[str, Any]],
|
|
1423
|
+
n_results: int
|
|
1424
|
+
) -> Dict[str, Any]:
|
|
1425
|
+
"""
|
|
1426
|
+
Build search_parm JSON from query, knn, and rank parameters
|
|
1427
|
+
|
|
1428
|
+
Args:
|
|
1429
|
+
query: Full-text search configuration dict
|
|
1430
|
+
knn: Vector search configuration dict
|
|
1431
|
+
rank: Ranking configuration dict
|
|
1432
|
+
n_results: Final number of results to return
|
|
1433
|
+
|
|
1434
|
+
Returns:
|
|
1435
|
+
search_parm dictionary
|
|
1436
|
+
"""
|
|
1437
|
+
search_parm = {}
|
|
1438
|
+
|
|
1439
|
+
# Build query part (full-text search or scalar query)
|
|
1440
|
+
if query:
|
|
1441
|
+
query_expr = self._build_query_expression(query)
|
|
1442
|
+
if query_expr:
|
|
1443
|
+
search_parm["query"] = query_expr
|
|
1444
|
+
|
|
1445
|
+
# Build knn part (vector search)
|
|
1446
|
+
if knn:
|
|
1447
|
+
knn_expr = self._build_knn_expression(knn)
|
|
1448
|
+
if knn_expr:
|
|
1449
|
+
search_parm["knn"] = knn_expr
|
|
1450
|
+
|
|
1451
|
+
# Build rank part
|
|
1452
|
+
if rank:
|
|
1453
|
+
search_parm["rank"] = rank
|
|
1454
|
+
|
|
1455
|
+
return search_parm
|
|
1456
|
+
|
|
1457
|
+
def _build_query_expression(self, query: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1458
|
+
"""
|
|
1459
|
+
Build query expression from query dict
|
|
1460
|
+
|
|
1461
|
+
Supports:
|
|
1462
|
+
- Scalar query (metadata filtering only): query.range or query.term
|
|
1463
|
+
- Full-text search: query.query_string
|
|
1464
|
+
- Full-text search with metadata filtering: query.bool with must and filter
|
|
1465
|
+
"""
|
|
1466
|
+
where_document = query.get("where_document")
|
|
1467
|
+
where = query.get("where")
|
|
1468
|
+
|
|
1469
|
+
# Case 1: Scalar query (metadata filtering only, no full-text search)
|
|
1470
|
+
if not where_document and where:
|
|
1471
|
+
filter_conditions = self._build_metadata_filter_for_search_parm(where)
|
|
1472
|
+
if filter_conditions:
|
|
1473
|
+
# If only one filter condition, check its type
|
|
1474
|
+
if len(filter_conditions) == 1:
|
|
1475
|
+
filter_cond = filter_conditions[0]
|
|
1476
|
+
# Check if it's a range query
|
|
1477
|
+
if "range" in filter_cond:
|
|
1478
|
+
return {"range": filter_cond["range"]}
|
|
1479
|
+
# Check if it's a term query
|
|
1480
|
+
elif "term" in filter_cond:
|
|
1481
|
+
return {"term": filter_cond["term"]}
|
|
1482
|
+
# Otherwise, it's a bool query, wrap in filter
|
|
1483
|
+
else:
|
|
1484
|
+
return {"bool": {"filter": filter_conditions}}
|
|
1485
|
+
# Multiple filter conditions, wrap in bool
|
|
1486
|
+
return {"bool": {"filter": filter_conditions}}
|
|
1487
|
+
|
|
1488
|
+
# Case 2: Full-text search (with or without metadata filtering)
|
|
1489
|
+
if where_document:
|
|
1490
|
+
# Build document query using query_string
|
|
1491
|
+
doc_query = self._build_document_query(where_document)
|
|
1492
|
+
if doc_query:
|
|
1493
|
+
# Build filter from where condition
|
|
1494
|
+
filter_conditions = self._build_metadata_filter_for_search_parm(where)
|
|
1495
|
+
|
|
1496
|
+
if filter_conditions:
|
|
1497
|
+
# Full-text search with metadata filtering
|
|
1498
|
+
return {
|
|
1499
|
+
"bool": {
|
|
1500
|
+
"must": [doc_query],
|
|
1501
|
+
"filter": filter_conditions
|
|
1502
|
+
}
|
|
1503
|
+
}
|
|
1504
|
+
else:
|
|
1505
|
+
# Full-text search only
|
|
1506
|
+
return doc_query
|
|
1507
|
+
|
|
1508
|
+
return None
|
|
1509
|
+
|
|
1510
|
+
def _build_document_query(self, where_document: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1511
|
+
"""
|
|
1512
|
+
Build document query from where_document condition using query_string
|
|
1513
|
+
|
|
1514
|
+
Args:
|
|
1515
|
+
where_document: Document filter conditions
|
|
1516
|
+
|
|
1517
|
+
Returns:
|
|
1518
|
+
query_string query dict
|
|
1519
|
+
"""
|
|
1520
|
+
if not where_document:
|
|
1521
|
+
return None
|
|
1522
|
+
|
|
1523
|
+
# Handle $contains - use query_string
|
|
1524
|
+
if "$contains" in where_document:
|
|
1525
|
+
return {
|
|
1526
|
+
"query_string": {
|
|
1527
|
+
"fields": ["document"],
|
|
1528
|
+
"query": where_document["$contains"]
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
# Handle $and with $contains
|
|
1533
|
+
if "$and" in where_document:
|
|
1534
|
+
and_conditions = where_document["$and"]
|
|
1535
|
+
contains_queries = []
|
|
1536
|
+
for condition in and_conditions:
|
|
1537
|
+
if isinstance(condition, dict) and "$contains" in condition:
|
|
1538
|
+
contains_queries.append(condition["$contains"])
|
|
1539
|
+
|
|
1540
|
+
if contains_queries:
|
|
1541
|
+
# Combine multiple $contains with AND
|
|
1542
|
+
return {
|
|
1543
|
+
"query_string": {
|
|
1544
|
+
"fields": ["document"],
|
|
1545
|
+
"query": " ".join(contains_queries)
|
|
1546
|
+
}
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1549
|
+
# Handle $or with $contains
|
|
1550
|
+
if "$or" in where_document:
|
|
1551
|
+
or_conditions = where_document["$or"]
|
|
1552
|
+
contains_queries = []
|
|
1553
|
+
for condition in or_conditions:
|
|
1554
|
+
if isinstance(condition, dict) and "$contains" in condition:
|
|
1555
|
+
contains_queries.append(condition["$contains"])
|
|
1556
|
+
|
|
1557
|
+
if contains_queries:
|
|
1558
|
+
# Combine multiple $contains with OR
|
|
1559
|
+
return {
|
|
1560
|
+
"query_string": {
|
|
1561
|
+
"fields": ["document"],
|
|
1562
|
+
"query": " OR ".join(contains_queries)
|
|
1563
|
+
}
|
|
1564
|
+
}
|
|
1565
|
+
|
|
1566
|
+
# Default: if it's a string, treat as $contains
|
|
1567
|
+
if isinstance(where_document, str):
|
|
1568
|
+
return {
|
|
1569
|
+
"query_string": {
|
|
1570
|
+
"fields": ["document"],
|
|
1571
|
+
"query": where_document
|
|
1572
|
+
}
|
|
1573
|
+
}
|
|
1574
|
+
|
|
1575
|
+
return None
|
|
1576
|
+
|
|
1577
|
+
def _build_metadata_filter_for_search_parm(self, where: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
1578
|
+
"""
|
|
1579
|
+
Build metadata filter conditions for search_parm using JSON_EXTRACT format
|
|
1580
|
+
|
|
1581
|
+
Args:
|
|
1582
|
+
where: Metadata filter conditions
|
|
1583
|
+
|
|
1584
|
+
Returns:
|
|
1585
|
+
List of filter conditions in search_parm format
|
|
1586
|
+
Format: {"term": {"(JSON_EXTRACT(metadata, '$.field_name'))": "value"}}
|
|
1587
|
+
or {"range": {"(JSON_EXTRACT(metadata, '$.field_name'))": {"gte": 30, "lte": 90}}}
|
|
1588
|
+
"""
|
|
1589
|
+
if not where:
|
|
1590
|
+
return []
|
|
1591
|
+
|
|
1592
|
+
return self._build_metadata_filter_conditions(where)
|
|
1593
|
+
|
|
1594
|
+
def _build_metadata_filter_conditions(self, condition: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
1595
|
+
"""
|
|
1596
|
+
Recursively build metadata filter conditions from nested dictionary
|
|
1597
|
+
|
|
1598
|
+
Args:
|
|
1599
|
+
condition: Filter condition dictionary
|
|
1600
|
+
|
|
1601
|
+
Returns:
|
|
1602
|
+
List of filter conditions
|
|
1603
|
+
"""
|
|
1604
|
+
if not condition:
|
|
1605
|
+
return []
|
|
1606
|
+
|
|
1607
|
+
result = []
|
|
1608
|
+
|
|
1609
|
+
# Handle logical operators
|
|
1610
|
+
if "$and" in condition:
|
|
1611
|
+
must_conditions = []
|
|
1612
|
+
for sub_condition in condition["$and"]:
|
|
1613
|
+
sub_filters = self._build_metadata_filter_conditions(sub_condition)
|
|
1614
|
+
must_conditions.extend(sub_filters)
|
|
1615
|
+
if must_conditions:
|
|
1616
|
+
result.append({"bool": {"must": must_conditions}})
|
|
1617
|
+
return result
|
|
1618
|
+
|
|
1619
|
+
if "$or" in condition:
|
|
1620
|
+
should_conditions = []
|
|
1621
|
+
for sub_condition in condition["$or"]:
|
|
1622
|
+
sub_filters = self._build_metadata_filter_conditions(sub_condition)
|
|
1623
|
+
should_conditions.extend(sub_filters)
|
|
1624
|
+
if should_conditions:
|
|
1625
|
+
result.append({"bool": {"should": should_conditions}})
|
|
1626
|
+
return result
|
|
1627
|
+
|
|
1628
|
+
if "$not" in condition:
|
|
1629
|
+
not_filters = self._build_metadata_filter_conditions(condition["$not"])
|
|
1630
|
+
if not_filters:
|
|
1631
|
+
result.append({"bool": {"must_not": not_filters}})
|
|
1632
|
+
return result
|
|
1633
|
+
|
|
1634
|
+
# Handle field conditions
|
|
1635
|
+
for key, value in condition.items():
|
|
1636
|
+
if key in ["$and", "$or", "$not"]:
|
|
1637
|
+
continue
|
|
1638
|
+
|
|
1639
|
+
# Build field name with JSON_EXTRACT format
|
|
1640
|
+
field_name = f"(JSON_EXTRACT(metadata, '$.{key}'))"
|
|
1641
|
+
|
|
1642
|
+
if isinstance(value, dict):
|
|
1643
|
+
# Handle comparison operators
|
|
1644
|
+
range_conditions = {}
|
|
1645
|
+
term_value = None
|
|
1646
|
+
|
|
1647
|
+
for op, op_value in value.items():
|
|
1648
|
+
if op == "$eq":
|
|
1649
|
+
term_value = op_value
|
|
1650
|
+
elif op == "$ne":
|
|
1651
|
+
# $ne should be in must_not
|
|
1652
|
+
result.append({"bool": {"must_not": [{"term": {field_name: op_value}}]}})
|
|
1653
|
+
elif op == "$lt":
|
|
1654
|
+
range_conditions["lt"] = op_value
|
|
1655
|
+
elif op == "$lte":
|
|
1656
|
+
range_conditions["lte"] = op_value
|
|
1657
|
+
elif op == "$gt":
|
|
1658
|
+
range_conditions["gt"] = op_value
|
|
1659
|
+
elif op == "$gte":
|
|
1660
|
+
range_conditions["gte"] = op_value
|
|
1661
|
+
elif op == "$in":
|
|
1662
|
+
# For $in, create multiple term queries wrapped in should
|
|
1663
|
+
in_conditions = [{"term": {field_name: val}} for val in op_value]
|
|
1664
|
+
if in_conditions:
|
|
1665
|
+
result.append({"bool": {"should": in_conditions}})
|
|
1666
|
+
elif op == "$nin":
|
|
1667
|
+
# For $nin, create multiple term queries wrapped in must_not
|
|
1668
|
+
nin_conditions = [{"term": {field_name: val}} for val in op_value]
|
|
1669
|
+
if nin_conditions:
|
|
1670
|
+
result.append({"bool": {"must_not": nin_conditions}})
|
|
1671
|
+
|
|
1672
|
+
if range_conditions:
|
|
1673
|
+
result.append({"range": {field_name: range_conditions}})
|
|
1674
|
+
elif term_value is not None:
|
|
1675
|
+
result.append({"term": {field_name: term_value}})
|
|
1676
|
+
else:
|
|
1677
|
+
# Direct equality
|
|
1678
|
+
result.append({"term": {field_name: value}})
|
|
1679
|
+
|
|
1680
|
+
return result
|
|
1681
|
+
|
|
1682
|
+
def _build_knn_expression(self, knn: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
1683
|
+
"""
|
|
1684
|
+
Build knn expression from knn dict
|
|
1685
|
+
|
|
1686
|
+
Args:
|
|
1687
|
+
knn: Vector search configuration dict
|
|
1688
|
+
|
|
1689
|
+
Returns:
|
|
1690
|
+
knn expression dict with optional filter
|
|
1691
|
+
"""
|
|
1692
|
+
query_texts = knn.get("query_texts")
|
|
1693
|
+
query_embeddings = knn.get("query_embeddings")
|
|
1694
|
+
where = knn.get("where")
|
|
1695
|
+
n_results = knn.get("n_results", 10)
|
|
1696
|
+
|
|
1697
|
+
# Get query vector
|
|
1698
|
+
query_vector = None
|
|
1699
|
+
if query_embeddings:
|
|
1700
|
+
# Use provided embeddings
|
|
1701
|
+
if isinstance(query_embeddings, list) and len(query_embeddings) > 0:
|
|
1702
|
+
if isinstance(query_embeddings[0], list):
|
|
1703
|
+
query_vector = query_embeddings[0] # Use first vector
|
|
1704
|
+
else:
|
|
1705
|
+
query_vector = query_embeddings
|
|
1706
|
+
elif query_texts:
|
|
1707
|
+
# Convert text to embedding
|
|
1708
|
+
try:
|
|
1709
|
+
texts = query_texts if isinstance(query_texts, list) else [query_texts]
|
|
1710
|
+
embeddings = self._embed_texts(texts[0] if len(texts) > 0 else texts)
|
|
1711
|
+
if embeddings and len(embeddings) > 0:
|
|
1712
|
+
query_vector = embeddings[0]
|
|
1713
|
+
except NotImplementedError:
|
|
1714
|
+
logger.warning("Text embedding not implemented. Please provide query_embeddings directly.")
|
|
1715
|
+
return None
|
|
1716
|
+
else:
|
|
1717
|
+
logger.warning("knn requires either query_texts or query_embeddings")
|
|
1718
|
+
return None
|
|
1719
|
+
|
|
1720
|
+
if not query_vector:
|
|
1721
|
+
return None
|
|
1722
|
+
|
|
1723
|
+
# Build knn expression
|
|
1724
|
+
knn_expr = {
|
|
1725
|
+
"field": "embedding",
|
|
1726
|
+
"k": n_results,
|
|
1727
|
+
"query_vector": query_vector
|
|
1728
|
+
}
|
|
1729
|
+
|
|
1730
|
+
# Add filter using JSON_EXTRACT format
|
|
1731
|
+
filter_conditions = self._build_metadata_filter_for_search_parm(where)
|
|
1732
|
+
if filter_conditions:
|
|
1733
|
+
knn_expr["filter"] = filter_conditions
|
|
1734
|
+
|
|
1735
|
+
return knn_expr
|
|
1736
|
+
|
|
1737
|
+
def _build_source_fields(self, include: Optional[List[str]]) -> List[str]:
|
|
1738
|
+
"""Build _source fields list from include parameter"""
|
|
1739
|
+
if not include:
|
|
1740
|
+
return ["document", "metadata", "embedding"]
|
|
1741
|
+
|
|
1742
|
+
source_fields = []
|
|
1743
|
+
field_mapping = {
|
|
1744
|
+
"documents": "document",
|
|
1745
|
+
"metadatas": "metadata",
|
|
1746
|
+
"embeddings": "embedding"
|
|
1747
|
+
}
|
|
1748
|
+
|
|
1749
|
+
for field in include:
|
|
1750
|
+
mapped = field_mapping.get(field.lower(), field)
|
|
1751
|
+
if mapped not in source_fields:
|
|
1752
|
+
source_fields.append(mapped)
|
|
1753
|
+
|
|
1754
|
+
return source_fields if source_fields else ["document", "metadata", "embedding"]
|
|
1755
|
+
|
|
1756
|
+
def _transform_sql_result(self, result_rows: List[Dict[str, Any]], include: Optional[List[str]]) -> Dict[str, Any]:
|
|
1757
|
+
"""
|
|
1758
|
+
Transform SQL query results to standard format
|
|
1759
|
+
|
|
1760
|
+
Args:
|
|
1761
|
+
result_rows: List of row dictionaries from SQL query
|
|
1762
|
+
include: Fields to include in results (optional)
|
|
1763
|
+
|
|
1764
|
+
Returns:
|
|
1765
|
+
Standard format dictionary with ids, distances, metadatas, documents, embeddings
|
|
1766
|
+
"""
|
|
1767
|
+
if not result_rows:
|
|
1768
|
+
return {
|
|
1769
|
+
"ids": [],
|
|
1770
|
+
"distances": [],
|
|
1771
|
+
"metadatas": [],
|
|
1772
|
+
"documents": [],
|
|
1773
|
+
"embeddings": []
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
ids = []
|
|
1777
|
+
distances = []
|
|
1778
|
+
metadatas = []
|
|
1779
|
+
documents = []
|
|
1780
|
+
embeddings = []
|
|
1781
|
+
|
|
1782
|
+
for row in result_rows:
|
|
1783
|
+
# Extract id (may be in different column names)
|
|
1784
|
+
row_id = row.get("id") or row.get("_id") or row.get("ID")
|
|
1785
|
+
# Convert bytes _id to string format
|
|
1786
|
+
row_id = self._convert_id_from_bytes(row_id)
|
|
1787
|
+
ids.append(row_id)
|
|
1788
|
+
|
|
1789
|
+
# Extract distance/score (may be in different column names)
|
|
1790
|
+
distance = row.get("_distance") or row.get("distance") or row.get("_score") or row.get("score") or row.get("DISTANCE") or row.get("_DISTANCE") or row.get("SCORE") or 0.0
|
|
1791
|
+
distances.append(distance)
|
|
1792
|
+
|
|
1793
|
+
# Extract metadata
|
|
1794
|
+
if include is None or "metadatas" in include or "metadata" in include:
|
|
1795
|
+
metadata = row.get("metadata") or row.get("METADATA")
|
|
1796
|
+
# Parse JSON string if needed
|
|
1797
|
+
if isinstance(metadata, str):
|
|
1798
|
+
try:
|
|
1799
|
+
metadata = json.loads(metadata)
|
|
1800
|
+
except (json.JSONDecodeError, TypeError):
|
|
1801
|
+
pass
|
|
1802
|
+
metadatas.append(metadata)
|
|
1803
|
+
else:
|
|
1804
|
+
metadatas.append(None)
|
|
1805
|
+
|
|
1806
|
+
# Extract document
|
|
1807
|
+
if include is None or "documents" in include or "document" in include:
|
|
1808
|
+
document = row.get("document") or row.get("DOCUMENT")
|
|
1809
|
+
documents.append(document)
|
|
1810
|
+
else:
|
|
1811
|
+
documents.append(None)
|
|
1812
|
+
|
|
1813
|
+
# Extract embedding
|
|
1814
|
+
if include and ("embeddings" in include or "embedding" in include):
|
|
1815
|
+
embedding = row.get("embedding") or row.get("EMBEDDING")
|
|
1816
|
+
# Parse JSON string or list if needed
|
|
1817
|
+
if isinstance(embedding, str):
|
|
1818
|
+
try:
|
|
1819
|
+
embedding = json.loads(embedding)
|
|
1820
|
+
except (json.JSONDecodeError, TypeError):
|
|
1821
|
+
pass
|
|
1822
|
+
embeddings.append(embedding)
|
|
1823
|
+
else:
|
|
1824
|
+
embeddings.append(None)
|
|
1825
|
+
|
|
1826
|
+
return {
|
|
1827
|
+
"ids": ids,
|
|
1828
|
+
"distances": distances,
|
|
1829
|
+
"metadatas": metadatas,
|
|
1830
|
+
"documents": documents,
|
|
1831
|
+
"embeddings": embeddings
|
|
1832
|
+
}
|
|
1833
|
+
|
|
1834
|
+
def _transform_search_result(self, search_result: Dict[str, Any], include: Optional[List[str]]) -> Dict[str, Any]:
|
|
1835
|
+
"""Transform OceanBase search result to standard format"""
|
|
1836
|
+
# OceanBase SEARCH function returns results in a specific format
|
|
1837
|
+
# This needs to be adapted based on actual return format
|
|
1838
|
+
# For now, assuming it returns hits array
|
|
1839
|
+
|
|
1840
|
+
hits = search_result.get("hits", {}).get("hits", [])
|
|
1841
|
+
|
|
1842
|
+
ids = []
|
|
1843
|
+
distances = []
|
|
1844
|
+
metadatas = []
|
|
1845
|
+
documents = []
|
|
1846
|
+
embeddings = []
|
|
1847
|
+
|
|
1848
|
+
for hit in hits:
|
|
1849
|
+
source = hit.get("_source", {})
|
|
1850
|
+
score = hit.get("_score", 0.0)
|
|
1851
|
+
|
|
1852
|
+
ids.append(hit.get("_id"))
|
|
1853
|
+
distances.append(score)
|
|
1854
|
+
|
|
1855
|
+
if include is None or "metadatas" in include or "metadata" in include:
|
|
1856
|
+
metadatas.append(source.get("metadata"))
|
|
1857
|
+
else:
|
|
1858
|
+
metadatas.append(None)
|
|
1859
|
+
|
|
1860
|
+
if include is None or "documents" in include or "document" in include:
|
|
1861
|
+
documents.append(source.get("document"))
|
|
1862
|
+
else:
|
|
1863
|
+
documents.append(None)
|
|
1864
|
+
|
|
1865
|
+
if include and ("embeddings" in include or "embedding" in include):
|
|
1866
|
+
embeddings.append(source.get("embedding"))
|
|
1867
|
+
else:
|
|
1868
|
+
embeddings.append(None)
|
|
1869
|
+
|
|
1870
|
+
return {
|
|
1871
|
+
"ids": ids,
|
|
1872
|
+
"distances": distances,
|
|
1873
|
+
"metadatas": metadatas,
|
|
1874
|
+
"documents": documents,
|
|
1875
|
+
"embeddings": embeddings
|
|
1876
|
+
}
|
|
1877
|
+
|
|
1878
|
+
# -------------------- Collection Info --------------------
|
|
1879
|
+
|
|
1880
|
+
def _collection_count(
|
|
1881
|
+
self,
|
|
1882
|
+
collection_id: Optional[str],
|
|
1883
|
+
collection_name: str
|
|
1884
|
+
) -> int:
|
|
1885
|
+
"""
|
|
1886
|
+
[Internal] Get the number of items in collection - Common SQL-based implementation
|
|
1887
|
+
|
|
1888
|
+
Args:
|
|
1889
|
+
collection_id: Collection ID
|
|
1890
|
+
collection_name: Collection name
|
|
1891
|
+
|
|
1892
|
+
Returns:
|
|
1893
|
+
Item count
|
|
1894
|
+
"""
|
|
1895
|
+
logger.info(f"Counting items in collection '{collection_name}'")
|
|
1896
|
+
conn = self._ensure_connection()
|
|
1897
|
+
|
|
1898
|
+
# Convert collection name to table name
|
|
1899
|
+
table_name = CollectionNames.table_name(collection_name)
|
|
1900
|
+
|
|
1901
|
+
# Execute COUNT query
|
|
1902
|
+
sql = f"SELECT COUNT(*) as cnt FROM `{table_name}`"
|
|
1903
|
+
logger.debug(f"Executing SQL: {sql}")
|
|
1904
|
+
|
|
1905
|
+
use_context_manager = self._use_context_manager_for_cursor()
|
|
1906
|
+
rows = self._execute_query_with_cursor(conn, sql, [], use_context_manager)
|
|
1907
|
+
|
|
1908
|
+
if not rows:
|
|
1909
|
+
count = 0
|
|
1910
|
+
else:
|
|
1911
|
+
# Extract count from result
|
|
1912
|
+
row = rows[0]
|
|
1913
|
+
if isinstance(row, dict):
|
|
1914
|
+
count = row.get('cnt', 0)
|
|
1915
|
+
elif isinstance(row, (tuple, list)):
|
|
1916
|
+
count = row[0] if len(row) > 0 else 0
|
|
1917
|
+
else:
|
|
1918
|
+
count = int(row) if row else 0
|
|
1919
|
+
|
|
1920
|
+
logger.info(f"✅ Collection '{collection_name}' has {count} items")
|
|
1921
|
+
return count
|