powermem 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- powermem/__init__.py +103 -0
- powermem/agent/__init__.py +35 -0
- powermem/agent/abstract/__init__.py +22 -0
- powermem/agent/abstract/collaboration.py +259 -0
- powermem/agent/abstract/context.py +187 -0
- powermem/agent/abstract/manager.py +232 -0
- powermem/agent/abstract/permission.py +217 -0
- powermem/agent/abstract/privacy.py +267 -0
- powermem/agent/abstract/scope.py +199 -0
- powermem/agent/agent.py +791 -0
- powermem/agent/components/__init__.py +18 -0
- powermem/agent/components/collaboration_coordinator.py +645 -0
- powermem/agent/components/permission_controller.py +586 -0
- powermem/agent/components/privacy_protector.py +767 -0
- powermem/agent/components/scope_controller.py +685 -0
- powermem/agent/factories/__init__.py +16 -0
- powermem/agent/factories/agent_factory.py +266 -0
- powermem/agent/factories/config_factory.py +308 -0
- powermem/agent/factories/memory_factory.py +229 -0
- powermem/agent/implementations/__init__.py +16 -0
- powermem/agent/implementations/hybrid.py +728 -0
- powermem/agent/implementations/multi_agent.py +1040 -0
- powermem/agent/implementations/multi_user.py +1020 -0
- powermem/agent/types.py +53 -0
- powermem/agent/wrappers/__init__.py +14 -0
- powermem/agent/wrappers/agent_memory_wrapper.py +427 -0
- powermem/agent/wrappers/compatibility_wrapper.py +520 -0
- powermem/config_loader.py +318 -0
- powermem/configs.py +249 -0
- powermem/core/__init__.py +19 -0
- powermem/core/async_memory.py +1493 -0
- powermem/core/audit.py +258 -0
- powermem/core/base.py +165 -0
- powermem/core/memory.py +1567 -0
- powermem/core/setup.py +162 -0
- powermem/core/telemetry.py +215 -0
- powermem/integrations/__init__.py +17 -0
- powermem/integrations/embeddings/__init__.py +13 -0
- powermem/integrations/embeddings/aws_bedrock.py +100 -0
- powermem/integrations/embeddings/azure_openai.py +55 -0
- powermem/integrations/embeddings/base.py +31 -0
- powermem/integrations/embeddings/config/base.py +132 -0
- powermem/integrations/embeddings/configs.py +31 -0
- powermem/integrations/embeddings/factory.py +48 -0
- powermem/integrations/embeddings/gemini.py +39 -0
- powermem/integrations/embeddings/huggingface.py +41 -0
- powermem/integrations/embeddings/langchain.py +35 -0
- powermem/integrations/embeddings/lmstudio.py +29 -0
- powermem/integrations/embeddings/mock.py +11 -0
- powermem/integrations/embeddings/ollama.py +53 -0
- powermem/integrations/embeddings/openai.py +49 -0
- powermem/integrations/embeddings/qwen.py +102 -0
- powermem/integrations/embeddings/together.py +31 -0
- powermem/integrations/embeddings/vertexai.py +54 -0
- powermem/integrations/llm/__init__.py +18 -0
- powermem/integrations/llm/anthropic.py +87 -0
- powermem/integrations/llm/base.py +132 -0
- powermem/integrations/llm/config/anthropic.py +56 -0
- powermem/integrations/llm/config/azure.py +56 -0
- powermem/integrations/llm/config/base.py +62 -0
- powermem/integrations/llm/config/deepseek.py +56 -0
- powermem/integrations/llm/config/ollama.py +56 -0
- powermem/integrations/llm/config/openai.py +79 -0
- powermem/integrations/llm/config/qwen.py +68 -0
- powermem/integrations/llm/config/qwen_asr.py +46 -0
- powermem/integrations/llm/config/vllm.py +56 -0
- powermem/integrations/llm/configs.py +26 -0
- powermem/integrations/llm/deepseek.py +106 -0
- powermem/integrations/llm/factory.py +118 -0
- powermem/integrations/llm/gemini.py +201 -0
- powermem/integrations/llm/langchain.py +65 -0
- powermem/integrations/llm/ollama.py +106 -0
- powermem/integrations/llm/openai.py +166 -0
- powermem/integrations/llm/openai_structured.py +80 -0
- powermem/integrations/llm/qwen.py +207 -0
- powermem/integrations/llm/qwen_asr.py +171 -0
- powermem/integrations/llm/vllm.py +106 -0
- powermem/integrations/rerank/__init__.py +20 -0
- powermem/integrations/rerank/base.py +43 -0
- powermem/integrations/rerank/config/__init__.py +7 -0
- powermem/integrations/rerank/config/base.py +27 -0
- powermem/integrations/rerank/configs.py +23 -0
- powermem/integrations/rerank/factory.py +68 -0
- powermem/integrations/rerank/qwen.py +159 -0
- powermem/intelligence/__init__.py +17 -0
- powermem/intelligence/ebbinghaus_algorithm.py +354 -0
- powermem/intelligence/importance_evaluator.py +361 -0
- powermem/intelligence/intelligent_memory_manager.py +284 -0
- powermem/intelligence/manager.py +148 -0
- powermem/intelligence/plugin.py +229 -0
- powermem/prompts/__init__.py +29 -0
- powermem/prompts/graph/graph_prompts.py +217 -0
- powermem/prompts/graph/graph_tools_prompts.py +469 -0
- powermem/prompts/importance_evaluation.py +246 -0
- powermem/prompts/intelligent_memory_prompts.py +163 -0
- powermem/prompts/templates.py +193 -0
- powermem/storage/__init__.py +14 -0
- powermem/storage/adapter.py +896 -0
- powermem/storage/base.py +109 -0
- powermem/storage/config/base.py +13 -0
- powermem/storage/config/oceanbase.py +58 -0
- powermem/storage/config/pgvector.py +52 -0
- powermem/storage/config/sqlite.py +27 -0
- powermem/storage/configs.py +159 -0
- powermem/storage/factory.py +59 -0
- powermem/storage/migration_manager.py +438 -0
- powermem/storage/oceanbase/__init__.py +8 -0
- powermem/storage/oceanbase/constants.py +162 -0
- powermem/storage/oceanbase/oceanbase.py +1384 -0
- powermem/storage/oceanbase/oceanbase_graph.py +1441 -0
- powermem/storage/pgvector/__init__.py +7 -0
- powermem/storage/pgvector/pgvector.py +420 -0
- powermem/storage/sqlite/__init__.py +0 -0
- powermem/storage/sqlite/sqlite.py +218 -0
- powermem/storage/sqlite/sqlite_vector_store.py +311 -0
- powermem/utils/__init__.py +35 -0
- powermem/utils/utils.py +605 -0
- powermem/version.py +23 -0
- powermem-0.1.0.dist-info/METADATA +187 -0
- powermem-0.1.0.dist-info/RECORD +123 -0
- powermem-0.1.0.dist-info/WHEEL +5 -0
- powermem-0.1.0.dist-info/licenses/LICENSE +206 -0
- powermem-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1441 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OceanBase graph storage implementation
|
|
3
|
+
|
|
4
|
+
This module provides OceanBase-based graph storage for memory data.
|
|
5
|
+
"""
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from sqlalchemy import bindparam, text, MetaData, Column, String, Index, Table, BigInteger
|
|
13
|
+
from sqlalchemy.dialects.mysql import TIMESTAMP
|
|
14
|
+
from sqlalchemy.exc import SAWarning
|
|
15
|
+
|
|
16
|
+
# Suppress SQLAlchemy warnings about unknown schema content from pyobvector
|
|
17
|
+
# These warnings occur because SQLAlchemy doesn't recognize OceanBase VECTOR index syntax
|
|
18
|
+
# This is harmless as pyobvector handles VECTOR types correctly
|
|
19
|
+
warnings.filterwarnings(
|
|
20
|
+
"ignore",
|
|
21
|
+
message="Unknown schema content",
|
|
22
|
+
category=SAWarning,
|
|
23
|
+
module="pyobvector.*"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Suppress pkg_resources deprecation warning from jieba internal usage
|
|
27
|
+
# This warning is from jieba's internal use of deprecated pkg_resources API
|
|
28
|
+
warnings.filterwarnings(
|
|
29
|
+
"ignore",
|
|
30
|
+
message="pkg_resources is deprecated as an API",
|
|
31
|
+
category=UserWarning,
|
|
32
|
+
module="jieba.*"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
from pyobvector import ObVecClient, l2_distance, VECTOR, VecIndexType
|
|
36
|
+
|
|
37
|
+
from powermem.integrations import EmbedderFactory, LLMFactory
|
|
38
|
+
from powermem.storage.base import GraphStoreBase
|
|
39
|
+
from powermem.utils.utils import format_entities, remove_code_blocks, generate_snowflake_id
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from rank_bm25 import BM25Okapi
|
|
43
|
+
except ImportError:
|
|
44
|
+
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
|
45
|
+
|
|
46
|
+
from powermem.prompts import GraphPrompts, GraphToolsPrompts
|
|
47
|
+
|
|
48
|
+
from powermem.storage.oceanbase import constants
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
# Try to import jieba for better Chinese text segmentation
|
|
53
|
+
try:
|
|
54
|
+
import jieba
|
|
55
|
+
except ImportError:
|
|
56
|
+
logger.warning("jieba is not installed. Falling back to simple space-based tokenization. "
|
|
57
|
+
"Install jieba for better Chinese text segmentation: pip install jieba")
|
|
58
|
+
jieba = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MemoryGraph(GraphStoreBase):
|
|
62
|
+
"""OceanBase-based graph memory storage implementation."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, config: Any) -> None:
|
|
65
|
+
"""Initialize OceanBase graph memory.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: Memory configuration containing graph_store, embedder, and llm configs.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: If embedding_model_dims is not configured.
|
|
72
|
+
"""
|
|
73
|
+
self.config = config
|
|
74
|
+
|
|
75
|
+
# Get OceanBase config
|
|
76
|
+
ob_config = self.config.graph_store.config
|
|
77
|
+
|
|
78
|
+
# Helper function to get config value (supports both dict and object)
|
|
79
|
+
def get_config_value(key: str, default: Any = None) -> Any:
|
|
80
|
+
if isinstance(ob_config, dict):
|
|
81
|
+
return ob_config.get(key, default)
|
|
82
|
+
else:
|
|
83
|
+
return getattr(ob_config, key, default)
|
|
84
|
+
|
|
85
|
+
# Get embedding_model_dims (required)
|
|
86
|
+
embedding_model_dims = get_config_value("embedding_model_dims")
|
|
87
|
+
if embedding_model_dims is None:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"embedding_model_dims is required for OceanBase graph operations. "
|
|
90
|
+
"Please configure embedding_model_dims in your OceanBaseGraphConfig."
|
|
91
|
+
)
|
|
92
|
+
self.embedding_dims = embedding_model_dims
|
|
93
|
+
|
|
94
|
+
# Get vidx parameters with defaults.
|
|
95
|
+
self.index_type = get_config_value("index_type", constants.DEFAULT_INDEX_TYPE)
|
|
96
|
+
self.vidx_metric_type = get_config_value("vidx_metric_type",
|
|
97
|
+
constants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE)
|
|
98
|
+
self.vidx_name = get_config_value("vidx_name", constants.DEFAULT_VIDX_NAME)
|
|
99
|
+
|
|
100
|
+
# Get graph search parameters
|
|
101
|
+
self.max_hops = get_config_value("max_hops", 3)
|
|
102
|
+
|
|
103
|
+
# Set vidx_algo_params with defaults based on index_type.
|
|
104
|
+
self.vidx_algo_params = get_config_value("vidx_algo_params", None)
|
|
105
|
+
if not self.vidx_algo_params:
|
|
106
|
+
# Set default parameters based on index type.
|
|
107
|
+
self.vidx_algo_params = constants.get_default_build_params(self.index_type)
|
|
108
|
+
|
|
109
|
+
# Initialize embedding model
|
|
110
|
+
self.embedding_model = EmbedderFactory.create(
|
|
111
|
+
self.config.embedder.provider,
|
|
112
|
+
self.config.embedder.config,
|
|
113
|
+
self.config.vector_store.config,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Initialize OceanBase client
|
|
117
|
+
host = get_config_value("host", "localhost")
|
|
118
|
+
port = get_config_value("port", "2881")
|
|
119
|
+
user = get_config_value("user", "root")
|
|
120
|
+
password = get_config_value("password", "")
|
|
121
|
+
db_name = get_config_value("db_name", "test")
|
|
122
|
+
|
|
123
|
+
self.client = ObVecClient(
|
|
124
|
+
uri=f"{host}:{port}",
|
|
125
|
+
user=user,
|
|
126
|
+
password=password,
|
|
127
|
+
db_name=db_name,
|
|
128
|
+
)
|
|
129
|
+
self.engine = self.client.engine
|
|
130
|
+
self.metadata = MetaData()
|
|
131
|
+
|
|
132
|
+
# Create tables
|
|
133
|
+
self._create_tables()
|
|
134
|
+
|
|
135
|
+
# Initialize LLM.
|
|
136
|
+
self.llm_provider = self._get_llm_provider()
|
|
137
|
+
llm_config = self._get_llm_config()
|
|
138
|
+
self.llm = LLMFactory.create(self.llm_provider, llm_config)
|
|
139
|
+
|
|
140
|
+
# Initialize graph prompts and tools with config
|
|
141
|
+
# Pass graph_store config or full config to prompts
|
|
142
|
+
graph_config = {}
|
|
143
|
+
if self.config.graph_store:
|
|
144
|
+
# Convert GraphStoreConfig to dict if needed
|
|
145
|
+
if hasattr(self.config.graph_store, 'model_dump'):
|
|
146
|
+
graph_config = self.config.graph_store.model_dump()
|
|
147
|
+
elif isinstance(self.config.graph_store, dict):
|
|
148
|
+
graph_config = self.config.graph_store
|
|
149
|
+
# Also include full config for fallback
|
|
150
|
+
prompts_config = {"graph_store": graph_config}
|
|
151
|
+
# Merge top-level config if it's a dict
|
|
152
|
+
if isinstance(self.config, dict):
|
|
153
|
+
prompts_config.update(self.config)
|
|
154
|
+
|
|
155
|
+
self.graph_prompts = GraphPrompts(prompts_config)
|
|
156
|
+
self.graph_tools_prompts = GraphToolsPrompts(prompts_config)
|
|
157
|
+
|
|
158
|
+
def _get_llm_provider(self) -> str:
|
|
159
|
+
"""Get LLM provider from configuration with fallback.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
LLM provider name.
|
|
163
|
+
"""
|
|
164
|
+
# Check graph_store.llm.provider first
|
|
165
|
+
if (self.config.graph_store and
|
|
166
|
+
self.config.graph_store.llm and
|
|
167
|
+
self.config.graph_store.llm.provider):
|
|
168
|
+
return self.config.graph_store.llm.provider
|
|
169
|
+
|
|
170
|
+
# Check config.llm.provider
|
|
171
|
+
if self.config.llm and self.config.llm.provider:
|
|
172
|
+
return self.config.llm.provider
|
|
173
|
+
|
|
174
|
+
# Default fallback
|
|
175
|
+
return constants.DEFAULT_LLM_PROVIDER
|
|
176
|
+
|
|
177
|
+
def _get_llm_config(self) -> Optional[Any]:
|
|
178
|
+
"""Get LLM config from configuration.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
LLM configuration object or None.
|
|
182
|
+
"""
|
|
183
|
+
# Check graph_store.llm.config first
|
|
184
|
+
if (self.config.graph_store and
|
|
185
|
+
self.config.graph_store.llm and
|
|
186
|
+
hasattr(self.config.graph_store.llm, "config")):
|
|
187
|
+
return self.config.graph_store.llm.config
|
|
188
|
+
|
|
189
|
+
# Check config.llm.config
|
|
190
|
+
if hasattr(self.config.llm, "config"):
|
|
191
|
+
return self.config.llm.config
|
|
192
|
+
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
def _build_user_identity(self, filters: Dict[str, Any]) -> str:
|
|
196
|
+
"""Build user identity string from filters.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Formatted user identity string.
|
|
203
|
+
"""
|
|
204
|
+
identity_parts = [f"user_id: {filters['user_id']}"]
|
|
205
|
+
|
|
206
|
+
if filters.get("agent_id"):
|
|
207
|
+
identity_parts.append(f"agent_id: {filters['agent_id']}")
|
|
208
|
+
|
|
209
|
+
if filters.get("run_id"):
|
|
210
|
+
identity_parts.append(f"run_id: {filters['run_id']}")
|
|
211
|
+
|
|
212
|
+
return ", ".join(identity_parts)
|
|
213
|
+
|
|
214
|
+
def _build_filter_conditions(
|
|
215
|
+
self,
|
|
216
|
+
filters: Dict[str, Any],
|
|
217
|
+
prefix: str = ""
|
|
218
|
+
) -> Tuple[List[str], Dict[str, Any]]:
|
|
219
|
+
"""Build SQL filter conditions and parameters from filters.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
223
|
+
prefix: Optional prefix for column names (e.g., "r." for table alias).
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Tuple of (filter_conditions_list, params_dict).
|
|
227
|
+
"""
|
|
228
|
+
filter_parts = [f"{prefix}user_id = :user_id"]
|
|
229
|
+
params = {"user_id": filters["user_id"]}
|
|
230
|
+
|
|
231
|
+
if filters.get("agent_id"):
|
|
232
|
+
filter_parts.append(f"{prefix}agent_id = :agent_id")
|
|
233
|
+
params["agent_id"] = filters["agent_id"]
|
|
234
|
+
|
|
235
|
+
if filters.get("run_id"):
|
|
236
|
+
filter_parts.append(f"{prefix}run_id = :run_id")
|
|
237
|
+
params["run_id"] = filters["run_id"]
|
|
238
|
+
|
|
239
|
+
return filter_parts, params
|
|
240
|
+
|
|
241
|
+
@staticmethod
|
|
242
|
+
def _coerce_tool_response_to_dict(response: Any) -> Dict[str, Any]:
|
|
243
|
+
"""Ensure LLM tool response is a dict.
|
|
244
|
+
|
|
245
|
+
Some LLM providers may return a JSON string instead of a parsed dict. This helper
|
|
246
|
+
normalizes the response to a dictionary with safe fallbacks.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
response: LLM tool call response, may be dict or JSON string.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Normalized dictionary object, or empty dict if unparseable.
|
|
253
|
+
"""
|
|
254
|
+
if isinstance(response, dict):
|
|
255
|
+
return response
|
|
256
|
+
if isinstance(response, str):
|
|
257
|
+
try:
|
|
258
|
+
cleaned = remove_code_blocks(response)
|
|
259
|
+
except Exception:
|
|
260
|
+
cleaned = response
|
|
261
|
+
try:
|
|
262
|
+
parsed = json.loads(cleaned)
|
|
263
|
+
if isinstance(parsed, dict):
|
|
264
|
+
return parsed
|
|
265
|
+
except Exception:
|
|
266
|
+
pass
|
|
267
|
+
# Fallback to empty dict if un-parseable
|
|
268
|
+
return {}
|
|
269
|
+
|
|
270
|
+
def _create_tables(self) -> None:
|
|
271
|
+
"""Create graph entities and relationships tables if they don't exist.
|
|
272
|
+
|
|
273
|
+
Creates two tables:
|
|
274
|
+
- graph_entities: Stores entity nodes and their vector embeddings
|
|
275
|
+
- graph_relationships: Stores relationships between entities
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
if not self.client.check_table_exists(constants.TABLE_ENTITIES):
|
|
279
|
+
# Define columns for entities table
|
|
280
|
+
cols = [
|
|
281
|
+
Column("id", BigInteger, primary_key=True, autoincrement=False),
|
|
282
|
+
Column("name", String(255), nullable=False),
|
|
283
|
+
Column("entity_type", String(64)),
|
|
284
|
+
Column("embedding", VECTOR(self.embedding_dims)),
|
|
285
|
+
Column("created_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP")),
|
|
286
|
+
Column("updated_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
|
287
|
+
]
|
|
288
|
+
# Define regular indexes
|
|
289
|
+
indexes = [
|
|
290
|
+
Index("idx_name", "name"),
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
# Map index_type string to VecIndexType enum
|
|
294
|
+
index_type_map = constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES
|
|
295
|
+
|
|
296
|
+
# Create vector index parameters
|
|
297
|
+
vidx_params = self.client.prepare_index_params()
|
|
298
|
+
vidx_params.add_index(
|
|
299
|
+
field_name="embedding",
|
|
300
|
+
index_type=index_type_map.get(self.index_type, VecIndexType.HNSW),
|
|
301
|
+
index_name=self.vidx_name,
|
|
302
|
+
metric_type=self.vidx_metric_type,
|
|
303
|
+
params=self.vidx_algo_params,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Create table with vector index
|
|
307
|
+
self.client.create_table_with_index_params(
|
|
308
|
+
table_name=constants.TABLE_ENTITIES,
|
|
309
|
+
columns=cols,
|
|
310
|
+
indexes=indexes,
|
|
311
|
+
vidxs=vidx_params,
|
|
312
|
+
partitions=None,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
logger.info("%s table created successfully", constants.TABLE_ENTITIES)
|
|
316
|
+
else:
|
|
317
|
+
logger.info("%s table already exists", constants.TABLE_ENTITIES)
|
|
318
|
+
# Check vector dimension consistency
|
|
319
|
+
existing_dim = self._get_existing_vector_dimension_for_entities()
|
|
320
|
+
if existing_dim is not None and existing_dim != self.embedding_dims:
|
|
321
|
+
raise ValueError(
|
|
322
|
+
f"Vector dimension mismatch: existing table '{constants.TABLE_ENTITIES}' has "
|
|
323
|
+
f"vector dimension {existing_dim}, but requested dimension is {self.embedding_dims}. "
|
|
324
|
+
f"Please use a different configuration or reset the graph."
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Create relationships table using pyobvector API
|
|
328
|
+
if not self.client.check_table_exists(constants.TABLE_RELATIONSHIPS):
|
|
329
|
+
|
|
330
|
+
# Define columns for relationships table
|
|
331
|
+
cols = [
|
|
332
|
+
Column("id", BigInteger, primary_key=True, autoincrement=False),
|
|
333
|
+
Column("source_entity_id", BigInteger, nullable=False),
|
|
334
|
+
Column("relationship_type", String(128), nullable=False),
|
|
335
|
+
Column("destination_entity_id", BigInteger, nullable=False),
|
|
336
|
+
Column("user_id", String(128)),
|
|
337
|
+
Column("agent_id", String(128)),
|
|
338
|
+
Column("run_id", String(128)),
|
|
339
|
+
Column("created_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP")),
|
|
340
|
+
Column("updated_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
# Define regular indexes
|
|
344
|
+
indexes = [
|
|
345
|
+
Index("idx_r_covering", "user_id","source_entity_id", "destination_entity_id","relationship_type"),
|
|
346
|
+
]
|
|
347
|
+
|
|
348
|
+
# Create table without vector index (relationships table has no vectors)
|
|
349
|
+
self.client.create_table_with_index_params(
|
|
350
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
351
|
+
columns=cols,
|
|
352
|
+
indexes=indexes,
|
|
353
|
+
vidxs=None,
|
|
354
|
+
partitions=None,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
logger.info("%s table created successfully", constants.TABLE_RELATIONSHIPS)
|
|
358
|
+
else:
|
|
359
|
+
logger.info("%s table already exists", constants.TABLE_RELATIONSHIPS)
|
|
360
|
+
|
|
361
|
+
def _get_existing_vector_dimension_for_entities(self) -> Optional[int]:
|
|
362
|
+
"""Get the dimension of the existing vector field in entities table.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Dimension of the vector field, or None if table doesn't exist or field not found.
|
|
366
|
+
"""
|
|
367
|
+
if not self.client.check_table_exists(constants.TABLE_ENTITIES):
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
with self.engine.connect() as conn:
|
|
372
|
+
result = conn.execute(text(f"DESCRIBE {constants.TABLE_ENTITIES}"))
|
|
373
|
+
columns = result.fetchall()
|
|
374
|
+
|
|
375
|
+
for col in columns:
|
|
376
|
+
if col[0] == "embedding":
|
|
377
|
+
col_type = col[1]
|
|
378
|
+
if col_type.startswith("VECTOR(") and col_type.endswith(")"):
|
|
379
|
+
dim_str = col_type[7:-1]
|
|
380
|
+
return int(dim_str)
|
|
381
|
+
return None
|
|
382
|
+
except Exception as e:
|
|
383
|
+
logger.warning(
|
|
384
|
+
"Failed to get vector dimension for %s: %s",
|
|
385
|
+
constants.TABLE_ENTITIES,
|
|
386
|
+
e
|
|
387
|
+
)
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
def _build_where_clause_with_filters(
|
|
391
|
+
self,
|
|
392
|
+
filters: Dict[str, Any],
|
|
393
|
+
prefix: str = "",
|
|
394
|
+
additional_params: Optional[Dict[str, Any]] = None
|
|
395
|
+
) -> Tuple[List[Any], Dict[str, Any]]:
|
|
396
|
+
"""Build where clause and parameters from filters using bindparam.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
400
|
+
prefix: Optional prefix for column names (e.g., "r." for table alias).
|
|
401
|
+
additional_params: Optional additional parameters to include.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
Tuple of (where_clause_list, params_dict).
|
|
405
|
+
"""
|
|
406
|
+
where_conditions = []
|
|
407
|
+
params = {"user_id": filters["user_id"]}
|
|
408
|
+
where_conditions.append(f"{prefix}user_id = :user_id")
|
|
409
|
+
|
|
410
|
+
if filters.get("agent_id"):
|
|
411
|
+
where_conditions.append(f"{prefix}agent_id = :agent_id")
|
|
412
|
+
params["agent_id"] = filters["agent_id"]
|
|
413
|
+
if filters.get("run_id"):
|
|
414
|
+
where_conditions.append(f"{prefix}run_id = :run_id")
|
|
415
|
+
params["run_id"] = filters["run_id"]
|
|
416
|
+
|
|
417
|
+
# Add additional params if provided
|
|
418
|
+
if additional_params:
|
|
419
|
+
params.update(additional_params)
|
|
420
|
+
|
|
421
|
+
where_sql = " AND ".join(where_conditions)
|
|
422
|
+
where_clause_with_params = text(where_sql).bindparams(**params)
|
|
423
|
+
return [where_clause_with_params], params
|
|
424
|
+
|
|
425
|
+
def add(self, data: str, filters: Dict[str, Any]) -> Dict[str, Any]:
|
|
426
|
+
"""Add data to the graph.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
data: The data to add to the graph.
|
|
430
|
+
filters: Dictionary containing filters (user_id, agent_id, run_id).
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
Dictionary containing deleted_entities and added_entities.
|
|
434
|
+
"""
|
|
435
|
+
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
|
436
|
+
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
|
437
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
|
438
|
+
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
|
439
|
+
|
|
440
|
+
deleted_entities = self._delete_entities(to_be_deleted, filters)
|
|
441
|
+
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
|
|
442
|
+
logger.debug("Deleted entities: %s, Added entities: %s", deleted_entities, added_entities)
|
|
443
|
+
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
|
444
|
+
|
|
445
|
+
def search(self, query: str, filters: Dict[str, Any], limit: int = 100) -> List[Dict[str, str]]:
|
|
446
|
+
"""Search for memories and related graph data.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
query: Query to search for.
|
|
450
|
+
filters: Dictionary containing filters (user_id, agent_id, run_id).
|
|
451
|
+
limit: Maximum number of nodes and relationships to retrieve. Defaults to 100.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
List of search results containing source, relationship, and destination.
|
|
455
|
+
"""
|
|
456
|
+
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
|
457
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters, limit=limit)
|
|
458
|
+
|
|
459
|
+
if not search_output:
|
|
460
|
+
return []
|
|
461
|
+
|
|
462
|
+
# Tokenize search outputs for BM25 with improved segmentation
|
|
463
|
+
search_outputs_sequence = []
|
|
464
|
+
for item in search_output:
|
|
465
|
+
# Combine source, relationship, destination into a single text for better tokenization
|
|
466
|
+
combined_text = f"{item['source']} {item['relationship']} {item['destination']}"
|
|
467
|
+
tokenized_item = self._tokenize_text(combined_text)
|
|
468
|
+
search_outputs_sequence.append(tokenized_item)
|
|
469
|
+
|
|
470
|
+
bm25 = BM25Okapi(search_outputs_sequence)
|
|
471
|
+
|
|
472
|
+
# Tokenize query using the same method
|
|
473
|
+
tokenized_query = self._tokenize_text(query)
|
|
474
|
+
|
|
475
|
+
# Get top N results based on BM25 scores
|
|
476
|
+
scores = bm25.get_scores(tokenized_query)
|
|
477
|
+
# Get indices sorted by score (descending)
|
|
478
|
+
sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
|
479
|
+
top_n_indices = sorted_indices[:constants.DEFAULT_BM25_TOP_N]
|
|
480
|
+
|
|
481
|
+
# Build reranked results
|
|
482
|
+
search_results = []
|
|
483
|
+
for idx in top_n_indices:
|
|
484
|
+
if idx < len(search_output):
|
|
485
|
+
item = search_output[idx]
|
|
486
|
+
search_results.append({
|
|
487
|
+
"source": item["source"],
|
|
488
|
+
"relationship": item["relationship"],
|
|
489
|
+
"destination": item["destination"]
|
|
490
|
+
})
|
|
491
|
+
|
|
492
|
+
logger.info("Returned %d search results (from %d candidates)", len(search_results), len(search_output))
|
|
493
|
+
|
|
494
|
+
return search_results
|
|
495
|
+
|
|
496
|
+
def _tokenize_text(self, text: str) -> List[str]:
|
|
497
|
+
"""Tokenize text using jieba for Chinese or simple split for other languages.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
text: Text to tokenize.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
List of tokens.
|
|
504
|
+
"""
|
|
505
|
+
if jieba is not None:
|
|
506
|
+
# Use jieba for Chinese text segmentation
|
|
507
|
+
# Convert to lowercase for better matching
|
|
508
|
+
tokens = list(jieba.cut(text.lower()))
|
|
509
|
+
# Filter out empty strings and single spaces
|
|
510
|
+
tokens = [t for t in tokens if t.strip()]
|
|
511
|
+
return tokens
|
|
512
|
+
else:
|
|
513
|
+
# Fallback to simple space-based tokenization
|
|
514
|
+
return text.lower().split()
|
|
515
|
+
|
|
516
|
+
def delete_all(self, filters: Dict[str, Any]) -> None:
|
|
517
|
+
"""Delete all graph data for the given filters.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
filters: Filters containing user_id, agent_id, run_id.
|
|
521
|
+
"""
|
|
522
|
+
where_clause, _ = self._build_where_clause_with_filters(filters)
|
|
523
|
+
|
|
524
|
+
try:
|
|
525
|
+
relationships_results = self.client.get(
|
|
526
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
527
|
+
ids=None,
|
|
528
|
+
output_column_name=["id", "source_entity_id", "destination_entity_id"],
|
|
529
|
+
where_clause=where_clause
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Collect unique entity IDs from relationships
|
|
533
|
+
entity_ids = set()
|
|
534
|
+
for rel in relationships_results.fetchall():
|
|
535
|
+
entity_ids.add(rel[1]) # source_entity_id
|
|
536
|
+
entity_ids.add(rel[2]) # destination_entity_id
|
|
537
|
+
|
|
538
|
+
# Delete the relationships
|
|
539
|
+
self.client.delete(
|
|
540
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
541
|
+
where_clause=where_clause
|
|
542
|
+
)
|
|
543
|
+
logger.info("Deleted relationships for filters: %s", filters)
|
|
544
|
+
|
|
545
|
+
# Delete entities that were part of these relationships
|
|
546
|
+
if entity_ids:
|
|
547
|
+
self.client.delete(
|
|
548
|
+
table_name=constants.TABLE_ENTITIES,
|
|
549
|
+
ids=list(entity_ids),
|
|
550
|
+
)
|
|
551
|
+
logger.info("Deleted %d entities for filters: %s", len(entity_ids), filters)
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.warning("Error deleting graph data: %s", e)
|
|
554
|
+
|
|
555
|
+
logger.info("Deleted all graph data for filters: %s", filters)
|
|
556
|
+
|
|
557
|
+
def get_all(self, filters: Dict[str, Any], limit: int = 100) -> List[Dict[str, str]]:
|
|
558
|
+
"""Retrieve all nodes and relationships from the graph database.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
filters: Dictionary containing filters (user_id, agent_id, run_id).
|
|
562
|
+
limit: Maximum number of relationships to retrieve. Defaults to 100.
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
List of dictionaries containing source, relationship, and target.
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
where_clause, params = self._build_where_clause_with_filters(filters)
|
|
569
|
+
|
|
570
|
+
relationships_results = self.client.get(
|
|
571
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
572
|
+
ids=None,
|
|
573
|
+
output_column_name=["id", "source_entity_id", "relationship_type", "destination_entity_id", "updated_at"],
|
|
574
|
+
where_clause=where_clause
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
relationships = relationships_results.fetchall()
|
|
578
|
+
if not relationships:
|
|
579
|
+
return []
|
|
580
|
+
|
|
581
|
+
# Limit results if needed
|
|
582
|
+
if len(relationships) > limit:
|
|
583
|
+
relationships = relationships[:limit]
|
|
584
|
+
|
|
585
|
+
# Extract unique entity IDs from relationships
|
|
586
|
+
entity_ids = set()
|
|
587
|
+
for rel in relationships:
|
|
588
|
+
entity_ids.add(rel[1]) # source_entity_id
|
|
589
|
+
entity_ids.add(rel[3]) # destination_entity_id
|
|
590
|
+
|
|
591
|
+
# Get all entities that are referenced in relationships
|
|
592
|
+
entities_results = self.client.get(
|
|
593
|
+
table_name=constants.TABLE_ENTITIES,
|
|
594
|
+
ids=list(entity_ids),
|
|
595
|
+
output_column_name=["id", "name"]
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Create a mapping from entity_id to entity_name
|
|
599
|
+
entity_map = {entity[0]: entity[1] for entity in entities_results.fetchall()}
|
|
600
|
+
|
|
601
|
+
# Build final results with updated_at for sorting
|
|
602
|
+
final_results = []
|
|
603
|
+
for rel in relationships:
|
|
604
|
+
rel_id, source_id, relationship_type, dest_id, updated_at = rel
|
|
605
|
+
|
|
606
|
+
source_name = entity_map.get(source_id, f"Unknown_{source_id}")
|
|
607
|
+
dest_name = entity_map.get(dest_id, f"Unknown_{dest_id}")
|
|
608
|
+
|
|
609
|
+
final_results.append({
|
|
610
|
+
"source": source_name,
|
|
611
|
+
"relationship": relationship_type,
|
|
612
|
+
"target": dest_name,
|
|
613
|
+
"_updated_at": updated_at, # Keep for sorting
|
|
614
|
+
})
|
|
615
|
+
|
|
616
|
+
# Sort by updated_at (descending)
|
|
617
|
+
final_results.sort(key=lambda x: x["_updated_at"], reverse=True)
|
|
618
|
+
|
|
619
|
+
# Remove the temporary _updated_at field
|
|
620
|
+
for result in final_results:
|
|
621
|
+
del result["_updated_at"]
|
|
622
|
+
|
|
623
|
+
logger.info("Retrieved %d relationships", len(final_results))
|
|
624
|
+
return final_results
|
|
625
|
+
|
|
626
|
+
def _retrieve_nodes_from_data(self, data: str, filters: Dict[str, Any]) -> Dict[str, str]:
|
|
627
|
+
"""Extract all the entities mentioned in the query.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
data: Input text to extract entities from.
|
|
631
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
Dictionary mapping entity names to entity types.
|
|
635
|
+
"""
|
|
636
|
+
_tools = [self.graph_tools_prompts.get_extract_entities_tool()]
|
|
637
|
+
if constants.is_structured_llm_provider(self.llm_provider):
|
|
638
|
+
_tools = [self.graph_tools_prompts.get_extract_entities_tool(structured=True)]
|
|
639
|
+
|
|
640
|
+
search_results = self.llm.generate_response(
|
|
641
|
+
messages=[
|
|
642
|
+
{
|
|
643
|
+
"role": "system",
|
|
644
|
+
"content": (
|
|
645
|
+
f"You are a smart assistant who understands entities and their types in a given text. "
|
|
646
|
+
f"If user message contains self reference such as 'I', 'me', 'my' etc. "
|
|
647
|
+
f"then use {filters['user_id']} as the source entity. "
|
|
648
|
+
f"Extract all the entities from the text. "
|
|
649
|
+
f"***DO NOT*** answer the question itself if the given text is a question."
|
|
650
|
+
),
|
|
651
|
+
},
|
|
652
|
+
{"role": "user", "content": data},
|
|
653
|
+
],
|
|
654
|
+
tools=_tools,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
# Normalize potential string response to dict
|
|
658
|
+
search_results = self._coerce_tool_response_to_dict(search_results)
|
|
659
|
+
|
|
660
|
+
entity_type_map = {}
|
|
661
|
+
|
|
662
|
+
try:
|
|
663
|
+
for tool_call in search_results.get("tool_calls", []):
|
|
664
|
+
if tool_call["name"] != "extract_entities":
|
|
665
|
+
continue
|
|
666
|
+
for item in tool_call["arguments"]["entities"]:
|
|
667
|
+
entity_type_map[item["entity"]] = item["entity_type"]
|
|
668
|
+
except Exception as e:
|
|
669
|
+
logger.exception(
|
|
670
|
+
"Error in search tool: %s, llm_provider=%s, search_results=%s",
|
|
671
|
+
e,
|
|
672
|
+
self.llm_provider,
|
|
673
|
+
search_results
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
entity_type_map = {
|
|
677
|
+
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
|
|
678
|
+
for k, v in entity_type_map.items()
|
|
679
|
+
}
|
|
680
|
+
logger.debug("Entity type map: %s\n search_results=%s", entity_type_map, search_results)
|
|
681
|
+
return entity_type_map
|
|
682
|
+
|
|
683
|
+
def _establish_nodes_relations_from_data(
|
|
684
|
+
self,
|
|
685
|
+
data: str,
|
|
686
|
+
filters: Dict[str, Any],
|
|
687
|
+
entity_type_map: Dict[str, str]
|
|
688
|
+
) -> List[Dict[str, str]]:
|
|
689
|
+
"""Establish relations among the extracted nodes.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
data: Input text to extract relationships from.
|
|
693
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
694
|
+
entity_type_map: Mapping of entity names to types.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
List of dictionaries containing source, destination, and relationship.
|
|
698
|
+
"""
|
|
699
|
+
user_identity = self._build_user_identity(filters)
|
|
700
|
+
|
|
701
|
+
if self.config.graph_store.custom_prompt:
|
|
702
|
+
system_content = self.graph_prompts.get_system_prompt("extract_relations")
|
|
703
|
+
system_content = system_content.replace("USER_ID", user_identity)
|
|
704
|
+
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
|
|
705
|
+
messages = [
|
|
706
|
+
{"role": "system", "content": system_content},
|
|
707
|
+
{"role": "user", "content": data},
|
|
708
|
+
]
|
|
709
|
+
else:
|
|
710
|
+
system_content = self.graph_prompts.get_system_prompt("extract_relations")
|
|
711
|
+
system_content = system_content.replace("USER_ID", user_identity)
|
|
712
|
+
messages = [
|
|
713
|
+
{"role": "system", "content": system_content},
|
|
714
|
+
{
|
|
715
|
+
"role": "user",
|
|
716
|
+
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"
|
|
717
|
+
},
|
|
718
|
+
]
|
|
719
|
+
|
|
720
|
+
_tools = [self.graph_tools_prompts.get_relations_tool()]
|
|
721
|
+
if constants.is_structured_llm_provider(self.llm_provider):
|
|
722
|
+
_tools = [self.graph_tools_prompts.get_relations_tool(structured=True)]
|
|
723
|
+
|
|
724
|
+
extracted_entities = self.llm.generate_response(
|
|
725
|
+
messages=messages,
|
|
726
|
+
tools=_tools,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Normalize to dict for consistent access
|
|
730
|
+
extracted_entities = self._coerce_tool_response_to_dict(extracted_entities)
|
|
731
|
+
|
|
732
|
+
entities = []
|
|
733
|
+
if extracted_entities.get("tool_calls"):
|
|
734
|
+
first_call = (
|
|
735
|
+
extracted_entities["tool_calls"][0]
|
|
736
|
+
if extracted_entities["tool_calls"]
|
|
737
|
+
else {}
|
|
738
|
+
)
|
|
739
|
+
entities = first_call.get("arguments", {}).get("entities", [])
|
|
740
|
+
|
|
741
|
+
entities = self._remove_spaces_from_entities(entities)
|
|
742
|
+
logger.debug("Extracted entities: %s", entities)
|
|
743
|
+
return entities
|
|
744
|
+
|
|
745
|
+
def _search_graph_db(
|
|
746
|
+
self,
|
|
747
|
+
node_list: List[str],
|
|
748
|
+
filters: Dict[str, Any],
|
|
749
|
+
limit: int = 100
|
|
750
|
+
) -> List[Dict[str, Any]]:
|
|
751
|
+
"""Search similar nodes and their relationships using vector similarity with multi-hop support.
|
|
752
|
+
|
|
753
|
+
Supports up to 3-hop graph traversal using recursive CTE. Results are prioritized by
|
|
754
|
+
path length (1-hop first, then 2-hop, then 3-hop).
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
node_list: List of node names to search for.
|
|
758
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
759
|
+
limit: Maximum number of results to return. Defaults to 100.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
List of dictionaries containing source, relationship, destination and their IDs.
|
|
763
|
+
"""
|
|
764
|
+
result_relations = []
|
|
765
|
+
|
|
766
|
+
for node in node_list:
|
|
767
|
+
n_embedding = self.embedding_model.embed(node)
|
|
768
|
+
|
|
769
|
+
entities = self._search_node(None, n_embedding, filters, limit=limit)
|
|
770
|
+
|
|
771
|
+
if not entities:
|
|
772
|
+
continue
|
|
773
|
+
|
|
774
|
+
# Ensure entities is always a list
|
|
775
|
+
if isinstance(entities, dict):
|
|
776
|
+
entities = [entities]
|
|
777
|
+
|
|
778
|
+
entity_ids = [e.get("id") for e in entities]
|
|
779
|
+
|
|
780
|
+
# Use multi-hop search with early stopping
|
|
781
|
+
multi_hop_results = self._multi_hop_search(entity_ids, filters, limit)
|
|
782
|
+
result_relations.extend(multi_hop_results)
|
|
783
|
+
|
|
784
|
+
return result_relations
|
|
785
|
+
|
|
786
|
+
def _execute_single_hop_query(
|
|
787
|
+
self,
|
|
788
|
+
source_entity_ids: List[int],
|
|
789
|
+
filters: Dict[str, Any],
|
|
790
|
+
hop_number: int,
|
|
791
|
+
visited_edges: set = None,
|
|
792
|
+
conn=None,
|
|
793
|
+
max_edges_per_hop: int = 1000
|
|
794
|
+
) -> List[Dict[str, Any]]:
|
|
795
|
+
"""Execute a single hop query from given source entities.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
source_entity_ids: List of source entity IDs to start from.
|
|
799
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
800
|
+
hop_number: Current hop number (for result annotation).
|
|
801
|
+
visited_edges: Set of visited edges (source_id, dest_id) to avoid cycles.
|
|
802
|
+
conn: Optional database connection to use. If None, creates a new connection.
|
|
803
|
+
max_edges_per_hop: Maximum number of edges to retrieve per hop. Defaults to 1000.
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
List of relationship dictionaries with hop_count.
|
|
807
|
+
|
|
808
|
+
Note:
|
|
809
|
+
- Prevents memory explosion from high-degree nodes
|
|
810
|
+
- Results are sorted by created_at DESC before limiting
|
|
811
|
+
- This ensures most recent edges are retrieved first
|
|
812
|
+
"""
|
|
813
|
+
if not source_entity_ids:
|
|
814
|
+
return []
|
|
815
|
+
|
|
816
|
+
if visited_edges is None:
|
|
817
|
+
visited_edges = set()
|
|
818
|
+
|
|
819
|
+
# Build filter conditions
|
|
820
|
+
filter_parts, params = self._build_filter_conditions(filters, prefix="")
|
|
821
|
+
filter_conditions = " AND ".join(filter_parts)
|
|
822
|
+
|
|
823
|
+
# Build query with LIMIT to prevent memory explosion from high-degree nodes
|
|
824
|
+
query = f"""
|
|
825
|
+
SELECT
|
|
826
|
+
e1.name AS source,
|
|
827
|
+
r.source_entity_id,
|
|
828
|
+
r.relationship_type,
|
|
829
|
+
r.id AS relation_id,
|
|
830
|
+
e2.name AS destination,
|
|
831
|
+
r.destination_entity_id
|
|
832
|
+
FROM
|
|
833
|
+
(
|
|
834
|
+
SELECT
|
|
835
|
+
id,
|
|
836
|
+
source_entity_id,
|
|
837
|
+
destination_entity_id,
|
|
838
|
+
relationship_type,
|
|
839
|
+
created_at,
|
|
840
|
+
updated_at,
|
|
841
|
+
user_id
|
|
842
|
+
FROM {constants.TABLE_RELATIONSHIPS}
|
|
843
|
+
WHERE
|
|
844
|
+
source_entity_id IN :entity_ids
|
|
845
|
+
AND {filter_conditions}
|
|
846
|
+
ORDER BY updated_at DESC, created_at DESC
|
|
847
|
+
LIMIT :max_edges_per_hop
|
|
848
|
+
) AS r
|
|
849
|
+
JOIN {constants.TABLE_ENTITIES} e1 ON r.source_entity_id = e1.id
|
|
850
|
+
JOIN {constants.TABLE_ENTITIES} e2 ON r.destination_entity_id = e2.id;
|
|
851
|
+
"""
|
|
852
|
+
|
|
853
|
+
# Add parameters
|
|
854
|
+
params["entity_ids"] = tuple(source_entity_ids)
|
|
855
|
+
params["max_edges_per_hop"] = max_edges_per_hop
|
|
856
|
+
logger.debug("Executing hop %d with max_edges_per_hop=%d\n query: %s\n params: %s",
|
|
857
|
+
hop_number, max_edges_per_hop, query, params)
|
|
858
|
+
|
|
859
|
+
# Execute query - use provided connection or create new one
|
|
860
|
+
if conn is not None:
|
|
861
|
+
# Reuse existing connection (transactional)
|
|
862
|
+
result = conn.execute(text(query), params)
|
|
863
|
+
rows = result.fetchall()
|
|
864
|
+
else:
|
|
865
|
+
# Create new connection (backward compatibility)
|
|
866
|
+
with self.engine.connect() as new_conn:
|
|
867
|
+
result = new_conn.execute(text(query), params)
|
|
868
|
+
rows = result.fetchall()
|
|
869
|
+
|
|
870
|
+
# Format results and filter out cycles
|
|
871
|
+
formatted_results = []
|
|
872
|
+
for row in rows:
|
|
873
|
+
source_id = row[1]
|
|
874
|
+
dest_id = row[5]
|
|
875
|
+
edge_key = (source_id, dest_id)
|
|
876
|
+
|
|
877
|
+
# Skip if this edge was already visited (cycle detection)
|
|
878
|
+
if edge_key in visited_edges:
|
|
879
|
+
continue
|
|
880
|
+
|
|
881
|
+
formatted_results.append({
|
|
882
|
+
"source": row[0],
|
|
883
|
+
"source_id": source_id,
|
|
884
|
+
"relationship": row[2],
|
|
885
|
+
"relation_id": row[3],
|
|
886
|
+
"destination": row[4],
|
|
887
|
+
"destination_id": dest_id,
|
|
888
|
+
"hop_count": hop_number,
|
|
889
|
+
})
|
|
890
|
+
|
|
891
|
+
return formatted_results
|
|
892
|
+
|
|
893
|
+
def _multi_hop_search(
|
|
894
|
+
self,
|
|
895
|
+
entity_ids: List[int],
|
|
896
|
+
filters: Dict[str, Any],
|
|
897
|
+
limit: int
|
|
898
|
+
) -> List[Dict[str, Any]]:
|
|
899
|
+
"""Perform multi-hop graph search with application-level early stopping.
|
|
900
|
+
|
|
901
|
+
Args:
|
|
902
|
+
entity_ids: List of seed entity IDs to start traversal from.
|
|
903
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
904
|
+
limit: Maximum number of results to return.
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
List of dictionaries containing source, relationship, destination and their IDs.
|
|
908
|
+
|
|
909
|
+
Note:
|
|
910
|
+
Application-level optimization strategy:
|
|
911
|
+
1. Execute 1-hop query first
|
|
912
|
+
2. Check if results satisfy limit - if yes, return immediately
|
|
913
|
+
3. If not, execute 2-hop query with cycle prevention
|
|
914
|
+
4. Accumulate results and check limit again
|
|
915
|
+
5. Continue until limit is satisfied or max_hops is reached
|
|
916
|
+
"""
|
|
917
|
+
if not entity_ids:
|
|
918
|
+
return []
|
|
919
|
+
|
|
920
|
+
# Use a transaction to ensure consistent reads across all hops
|
|
921
|
+
# This prevents phantom reads and non-repeatable reads during multi-hop traversal
|
|
922
|
+
with self.engine.begin() as conn:
|
|
923
|
+
logger.debug("Started transaction for multi-hop search")
|
|
924
|
+
|
|
925
|
+
all_results = []
|
|
926
|
+
visited_edges = set() # Track visited edges to prevent cycles
|
|
927
|
+
visited_nodes = set(entity_ids) # Track all visited nodes (start with seed entities)
|
|
928
|
+
current_source_ids = entity_ids # Start from seed entities
|
|
929
|
+
|
|
930
|
+
# Iteratively execute each hop until limit is satisfied or max_hops reached
|
|
931
|
+
for hop in range(1, self.max_hops + 1):
|
|
932
|
+
# Execute single hop query within the same transaction
|
|
933
|
+
hop_results = self._execute_single_hop_query(
|
|
934
|
+
source_entity_ids=current_source_ids,
|
|
935
|
+
filters=filters,
|
|
936
|
+
hop_number=hop,
|
|
937
|
+
visited_edges=visited_edges,
|
|
938
|
+
conn=conn
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
# If no results at this hop, stop early
|
|
942
|
+
if not hop_results:
|
|
943
|
+
logger.info("STOP early: No results at hop %s", hop)
|
|
944
|
+
break
|
|
945
|
+
|
|
946
|
+
# Add results to accumulator
|
|
947
|
+
all_results.extend(hop_results)
|
|
948
|
+
|
|
949
|
+
# Update visited edges to prevent cycles in next hop
|
|
950
|
+
for result in hop_results:
|
|
951
|
+
edge_key = (result["source_id"], result["destination_id"])
|
|
952
|
+
visited_edges.add(edge_key)
|
|
953
|
+
|
|
954
|
+
# Check if we've satisfied the limit (early stopping)
|
|
955
|
+
if len(all_results) >= limit:
|
|
956
|
+
logger.info("STOP early: Limit satisfied at hop %s", hop)
|
|
957
|
+
# Truncate to exact limit and return
|
|
958
|
+
return all_results[:limit]
|
|
959
|
+
|
|
960
|
+
# Prepare source IDs for next hop (destination entities become new sources)
|
|
961
|
+
next_source_ids = set([r["destination_id"] for r in hop_results])
|
|
962
|
+
|
|
963
|
+
# Check if we have any new nodes that haven't been visited
|
|
964
|
+
new_nodes = next_source_ids - visited_nodes
|
|
965
|
+
|
|
966
|
+
# If no new nodes, all destinations are already visited - stop early
|
|
967
|
+
# This means we've exhausted all reachable nodes in the graph
|
|
968
|
+
if not new_nodes:
|
|
969
|
+
logger.info("STOP early: All destinations are already visited at hop %s", hop)
|
|
970
|
+
break
|
|
971
|
+
|
|
972
|
+
# Update visited nodes and prepare for next hop
|
|
973
|
+
visited_nodes.update(next_source_ids)
|
|
974
|
+
current_source_ids = list(next_source_ids)
|
|
975
|
+
|
|
976
|
+
# Return all accumulated results
|
|
977
|
+
logger.debug("Transaction completed for multi-hop search, returning %d results", len(all_results))
|
|
978
|
+
return all_results
|
|
979
|
+
|
|
980
|
+
def _get_delete_entities_from_search_output(
|
|
981
|
+
self,
|
|
982
|
+
search_output: List[Dict[str, Any]],
|
|
983
|
+
data: str,
|
|
984
|
+
filters: Dict[str, Any]
|
|
985
|
+
) -> List[Dict[str, str]]:
|
|
986
|
+
"""Get the entities to be deleted from the search output.
|
|
987
|
+
|
|
988
|
+
Args:
|
|
989
|
+
search_output: Search results from graph database.
|
|
990
|
+
data: New input data to compare against.
|
|
991
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
List of dictionaries containing source, destination, and relationship to delete.
|
|
995
|
+
"""
|
|
996
|
+
search_output_string = format_entities(search_output)
|
|
997
|
+
user_identity = self._build_user_identity(filters)
|
|
998
|
+
system_prompt, user_prompt = self.graph_prompts.get_delete_relations_prompt(search_output_string, data, user_identity)
|
|
999
|
+
|
|
1000
|
+
_tools = [self.graph_tools_prompts.get_delete_tool()]
|
|
1001
|
+
if constants.is_structured_llm_provider(self.llm_provider):
|
|
1002
|
+
_tools = [self.graph_tools_prompts.get_delete_tool(structured=True)]
|
|
1003
|
+
|
|
1004
|
+
memory_updates = self.llm.generate_response(
|
|
1005
|
+
messages=[
|
|
1006
|
+
{"role": "system", "content": system_prompt},
|
|
1007
|
+
{"role": "user", "content": user_prompt},
|
|
1008
|
+
],
|
|
1009
|
+
tools=_tools,
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
# Normalize to dict before access
|
|
1013
|
+
memory_updates = self._coerce_tool_response_to_dict(memory_updates)
|
|
1014
|
+
|
|
1015
|
+
to_be_deleted = []
|
|
1016
|
+
for item in memory_updates.get("tool_calls", []):
|
|
1017
|
+
if item.get("name") == "delete_graph_memory":
|
|
1018
|
+
to_be_deleted.append(item.get("arguments"))
|
|
1019
|
+
|
|
1020
|
+
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
|
1021
|
+
logger.debug("Deleted relationships: %s", to_be_deleted)
|
|
1022
|
+
return to_be_deleted
|
|
1023
|
+
|
|
1024
|
+
def _delete_entities(
|
|
1025
|
+
self,
|
|
1026
|
+
to_be_deleted: List[Dict[str, str]],
|
|
1027
|
+
filters: Dict[str, Any]
|
|
1028
|
+
) -> List[Dict[str, int]]:
|
|
1029
|
+
"""Delete the specified relationships from the graph.
|
|
1030
|
+
|
|
1031
|
+
Args:
|
|
1032
|
+
to_be_deleted: List of relationships to delete with source, destination, relationship.
|
|
1033
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1034
|
+
|
|
1035
|
+
Returns:
|
|
1036
|
+
List of dictionaries containing deleted_count for each deletion operation.
|
|
1037
|
+
"""
|
|
1038
|
+
results = []
|
|
1039
|
+
|
|
1040
|
+
for item in to_be_deleted:
|
|
1041
|
+
source = item["source"]
|
|
1042
|
+
destination = item["destination"]
|
|
1043
|
+
relationship = item["relationship"]
|
|
1044
|
+
|
|
1045
|
+
# First, find the source and destination entities by name
|
|
1046
|
+
source_entities = self.client.get(
|
|
1047
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1048
|
+
ids=None,
|
|
1049
|
+
output_column_name=["id", "name"],
|
|
1050
|
+
where_clause=[text(f"name = :source_name").bindparams(
|
|
1051
|
+
bindparam("source_name", source)
|
|
1052
|
+
)]
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
dest_entities = self.client.get(
|
|
1056
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1057
|
+
ids=None,
|
|
1058
|
+
output_column_name=["id", "name"],
|
|
1059
|
+
where_clause=[text(f"name = :dest_name").bindparams(
|
|
1060
|
+
bindparam("dest_name", destination)
|
|
1061
|
+
)]
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
# Get entity IDs
|
|
1065
|
+
source_rows = source_entities.fetchall() if source_entities else []
|
|
1066
|
+
dest_rows = dest_entities.fetchall() if dest_entities else []
|
|
1067
|
+
|
|
1068
|
+
source_ids = [e[0] for e in source_rows]
|
|
1069
|
+
dest_ids = [e[0] for e in dest_rows]
|
|
1070
|
+
|
|
1071
|
+
# Check if we found any entities
|
|
1072
|
+
if not source_ids or not dest_ids:
|
|
1073
|
+
logger.warning(
|
|
1074
|
+
"Could not find entities: source='%s' (found %d), destination='%s' (found %d)",
|
|
1075
|
+
source,
|
|
1076
|
+
len(source_ids),
|
|
1077
|
+
destination,
|
|
1078
|
+
len(dest_ids)
|
|
1079
|
+
)
|
|
1080
|
+
results.append({"deleted_count": 0})
|
|
1081
|
+
continue
|
|
1082
|
+
|
|
1083
|
+
# Build where clause for relationship deletion
|
|
1084
|
+
where_clauses = [
|
|
1085
|
+
"relationship_type = :rel_type",
|
|
1086
|
+
"user_id = :user_id",
|
|
1087
|
+
]
|
|
1088
|
+
params = {
|
|
1089
|
+
"rel_type": relationship,
|
|
1090
|
+
"user_id": filters["user_id"],
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
if filters.get("agent_id"):
|
|
1094
|
+
where_clauses.append("agent_id = :agent_id")
|
|
1095
|
+
params["agent_id"] = filters["agent_id"]
|
|
1096
|
+
if filters.get("run_id"):
|
|
1097
|
+
where_clauses.append("run_id = :run_id")
|
|
1098
|
+
params["run_id"] = filters["run_id"]
|
|
1099
|
+
|
|
1100
|
+
# Add source and destination entity ID conditions.
|
|
1101
|
+
source_conditions = " OR ".join([
|
|
1102
|
+
f"source_entity_id = :src_id_{i}"
|
|
1103
|
+
for i in range(len(source_ids))
|
|
1104
|
+
])
|
|
1105
|
+
dest_conditions = " OR ".join([
|
|
1106
|
+
f"destination_entity_id = :dest_id_{i}"
|
|
1107
|
+
for i in range(len(dest_ids))
|
|
1108
|
+
])
|
|
1109
|
+
|
|
1110
|
+
where_clauses.append(f"({source_conditions}) AND ({dest_conditions})")
|
|
1111
|
+
|
|
1112
|
+
# Add entity ID parameters
|
|
1113
|
+
for i, src_id in enumerate(source_ids):
|
|
1114
|
+
params[f"src_id_{i}"] = src_id
|
|
1115
|
+
for i, dest_id in enumerate(dest_ids):
|
|
1116
|
+
params[f"dest_id_{i}"] = dest_id
|
|
1117
|
+
|
|
1118
|
+
where_str = " AND ".join(where_clauses)
|
|
1119
|
+
where_clause = text(where_str).bindparams(**params)
|
|
1120
|
+
|
|
1121
|
+
# Delete relationships using pyobvector delete method.
|
|
1122
|
+
try:
|
|
1123
|
+
delete_result = self.client.delete(
|
|
1124
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
1125
|
+
where_clause=[where_clause]
|
|
1126
|
+
)
|
|
1127
|
+
deleted_count = (
|
|
1128
|
+
delete_result.rowcount
|
|
1129
|
+
if hasattr(delete_result, "rowcount")
|
|
1130
|
+
else 1
|
|
1131
|
+
)
|
|
1132
|
+
results.append({"deleted_count": deleted_count})
|
|
1133
|
+
except Exception as e:
|
|
1134
|
+
logger.warning("Error deleting relationship: %s", e)
|
|
1135
|
+
results.append({"deleted_count": 0})
|
|
1136
|
+
|
|
1137
|
+
return results
|
|
1138
|
+
|
|
1139
|
+
def _add_entities(
|
|
1140
|
+
self,
|
|
1141
|
+
to_be_added: List[Dict[str, str]],
|
|
1142
|
+
filters: Dict[str, Any],
|
|
1143
|
+
entity_type_map: Dict[str, str]
|
|
1144
|
+
) -> List[Dict[str, str]]:
|
|
1145
|
+
"""Add new entities and relationships to the graph.
|
|
1146
|
+
|
|
1147
|
+
Args:
|
|
1148
|
+
to_be_added: List of relationships to add with source, destination, relationship.
|
|
1149
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1150
|
+
entity_type_map: Mapping of entity names to types.
|
|
1151
|
+
|
|
1152
|
+
Returns:
|
|
1153
|
+
List of dictionaries containing added source, relationship, and target.
|
|
1154
|
+
"""
|
|
1155
|
+
results = []
|
|
1156
|
+
|
|
1157
|
+
for item in to_be_added:
|
|
1158
|
+
source = item["source"]
|
|
1159
|
+
destination = item["destination"]
|
|
1160
|
+
relationship = item["relationship"]
|
|
1161
|
+
|
|
1162
|
+
source_embedding = self.embedding_model.embed(source)
|
|
1163
|
+
dest_embedding = self.embedding_model.embed(destination)
|
|
1164
|
+
|
|
1165
|
+
# Search for existing similar nodes.
|
|
1166
|
+
source_node = self._search_source_node(source_embedding, filters,
|
|
1167
|
+
threshold=constants.DEFAULT_SIMILARITY_THRESHOLD, limit=1)
|
|
1168
|
+
dest_node = self._search_destination_node(dest_embedding, filters,
|
|
1169
|
+
threshold=constants.DEFAULT_SIMILARITY_THRESHOLD, limit=1)
|
|
1170
|
+
|
|
1171
|
+
# Get or create source entity
|
|
1172
|
+
if source_node:
|
|
1173
|
+
source_id = source_node["id"]
|
|
1174
|
+
else:
|
|
1175
|
+
source_id = self._create_entity(source, entity_type_map.get(source, "entity"),
|
|
1176
|
+
source_embedding, filters)
|
|
1177
|
+
|
|
1178
|
+
# Get or create destination entity
|
|
1179
|
+
if dest_node:
|
|
1180
|
+
dest_id = dest_node["id"]
|
|
1181
|
+
else:
|
|
1182
|
+
dest_id = self._create_entity(destination, entity_type_map.get(destination, "entity"),
|
|
1183
|
+
dest_embedding, filters)
|
|
1184
|
+
|
|
1185
|
+
# Create or update relationship
|
|
1186
|
+
rel_result = self._create_or_update_relationship(source_id, dest_id, relationship, filters)
|
|
1187
|
+
results.append(rel_result)
|
|
1188
|
+
|
|
1189
|
+
return results
|
|
1190
|
+
|
|
1191
|
+
def _search_node(
|
|
1192
|
+
self,
|
|
1193
|
+
name: Optional[str],
|
|
1194
|
+
embedding: List[float],
|
|
1195
|
+
filters: Dict[str, Any],
|
|
1196
|
+
threshold: float = None,
|
|
1197
|
+
limit: int = 10
|
|
1198
|
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
|
|
1199
|
+
"""Search for a node by embedding similarity within threshold.
|
|
1200
|
+
|
|
1201
|
+
Args:
|
|
1202
|
+
name: Node name (not currently used, kept for compatibility).
|
|
1203
|
+
embedding: Vector embedding to search with.
|
|
1204
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1205
|
+
threshold: Distance threshold for filtering results.
|
|
1206
|
+
Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
|
|
1207
|
+
limit: Maximum number of results to return. Defaults to 10.
|
|
1208
|
+
|
|
1209
|
+
Returns:
|
|
1210
|
+
If limit==1: Single dict with id, name, distance, or None if no match.
|
|
1211
|
+
If limit>1: List of dicts with id, name, distance, or None if no match.
|
|
1212
|
+
"""
|
|
1213
|
+
if threshold is None:
|
|
1214
|
+
threshold = constants.DEFAULT_SIMILARITY_THRESHOLD
|
|
1215
|
+
|
|
1216
|
+
# Create Table object to access columns for WHERE clause.
|
|
1217
|
+
table = Table(constants.TABLE_ENTITIES, self.metadata, autoload_with=self.engine)
|
|
1218
|
+
vec_str = "[" + ",".join([str(np.float32(v)) for v in embedding]) + "]"
|
|
1219
|
+
distance_expr = l2_distance(table.c.embedding, vec_str)
|
|
1220
|
+
where_clause = [distance_expr < threshold]
|
|
1221
|
+
|
|
1222
|
+
results = self.client.ann_search(
|
|
1223
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1224
|
+
vec_data=embedding,
|
|
1225
|
+
vec_column_name="embedding",
|
|
1226
|
+
distance_func=l2_distance,
|
|
1227
|
+
with_dist=True,
|
|
1228
|
+
topk=limit,
|
|
1229
|
+
output_column_names=["id", "name"],
|
|
1230
|
+
where_clause=where_clause,
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
rows = results.fetchall()
|
|
1234
|
+
if rows:
|
|
1235
|
+
if limit == 1:
|
|
1236
|
+
row = rows[0]
|
|
1237
|
+
# Row format: (id, name, distance)
|
|
1238
|
+
entity_id, entity_name = row[0], row[1]
|
|
1239
|
+
distance = row[-1] # Distance is always the last column
|
|
1240
|
+
|
|
1241
|
+
return {"id": entity_id, "name": entity_name, "distance": distance}
|
|
1242
|
+
else:
|
|
1243
|
+
return [{"id": row[0], "name": row[1], "distance": row[-1]} for row in rows]
|
|
1244
|
+
|
|
1245
|
+
return None
|
|
1246
|
+
|
|
1247
|
+
def _create_entity(
|
|
1248
|
+
self,
|
|
1249
|
+
name: str,
|
|
1250
|
+
entity_type: str,
|
|
1251
|
+
embedding: List[float],
|
|
1252
|
+
filters: Dict[str, Any]
|
|
1253
|
+
) -> int:
|
|
1254
|
+
"""Create a new entity in the graph.
|
|
1255
|
+
|
|
1256
|
+
Args:
|
|
1257
|
+
name: Entity name.
|
|
1258
|
+
entity_type: Type of the entity.
|
|
1259
|
+
embedding: Vector embedding of the entity.
|
|
1260
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1261
|
+
|
|
1262
|
+
Returns:
|
|
1263
|
+
Snowflake ID of the created entity.
|
|
1264
|
+
"""
|
|
1265
|
+
entity_id = generate_snowflake_id()
|
|
1266
|
+
|
|
1267
|
+
# Prepare data for insertion using pyobvector API
|
|
1268
|
+
record = {
|
|
1269
|
+
"id": entity_id,
|
|
1270
|
+
"name": name,
|
|
1271
|
+
"entity_type": entity_type,
|
|
1272
|
+
"embedding": embedding,
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
# Use pyobvector upsert method
|
|
1276
|
+
self.client.upsert(
|
|
1277
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1278
|
+
data=[record],
|
|
1279
|
+
)
|
|
1280
|
+
|
|
1281
|
+
logger.debug("Created entity: %s with id: %s", name, entity_id)
|
|
1282
|
+
return entity_id
|
|
1283
|
+
|
|
1284
|
+
def _create_or_update_relationship(
|
|
1285
|
+
self,
|
|
1286
|
+
source_id: int,
|
|
1287
|
+
dest_id: int,
|
|
1288
|
+
relationship_type: str,
|
|
1289
|
+
filters: Dict[str, Any]
|
|
1290
|
+
) -> Dict[str, str]:
|
|
1291
|
+
"""Create or update a relationship between two entities.
|
|
1292
|
+
|
|
1293
|
+
Args:
|
|
1294
|
+
source_id: Snowflake ID of the source entity.
|
|
1295
|
+
dest_id: Snowflake ID of the destination entity.
|
|
1296
|
+
relationship_type: Type of the relationship.
|
|
1297
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
Dictionary containing source, relationship, and target names.
|
|
1301
|
+
"""
|
|
1302
|
+
# First, check if relationship already exists
|
|
1303
|
+
where_clause, params = self._build_where_clause_with_filters(filters)
|
|
1304
|
+
|
|
1305
|
+
# Add relationship-specific conditions to the where clause
|
|
1306
|
+
additional_conditions = " AND source_entity_id = :source_id AND destination_entity_id = :dest_id AND relationship_type = :rel_type"
|
|
1307
|
+
|
|
1308
|
+
# Rebuild the where clause with additional conditions
|
|
1309
|
+
where_str = str(where_clause[0].text) + additional_conditions
|
|
1310
|
+
|
|
1311
|
+
params.update({
|
|
1312
|
+
"source_id": source_id,
|
|
1313
|
+
"dest_id": dest_id,
|
|
1314
|
+
"rel_type": relationship_type
|
|
1315
|
+
})
|
|
1316
|
+
|
|
1317
|
+
where_clause_with_params = text(where_str).bindparams(**params)
|
|
1318
|
+
|
|
1319
|
+
# Check if relationship exists
|
|
1320
|
+
existing_relationships = self.client.get(
|
|
1321
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
1322
|
+
ids=None,
|
|
1323
|
+
output_column_name=["id"],
|
|
1324
|
+
where_clause=[where_clause_with_params]
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
existing_rows = existing_relationships.fetchall()
|
|
1328
|
+
if not existing_rows:
|
|
1329
|
+
# Relationship doesn't exist, create new one
|
|
1330
|
+
new_record = {
|
|
1331
|
+
"id": generate_snowflake_id(),
|
|
1332
|
+
"source_entity_id": source_id,
|
|
1333
|
+
"relationship_type": relationship_type,
|
|
1334
|
+
"destination_entity_id": dest_id,
|
|
1335
|
+
"user_id": filters["user_id"],
|
|
1336
|
+
"agent_id": filters.get("agent_id"),
|
|
1337
|
+
"run_id": filters.get("run_id"),
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
self.client.insert(
|
|
1341
|
+
table_name=constants.TABLE_RELATIONSHIPS,
|
|
1342
|
+
data=[new_record],
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
# Get the names for return value using pyobvector get method
|
|
1346
|
+
# First get the entities
|
|
1347
|
+
source_entity = self.client.get(
|
|
1348
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1349
|
+
ids=[source_id],
|
|
1350
|
+
output_column_name=["id", "name"]
|
|
1351
|
+
).fetchone()
|
|
1352
|
+
|
|
1353
|
+
dest_entity = self.client.get(
|
|
1354
|
+
table_name=constants.TABLE_ENTITIES,
|
|
1355
|
+
ids=[dest_id],
|
|
1356
|
+
output_column_name=["id", "name"]
|
|
1357
|
+
).fetchone()
|
|
1358
|
+
|
|
1359
|
+
return {
|
|
1360
|
+
"source": source_entity[1] if source_entity else None,
|
|
1361
|
+
"relationship": relationship_type,
|
|
1362
|
+
"target": dest_entity[1] if dest_entity else None,
|
|
1363
|
+
}
|
|
1364
|
+
|
|
1365
|
+
def _remove_spaces_from_entities(self, entity_list: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
|
1366
|
+
"""Clean entity names by replacing spaces with underscores.
|
|
1367
|
+
|
|
1368
|
+
Args:
|
|
1369
|
+
entity_list: List of dictionaries containing source, destination, relationship.
|
|
1370
|
+
|
|
1371
|
+
Returns:
|
|
1372
|
+
Cleaned entity list with spaces replaced by underscores and lowercased.
|
|
1373
|
+
"""
|
|
1374
|
+
for item in entity_list:
|
|
1375
|
+
item["source"] = item["source"].lower().replace(" ", "_")
|
|
1376
|
+
item["relationship"] = item["relationship"].lower().replace(" ", "_")
|
|
1377
|
+
item["destination"] = item["destination"].lower().replace(" ", "_")
|
|
1378
|
+
return entity_list
|
|
1379
|
+
|
|
1380
|
+
def _search_source_node(
|
|
1381
|
+
self,
|
|
1382
|
+
source_embedding: List[float],
|
|
1383
|
+
filters: Dict[str, Any],
|
|
1384
|
+
threshold: float = None,
|
|
1385
|
+
limit: int = 10
|
|
1386
|
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
|
|
1387
|
+
"""Search for a source node by embedding similarity (compatibility method).
|
|
1388
|
+
|
|
1389
|
+
Args:
|
|
1390
|
+
source_embedding: Vector embedding to search with.
|
|
1391
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1392
|
+
threshold: Distance threshold for filtering results.
|
|
1393
|
+
Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
|
|
1394
|
+
limit: Maximum number of results to return. Defaults to 10.
|
|
1395
|
+
|
|
1396
|
+
Returns:
|
|
1397
|
+
Search results from _search_node method.
|
|
1398
|
+
"""
|
|
1399
|
+
return self._search_node("source", source_embedding, filters, threshold, limit)
|
|
1400
|
+
|
|
1401
|
+
def _search_destination_node(
|
|
1402
|
+
self,
|
|
1403
|
+
destination_embedding: List[float],
|
|
1404
|
+
filters: Dict[str, Any],
|
|
1405
|
+
threshold: float = None,
|
|
1406
|
+
limit: int = 10
|
|
1407
|
+
) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
|
|
1408
|
+
"""Search for a destination node by embedding similarity (compatibility method).
|
|
1409
|
+
|
|
1410
|
+
Args:
|
|
1411
|
+
destination_embedding: Vector embedding to search with.
|
|
1412
|
+
filters: Dictionary containing user_id, agent_id, run_id.
|
|
1413
|
+
threshold: Distance threshold for filtering results.
|
|
1414
|
+
Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
|
|
1415
|
+
limit: Maximum number of results to return. Defaults to 10.
|
|
1416
|
+
|
|
1417
|
+
Returns:
|
|
1418
|
+
Search results from _search_node method.
|
|
1419
|
+
"""
|
|
1420
|
+
return self._search_node("destination", destination_embedding, filters, threshold, limit)
|
|
1421
|
+
|
|
1422
|
+
def reset(self) -> None:
|
|
1423
|
+
"""Reset the graph by clearing all nodes and relationships.
|
|
1424
|
+
|
|
1425
|
+
This method drops both entities and relationships tables and recreates them.
|
|
1426
|
+
"""
|
|
1427
|
+
logger.warning("Clearing graph...")
|
|
1428
|
+
|
|
1429
|
+
# Use pyobvector API to drop tables
|
|
1430
|
+
if self.client.check_table_exists(constants.TABLE_RELATIONSHIPS):
|
|
1431
|
+
self.client.drop_table_if_exist(constants.TABLE_RELATIONSHIPS)
|
|
1432
|
+
logger.info("Dropped %s table", constants.TABLE_RELATIONSHIPS)
|
|
1433
|
+
|
|
1434
|
+
if self.client.check_table_exists(constants.TABLE_ENTITIES):
|
|
1435
|
+
self.client.drop_table_if_exist(constants.TABLE_ENTITIES)
|
|
1436
|
+
logger.info("Dropped %s table", constants.TABLE_ENTITIES)
|
|
1437
|
+
|
|
1438
|
+
# Recreate tables
|
|
1439
|
+
self._create_tables()
|
|
1440
|
+
|
|
1441
|
+
logger.info("Graph reset completed")
|