veadk-python 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of veadk-python might be problematic. Click here for more details.

Files changed (46) hide show
  1. veadk/a2a/remote_ve_agent.py +63 -6
  2. veadk/agent.py +10 -3
  3. veadk/agent_builder.py +2 -3
  4. veadk/auth/veauth/ark_veauth.py +43 -51
  5. veadk/auth/veauth/utils.py +57 -0
  6. veadk/cli/cli.py +2 -0
  7. veadk/cli/cli_kb.py +75 -0
  8. veadk/cli/cli_web.py +4 -0
  9. veadk/configs/model_configs.py +3 -3
  10. veadk/consts.py +9 -0
  11. veadk/integrations/__init__.py +13 -0
  12. veadk/integrations/ve_viking_db_memory/__init__.py +13 -0
  13. veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py +293 -0
  14. veadk/knowledgebase/knowledgebase.py +19 -32
  15. veadk/memory/__init__.py +1 -1
  16. veadk/memory/long_term_memory.py +40 -68
  17. veadk/memory/long_term_memory_backends/base_backend.py +4 -2
  18. veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
  19. veadk/memory/long_term_memory_backends/mem0_backend.py +25 -10
  20. veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
  21. veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
  22. veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +56 -35
  23. veadk/memory/short_term_memory.py +12 -8
  24. veadk/memory/short_term_memory_backends/postgresql_backend.py +3 -1
  25. veadk/runner.py +42 -19
  26. veadk/tools/builtin_tools/generate_image.py +56 -17
  27. veadk/tools/builtin_tools/image_edit.py +17 -7
  28. veadk/tools/builtin_tools/image_generate.py +17 -7
  29. veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
  30. veadk/tools/builtin_tools/video_generate.py +11 -9
  31. veadk/tools/builtin_tools/web_search.py +10 -3
  32. veadk/tools/load_knowledgebase_tool.py +12 -0
  33. veadk/tracing/telemetry/attributes/extractors/llm_attributes_extractors.py +5 -0
  34. veadk/tracing/telemetry/attributes/extractors/tool_attributes_extractors.py +7 -0
  35. veadk/tracing/telemetry/exporters/apmplus_exporter.py +82 -2
  36. veadk/tracing/telemetry/exporters/inmemory_exporter.py +8 -2
  37. veadk/tracing/telemetry/telemetry.py +41 -5
  38. veadk/utils/misc.py +6 -10
  39. veadk/utils/volcengine_sign.py +2 -0
  40. veadk/version.py +1 -1
  41. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/METADATA +4 -3
  42. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/RECORD +46 -40
  43. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/WHEEL +0 -0
  44. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/entry_points.txt +0 -0
  45. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/licenses/LICENSE +0 -0
  46. {veadk_python-0.2.9.dist-info → veadk_python-0.2.11.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from typing import Any
16
- from typing_extensions import override
16
+
17
17
  from pydantic import Field
18
+ from typing_extensions import override
18
19
 
19
20
  from veadk.configs.database_configs import Mem0Config
20
-
21
-
22
21
  from veadk.memory.long_term_memory_backends.base_backend import (
23
22
  BaseLongTermMemoryBackend,
24
23
  )
@@ -46,12 +45,17 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
46
45
 
47
46
  try:
48
47
  self._mem0_client = MemoryClient(
49
- # base_url=self.mem0_config.base_url, # mem0 endpoint
48
+ host=self.mem0_config.base_url, # mem0 endpoint
50
49
  api_key=self.mem0_config.api_key, # mem0 API key
51
50
  )
51
+ logger.info(
52
+ f"Initialized Mem0 client for host: {self.mem0_config.base_url}"
53
+ )
52
54
  logger.info(f"Initialized Mem0 client for index: {self.index}")
53
55
  except Exception as e:
54
- logger.error(f"Failed to initialize Mem0 client: {str(e)}")
56
+ logger.error(
57
+ f"Failed to initialize Mem0 client for host {self.mem0_config.base_url} : {str(e)}"
58
+ )
55
59
  raise
56
60
 
57
61
  def precheck_index_naming(self):
@@ -61,7 +65,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
61
65
  pass
62
66
 
63
67
  @override
64
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
68
+ def save_memory(
69
+ self, event_strings: list[str], user_id: str = "default_user", **kwargs
70
+ ) -> bool:
65
71
  """Save memory to Mem0
66
72
 
67
73
  Args:
@@ -71,8 +77,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
71
77
  Returns:
72
78
  bool: True if saved successfully, False otherwise
73
79
  """
74
- user_id = kwargs.get("user_id", "default_user")
75
-
76
80
  try:
77
81
  logger.info(
78
82
  f"Saving {len(event_strings)} events to Mem0 for user: {user_id}"
@@ -84,6 +88,7 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
84
88
  [{"role": "user", "content": event_string}],
85
89
  user_id=user_id,
86
90
  output_format="v1.1",
91
+ async_mode=True,
87
92
  )
88
93
  logger.debug(f"Saved memory result: {result}")
89
94
 
@@ -94,7 +99,9 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
94
99
  return False
95
100
 
96
101
  @override
97
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
102
+ def search_memory(
103
+ self, query: str, top_k: int, user_id: str = "default_user", **kwargs
104
+ ) -> list[str]:
98
105
  """Search memory from Mem0
99
106
 
100
107
  Args:
@@ -105,7 +112,6 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
105
112
  Returns:
106
113
  list[str]: List of memory strings
107
114
  """
108
- user_id = kwargs.get("user_id", "default_user")
109
115
 
110
116
  try:
111
117
  logger.info(
@@ -116,7 +122,16 @@ class Mem0LTMBackend(BaseLongTermMemoryBackend):
116
122
  query, user_id=user_id, output_format="v1.1", top_k=top_k
117
123
  )
118
124
 
125
+ logger.debug(f"return relevant memories: {memories}")
126
+
119
127
  memory_list = []
128
+ # 如果 memories 是列表,直接返回
129
+ if isinstance(memories, list):
130
+ for mem in memories:
131
+ if "memory" in mem:
132
+ memory_list.append(mem["memory"])
133
+ return memory_list
134
+
120
135
  if memories.get("results", []):
121
136
  for mem in memories["results"]:
122
137
  if "memory" in mem:
@@ -14,11 +14,7 @@
14
14
 
15
15
  import re
16
16
 
17
- from llama_index.core import (
18
- Document,
19
- StorageContext,
20
- VectorStoreIndex,
21
- )
17
+ from llama_index.core import Document, VectorStoreIndex
22
18
  from llama_index.core.schema import BaseNode
23
19
  from llama_index.embeddings.openai_like import OpenAILikeEmbedding
24
20
  from pydantic import Field
@@ -31,6 +27,7 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
31
27
  from veadk.memory.long_term_memory_backends.base_backend import (
32
28
  BaseLongTermMemoryBackend,
33
29
  )
30
+ from veadk.utils.logger import get_logger
34
31
 
35
32
  try:
36
33
  from llama_index.vector_stores.opensearch import (
@@ -42,6 +39,8 @@ except ImportError:
42
39
  "Please install VeADK extensions\npip install veadk-python[extensions]"
43
40
  )
44
41
 
42
+ logger = get_logger(__name__)
43
+
45
44
 
46
45
  class OpensearchLTMBackend(BaseLongTermMemoryBackend):
47
46
  opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig)
@@ -52,19 +51,30 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
52
51
  )
53
52
  """Embedding model configs"""
54
53
 
55
- def precheck_index_naming(self):
54
+ def model_post_init(self, __context: Any) -> None:
55
+ self._embed_model = OpenAILikeEmbedding(
56
+ model_name=self.embedding_config.name,
57
+ api_key=self.embedding_config.api_key,
58
+ api_base=self.embedding_config.api_base,
59
+ )
60
+
61
+ def precheck_index_naming(self, index: str):
56
62
  if not (
57
- isinstance(self.index, str)
58
- and not self.index.startswith(("_", "-"))
59
- and self.index.islower()
60
- and re.match(r"^[a-z0-9_\-.]+$", self.index)
63
+ isinstance(index, str)
64
+ and not index.startswith(("_", "-"))
65
+ and index.islower()
66
+ and re.match(r"^[a-z0-9_\-.]+$", index)
61
67
  ):
62
68
  raise ValueError(
63
- "The index name does not conform to the naming rules of OpenSearch"
69
+ f"The index name {index} does not conform to the naming rules of OpenSearch"
64
70
  )
65
71
 
66
- def model_post_init(self, __context: Any) -> None:
67
- self._opensearch_client = OpensearchVectorClient(
72
+ def _create_vector_index(self, index: str) -> VectorStoreIndex:
73
+ logger.info(f"Create OpenSearch vector index with index={index}")
74
+
75
+ self.precheck_index_naming(index)
76
+
77
+ opensearch_client = OpensearchVectorClient(
68
78
  endpoint=self.opensearch_config.host,
69
79
  port=self.opensearch_config.port,
70
80
  http_auth=(
@@ -74,39 +84,33 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
74
84
  use_ssl=True,
75
85
  verify_certs=False,
76
86
  dim=self.embedding_config.dim,
77
- index=self.index, # collection name
87
+ index=index,
78
88
  )
79
-
80
- self._vector_store = OpensearchVectorStore(client=self._opensearch_client)
81
-
82
- self._storage_context = StorageContext.from_defaults(
83
- vector_store=self._vector_store
84
- )
85
-
86
- self._embed_model = OpenAILikeEmbedding(
87
- model_name=self.embedding_config.name,
88
- api_key=self.embedding_config.api_key,
89
- api_base=self.embedding_config.api_base,
90
- )
91
-
92
- self._vector_index = VectorStoreIndex.from_documents(
93
- documents=[],
94
- storage_context=self._storage_context,
95
- embed_model=self._embed_model,
89
+ vector_store = OpensearchVectorStore(client=opensearch_client)
90
+ return VectorStoreIndex.from_vector_store(
91
+ vector_store=vector_store, embed_model=self._embed_model
96
92
  )
97
- self._retriever = self._vector_index.as_retriever()
98
93
 
99
94
  @override
100
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
95
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
96
+ index = f"{self.index}_{user_id}"
97
+ vector_index = self._create_vector_index(index)
98
+
101
99
  for event_string in event_strings:
102
100
  document = Document(text=event_string)
103
101
  nodes = self._split_documents([document])
104
- self._vector_index.insert_nodes(nodes)
102
+ vector_index.insert_nodes(nodes)
105
103
  return True
106
104
 
107
105
  @override
108
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
109
- _retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
106
+ def search_memory(
107
+ self, user_id: str, query: str, top_k: int, **kwargs
108
+ ) -> list[str]:
109
+ index = f"{self.index}_{user_id}"
110
+
111
+ vector_index = self._create_vector_index(index)
112
+
113
+ _retriever = vector_index.as_retriever(similarity_top_k=top_k)
110
114
  retrieved_nodes = _retriever.retrieve(query)
111
115
  return [node.text for node in retrieved_nodes]
112
116
 
@@ -12,11 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from llama_index.core import (
16
- Document,
17
- StorageContext,
18
- VectorStoreIndex,
19
- )
15
+ from llama_index.core import Document, VectorStoreIndex
20
16
  from llama_index.core.schema import BaseNode
21
17
  from llama_index.embeddings.openai_like import OpenAILikeEmbedding
22
18
  from pydantic import Field
@@ -29,21 +25,22 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
29
25
  from veadk.memory.long_term_memory_backends.base_backend import (
30
26
  BaseLongTermMemoryBackend,
31
27
  )
28
+ from veadk.utils.logger import get_logger
32
29
 
33
30
  try:
34
31
  from llama_index.vector_stores.redis import RedisVectorStore
35
- from llama_index.vector_stores.redis.schema import (
36
- RedisIndexInfo,
37
- RedisVectorStoreSchema,
38
- )
39
32
  from redis import Redis
40
- from redisvl.schema.fields import BaseVectorFieldAttributes
33
+ from redisvl.schema import IndexSchema
34
+
41
35
  except ImportError:
42
36
  raise ImportError(
43
37
  "Please install VeADK extensions\npip install veadk-python[extensions]"
44
38
  )
45
39
 
46
40
 
41
+ logger = get_logger(__name__)
42
+
43
+
47
44
  class RedisLTMBackend(BaseLongTermMemoryBackend):
48
45
  redis_config: RedisConfig = Field(default_factory=RedisConfig)
49
46
  """Redis client configs"""
@@ -53,67 +50,83 @@ class RedisLTMBackend(BaseLongTermMemoryBackend):
53
50
  )
54
51
  """Embedding model configs"""
55
52
 
56
- def precheck_index_naming(self):
53
+ def model_post_init(self, __context: Any) -> None:
54
+ self._embed_model = OpenAILikeEmbedding(
55
+ model_name=self.embedding_config.name,
56
+ api_key=self.embedding_config.api_key,
57
+ api_base=self.embedding_config.api_base,
58
+ )
59
+
60
+ def precheck_index_naming(self, index: str):
57
61
  # no checking
58
62
  pass
59
63
 
60
- def model_post_init(self, __context: Any) -> None:
64
+ def _create_vector_index(self, index: str) -> VectorStoreIndex:
65
+ logger.info(f"Create Redis vector index with index={index}")
66
+
67
+ self.precheck_index_naming(index)
68
+
61
69
  # We will use `from_url` to init Redis client once the
62
70
  # AK/SK -> STS token is ready.
63
71
  # self._redis_client = Redis.from_url(url=...)
64
-
65
- self._redis_client = Redis(
72
+ redis_client = Redis(
66
73
  host=self.redis_config.host,
67
74
  port=self.redis_config.port,
68
75
  db=self.redis_config.db,
69
76
  password=self.redis_config.password,
70
77
  )
71
78
 
72
- self._embed_model = OpenAILikeEmbedding(
73
- model_name=self.embedding_config.name,
74
- api_key=self.embedding_config.api_key,
75
- api_base=self.embedding_config.api_base,
79
+ # Create an index for each user
80
+ # Should be Optimized in the future
81
+ schema = IndexSchema.from_dict(
82
+ {
83
+ "index": {"name": index, "prefix": index, "key_separator": "_"},
84
+ "fields": [
85
+ {"name": "id", "type": "tag", "attrs": {"sortable": False}},
86
+ {"name": "doc_id", "type": "tag", "attrs": {"sortable": False}},
87
+ {"name": "text", "type": "text", "attrs": {"weight": 1.0}},
88
+ {
89
+ "name": "vector",
90
+ "type": "vector",
91
+ "attrs": {
92
+ "dims": self.embedding_config.dim,
93
+ "algorithm": "flat",
94
+ "distance_metric": "cosine",
95
+ },
96
+ },
97
+ ],
98
+ }
76
99
  )
100
+ vector_store = RedisVectorStore(schema=schema, redis_client=redis_client)
77
101
 
78
- self._schema = RedisVectorStoreSchema(
79
- index=RedisIndexInfo(name=self.index),
80
- )
81
- if "vector" in self._schema.fields:
82
- vector_field = self._schema.fields["vector"]
83
- if (
84
- vector_field
85
- and vector_field.attrs
86
- and isinstance(vector_field.attrs, BaseVectorFieldAttributes)
87
- ):
88
- vector_field.attrs.dims = self.embedding_config.dim
89
- self._vector_store = RedisVectorStore(
90
- schema=self._schema,
91
- redis_client=self._redis_client,
92
- overwrite=True,
93
- collection_name=self.index,
94
- )
95
-
96
- self._storage_context = StorageContext.from_defaults(
97
- vector_store=self._vector_store
102
+ logger.info(
103
+ f"Create vector store done, index_name={vector_store.index_name} prefix={vector_store.schema.index.prefix}"
98
104
  )
99
105
 
100
- self._vector_index = VectorStoreIndex.from_documents(
101
- documents=[],
102
- storage_context=self._storage_context,
103
- embed_model=self._embed_model,
106
+ return VectorStoreIndex.from_vector_store(
107
+ vector_store=vector_store, embed_model=self._embed_model
104
108
  )
105
109
 
106
110
  @override
107
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
111
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
112
+ index = f"veadk-ltm/{self.index}/{user_id}"
113
+ vector_index = self._create_vector_index(index)
114
+
108
115
  for event_string in event_strings:
109
116
  document = Document(text=event_string)
110
117
  nodes = self._split_documents([document])
111
- self._vector_index.insert_nodes(nodes)
118
+ vector_index.insert_nodes(nodes)
119
+
112
120
  return True
113
121
 
114
122
  @override
115
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
116
- _retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
123
+ def search_memory(
124
+ self, user_id: str, query: str, top_k: int, **kwargs
125
+ ) -> list[str]:
126
+ index = f"veadk-ltm/{self.index}/{user_id}"
127
+ vector_index = self._create_vector_index(index)
128
+
129
+ _retriever = vector_index.as_retriever(similarity_top_k=top_k)
117
130
  retrieved_nodes = _retriever.retrieve(query)
118
131
  return [node.text for node in retrieved_nodes]
119
132
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import json
16
+ import os
16
17
  import re
17
18
  import time
18
19
  import uuid
@@ -22,34 +23,37 @@ from pydantic import Field
22
23
  from typing_extensions import override
23
24
 
24
25
  import veadk.config # noqa E401
25
- from veadk.config import getenv
26
+ from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
27
+ from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
28
+ VikingDBMemoryClient,
29
+ )
26
30
  from veadk.memory.long_term_memory_backends.base_backend import (
27
31
  BaseLongTermMemoryBackend,
28
32
  )
29
33
  from veadk.utils.logger import get_logger
30
34
 
31
- try:
32
- from mcp_server_vikingdb_memory.common.memory_client import VikingDBMemoryService
33
- except ImportError:
34
- raise ImportError(
35
- "Please install VeADK extensions\npip install veadk-python[extensions]"
36
- )
37
-
38
35
  logger = get_logger(__name__)
39
36
 
40
37
 
41
38
  class VikingDBLTMBackend(BaseLongTermMemoryBackend):
42
- volcengine_access_key: str = Field(
43
- default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
39
+ volcengine_access_key: str | None = Field(
40
+ default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
44
41
  )
45
42
 
46
- volcengine_secret_key: str = Field(
47
- default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
43
+ volcengine_secret_key: str | None = Field(
44
+ default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
48
45
  )
49
46
 
47
+ session_token: str = ""
48
+
50
49
  region: str = "cn-beijing"
51
50
  """VikingDB memory region"""
52
51
 
52
+ def model_post_init(self, __context: Any) -> None:
53
+ # check whether collection exist, if not, create it
54
+ if not self._collection_exist():
55
+ self._create_collection()
56
+
53
57
  def precheck_index_naming(self):
54
58
  if not (
55
59
  isinstance(self.index, str)
@@ -60,37 +64,39 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
60
64
  "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
61
65
  )
62
66
 
63
- def model_post_init(self, __context: Any) -> None:
64
- self._client = VikingDBMemoryService(
65
- ak=self.volcengine_access_key,
66
- sk=self.volcengine_secret_key,
67
- region=self.region,
68
- )
69
-
70
- # check whether collection exist, if not, create it
71
- if not self._collection_exist():
72
- self._create_collection()
73
-
74
67
  def _collection_exist(self) -> bool:
75
68
  try:
76
- self._client.get_collection(collection_name=self.index)
69
+ client = self._get_client()
70
+ client.get_collection(collection_name=self.index)
77
71
  return True
78
72
  except Exception:
79
73
  return False
80
74
 
81
75
  def _create_collection(self) -> None:
82
- response = self._client.create_collection(
76
+ client = self._get_client()
77
+ response = client.create_collection(
83
78
  collection_name=self.index,
84
79
  description="Created by Volcengine Agent Development Kit VeADK",
85
80
  builtin_event_types=["sys_event_v1"],
86
81
  )
87
82
  return response
88
83
 
84
+ def _get_client(self) -> VikingDBMemoryClient:
85
+ if not (self.volcengine_access_key and self.volcengine_secret_key):
86
+ cred = get_credential_from_vefaas_iam()
87
+ self.volcengine_access_key = cred.access_key_id
88
+ self.volcengine_secret_key = cred.secret_access_key
89
+ self.session_token = cred.session_token
90
+
91
+ return VikingDBMemoryClient(
92
+ ak=self.volcengine_access_key,
93
+ sk=self.volcengine_secret_key,
94
+ sts_token=self.session_token,
95
+ region=self.region,
96
+ )
97
+
89
98
  @override
90
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
91
- user_id = kwargs.get("user_id")
92
- if user_id is None:
93
- raise ValueError("user_id is required")
99
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
94
100
  session_id = str(uuid.uuid1())
95
101
  messages = []
96
102
  for raw_events in event_strings:
@@ -105,31 +111,46 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
105
111
  "default_assistant_id": "assistant",
106
112
  "time": int(time.time() * 1000),
107
113
  }
108
- response = self._client.add_messages(
114
+
115
+ logger.debug(
116
+ f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
117
+ )
118
+
119
+ client = self._get_client()
120
+ response = client.add_messages(
109
121
  collection_name=self.index,
110
122
  messages=messages,
111
123
  metadata=metadata,
112
124
  session_id=session_id,
113
125
  )
114
126
 
127
+ logger.debug(f"Response from add memory to VikingDB: {response}")
128
+
115
129
  if not response.get("code") == 0:
116
130
  raise ValueError(f"Save VikingDB memory error: {response}")
117
131
 
118
132
  return True
119
133
 
120
134
  @override
121
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
122
- user_id = kwargs.get("user_id")
123
- if user_id is None:
124
- raise ValueError("user_id is required")
135
+ def search_memory(
136
+ self, user_id: str, query: str, top_k: int, **kwargs
137
+ ) -> list[str]:
125
138
  filter = {
126
139
  "user_id": user_id,
127
140
  "memory_type": ["sys_event_v1"],
128
141
  }
129
- response = self._client.search_memory(
142
+
143
+ logger.debug(
144
+ f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
145
+ )
146
+
147
+ client = self._get_client()
148
+ response = client.search_memory(
130
149
  collection_name=self.index, query=query, filter=filter, limit=top_k
131
150
  )
132
151
 
152
+ logger.debug(f"Response from search memory in VikingDB: {response}")
153
+
133
154
  if not response.get("code") == 0:
134
155
  raise ValueError(f"Search VikingDB memory error: {response}")
135
156
 
@@ -19,6 +19,7 @@ from google.adk.sessions import (
19
19
  BaseSessionService,
20
20
  DatabaseSessionService,
21
21
  InMemorySessionService,
22
+ Session,
22
23
  )
23
24
  from pydantic import BaseModel, Field, PrivateAttr
24
25
 
@@ -106,7 +107,7 @@ class ShortTermMemory(BaseModel):
106
107
  app_name: str,
107
108
  user_id: str,
108
109
  session_id: str,
109
- ) -> None:
110
+ ) -> Session | None:
110
111
  if isinstance(self._session_service, DatabaseSessionService):
111
112
  list_sessions_response = await self._session_service.list_sessions(
112
113
  app_name=app_name, user_id=user_id
@@ -116,13 +117,16 @@ class ShortTermMemory(BaseModel):
116
117
  f"Loaded {len(list_sessions_response.sessions)} sessions from db {self.db_url}."
117
118
  )
118
119
 
119
- if (
120
- await self._session_service.get_session(
121
- app_name=app_name, user_id=user_id, session_id=session_id
120
+ session = await self._session_service.get_session(
121
+ app_name=app_name, user_id=user_id, session_id=session_id
122
+ )
123
+
124
+ if session:
125
+ logger.info(
126
+ f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
122
127
  )
123
- is None
124
- ):
125
- # create a new session for this running
126
- await self._session_service.create_session(
128
+ return session
129
+ else:
130
+ return await self._session_service.create_session(
127
131
  app_name=app_name, user_id=user_id, session_id=session_id
128
132
  )
@@ -14,6 +14,7 @@
14
14
 
15
15
  from functools import cached_property
16
16
  from typing import Any
17
+ from venv import logger
17
18
 
18
19
  from google.adk.sessions import (
19
20
  BaseSessionService,
@@ -33,7 +34,8 @@ class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
33
34
  postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig)
34
35
 
35
36
  def model_post_init(self, context: Any) -> None:
36
- self._db_url = f"postgresql+psycopg2://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
37
+ self._db_url = f"postgresql://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
38
+ logger.debug(self._db_url)
37
39
 
38
40
  @cached_property
39
41
  @override