veadk-python 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of veadk-python might be problematic. Click here for more details.

Files changed (29) hide show
  1. veadk/agent.py +7 -3
  2. veadk/auth/veauth/ark_veauth.py +43 -51
  3. veadk/auth/veauth/utils.py +57 -0
  4. veadk/configs/model_configs.py +3 -3
  5. veadk/consts.py +9 -0
  6. veadk/knowledgebase/knowledgebase.py +19 -32
  7. veadk/memory/long_term_memory.py +39 -92
  8. veadk/memory/long_term_memory_backends/base_backend.py +4 -2
  9. veadk/memory/long_term_memory_backends/in_memory_backend.py +8 -6
  10. veadk/memory/long_term_memory_backends/mem0_backend.py +8 -8
  11. veadk/memory/long_term_memory_backends/opensearch_backend.py +40 -36
  12. veadk/memory/long_term_memory_backends/redis_backend.py +59 -46
  13. veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py +53 -28
  14. veadk/memory/short_term_memory.py +9 -11
  15. veadk/runner.py +19 -11
  16. veadk/tools/builtin_tools/generate_image.py +11 -6
  17. veadk/tools/builtin_tools/image_edit.py +9 -4
  18. veadk/tools/builtin_tools/image_generate.py +9 -4
  19. veadk/tools/builtin_tools/load_knowledgebase.py +97 -0
  20. veadk/tools/builtin_tools/video_generate.py +6 -4
  21. veadk/utils/misc.py +6 -10
  22. veadk/utils/volcengine_sign.py +2 -0
  23. veadk/version.py +1 -1
  24. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/METADATA +2 -1
  25. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/RECORD +29 -27
  26. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/WHEEL +0 -0
  27. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/entry_points.txt +0 -0
  28. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/licenses/LICENSE +0 -0
  29. {veadk_python-0.2.10.dist-info → veadk_python-0.2.11.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,7 @@
14
14
 
15
15
  import re
16
16
 
17
- from llama_index.core import (
18
- Document,
19
- StorageContext,
20
- VectorStoreIndex,
21
- )
17
+ from llama_index.core import Document, VectorStoreIndex
22
18
  from llama_index.core.schema import BaseNode
23
19
  from llama_index.embeddings.openai_like import OpenAILikeEmbedding
24
20
  from pydantic import Field
@@ -31,6 +27,7 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
31
27
  from veadk.memory.long_term_memory_backends.base_backend import (
32
28
  BaseLongTermMemoryBackend,
33
29
  )
30
+ from veadk.utils.logger import get_logger
34
31
 
35
32
  try:
36
33
  from llama_index.vector_stores.opensearch import (
@@ -42,6 +39,8 @@ except ImportError:
42
39
  "Please install VeADK extensions\npip install veadk-python[extensions]"
43
40
  )
44
41
 
42
+ logger = get_logger(__name__)
43
+
45
44
 
46
45
  class OpensearchLTMBackend(BaseLongTermMemoryBackend):
47
46
  opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig)
@@ -52,19 +51,30 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
52
51
  )
53
52
  """Embedding model configs"""
54
53
 
55
- def precheck_index_naming(self):
54
+ def model_post_init(self, __context: Any) -> None:
55
+ self._embed_model = OpenAILikeEmbedding(
56
+ model_name=self.embedding_config.name,
57
+ api_key=self.embedding_config.api_key,
58
+ api_base=self.embedding_config.api_base,
59
+ )
60
+
61
+ def precheck_index_naming(self, index: str):
56
62
  if not (
57
- isinstance(self.index, str)
58
- and not self.index.startswith(("_", "-"))
59
- and self.index.islower()
60
- and re.match(r"^[a-z0-9_\-.]+$", self.index)
63
+ isinstance(index, str)
64
+ and not index.startswith(("_", "-"))
65
+ and index.islower()
66
+ and re.match(r"^[a-z0-9_\-.]+$", index)
61
67
  ):
62
68
  raise ValueError(
63
- "The index name does not conform to the naming rules of OpenSearch"
69
+ f"The index name {index} does not conform to the naming rules of OpenSearch"
64
70
  )
65
71
 
66
- def model_post_init(self, __context: Any) -> None:
67
- self._opensearch_client = OpensearchVectorClient(
72
+ def _create_vector_index(self, index: str) -> VectorStoreIndex:
73
+ logger.info(f"Create OpenSearch vector index with index={index}")
74
+
75
+ self.precheck_index_naming(index)
76
+
77
+ opensearch_client = OpensearchVectorClient(
68
78
  endpoint=self.opensearch_config.host,
69
79
  port=self.opensearch_config.port,
70
80
  http_auth=(
@@ -74,39 +84,33 @@ class OpensearchLTMBackend(BaseLongTermMemoryBackend):
74
84
  use_ssl=True,
75
85
  verify_certs=False,
76
86
  dim=self.embedding_config.dim,
77
- index=self.index, # collection name
87
+ index=index,
78
88
  )
79
-
80
- self._vector_store = OpensearchVectorStore(client=self._opensearch_client)
81
-
82
- self._storage_context = StorageContext.from_defaults(
83
- vector_store=self._vector_store
84
- )
85
-
86
- self._embed_model = OpenAILikeEmbedding(
87
- model_name=self.embedding_config.name,
88
- api_key=self.embedding_config.api_key,
89
- api_base=self.embedding_config.api_base,
90
- )
91
-
92
- self._vector_index = VectorStoreIndex.from_documents(
93
- documents=[],
94
- storage_context=self._storage_context,
95
- embed_model=self._embed_model,
89
+ vector_store = OpensearchVectorStore(client=opensearch_client)
90
+ return VectorStoreIndex.from_vector_store(
91
+ vector_store=vector_store, embed_model=self._embed_model
96
92
  )
97
- self._retriever = self._vector_index.as_retriever()
98
93
 
99
94
  @override
100
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
95
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
96
+ index = f"{self.index}_{user_id}"
97
+ vector_index = self._create_vector_index(index)
98
+
101
99
  for event_string in event_strings:
102
100
  document = Document(text=event_string)
103
101
  nodes = self._split_documents([document])
104
- self._vector_index.insert_nodes(nodes)
102
+ vector_index.insert_nodes(nodes)
105
103
  return True
106
104
 
107
105
  @override
108
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
109
- _retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
106
+ def search_memory(
107
+ self, user_id: str, query: str, top_k: int, **kwargs
108
+ ) -> list[str]:
109
+ index = f"{self.index}_{user_id}"
110
+
111
+ vector_index = self._create_vector_index(index)
112
+
113
+ _retriever = vector_index.as_retriever(similarity_top_k=top_k)
110
114
  retrieved_nodes = _retriever.retrieve(query)
111
115
  return [node.text for node in retrieved_nodes]
112
116
 
@@ -12,11 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from llama_index.core import (
16
- Document,
17
- StorageContext,
18
- VectorStoreIndex,
19
- )
15
+ from llama_index.core import Document, VectorStoreIndex
20
16
  from llama_index.core.schema import BaseNode
21
17
  from llama_index.embeddings.openai_like import OpenAILikeEmbedding
22
18
  from pydantic import Field
@@ -29,21 +25,22 @@ from veadk.knowledgebase.backends.utils import get_llama_index_splitter
29
25
  from veadk.memory.long_term_memory_backends.base_backend import (
30
26
  BaseLongTermMemoryBackend,
31
27
  )
28
+ from veadk.utils.logger import get_logger
32
29
 
33
30
  try:
34
31
  from llama_index.vector_stores.redis import RedisVectorStore
35
- from llama_index.vector_stores.redis.schema import (
36
- RedisIndexInfo,
37
- RedisVectorStoreSchema,
38
- )
39
32
  from redis import Redis
40
- from redisvl.schema.fields import BaseVectorFieldAttributes
33
+ from redisvl.schema import IndexSchema
34
+
41
35
  except ImportError:
42
36
  raise ImportError(
43
37
  "Please install VeADK extensions\npip install veadk-python[extensions]"
44
38
  )
45
39
 
46
40
 
41
+ logger = get_logger(__name__)
42
+
43
+
47
44
  class RedisLTMBackend(BaseLongTermMemoryBackend):
48
45
  redis_config: RedisConfig = Field(default_factory=RedisConfig)
49
46
  """Redis client configs"""
@@ -53,67 +50,83 @@ class RedisLTMBackend(BaseLongTermMemoryBackend):
53
50
  )
54
51
  """Embedding model configs"""
55
52
 
56
- def precheck_index_naming(self):
53
+ def model_post_init(self, __context: Any) -> None:
54
+ self._embed_model = OpenAILikeEmbedding(
55
+ model_name=self.embedding_config.name,
56
+ api_key=self.embedding_config.api_key,
57
+ api_base=self.embedding_config.api_base,
58
+ )
59
+
60
+ def precheck_index_naming(self, index: str):
57
61
  # no checking
58
62
  pass
59
63
 
60
- def model_post_init(self, __context: Any) -> None:
64
+ def _create_vector_index(self, index: str) -> VectorStoreIndex:
65
+ logger.info(f"Create Redis vector index with index={index}")
66
+
67
+ self.precheck_index_naming(index)
68
+
61
69
  # We will use `from_url` to init Redis client once the
62
70
  # AK/SK -> STS token is ready.
63
71
  # self._redis_client = Redis.from_url(url=...)
64
-
65
- self._redis_client = Redis(
72
+ redis_client = Redis(
66
73
  host=self.redis_config.host,
67
74
  port=self.redis_config.port,
68
75
  db=self.redis_config.db,
69
76
  password=self.redis_config.password,
70
77
  )
71
78
 
72
- self._embed_model = OpenAILikeEmbedding(
73
- model_name=self.embedding_config.name,
74
- api_key=self.embedding_config.api_key,
75
- api_base=self.embedding_config.api_base,
79
+ # Create an index for each user
80
+ # Should be Optimized in the future
81
+ schema = IndexSchema.from_dict(
82
+ {
83
+ "index": {"name": index, "prefix": index, "key_separator": "_"},
84
+ "fields": [
85
+ {"name": "id", "type": "tag", "attrs": {"sortable": False}},
86
+ {"name": "doc_id", "type": "tag", "attrs": {"sortable": False}},
87
+ {"name": "text", "type": "text", "attrs": {"weight": 1.0}},
88
+ {
89
+ "name": "vector",
90
+ "type": "vector",
91
+ "attrs": {
92
+ "dims": self.embedding_config.dim,
93
+ "algorithm": "flat",
94
+ "distance_metric": "cosine",
95
+ },
96
+ },
97
+ ],
98
+ }
76
99
  )
100
+ vector_store = RedisVectorStore(schema=schema, redis_client=redis_client)
77
101
 
78
- self._schema = RedisVectorStoreSchema(
79
- index=RedisIndexInfo(name=self.index),
80
- )
81
- if "vector" in self._schema.fields:
82
- vector_field = self._schema.fields["vector"]
83
- if (
84
- vector_field
85
- and vector_field.attrs
86
- and isinstance(vector_field.attrs, BaseVectorFieldAttributes)
87
- ):
88
- vector_field.attrs.dims = self.embedding_config.dim
89
- self._vector_store = RedisVectorStore(
90
- schema=self._schema,
91
- redis_client=self._redis_client,
92
- overwrite=True,
93
- collection_name=self.index,
94
- )
95
-
96
- self._storage_context = StorageContext.from_defaults(
97
- vector_store=self._vector_store
102
+ logger.info(
103
+ f"Create vector store done, index_name={vector_store.index_name} prefix={vector_store.schema.index.prefix}"
98
104
  )
99
105
 
100
- self._vector_index = VectorStoreIndex.from_documents(
101
- documents=[],
102
- storage_context=self._storage_context,
103
- embed_model=self._embed_model,
106
+ return VectorStoreIndex.from_vector_store(
107
+ vector_store=vector_store, embed_model=self._embed_model
104
108
  )
105
109
 
106
110
  @override
107
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
111
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
112
+ index = f"veadk-ltm/{self.index}/{user_id}"
113
+ vector_index = self._create_vector_index(index)
114
+
108
115
  for event_string in event_strings:
109
116
  document = Document(text=event_string)
110
117
  nodes = self._split_documents([document])
111
- self._vector_index.insert_nodes(nodes)
118
+ vector_index.insert_nodes(nodes)
119
+
112
120
  return True
113
121
 
114
122
  @override
115
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
116
- _retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
123
+ def search_memory(
124
+ self, user_id: str, query: str, top_k: int, **kwargs
125
+ ) -> list[str]:
126
+ index = f"veadk-ltm/{self.index}/{user_id}"
127
+ vector_index = self._create_vector_index(index)
128
+
129
+ _retriever = vector_index.as_retriever(similarity_top_k=top_k)
117
130
  retrieved_nodes = _retriever.retrieve(query)
118
131
  return [node.text for node in retrieved_nodes]
119
132
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import json
16
+ import os
16
17
  import re
17
18
  import time
18
19
  import uuid
@@ -22,7 +23,7 @@ from pydantic import Field
22
23
  from typing_extensions import override
23
24
 
24
25
  import veadk.config # noqa E401
25
- from veadk.config import getenv
26
+ from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
26
27
  from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
27
28
  VikingDBMemoryClient,
28
29
  )
@@ -35,17 +36,24 @@ logger = get_logger(__name__)
35
36
 
36
37
 
37
38
  class VikingDBLTMBackend(BaseLongTermMemoryBackend):
38
- volcengine_access_key: str = Field(
39
- default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
39
+ volcengine_access_key: str | None = Field(
40
+ default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
40
41
  )
41
42
 
42
- volcengine_secret_key: str = Field(
43
- default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
43
+ volcengine_secret_key: str | None = Field(
44
+ default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
44
45
  )
45
46
 
47
+ session_token: str = ""
48
+
46
49
  region: str = "cn-beijing"
47
50
  """VikingDB memory region"""
48
51
 
52
+ def model_post_init(self, __context: Any) -> None:
53
+ # check whether collection exist, if not, create it
54
+ if not self._collection_exist():
55
+ self._create_collection()
56
+
49
57
  def precheck_index_naming(self):
50
58
  if not (
51
59
  isinstance(self.index, str)
@@ -56,37 +64,39 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
56
64
  "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
57
65
  )
58
66
 
59
- def model_post_init(self, __context: Any) -> None:
60
- self._client = VikingDBMemoryClient(
61
- ak=self.volcengine_access_key,
62
- sk=self.volcengine_secret_key,
63
- region=self.region,
64
- )
65
-
66
- # check whether collection exist, if not, create it
67
- if not self._collection_exist():
68
- self._create_collection()
69
-
70
67
  def _collection_exist(self) -> bool:
71
68
  try:
72
- self._client.get_collection(collection_name=self.index)
69
+ client = self._get_client()
70
+ client.get_collection(collection_name=self.index)
73
71
  return True
74
72
  except Exception:
75
73
  return False
76
74
 
77
75
  def _create_collection(self) -> None:
78
- response = self._client.create_collection(
76
+ client = self._get_client()
77
+ response = client.create_collection(
79
78
  collection_name=self.index,
80
79
  description="Created by Volcengine Agent Development Kit VeADK",
81
80
  builtin_event_types=["sys_event_v1"],
82
81
  )
83
82
  return response
84
83
 
84
+ def _get_client(self) -> VikingDBMemoryClient:
85
+ if not (self.volcengine_access_key and self.volcengine_secret_key):
86
+ cred = get_credential_from_vefaas_iam()
87
+ self.volcengine_access_key = cred.access_key_id
88
+ self.volcengine_secret_key = cred.secret_access_key
89
+ self.session_token = cred.session_token
90
+
91
+ return VikingDBMemoryClient(
92
+ ak=self.volcengine_access_key,
93
+ sk=self.volcengine_secret_key,
94
+ sts_token=self.session_token,
95
+ region=self.region,
96
+ )
97
+
85
98
  @override
86
- def save_memory(self, event_strings: list[str], **kwargs) -> bool:
87
- user_id = kwargs.get("user_id")
88
- if user_id is None:
89
- raise ValueError("user_id is required")
99
+ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
90
100
  session_id = str(uuid.uuid1())
91
101
  messages = []
92
102
  for raw_events in event_strings:
@@ -101,31 +111,46 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
101
111
  "default_assistant_id": "assistant",
102
112
  "time": int(time.time() * 1000),
103
113
  }
104
- response = self._client.add_messages(
114
+
115
+ logger.debug(
116
+ f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
117
+ )
118
+
119
+ client = self._get_client()
120
+ response = client.add_messages(
105
121
  collection_name=self.index,
106
122
  messages=messages,
107
123
  metadata=metadata,
108
124
  session_id=session_id,
109
125
  )
110
126
 
127
+ logger.debug(f"Response from add memory to VikingDB: {response}")
128
+
111
129
  if not response.get("code") == 0:
112
130
  raise ValueError(f"Save VikingDB memory error: {response}")
113
131
 
114
132
  return True
115
133
 
116
134
  @override
117
- def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
118
- user_id = kwargs.get("user_id")
119
- if user_id is None:
120
- raise ValueError("user_id is required")
135
+ def search_memory(
136
+ self, user_id: str, query: str, top_k: int, **kwargs
137
+ ) -> list[str]:
121
138
  filter = {
122
139
  "user_id": user_id,
123
140
  "memory_type": ["sys_event_v1"],
124
141
  }
125
- response = self._client.search_memory(
142
+
143
+ logger.debug(
144
+ f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
145
+ )
146
+
147
+ client = self._get_client()
148
+ response = client.search_memory(
126
149
  collection_name=self.index, query=query, filter=filter, limit=top_k
127
150
  )
128
151
 
152
+ logger.debug(f"Response from search memory in VikingDB: {response}")
153
+
129
154
  if not response.get("code") == 0:
130
155
  raise ValueError(f"Search VikingDB memory error: {response}")
131
156
 
@@ -117,18 +117,16 @@ class ShortTermMemory(BaseModel):
117
117
  f"Loaded {len(list_sessions_response.sessions)} sessions from db {self.db_url}."
118
118
  )
119
119
 
120
- if (
121
- await self._session_service.get_session(
122
- app_name=app_name, user_id=user_id, session_id=session_id
120
+ session = await self._session_service.get_session(
121
+ app_name=app_name, user_id=user_id, session_id=session_id
122
+ )
123
+
124
+ if session:
125
+ logger.info(
126
+ f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
123
127
  )
124
- is None
125
- ):
128
+ return session
129
+ else:
126
130
  return await self._session_service.create_session(
127
131
  app_name=app_name, user_id=user_id, session_id=session_id
128
132
  )
129
- else:
130
- logger.info(
131
- f"Session {session_id} already exists with app_name={app_name} user_id={user_id}."
132
- )
133
-
134
- return None
veadk/runner.py CHANGED
@@ -34,7 +34,7 @@ from veadk.evaluation import EvalSetRecorder
34
34
  from veadk.memory.short_term_memory import ShortTermMemory
35
35
  from veadk.types import MediaMessage
36
36
  from veadk.utils.logger import get_logger
37
- from veadk.utils.misc import formatted_timestamp, read_png_to_bytes
37
+ from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
38
38
 
39
39
  logger = get_logger(__name__)
40
40
 
@@ -50,11 +50,7 @@ RunnerMessage = Union[
50
50
  async def pre_run_process(self, process_func, new_message, user_id, session_id):
51
51
  if new_message.parts:
52
52
  for part in new_message.parts:
53
- if (
54
- part.inline_data
55
- and part.inline_data.mime_type == "image/png"
56
- and self.upload_inline_data_to_tos
57
- ):
53
+ if part.inline_data and self.upload_inline_data_to_tos:
58
54
  await process_func(
59
55
  part,
60
56
  self.app_name,
@@ -105,9 +101,20 @@ def _convert_messages(
105
101
  if isinstance(messages, str):
106
102
  _messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
107
103
  elif isinstance(messages, MediaMessage):
108
- assert messages.media.endswith(".png"), (
109
- "The MediaMessage only supports PNG format file for now."
104
+ import filetype
105
+
106
+ file_data = read_file_to_bytes(messages.media)
107
+
108
+ kind = filetype.guess(file_data)
109
+ if kind is None:
110
+ raise ValueError("Unsupported or unknown file type.")
111
+
112
+ mime_type = kind.mime
113
+
114
+ assert mime_type.startswith(("image/", "video/")), (
115
+ f"Unsupported media type: {mime_type}"
110
116
  )
117
+
111
118
  _messages = [
112
119
  types.Content(
113
120
  role="user",
@@ -116,8 +123,8 @@ def _convert_messages(
116
123
  types.Part(
117
124
  inline_data=Blob(
118
125
  display_name=messages.media,
119
- data=read_png_to_bytes(messages.media),
120
- mime_type="image/png",
126
+ data=file_data,
127
+ mime_type=mime_type,
121
128
  )
122
129
  ),
123
130
  ],
@@ -277,7 +284,8 @@ class Runner(ADKRunner):
277
284
  and event.content.parts[0].text is not None
278
285
  and len(event.content.parts[0].text.strip()) > 0
279
286
  ):
280
- final_output += event.content.parts[0].text
287
+ final_output = event.content.parts[0].text
288
+ logger.debug(f"Event output: {final_output}")
281
289
  except LlmCallsLimitExceededError as e:
282
290
  logger.warning(f"Max number of llm calls limit exceeded: {e}")
283
291
  final_output = ""
@@ -25,17 +25,22 @@ from opentelemetry.trace import Span
25
25
  from volcenginesdkarkruntime import Ark
26
26
  from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions
27
27
 
28
- from veadk.config import getenv
29
- from veadk.consts import DEFAULT_IMAGE_GENERATE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
28
+ from veadk.config import getenv, settings
29
+ from veadk.consts import (
30
+ DEFAULT_IMAGE_GENERATE_MODEL_NAME,
31
+ DEFAULT_IMAGE_GENERATE_MODEL_API_BASE,
32
+ )
30
33
  from veadk.utils.logger import get_logger
31
- from veadk.utils.misc import formatted_timestamp, read_png_to_bytes
34
+ from veadk.utils.misc import formatted_timestamp, read_file_to_bytes
32
35
  from veadk.version import VERSION
33
36
 
34
37
  logger = get_logger(__name__)
35
38
 
36
39
  client = Ark(
37
- api_key=getenv("MODEL_AGENT_API_KEY"),
38
- base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
40
+ api_key=getenv(
41
+ "MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
42
+ ),
43
+ base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_IMAGE_GENERATE_MODEL_API_BASE),
39
44
  )
40
45
 
41
46
 
@@ -299,7 +304,7 @@ async def image_generate(
299
304
  artifact=Part(
300
305
  inline_data=Blob(
301
306
  display_name=filename,
302
- data=read_png_to_bytes(image_tos_url),
307
+ data=read_file_to_bytes(image_tos_url),
303
308
  mime_type=mimetypes.guess_type(image_tos_url)[0],
304
309
  )
305
310
  ),
@@ -15,8 +15,11 @@
15
15
  from typing import Dict
16
16
  from google.adk.tools import ToolContext
17
17
  from volcenginesdkarkruntime import Ark
18
- from veadk.config import getenv
19
- from veadk.consts import DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_IMAGE_EDIT_MODEL_NAME
18
+ from veadk.config import getenv, settings
19
+ from veadk.consts import (
20
+ DEFAULT_IMAGE_EDIT_MODEL_API_BASE,
21
+ DEFAULT_IMAGE_EDIT_MODEL_NAME,
22
+ )
20
23
  import base64
21
24
  from opentelemetry import trace
22
25
  import traceback
@@ -28,8 +31,10 @@ from veadk.utils.logger import get_logger
28
31
  logger = get_logger(__name__)
29
32
 
30
33
  client = Ark(
31
- api_key=getenv("MODEL_AGENT_API_KEY"),
32
- base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
34
+ api_key=getenv(
35
+ "MODEL_EDIT_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
36
+ ),
37
+ base_url=getenv("MODEL_EDIT_API_BASE", DEFAULT_IMAGE_EDIT_MODEL_API_BASE),
33
38
  )
34
39
 
35
40
 
@@ -15,8 +15,11 @@
15
15
  from typing import Dict
16
16
 
17
17
  from google.adk.tools import ToolContext
18
- from veadk.config import getenv
19
- from veadk.consts import DEFAULT_TEXT_TO_IMAGE_MODEL_NAME, DEFAULT_MODEL_AGENT_API_BASE
18
+ from veadk.config import getenv, settings
19
+ from veadk.consts import (
20
+ DEFAULT_TEXT_TO_IMAGE_MODEL_NAME,
21
+ DEFAULT_TEXT_TO_IMAGE_MODEL_API_BASE,
22
+ )
20
23
  import base64
21
24
  from volcenginesdkarkruntime import Ark
22
25
  from opentelemetry import trace
@@ -29,8 +32,10 @@ from veadk.utils.logger import get_logger
29
32
  logger = get_logger(__name__)
30
33
 
31
34
  client = Ark(
32
- api_key=getenv("MODEL_AGENT_API_KEY"),
33
- base_url=getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE),
35
+ api_key=getenv(
36
+ "MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key)
37
+ ),
38
+ base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_TEXT_TO_IMAGE_MODEL_API_BASE),
34
39
  )
35
40
 
36
41