powermem 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. powermem/__init__.py +103 -0
  2. powermem/agent/__init__.py +35 -0
  3. powermem/agent/abstract/__init__.py +22 -0
  4. powermem/agent/abstract/collaboration.py +259 -0
  5. powermem/agent/abstract/context.py +187 -0
  6. powermem/agent/abstract/manager.py +232 -0
  7. powermem/agent/abstract/permission.py +217 -0
  8. powermem/agent/abstract/privacy.py +267 -0
  9. powermem/agent/abstract/scope.py +199 -0
  10. powermem/agent/agent.py +791 -0
  11. powermem/agent/components/__init__.py +18 -0
  12. powermem/agent/components/collaboration_coordinator.py +645 -0
  13. powermem/agent/components/permission_controller.py +586 -0
  14. powermem/agent/components/privacy_protector.py +767 -0
  15. powermem/agent/components/scope_controller.py +685 -0
  16. powermem/agent/factories/__init__.py +16 -0
  17. powermem/agent/factories/agent_factory.py +266 -0
  18. powermem/agent/factories/config_factory.py +308 -0
  19. powermem/agent/factories/memory_factory.py +229 -0
  20. powermem/agent/implementations/__init__.py +16 -0
  21. powermem/agent/implementations/hybrid.py +728 -0
  22. powermem/agent/implementations/multi_agent.py +1040 -0
  23. powermem/agent/implementations/multi_user.py +1020 -0
  24. powermem/agent/types.py +53 -0
  25. powermem/agent/wrappers/__init__.py +14 -0
  26. powermem/agent/wrappers/agent_memory_wrapper.py +427 -0
  27. powermem/agent/wrappers/compatibility_wrapper.py +520 -0
  28. powermem/config_loader.py +318 -0
  29. powermem/configs.py +249 -0
  30. powermem/core/__init__.py +19 -0
  31. powermem/core/async_memory.py +1493 -0
  32. powermem/core/audit.py +258 -0
  33. powermem/core/base.py +165 -0
  34. powermem/core/memory.py +1567 -0
  35. powermem/core/setup.py +162 -0
  36. powermem/core/telemetry.py +215 -0
  37. powermem/integrations/__init__.py +17 -0
  38. powermem/integrations/embeddings/__init__.py +13 -0
  39. powermem/integrations/embeddings/aws_bedrock.py +100 -0
  40. powermem/integrations/embeddings/azure_openai.py +55 -0
  41. powermem/integrations/embeddings/base.py +31 -0
  42. powermem/integrations/embeddings/config/base.py +132 -0
  43. powermem/integrations/embeddings/configs.py +31 -0
  44. powermem/integrations/embeddings/factory.py +48 -0
  45. powermem/integrations/embeddings/gemini.py +39 -0
  46. powermem/integrations/embeddings/huggingface.py +41 -0
  47. powermem/integrations/embeddings/langchain.py +35 -0
  48. powermem/integrations/embeddings/lmstudio.py +29 -0
  49. powermem/integrations/embeddings/mock.py +11 -0
  50. powermem/integrations/embeddings/ollama.py +53 -0
  51. powermem/integrations/embeddings/openai.py +49 -0
  52. powermem/integrations/embeddings/qwen.py +102 -0
  53. powermem/integrations/embeddings/together.py +31 -0
  54. powermem/integrations/embeddings/vertexai.py +54 -0
  55. powermem/integrations/llm/__init__.py +18 -0
  56. powermem/integrations/llm/anthropic.py +87 -0
  57. powermem/integrations/llm/base.py +132 -0
  58. powermem/integrations/llm/config/anthropic.py +56 -0
  59. powermem/integrations/llm/config/azure.py +56 -0
  60. powermem/integrations/llm/config/base.py +62 -0
  61. powermem/integrations/llm/config/deepseek.py +56 -0
  62. powermem/integrations/llm/config/ollama.py +56 -0
  63. powermem/integrations/llm/config/openai.py +79 -0
  64. powermem/integrations/llm/config/qwen.py +68 -0
  65. powermem/integrations/llm/config/qwen_asr.py +46 -0
  66. powermem/integrations/llm/config/vllm.py +56 -0
  67. powermem/integrations/llm/configs.py +26 -0
  68. powermem/integrations/llm/deepseek.py +106 -0
  69. powermem/integrations/llm/factory.py +118 -0
  70. powermem/integrations/llm/gemini.py +201 -0
  71. powermem/integrations/llm/langchain.py +65 -0
  72. powermem/integrations/llm/ollama.py +106 -0
  73. powermem/integrations/llm/openai.py +166 -0
  74. powermem/integrations/llm/openai_structured.py +80 -0
  75. powermem/integrations/llm/qwen.py +207 -0
  76. powermem/integrations/llm/qwen_asr.py +171 -0
  77. powermem/integrations/llm/vllm.py +106 -0
  78. powermem/integrations/rerank/__init__.py +20 -0
  79. powermem/integrations/rerank/base.py +43 -0
  80. powermem/integrations/rerank/config/__init__.py +7 -0
  81. powermem/integrations/rerank/config/base.py +27 -0
  82. powermem/integrations/rerank/configs.py +23 -0
  83. powermem/integrations/rerank/factory.py +68 -0
  84. powermem/integrations/rerank/qwen.py +159 -0
  85. powermem/intelligence/__init__.py +17 -0
  86. powermem/intelligence/ebbinghaus_algorithm.py +354 -0
  87. powermem/intelligence/importance_evaluator.py +361 -0
  88. powermem/intelligence/intelligent_memory_manager.py +284 -0
  89. powermem/intelligence/manager.py +148 -0
  90. powermem/intelligence/plugin.py +229 -0
  91. powermem/prompts/__init__.py +29 -0
  92. powermem/prompts/graph/graph_prompts.py +217 -0
  93. powermem/prompts/graph/graph_tools_prompts.py +469 -0
  94. powermem/prompts/importance_evaluation.py +246 -0
  95. powermem/prompts/intelligent_memory_prompts.py +163 -0
  96. powermem/prompts/templates.py +193 -0
  97. powermem/storage/__init__.py +14 -0
  98. powermem/storage/adapter.py +896 -0
  99. powermem/storage/base.py +109 -0
  100. powermem/storage/config/base.py +13 -0
  101. powermem/storage/config/oceanbase.py +58 -0
  102. powermem/storage/config/pgvector.py +52 -0
  103. powermem/storage/config/sqlite.py +27 -0
  104. powermem/storage/configs.py +159 -0
  105. powermem/storage/factory.py +59 -0
  106. powermem/storage/migration_manager.py +438 -0
  107. powermem/storage/oceanbase/__init__.py +8 -0
  108. powermem/storage/oceanbase/constants.py +162 -0
  109. powermem/storage/oceanbase/oceanbase.py +1384 -0
  110. powermem/storage/oceanbase/oceanbase_graph.py +1441 -0
  111. powermem/storage/pgvector/__init__.py +7 -0
  112. powermem/storage/pgvector/pgvector.py +420 -0
  113. powermem/storage/sqlite/__init__.py +0 -0
  114. powermem/storage/sqlite/sqlite.py +218 -0
  115. powermem/storage/sqlite/sqlite_vector_store.py +311 -0
  116. powermem/utils/__init__.py +35 -0
  117. powermem/utils/utils.py +605 -0
  118. powermem/version.py +23 -0
  119. powermem-0.1.0.dist-info/METADATA +187 -0
  120. powermem-0.1.0.dist-info/RECORD +123 -0
  121. powermem-0.1.0.dist-info/WHEEL +5 -0
  122. powermem-0.1.0.dist-info/licenses/LICENSE +206 -0
  123. powermem-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1441 @@
1
+ """
2
+ OceanBase graph storage implementation
3
+
4
+ This module provides OceanBase-based graph storage for memory data.
5
+ """
6
+ import json
7
+ import logging
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ from sqlalchemy import bindparam, text, MetaData, Column, String, Index, Table, BigInteger
13
+ from sqlalchemy.dialects.mysql import TIMESTAMP
14
+ from sqlalchemy.exc import SAWarning
15
+
16
+ # Suppress SQLAlchemy warnings about unknown schema content from pyobvector
17
+ # These warnings occur because SQLAlchemy doesn't recognize OceanBase VECTOR index syntax
18
+ # This is harmless as pyobvector handles VECTOR types correctly
19
+ warnings.filterwarnings(
20
+ "ignore",
21
+ message="Unknown schema content",
22
+ category=SAWarning,
23
+ module="pyobvector.*"
24
+ )
25
+
26
+ # Suppress pkg_resources deprecation warning from jieba internal usage
27
+ # This warning is from jieba's internal use of deprecated pkg_resources API
28
+ warnings.filterwarnings(
29
+ "ignore",
30
+ message="pkg_resources is deprecated as an API",
31
+ category=UserWarning,
32
+ module="jieba.*"
33
+ )
34
+
35
+ from pyobvector import ObVecClient, l2_distance, VECTOR, VecIndexType
36
+
37
+ from powermem.integrations import EmbedderFactory, LLMFactory
38
+ from powermem.storage.base import GraphStoreBase
39
+ from powermem.utils.utils import format_entities, remove_code_blocks, generate_snowflake_id
40
+
41
+ try:
42
+ from rank_bm25 import BM25Okapi
43
+ except ImportError:
44
+ raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
45
+
46
+ from powermem.prompts import GraphPrompts, GraphToolsPrompts
47
+
48
+ from powermem.storage.oceanbase import constants
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # Try to import jieba for better Chinese text segmentation
53
+ try:
54
+ import jieba
55
+ except ImportError:
56
+ logger.warning("jieba is not installed. Falling back to simple space-based tokenization. "
57
+ "Install jieba for better Chinese text segmentation: pip install jieba")
58
+ jieba = None
59
+
60
+
61
+ class MemoryGraph(GraphStoreBase):
62
+ """OceanBase-based graph memory storage implementation."""
63
+
64
+ def __init__(self, config: Any) -> None:
65
+ """Initialize OceanBase graph memory.
66
+
67
+ Args:
68
+ config: Memory configuration containing graph_store, embedder, and llm configs.
69
+
70
+ Raises:
71
+ ValueError: If embedding_model_dims is not configured.
72
+ """
73
+ self.config = config
74
+
75
+ # Get OceanBase config
76
+ ob_config = self.config.graph_store.config
77
+
78
+ # Helper function to get config value (supports both dict and object)
79
+ def get_config_value(key: str, default: Any = None) -> Any:
80
+ if isinstance(ob_config, dict):
81
+ return ob_config.get(key, default)
82
+ else:
83
+ return getattr(ob_config, key, default)
84
+
85
+ # Get embedding_model_dims (required)
86
+ embedding_model_dims = get_config_value("embedding_model_dims")
87
+ if embedding_model_dims is None:
88
+ raise ValueError(
89
+ "embedding_model_dims is required for OceanBase graph operations. "
90
+ "Please configure embedding_model_dims in your OceanBaseGraphConfig."
91
+ )
92
+ self.embedding_dims = embedding_model_dims
93
+
94
+ # Get vidx parameters with defaults.
95
+ self.index_type = get_config_value("index_type", constants.DEFAULT_INDEX_TYPE)
96
+ self.vidx_metric_type = get_config_value("vidx_metric_type",
97
+ constants.DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE)
98
+ self.vidx_name = get_config_value("vidx_name", constants.DEFAULT_VIDX_NAME)
99
+
100
+ # Get graph search parameters
101
+ self.max_hops = get_config_value("max_hops", 3)
102
+
103
+ # Set vidx_algo_params with defaults based on index_type.
104
+ self.vidx_algo_params = get_config_value("vidx_algo_params", None)
105
+ if not self.vidx_algo_params:
106
+ # Set default parameters based on index type.
107
+ self.vidx_algo_params = constants.get_default_build_params(self.index_type)
108
+
109
+ # Initialize embedding model
110
+ self.embedding_model = EmbedderFactory.create(
111
+ self.config.embedder.provider,
112
+ self.config.embedder.config,
113
+ self.config.vector_store.config,
114
+ )
115
+
116
+ # Initialize OceanBase client
117
+ host = get_config_value("host", "localhost")
118
+ port = get_config_value("port", "2881")
119
+ user = get_config_value("user", "root")
120
+ password = get_config_value("password", "")
121
+ db_name = get_config_value("db_name", "test")
122
+
123
+ self.client = ObVecClient(
124
+ uri=f"{host}:{port}",
125
+ user=user,
126
+ password=password,
127
+ db_name=db_name,
128
+ )
129
+ self.engine = self.client.engine
130
+ self.metadata = MetaData()
131
+
132
+ # Create tables
133
+ self._create_tables()
134
+
135
+ # Initialize LLM.
136
+ self.llm_provider = self._get_llm_provider()
137
+ llm_config = self._get_llm_config()
138
+ self.llm = LLMFactory.create(self.llm_provider, llm_config)
139
+
140
+ # Initialize graph prompts and tools with config
141
+ # Pass graph_store config or full config to prompts
142
+ graph_config = {}
143
+ if self.config.graph_store:
144
+ # Convert GraphStoreConfig to dict if needed
145
+ if hasattr(self.config.graph_store, 'model_dump'):
146
+ graph_config = self.config.graph_store.model_dump()
147
+ elif isinstance(self.config.graph_store, dict):
148
+ graph_config = self.config.graph_store
149
+ # Also include full config for fallback
150
+ prompts_config = {"graph_store": graph_config}
151
+ # Merge top-level config if it's a dict
152
+ if isinstance(self.config, dict):
153
+ prompts_config.update(self.config)
154
+
155
+ self.graph_prompts = GraphPrompts(prompts_config)
156
+ self.graph_tools_prompts = GraphToolsPrompts(prompts_config)
157
+
158
+ def _get_llm_provider(self) -> str:
159
+ """Get LLM provider from configuration with fallback.
160
+
161
+ Returns:
162
+ LLM provider name.
163
+ """
164
+ # Check graph_store.llm.provider first
165
+ if (self.config.graph_store and
166
+ self.config.graph_store.llm and
167
+ self.config.graph_store.llm.provider):
168
+ return self.config.graph_store.llm.provider
169
+
170
+ # Check config.llm.provider
171
+ if self.config.llm and self.config.llm.provider:
172
+ return self.config.llm.provider
173
+
174
+ # Default fallback
175
+ return constants.DEFAULT_LLM_PROVIDER
176
+
177
+ def _get_llm_config(self) -> Optional[Any]:
178
+ """Get LLM config from configuration.
179
+
180
+ Returns:
181
+ LLM configuration object or None.
182
+ """
183
+ # Check graph_store.llm.config first
184
+ if (self.config.graph_store and
185
+ self.config.graph_store.llm and
186
+ hasattr(self.config.graph_store.llm, "config")):
187
+ return self.config.graph_store.llm.config
188
+
189
+ # Check config.llm.config
190
+ if hasattr(self.config.llm, "config"):
191
+ return self.config.llm.config
192
+
193
+ return None
194
+
195
+ def _build_user_identity(self, filters: Dict[str, Any]) -> str:
196
+ """Build user identity string from filters.
197
+
198
+ Args:
199
+ filters: Dictionary containing user_id, agent_id, run_id.
200
+
201
+ Returns:
202
+ Formatted user identity string.
203
+ """
204
+ identity_parts = [f"user_id: {filters['user_id']}"]
205
+
206
+ if filters.get("agent_id"):
207
+ identity_parts.append(f"agent_id: {filters['agent_id']}")
208
+
209
+ if filters.get("run_id"):
210
+ identity_parts.append(f"run_id: {filters['run_id']}")
211
+
212
+ return ", ".join(identity_parts)
213
+
214
+ def _build_filter_conditions(
215
+ self,
216
+ filters: Dict[str, Any],
217
+ prefix: str = ""
218
+ ) -> Tuple[List[str], Dict[str, Any]]:
219
+ """Build SQL filter conditions and parameters from filters.
220
+
221
+ Args:
222
+ filters: Dictionary containing user_id, agent_id, run_id.
223
+ prefix: Optional prefix for column names (e.g., "r." for table alias).
224
+
225
+ Returns:
226
+ Tuple of (filter_conditions_list, params_dict).
227
+ """
228
+ filter_parts = [f"{prefix}user_id = :user_id"]
229
+ params = {"user_id": filters["user_id"]}
230
+
231
+ if filters.get("agent_id"):
232
+ filter_parts.append(f"{prefix}agent_id = :agent_id")
233
+ params["agent_id"] = filters["agent_id"]
234
+
235
+ if filters.get("run_id"):
236
+ filter_parts.append(f"{prefix}run_id = :run_id")
237
+ params["run_id"] = filters["run_id"]
238
+
239
+ return filter_parts, params
240
+
241
+ @staticmethod
242
+ def _coerce_tool_response_to_dict(response: Any) -> Dict[str, Any]:
243
+ """Ensure LLM tool response is a dict.
244
+
245
+ Some LLM providers may return a JSON string instead of a parsed dict. This helper
246
+ normalizes the response to a dictionary with safe fallbacks.
247
+
248
+ Args:
249
+ response: LLM tool call response, may be dict or JSON string.
250
+
251
+ Returns:
252
+ Normalized dictionary object, or empty dict if unparseable.
253
+ """
254
+ if isinstance(response, dict):
255
+ return response
256
+ if isinstance(response, str):
257
+ try:
258
+ cleaned = remove_code_blocks(response)
259
+ except Exception:
260
+ cleaned = response
261
+ try:
262
+ parsed = json.loads(cleaned)
263
+ if isinstance(parsed, dict):
264
+ return parsed
265
+ except Exception:
266
+ pass
267
+ # Fallback to empty dict if un-parseable
268
+ return {}
269
+
270
+ def _create_tables(self) -> None:
271
+ """Create graph entities and relationships tables if they don't exist.
272
+
273
+ Creates two tables:
274
+ - graph_entities: Stores entity nodes and their vector embeddings
275
+ - graph_relationships: Stores relationships between entities
276
+ """
277
+
278
+ if not self.client.check_table_exists(constants.TABLE_ENTITIES):
279
+ # Define columns for entities table
280
+ cols = [
281
+ Column("id", BigInteger, primary_key=True, autoincrement=False),
282
+ Column("name", String(255), nullable=False),
283
+ Column("entity_type", String(64)),
284
+ Column("embedding", VECTOR(self.embedding_dims)),
285
+ Column("created_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP")),
286
+ Column("updated_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
287
+ ]
288
+ # Define regular indexes
289
+ indexes = [
290
+ Index("idx_name", "name"),
291
+ ]
292
+
293
+ # Map index_type string to VecIndexType enum
294
+ index_type_map = constants.OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPES
295
+
296
+ # Create vector index parameters
297
+ vidx_params = self.client.prepare_index_params()
298
+ vidx_params.add_index(
299
+ field_name="embedding",
300
+ index_type=index_type_map.get(self.index_type, VecIndexType.HNSW),
301
+ index_name=self.vidx_name,
302
+ metric_type=self.vidx_metric_type,
303
+ params=self.vidx_algo_params,
304
+ )
305
+
306
+ # Create table with vector index
307
+ self.client.create_table_with_index_params(
308
+ table_name=constants.TABLE_ENTITIES,
309
+ columns=cols,
310
+ indexes=indexes,
311
+ vidxs=vidx_params,
312
+ partitions=None,
313
+ )
314
+
315
+ logger.info("%s table created successfully", constants.TABLE_ENTITIES)
316
+ else:
317
+ logger.info("%s table already exists", constants.TABLE_ENTITIES)
318
+ # Check vector dimension consistency
319
+ existing_dim = self._get_existing_vector_dimension_for_entities()
320
+ if existing_dim is not None and existing_dim != self.embedding_dims:
321
+ raise ValueError(
322
+ f"Vector dimension mismatch: existing table '{constants.TABLE_ENTITIES}' has "
323
+ f"vector dimension {existing_dim}, but requested dimension is {self.embedding_dims}. "
324
+ f"Please use a different configuration or reset the graph."
325
+ )
326
+
327
+ # Create relationships table using pyobvector API
328
+ if not self.client.check_table_exists(constants.TABLE_RELATIONSHIPS):
329
+
330
+ # Define columns for relationships table
331
+ cols = [
332
+ Column("id", BigInteger, primary_key=True, autoincrement=False),
333
+ Column("source_entity_id", BigInteger, nullable=False),
334
+ Column("relationship_type", String(128), nullable=False),
335
+ Column("destination_entity_id", BigInteger, nullable=False),
336
+ Column("user_id", String(128)),
337
+ Column("agent_id", String(128)),
338
+ Column("run_id", String(128)),
339
+ Column("created_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP")),
340
+ Column("updated_at", TIMESTAMP, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
341
+ ]
342
+
343
+ # Define regular indexes
344
+ indexes = [
345
+ Index("idx_r_covering", "user_id","source_entity_id", "destination_entity_id","relationship_type"),
346
+ ]
347
+
348
+ # Create table without vector index (relationships table has no vectors)
349
+ self.client.create_table_with_index_params(
350
+ table_name=constants.TABLE_RELATIONSHIPS,
351
+ columns=cols,
352
+ indexes=indexes,
353
+ vidxs=None,
354
+ partitions=None,
355
+ )
356
+
357
+ logger.info("%s table created successfully", constants.TABLE_RELATIONSHIPS)
358
+ else:
359
+ logger.info("%s table already exists", constants.TABLE_RELATIONSHIPS)
360
+
361
+ def _get_existing_vector_dimension_for_entities(self) -> Optional[int]:
362
+ """Get the dimension of the existing vector field in entities table.
363
+
364
+ Returns:
365
+ Dimension of the vector field, or None if table doesn't exist or field not found.
366
+ """
367
+ if not self.client.check_table_exists(constants.TABLE_ENTITIES):
368
+ return None
369
+
370
+ try:
371
+ with self.engine.connect() as conn:
372
+ result = conn.execute(text(f"DESCRIBE {constants.TABLE_ENTITIES}"))
373
+ columns = result.fetchall()
374
+
375
+ for col in columns:
376
+ if col[0] == "embedding":
377
+ col_type = col[1]
378
+ if col_type.startswith("VECTOR(") and col_type.endswith(")"):
379
+ dim_str = col_type[7:-1]
380
+ return int(dim_str)
381
+ return None
382
+ except Exception as e:
383
+ logger.warning(
384
+ "Failed to get vector dimension for %s: %s",
385
+ constants.TABLE_ENTITIES,
386
+ e
387
+ )
388
+ return None
389
+
390
+ def _build_where_clause_with_filters(
391
+ self,
392
+ filters: Dict[str, Any],
393
+ prefix: str = "",
394
+ additional_params: Optional[Dict[str, Any]] = None
395
+ ) -> Tuple[List[Any], Dict[str, Any]]:
396
+ """Build where clause and parameters from filters using bindparam.
397
+
398
+ Args:
399
+ filters: Dictionary containing user_id, agent_id, run_id.
400
+ prefix: Optional prefix for column names (e.g., "r." for table alias).
401
+ additional_params: Optional additional parameters to include.
402
+
403
+ Returns:
404
+ Tuple of (where_clause_list, params_dict).
405
+ """
406
+ where_conditions = []
407
+ params = {"user_id": filters["user_id"]}
408
+ where_conditions.append(f"{prefix}user_id = :user_id")
409
+
410
+ if filters.get("agent_id"):
411
+ where_conditions.append(f"{prefix}agent_id = :agent_id")
412
+ params["agent_id"] = filters["agent_id"]
413
+ if filters.get("run_id"):
414
+ where_conditions.append(f"{prefix}run_id = :run_id")
415
+ params["run_id"] = filters["run_id"]
416
+
417
+ # Add additional params if provided
418
+ if additional_params:
419
+ params.update(additional_params)
420
+
421
+ where_sql = " AND ".join(where_conditions)
422
+ where_clause_with_params = text(where_sql).bindparams(**params)
423
+ return [where_clause_with_params], params
424
+
425
+ def add(self, data: str, filters: Dict[str, Any]) -> Dict[str, Any]:
426
+ """Add data to the graph.
427
+
428
+ Args:
429
+ data: The data to add to the graph.
430
+ filters: Dictionary containing filters (user_id, agent_id, run_id).
431
+
432
+ Returns:
433
+ Dictionary containing deleted_entities and added_entities.
434
+ """
435
+ entity_type_map = self._retrieve_nodes_from_data(data, filters)
436
+ to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
437
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
438
+ to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
439
+
440
+ deleted_entities = self._delete_entities(to_be_deleted, filters)
441
+ added_entities = self._add_entities(to_be_added, filters, entity_type_map)
442
+ logger.debug("Deleted entities: %s, Added entities: %s", deleted_entities, added_entities)
443
+ return {"deleted_entities": deleted_entities, "added_entities": added_entities}
444
+
445
+ def search(self, query: str, filters: Dict[str, Any], limit: int = 100) -> List[Dict[str, str]]:
446
+ """Search for memories and related graph data.
447
+
448
+ Args:
449
+ query: Query to search for.
450
+ filters: Dictionary containing filters (user_id, agent_id, run_id).
451
+ limit: Maximum number of nodes and relationships to retrieve. Defaults to 100.
452
+
453
+ Returns:
454
+ List of search results containing source, relationship, and destination.
455
+ """
456
+ entity_type_map = self._retrieve_nodes_from_data(query, filters)
457
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters, limit=limit)
458
+
459
+ if not search_output:
460
+ return []
461
+
462
+ # Tokenize search outputs for BM25 with improved segmentation
463
+ search_outputs_sequence = []
464
+ for item in search_output:
465
+ # Combine source, relationship, destination into a single text for better tokenization
466
+ combined_text = f"{item['source']} {item['relationship']} {item['destination']}"
467
+ tokenized_item = self._tokenize_text(combined_text)
468
+ search_outputs_sequence.append(tokenized_item)
469
+
470
+ bm25 = BM25Okapi(search_outputs_sequence)
471
+
472
+ # Tokenize query using the same method
473
+ tokenized_query = self._tokenize_text(query)
474
+
475
+ # Get top N results based on BM25 scores
476
+ scores = bm25.get_scores(tokenized_query)
477
+ # Get indices sorted by score (descending)
478
+ sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
479
+ top_n_indices = sorted_indices[:constants.DEFAULT_BM25_TOP_N]
480
+
481
+ # Build reranked results
482
+ search_results = []
483
+ for idx in top_n_indices:
484
+ if idx < len(search_output):
485
+ item = search_output[idx]
486
+ search_results.append({
487
+ "source": item["source"],
488
+ "relationship": item["relationship"],
489
+ "destination": item["destination"]
490
+ })
491
+
492
+ logger.info("Returned %d search results (from %d candidates)", len(search_results), len(search_output))
493
+
494
+ return search_results
495
+
496
+ def _tokenize_text(self, text: str) -> List[str]:
497
+ """Tokenize text using jieba for Chinese or simple split for other languages.
498
+
499
+ Args:
500
+ text: Text to tokenize.
501
+
502
+ Returns:
503
+ List of tokens.
504
+ """
505
+ if jieba is not None:
506
+ # Use jieba for Chinese text segmentation
507
+ # Convert to lowercase for better matching
508
+ tokens = list(jieba.cut(text.lower()))
509
+ # Filter out empty strings and single spaces
510
+ tokens = [t for t in tokens if t.strip()]
511
+ return tokens
512
+ else:
513
+ # Fallback to simple space-based tokenization
514
+ return text.lower().split()
515
+
516
+ def delete_all(self, filters: Dict[str, Any]) -> None:
517
+ """Delete all graph data for the given filters.
518
+
519
+ Args:
520
+ filters: Filters containing user_id, agent_id, run_id.
521
+ """
522
+ where_clause, _ = self._build_where_clause_with_filters(filters)
523
+
524
+ try:
525
+ relationships_results = self.client.get(
526
+ table_name=constants.TABLE_RELATIONSHIPS,
527
+ ids=None,
528
+ output_column_name=["id", "source_entity_id", "destination_entity_id"],
529
+ where_clause=where_clause
530
+ )
531
+
532
+ # Collect unique entity IDs from relationships
533
+ entity_ids = set()
534
+ for rel in relationships_results.fetchall():
535
+ entity_ids.add(rel[1]) # source_entity_id
536
+ entity_ids.add(rel[2]) # destination_entity_id
537
+
538
+ # Delete the relationships
539
+ self.client.delete(
540
+ table_name=constants.TABLE_RELATIONSHIPS,
541
+ where_clause=where_clause
542
+ )
543
+ logger.info("Deleted relationships for filters: %s", filters)
544
+
545
+ # Delete entities that were part of these relationships
546
+ if entity_ids:
547
+ self.client.delete(
548
+ table_name=constants.TABLE_ENTITIES,
549
+ ids=list(entity_ids),
550
+ )
551
+ logger.info("Deleted %d entities for filters: %s", len(entity_ids), filters)
552
+ except Exception as e:
553
+ logger.warning("Error deleting graph data: %s", e)
554
+
555
+ logger.info("Deleted all graph data for filters: %s", filters)
556
+
557
+ def get_all(self, filters: Dict[str, Any], limit: int = 100) -> List[Dict[str, str]]:
558
+ """Retrieve all nodes and relationships from the graph database.
559
+
560
+ Args:
561
+ filters: Dictionary containing filters (user_id, agent_id, run_id).
562
+ limit: Maximum number of relationships to retrieve. Defaults to 100.
563
+
564
+ Returns:
565
+ List of dictionaries containing source, relationship, and target.
566
+ """
567
+
568
+ where_clause, params = self._build_where_clause_with_filters(filters)
569
+
570
+ relationships_results = self.client.get(
571
+ table_name=constants.TABLE_RELATIONSHIPS,
572
+ ids=None,
573
+ output_column_name=["id", "source_entity_id", "relationship_type", "destination_entity_id", "updated_at"],
574
+ where_clause=where_clause
575
+ )
576
+
577
+ relationships = relationships_results.fetchall()
578
+ if not relationships:
579
+ return []
580
+
581
+ # Limit results if needed
582
+ if len(relationships) > limit:
583
+ relationships = relationships[:limit]
584
+
585
+ # Extract unique entity IDs from relationships
586
+ entity_ids = set()
587
+ for rel in relationships:
588
+ entity_ids.add(rel[1]) # source_entity_id
589
+ entity_ids.add(rel[3]) # destination_entity_id
590
+
591
+ # Get all entities that are referenced in relationships
592
+ entities_results = self.client.get(
593
+ table_name=constants.TABLE_ENTITIES,
594
+ ids=list(entity_ids),
595
+ output_column_name=["id", "name"]
596
+ )
597
+
598
+ # Create a mapping from entity_id to entity_name
599
+ entity_map = {entity[0]: entity[1] for entity in entities_results.fetchall()}
600
+
601
+ # Build final results with updated_at for sorting
602
+ final_results = []
603
+ for rel in relationships:
604
+ rel_id, source_id, relationship_type, dest_id, updated_at = rel
605
+
606
+ source_name = entity_map.get(source_id, f"Unknown_{source_id}")
607
+ dest_name = entity_map.get(dest_id, f"Unknown_{dest_id}")
608
+
609
+ final_results.append({
610
+ "source": source_name,
611
+ "relationship": relationship_type,
612
+ "target": dest_name,
613
+ "_updated_at": updated_at, # Keep for sorting
614
+ })
615
+
616
+ # Sort by updated_at (descending)
617
+ final_results.sort(key=lambda x: x["_updated_at"], reverse=True)
618
+
619
+ # Remove the temporary _updated_at field
620
+ for result in final_results:
621
+ del result["_updated_at"]
622
+
623
+ logger.info("Retrieved %d relationships", len(final_results))
624
+ return final_results
625
+
626
+ def _retrieve_nodes_from_data(self, data: str, filters: Dict[str, Any]) -> Dict[str, str]:
627
+ """Extract all the entities mentioned in the query.
628
+
629
+ Args:
630
+ data: Input text to extract entities from.
631
+ filters: Dictionary containing user_id, agent_id, run_id.
632
+
633
+ Returns:
634
+ Dictionary mapping entity names to entity types.
635
+ """
636
+ _tools = [self.graph_tools_prompts.get_extract_entities_tool()]
637
+ if constants.is_structured_llm_provider(self.llm_provider):
638
+ _tools = [self.graph_tools_prompts.get_extract_entities_tool(structured=True)]
639
+
640
+ search_results = self.llm.generate_response(
641
+ messages=[
642
+ {
643
+ "role": "system",
644
+ "content": (
645
+ f"You are a smart assistant who understands entities and their types in a given text. "
646
+ f"If user message contains self reference such as 'I', 'me', 'my' etc. "
647
+ f"then use {filters['user_id']} as the source entity. "
648
+ f"Extract all the entities from the text. "
649
+ f"***DO NOT*** answer the question itself if the given text is a question."
650
+ ),
651
+ },
652
+ {"role": "user", "content": data},
653
+ ],
654
+ tools=_tools,
655
+ )
656
+
657
+ # Normalize potential string response to dict
658
+ search_results = self._coerce_tool_response_to_dict(search_results)
659
+
660
+ entity_type_map = {}
661
+
662
+ try:
663
+ for tool_call in search_results.get("tool_calls", []):
664
+ if tool_call["name"] != "extract_entities":
665
+ continue
666
+ for item in tool_call["arguments"]["entities"]:
667
+ entity_type_map[item["entity"]] = item["entity_type"]
668
+ except Exception as e:
669
+ logger.exception(
670
+ "Error in search tool: %s, llm_provider=%s, search_results=%s",
671
+ e,
672
+ self.llm_provider,
673
+ search_results
674
+ )
675
+
676
+ entity_type_map = {
677
+ k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
678
+ for k, v in entity_type_map.items()
679
+ }
680
+ logger.debug("Entity type map: %s\n search_results=%s", entity_type_map, search_results)
681
+ return entity_type_map
682
+
683
+ def _establish_nodes_relations_from_data(
684
+ self,
685
+ data: str,
686
+ filters: Dict[str, Any],
687
+ entity_type_map: Dict[str, str]
688
+ ) -> List[Dict[str, str]]:
689
+ """Establish relations among the extracted nodes.
690
+
691
+ Args:
692
+ data: Input text to extract relationships from.
693
+ filters: Dictionary containing user_id, agent_id, run_id.
694
+ entity_type_map: Mapping of entity names to types.
695
+
696
+ Returns:
697
+ List of dictionaries containing source, destination, and relationship.
698
+ """
699
+ user_identity = self._build_user_identity(filters)
700
+
701
+ if self.config.graph_store.custom_prompt:
702
+ system_content = self.graph_prompts.get_system_prompt("extract_relations")
703
+ system_content = system_content.replace("USER_ID", user_identity)
704
+ system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
705
+ messages = [
706
+ {"role": "system", "content": system_content},
707
+ {"role": "user", "content": data},
708
+ ]
709
+ else:
710
+ system_content = self.graph_prompts.get_system_prompt("extract_relations")
711
+ system_content = system_content.replace("USER_ID", user_identity)
712
+ messages = [
713
+ {"role": "system", "content": system_content},
714
+ {
715
+ "role": "user",
716
+ "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"
717
+ },
718
+ ]
719
+
720
+ _tools = [self.graph_tools_prompts.get_relations_tool()]
721
+ if constants.is_structured_llm_provider(self.llm_provider):
722
+ _tools = [self.graph_tools_prompts.get_relations_tool(structured=True)]
723
+
724
+ extracted_entities = self.llm.generate_response(
725
+ messages=messages,
726
+ tools=_tools,
727
+ )
728
+
729
+ # Normalize to dict for consistent access
730
+ extracted_entities = self._coerce_tool_response_to_dict(extracted_entities)
731
+
732
+ entities = []
733
+ if extracted_entities.get("tool_calls"):
734
+ first_call = (
735
+ extracted_entities["tool_calls"][0]
736
+ if extracted_entities["tool_calls"]
737
+ else {}
738
+ )
739
+ entities = first_call.get("arguments", {}).get("entities", [])
740
+
741
+ entities = self._remove_spaces_from_entities(entities)
742
+ logger.debug("Extracted entities: %s", entities)
743
+ return entities
744
+
745
+ def _search_graph_db(
746
+ self,
747
+ node_list: List[str],
748
+ filters: Dict[str, Any],
749
+ limit: int = 100
750
+ ) -> List[Dict[str, Any]]:
751
+ """Search similar nodes and their relationships using vector similarity with multi-hop support.
752
+
753
+ Supports up to 3-hop graph traversal using recursive CTE. Results are prioritized by
754
+ path length (1-hop first, then 2-hop, then 3-hop).
755
+
756
+ Args:
757
+ node_list: List of node names to search for.
758
+ filters: Dictionary containing user_id, agent_id, run_id.
759
+ limit: Maximum number of results to return. Defaults to 100.
760
+
761
+ Returns:
762
+ List of dictionaries containing source, relationship, destination and their IDs.
763
+ """
764
+ result_relations = []
765
+
766
+ for node in node_list:
767
+ n_embedding = self.embedding_model.embed(node)
768
+
769
+ entities = self._search_node(None, n_embedding, filters, limit=limit)
770
+
771
+ if not entities:
772
+ continue
773
+
774
+ # Ensure entities is always a list
775
+ if isinstance(entities, dict):
776
+ entities = [entities]
777
+
778
+ entity_ids = [e.get("id") for e in entities]
779
+
780
+ # Use multi-hop search with early stopping
781
+ multi_hop_results = self._multi_hop_search(entity_ids, filters, limit)
782
+ result_relations.extend(multi_hop_results)
783
+
784
+ return result_relations
785
+
786
+ def _execute_single_hop_query(
787
+ self,
788
+ source_entity_ids: List[int],
789
+ filters: Dict[str, Any],
790
+ hop_number: int,
791
+ visited_edges: set = None,
792
+ conn=None,
793
+ max_edges_per_hop: int = 1000
794
+ ) -> List[Dict[str, Any]]:
795
+ """Execute a single hop query from given source entities.
796
+
797
+ Args:
798
+ source_entity_ids: List of source entity IDs to start from.
799
+ filters: Dictionary containing user_id, agent_id, run_id.
800
+ hop_number: Current hop number (for result annotation).
801
+ visited_edges: Set of visited edges (source_id, dest_id) to avoid cycles.
802
+ conn: Optional database connection to use. If None, creates a new connection.
803
+ max_edges_per_hop: Maximum number of edges to retrieve per hop. Defaults to 1000.
804
+
805
+ Returns:
806
+ List of relationship dictionaries with hop_count.
807
+
808
+ Note:
809
+ - Prevents memory explosion from high-degree nodes
810
+ - Results are sorted by created_at DESC before limiting
811
+ - This ensures most recent edges are retrieved first
812
+ """
813
+ if not source_entity_ids:
814
+ return []
815
+
816
+ if visited_edges is None:
817
+ visited_edges = set()
818
+
819
+ # Build filter conditions
820
+ filter_parts, params = self._build_filter_conditions(filters, prefix="")
821
+ filter_conditions = " AND ".join(filter_parts)
822
+
823
+ # Build query with LIMIT to prevent memory explosion from high-degree nodes
824
+ query = f"""
825
+ SELECT
826
+ e1.name AS source,
827
+ r.source_entity_id,
828
+ r.relationship_type,
829
+ r.id AS relation_id,
830
+ e2.name AS destination,
831
+ r.destination_entity_id
832
+ FROM
833
+ (
834
+ SELECT
835
+ id,
836
+ source_entity_id,
837
+ destination_entity_id,
838
+ relationship_type,
839
+ created_at,
840
+ updated_at,
841
+ user_id
842
+ FROM {constants.TABLE_RELATIONSHIPS}
843
+ WHERE
844
+ source_entity_id IN :entity_ids
845
+ AND {filter_conditions}
846
+ ORDER BY updated_at DESC, created_at DESC
847
+ LIMIT :max_edges_per_hop
848
+ ) AS r
849
+ JOIN {constants.TABLE_ENTITIES} e1 ON r.source_entity_id = e1.id
850
+ JOIN {constants.TABLE_ENTITIES} e2 ON r.destination_entity_id = e2.id;
851
+ """
852
+
853
+ # Add parameters
854
+ params["entity_ids"] = tuple(source_entity_ids)
855
+ params["max_edges_per_hop"] = max_edges_per_hop
856
+ logger.debug("Executing hop %d with max_edges_per_hop=%d\n query: %s\n params: %s",
857
+ hop_number, max_edges_per_hop, query, params)
858
+
859
+ # Execute query - use provided connection or create new one
860
+ if conn is not None:
861
+ # Reuse existing connection (transactional)
862
+ result = conn.execute(text(query), params)
863
+ rows = result.fetchall()
864
+ else:
865
+ # Create new connection (backward compatibility)
866
+ with self.engine.connect() as new_conn:
867
+ result = new_conn.execute(text(query), params)
868
+ rows = result.fetchall()
869
+
870
+ # Format results and filter out cycles
871
+ formatted_results = []
872
+ for row in rows:
873
+ source_id = row[1]
874
+ dest_id = row[5]
875
+ edge_key = (source_id, dest_id)
876
+
877
+ # Skip if this edge was already visited (cycle detection)
878
+ if edge_key in visited_edges:
879
+ continue
880
+
881
+ formatted_results.append({
882
+ "source": row[0],
883
+ "source_id": source_id,
884
+ "relationship": row[2],
885
+ "relation_id": row[3],
886
+ "destination": row[4],
887
+ "destination_id": dest_id,
888
+ "hop_count": hop_number,
889
+ })
890
+
891
+ return formatted_results
892
+
893
+ def _multi_hop_search(
894
+ self,
895
+ entity_ids: List[int],
896
+ filters: Dict[str, Any],
897
+ limit: int
898
+ ) -> List[Dict[str, Any]]:
899
+ """Perform multi-hop graph search with application-level early stopping.
900
+
901
+ Args:
902
+ entity_ids: List of seed entity IDs to start traversal from.
903
+ filters: Dictionary containing user_id, agent_id, run_id.
904
+ limit: Maximum number of results to return.
905
+
906
+ Returns:
907
+ List of dictionaries containing source, relationship, destination and their IDs.
908
+
909
+ Note:
910
+ Application-level optimization strategy:
911
+ 1. Execute 1-hop query first
912
+ 2. Check if results satisfy limit - if yes, return immediately
913
+ 3. If not, execute 2-hop query with cycle prevention
914
+ 4. Accumulate results and check limit again
915
+ 5. Continue until limit is satisfied or max_hops is reached
916
+ """
917
+ if not entity_ids:
918
+ return []
919
+
920
+ # Use a transaction to ensure consistent reads across all hops
921
+ # This prevents phantom reads and non-repeatable reads during multi-hop traversal
922
+ with self.engine.begin() as conn:
923
+ logger.debug("Started transaction for multi-hop search")
924
+
925
+ all_results = []
926
+ visited_edges = set() # Track visited edges to prevent cycles
927
+ visited_nodes = set(entity_ids) # Track all visited nodes (start with seed entities)
928
+ current_source_ids = entity_ids # Start from seed entities
929
+
930
+ # Iteratively execute each hop until limit is satisfied or max_hops reached
931
+ for hop in range(1, self.max_hops + 1):
932
+ # Execute single hop query within the same transaction
933
+ hop_results = self._execute_single_hop_query(
934
+ source_entity_ids=current_source_ids,
935
+ filters=filters,
936
+ hop_number=hop,
937
+ visited_edges=visited_edges,
938
+ conn=conn
939
+ )
940
+
941
+ # If no results at this hop, stop early
942
+ if not hop_results:
943
+ logger.info("STOP early: No results at hop %s", hop)
944
+ break
945
+
946
+ # Add results to accumulator
947
+ all_results.extend(hop_results)
948
+
949
+ # Update visited edges to prevent cycles in next hop
950
+ for result in hop_results:
951
+ edge_key = (result["source_id"], result["destination_id"])
952
+ visited_edges.add(edge_key)
953
+
954
+ # Check if we've satisfied the limit (early stopping)
955
+ if len(all_results) >= limit:
956
+ logger.info("STOP early: Limit satisfied at hop %s", hop)
957
+ # Truncate to exact limit and return
958
+ return all_results[:limit]
959
+
960
+ # Prepare source IDs for next hop (destination entities become new sources)
961
+ next_source_ids = set([r["destination_id"] for r in hop_results])
962
+
963
+ # Check if we have any new nodes that haven't been visited
964
+ new_nodes = next_source_ids - visited_nodes
965
+
966
+ # If no new nodes, all destinations are already visited - stop early
967
+ # This means we've exhausted all reachable nodes in the graph
968
+ if not new_nodes:
969
+ logger.info("STOP early: All destinations are already visited at hop %s", hop)
970
+ break
971
+
972
+ # Update visited nodes and prepare for next hop
973
+ visited_nodes.update(next_source_ids)
974
+ current_source_ids = list(next_source_ids)
975
+
976
+ # Return all accumulated results
977
+ logger.debug("Transaction completed for multi-hop search, returning %d results", len(all_results))
978
+ return all_results
979
+
980
+ def _get_delete_entities_from_search_output(
981
+ self,
982
+ search_output: List[Dict[str, Any]],
983
+ data: str,
984
+ filters: Dict[str, Any]
985
+ ) -> List[Dict[str, str]]:
986
+ """Get the entities to be deleted from the search output.
987
+
988
+ Args:
989
+ search_output: Search results from graph database.
990
+ data: New input data to compare against.
991
+ filters: Dictionary containing user_id, agent_id, run_id.
992
+
993
+ Returns:
994
+ List of dictionaries containing source, destination, and relationship to delete.
995
+ """
996
+ search_output_string = format_entities(search_output)
997
+ user_identity = self._build_user_identity(filters)
998
+ system_prompt, user_prompt = self.graph_prompts.get_delete_relations_prompt(search_output_string, data, user_identity)
999
+
1000
+ _tools = [self.graph_tools_prompts.get_delete_tool()]
1001
+ if constants.is_structured_llm_provider(self.llm_provider):
1002
+ _tools = [self.graph_tools_prompts.get_delete_tool(structured=True)]
1003
+
1004
+ memory_updates = self.llm.generate_response(
1005
+ messages=[
1006
+ {"role": "system", "content": system_prompt},
1007
+ {"role": "user", "content": user_prompt},
1008
+ ],
1009
+ tools=_tools,
1010
+ )
1011
+
1012
+ # Normalize to dict before access
1013
+ memory_updates = self._coerce_tool_response_to_dict(memory_updates)
1014
+
1015
+ to_be_deleted = []
1016
+ for item in memory_updates.get("tool_calls", []):
1017
+ if item.get("name") == "delete_graph_memory":
1018
+ to_be_deleted.append(item.get("arguments"))
1019
+
1020
+ to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
1021
+ logger.debug("Deleted relationships: %s", to_be_deleted)
1022
+ return to_be_deleted
1023
+
1024
+ def _delete_entities(
1025
+ self,
1026
+ to_be_deleted: List[Dict[str, str]],
1027
+ filters: Dict[str, Any]
1028
+ ) -> List[Dict[str, int]]:
1029
+ """Delete the specified relationships from the graph.
1030
+
1031
+ Args:
1032
+ to_be_deleted: List of relationships to delete with source, destination, relationship.
1033
+ filters: Dictionary containing user_id, agent_id, run_id.
1034
+
1035
+ Returns:
1036
+ List of dictionaries containing deleted_count for each deletion operation.
1037
+ """
1038
+ results = []
1039
+
1040
+ for item in to_be_deleted:
1041
+ source = item["source"]
1042
+ destination = item["destination"]
1043
+ relationship = item["relationship"]
1044
+
1045
+ # First, find the source and destination entities by name
1046
+ source_entities = self.client.get(
1047
+ table_name=constants.TABLE_ENTITIES,
1048
+ ids=None,
1049
+ output_column_name=["id", "name"],
1050
+ where_clause=[text(f"name = :source_name").bindparams(
1051
+ bindparam("source_name", source)
1052
+ )]
1053
+ )
1054
+
1055
+ dest_entities = self.client.get(
1056
+ table_name=constants.TABLE_ENTITIES,
1057
+ ids=None,
1058
+ output_column_name=["id", "name"],
1059
+ where_clause=[text(f"name = :dest_name").bindparams(
1060
+ bindparam("dest_name", destination)
1061
+ )]
1062
+ )
1063
+
1064
+ # Get entity IDs
1065
+ source_rows = source_entities.fetchall() if source_entities else []
1066
+ dest_rows = dest_entities.fetchall() if dest_entities else []
1067
+
1068
+ source_ids = [e[0] for e in source_rows]
1069
+ dest_ids = [e[0] for e in dest_rows]
1070
+
1071
+ # Check if we found any entities
1072
+ if not source_ids or not dest_ids:
1073
+ logger.warning(
1074
+ "Could not find entities: source='%s' (found %d), destination='%s' (found %d)",
1075
+ source,
1076
+ len(source_ids),
1077
+ destination,
1078
+ len(dest_ids)
1079
+ )
1080
+ results.append({"deleted_count": 0})
1081
+ continue
1082
+
1083
+ # Build where clause for relationship deletion
1084
+ where_clauses = [
1085
+ "relationship_type = :rel_type",
1086
+ "user_id = :user_id",
1087
+ ]
1088
+ params = {
1089
+ "rel_type": relationship,
1090
+ "user_id": filters["user_id"],
1091
+ }
1092
+
1093
+ if filters.get("agent_id"):
1094
+ where_clauses.append("agent_id = :agent_id")
1095
+ params["agent_id"] = filters["agent_id"]
1096
+ if filters.get("run_id"):
1097
+ where_clauses.append("run_id = :run_id")
1098
+ params["run_id"] = filters["run_id"]
1099
+
1100
+ # Add source and destination entity ID conditions.
1101
+ source_conditions = " OR ".join([
1102
+ f"source_entity_id = :src_id_{i}"
1103
+ for i in range(len(source_ids))
1104
+ ])
1105
+ dest_conditions = " OR ".join([
1106
+ f"destination_entity_id = :dest_id_{i}"
1107
+ for i in range(len(dest_ids))
1108
+ ])
1109
+
1110
+ where_clauses.append(f"({source_conditions}) AND ({dest_conditions})")
1111
+
1112
+ # Add entity ID parameters
1113
+ for i, src_id in enumerate(source_ids):
1114
+ params[f"src_id_{i}"] = src_id
1115
+ for i, dest_id in enumerate(dest_ids):
1116
+ params[f"dest_id_{i}"] = dest_id
1117
+
1118
+ where_str = " AND ".join(where_clauses)
1119
+ where_clause = text(where_str).bindparams(**params)
1120
+
1121
+ # Delete relationships using pyobvector delete method.
1122
+ try:
1123
+ delete_result = self.client.delete(
1124
+ table_name=constants.TABLE_RELATIONSHIPS,
1125
+ where_clause=[where_clause]
1126
+ )
1127
+ deleted_count = (
1128
+ delete_result.rowcount
1129
+ if hasattr(delete_result, "rowcount")
1130
+ else 1
1131
+ )
1132
+ results.append({"deleted_count": deleted_count})
1133
+ except Exception as e:
1134
+ logger.warning("Error deleting relationship: %s", e)
1135
+ results.append({"deleted_count": 0})
1136
+
1137
+ return results
1138
+
1139
+ def _add_entities(
1140
+ self,
1141
+ to_be_added: List[Dict[str, str]],
1142
+ filters: Dict[str, Any],
1143
+ entity_type_map: Dict[str, str]
1144
+ ) -> List[Dict[str, str]]:
1145
+ """Add new entities and relationships to the graph.
1146
+
1147
+ Args:
1148
+ to_be_added: List of relationships to add with source, destination, relationship.
1149
+ filters: Dictionary containing user_id, agent_id, run_id.
1150
+ entity_type_map: Mapping of entity names to types.
1151
+
1152
+ Returns:
1153
+ List of dictionaries containing added source, relationship, and target.
1154
+ """
1155
+ results = []
1156
+
1157
+ for item in to_be_added:
1158
+ source = item["source"]
1159
+ destination = item["destination"]
1160
+ relationship = item["relationship"]
1161
+
1162
+ source_embedding = self.embedding_model.embed(source)
1163
+ dest_embedding = self.embedding_model.embed(destination)
1164
+
1165
+ # Search for existing similar nodes.
1166
+ source_node = self._search_source_node(source_embedding, filters,
1167
+ threshold=constants.DEFAULT_SIMILARITY_THRESHOLD, limit=1)
1168
+ dest_node = self._search_destination_node(dest_embedding, filters,
1169
+ threshold=constants.DEFAULT_SIMILARITY_THRESHOLD, limit=1)
1170
+
1171
+ # Get or create source entity
1172
+ if source_node:
1173
+ source_id = source_node["id"]
1174
+ else:
1175
+ source_id = self._create_entity(source, entity_type_map.get(source, "entity"),
1176
+ source_embedding, filters)
1177
+
1178
+ # Get or create destination entity
1179
+ if dest_node:
1180
+ dest_id = dest_node["id"]
1181
+ else:
1182
+ dest_id = self._create_entity(destination, entity_type_map.get(destination, "entity"),
1183
+ dest_embedding, filters)
1184
+
1185
+ # Create or update relationship
1186
+ rel_result = self._create_or_update_relationship(source_id, dest_id, relationship, filters)
1187
+ results.append(rel_result)
1188
+
1189
+ return results
1190
+
1191
+ def _search_node(
1192
+ self,
1193
+ name: Optional[str],
1194
+ embedding: List[float],
1195
+ filters: Dict[str, Any],
1196
+ threshold: float = None,
1197
+ limit: int = 10
1198
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
1199
+ """Search for a node by embedding similarity within threshold.
1200
+
1201
+ Args:
1202
+ name: Node name (not currently used, kept for compatibility).
1203
+ embedding: Vector embedding to search with.
1204
+ filters: Dictionary containing user_id, agent_id, run_id.
1205
+ threshold: Distance threshold for filtering results.
1206
+ Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
1207
+ limit: Maximum number of results to return. Defaults to 10.
1208
+
1209
+ Returns:
1210
+ If limit==1: Single dict with id, name, distance, or None if no match.
1211
+ If limit>1: List of dicts with id, name, distance, or None if no match.
1212
+ """
1213
+ if threshold is None:
1214
+ threshold = constants.DEFAULT_SIMILARITY_THRESHOLD
1215
+
1216
+ # Create Table object to access columns for WHERE clause.
1217
+ table = Table(constants.TABLE_ENTITIES, self.metadata, autoload_with=self.engine)
1218
+ vec_str = "[" + ",".join([str(np.float32(v)) for v in embedding]) + "]"
1219
+ distance_expr = l2_distance(table.c.embedding, vec_str)
1220
+ where_clause = [distance_expr < threshold]
1221
+
1222
+ results = self.client.ann_search(
1223
+ table_name=constants.TABLE_ENTITIES,
1224
+ vec_data=embedding,
1225
+ vec_column_name="embedding",
1226
+ distance_func=l2_distance,
1227
+ with_dist=True,
1228
+ topk=limit,
1229
+ output_column_names=["id", "name"],
1230
+ where_clause=where_clause,
1231
+ )
1232
+
1233
+ rows = results.fetchall()
1234
+ if rows:
1235
+ if limit == 1:
1236
+ row = rows[0]
1237
+ # Row format: (id, name, distance)
1238
+ entity_id, entity_name = row[0], row[1]
1239
+ distance = row[-1] # Distance is always the last column
1240
+
1241
+ return {"id": entity_id, "name": entity_name, "distance": distance}
1242
+ else:
1243
+ return [{"id": row[0], "name": row[1], "distance": row[-1]} for row in rows]
1244
+
1245
+ return None
1246
+
1247
+ def _create_entity(
1248
+ self,
1249
+ name: str,
1250
+ entity_type: str,
1251
+ embedding: List[float],
1252
+ filters: Dict[str, Any]
1253
+ ) -> int:
1254
+ """Create a new entity in the graph.
1255
+
1256
+ Args:
1257
+ name: Entity name.
1258
+ entity_type: Type of the entity.
1259
+ embedding: Vector embedding of the entity.
1260
+ filters: Dictionary containing user_id, agent_id, run_id.
1261
+
1262
+ Returns:
1263
+ Snowflake ID of the created entity.
1264
+ """
1265
+ entity_id = generate_snowflake_id()
1266
+
1267
+ # Prepare data for insertion using pyobvector API
1268
+ record = {
1269
+ "id": entity_id,
1270
+ "name": name,
1271
+ "entity_type": entity_type,
1272
+ "embedding": embedding,
1273
+ }
1274
+
1275
+ # Use pyobvector upsert method
1276
+ self.client.upsert(
1277
+ table_name=constants.TABLE_ENTITIES,
1278
+ data=[record],
1279
+ )
1280
+
1281
+ logger.debug("Created entity: %s with id: %s", name, entity_id)
1282
+ return entity_id
1283
+
1284
+ def _create_or_update_relationship(
1285
+ self,
1286
+ source_id: int,
1287
+ dest_id: int,
1288
+ relationship_type: str,
1289
+ filters: Dict[str, Any]
1290
+ ) -> Dict[str, str]:
1291
+ """Create or update a relationship between two entities.
1292
+
1293
+ Args:
1294
+ source_id: Snowflake ID of the source entity.
1295
+ dest_id: Snowflake ID of the destination entity.
1296
+ relationship_type: Type of the relationship.
1297
+ filters: Dictionary containing user_id, agent_id, run_id.
1298
+
1299
+ Returns:
1300
+ Dictionary containing source, relationship, and target names.
1301
+ """
1302
+ # First, check if relationship already exists
1303
+ where_clause, params = self._build_where_clause_with_filters(filters)
1304
+
1305
+ # Add relationship-specific conditions to the where clause
1306
+ additional_conditions = " AND source_entity_id = :source_id AND destination_entity_id = :dest_id AND relationship_type = :rel_type"
1307
+
1308
+ # Rebuild the where clause with additional conditions
1309
+ where_str = str(where_clause[0].text) + additional_conditions
1310
+
1311
+ params.update({
1312
+ "source_id": source_id,
1313
+ "dest_id": dest_id,
1314
+ "rel_type": relationship_type
1315
+ })
1316
+
1317
+ where_clause_with_params = text(where_str).bindparams(**params)
1318
+
1319
+ # Check if relationship exists
1320
+ existing_relationships = self.client.get(
1321
+ table_name=constants.TABLE_RELATIONSHIPS,
1322
+ ids=None,
1323
+ output_column_name=["id"],
1324
+ where_clause=[where_clause_with_params]
1325
+ )
1326
+
1327
+ existing_rows = existing_relationships.fetchall()
1328
+ if not existing_rows:
1329
+ # Relationship doesn't exist, create new one
1330
+ new_record = {
1331
+ "id": generate_snowflake_id(),
1332
+ "source_entity_id": source_id,
1333
+ "relationship_type": relationship_type,
1334
+ "destination_entity_id": dest_id,
1335
+ "user_id": filters["user_id"],
1336
+ "agent_id": filters.get("agent_id"),
1337
+ "run_id": filters.get("run_id"),
1338
+ }
1339
+
1340
+ self.client.insert(
1341
+ table_name=constants.TABLE_RELATIONSHIPS,
1342
+ data=[new_record],
1343
+ )
1344
+
1345
+ # Get the names for return value using pyobvector get method
1346
+ # First get the entities
1347
+ source_entity = self.client.get(
1348
+ table_name=constants.TABLE_ENTITIES,
1349
+ ids=[source_id],
1350
+ output_column_name=["id", "name"]
1351
+ ).fetchone()
1352
+
1353
+ dest_entity = self.client.get(
1354
+ table_name=constants.TABLE_ENTITIES,
1355
+ ids=[dest_id],
1356
+ output_column_name=["id", "name"]
1357
+ ).fetchone()
1358
+
1359
+ return {
1360
+ "source": source_entity[1] if source_entity else None,
1361
+ "relationship": relationship_type,
1362
+ "target": dest_entity[1] if dest_entity else None,
1363
+ }
1364
+
1365
+ def _remove_spaces_from_entities(self, entity_list: List[Dict[str, str]]) -> List[Dict[str, str]]:
1366
+ """Clean entity names by replacing spaces with underscores.
1367
+
1368
+ Args:
1369
+ entity_list: List of dictionaries containing source, destination, relationship.
1370
+
1371
+ Returns:
1372
+ Cleaned entity list with spaces replaced by underscores and lowercased.
1373
+ """
1374
+ for item in entity_list:
1375
+ item["source"] = item["source"].lower().replace(" ", "_")
1376
+ item["relationship"] = item["relationship"].lower().replace(" ", "_")
1377
+ item["destination"] = item["destination"].lower().replace(" ", "_")
1378
+ return entity_list
1379
+
1380
+ def _search_source_node(
1381
+ self,
1382
+ source_embedding: List[float],
1383
+ filters: Dict[str, Any],
1384
+ threshold: float = None,
1385
+ limit: int = 10
1386
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
1387
+ """Search for a source node by embedding similarity (compatibility method).
1388
+
1389
+ Args:
1390
+ source_embedding: Vector embedding to search with.
1391
+ filters: Dictionary containing user_id, agent_id, run_id.
1392
+ threshold: Distance threshold for filtering results.
1393
+ Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
1394
+ limit: Maximum number of results to return. Defaults to 10.
1395
+
1396
+ Returns:
1397
+ Search results from _search_node method.
1398
+ """
1399
+ return self._search_node("source", source_embedding, filters, threshold, limit)
1400
+
1401
+ def _search_destination_node(
1402
+ self,
1403
+ destination_embedding: List[float],
1404
+ filters: Dict[str, Any],
1405
+ threshold: float = None,
1406
+ limit: int = 10
1407
+ ) -> Union[Dict[str, Any], List[Dict[str, Any]], None]:
1408
+ """Search for a destination node by embedding similarity (compatibility method).
1409
+
1410
+ Args:
1411
+ destination_embedding: Vector embedding to search with.
1412
+ filters: Dictionary containing user_id, agent_id, run_id.
1413
+ threshold: Distance threshold for filtering results.
1414
+ Defaults to constants.DEFAULT_SIMILARITY_THRESHOLD if None.
1415
+ limit: Maximum number of results to return. Defaults to 10.
1416
+
1417
+ Returns:
1418
+ Search results from _search_node method.
1419
+ """
1420
+ return self._search_node("destination", destination_embedding, filters, threshold, limit)
1421
+
1422
+ def reset(self) -> None:
1423
+ """Reset the graph by clearing all nodes and relationships.
1424
+
1425
+ This method drops both entities and relationships tables and recreates them.
1426
+ """
1427
+ logger.warning("Clearing graph...")
1428
+
1429
+ # Use pyobvector API to drop tables
1430
+ if self.client.check_table_exists(constants.TABLE_RELATIONSHIPS):
1431
+ self.client.drop_table_if_exist(constants.TABLE_RELATIONSHIPS)
1432
+ logger.info("Dropped %s table", constants.TABLE_RELATIONSHIPS)
1433
+
1434
+ if self.client.check_table_exists(constants.TABLE_ENTITIES):
1435
+ self.client.drop_table_if_exist(constants.TABLE_ENTITIES)
1436
+ logger.info("Dropped %s table", constants.TABLE_ENTITIES)
1437
+
1438
+ # Recreate tables
1439
+ self._create_tables()
1440
+
1441
+ logger.info("Graph reset completed")