powermem 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- powermem/__init__.py +103 -0
- powermem/agent/__init__.py +35 -0
- powermem/agent/abstract/__init__.py +22 -0
- powermem/agent/abstract/collaboration.py +259 -0
- powermem/agent/abstract/context.py +187 -0
- powermem/agent/abstract/manager.py +232 -0
- powermem/agent/abstract/permission.py +217 -0
- powermem/agent/abstract/privacy.py +267 -0
- powermem/agent/abstract/scope.py +199 -0
- powermem/agent/agent.py +791 -0
- powermem/agent/components/__init__.py +18 -0
- powermem/agent/components/collaboration_coordinator.py +645 -0
- powermem/agent/components/permission_controller.py +586 -0
- powermem/agent/components/privacy_protector.py +767 -0
- powermem/agent/components/scope_controller.py +685 -0
- powermem/agent/factories/__init__.py +16 -0
- powermem/agent/factories/agent_factory.py +266 -0
- powermem/agent/factories/config_factory.py +308 -0
- powermem/agent/factories/memory_factory.py +229 -0
- powermem/agent/implementations/__init__.py +16 -0
- powermem/agent/implementations/hybrid.py +728 -0
- powermem/agent/implementations/multi_agent.py +1040 -0
- powermem/agent/implementations/multi_user.py +1020 -0
- powermem/agent/types.py +53 -0
- powermem/agent/wrappers/__init__.py +14 -0
- powermem/agent/wrappers/agent_memory_wrapper.py +427 -0
- powermem/agent/wrappers/compatibility_wrapper.py +520 -0
- powermem/config_loader.py +318 -0
- powermem/configs.py +249 -0
- powermem/core/__init__.py +19 -0
- powermem/core/async_memory.py +1493 -0
- powermem/core/audit.py +258 -0
- powermem/core/base.py +165 -0
- powermem/core/memory.py +1567 -0
- powermem/core/setup.py +162 -0
- powermem/core/telemetry.py +215 -0
- powermem/integrations/__init__.py +17 -0
- powermem/integrations/embeddings/__init__.py +13 -0
- powermem/integrations/embeddings/aws_bedrock.py +100 -0
- powermem/integrations/embeddings/azure_openai.py +55 -0
- powermem/integrations/embeddings/base.py +31 -0
- powermem/integrations/embeddings/config/base.py +132 -0
- powermem/integrations/embeddings/configs.py +31 -0
- powermem/integrations/embeddings/factory.py +48 -0
- powermem/integrations/embeddings/gemini.py +39 -0
- powermem/integrations/embeddings/huggingface.py +41 -0
- powermem/integrations/embeddings/langchain.py +35 -0
- powermem/integrations/embeddings/lmstudio.py +29 -0
- powermem/integrations/embeddings/mock.py +11 -0
- powermem/integrations/embeddings/ollama.py +53 -0
- powermem/integrations/embeddings/openai.py +49 -0
- powermem/integrations/embeddings/qwen.py +102 -0
- powermem/integrations/embeddings/together.py +31 -0
- powermem/integrations/embeddings/vertexai.py +54 -0
- powermem/integrations/llm/__init__.py +18 -0
- powermem/integrations/llm/anthropic.py +87 -0
- powermem/integrations/llm/base.py +132 -0
- powermem/integrations/llm/config/anthropic.py +56 -0
- powermem/integrations/llm/config/azure.py +56 -0
- powermem/integrations/llm/config/base.py +62 -0
- powermem/integrations/llm/config/deepseek.py +56 -0
- powermem/integrations/llm/config/ollama.py +56 -0
- powermem/integrations/llm/config/openai.py +79 -0
- powermem/integrations/llm/config/qwen.py +68 -0
- powermem/integrations/llm/config/qwen_asr.py +46 -0
- powermem/integrations/llm/config/vllm.py +56 -0
- powermem/integrations/llm/configs.py +26 -0
- powermem/integrations/llm/deepseek.py +106 -0
- powermem/integrations/llm/factory.py +118 -0
- powermem/integrations/llm/gemini.py +201 -0
- powermem/integrations/llm/langchain.py +65 -0
- powermem/integrations/llm/ollama.py +106 -0
- powermem/integrations/llm/openai.py +166 -0
- powermem/integrations/llm/openai_structured.py +80 -0
- powermem/integrations/llm/qwen.py +207 -0
- powermem/integrations/llm/qwen_asr.py +171 -0
- powermem/integrations/llm/vllm.py +106 -0
- powermem/integrations/rerank/__init__.py +20 -0
- powermem/integrations/rerank/base.py +43 -0
- powermem/integrations/rerank/config/__init__.py +7 -0
- powermem/integrations/rerank/config/base.py +27 -0
- powermem/integrations/rerank/configs.py +23 -0
- powermem/integrations/rerank/factory.py +68 -0
- powermem/integrations/rerank/qwen.py +159 -0
- powermem/intelligence/__init__.py +17 -0
- powermem/intelligence/ebbinghaus_algorithm.py +354 -0
- powermem/intelligence/importance_evaluator.py +361 -0
- powermem/intelligence/intelligent_memory_manager.py +284 -0
- powermem/intelligence/manager.py +148 -0
- powermem/intelligence/plugin.py +229 -0
- powermem/prompts/__init__.py +29 -0
- powermem/prompts/graph/graph_prompts.py +217 -0
- powermem/prompts/graph/graph_tools_prompts.py +469 -0
- powermem/prompts/importance_evaluation.py +246 -0
- powermem/prompts/intelligent_memory_prompts.py +163 -0
- powermem/prompts/templates.py +193 -0
- powermem/storage/__init__.py +14 -0
- powermem/storage/adapter.py +896 -0
- powermem/storage/base.py +109 -0
- powermem/storage/config/base.py +13 -0
- powermem/storage/config/oceanbase.py +58 -0
- powermem/storage/config/pgvector.py +52 -0
- powermem/storage/config/sqlite.py +27 -0
- powermem/storage/configs.py +159 -0
- powermem/storage/factory.py +59 -0
- powermem/storage/migration_manager.py +438 -0
- powermem/storage/oceanbase/__init__.py +8 -0
- powermem/storage/oceanbase/constants.py +162 -0
- powermem/storage/oceanbase/oceanbase.py +1384 -0
- powermem/storage/oceanbase/oceanbase_graph.py +1441 -0
- powermem/storage/pgvector/__init__.py +7 -0
- powermem/storage/pgvector/pgvector.py +420 -0
- powermem/storage/sqlite/__init__.py +0 -0
- powermem/storage/sqlite/sqlite.py +218 -0
- powermem/storage/sqlite/sqlite_vector_store.py +311 -0
- powermem/utils/__init__.py +35 -0
- powermem/utils/utils.py +605 -0
- powermem/version.py +23 -0
- powermem-0.1.0.dist-info/METADATA +187 -0
- powermem-0.1.0.dist-info/RECORD +123 -0
- powermem-0.1.0.dist-info/WHEEL +5 -0
- powermem-0.1.0.dist-info/licenses/LICENSE +206 -0
- powermem-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1384 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OceanBase storage implementation
|
|
3
|
+
|
|
4
|
+
This module provides OceanBase-based storage for memory data.
|
|
5
|
+
"""
|
|
6
|
+
import heapq
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from powermem.storage.base import VectorStoreBase, OutputData
|
|
13
|
+
from powermem.utils.utils import serialize_datetime, generate_snowflake_id
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from pyobvector import (
|
|
17
|
+
VECTOR,
|
|
18
|
+
ObVecClient,
|
|
19
|
+
cosine_distance,
|
|
20
|
+
inner_product,
|
|
21
|
+
l2_distance,
|
|
22
|
+
VecIndexType,
|
|
23
|
+
)
|
|
24
|
+
from pyobvector.schema import ReplaceStmt
|
|
25
|
+
from sqlalchemy import JSON, Column, String, Table, func, ColumnElement, BigInteger
|
|
26
|
+
from sqlalchemy import text, and_, or_, not_, select, bindparam, literal_column
|
|
27
|
+
from sqlalchemy.dialects.mysql import LONGTEXT
|
|
28
|
+
except ImportError as e:
|
|
29
|
+
raise ImportError(
|
|
30
|
+
f"Required dependencies not found: {e}. Please install pyobvector and sqlalchemy."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from powermem.storage.oceanbase import constants
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
class OceanBaseVectorStore(VectorStoreBase):
|
|
38
|
+
"""OceanBase vector store implementation"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
collection_name: str,
|
|
43
|
+
connection_args: Optional[Dict[str, Any]] = None,
|
|
44
|
+
vidx_metric_type: str = constants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
|
|
45
|
+
vidx_algo_params: Optional[Dict] = None,
|
|
46
|
+
index_type: str = constants.DEFAULT_INDEX_TYPE,
|
|
47
|
+
embedding_model_dims: Optional[int] = None,
|
|
48
|
+
primary_field: str = constants.DEFAULT_PRIMARY_FIELD,
|
|
49
|
+
vector_field: str = constants.DEFAULT_VECTOR_FIELD,
|
|
50
|
+
text_field: str = constants.DEFAULT_TEXT_FIELD,
|
|
51
|
+
metadata_field: str = constants.DEFAULT_METADATA_FIELD,
|
|
52
|
+
vidx_name: str = constants.DEFAULT_VIDX_NAME,
|
|
53
|
+
normalize: bool = False,
|
|
54
|
+
include_sparse: bool = False,
|
|
55
|
+
auto_configure_vector_index: bool = True,
|
|
56
|
+
# Connection parameters (for compatibility with config)
|
|
57
|
+
host: Optional[str] = None,
|
|
58
|
+
port: Optional[str] = None,
|
|
59
|
+
user: Optional[str] = None,
|
|
60
|
+
password: Optional[str] = None,
|
|
61
|
+
db_name: Optional[str] = None,
|
|
62
|
+
hybrid_search: bool = True,
|
|
63
|
+
fulltext_parser: str = constants.DEFAULT_FULLTEXT_PARSER,
|
|
64
|
+
vector_weight: float = 0.5,
|
|
65
|
+
fts_weight: float = 0.5,
|
|
66
|
+
reranker: Optional[Any] = None,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Initialize the OceanBase vector store.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
collection_name (str): Name of the collection/table.
|
|
74
|
+
connection_args (Optional[Dict[str, Any]]): Connection parameters for OceanBase.
|
|
75
|
+
vidx_metric_type (str): Metric method of distance between vectors.
|
|
76
|
+
vidx_algo_params (Optional[Dict]): Index parameters.
|
|
77
|
+
index_type (str): Type of vector index to use.
|
|
78
|
+
embedding_model_dims (Optional[int]): Dimension of vectors.
|
|
79
|
+
primary_field (str): Name of the primary key column.
|
|
80
|
+
vector_field (str): Name of the vector column.
|
|
81
|
+
text_field (str): Name of the text column.
|
|
82
|
+
metadata_field (str): Name of the metadata column.
|
|
83
|
+
vidx_name (str): Name of the vector index.
|
|
84
|
+
normalize (bool): Whether to perform L2 normalization on vectors.
|
|
85
|
+
include_sparse (bool): Whether to include sparse vector support.
|
|
86
|
+
auto_configure_vector_index (bool): Whether to automatically configure vector index settings.
|
|
87
|
+
host (Optional[str]): OceanBase server host.
|
|
88
|
+
port (Optional[str]): OceanBase server port.
|
|
89
|
+
user (Optional[str]): OceanBase username.
|
|
90
|
+
password (Optional[str]): OceanBase password.
|
|
91
|
+
db_name (Optional[str]): OceanBase database name.
|
|
92
|
+
hybrid_search (bool): Whether to use hybrid search.
|
|
93
|
+
vector_weight (float): Weight for vector search in hybrid search (default: 1.0).
|
|
94
|
+
fts_weight (float): Weight for full-text search in hybrid search (default: 1.0).
|
|
95
|
+
"""
|
|
96
|
+
self.normalize = normalize
|
|
97
|
+
self.include_sparse = include_sparse
|
|
98
|
+
self.auto_configure_vector_index = auto_configure_vector_index
|
|
99
|
+
self.hybrid_search = hybrid_search
|
|
100
|
+
self.fulltext_parser = fulltext_parser
|
|
101
|
+
self.vector_weight = vector_weight
|
|
102
|
+
self.fts_weight = fts_weight
|
|
103
|
+
self.reranker = reranker
|
|
104
|
+
|
|
105
|
+
# Validate fulltext parser
|
|
106
|
+
if self.fulltext_parser not in constants.OCEANBASE_SUPPORTED_FULLTEXT_PARSERS:
|
|
107
|
+
supported = ', '.join(constants.OCEANBASE_SUPPORTED_FULLTEXT_PARSERS)
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Invalid fulltext parser: {self.fulltext_parser}. "
|
|
110
|
+
f"Supported parsers are: {supported}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Handle connection arguments - prioritize individual parameters over connection_args
|
|
114
|
+
if connection_args is None:
|
|
115
|
+
connection_args = {}
|
|
116
|
+
|
|
117
|
+
# Merge individual connection parameters with connection_args
|
|
118
|
+
final_connection_args = {
|
|
119
|
+
"host": host or connection_args.get("host", constants.DEFAULT_OCEANBASE_CONNECTION["host"]),
|
|
120
|
+
"port": port or connection_args.get("port", constants.DEFAULT_OCEANBASE_CONNECTION["port"]),
|
|
121
|
+
"user": user or connection_args.get("user", constants.DEFAULT_OCEANBASE_CONNECTION["user"]),
|
|
122
|
+
"password": password or connection_args.get("password", constants.DEFAULT_OCEANBASE_CONNECTION["password"]),
|
|
123
|
+
"db_name": db_name or connection_args.get("db_name", constants.DEFAULT_OCEANBASE_CONNECTION["db_name"]),
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
self.connection_args = final_connection_args
|
|
127
|
+
|
|
128
|
+
self.index_type = index_type.upper()
|
|
129
|
+
if self.index_type not in constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"`index_type` should be one of "
|
|
132
|
+
f"{list(constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES.keys())}. "
|
|
133
|
+
f"Got {self.index_type}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Set default parameters based on index type
|
|
137
|
+
if vidx_algo_params is None:
|
|
138
|
+
index_param_map = constants.OCEANBASE_BUILD_PARAMS_MAPPING
|
|
139
|
+
self.vidx_algo_params = index_param_map[self.index_type].copy()
|
|
140
|
+
|
|
141
|
+
if self.index_type == "IVF_PQ" and "m" not in self.vidx_algo_params:
|
|
142
|
+
self.vidx_algo_params["m"] = 3
|
|
143
|
+
else:
|
|
144
|
+
self.vidx_algo_params = vidx_algo_params.copy()
|
|
145
|
+
|
|
146
|
+
# Set field names
|
|
147
|
+
self.collection_name = collection_name
|
|
148
|
+
self.embedding_model_dims = embedding_model_dims
|
|
149
|
+
self.primary_field = primary_field
|
|
150
|
+
self.vector_field = vector_field
|
|
151
|
+
self.text_field = text_field
|
|
152
|
+
self.metadata_field = metadata_field
|
|
153
|
+
self.vidx_name = vidx_name
|
|
154
|
+
self.sparse_vector_field = "sparse_embedding"
|
|
155
|
+
self.fulltext_field = "fulltext_content"
|
|
156
|
+
|
|
157
|
+
# Set up vector index parameters
|
|
158
|
+
self.vidx_metric_type = vidx_metric_type.lower()
|
|
159
|
+
|
|
160
|
+
# Initialize client
|
|
161
|
+
self._create_client(**kwargs)
|
|
162
|
+
assert self.obvector is not None
|
|
163
|
+
|
|
164
|
+
# Autoconfigure vector index settings if enabled
|
|
165
|
+
if self.auto_configure_vector_index:
|
|
166
|
+
self._configure_vector_index_settings()
|
|
167
|
+
|
|
168
|
+
self._create_col()
|
|
169
|
+
|
|
170
|
+
def _create_client(self, **kwargs):
|
|
171
|
+
"""Create and initialize the OceanBase vector client."""
|
|
172
|
+
host = self.connection_args.get("host")
|
|
173
|
+
port = self.connection_args.get("port")
|
|
174
|
+
user = self.connection_args.get("user")
|
|
175
|
+
password = self.connection_args.get("password")
|
|
176
|
+
db_name = self.connection_args.get("db_name")
|
|
177
|
+
|
|
178
|
+
self.obvector = ObVecClient(
|
|
179
|
+
uri=f"{host}:{port}",
|
|
180
|
+
user=user,
|
|
181
|
+
password=password,
|
|
182
|
+
db_name=db_name,
|
|
183
|
+
**kwargs,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def _configure_vector_index_settings(self):
|
|
187
|
+
"""Configure OceanBase vector index settings automatically."""
|
|
188
|
+
try:
|
|
189
|
+
logger.info("Configuring OceanBase vector index settings...")
|
|
190
|
+
|
|
191
|
+
# Set vector memory limit percentage
|
|
192
|
+
with self.obvector.engine.connect() as conn:
|
|
193
|
+
conn.execute(text("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30"))
|
|
194
|
+
conn.commit()
|
|
195
|
+
logger.info("Set ob_vector_memory_limit_percentage = 30")
|
|
196
|
+
|
|
197
|
+
logger.info("OceanBase vector index configuration completed")
|
|
198
|
+
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.warning(f"Failed to configure vector index settings: {e}")
|
|
201
|
+
logger.warning(" Vector index functionality may not work properly")
|
|
202
|
+
|
|
203
|
+
def _create_table_with_index_by_embedding_model_dims(self) -> None:
|
|
204
|
+
"""Create table with vector index based on embedding dimension."""
|
|
205
|
+
cols = [
|
|
206
|
+
# Primary key - Snowflake ID (BIGINT without AUTO_INCREMENT)
|
|
207
|
+
Column(self.primary_field, BigInteger, primary_key=True, autoincrement=False),
|
|
208
|
+
# Vector field
|
|
209
|
+
Column(self.vector_field, VECTOR(self.embedding_model_dims)),
|
|
210
|
+
# Text content field
|
|
211
|
+
Column(self.text_field, LONGTEXT),
|
|
212
|
+
# Metadata field (JSON)
|
|
213
|
+
Column(self.metadata_field, JSON),
|
|
214
|
+
Column("user_id", String(128)), # User identifier
|
|
215
|
+
Column("agent_id", String(128)), # Agent identifier
|
|
216
|
+
Column("run_id", String(128)), # Run identifier
|
|
217
|
+
Column("actor_id", String(128)), # Actor identifier
|
|
218
|
+
Column("hash", String(32)), # MD5 hash (32 chars)
|
|
219
|
+
Column("created_at", String(128)),
|
|
220
|
+
Column("updated_at", String(128)),
|
|
221
|
+
Column("category", String(64)), # Category name
|
|
222
|
+
Column(self.fulltext_field, LONGTEXT)
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
# Add hybrid search columns if enabled
|
|
226
|
+
if self.include_sparse:
|
|
227
|
+
cols.append(Column(self.sparse_vector_field, JSON))
|
|
228
|
+
|
|
229
|
+
# Create vector index
|
|
230
|
+
vidx_params = self.obvector.prepare_index_params()
|
|
231
|
+
vidx_params.add_index(
|
|
232
|
+
field_name=self.vector_field,
|
|
233
|
+
index_type=constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES[self.index_type],
|
|
234
|
+
index_name=self.vidx_name,
|
|
235
|
+
metric_type=self.vidx_metric_type,
|
|
236
|
+
params=self.vidx_algo_params,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Add sparse vector index if enabled
|
|
240
|
+
if self.include_sparse:
|
|
241
|
+
logger.warning("Sparse vector indexing not fully implemented yet")
|
|
242
|
+
|
|
243
|
+
# Create table with vector index first
|
|
244
|
+
self.obvector.create_table_with_index_params(
|
|
245
|
+
table_name=self.collection_name,
|
|
246
|
+
columns=cols,
|
|
247
|
+
indexes=None,
|
|
248
|
+
vidxs=vidx_params,
|
|
249
|
+
partitions=None,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
logger.debug("DEBUG: Table '%s' created successfully", self.collection_name)
|
|
253
|
+
|
|
254
|
+
def _normalize(self, vector: List[float]) -> List[float]:
|
|
255
|
+
"""Normalize vector using L2 normalization."""
|
|
256
|
+
import numpy as np
|
|
257
|
+
arr = np.array(vector)
|
|
258
|
+
norm = np.linalg.norm(arr)
|
|
259
|
+
if norm == 0:
|
|
260
|
+
return vector
|
|
261
|
+
arr = arr / norm
|
|
262
|
+
return arr.tolist()
|
|
263
|
+
|
|
264
|
+
def _get_distance_function(self, metric_type: str):
|
|
265
|
+
"""Get the appropriate distance function for the given metric type."""
|
|
266
|
+
if metric_type == "inner_product":
|
|
267
|
+
return inner_product
|
|
268
|
+
elif metric_type == "l2":
|
|
269
|
+
return l2_distance
|
|
270
|
+
elif metric_type == "cosine":
|
|
271
|
+
return cosine_distance
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(f"Unsupported metric type: {metric_type}")
|
|
274
|
+
|
|
275
|
+
def _get_default_search_params(self) -> dict:
|
|
276
|
+
"""Get default search parameters based on index type."""
|
|
277
|
+
search_param_map = constants.OCEANBASE_SEARCH_PARAMS_MAPPING
|
|
278
|
+
return search_param_map.get(
|
|
279
|
+
self.index_type, constants.DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def create_col(self, name: str, vector_size: Optional[int] = None, distance: str = "l2"):
|
|
283
|
+
"""Create a new collection."""
|
|
284
|
+
try:
|
|
285
|
+
if vector_size is None:
|
|
286
|
+
raise ValueError("vector_size must be specified to create a collection.")
|
|
287
|
+
distance = distance.lower()
|
|
288
|
+
if distance not in ("l2", "inner_product", "cosine"):
|
|
289
|
+
raise ValueError("distance must be one of 'l2', 'inner_product', or 'cosine'.")
|
|
290
|
+
self.embedding_model_dims = vector_size
|
|
291
|
+
self.vidx_metric_type = distance
|
|
292
|
+
self.collection_name = name
|
|
293
|
+
|
|
294
|
+
self._create_col()
|
|
295
|
+
logger.info(f"Successfully created collection '{name}' with vector size {vector_size} and distance '{distance}'")
|
|
296
|
+
|
|
297
|
+
except ValueError as e:
|
|
298
|
+
logger.error(f"Invalid parameters for creating collection: {e}")
|
|
299
|
+
raise
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.error(f"Failed to create collection '{name}': {e}", exc_info=True)
|
|
302
|
+
raise
|
|
303
|
+
|
|
304
|
+
def _create_col(self):
|
|
305
|
+
"""Create a new collection."""
|
|
306
|
+
|
|
307
|
+
if self.embedding_model_dims is None:
|
|
308
|
+
raise ValueError(
|
|
309
|
+
"embedding_model_dims is required for OceanBase vector operations. "
|
|
310
|
+
"Please configure embedding_model_dims in your OceanBaseConfig."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Set up vector index parameters
|
|
314
|
+
if self.vidx_metric_type not in ("l2", "inner_product", "cosine"):
|
|
315
|
+
raise ValueError(
|
|
316
|
+
"`vidx_metric_type` should be set in `l2`/`inner_product`/`cosine`."
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Only create table if it doesn't exist (preserve existing data)
|
|
320
|
+
if not self.obvector.check_table_exists(self.collection_name):
|
|
321
|
+
self._create_table_with_index_by_embedding_model_dims()
|
|
322
|
+
logger.info(f"Created new table {self.collection_name}")
|
|
323
|
+
else:
|
|
324
|
+
logger.info(f"Table {self.collection_name} already exists, preserving existing data")
|
|
325
|
+
# Check if the existing table's vector dimension matches the requested dimension
|
|
326
|
+
existing_dim = self._get_existing_vector_dimension()
|
|
327
|
+
if existing_dim is not None and existing_dim != self.embedding_model_dims:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Vector dimension mismatch: existing table '{self.collection_name}' has "
|
|
330
|
+
f"vector dimension {existing_dim}, but requested dimension is {self.embedding_model_dims}. "
|
|
331
|
+
f"Please use a different collection name or delete the existing table."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
if self.hybrid_search:
|
|
335
|
+
self._check_and_create_fulltext_index()
|
|
336
|
+
self.table = Table(self.collection_name, self.obvector.metadata_obj, autoload_with=self.obvector.engine)
|
|
337
|
+
|
|
338
|
+
def insert(self,
|
|
339
|
+
vectors: List[List[float]],
|
|
340
|
+
payloads: Optional[List[Dict]] = None,
|
|
341
|
+
ids: Optional[List[str]] = None) -> List[int]:
|
|
342
|
+
"""
|
|
343
|
+
Insert vectors into the collection.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
vectors: List of vectors to insert
|
|
347
|
+
payloads: Optional list of payload dictionaries
|
|
348
|
+
ids: Deprecated parameter (ignored), IDs are now generated using Snowflake algorithm
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
List[int]: List of generated Snowflake IDs
|
|
352
|
+
"""
|
|
353
|
+
try:
|
|
354
|
+
if not vectors:
|
|
355
|
+
return []
|
|
356
|
+
|
|
357
|
+
if payloads is None:
|
|
358
|
+
payloads = [{} for _ in vectors]
|
|
359
|
+
|
|
360
|
+
# Generate Snowflake IDs for each vector
|
|
361
|
+
generated_ids = [generate_snowflake_id() for _ in range(len(vectors))]
|
|
362
|
+
|
|
363
|
+
# Prepare data for insertion with explicit IDs
|
|
364
|
+
data: List[Dict[str, Any]] = []
|
|
365
|
+
for vector, payload, vector_id in zip(vectors, payloads, generated_ids):
|
|
366
|
+
record = self._build_record_for_insert(vector, payload)
|
|
367
|
+
# Explicitly set the primary key field with Snowflake ID
|
|
368
|
+
record[self.primary_field] = vector_id
|
|
369
|
+
data.append(record)
|
|
370
|
+
|
|
371
|
+
# Use transaction to ensure atomicity of insert
|
|
372
|
+
table = Table(self.collection_name, self.obvector.metadata_obj,
|
|
373
|
+
autoload_with=self.obvector.engine)
|
|
374
|
+
|
|
375
|
+
with self.obvector.engine.connect() as conn:
|
|
376
|
+
with conn.begin():
|
|
377
|
+
# Execute REPLACE INTO (upsert) statement
|
|
378
|
+
upsert_stmt = ReplaceStmt(table).values(data)
|
|
379
|
+
conn.execute(upsert_stmt)
|
|
380
|
+
|
|
381
|
+
logger.debug(f"Successfully inserted {len(vectors)} vectors, generated Snowflake IDs: {generated_ids}")
|
|
382
|
+
return generated_ids
|
|
383
|
+
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.error(f"Failed to insert vectors into collection '{self.collection_name}': {e}", exc_info=True)
|
|
386
|
+
raise
|
|
387
|
+
|
|
388
|
+
def _parse_metadata(self, metadata_json):
|
|
389
|
+
"""
|
|
390
|
+
Parse metadata from OceanBase.
|
|
391
|
+
|
|
392
|
+
SQLAlchemy's JSON type automatically deserializes to dict, but this method
|
|
393
|
+
handles backward compatibility with legacy string-serialized data.
|
|
394
|
+
"""
|
|
395
|
+
if isinstance(metadata_json, dict):
|
|
396
|
+
# SQLAlchemy JSON type returns dict directly (preferred path)
|
|
397
|
+
return metadata_json
|
|
398
|
+
elif isinstance(metadata_json, str):
|
|
399
|
+
# Legacy compatibility: handle manually serialized strings
|
|
400
|
+
try:
|
|
401
|
+
# First attempt to parse
|
|
402
|
+
metadata = json.loads(metadata_json)
|
|
403
|
+
# Check if it's still a string (double encoded - legacy bug)
|
|
404
|
+
if isinstance(metadata, str):
|
|
405
|
+
try:
|
|
406
|
+
# Second attempt to parse
|
|
407
|
+
metadata = json.loads(metadata)
|
|
408
|
+
except json.JSONDecodeError:
|
|
409
|
+
metadata = {}
|
|
410
|
+
return metadata
|
|
411
|
+
except json.JSONDecodeError:
|
|
412
|
+
return {}
|
|
413
|
+
else:
|
|
414
|
+
return {}
|
|
415
|
+
|
|
416
|
+
def _generate_where_clause(self, filters: Optional[Dict] = None) -> Optional[List]:
|
|
417
|
+
"""
|
|
418
|
+
Generate a properly formatted where clause for OceanBase.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
filters (Optional[Dict]): The filter conditions.
|
|
422
|
+
Supports both simple and complex formats:
|
|
423
|
+
|
|
424
|
+
Simple format (Open Source):
|
|
425
|
+
- Simple values: {"field": "value"} -> field = 'value'
|
|
426
|
+
- Comparison ops: {"field": {"gte": 10, "lte": 20}}
|
|
427
|
+
- List values: {"field": ["a", "b", "c"]} -> field IN ('a', 'b', 'c')
|
|
428
|
+
|
|
429
|
+
Complex format (Platform):
|
|
430
|
+
- AND logic: {"AND": [{"user_id": "alice"}, {"category": "food"}]}
|
|
431
|
+
- OR logic: {"OR": [{"rating": {"gte": 4.0}}, {"priority": "high"}]}
|
|
432
|
+
- Nested: {"AND": [{"user_id": "alice"}, {"OR": [{"rating": {"gte": 4.0}}, {"priority": "high"}]}]}
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Optional[List]: List of SQLAlchemy ColumnElement objects for where clause.
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
def get_column(key) -> ColumnElement:
|
|
439
|
+
"""Get the appropriate column element for a field."""
|
|
440
|
+
if key in self.table.c:
|
|
441
|
+
return self.table.c[key]
|
|
442
|
+
else:
|
|
443
|
+
# Use ->> operator for unquoted JSON extract (MySQL/PostgreSQL)
|
|
444
|
+
return self.table.c[self.metadata_field].op("->>")(f"$.{key}")
|
|
445
|
+
|
|
446
|
+
def build_condition(key, value):
|
|
447
|
+
"""Build a single condition."""
|
|
448
|
+
column = get_column(key)
|
|
449
|
+
|
|
450
|
+
if isinstance(value, dict):
|
|
451
|
+
# Handle comparison operators
|
|
452
|
+
conditions = []
|
|
453
|
+
for op, op_value in value.items():
|
|
454
|
+
op = op.lstrip("$")
|
|
455
|
+
match op:
|
|
456
|
+
case "eq":
|
|
457
|
+
conditions.append(column == op_value)
|
|
458
|
+
case "ne":
|
|
459
|
+
conditions.append(column != op_value)
|
|
460
|
+
case "gt":
|
|
461
|
+
conditions.append(column > op_value)
|
|
462
|
+
case "gte":
|
|
463
|
+
conditions.append(column >= op_value)
|
|
464
|
+
case "lt":
|
|
465
|
+
conditions.append(column < op_value)
|
|
466
|
+
case "lte":
|
|
467
|
+
conditions.append(column <= op_value)
|
|
468
|
+
case "in":
|
|
469
|
+
if not isinstance(op_value, list):
|
|
470
|
+
raise TypeError(f"Value for $in must be a list, got {type(op_value)}")
|
|
471
|
+
conditions.append(column.in_(op_value))
|
|
472
|
+
case "nin":
|
|
473
|
+
if not isinstance(op_value, list):
|
|
474
|
+
raise TypeError(f"Value for $nin must be a list, got {type(op_value)}")
|
|
475
|
+
conditions.append(~column.in_(op_value))
|
|
476
|
+
case "like":
|
|
477
|
+
conditions.append(column.like(str(op_value)))
|
|
478
|
+
case "ilike":
|
|
479
|
+
conditions.append(column.ilike(str(op_value)))
|
|
480
|
+
case _:
|
|
481
|
+
raise ValueError(f"Unsupported operator: {op}")
|
|
482
|
+
return and_(*conditions) if conditions else None
|
|
483
|
+
elif value is None:
|
|
484
|
+
return column.is_(None)
|
|
485
|
+
else:
|
|
486
|
+
return column == value
|
|
487
|
+
|
|
488
|
+
def process_condition(cond):
|
|
489
|
+
"""Process a single condition, handling nested AND/OR logic."""
|
|
490
|
+
if isinstance(cond, dict):
|
|
491
|
+
# Handle complex filters with AND/OR
|
|
492
|
+
if "AND" in cond:
|
|
493
|
+
and_conditions = [process_condition(item) for item in cond["AND"]]
|
|
494
|
+
and_conditions = [c for c in and_conditions if c is not None]
|
|
495
|
+
return and_(*and_conditions) if and_conditions else None
|
|
496
|
+
elif "OR" in cond:
|
|
497
|
+
or_conditions = [process_condition(item) for item in cond["OR"]]
|
|
498
|
+
or_conditions = [c for c in or_conditions if c is not None]
|
|
499
|
+
return or_(*or_conditions) if or_conditions else None
|
|
500
|
+
else:
|
|
501
|
+
# Simple key-value filters
|
|
502
|
+
conditions = []
|
|
503
|
+
for k, v in cond.items():
|
|
504
|
+
expr = build_condition(k, v)
|
|
505
|
+
if expr is not None:
|
|
506
|
+
conditions.append(expr)
|
|
507
|
+
return and_(*conditions) if conditions else None
|
|
508
|
+
elif isinstance(cond, list):
|
|
509
|
+
subconditions = [process_condition(c) for c in cond]
|
|
510
|
+
subconditions = [c for c in subconditions if c is not None]
|
|
511
|
+
return and_(*subconditions) if subconditions else None
|
|
512
|
+
else:
|
|
513
|
+
return None
|
|
514
|
+
|
|
515
|
+
# Handle complex filters with AND/OR
|
|
516
|
+
result = process_condition(filters)
|
|
517
|
+
return [result] if result is not None else None
|
|
518
|
+
|
|
519
|
+
def _parse_row(self, row) -> tuple:
|
|
520
|
+
"""Parse a database result row. Returns up to 12 fields, padding with None if needed."""
|
|
521
|
+
padded_row = list(row) + [None] * (12 - len(row))
|
|
522
|
+
return tuple(padded_row[:12])
|
|
523
|
+
|
|
524
|
+
def _build_standard_metadata(self, user_id: str, agent_id: str, run_id: str,
|
|
525
|
+
actor_id: str, hash_val: str, created_at: str,
|
|
526
|
+
updated_at: str, category: str, metadata_json: str) -> Dict:
|
|
527
|
+
"""Build standard metadata dictionary from row fields."""
|
|
528
|
+
# Parse the JSON metadata first - this contains user-defined metadata
|
|
529
|
+
user_metadata = self._parse_metadata(metadata_json)
|
|
530
|
+
|
|
531
|
+
# Build complete payload with standard fields at top level and user metadata nested
|
|
532
|
+
metadata = {
|
|
533
|
+
"user_id": user_id,
|
|
534
|
+
"agent_id": agent_id,
|
|
535
|
+
"run_id": run_id,
|
|
536
|
+
"actor_id": actor_id,
|
|
537
|
+
"hash": hash_val,
|
|
538
|
+
"created_at": created_at,
|
|
539
|
+
"updated_at": updated_at,
|
|
540
|
+
"category": category,
|
|
541
|
+
# Store user metadata as nested structure to preserve it
|
|
542
|
+
"metadata": user_metadata
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
return metadata
|
|
546
|
+
|
|
547
|
+
def _create_output_data(self, vector_id: int, text_content: str, score: float,
|
|
548
|
+
metadata: Dict) -> OutputData:
|
|
549
|
+
"""Create an OutputData object with standard structure."""
|
|
550
|
+
return OutputData(
|
|
551
|
+
id=vector_id,
|
|
552
|
+
score=score,
|
|
553
|
+
payload={
|
|
554
|
+
"data": text_content,
|
|
555
|
+
**metadata
|
|
556
|
+
}
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def _build_record_for_insert(self, vector: List[float], payload: Dict) -> Dict[str, Any]:
|
|
560
|
+
"""
|
|
561
|
+
Build a record dictionary for insertion with all standard fields.
|
|
562
|
+
Note: Primary key (id) should be set explicitly before insertion.
|
|
563
|
+
"""
|
|
564
|
+
# Serialize metadata to handle datetime objects
|
|
565
|
+
metadata = payload.get("metadata", {})
|
|
566
|
+
serialized_metadata = serialize_datetime(metadata) if metadata else {}
|
|
567
|
+
|
|
568
|
+
record = {
|
|
569
|
+
# Primary key (id) will be set explicitly in insert() method with Snowflake ID
|
|
570
|
+
self.vector_field: (
|
|
571
|
+
vector if not self.normalize else self._normalize(vector)
|
|
572
|
+
),
|
|
573
|
+
self.text_field: payload.get("data", ""),
|
|
574
|
+
self.metadata_field: serialized_metadata,
|
|
575
|
+
"user_id": payload.get("user_id", ""),
|
|
576
|
+
"agent_id": payload.get("agent_id", ""),
|
|
577
|
+
"run_id": payload.get("run_id", ""),
|
|
578
|
+
"actor_id": payload.get("actor_id", ""),
|
|
579
|
+
"hash": payload.get("hash", ""),
|
|
580
|
+
"created_at": serialize_datetime(payload.get("created_at", "")),
|
|
581
|
+
"updated_at": serialize_datetime(payload.get("updated_at", "")),
|
|
582
|
+
"category": payload.get("category", ""),
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
# Add hybrid search fields if enabled
|
|
586
|
+
if self.include_sparse and "sparse_embedding" in payload:
|
|
587
|
+
record[self.sparse_vector_field] = payload["sparse_embedding"] # SQLAlchemy JSON type handles serialization automatically
|
|
588
|
+
|
|
589
|
+
# Always add full-text content (enabled by default)
|
|
590
|
+
fulltext_content = payload.get("fulltext_content") or payload.get("data", "")
|
|
591
|
+
record[self.fulltext_field] = fulltext_content
|
|
592
|
+
|
|
593
|
+
return record
|
|
594
|
+
|
|
595
|
+
def search(self,
|
|
596
|
+
query: str,
|
|
597
|
+
vectors: List[List[float]],
|
|
598
|
+
limit: int = 5,
|
|
599
|
+
filters: Optional[Dict] = None) -> list[OutputData]:
|
|
600
|
+
# Check if hybrid search is enabled, and we have query text
|
|
601
|
+
# Full-text search is always enabled by default
|
|
602
|
+
if self.hybrid_search and query:
|
|
603
|
+
return self._hybrid_search(query, vectors, limit, filters)
|
|
604
|
+
else:
|
|
605
|
+
return self._vector_search(query, vectors, limit, filters)
|
|
606
|
+
|
|
607
|
+
def _vector_search(self,
|
|
608
|
+
query: str,
|
|
609
|
+
vectors: List[List[float]],
|
|
610
|
+
limit: int = 5,
|
|
611
|
+
filters: Optional[Dict] = None) -> list[OutputData]:
|
|
612
|
+
"""Perform pure vector search."""
|
|
613
|
+
try:
|
|
614
|
+
# Handle both cases: single vector or list of vectors
|
|
615
|
+
# If vectors is a single vector (list of floats), use it directly
|
|
616
|
+
if isinstance(vectors, list) and len(vectors) > 0 and isinstance(vectors[0], (int, float)):
|
|
617
|
+
query_vector = vectors
|
|
618
|
+
# If vectors is a list of vectors, use the first one
|
|
619
|
+
elif isinstance(vectors, list) and len(vectors) > 0 and isinstance(vectors[0], list):
|
|
620
|
+
query_vector = vectors[0]
|
|
621
|
+
else:
|
|
622
|
+
logger.warning("Invalid vector format provided for search")
|
|
623
|
+
return []
|
|
624
|
+
|
|
625
|
+
# Build where clause from filters
|
|
626
|
+
where_clause = self._generate_where_clause(filters)
|
|
627
|
+
|
|
628
|
+
# Perform vector search - pyobvector expects a single vector, not a list of vectors
|
|
629
|
+
results = self.obvector.ann_search(
|
|
630
|
+
table_name=self.collection_name,
|
|
631
|
+
vec_data=query_vector if not self.normalize else self._normalize(query_vector),
|
|
632
|
+
vec_column_name=self.vector_field,
|
|
633
|
+
distance_func=self._get_distance_function(self.vidx_metric_type),
|
|
634
|
+
with_dist=True,
|
|
635
|
+
topk=limit,
|
|
636
|
+
output_column_names=[
|
|
637
|
+
self.text_field,
|
|
638
|
+
self.metadata_field,
|
|
639
|
+
self.primary_field,
|
|
640
|
+
"user_id",
|
|
641
|
+
"agent_id",
|
|
642
|
+
"run_id",
|
|
643
|
+
"actor_id",
|
|
644
|
+
"hash",
|
|
645
|
+
"created_at",
|
|
646
|
+
"updated_at",
|
|
647
|
+
"category",
|
|
648
|
+
],
|
|
649
|
+
where_clause=where_clause,
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Convert results to OutputData objects
|
|
653
|
+
search_results = []
|
|
654
|
+
for row in results.fetchall():
|
|
655
|
+
(text_content, metadata_json, vector_id, user_id, agent_id, run_id,
|
|
656
|
+
actor_id, hash_val, created_at, updated_at, category, distance) = self._parse_row(row)
|
|
657
|
+
|
|
658
|
+
# Build standard metadata
|
|
659
|
+
metadata = self._build_standard_metadata(
|
|
660
|
+
user_id, agent_id, run_id, actor_id, hash_val,
|
|
661
|
+
created_at, updated_at, category, metadata_json
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Convert distance to score based on metric type
|
|
665
|
+
# Handle None distance (shouldn't happen but be defensive)
|
|
666
|
+
if distance is None:
|
|
667
|
+
logger.warning(f"Distance is None for vector_id {vector_id}, using default score 0.0")
|
|
668
|
+
score = 0.0
|
|
669
|
+
elif self.vidx_metric_type == "l2":
|
|
670
|
+
# For L2 distance, lower is better, so we can use 1/(1+distance) or just use distance
|
|
671
|
+
score = float(distance)
|
|
672
|
+
elif self.vidx_metric_type == "cosine":
|
|
673
|
+
# For cosine distance, lower is better
|
|
674
|
+
score = float(distance)
|
|
675
|
+
elif self.vidx_metric_type == "inner_product":
|
|
676
|
+
# For inner product, higher is better, so we negate the distance
|
|
677
|
+
score = -float(distance)
|
|
678
|
+
else:
|
|
679
|
+
score = float(distance)
|
|
680
|
+
|
|
681
|
+
search_results.append(self._create_output_data(vector_id, text_content, score, metadata))
|
|
682
|
+
logger.debug(f"_vector_search results, len : {len(search_results)}")
|
|
683
|
+
return search_results
|
|
684
|
+
|
|
685
|
+
except Exception as e:
|
|
686
|
+
logger.error(f"Vector search failed in collection '{self.collection_name}': {e}", exc_info=True)
|
|
687
|
+
raise
|
|
688
|
+
|
|
689
|
+
def _fulltext_search(self, query: str, limit: int = 5, filters: Optional[Dict] = None) -> list[OutputData]:
|
|
690
|
+
"""Perform full-text search using OceanBase FTS with parameterized queries including score."""
|
|
691
|
+
# Skip search if query is empty
|
|
692
|
+
if not query or not query.strip():
|
|
693
|
+
logger.debug("Full-text search query is empty, returning empty results.")
|
|
694
|
+
return []
|
|
695
|
+
|
|
696
|
+
# Generate where clause from filters using the existing method
|
|
697
|
+
filter_where_clause = self._generate_where_clause(filters)
|
|
698
|
+
|
|
699
|
+
# Build the full-text search condition using SQLAlchemy text with parameters
|
|
700
|
+
# Use the same parameter format that SQLAlchemy will use for other parameters
|
|
701
|
+
fts_condition = text(f"MATCH({self.fulltext_field}) AGAINST(:query IN NATURAL LANGUAGE MODE)").bindparams(
|
|
702
|
+
bindparam("query", query)
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# Combine FTS condition with filter conditions
|
|
706
|
+
where_conditions = [fts_condition]
|
|
707
|
+
if filter_where_clause:
|
|
708
|
+
where_conditions.extend(filter_where_clause)
|
|
709
|
+
|
|
710
|
+
# Build custom query to include score field
|
|
711
|
+
try:
|
|
712
|
+
# Build select statement with specific columns AND score
|
|
713
|
+
columns = [
|
|
714
|
+
self.table.c[self.text_field],
|
|
715
|
+
self.table.c[self.metadata_field],
|
|
716
|
+
self.table.c[self.primary_field],
|
|
717
|
+
self.table.c["user_id"],
|
|
718
|
+
self.table.c["agent_id"],
|
|
719
|
+
self.table.c["run_id"],
|
|
720
|
+
self.table.c["actor_id"],
|
|
721
|
+
self.table.c["hash"],
|
|
722
|
+
self.table.c["created_at"],
|
|
723
|
+
self.table.c["updated_at"],
|
|
724
|
+
self.table.c["category"],
|
|
725
|
+
# Add the score calculation as a column
|
|
726
|
+
text(f"MATCH({self.fulltext_field}) AGAINST(:query IN NATURAL LANGUAGE MODE) as score").bindparams(
|
|
727
|
+
bindparam("query", query)
|
|
728
|
+
)
|
|
729
|
+
]
|
|
730
|
+
|
|
731
|
+
stmt = select(*columns)
|
|
732
|
+
|
|
733
|
+
# Add where conditions
|
|
734
|
+
for condition in where_conditions:
|
|
735
|
+
stmt = stmt.where(condition)
|
|
736
|
+
|
|
737
|
+
# Order by score DESC to get best matches first
|
|
738
|
+
stmt = stmt.order_by(text('score DESC'))
|
|
739
|
+
|
|
740
|
+
# Add limit
|
|
741
|
+
if limit:
|
|
742
|
+
stmt = stmt.limit(limit)
|
|
743
|
+
|
|
744
|
+
# Execute the query with parameters - use direct parameter passing
|
|
745
|
+
with self.obvector.engine.connect() as conn:
|
|
746
|
+
with conn.begin():
|
|
747
|
+
logger.info(f"Executing FTS query with parameters: query={query}")
|
|
748
|
+
# Execute with parameter dictionary - the standard SQLAlchemy way
|
|
749
|
+
results = conn.execute(stmt)
|
|
750
|
+
rows = results.fetchall()
|
|
751
|
+
|
|
752
|
+
except Exception as e:
|
|
753
|
+
logger.warning(f"Full-text search failed, falling back to LIKE search: {e}")
|
|
754
|
+
try:
|
|
755
|
+
# Fallback to simple LIKE search with parameters
|
|
756
|
+
like_query = f"%{query}%"
|
|
757
|
+
like_condition = text(f"{self.fulltext_field} LIKE :like_query").bindparams(
|
|
758
|
+
bindparam("like_query", like_query)
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
fallback_conditions = [like_condition]
|
|
762
|
+
if filter_where_clause:
|
|
763
|
+
fallback_conditions.extend(filter_where_clause)
|
|
764
|
+
|
|
765
|
+
# Build fallback query with default score
|
|
766
|
+
columns = [
|
|
767
|
+
self.table.c[self.text_field],
|
|
768
|
+
self.table.c[self.metadata_field],
|
|
769
|
+
self.table.c[self.primary_field],
|
|
770
|
+
self.table.c["user_id"],
|
|
771
|
+
self.table.c["agent_id"],
|
|
772
|
+
self.table.c["run_id"],
|
|
773
|
+
self.table.c["actor_id"],
|
|
774
|
+
self.table.c["hash"],
|
|
775
|
+
self.table.c["created_at"],
|
|
776
|
+
self.table.c["updated_at"],
|
|
777
|
+
self.table.c["category"],
|
|
778
|
+
# Default score for LIKE search
|
|
779
|
+
'1.0 as score'
|
|
780
|
+
]
|
|
781
|
+
|
|
782
|
+
stmt = select(*columns)
|
|
783
|
+
|
|
784
|
+
for condition in fallback_conditions:
|
|
785
|
+
stmt = stmt.where(condition)
|
|
786
|
+
|
|
787
|
+
if limit:
|
|
788
|
+
stmt = stmt.limit(limit)
|
|
789
|
+
|
|
790
|
+
# Execute fallback query with parameters
|
|
791
|
+
with self.obvector.engine.connect() as conn:
|
|
792
|
+
with conn.begin():
|
|
793
|
+
logger.info(f"Executing LIKE fallback query with parameters: like_query={like_query}")
|
|
794
|
+
# Execute with parameter dictionary - the standard SQLAlchemy way
|
|
795
|
+
results = conn.execute(stmt)
|
|
796
|
+
rows = results.fetchall()
|
|
797
|
+
except Exception as fallback_error:
|
|
798
|
+
logger.error(f"Both full-text search and LIKE fallback failed: {fallback_error}")
|
|
799
|
+
return []
|
|
800
|
+
|
|
801
|
+
# Convert results to OutputData objects
|
|
802
|
+
fts_results = []
|
|
803
|
+
for row in rows:
|
|
804
|
+
# Parse the row data including score as the last column
|
|
805
|
+
(text_content, metadata_json, vector_id, user_id, agent_id, run_id, actor_id, hash_val,
|
|
806
|
+
created_at, updated_at, category, fts_score) = self._parse_row(row)
|
|
807
|
+
|
|
808
|
+
# Build standard metadata
|
|
809
|
+
metadata = self._build_standard_metadata(
|
|
810
|
+
user_id, agent_id, run_id, actor_id, hash_val,
|
|
811
|
+
created_at, updated_at, category, metadata_json
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
# Use the actual FTS score from the query
|
|
815
|
+
fts_results.append(self._create_output_data(vector_id, text_content, float(fts_score), metadata))
|
|
816
|
+
|
|
817
|
+
logger.info(f"_fulltext_search results, len : {len(fts_results)}, fts_results : {fts_results}")
|
|
818
|
+
return fts_results
|
|
819
|
+
|
|
820
|
+
def _hybrid_search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None,
|
|
821
|
+
fusion_method: str = "rrf", k: int = 60):
|
|
822
|
+
"""Perform hybrid search combining vector and full-text search with optional reranking."""
|
|
823
|
+
# Determine candidate limit for reranking
|
|
824
|
+
candidate_limit = limit * 3 if self.reranker else limit
|
|
825
|
+
|
|
826
|
+
# Perform vector search and full-text search in parallel for better performance
|
|
827
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
828
|
+
# Submit both searches concurrently
|
|
829
|
+
vector_future = executor.submit(self._vector_search, query, vectors, candidate_limit, filters)
|
|
830
|
+
fts_future = executor.submit(self._fulltext_search, query, candidate_limit, filters)
|
|
831
|
+
# Wait for both to complete and get results
|
|
832
|
+
vector_results = vector_future.result()
|
|
833
|
+
fts_results = fts_future.result()
|
|
834
|
+
|
|
835
|
+
# Step 1: Coarse ranking - Combine results using RRF or weighted fusion
|
|
836
|
+
coarse_ranked_results = self._combine_search_results(
|
|
837
|
+
vector_results, fts_results, candidate_limit, fusion_method, k
|
|
838
|
+
)
|
|
839
|
+
logger.debug(f"Coarse ranking completed, candidates: {len(coarse_ranked_results)}")
|
|
840
|
+
|
|
841
|
+
# Step 2: Fine ranking - Use Rerank model for precision sorting (if enabled)
|
|
842
|
+
if self.reranker and query and coarse_ranked_results:
|
|
843
|
+
try:
|
|
844
|
+
final_results = self._apply_rerank(query, coarse_ranked_results, limit)
|
|
845
|
+
logger.debug(f"Rerank applied, final results: {len(final_results)}")
|
|
846
|
+
return final_results
|
|
847
|
+
except Exception as e:
|
|
848
|
+
logger.warning(f"Rerank failed, falling back to coarse ranking: {e}")
|
|
849
|
+
return coarse_ranked_results[:limit]
|
|
850
|
+
else:
|
|
851
|
+
# No reranker, return coarse ranking results
|
|
852
|
+
return coarse_ranked_results[:limit]
|
|
853
|
+
|
|
854
|
+
def _apply_rerank(self, query: str, candidates: List[OutputData], limit: int) -> List[OutputData]:
|
|
855
|
+
"""
|
|
856
|
+
Apply Rerank model for precision sorting.
|
|
857
|
+
|
|
858
|
+
Args:
|
|
859
|
+
query: Search query text
|
|
860
|
+
candidates: Candidate results from coarse ranking
|
|
861
|
+
limit: Number of final results to return
|
|
862
|
+
|
|
863
|
+
Returns:
|
|
864
|
+
List of reranked OutputData objects
|
|
865
|
+
"""
|
|
866
|
+
if not candidates:
|
|
867
|
+
return []
|
|
868
|
+
|
|
869
|
+
# Extract document texts from candidates
|
|
870
|
+
documents = [result.payload.get('data', '') for result in candidates]
|
|
871
|
+
|
|
872
|
+
# Call reranker to get reranked indices and scores
|
|
873
|
+
reranked_indices = self.reranker.rerank(query, documents, top_n=limit)
|
|
874
|
+
|
|
875
|
+
# Reconstruct results with rerank scores
|
|
876
|
+
final_results = []
|
|
877
|
+
for idx, rerank_score in reranked_indices:
|
|
878
|
+
result = candidates[idx]
|
|
879
|
+
# Preserve original scores in payload
|
|
880
|
+
result.payload['_fusion_score'] = result.score
|
|
881
|
+
# Update score to rerank score
|
|
882
|
+
result.score = rerank_score
|
|
883
|
+
result.payload['_rerank_score'] = rerank_score
|
|
884
|
+
final_results.append(result)
|
|
885
|
+
|
|
886
|
+
# Reorder results: high scores on both ends, low scores in the middle
|
|
887
|
+
if len(final_results) > 1:
|
|
888
|
+
reordered = [None] * len(final_results)
|
|
889
|
+
left = 0
|
|
890
|
+
right = len(final_results) - 1
|
|
891
|
+
|
|
892
|
+
for i, result in enumerate(final_results):
|
|
893
|
+
if i % 2 == 0:
|
|
894
|
+
# Even indices go to the left side
|
|
895
|
+
reordered[left] = result
|
|
896
|
+
left += 1
|
|
897
|
+
else:
|
|
898
|
+
# Odd indices go to the right side
|
|
899
|
+
reordered[right] = result
|
|
900
|
+
right -= 1
|
|
901
|
+
|
|
902
|
+
final_results = reordered
|
|
903
|
+
|
|
904
|
+
logger.debug(f"Rerank completed: {len(final_results)} results")
|
|
905
|
+
|
|
906
|
+
return final_results
|
|
907
|
+
|
|
908
|
+
def _combine_search_results(self, vector_results: List[OutputData], fts_results: List[OutputData],
|
|
909
|
+
limit: int, fusion_method: str = "rrf", k: int = 60):
|
|
910
|
+
"""Combine and rerank vector and full-text search results using RRF or weighted fusion."""
|
|
911
|
+
if fusion_method == "rrf":
|
|
912
|
+
return self._rrf_fusion(vector_results, fts_results, limit, k)
|
|
913
|
+
else:
|
|
914
|
+
return self._weighted_fusion(vector_results, fts_results, limit)
|
|
915
|
+
|
|
916
|
+
def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[OutputData],
|
|
917
|
+
limit: int, k: int = 60):
|
|
918
|
+
"""
|
|
919
|
+
Reciprocal Rank Fusion (RRF) for combining search results.
|
|
920
|
+
|
|
921
|
+
Uses weights configured at initialization (self.vector_weight and self.fts_weight)
|
|
922
|
+
to control the contribution of vector search vs full-text search.
|
|
923
|
+
"""
|
|
924
|
+
# Create mapping of document ID to result data
|
|
925
|
+
all_docs = {}
|
|
926
|
+
|
|
927
|
+
# Process vector search results (rank-based scoring with weight)
|
|
928
|
+
for rank, result in enumerate(vector_results, 1):
|
|
929
|
+
rrf_score = self.vector_weight * (1.0 / (k + rank))
|
|
930
|
+
all_docs[result.id] = {
|
|
931
|
+
'result': result,
|
|
932
|
+
'vector_rank': rank,
|
|
933
|
+
'fts_rank': None,
|
|
934
|
+
'rrf_score': rrf_score
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
# Process FTS results (add or update RRF scores with weight)
|
|
938
|
+
for rank, result in enumerate(fts_results, 1):
|
|
939
|
+
fts_rrf_score = self.fts_weight * (1.0 / (k + rank))
|
|
940
|
+
|
|
941
|
+
if result.id in all_docs:
|
|
942
|
+
# Document found in both searches - combine RRF scores
|
|
943
|
+
all_docs[result.id]['fts_rank'] = rank
|
|
944
|
+
all_docs[result.id]['rrf_score'] += fts_rrf_score
|
|
945
|
+
else:
|
|
946
|
+
# Document only in FTS results
|
|
947
|
+
all_docs[result.id] = {
|
|
948
|
+
'result': result,
|
|
949
|
+
'vector_rank': None,
|
|
950
|
+
'fts_rank': rank,
|
|
951
|
+
'rrf_score': fts_rrf_score
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
# Convert to final results and sort by RRF score
|
|
955
|
+
heap = []
|
|
956
|
+
for doc_id, doc_data in all_docs.items():
|
|
957
|
+
# Use document ID as tiebreaker to avoid dict comparison when rrf_scores are equal
|
|
958
|
+
if len(heap) < limit:
|
|
959
|
+
heapq.heappush(heap, (doc_data['rrf_score'], doc_id, doc_data))
|
|
960
|
+
elif doc_data['rrf_score'] > heap[0][0]:
|
|
961
|
+
heapq.heapreplace(heap, (doc_data['rrf_score'], doc_id, doc_data))
|
|
962
|
+
|
|
963
|
+
final_results = []
|
|
964
|
+
for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True):
|
|
965
|
+
result = doc_data['result']
|
|
966
|
+
result.score = score
|
|
967
|
+
# Add ranking information to metadata for debugging
|
|
968
|
+
result.payload['_fusion_info'] = {
|
|
969
|
+
'vector_rank': doc_data['vector_rank'],
|
|
970
|
+
'fts_rank': doc_data['fts_rank'],
|
|
971
|
+
'rrf_score': score,
|
|
972
|
+
'fusion_method': 'rrf',
|
|
973
|
+
'vector_weight': self.vector_weight,
|
|
974
|
+
'fts_weight': self.fts_weight
|
|
975
|
+
}
|
|
976
|
+
final_results.append(result)
|
|
977
|
+
|
|
978
|
+
return final_results
|
|
979
|
+
|
|
980
|
+
def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[OutputData],
|
|
981
|
+
limit: int, vector_weight: float = 0.7, text_weight: float = 0.3):
|
|
982
|
+
"""Traditional weighted score fusion (fallback method)."""
|
|
983
|
+
# Create a mapping of id to results for deduplication
|
|
984
|
+
combined_results = {}
|
|
985
|
+
|
|
986
|
+
# Normalize vector scores to 0-1 range
|
|
987
|
+
if vector_results:
|
|
988
|
+
vector_scores = [result.score for result in vector_results]
|
|
989
|
+
min_vector_score = min(vector_scores)
|
|
990
|
+
max_vector_score = max(vector_scores)
|
|
991
|
+
vector_score_range = max_vector_score - min_vector_score
|
|
992
|
+
|
|
993
|
+
for result in vector_results:
|
|
994
|
+
if vector_score_range > 0:
|
|
995
|
+
# For distance metrics, lower is better, so we invert the normalized score
|
|
996
|
+
if self.vidx_metric_type in ["l2", "cosine"]:
|
|
997
|
+
normalized_score = 1.0 - (result.score - min_vector_score) / vector_score_range
|
|
998
|
+
else: # inner_product
|
|
999
|
+
normalized_score = (result.score - min_vector_score) / vector_score_range
|
|
1000
|
+
else:
|
|
1001
|
+
normalized_score = 1.0
|
|
1002
|
+
|
|
1003
|
+
combined_results[result.id] = {
|
|
1004
|
+
'result': result,
|
|
1005
|
+
'vector_score': normalized_score,
|
|
1006
|
+
'fts_score': 0.0
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
# Add FTS results (FTS scores are already normalized to 0-1)
|
|
1010
|
+
for result in fts_results:
|
|
1011
|
+
if result.id in combined_results:
|
|
1012
|
+
# Update existing result with FTS score
|
|
1013
|
+
combined_results[result.id]['fts_score'] = result.score
|
|
1014
|
+
else:
|
|
1015
|
+
# Add new FTS-only result
|
|
1016
|
+
combined_results[result.id] = {
|
|
1017
|
+
'result': result,
|
|
1018
|
+
'vector_score': 0.0,
|
|
1019
|
+
'fts_score': result.score
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
# Calculate combined scores and create final results
|
|
1023
|
+
heap = []
|
|
1024
|
+
for doc_id, doc_data in combined_results.items():
|
|
1025
|
+
combined_score = (vector_weight * doc_data['vector_score'] +
|
|
1026
|
+
text_weight * doc_data['fts_score'])
|
|
1027
|
+
|
|
1028
|
+
if len(heap) < limit:
|
|
1029
|
+
heapq.heappush(heap, (combined_score, doc_id, doc_data))
|
|
1030
|
+
elif combined_score > heap[0][0]:
|
|
1031
|
+
heapq.heapreplace(heap, (combined_score, doc_id, doc_data))
|
|
1032
|
+
|
|
1033
|
+
final_results = []
|
|
1034
|
+
for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True):
|
|
1035
|
+
result = doc_data['result']
|
|
1036
|
+
result.score = score
|
|
1037
|
+
# Add fusion info for debugging
|
|
1038
|
+
result.payload['_fusion_info'] = {
|
|
1039
|
+
'vector_score': doc_data['vector_score'],
|
|
1040
|
+
'fts_score': doc_data['fts_score'],
|
|
1041
|
+
'combined_score': score,
|
|
1042
|
+
'fusion_method': 'weighted'
|
|
1043
|
+
}
|
|
1044
|
+
final_results.append(result)
|
|
1045
|
+
|
|
1046
|
+
# Return top results
|
|
1047
|
+
return final_results
|
|
1048
|
+
|
|
1049
|
+
def delete(self, vector_id: int):
|
|
1050
|
+
"""Delete a vector by ID."""
|
|
1051
|
+
try:
|
|
1052
|
+
self.obvector.delete(
|
|
1053
|
+
table_name=self.collection_name,
|
|
1054
|
+
ids=[vector_id],
|
|
1055
|
+
)
|
|
1056
|
+
logger.debug(f"Successfully deleted vector with ID: {vector_id} from collection '{self.collection_name}'")
|
|
1057
|
+
except Exception as e:
|
|
1058
|
+
logger.error(f"Failed to delete vector with ID {vector_id} from collection '{self.collection_name}': {e}", exc_info=True)
|
|
1059
|
+
raise
|
|
1060
|
+
|
|
1061
|
+
def update(self, vector_id: int, vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
|
|
1062
|
+
"""Update a vector and its payload."""
|
|
1063
|
+
try:
|
|
1064
|
+
# Get existing record to preserve fields not being updated
|
|
1065
|
+
existing_result = self.obvector.get(
|
|
1066
|
+
table_name=self.collection_name,
|
|
1067
|
+
ids=[vector_id],
|
|
1068
|
+
output_column_name=[self.vector_field] # Get the existing vector
|
|
1069
|
+
)
|
|
1070
|
+
|
|
1071
|
+
existing_rows = existing_result.fetchall()
|
|
1072
|
+
if not existing_rows:
|
|
1073
|
+
logger.warning(f"Vector with ID {vector_id} not found in collection '{self.collection_name}'")
|
|
1074
|
+
return
|
|
1075
|
+
|
|
1076
|
+
# Prepare update data
|
|
1077
|
+
update_data: Dict[str, Any] = {
|
|
1078
|
+
self.primary_field: vector_id,
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
if vector is not None:
|
|
1082
|
+
update_data[self.vector_field] = (
|
|
1083
|
+
vector if not self.normalize else self._normalize(vector)
|
|
1084
|
+
)
|
|
1085
|
+
else:
|
|
1086
|
+
# Preserve the existing vector to avoid it being cleared by upsert
|
|
1087
|
+
existing_vector = existing_rows[0][0] if existing_rows[0] else None
|
|
1088
|
+
if existing_vector is not None:
|
|
1089
|
+
update_data[self.vector_field] = existing_vector
|
|
1090
|
+
logger.debug(f"Preserving existing vector for ID {vector_id}")
|
|
1091
|
+
|
|
1092
|
+
if payload is not None:
|
|
1093
|
+
# Use the helper method to build fields, then merge with update_data
|
|
1094
|
+
temp_record = self._build_record_for_insert(vector or [], payload)
|
|
1095
|
+
|
|
1096
|
+
# Copy relevant fields from temp_record (excluding primary key and vector if not updating)
|
|
1097
|
+
for key, value in temp_record.items():
|
|
1098
|
+
if key != self.primary_field and (vector is not None or key != self.vector_field):
|
|
1099
|
+
update_data[key] = value
|
|
1100
|
+
|
|
1101
|
+
# Update record
|
|
1102
|
+
self.obvector.upsert(
|
|
1103
|
+
table_name=self.collection_name,
|
|
1104
|
+
data=[update_data],
|
|
1105
|
+
)
|
|
1106
|
+
logger.debug(f"Successfully updated vector with ID: {vector_id} in collection '{self.collection_name}'")
|
|
1107
|
+
|
|
1108
|
+
except Exception as e:
|
|
1109
|
+
logger.error(f"Failed to update vector with ID {vector_id} in collection '{self.collection_name}': {e}", exc_info=True)
|
|
1110
|
+
raise
|
|
1111
|
+
|
|
1112
|
+
def get(self, vector_id: int):
|
|
1113
|
+
"""Retrieve a vector by ID."""
|
|
1114
|
+
try:
|
|
1115
|
+
results = self.obvector.get(
|
|
1116
|
+
table_name=self.collection_name,
|
|
1117
|
+
ids=[vector_id],
|
|
1118
|
+
output_column_name=[
|
|
1119
|
+
self.vector_field,
|
|
1120
|
+
self.text_field,
|
|
1121
|
+
self.metadata_field,
|
|
1122
|
+
"user_id",
|
|
1123
|
+
"agent_id",
|
|
1124
|
+
"run_id",
|
|
1125
|
+
"actor_id",
|
|
1126
|
+
"hash",
|
|
1127
|
+
"created_at",
|
|
1128
|
+
"updated_at",
|
|
1129
|
+
"category",
|
|
1130
|
+
],
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
rows = results.fetchall()
|
|
1134
|
+
if not rows:
|
|
1135
|
+
logger.debug(f"Vector with ID {vector_id} not found in collection '{self.collection_name}'")
|
|
1136
|
+
return None
|
|
1137
|
+
|
|
1138
|
+
(vector, text_content, metadata_json, user_id, agent_id,
|
|
1139
|
+
run_id, actor_id, hash_val, created_at, updated_at, category, _) = self._parse_row(rows[0])
|
|
1140
|
+
|
|
1141
|
+
# Build standard metadata
|
|
1142
|
+
metadata = self._build_standard_metadata(
|
|
1143
|
+
user_id, agent_id, run_id, actor_id, hash_val,
|
|
1144
|
+
created_at, updated_at, category, metadata_json
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
logger.debug(f"Successfully retrieved vector with ID: {vector_id} from collection '{self.collection_name}'")
|
|
1148
|
+
return self._create_output_data(vector_id, text_content, 0.0, metadata)
|
|
1149
|
+
|
|
1150
|
+
except Exception as e:
|
|
1151
|
+
logger.error(f"Failed to get vector with ID {vector_id} from collection '{self.collection_name}': {e}", exc_info=True)
|
|
1152
|
+
raise
|
|
1153
|
+
|
|
1154
|
+
def list_cols(self):
|
|
1155
|
+
"""List all collections."""
|
|
1156
|
+
try:
|
|
1157
|
+
# Get all tables from the database using the correct SQLAlchemy API
|
|
1158
|
+
with self.obvector.engine.connect() as conn:
|
|
1159
|
+
result = conn.execute(text("SHOW TABLES"))
|
|
1160
|
+
tables = [row[0] for row in result.fetchall()]
|
|
1161
|
+
logger.debug(f"Successfully listed {len(tables)} collections")
|
|
1162
|
+
return tables
|
|
1163
|
+
except Exception as e:
|
|
1164
|
+
logger.error(f"Failed to list collections: {e}", exc_info=True)
|
|
1165
|
+
raise
|
|
1166
|
+
|
|
1167
|
+
def delete_col(self):
|
|
1168
|
+
"""Delete the collection."""
|
|
1169
|
+
try:
|
|
1170
|
+
if self.obvector.check_table_exists(self.collection_name):
|
|
1171
|
+
self.obvector.drop_table_if_exist(self.collection_name)
|
|
1172
|
+
logger.info(f"Successfully deleted collection '{self.collection_name}'")
|
|
1173
|
+
else:
|
|
1174
|
+
logger.warning(f"Collection '{self.collection_name}' does not exist, skipping deletion")
|
|
1175
|
+
except Exception as e:
|
|
1176
|
+
logger.error(f"Failed to delete collection '{self.collection_name}': {e}", exc_info=True)
|
|
1177
|
+
raise
|
|
1178
|
+
|
|
1179
|
+
def _get_existing_vector_dimension(self) -> Optional[int]:
|
|
1180
|
+
"""Get the dimension of the existing vector field in the table."""
|
|
1181
|
+
if not self.obvector.check_table_exists(self.collection_name):
|
|
1182
|
+
return None
|
|
1183
|
+
|
|
1184
|
+
try:
|
|
1185
|
+
# Get table schema information using the correct SQLAlchemy API
|
|
1186
|
+
with self.obvector.engine.connect() as conn:
|
|
1187
|
+
result = conn.execute(text(f"DESCRIBE {self.collection_name}"))
|
|
1188
|
+
columns = result.fetchall()
|
|
1189
|
+
|
|
1190
|
+
# Find the vector field and extract its dimension
|
|
1191
|
+
for col in columns:
|
|
1192
|
+
if col[0] == self.vector_field:
|
|
1193
|
+
# Parse vector type like "VECTOR(1536)" to extract dimension
|
|
1194
|
+
col_type = col[1]
|
|
1195
|
+
if col_type.startswith("VECTOR(") and col_type.endswith(")"):
|
|
1196
|
+
dim_str = col_type[7:-1] # Extract dimension from "VECTOR(1536)"
|
|
1197
|
+
return int(dim_str)
|
|
1198
|
+
return None
|
|
1199
|
+
except Exception as e:
|
|
1200
|
+
logger.warning(f"Failed to get vector dimension for table {self.collection_name}: {e}")
|
|
1201
|
+
return None
|
|
1202
|
+
|
|
1203
|
+
def col_info(self):
|
|
1204
|
+
"""Get information about the collection."""
|
|
1205
|
+
try:
|
|
1206
|
+
if not self.obvector.check_table_exists(self.collection_name):
|
|
1207
|
+
logger.debug(f"Collection '{self.collection_name}' does not exist")
|
|
1208
|
+
return None
|
|
1209
|
+
|
|
1210
|
+
# Get table schema information using the correct SQLAlchemy API
|
|
1211
|
+
with self.obvector.engine.connect() as conn:
|
|
1212
|
+
result = conn.execute(text(f"DESCRIBE {self.collection_name}"))
|
|
1213
|
+
columns = result.fetchall()
|
|
1214
|
+
|
|
1215
|
+
logger.debug(f"Successfully retrieved info for collection '{self.collection_name}'")
|
|
1216
|
+
return {
|
|
1217
|
+
"name": self.collection_name,
|
|
1218
|
+
"columns": [{"name": col[0], "type": col[1]} for col in columns],
|
|
1219
|
+
"index_type": self.index_type,
|
|
1220
|
+
"metric_type": self.vidx_metric_type,
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
except Exception as e:
|
|
1224
|
+
logger.error(f"Failed to get collection info for '{self.collection_name}': {e}", exc_info=True)
|
|
1225
|
+
raise
|
|
1226
|
+
|
|
1227
|
+
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None):
|
|
1228
|
+
"""List all memories."""
|
|
1229
|
+
try:
|
|
1230
|
+
# Build where clause from filters
|
|
1231
|
+
where_clause = self._generate_where_clause(filters)
|
|
1232
|
+
|
|
1233
|
+
# Get all records
|
|
1234
|
+
results = self.obvector.get(
|
|
1235
|
+
table_name=self.collection_name,
|
|
1236
|
+
ids=None,
|
|
1237
|
+
output_column_name=[
|
|
1238
|
+
self.primary_field,
|
|
1239
|
+
self.vector_field,
|
|
1240
|
+
self.text_field,
|
|
1241
|
+
self.metadata_field,
|
|
1242
|
+
"user_id",
|
|
1243
|
+
"agent_id",
|
|
1244
|
+
"run_id",
|
|
1245
|
+
"actor_id",
|
|
1246
|
+
"hash",
|
|
1247
|
+
"created_at",
|
|
1248
|
+
"updated_at",
|
|
1249
|
+
"category",
|
|
1250
|
+
],
|
|
1251
|
+
where_clause=where_clause
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
memories = []
|
|
1255
|
+
for row in results.fetchall():
|
|
1256
|
+
(vector_id, vector, text_content, metadata_json, user_id, agent_id, run_id,
|
|
1257
|
+
actor_id, hash_val, created_at, updated_at, category) = self._parse_row(row)
|
|
1258
|
+
|
|
1259
|
+
# Build standard metadata
|
|
1260
|
+
metadata = self._build_standard_metadata(
|
|
1261
|
+
user_id, agent_id, run_id, actor_id, hash_val,
|
|
1262
|
+
created_at, updated_at, category, metadata_json
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
memories.append(self._create_output_data(vector_id, text_content, 0.0, metadata))
|
|
1266
|
+
|
|
1267
|
+
if limit:
|
|
1268
|
+
memories = memories[:limit]
|
|
1269
|
+
|
|
1270
|
+
logger.debug(f"Successfully listed {len(memories)} memories from collection '{self.collection_name}'")
|
|
1271
|
+
return [memories]
|
|
1272
|
+
|
|
1273
|
+
except Exception as e:
|
|
1274
|
+
logger.error(f"Failed to list memories from collection '{self.collection_name}': {e}", exc_info=True)
|
|
1275
|
+
raise
|
|
1276
|
+
|
|
1277
|
+
def reset(self):
|
|
1278
|
+
"""Reset by deleting the collection and recreating it."""
|
|
1279
|
+
try:
|
|
1280
|
+
logger.info(f"Resetting collection '{self.collection_name}'")
|
|
1281
|
+
self.delete_col()
|
|
1282
|
+
if self.embedding_model_dims is not None:
|
|
1283
|
+
self._create_table_with_index_by_embedding_model_dims()
|
|
1284
|
+
|
|
1285
|
+
if self.hybrid_search:
|
|
1286
|
+
self._check_and_create_fulltext_index()
|
|
1287
|
+
|
|
1288
|
+
logger.info(f"Successfully reset collection '{self.collection_name}'")
|
|
1289
|
+
|
|
1290
|
+
except Exception as e:
|
|
1291
|
+
logger.error(f"Failed to reset collection '{self.collection_name}': {e}", exc_info=True)
|
|
1292
|
+
raise
|
|
1293
|
+
|
|
1294
|
+
def _check_and_create_fulltext_index(self):
|
|
1295
|
+
# Check whether the full-text index exists, if not, create it
|
|
1296
|
+
if not self._check_fulltext_index_exists():
|
|
1297
|
+
self._create_fulltext_index()
|
|
1298
|
+
|
|
1299
|
+
def _check_fulltext_index_exists(self) -> bool:
|
|
1300
|
+
"""
|
|
1301
|
+
Check if the full-text index of the specified table exists.
|
|
1302
|
+
"""
|
|
1303
|
+
try:
|
|
1304
|
+
with self.obvector.engine.connect() as conn:
|
|
1305
|
+
result = conn.execute(text(f"SHOW INDEX FROM {self.collection_name}"))
|
|
1306
|
+
indexes = result.fetchall()
|
|
1307
|
+
|
|
1308
|
+
for index in indexes:
|
|
1309
|
+
# Index [2] is the index name, index [4] is the column name, and index [10] is the index type
|
|
1310
|
+
if len(index) > 10 and index[10] == 'FULLTEXT':
|
|
1311
|
+
if self.fulltext_field in str(index[4]):
|
|
1312
|
+
return True
|
|
1313
|
+
|
|
1314
|
+
return False
|
|
1315
|
+
|
|
1316
|
+
except Exception as e:
|
|
1317
|
+
logger.error(f"An error occurred while checking the full-text index: {e}")
|
|
1318
|
+
return False
|
|
1319
|
+
|
|
1320
|
+
def _create_fulltext_index(self):
|
|
1321
|
+
try:
|
|
1322
|
+
logger.debug(
|
|
1323
|
+
"About to create fulltext index for collection '%s' using parser '%s'",
|
|
1324
|
+
self.collection_name,
|
|
1325
|
+
self.fulltext_parser,
|
|
1326
|
+
)
|
|
1327
|
+
|
|
1328
|
+
# Create fulltext index with the specified parser using SQL
|
|
1329
|
+
with self.obvector.engine.connect() as conn:
|
|
1330
|
+
sql_command = text(f"""ALTER TABLE {self.collection_name}
|
|
1331
|
+
ADD FULLTEXT INDEX fulltext_index_for_col_text ({self.fulltext_field}) WITH PARSER {self.fulltext_parser}""")
|
|
1332
|
+
|
|
1333
|
+
logger.debug("DEBUG: Executing SQL: %s", sql_command)
|
|
1334
|
+
conn.execute(sql_command)
|
|
1335
|
+
logger.debug("DEBUG: Fulltext index created successfully for '%s'", self.collection_name)
|
|
1336
|
+
|
|
1337
|
+
except Exception as e:
|
|
1338
|
+
logger.exception("Exception occurred while creating fulltext index")
|
|
1339
|
+
raise Exception(
|
|
1340
|
+
"Failed to add fulltext index to the target table, your OceanBase version must be "
|
|
1341
|
+
"4.3.5.1 or above to support fulltext index and vector index in the same table"
|
|
1342
|
+
) from e
|
|
1343
|
+
|
|
1344
|
+
# Refresh metadata
|
|
1345
|
+
self.obvector.refresh_metadata([self.collection_name])
|
|
1346
|
+
|
|
1347
|
+
def execute_sql(self, sql: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
|
1348
|
+
"""
|
|
1349
|
+
Execute a raw SQL statement and return results.
|
|
1350
|
+
|
|
1351
|
+
This method is used by SubStoreMigrationManager to manage migration status table.
|
|
1352
|
+
|
|
1353
|
+
Args:
|
|
1354
|
+
sql: SQL statement to execute
|
|
1355
|
+
params: Optional parameters for the SQL statement
|
|
1356
|
+
|
|
1357
|
+
Returns:
|
|
1358
|
+
List of result rows as dictionaries
|
|
1359
|
+
"""
|
|
1360
|
+
try:
|
|
1361
|
+
with self.obvector.engine.connect() as conn:
|
|
1362
|
+
if params:
|
|
1363
|
+
result = conn.execute(text(sql), params)
|
|
1364
|
+
else:
|
|
1365
|
+
result = conn.execute(text(sql))
|
|
1366
|
+
|
|
1367
|
+
# Commit for DDL/DML statements
|
|
1368
|
+
conn.commit()
|
|
1369
|
+
|
|
1370
|
+
# Try to fetch results (for SELECT queries)
|
|
1371
|
+
try:
|
|
1372
|
+
rows = result.fetchall()
|
|
1373
|
+
# Convert rows to dictionaries
|
|
1374
|
+
if rows and result.keys():
|
|
1375
|
+
return [dict(zip(result.keys(), row)) for row in rows]
|
|
1376
|
+
return []
|
|
1377
|
+
except Exception:
|
|
1378
|
+
# No results to fetch (for INSERT/UPDATE/DELETE/CREATE)
|
|
1379
|
+
return []
|
|
1380
|
+
|
|
1381
|
+
except Exception as e:
|
|
1382
|
+
logger.error(f"Failed to execute SQL: {e}")
|
|
1383
|
+
logger.debug(f"SQL statement: {sql}")
|
|
1384
|
+
raise
|