powermem 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. powermem/__init__.py +103 -0
  2. powermem/agent/__init__.py +35 -0
  3. powermem/agent/abstract/__init__.py +22 -0
  4. powermem/agent/abstract/collaboration.py +259 -0
  5. powermem/agent/abstract/context.py +187 -0
  6. powermem/agent/abstract/manager.py +232 -0
  7. powermem/agent/abstract/permission.py +217 -0
  8. powermem/agent/abstract/privacy.py +267 -0
  9. powermem/agent/abstract/scope.py +199 -0
  10. powermem/agent/agent.py +791 -0
  11. powermem/agent/components/__init__.py +18 -0
  12. powermem/agent/components/collaboration_coordinator.py +645 -0
  13. powermem/agent/components/permission_controller.py +586 -0
  14. powermem/agent/components/privacy_protector.py +767 -0
  15. powermem/agent/components/scope_controller.py +685 -0
  16. powermem/agent/factories/__init__.py +16 -0
  17. powermem/agent/factories/agent_factory.py +266 -0
  18. powermem/agent/factories/config_factory.py +308 -0
  19. powermem/agent/factories/memory_factory.py +229 -0
  20. powermem/agent/implementations/__init__.py +16 -0
  21. powermem/agent/implementations/hybrid.py +728 -0
  22. powermem/agent/implementations/multi_agent.py +1040 -0
  23. powermem/agent/implementations/multi_user.py +1020 -0
  24. powermem/agent/types.py +53 -0
  25. powermem/agent/wrappers/__init__.py +14 -0
  26. powermem/agent/wrappers/agent_memory_wrapper.py +427 -0
  27. powermem/agent/wrappers/compatibility_wrapper.py +520 -0
  28. powermem/config_loader.py +318 -0
  29. powermem/configs.py +249 -0
  30. powermem/core/__init__.py +19 -0
  31. powermem/core/async_memory.py +1493 -0
  32. powermem/core/audit.py +258 -0
  33. powermem/core/base.py +165 -0
  34. powermem/core/memory.py +1567 -0
  35. powermem/core/setup.py +162 -0
  36. powermem/core/telemetry.py +215 -0
  37. powermem/integrations/__init__.py +17 -0
  38. powermem/integrations/embeddings/__init__.py +13 -0
  39. powermem/integrations/embeddings/aws_bedrock.py +100 -0
  40. powermem/integrations/embeddings/azure_openai.py +55 -0
  41. powermem/integrations/embeddings/base.py +31 -0
  42. powermem/integrations/embeddings/config/base.py +132 -0
  43. powermem/integrations/embeddings/configs.py +31 -0
  44. powermem/integrations/embeddings/factory.py +48 -0
  45. powermem/integrations/embeddings/gemini.py +39 -0
  46. powermem/integrations/embeddings/huggingface.py +41 -0
  47. powermem/integrations/embeddings/langchain.py +35 -0
  48. powermem/integrations/embeddings/lmstudio.py +29 -0
  49. powermem/integrations/embeddings/mock.py +11 -0
  50. powermem/integrations/embeddings/ollama.py +53 -0
  51. powermem/integrations/embeddings/openai.py +49 -0
  52. powermem/integrations/embeddings/qwen.py +102 -0
  53. powermem/integrations/embeddings/together.py +31 -0
  54. powermem/integrations/embeddings/vertexai.py +54 -0
  55. powermem/integrations/llm/__init__.py +18 -0
  56. powermem/integrations/llm/anthropic.py +87 -0
  57. powermem/integrations/llm/base.py +132 -0
  58. powermem/integrations/llm/config/anthropic.py +56 -0
  59. powermem/integrations/llm/config/azure.py +56 -0
  60. powermem/integrations/llm/config/base.py +62 -0
  61. powermem/integrations/llm/config/deepseek.py +56 -0
  62. powermem/integrations/llm/config/ollama.py +56 -0
  63. powermem/integrations/llm/config/openai.py +79 -0
  64. powermem/integrations/llm/config/qwen.py +68 -0
  65. powermem/integrations/llm/config/qwen_asr.py +46 -0
  66. powermem/integrations/llm/config/vllm.py +56 -0
  67. powermem/integrations/llm/configs.py +26 -0
  68. powermem/integrations/llm/deepseek.py +106 -0
  69. powermem/integrations/llm/factory.py +118 -0
  70. powermem/integrations/llm/gemini.py +201 -0
  71. powermem/integrations/llm/langchain.py +65 -0
  72. powermem/integrations/llm/ollama.py +106 -0
  73. powermem/integrations/llm/openai.py +166 -0
  74. powermem/integrations/llm/openai_structured.py +80 -0
  75. powermem/integrations/llm/qwen.py +207 -0
  76. powermem/integrations/llm/qwen_asr.py +171 -0
  77. powermem/integrations/llm/vllm.py +106 -0
  78. powermem/integrations/rerank/__init__.py +20 -0
  79. powermem/integrations/rerank/base.py +43 -0
  80. powermem/integrations/rerank/config/__init__.py +7 -0
  81. powermem/integrations/rerank/config/base.py +27 -0
  82. powermem/integrations/rerank/configs.py +23 -0
  83. powermem/integrations/rerank/factory.py +68 -0
  84. powermem/integrations/rerank/qwen.py +159 -0
  85. powermem/intelligence/__init__.py +17 -0
  86. powermem/intelligence/ebbinghaus_algorithm.py +354 -0
  87. powermem/intelligence/importance_evaluator.py +361 -0
  88. powermem/intelligence/intelligent_memory_manager.py +284 -0
  89. powermem/intelligence/manager.py +148 -0
  90. powermem/intelligence/plugin.py +229 -0
  91. powermem/prompts/__init__.py +29 -0
  92. powermem/prompts/graph/graph_prompts.py +217 -0
  93. powermem/prompts/graph/graph_tools_prompts.py +469 -0
  94. powermem/prompts/importance_evaluation.py +246 -0
  95. powermem/prompts/intelligent_memory_prompts.py +163 -0
  96. powermem/prompts/templates.py +193 -0
  97. powermem/storage/__init__.py +14 -0
  98. powermem/storage/adapter.py +896 -0
  99. powermem/storage/base.py +109 -0
  100. powermem/storage/config/base.py +13 -0
  101. powermem/storage/config/oceanbase.py +58 -0
  102. powermem/storage/config/pgvector.py +52 -0
  103. powermem/storage/config/sqlite.py +27 -0
  104. powermem/storage/configs.py +159 -0
  105. powermem/storage/factory.py +59 -0
  106. powermem/storage/migration_manager.py +438 -0
  107. powermem/storage/oceanbase/__init__.py +8 -0
  108. powermem/storage/oceanbase/constants.py +162 -0
  109. powermem/storage/oceanbase/oceanbase.py +1384 -0
  110. powermem/storage/oceanbase/oceanbase_graph.py +1441 -0
  111. powermem/storage/pgvector/__init__.py +7 -0
  112. powermem/storage/pgvector/pgvector.py +420 -0
  113. powermem/storage/sqlite/__init__.py +0 -0
  114. powermem/storage/sqlite/sqlite.py +218 -0
  115. powermem/storage/sqlite/sqlite_vector_store.py +311 -0
  116. powermem/utils/__init__.py +35 -0
  117. powermem/utils/utils.py +605 -0
  118. powermem/version.py +23 -0
  119. powermem-0.1.0.dist-info/METADATA +187 -0
  120. powermem-0.1.0.dist-info/RECORD +123 -0
  121. powermem-0.1.0.dist-info/WHEEL +5 -0
  122. powermem-0.1.0.dist-info/licenses/LICENSE +206 -0
  123. powermem-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1384 @@
1
+ """
2
+ OceanBase storage implementation
3
+
4
+ This module provides OceanBase-based storage for memory data.
5
+ """
6
+ import heapq
7
+ import json
8
+ import logging
9
+ import uuid
10
+ from typing import Any, Dict, List, Optional
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from powermem.storage.base import VectorStoreBase, OutputData
13
+ from powermem.utils.utils import serialize_datetime, generate_snowflake_id
14
+
15
+ try:
16
+ from pyobvector import (
17
+ VECTOR,
18
+ ObVecClient,
19
+ cosine_distance,
20
+ inner_product,
21
+ l2_distance,
22
+ VecIndexType,
23
+ )
24
+ from pyobvector.schema import ReplaceStmt
25
+ from sqlalchemy import JSON, Column, String, Table, func, ColumnElement, BigInteger
26
+ from sqlalchemy import text, and_, or_, not_, select, bindparam, literal_column
27
+ from sqlalchemy.dialects.mysql import LONGTEXT
28
+ except ImportError as e:
29
+ raise ImportError(
30
+ f"Required dependencies not found: {e}. Please install pyobvector and sqlalchemy."
31
+ )
32
+
33
+ from powermem.storage.oceanbase import constants
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ class OceanBaseVectorStore(VectorStoreBase):
38
+ """OceanBase vector store implementation"""
39
+
40
+ def __init__(
41
+ self,
42
+ collection_name: str,
43
+ connection_args: Optional[Dict[str, Any]] = None,
44
+ vidx_metric_type: str = constants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
45
+ vidx_algo_params: Optional[Dict] = None,
46
+ index_type: str = constants.DEFAULT_INDEX_TYPE,
47
+ embedding_model_dims: Optional[int] = None,
48
+ primary_field: str = constants.DEFAULT_PRIMARY_FIELD,
49
+ vector_field: str = constants.DEFAULT_VECTOR_FIELD,
50
+ text_field: str = constants.DEFAULT_TEXT_FIELD,
51
+ metadata_field: str = constants.DEFAULT_METADATA_FIELD,
52
+ vidx_name: str = constants.DEFAULT_VIDX_NAME,
53
+ normalize: bool = False,
54
+ include_sparse: bool = False,
55
+ auto_configure_vector_index: bool = True,
56
+ # Connection parameters (for compatibility with config)
57
+ host: Optional[str] = None,
58
+ port: Optional[str] = None,
59
+ user: Optional[str] = None,
60
+ password: Optional[str] = None,
61
+ db_name: Optional[str] = None,
62
+ hybrid_search: bool = True,
63
+ fulltext_parser: str = constants.DEFAULT_FULLTEXT_PARSER,
64
+ vector_weight: float = 0.5,
65
+ fts_weight: float = 0.5,
66
+ reranker: Optional[Any] = None,
67
+ **kwargs,
68
+ ):
69
+ """
70
+ Initialize the OceanBase vector store.
71
+
72
+ Args:
73
+ collection_name (str): Name of the collection/table.
74
+ connection_args (Optional[Dict[str, Any]]): Connection parameters for OceanBase.
75
+ vidx_metric_type (str): Metric method of distance between vectors.
76
+ vidx_algo_params (Optional[Dict]): Index parameters.
77
+ index_type (str): Type of vector index to use.
78
+ embedding_model_dims (Optional[int]): Dimension of vectors.
79
+ primary_field (str): Name of the primary key column.
80
+ vector_field (str): Name of the vector column.
81
+ text_field (str): Name of the text column.
82
+ metadata_field (str): Name of the metadata column.
83
+ vidx_name (str): Name of the vector index.
84
+ normalize (bool): Whether to perform L2 normalization on vectors.
85
+ include_sparse (bool): Whether to include sparse vector support.
86
+ auto_configure_vector_index (bool): Whether to automatically configure vector index settings.
87
+ host (Optional[str]): OceanBase server host.
88
+ port (Optional[str]): OceanBase server port.
89
+ user (Optional[str]): OceanBase username.
90
+ password (Optional[str]): OceanBase password.
91
+ db_name (Optional[str]): OceanBase database name.
92
+ hybrid_search (bool): Whether to use hybrid search.
93
+ vector_weight (float): Weight for vector search in hybrid search (default: 1.0).
94
+ fts_weight (float): Weight for full-text search in hybrid search (default: 1.0).
95
+ """
96
+ self.normalize = normalize
97
+ self.include_sparse = include_sparse
98
+ self.auto_configure_vector_index = auto_configure_vector_index
99
+ self.hybrid_search = hybrid_search
100
+ self.fulltext_parser = fulltext_parser
101
+ self.vector_weight = vector_weight
102
+ self.fts_weight = fts_weight
103
+ self.reranker = reranker
104
+
105
+ # Validate fulltext parser
106
+ if self.fulltext_parser not in constants.OCEANBASE_SUPPORTED_FULLTEXT_PARSERS:
107
+ supported = ', '.join(constants.OCEANBASE_SUPPORTED_FULLTEXT_PARSERS)
108
+ raise ValueError(
109
+ f"Invalid fulltext parser: {self.fulltext_parser}. "
110
+ f"Supported parsers are: {supported}"
111
+ )
112
+
113
+ # Handle connection arguments - prioritize individual parameters over connection_args
114
+ if connection_args is None:
115
+ connection_args = {}
116
+
117
+ # Merge individual connection parameters with connection_args
118
+ final_connection_args = {
119
+ "host": host or connection_args.get("host", constants.DEFAULT_OCEANBASE_CONNECTION["host"]),
120
+ "port": port or connection_args.get("port", constants.DEFAULT_OCEANBASE_CONNECTION["port"]),
121
+ "user": user or connection_args.get("user", constants.DEFAULT_OCEANBASE_CONNECTION["user"]),
122
+ "password": password or connection_args.get("password", constants.DEFAULT_OCEANBASE_CONNECTION["password"]),
123
+ "db_name": db_name or connection_args.get("db_name", constants.DEFAULT_OCEANBASE_CONNECTION["db_name"]),
124
+ }
125
+
126
+ self.connection_args = final_connection_args
127
+
128
+ self.index_type = index_type.upper()
129
+ if self.index_type not in constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES:
130
+ raise ValueError(
131
+ f"`index_type` should be one of "
132
+ f"{list(constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES.keys())}. "
133
+ f"Got {self.index_type}"
134
+ )
135
+
136
+ # Set default parameters based on index type
137
+ if vidx_algo_params is None:
138
+ index_param_map = constants.OCEANBASE_BUILD_PARAMS_MAPPING
139
+ self.vidx_algo_params = index_param_map[self.index_type].copy()
140
+
141
+ if self.index_type == "IVF_PQ" and "m" not in self.vidx_algo_params:
142
+ self.vidx_algo_params["m"] = 3
143
+ else:
144
+ self.vidx_algo_params = vidx_algo_params.copy()
145
+
146
+ # Set field names
147
+ self.collection_name = collection_name
148
+ self.embedding_model_dims = embedding_model_dims
149
+ self.primary_field = primary_field
150
+ self.vector_field = vector_field
151
+ self.text_field = text_field
152
+ self.metadata_field = metadata_field
153
+ self.vidx_name = vidx_name
154
+ self.sparse_vector_field = "sparse_embedding"
155
+ self.fulltext_field = "fulltext_content"
156
+
157
+ # Set up vector index parameters
158
+ self.vidx_metric_type = vidx_metric_type.lower()
159
+
160
+ # Initialize client
161
+ self._create_client(**kwargs)
162
+ assert self.obvector is not None
163
+
164
+ # Autoconfigure vector index settings if enabled
165
+ if self.auto_configure_vector_index:
166
+ self._configure_vector_index_settings()
167
+
168
+ self._create_col()
169
+
170
+ def _create_client(self, **kwargs):
171
+ """Create and initialize the OceanBase vector client."""
172
+ host = self.connection_args.get("host")
173
+ port = self.connection_args.get("port")
174
+ user = self.connection_args.get("user")
175
+ password = self.connection_args.get("password")
176
+ db_name = self.connection_args.get("db_name")
177
+
178
+ self.obvector = ObVecClient(
179
+ uri=f"{host}:{port}",
180
+ user=user,
181
+ password=password,
182
+ db_name=db_name,
183
+ **kwargs,
184
+ )
185
+
186
+ def _configure_vector_index_settings(self):
187
+ """Configure OceanBase vector index settings automatically."""
188
+ try:
189
+ logger.info("Configuring OceanBase vector index settings...")
190
+
191
+ # Set vector memory limit percentage
192
+ with self.obvector.engine.connect() as conn:
193
+ conn.execute(text("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30"))
194
+ conn.commit()
195
+ logger.info("Set ob_vector_memory_limit_percentage = 30")
196
+
197
+ logger.info("OceanBase vector index configuration completed")
198
+
199
+ except Exception as e:
200
+ logger.warning(f"Failed to configure vector index settings: {e}")
201
+ logger.warning(" Vector index functionality may not work properly")
202
+
203
+ def _create_table_with_index_by_embedding_model_dims(self) -> None:
204
+ """Create table with vector index based on embedding dimension."""
205
+ cols = [
206
+ # Primary key - Snowflake ID (BIGINT without AUTO_INCREMENT)
207
+ Column(self.primary_field, BigInteger, primary_key=True, autoincrement=False),
208
+ # Vector field
209
+ Column(self.vector_field, VECTOR(self.embedding_model_dims)),
210
+ # Text content field
211
+ Column(self.text_field, LONGTEXT),
212
+ # Metadata field (JSON)
213
+ Column(self.metadata_field, JSON),
214
+ Column("user_id", String(128)), # User identifier
215
+ Column("agent_id", String(128)), # Agent identifier
216
+ Column("run_id", String(128)), # Run identifier
217
+ Column("actor_id", String(128)), # Actor identifier
218
+ Column("hash", String(32)), # MD5 hash (32 chars)
219
+ Column("created_at", String(128)),
220
+ Column("updated_at", String(128)),
221
+ Column("category", String(64)), # Category name
222
+ Column(self.fulltext_field, LONGTEXT)
223
+ ]
224
+
225
+ # Add hybrid search columns if enabled
226
+ if self.include_sparse:
227
+ cols.append(Column(self.sparse_vector_field, JSON))
228
+
229
+ # Create vector index
230
+ vidx_params = self.obvector.prepare_index_params()
231
+ vidx_params.add_index(
232
+ field_name=self.vector_field,
233
+ index_type=constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES[self.index_type],
234
+ index_name=self.vidx_name,
235
+ metric_type=self.vidx_metric_type,
236
+ params=self.vidx_algo_params,
237
+ )
238
+
239
+ # Add sparse vector index if enabled
240
+ if self.include_sparse:
241
+ logger.warning("Sparse vector indexing not fully implemented yet")
242
+
243
+ # Create table with vector index first
244
+ self.obvector.create_table_with_index_params(
245
+ table_name=self.collection_name,
246
+ columns=cols,
247
+ indexes=None,
248
+ vidxs=vidx_params,
249
+ partitions=None,
250
+ )
251
+
252
+ logger.debug("DEBUG: Table '%s' created successfully", self.collection_name)
253
+
254
+ def _normalize(self, vector: List[float]) -> List[float]:
255
+ """Normalize vector using L2 normalization."""
256
+ import numpy as np
257
+ arr = np.array(vector)
258
+ norm = np.linalg.norm(arr)
259
+ if norm == 0:
260
+ return vector
261
+ arr = arr / norm
262
+ return arr.tolist()
263
+
264
+ def _get_distance_function(self, metric_type: str):
265
+ """Get the appropriate distance function for the given metric type."""
266
+ if metric_type == "inner_product":
267
+ return inner_product
268
+ elif metric_type == "l2":
269
+ return l2_distance
270
+ elif metric_type == "cosine":
271
+ return cosine_distance
272
+ else:
273
+ raise ValueError(f"Unsupported metric type: {metric_type}")
274
+
275
+ def _get_default_search_params(self) -> dict:
276
+ """Get default search parameters based on index type."""
277
+ search_param_map = constants.OCEANBASE_SEARCH_PARAMS_MAPPING
278
+ return search_param_map.get(
279
+ self.index_type, constants.DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM
280
+ )
281
+
282
+ def create_col(self, name: str, vector_size: Optional[int] = None, distance: str = "l2"):
283
+ """Create a new collection."""
284
+ try:
285
+ if vector_size is None:
286
+ raise ValueError("vector_size must be specified to create a collection.")
287
+ distance = distance.lower()
288
+ if distance not in ("l2", "inner_product", "cosine"):
289
+ raise ValueError("distance must be one of 'l2', 'inner_product', or 'cosine'.")
290
+ self.embedding_model_dims = vector_size
291
+ self.vidx_metric_type = distance
292
+ self.collection_name = name
293
+
294
+ self._create_col()
295
+ logger.info(f"Successfully created collection '{name}' with vector size {vector_size} and distance '{distance}'")
296
+
297
+ except ValueError as e:
298
+ logger.error(f"Invalid parameters for creating collection: {e}")
299
+ raise
300
+ except Exception as e:
301
+ logger.error(f"Failed to create collection '{name}': {e}", exc_info=True)
302
+ raise
303
+
304
+ def _create_col(self):
305
+ """Create a new collection."""
306
+
307
+ if self.embedding_model_dims is None:
308
+ raise ValueError(
309
+ "embedding_model_dims is required for OceanBase vector operations. "
310
+ "Please configure embedding_model_dims in your OceanBaseConfig."
311
+ )
312
+
313
+ # Set up vector index parameters
314
+ if self.vidx_metric_type not in ("l2", "inner_product", "cosine"):
315
+ raise ValueError(
316
+ "`vidx_metric_type` should be set in `l2`/`inner_product`/`cosine`."
317
+ )
318
+
319
+ # Only create table if it doesn't exist (preserve existing data)
320
+ if not self.obvector.check_table_exists(self.collection_name):
321
+ self._create_table_with_index_by_embedding_model_dims()
322
+ logger.info(f"Created new table {self.collection_name}")
323
+ else:
324
+ logger.info(f"Table {self.collection_name} already exists, preserving existing data")
325
+ # Check if the existing table's vector dimension matches the requested dimension
326
+ existing_dim = self._get_existing_vector_dimension()
327
+ if existing_dim is not None and existing_dim != self.embedding_model_dims:
328
+ raise ValueError(
329
+ f"Vector dimension mismatch: existing table '{self.collection_name}' has "
330
+ f"vector dimension {existing_dim}, but requested dimension is {self.embedding_model_dims}. "
331
+ f"Please use a different collection name or delete the existing table."
332
+ )
333
+
334
+ if self.hybrid_search:
335
+ self._check_and_create_fulltext_index()
336
+ self.table = Table(self.collection_name, self.obvector.metadata_obj, autoload_with=self.obvector.engine)
337
+
338
+ def insert(self,
339
+ vectors: List[List[float]],
340
+ payloads: Optional[List[Dict]] = None,
341
+ ids: Optional[List[str]] = None) -> List[int]:
342
+ """
343
+ Insert vectors into the collection.
344
+
345
+ Args:
346
+ vectors: List of vectors to insert
347
+ payloads: Optional list of payload dictionaries
348
+ ids: Deprecated parameter (ignored), IDs are now generated using Snowflake algorithm
349
+
350
+ Returns:
351
+ List[int]: List of generated Snowflake IDs
352
+ """
353
+ try:
354
+ if not vectors:
355
+ return []
356
+
357
+ if payloads is None:
358
+ payloads = [{} for _ in vectors]
359
+
360
+ # Generate Snowflake IDs for each vector
361
+ generated_ids = [generate_snowflake_id() for _ in range(len(vectors))]
362
+
363
+ # Prepare data for insertion with explicit IDs
364
+ data: List[Dict[str, Any]] = []
365
+ for vector, payload, vector_id in zip(vectors, payloads, generated_ids):
366
+ record = self._build_record_for_insert(vector, payload)
367
+ # Explicitly set the primary key field with Snowflake ID
368
+ record[self.primary_field] = vector_id
369
+ data.append(record)
370
+
371
+ # Use transaction to ensure atomicity of insert
372
+ table = Table(self.collection_name, self.obvector.metadata_obj,
373
+ autoload_with=self.obvector.engine)
374
+
375
+ with self.obvector.engine.connect() as conn:
376
+ with conn.begin():
377
+ # Execute REPLACE INTO (upsert) statement
378
+ upsert_stmt = ReplaceStmt(table).values(data)
379
+ conn.execute(upsert_stmt)
380
+
381
+ logger.debug(f"Successfully inserted {len(vectors)} vectors, generated Snowflake IDs: {generated_ids}")
382
+ return generated_ids
383
+
384
+ except Exception as e:
385
+ logger.error(f"Failed to insert vectors into collection '{self.collection_name}': {e}", exc_info=True)
386
+ raise
387
+
388
+ def _parse_metadata(self, metadata_json):
389
+ """
390
+ Parse metadata from OceanBase.
391
+
392
+ SQLAlchemy's JSON type automatically deserializes to dict, but this method
393
+ handles backward compatibility with legacy string-serialized data.
394
+ """
395
+ if isinstance(metadata_json, dict):
396
+ # SQLAlchemy JSON type returns dict directly (preferred path)
397
+ return metadata_json
398
+ elif isinstance(metadata_json, str):
399
+ # Legacy compatibility: handle manually serialized strings
400
+ try:
401
+ # First attempt to parse
402
+ metadata = json.loads(metadata_json)
403
+ # Check if it's still a string (double encoded - legacy bug)
404
+ if isinstance(metadata, str):
405
+ try:
406
+ # Second attempt to parse
407
+ metadata = json.loads(metadata)
408
+ except json.JSONDecodeError:
409
+ metadata = {}
410
+ return metadata
411
+ except json.JSONDecodeError:
412
+ return {}
413
+ else:
414
+ return {}
415
+
416
+ def _generate_where_clause(self, filters: Optional[Dict] = None) -> Optional[List]:
417
+ """
418
+ Generate a properly formatted where clause for OceanBase.
419
+
420
+ Args:
421
+ filters (Optional[Dict]): The filter conditions.
422
+ Supports both simple and complex formats:
423
+
424
+ Simple format (Open Source):
425
+ - Simple values: {"field": "value"} -> field = 'value'
426
+ - Comparison ops: {"field": {"gte": 10, "lte": 20}}
427
+ - List values: {"field": ["a", "b", "c"]} -> field IN ('a', 'b', 'c')
428
+
429
+ Complex format (Platform):
430
+ - AND logic: {"AND": [{"user_id": "alice"}, {"category": "food"}]}
431
+ - OR logic: {"OR": [{"rating": {"gte": 4.0}}, {"priority": "high"}]}
432
+ - Nested: {"AND": [{"user_id": "alice"}, {"OR": [{"rating": {"gte": 4.0}}, {"priority": "high"}]}]}
433
+
434
+ Returns:
435
+ Optional[List]: List of SQLAlchemy ColumnElement objects for where clause.
436
+ """
437
+
438
+ def get_column(key) -> ColumnElement:
439
+ """Get the appropriate column element for a field."""
440
+ if key in self.table.c:
441
+ return self.table.c[key]
442
+ else:
443
+ # Use ->> operator for unquoted JSON extract (MySQL/PostgreSQL)
444
+ return self.table.c[self.metadata_field].op("->>")(f"$.{key}")
445
+
446
+ def build_condition(key, value):
447
+ """Build a single condition."""
448
+ column = get_column(key)
449
+
450
+ if isinstance(value, dict):
451
+ # Handle comparison operators
452
+ conditions = []
453
+ for op, op_value in value.items():
454
+ op = op.lstrip("$")
455
+ match op:
456
+ case "eq":
457
+ conditions.append(column == op_value)
458
+ case "ne":
459
+ conditions.append(column != op_value)
460
+ case "gt":
461
+ conditions.append(column > op_value)
462
+ case "gte":
463
+ conditions.append(column >= op_value)
464
+ case "lt":
465
+ conditions.append(column < op_value)
466
+ case "lte":
467
+ conditions.append(column <= op_value)
468
+ case "in":
469
+ if not isinstance(op_value, list):
470
+ raise TypeError(f"Value for $in must be a list, got {type(op_value)}")
471
+ conditions.append(column.in_(op_value))
472
+ case "nin":
473
+ if not isinstance(op_value, list):
474
+ raise TypeError(f"Value for $nin must be a list, got {type(op_value)}")
475
+ conditions.append(~column.in_(op_value))
476
+ case "like":
477
+ conditions.append(column.like(str(op_value)))
478
+ case "ilike":
479
+ conditions.append(column.ilike(str(op_value)))
480
+ case _:
481
+ raise ValueError(f"Unsupported operator: {op}")
482
+ return and_(*conditions) if conditions else None
483
+ elif value is None:
484
+ return column.is_(None)
485
+ else:
486
+ return column == value
487
+
488
+ def process_condition(cond):
489
+ """Process a single condition, handling nested AND/OR logic."""
490
+ if isinstance(cond, dict):
491
+ # Handle complex filters with AND/OR
492
+ if "AND" in cond:
493
+ and_conditions = [process_condition(item) for item in cond["AND"]]
494
+ and_conditions = [c for c in and_conditions if c is not None]
495
+ return and_(*and_conditions) if and_conditions else None
496
+ elif "OR" in cond:
497
+ or_conditions = [process_condition(item) for item in cond["OR"]]
498
+ or_conditions = [c for c in or_conditions if c is not None]
499
+ return or_(*or_conditions) if or_conditions else None
500
+ else:
501
+ # Simple key-value filters
502
+ conditions = []
503
+ for k, v in cond.items():
504
+ expr = build_condition(k, v)
505
+ if expr is not None:
506
+ conditions.append(expr)
507
+ return and_(*conditions) if conditions else None
508
+ elif isinstance(cond, list):
509
+ subconditions = [process_condition(c) for c in cond]
510
+ subconditions = [c for c in subconditions if c is not None]
511
+ return and_(*subconditions) if subconditions else None
512
+ else:
513
+ return None
514
+
515
+ # Handle complex filters with AND/OR
516
+ result = process_condition(filters)
517
+ return [result] if result is not None else None
518
+
519
+ def _parse_row(self, row) -> tuple:
520
+ """Parse a database result row. Returns up to 12 fields, padding with None if needed."""
521
+ padded_row = list(row) + [None] * (12 - len(row))
522
+ return tuple(padded_row[:12])
523
+
524
+ def _build_standard_metadata(self, user_id: str, agent_id: str, run_id: str,
525
+ actor_id: str, hash_val: str, created_at: str,
526
+ updated_at: str, category: str, metadata_json: str) -> Dict:
527
+ """Build standard metadata dictionary from row fields."""
528
+ # Parse the JSON metadata first - this contains user-defined metadata
529
+ user_metadata = self._parse_metadata(metadata_json)
530
+
531
+ # Build complete payload with standard fields at top level and user metadata nested
532
+ metadata = {
533
+ "user_id": user_id,
534
+ "agent_id": agent_id,
535
+ "run_id": run_id,
536
+ "actor_id": actor_id,
537
+ "hash": hash_val,
538
+ "created_at": created_at,
539
+ "updated_at": updated_at,
540
+ "category": category,
541
+ # Store user metadata as nested structure to preserve it
542
+ "metadata": user_metadata
543
+ }
544
+
545
+ return metadata
546
+
547
+ def _create_output_data(self, vector_id: int, text_content: str, score: float,
548
+ metadata: Dict) -> OutputData:
549
+ """Create an OutputData object with standard structure."""
550
+ return OutputData(
551
+ id=vector_id,
552
+ score=score,
553
+ payload={
554
+ "data": text_content,
555
+ **metadata
556
+ }
557
+ )
558
+
559
+ def _build_record_for_insert(self, vector: List[float], payload: Dict) -> Dict[str, Any]:
560
+ """
561
+ Build a record dictionary for insertion with all standard fields.
562
+ Note: Primary key (id) should be set explicitly before insertion.
563
+ """
564
+ # Serialize metadata to handle datetime objects
565
+ metadata = payload.get("metadata", {})
566
+ serialized_metadata = serialize_datetime(metadata) if metadata else {}
567
+
568
+ record = {
569
+ # Primary key (id) will be set explicitly in insert() method with Snowflake ID
570
+ self.vector_field: (
571
+ vector if not self.normalize else self._normalize(vector)
572
+ ),
573
+ self.text_field: payload.get("data", ""),
574
+ self.metadata_field: serialized_metadata,
575
+ "user_id": payload.get("user_id", ""),
576
+ "agent_id": payload.get("agent_id", ""),
577
+ "run_id": payload.get("run_id", ""),
578
+ "actor_id": payload.get("actor_id", ""),
579
+ "hash": payload.get("hash", ""),
580
+ "created_at": serialize_datetime(payload.get("created_at", "")),
581
+ "updated_at": serialize_datetime(payload.get("updated_at", "")),
582
+ "category": payload.get("category", ""),
583
+ }
584
+
585
+ # Add hybrid search fields if enabled
586
+ if self.include_sparse and "sparse_embedding" in payload:
587
+ record[self.sparse_vector_field] = payload["sparse_embedding"] # SQLAlchemy JSON type handles serialization automatically
588
+
589
+ # Always add full-text content (enabled by default)
590
+ fulltext_content = payload.get("fulltext_content") or payload.get("data", "")
591
+ record[self.fulltext_field] = fulltext_content
592
+
593
+ return record
594
+
595
+ def search(self,
596
+ query: str,
597
+ vectors: List[List[float]],
598
+ limit: int = 5,
599
+ filters: Optional[Dict] = None) -> list[OutputData]:
600
+ # Check if hybrid search is enabled, and we have query text
601
+ # Full-text search is always enabled by default
602
+ if self.hybrid_search and query:
603
+ return self._hybrid_search(query, vectors, limit, filters)
604
+ else:
605
+ return self._vector_search(query, vectors, limit, filters)
606
+
607
+ def _vector_search(self,
608
+ query: str,
609
+ vectors: List[List[float]],
610
+ limit: int = 5,
611
+ filters: Optional[Dict] = None) -> list[OutputData]:
612
+ """Perform pure vector search."""
613
+ try:
614
+ # Handle both cases: single vector or list of vectors
615
+ # If vectors is a single vector (list of floats), use it directly
616
+ if isinstance(vectors, list) and len(vectors) > 0 and isinstance(vectors[0], (int, float)):
617
+ query_vector = vectors
618
+ # If vectors is a list of vectors, use the first one
619
+ elif isinstance(vectors, list) and len(vectors) > 0 and isinstance(vectors[0], list):
620
+ query_vector = vectors[0]
621
+ else:
622
+ logger.warning("Invalid vector format provided for search")
623
+ return []
624
+
625
+ # Build where clause from filters
626
+ where_clause = self._generate_where_clause(filters)
627
+
628
+ # Perform vector search - pyobvector expects a single vector, not a list of vectors
629
+ results = self.obvector.ann_search(
630
+ table_name=self.collection_name,
631
+ vec_data=query_vector if not self.normalize else self._normalize(query_vector),
632
+ vec_column_name=self.vector_field,
633
+ distance_func=self._get_distance_function(self.vidx_metric_type),
634
+ with_dist=True,
635
+ topk=limit,
636
+ output_column_names=[
637
+ self.text_field,
638
+ self.metadata_field,
639
+ self.primary_field,
640
+ "user_id",
641
+ "agent_id",
642
+ "run_id",
643
+ "actor_id",
644
+ "hash",
645
+ "created_at",
646
+ "updated_at",
647
+ "category",
648
+ ],
649
+ where_clause=where_clause,
650
+ )
651
+
652
+ # Convert results to OutputData objects
653
+ search_results = []
654
+ for row in results.fetchall():
655
+ (text_content, metadata_json, vector_id, user_id, agent_id, run_id,
656
+ actor_id, hash_val, created_at, updated_at, category, distance) = self._parse_row(row)
657
+
658
+ # Build standard metadata
659
+ metadata = self._build_standard_metadata(
660
+ user_id, agent_id, run_id, actor_id, hash_val,
661
+ created_at, updated_at, category, metadata_json
662
+ )
663
+
664
+ # Convert distance to score based on metric type
665
+ # Handle None distance (shouldn't happen but be defensive)
666
+ if distance is None:
667
+ logger.warning(f"Distance is None for vector_id {vector_id}, using default score 0.0")
668
+ score = 0.0
669
+ elif self.vidx_metric_type == "l2":
670
+ # For L2 distance, lower is better, so we can use 1/(1+distance) or just use distance
671
+ score = float(distance)
672
+ elif self.vidx_metric_type == "cosine":
673
+ # For cosine distance, lower is better
674
+ score = float(distance)
675
+ elif self.vidx_metric_type == "inner_product":
676
+ # For inner product, higher is better, so we negate the distance
677
+ score = -float(distance)
678
+ else:
679
+ score = float(distance)
680
+
681
+ search_results.append(self._create_output_data(vector_id, text_content, score, metadata))
682
+ logger.debug(f"_vector_search results, len : {len(search_results)}")
683
+ return search_results
684
+
685
+ except Exception as e:
686
+ logger.error(f"Vector search failed in collection '{self.collection_name}': {e}", exc_info=True)
687
+ raise
688
+
689
+ def _fulltext_search(self, query: str, limit: int = 5, filters: Optional[Dict] = None) -> list[OutputData]:
690
+ """Perform full-text search using OceanBase FTS with parameterized queries including score."""
691
+ # Skip search if query is empty
692
+ if not query or not query.strip():
693
+ logger.debug("Full-text search query is empty, returning empty results.")
694
+ return []
695
+
696
+ # Generate where clause from filters using the existing method
697
+ filter_where_clause = self._generate_where_clause(filters)
698
+
699
+ # Build the full-text search condition using SQLAlchemy text with parameters
700
+ # Use the same parameter format that SQLAlchemy will use for other parameters
701
+ fts_condition = text(f"MATCH({self.fulltext_field}) AGAINST(:query IN NATURAL LANGUAGE MODE)").bindparams(
702
+ bindparam("query", query)
703
+ )
704
+
705
+ # Combine FTS condition with filter conditions
706
+ where_conditions = [fts_condition]
707
+ if filter_where_clause:
708
+ where_conditions.extend(filter_where_clause)
709
+
710
+ # Build custom query to include score field
711
+ try:
712
+ # Build select statement with specific columns AND score
713
+ columns = [
714
+ self.table.c[self.text_field],
715
+ self.table.c[self.metadata_field],
716
+ self.table.c[self.primary_field],
717
+ self.table.c["user_id"],
718
+ self.table.c["agent_id"],
719
+ self.table.c["run_id"],
720
+ self.table.c["actor_id"],
721
+ self.table.c["hash"],
722
+ self.table.c["created_at"],
723
+ self.table.c["updated_at"],
724
+ self.table.c["category"],
725
+ # Add the score calculation as a column
726
+ text(f"MATCH({self.fulltext_field}) AGAINST(:query IN NATURAL LANGUAGE MODE) as score").bindparams(
727
+ bindparam("query", query)
728
+ )
729
+ ]
730
+
731
+ stmt = select(*columns)
732
+
733
+ # Add where conditions
734
+ for condition in where_conditions:
735
+ stmt = stmt.where(condition)
736
+
737
+ # Order by score DESC to get best matches first
738
+ stmt = stmt.order_by(text('score DESC'))
739
+
740
+ # Add limit
741
+ if limit:
742
+ stmt = stmt.limit(limit)
743
+
744
+ # Execute the query with parameters - use direct parameter passing
745
+ with self.obvector.engine.connect() as conn:
746
+ with conn.begin():
747
+ logger.info(f"Executing FTS query with parameters: query={query}")
748
+ # Execute with parameter dictionary - the standard SQLAlchemy way
749
+ results = conn.execute(stmt)
750
+ rows = results.fetchall()
751
+
752
+ except Exception as e:
753
+ logger.warning(f"Full-text search failed, falling back to LIKE search: {e}")
754
+ try:
755
+ # Fallback to simple LIKE search with parameters
756
+ like_query = f"%{query}%"
757
+ like_condition = text(f"{self.fulltext_field} LIKE :like_query").bindparams(
758
+ bindparam("like_query", like_query)
759
+ )
760
+
761
+ fallback_conditions = [like_condition]
762
+ if filter_where_clause:
763
+ fallback_conditions.extend(filter_where_clause)
764
+
765
+ # Build fallback query with default score
766
+ columns = [
767
+ self.table.c[self.text_field],
768
+ self.table.c[self.metadata_field],
769
+ self.table.c[self.primary_field],
770
+ self.table.c["user_id"],
771
+ self.table.c["agent_id"],
772
+ self.table.c["run_id"],
773
+ self.table.c["actor_id"],
774
+ self.table.c["hash"],
775
+ self.table.c["created_at"],
776
+ self.table.c["updated_at"],
777
+ self.table.c["category"],
778
+ # Default score for LIKE search
779
+ '1.0 as score'
780
+ ]
781
+
782
+ stmt = select(*columns)
783
+
784
+ for condition in fallback_conditions:
785
+ stmt = stmt.where(condition)
786
+
787
+ if limit:
788
+ stmt = stmt.limit(limit)
789
+
790
+ # Execute fallback query with parameters
791
+ with self.obvector.engine.connect() as conn:
792
+ with conn.begin():
793
+ logger.info(f"Executing LIKE fallback query with parameters: like_query={like_query}")
794
+ # Execute with parameter dictionary - the standard SQLAlchemy way
795
+ results = conn.execute(stmt)
796
+ rows = results.fetchall()
797
+ except Exception as fallback_error:
798
+ logger.error(f"Both full-text search and LIKE fallback failed: {fallback_error}")
799
+ return []
800
+
801
+ # Convert results to OutputData objects
802
+ fts_results = []
803
+ for row in rows:
804
+ # Parse the row data including score as the last column
805
+ (text_content, metadata_json, vector_id, user_id, agent_id, run_id, actor_id, hash_val,
806
+ created_at, updated_at, category, fts_score) = self._parse_row(row)
807
+
808
+ # Build standard metadata
809
+ metadata = self._build_standard_metadata(
810
+ user_id, agent_id, run_id, actor_id, hash_val,
811
+ created_at, updated_at, category, metadata_json
812
+ )
813
+
814
+ # Use the actual FTS score from the query
815
+ fts_results.append(self._create_output_data(vector_id, text_content, float(fts_score), metadata))
816
+
817
+ logger.info(f"_fulltext_search results, len : {len(fts_results)}, fts_results : {fts_results}")
818
+ return fts_results
819
+
820
+ def _hybrid_search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None,
821
+ fusion_method: str = "rrf", k: int = 60):
822
+ """Perform hybrid search combining vector and full-text search with optional reranking."""
823
+ # Determine candidate limit for reranking
824
+ candidate_limit = limit * 3 if self.reranker else limit
825
+
826
+ # Perform vector search and full-text search in parallel for better performance
827
+ with ThreadPoolExecutor(max_workers=2) as executor:
828
+ # Submit both searches concurrently
829
+ vector_future = executor.submit(self._vector_search, query, vectors, candidate_limit, filters)
830
+ fts_future = executor.submit(self._fulltext_search, query, candidate_limit, filters)
831
+ # Wait for both to complete and get results
832
+ vector_results = vector_future.result()
833
+ fts_results = fts_future.result()
834
+
835
+ # Step 1: Coarse ranking - Combine results using RRF or weighted fusion
836
+ coarse_ranked_results = self._combine_search_results(
837
+ vector_results, fts_results, candidate_limit, fusion_method, k
838
+ )
839
+ logger.debug(f"Coarse ranking completed, candidates: {len(coarse_ranked_results)}")
840
+
841
+ # Step 2: Fine ranking - Use Rerank model for precision sorting (if enabled)
842
+ if self.reranker and query and coarse_ranked_results:
843
+ try:
844
+ final_results = self._apply_rerank(query, coarse_ranked_results, limit)
845
+ logger.debug(f"Rerank applied, final results: {len(final_results)}")
846
+ return final_results
847
+ except Exception as e:
848
+ logger.warning(f"Rerank failed, falling back to coarse ranking: {e}")
849
+ return coarse_ranked_results[:limit]
850
+ else:
851
+ # No reranker, return coarse ranking results
852
+ return coarse_ranked_results[:limit]
853
+
854
+ def _apply_rerank(self, query: str, candidates: List[OutputData], limit: int) -> List[OutputData]:
855
+ """
856
+ Apply Rerank model for precision sorting.
857
+
858
+ Args:
859
+ query: Search query text
860
+ candidates: Candidate results from coarse ranking
861
+ limit: Number of final results to return
862
+
863
+ Returns:
864
+ List of reranked OutputData objects
865
+ """
866
+ if not candidates:
867
+ return []
868
+
869
+ # Extract document texts from candidates
870
+ documents = [result.payload.get('data', '') for result in candidates]
871
+
872
+ # Call reranker to get reranked indices and scores
873
+ reranked_indices = self.reranker.rerank(query, documents, top_n=limit)
874
+
875
+ # Reconstruct results with rerank scores
876
+ final_results = []
877
+ for idx, rerank_score in reranked_indices:
878
+ result = candidates[idx]
879
+ # Preserve original scores in payload
880
+ result.payload['_fusion_score'] = result.score
881
+ # Update score to rerank score
882
+ result.score = rerank_score
883
+ result.payload['_rerank_score'] = rerank_score
884
+ final_results.append(result)
885
+
886
+ # Reorder results: high scores on both ends, low scores in the middle
887
+ if len(final_results) > 1:
888
+ reordered = [None] * len(final_results)
889
+ left = 0
890
+ right = len(final_results) - 1
891
+
892
+ for i, result in enumerate(final_results):
893
+ if i % 2 == 0:
894
+ # Even indices go to the left side
895
+ reordered[left] = result
896
+ left += 1
897
+ else:
898
+ # Odd indices go to the right side
899
+ reordered[right] = result
900
+ right -= 1
901
+
902
+ final_results = reordered
903
+
904
+ logger.debug(f"Rerank completed: {len(final_results)} results")
905
+
906
+ return final_results
907
+
908
+ def _combine_search_results(self, vector_results: List[OutputData], fts_results: List[OutputData],
909
+ limit: int, fusion_method: str = "rrf", k: int = 60):
910
+ """Combine and rerank vector and full-text search results using RRF or weighted fusion."""
911
+ if fusion_method == "rrf":
912
+ return self._rrf_fusion(vector_results, fts_results, limit, k)
913
+ else:
914
+ return self._weighted_fusion(vector_results, fts_results, limit)
915
+
916
+ def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[OutputData],
917
+ limit: int, k: int = 60):
918
+ """
919
+ Reciprocal Rank Fusion (RRF) for combining search results.
920
+
921
+ Uses weights configured at initialization (self.vector_weight and self.fts_weight)
922
+ to control the contribution of vector search vs full-text search.
923
+ """
924
+ # Create mapping of document ID to result data
925
+ all_docs = {}
926
+
927
+ # Process vector search results (rank-based scoring with weight)
928
+ for rank, result in enumerate(vector_results, 1):
929
+ rrf_score = self.vector_weight * (1.0 / (k + rank))
930
+ all_docs[result.id] = {
931
+ 'result': result,
932
+ 'vector_rank': rank,
933
+ 'fts_rank': None,
934
+ 'rrf_score': rrf_score
935
+ }
936
+
937
+ # Process FTS results (add or update RRF scores with weight)
938
+ for rank, result in enumerate(fts_results, 1):
939
+ fts_rrf_score = self.fts_weight * (1.0 / (k + rank))
940
+
941
+ if result.id in all_docs:
942
+ # Document found in both searches - combine RRF scores
943
+ all_docs[result.id]['fts_rank'] = rank
944
+ all_docs[result.id]['rrf_score'] += fts_rrf_score
945
+ else:
946
+ # Document only in FTS results
947
+ all_docs[result.id] = {
948
+ 'result': result,
949
+ 'vector_rank': None,
950
+ 'fts_rank': rank,
951
+ 'rrf_score': fts_rrf_score
952
+ }
953
+
954
+ # Convert to final results and sort by RRF score
955
+ heap = []
956
+ for doc_id, doc_data in all_docs.items():
957
+ # Use document ID as tiebreaker to avoid dict comparison when rrf_scores are equal
958
+ if len(heap) < limit:
959
+ heapq.heappush(heap, (doc_data['rrf_score'], doc_id, doc_data))
960
+ elif doc_data['rrf_score'] > heap[0][0]:
961
+ heapq.heapreplace(heap, (doc_data['rrf_score'], doc_id, doc_data))
962
+
963
+ final_results = []
964
+ for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True):
965
+ result = doc_data['result']
966
+ result.score = score
967
+ # Add ranking information to metadata for debugging
968
+ result.payload['_fusion_info'] = {
969
+ 'vector_rank': doc_data['vector_rank'],
970
+ 'fts_rank': doc_data['fts_rank'],
971
+ 'rrf_score': score,
972
+ 'fusion_method': 'rrf',
973
+ 'vector_weight': self.vector_weight,
974
+ 'fts_weight': self.fts_weight
975
+ }
976
+ final_results.append(result)
977
+
978
+ return final_results
979
+
980
+ def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[OutputData],
981
+ limit: int, vector_weight: float = 0.7, text_weight: float = 0.3):
982
+ """Traditional weighted score fusion (fallback method)."""
983
+ # Create a mapping of id to results for deduplication
984
+ combined_results = {}
985
+
986
+ # Normalize vector scores to 0-1 range
987
+ if vector_results:
988
+ vector_scores = [result.score for result in vector_results]
989
+ min_vector_score = min(vector_scores)
990
+ max_vector_score = max(vector_scores)
991
+ vector_score_range = max_vector_score - min_vector_score
992
+
993
+ for result in vector_results:
994
+ if vector_score_range > 0:
995
+ # For distance metrics, lower is better, so we invert the normalized score
996
+ if self.vidx_metric_type in ["l2", "cosine"]:
997
+ normalized_score = 1.0 - (result.score - min_vector_score) / vector_score_range
998
+ else: # inner_product
999
+ normalized_score = (result.score - min_vector_score) / vector_score_range
1000
+ else:
1001
+ normalized_score = 1.0
1002
+
1003
+ combined_results[result.id] = {
1004
+ 'result': result,
1005
+ 'vector_score': normalized_score,
1006
+ 'fts_score': 0.0
1007
+ }
1008
+
1009
+ # Add FTS results (FTS scores are already normalized to 0-1)
1010
+ for result in fts_results:
1011
+ if result.id in combined_results:
1012
+ # Update existing result with FTS score
1013
+ combined_results[result.id]['fts_score'] = result.score
1014
+ else:
1015
+ # Add new FTS-only result
1016
+ combined_results[result.id] = {
1017
+ 'result': result,
1018
+ 'vector_score': 0.0,
1019
+ 'fts_score': result.score
1020
+ }
1021
+
1022
+ # Calculate combined scores and create final results
1023
+ heap = []
1024
+ for doc_id, doc_data in combined_results.items():
1025
+ combined_score = (vector_weight * doc_data['vector_score'] +
1026
+ text_weight * doc_data['fts_score'])
1027
+
1028
+ if len(heap) < limit:
1029
+ heapq.heappush(heap, (combined_score, doc_id, doc_data))
1030
+ elif combined_score > heap[0][0]:
1031
+ heapq.heapreplace(heap, (combined_score, doc_id, doc_data))
1032
+
1033
+ final_results = []
1034
+ for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True):
1035
+ result = doc_data['result']
1036
+ result.score = score
1037
+ # Add fusion info for debugging
1038
+ result.payload['_fusion_info'] = {
1039
+ 'vector_score': doc_data['vector_score'],
1040
+ 'fts_score': doc_data['fts_score'],
1041
+ 'combined_score': score,
1042
+ 'fusion_method': 'weighted'
1043
+ }
1044
+ final_results.append(result)
1045
+
1046
+ # Return top results
1047
+ return final_results
1048
+
1049
+ def delete(self, vector_id: int):
1050
+ """Delete a vector by ID."""
1051
+ try:
1052
+ self.obvector.delete(
1053
+ table_name=self.collection_name,
1054
+ ids=[vector_id],
1055
+ )
1056
+ logger.debug(f"Successfully deleted vector with ID: {vector_id} from collection '{self.collection_name}'")
1057
+ except Exception as e:
1058
+ logger.error(f"Failed to delete vector with ID {vector_id} from collection '{self.collection_name}': {e}", exc_info=True)
1059
+ raise
1060
+
1061
+ def update(self, vector_id: int, vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
1062
+ """Update a vector and its payload."""
1063
+ try:
1064
+ # Get existing record to preserve fields not being updated
1065
+ existing_result = self.obvector.get(
1066
+ table_name=self.collection_name,
1067
+ ids=[vector_id],
1068
+ output_column_name=[self.vector_field] # Get the existing vector
1069
+ )
1070
+
1071
+ existing_rows = existing_result.fetchall()
1072
+ if not existing_rows:
1073
+ logger.warning(f"Vector with ID {vector_id} not found in collection '{self.collection_name}'")
1074
+ return
1075
+
1076
+ # Prepare update data
1077
+ update_data: Dict[str, Any] = {
1078
+ self.primary_field: vector_id,
1079
+ }
1080
+
1081
+ if vector is not None:
1082
+ update_data[self.vector_field] = (
1083
+ vector if not self.normalize else self._normalize(vector)
1084
+ )
1085
+ else:
1086
+ # Preserve the existing vector to avoid it being cleared by upsert
1087
+ existing_vector = existing_rows[0][0] if existing_rows[0] else None
1088
+ if existing_vector is not None:
1089
+ update_data[self.vector_field] = existing_vector
1090
+ logger.debug(f"Preserving existing vector for ID {vector_id}")
1091
+
1092
+ if payload is not None:
1093
+ # Use the helper method to build fields, then merge with update_data
1094
+ temp_record = self._build_record_for_insert(vector or [], payload)
1095
+
1096
+ # Copy relevant fields from temp_record (excluding primary key and vector if not updating)
1097
+ for key, value in temp_record.items():
1098
+ if key != self.primary_field and (vector is not None or key != self.vector_field):
1099
+ update_data[key] = value
1100
+
1101
+ # Update record
1102
+ self.obvector.upsert(
1103
+ table_name=self.collection_name,
1104
+ data=[update_data],
1105
+ )
1106
+ logger.debug(f"Successfully updated vector with ID: {vector_id} in collection '{self.collection_name}'")
1107
+
1108
+ except Exception as e:
1109
+ logger.error(f"Failed to update vector with ID {vector_id} in collection '{self.collection_name}': {e}", exc_info=True)
1110
+ raise
1111
+
1112
+ def get(self, vector_id: int):
1113
+ """Retrieve a vector by ID."""
1114
+ try:
1115
+ results = self.obvector.get(
1116
+ table_name=self.collection_name,
1117
+ ids=[vector_id],
1118
+ output_column_name=[
1119
+ self.vector_field,
1120
+ self.text_field,
1121
+ self.metadata_field,
1122
+ "user_id",
1123
+ "agent_id",
1124
+ "run_id",
1125
+ "actor_id",
1126
+ "hash",
1127
+ "created_at",
1128
+ "updated_at",
1129
+ "category",
1130
+ ],
1131
+ )
1132
+
1133
+ rows = results.fetchall()
1134
+ if not rows:
1135
+ logger.debug(f"Vector with ID {vector_id} not found in collection '{self.collection_name}'")
1136
+ return None
1137
+
1138
+ (vector, text_content, metadata_json, user_id, agent_id,
1139
+ run_id, actor_id, hash_val, created_at, updated_at, category, _) = self._parse_row(rows[0])
1140
+
1141
+ # Build standard metadata
1142
+ metadata = self._build_standard_metadata(
1143
+ user_id, agent_id, run_id, actor_id, hash_val,
1144
+ created_at, updated_at, category, metadata_json
1145
+ )
1146
+
1147
+ logger.debug(f"Successfully retrieved vector with ID: {vector_id} from collection '{self.collection_name}'")
1148
+ return self._create_output_data(vector_id, text_content, 0.0, metadata)
1149
+
1150
+ except Exception as e:
1151
+ logger.error(f"Failed to get vector with ID {vector_id} from collection '{self.collection_name}': {e}", exc_info=True)
1152
+ raise
1153
+
1154
+ def list_cols(self):
1155
+ """List all collections."""
1156
+ try:
1157
+ # Get all tables from the database using the correct SQLAlchemy API
1158
+ with self.obvector.engine.connect() as conn:
1159
+ result = conn.execute(text("SHOW TABLES"))
1160
+ tables = [row[0] for row in result.fetchall()]
1161
+ logger.debug(f"Successfully listed {len(tables)} collections")
1162
+ return tables
1163
+ except Exception as e:
1164
+ logger.error(f"Failed to list collections: {e}", exc_info=True)
1165
+ raise
1166
+
1167
+ def delete_col(self):
1168
+ """Delete the collection."""
1169
+ try:
1170
+ if self.obvector.check_table_exists(self.collection_name):
1171
+ self.obvector.drop_table_if_exist(self.collection_name)
1172
+ logger.info(f"Successfully deleted collection '{self.collection_name}'")
1173
+ else:
1174
+ logger.warning(f"Collection '{self.collection_name}' does not exist, skipping deletion")
1175
+ except Exception as e:
1176
+ logger.error(f"Failed to delete collection '{self.collection_name}': {e}", exc_info=True)
1177
+ raise
1178
+
1179
+ def _get_existing_vector_dimension(self) -> Optional[int]:
1180
+ """Get the dimension of the existing vector field in the table."""
1181
+ if not self.obvector.check_table_exists(self.collection_name):
1182
+ return None
1183
+
1184
+ try:
1185
+ # Get table schema information using the correct SQLAlchemy API
1186
+ with self.obvector.engine.connect() as conn:
1187
+ result = conn.execute(text(f"DESCRIBE {self.collection_name}"))
1188
+ columns = result.fetchall()
1189
+
1190
+ # Find the vector field and extract its dimension
1191
+ for col in columns:
1192
+ if col[0] == self.vector_field:
1193
+ # Parse vector type like "VECTOR(1536)" to extract dimension
1194
+ col_type = col[1]
1195
+ if col_type.startswith("VECTOR(") and col_type.endswith(")"):
1196
+ dim_str = col_type[7:-1] # Extract dimension from "VECTOR(1536)"
1197
+ return int(dim_str)
1198
+ return None
1199
+ except Exception as e:
1200
+ logger.warning(f"Failed to get vector dimension for table {self.collection_name}: {e}")
1201
+ return None
1202
+
1203
+ def col_info(self):
1204
+ """Get information about the collection."""
1205
+ try:
1206
+ if not self.obvector.check_table_exists(self.collection_name):
1207
+ logger.debug(f"Collection '{self.collection_name}' does not exist")
1208
+ return None
1209
+
1210
+ # Get table schema information using the correct SQLAlchemy API
1211
+ with self.obvector.engine.connect() as conn:
1212
+ result = conn.execute(text(f"DESCRIBE {self.collection_name}"))
1213
+ columns = result.fetchall()
1214
+
1215
+ logger.debug(f"Successfully retrieved info for collection '{self.collection_name}'")
1216
+ return {
1217
+ "name": self.collection_name,
1218
+ "columns": [{"name": col[0], "type": col[1]} for col in columns],
1219
+ "index_type": self.index_type,
1220
+ "metric_type": self.vidx_metric_type,
1221
+ }
1222
+
1223
+ except Exception as e:
1224
+ logger.error(f"Failed to get collection info for '{self.collection_name}': {e}", exc_info=True)
1225
+ raise
1226
+
1227
+ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None):
1228
+ """List all memories."""
1229
+ try:
1230
+ # Build where clause from filters
1231
+ where_clause = self._generate_where_clause(filters)
1232
+
1233
+ # Get all records
1234
+ results = self.obvector.get(
1235
+ table_name=self.collection_name,
1236
+ ids=None,
1237
+ output_column_name=[
1238
+ self.primary_field,
1239
+ self.vector_field,
1240
+ self.text_field,
1241
+ self.metadata_field,
1242
+ "user_id",
1243
+ "agent_id",
1244
+ "run_id",
1245
+ "actor_id",
1246
+ "hash",
1247
+ "created_at",
1248
+ "updated_at",
1249
+ "category",
1250
+ ],
1251
+ where_clause=where_clause
1252
+ )
1253
+
1254
+ memories = []
1255
+ for row in results.fetchall():
1256
+ (vector_id, vector, text_content, metadata_json, user_id, agent_id, run_id,
1257
+ actor_id, hash_val, created_at, updated_at, category) = self._parse_row(row)
1258
+
1259
+ # Build standard metadata
1260
+ metadata = self._build_standard_metadata(
1261
+ user_id, agent_id, run_id, actor_id, hash_val,
1262
+ created_at, updated_at, category, metadata_json
1263
+ )
1264
+
1265
+ memories.append(self._create_output_data(vector_id, text_content, 0.0, metadata))
1266
+
1267
+ if limit:
1268
+ memories = memories[:limit]
1269
+
1270
+ logger.debug(f"Successfully listed {len(memories)} memories from collection '{self.collection_name}'")
1271
+ return [memories]
1272
+
1273
+ except Exception as e:
1274
+ logger.error(f"Failed to list memories from collection '{self.collection_name}': {e}", exc_info=True)
1275
+ raise
1276
+
1277
+ def reset(self):
1278
+ """Reset by deleting the collection and recreating it."""
1279
+ try:
1280
+ logger.info(f"Resetting collection '{self.collection_name}'")
1281
+ self.delete_col()
1282
+ if self.embedding_model_dims is not None:
1283
+ self._create_table_with_index_by_embedding_model_dims()
1284
+
1285
+ if self.hybrid_search:
1286
+ self._check_and_create_fulltext_index()
1287
+
1288
+ logger.info(f"Successfully reset collection '{self.collection_name}'")
1289
+
1290
+ except Exception as e:
1291
+ logger.error(f"Failed to reset collection '{self.collection_name}': {e}", exc_info=True)
1292
+ raise
1293
+
1294
+ def _check_and_create_fulltext_index(self):
1295
+ # Check whether the full-text index exists, if not, create it
1296
+ if not self._check_fulltext_index_exists():
1297
+ self._create_fulltext_index()
1298
+
1299
+ def _check_fulltext_index_exists(self) -> bool:
1300
+ """
1301
+ Check if the full-text index of the specified table exists.
1302
+ """
1303
+ try:
1304
+ with self.obvector.engine.connect() as conn:
1305
+ result = conn.execute(text(f"SHOW INDEX FROM {self.collection_name}"))
1306
+ indexes = result.fetchall()
1307
+
1308
+ for index in indexes:
1309
+ # Index [2] is the index name, index [4] is the column name, and index [10] is the index type
1310
+ if len(index) > 10 and index[10] == 'FULLTEXT':
1311
+ if self.fulltext_field in str(index[4]):
1312
+ return True
1313
+
1314
+ return False
1315
+
1316
+ except Exception as e:
1317
+ logger.error(f"An error occurred while checking the full-text index: {e}")
1318
+ return False
1319
+
1320
+ def _create_fulltext_index(self):
1321
+ try:
1322
+ logger.debug(
1323
+ "About to create fulltext index for collection '%s' using parser '%s'",
1324
+ self.collection_name,
1325
+ self.fulltext_parser,
1326
+ )
1327
+
1328
+ # Create fulltext index with the specified parser using SQL
1329
+ with self.obvector.engine.connect() as conn:
1330
+ sql_command = text(f"""ALTER TABLE {self.collection_name}
1331
+ ADD FULLTEXT INDEX fulltext_index_for_col_text ({self.fulltext_field}) WITH PARSER {self.fulltext_parser}""")
1332
+
1333
+ logger.debug("DEBUG: Executing SQL: %s", sql_command)
1334
+ conn.execute(sql_command)
1335
+ logger.debug("DEBUG: Fulltext index created successfully for '%s'", self.collection_name)
1336
+
1337
+ except Exception as e:
1338
+ logger.exception("Exception occurred while creating fulltext index")
1339
+ raise Exception(
1340
+ "Failed to add fulltext index to the target table, your OceanBase version must be "
1341
+ "4.3.5.1 or above to support fulltext index and vector index in the same table"
1342
+ ) from e
1343
+
1344
+ # Refresh metadata
1345
+ self.obvector.refresh_metadata([self.collection_name])
1346
+
1347
+ def execute_sql(self, sql: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
1348
+ """
1349
+ Execute a raw SQL statement and return results.
1350
+
1351
+ This method is used by SubStoreMigrationManager to manage migration status table.
1352
+
1353
+ Args:
1354
+ sql: SQL statement to execute
1355
+ params: Optional parameters for the SQL statement
1356
+
1357
+ Returns:
1358
+ List of result rows as dictionaries
1359
+ """
1360
+ try:
1361
+ with self.obvector.engine.connect() as conn:
1362
+ if params:
1363
+ result = conn.execute(text(sql), params)
1364
+ else:
1365
+ result = conn.execute(text(sql))
1366
+
1367
+ # Commit for DDL/DML statements
1368
+ conn.commit()
1369
+
1370
+ # Try to fetch results (for SELECT queries)
1371
+ try:
1372
+ rows = result.fetchall()
1373
+ # Convert rows to dictionaries
1374
+ if rows and result.keys():
1375
+ return [dict(zip(result.keys(), row)) for row in rows]
1376
+ return []
1377
+ except Exception:
1378
+ # No results to fetch (for INSERT/UPDATE/DELETE/CREATE)
1379
+ return []
1380
+
1381
+ except Exception as e:
1382
+ logger.error(f"Failed to execute SQL: {e}")
1383
+ logger.debug(f"SQL statement: {sql}")
1384
+ raise