kagent-adk 0.7.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kagent/adk/__init__.py +8 -0
- kagent/adk/_a2a.py +178 -0
- kagent/adk/_agent_executor.py +335 -0
- kagent/adk/_lifespan.py +36 -0
- kagent/adk/_session_service.py +178 -0
- kagent/adk/_token.py +80 -0
- kagent/adk/artifacts/__init__.py +13 -0
- kagent/adk/artifacts/artifacts_toolset.py +56 -0
- kagent/adk/artifacts/return_artifacts_tool.py +160 -0
- kagent/adk/artifacts/session_path.py +106 -0
- kagent/adk/artifacts/stage_artifacts_tool.py +170 -0
- kagent/adk/cli.py +249 -0
- kagent/adk/converters/__init__.py +0 -0
- kagent/adk/converters/error_mappings.py +60 -0
- kagent/adk/converters/event_converter.py +322 -0
- kagent/adk/converters/part_converter.py +206 -0
- kagent/adk/converters/request_converter.py +35 -0
- kagent/adk/models/__init__.py +3 -0
- kagent/adk/models/_openai.py +564 -0
- kagent/adk/models/_ssl.py +245 -0
- kagent/adk/sandbox_code_executer.py +77 -0
- kagent/adk/skill_fetcher.py +103 -0
- kagent/adk/tools/README.md +217 -0
- kagent/adk/tools/__init__.py +15 -0
- kagent/adk/tools/bash_tool.py +74 -0
- kagent/adk/tools/file_tools.py +192 -0
- kagent/adk/tools/skill_tool.py +104 -0
- kagent/adk/tools/skills_plugin.py +49 -0
- kagent/adk/tools/skills_toolset.py +68 -0
- kagent/adk/types.py +268 -0
- kagent_adk-0.7.11.dist-info/METADATA +35 -0
- kagent_adk-0.7.11.dist-info/RECORD +34 -0
- kagent_adk-0.7.11.dist-info/WHEEL +4 -0
- kagent_adk-0.7.11.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import mimetypes
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, List
|
|
7
|
+
|
|
8
|
+
from google.adk.tools import BaseTool, ToolContext
|
|
9
|
+
from google.genai import types
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from .session_path import get_session_path
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("kagent_adk." + __name__)
|
|
15
|
+
|
|
16
|
+
# Maximum file size for staging (100 MB)
|
|
17
|
+
MAX_ARTIFACT_SIZE_BYTES = 100 * 1024 * 1024
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StageArtifactsTool(BaseTool):
|
|
21
|
+
"""A tool to stage artifacts from the artifact service to the local filesystem.
|
|
22
|
+
|
|
23
|
+
This tool enables working with user-uploaded files by staging them from the
|
|
24
|
+
artifact store to a local working directory where they can be accessed by
|
|
25
|
+
scripts, commands, and other tools.
|
|
26
|
+
|
|
27
|
+
Workflow:
|
|
28
|
+
1. Stage: Copy artifacts from artifact store to local 'uploads/' directory
|
|
29
|
+
2. Access: Use the staged files in commands, scripts, or other processing
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
super().__init__(
|
|
34
|
+
name="stage_artifacts",
|
|
35
|
+
description=(
|
|
36
|
+
"Stage artifacts from the artifact store to a local filesystem path, "
|
|
37
|
+
"making them available for processing by tools and scripts.\n\n"
|
|
38
|
+
"WORKFLOW:\n"
|
|
39
|
+
"1. When a user uploads a file, it's stored as an artifact with a name\n"
|
|
40
|
+
"2. Use this tool to copy the artifact to your local 'uploads/' directory\n"
|
|
41
|
+
"3. Then reference the staged file path in commands or scripts\n\n"
|
|
42
|
+
"USAGE EXAMPLE:\n"
|
|
43
|
+
"- stage_artifacts(artifact_names=['data.csv'])\n"
|
|
44
|
+
" Returns: 'Successfully staged 1 file(s): uploads/data.csv (1.2 MB)'\n"
|
|
45
|
+
"- Then use: bash('python scripts/process.py uploads/data.csv')\n\n"
|
|
46
|
+
"PARAMETERS:\n"
|
|
47
|
+
"- artifact_names: List of artifact names to stage (required)\n"
|
|
48
|
+
"- destination_path: Target directory within session (default: 'uploads/')\n\n"
|
|
49
|
+
"BEST PRACTICES:\n"
|
|
50
|
+
"- Always stage artifacts before using them\n"
|
|
51
|
+
"- Use default 'uploads/' destination for consistency\n"
|
|
52
|
+
"- Stage all artifacts at the start of your workflow\n"
|
|
53
|
+
"- Check returned paths to confirm successful staging"
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
|
58
|
+
return types.FunctionDeclaration(
|
|
59
|
+
name=self.name,
|
|
60
|
+
description=self.description,
|
|
61
|
+
parameters=types.Schema(
|
|
62
|
+
type=types.Type.OBJECT,
|
|
63
|
+
properties={
|
|
64
|
+
"artifact_names": types.Schema(
|
|
65
|
+
type=types.Type.ARRAY,
|
|
66
|
+
description=(
|
|
67
|
+
"List of artifact names to stage. These are artifact identifiers "
|
|
68
|
+
"provided by the system when files are uploaded. "
|
|
69
|
+
"The tool will copy each artifact from the artifact store to the destination directory."
|
|
70
|
+
),
|
|
71
|
+
items=types.Schema(type=types.Type.STRING),
|
|
72
|
+
),
|
|
73
|
+
"destination_path": types.Schema(
|
|
74
|
+
type=types.Type.STRING,
|
|
75
|
+
description=(
|
|
76
|
+
"Relative path within the session directory to save the files. "
|
|
77
|
+
"Default is 'uploads/' where user-uploaded files are conventionally stored. "
|
|
78
|
+
"Path must be within the session directory for security. "
|
|
79
|
+
"Useful for organizing different types of artifacts (e.g., 'uploads/input/', 'uploads/processed/')."
|
|
80
|
+
),
|
|
81
|
+
default="uploads/",
|
|
82
|
+
),
|
|
83
|
+
},
|
|
84
|
+
required=["artifact_names"],
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
async def run_async(self, *, args: dict[str, Any], tool_context: ToolContext) -> str:
|
|
90
|
+
artifact_names: List[str] = args.get("artifact_names", [])
|
|
91
|
+
destination_path_str: str = args.get("destination_path", "uploads/")
|
|
92
|
+
|
|
93
|
+
if not tool_context._invocation_context.artifact_service:
|
|
94
|
+
return "Error: Artifact service is not available in this context."
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
staging_root = get_session_path(session_id=tool_context.session.id)
|
|
98
|
+
destination_dir = (staging_root / destination_path_str).resolve()
|
|
99
|
+
|
|
100
|
+
# Security: Ensure the destination is within the staging path
|
|
101
|
+
if staging_root not in destination_dir.parents and destination_dir != staging_root:
|
|
102
|
+
return f"Error: Invalid destination path '{destination_path_str}'."
|
|
103
|
+
|
|
104
|
+
destination_dir.mkdir(parents=True, exist_ok=True)
|
|
105
|
+
|
|
106
|
+
staged_files = []
|
|
107
|
+
for name in artifact_names:
|
|
108
|
+
artifact = await tool_context.load_artifact(name)
|
|
109
|
+
if artifact is None or artifact.inline_data is None:
|
|
110
|
+
logger.warning('Artifact "%s" not found or has no data, skipping', name)
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
# Check file size
|
|
114
|
+
data_size = len(artifact.inline_data.data)
|
|
115
|
+
if data_size > MAX_ARTIFACT_SIZE_BYTES:
|
|
116
|
+
size_mb = data_size / (1024 * 1024)
|
|
117
|
+
logger.warning(f'Artifact "{name}" exceeds size limit: {size_mb:.1f} MB')
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
# Use artifact name as filename (frontend should provide meaningful names)
|
|
121
|
+
# If name has no extension, try to infer from MIME type
|
|
122
|
+
filename = self._ensure_proper_extension(name, artifact.inline_data.mime_type)
|
|
123
|
+
output_file = destination_dir / filename
|
|
124
|
+
|
|
125
|
+
# Write file to disk
|
|
126
|
+
output_file.write_bytes(artifact.inline_data.data)
|
|
127
|
+
|
|
128
|
+
relative_path = output_file.relative_to(staging_root)
|
|
129
|
+
size_kb = data_size / 1024
|
|
130
|
+
staged_files.append(f"{relative_path} ({size_kb:.1f} KB)")
|
|
131
|
+
|
|
132
|
+
logger.info(f"Staged artifact: {name} -> {relative_path} ({size_kb:.1f} KB)")
|
|
133
|
+
|
|
134
|
+
if not staged_files:
|
|
135
|
+
return "No valid artifacts were staged."
|
|
136
|
+
|
|
137
|
+
return f"Successfully staged {len(staged_files)} file(s):\n" + "\n".join(
|
|
138
|
+
f" • {file}" for file in staged_files
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error("Error staging artifacts: %s", e, exc_info=True)
|
|
143
|
+
return f"An error occurred while staging artifacts: {e}"
|
|
144
|
+
|
|
145
|
+
def _ensure_proper_extension(self, filename: str, mime_type: str) -> str:
|
|
146
|
+
"""Ensure filename has proper extension based on MIME type.
|
|
147
|
+
|
|
148
|
+
If filename already has an extension, keep it.
|
|
149
|
+
If not, add extension based on MIME type.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
filename: Original filename from artifact
|
|
153
|
+
mime_type: MIME type of the file
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Filename with proper extension
|
|
157
|
+
"""
|
|
158
|
+
if not filename or not mime_type:
|
|
159
|
+
return filename
|
|
160
|
+
|
|
161
|
+
# If filename already has an extension, use it
|
|
162
|
+
if Path(filename).suffix:
|
|
163
|
+
return filename
|
|
164
|
+
|
|
165
|
+
# Try to infer extension from MIME type
|
|
166
|
+
extension = mimetypes.guess_extension(mime_type)
|
|
167
|
+
if extension:
|
|
168
|
+
return f"{filename}{extension}"
|
|
169
|
+
|
|
170
|
+
return filename
|
kagent/adk/cli.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from typing import Annotated, Optional
|
|
7
|
+
|
|
8
|
+
import typer
|
|
9
|
+
import uvicorn
|
|
10
|
+
from a2a.types import AgentCard
|
|
11
|
+
from agentsts.adk import ADKSTSIntegration, ADKTokenPropagationPlugin
|
|
12
|
+
from google.adk.agents import BaseAgent
|
|
13
|
+
from google.adk.cli.utils.agent_loader import AgentLoader
|
|
14
|
+
|
|
15
|
+
from kagent.core import KAgentConfig, configure_logging, configure_tracing
|
|
16
|
+
|
|
17
|
+
from . import AgentConfig, KAgentApp
|
|
18
|
+
from .skill_fetcher import fetch_skill
|
|
19
|
+
from .tools import add_skills_tool_to_agent
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
logging.getLogger("google_adk.google.adk.tools.base_authenticated_tool").setLevel(logging.ERROR)
|
|
23
|
+
|
|
24
|
+
app = typer.Typer()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
kagent_url_override = os.getenv("KAGENT_URL")
|
|
28
|
+
sts_well_known_uri = os.getenv("STS_WELL_KNOWN_URI")
|
|
29
|
+
propagate_token = os.getenv("KAGENT_PROPAGATE_TOKEN")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def create_sts_integration() -> Optional[ADKTokenPropagationPlugin]:
|
|
33
|
+
if sts_well_known_uri or propagate_token:
|
|
34
|
+
sts_integration = None
|
|
35
|
+
if sts_well_known_uri:
|
|
36
|
+
sts_integration = ADKSTSIntegration(sts_well_known_uri)
|
|
37
|
+
return ADKTokenPropagationPlugin(sts_integration)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def maybe_add_skills(root_agent: BaseAgent):
|
|
41
|
+
skills_directory = os.getenv("KAGENT_SKILLS_FOLDER", None)
|
|
42
|
+
if skills_directory:
|
|
43
|
+
logger.info(f"Adding skills from directory: {skills_directory}")
|
|
44
|
+
add_skills_tool_to_agent(skills_directory, root_agent)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@app.command()
|
|
48
|
+
def static(
|
|
49
|
+
host: str = "127.0.0.1",
|
|
50
|
+
port: int = 8080,
|
|
51
|
+
workers: int = 1,
|
|
52
|
+
filepath: str = "/config",
|
|
53
|
+
reload: Annotated[bool, typer.Option("--reload")] = False,
|
|
54
|
+
):
|
|
55
|
+
app_cfg = KAgentConfig()
|
|
56
|
+
|
|
57
|
+
with open(os.path.join(filepath, "config.json"), "r") as f:
|
|
58
|
+
config = json.load(f)
|
|
59
|
+
agent_config = AgentConfig.model_validate(config)
|
|
60
|
+
with open(os.path.join(filepath, "agent-card.json"), "r") as f:
|
|
61
|
+
agent_card = json.load(f)
|
|
62
|
+
agent_card = AgentCard.model_validate(agent_card)
|
|
63
|
+
plugins = None
|
|
64
|
+
sts_integration = create_sts_integration()
|
|
65
|
+
if sts_integration:
|
|
66
|
+
plugins = [sts_integration]
|
|
67
|
+
|
|
68
|
+
def root_agent_factory() -> BaseAgent:
|
|
69
|
+
root_agent = agent_config.to_agent(app_cfg.name, sts_integration)
|
|
70
|
+
|
|
71
|
+
maybe_add_skills(root_agent)
|
|
72
|
+
|
|
73
|
+
return root_agent
|
|
74
|
+
|
|
75
|
+
kagent_app = KAgentApp(
|
|
76
|
+
root_agent_factory,
|
|
77
|
+
agent_card,
|
|
78
|
+
app_cfg.url,
|
|
79
|
+
app_cfg.app_name,
|
|
80
|
+
plugins=plugins,
|
|
81
|
+
stream=agent_config.stream if agent_config.stream is not None else False,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
server = kagent_app.build()
|
|
85
|
+
configure_tracing(server)
|
|
86
|
+
|
|
87
|
+
uvicorn.run(
|
|
88
|
+
server,
|
|
89
|
+
host=host,
|
|
90
|
+
port=port,
|
|
91
|
+
workers=workers,
|
|
92
|
+
reload=reload,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@app.command()
|
|
97
|
+
def pull_skills(
|
|
98
|
+
skills: Annotated[list[str], typer.Argument()],
|
|
99
|
+
insecure: Annotated[
|
|
100
|
+
bool,
|
|
101
|
+
typer.Option("--insecure", help="Allow insecure connections to registries"),
|
|
102
|
+
] = False,
|
|
103
|
+
):
|
|
104
|
+
skill_dir = os.environ.get("KAGENT_SKILLS_FOLDER", ".")
|
|
105
|
+
logger.info("Pulling skills")
|
|
106
|
+
for skill in skills:
|
|
107
|
+
fetch_skill(skill, skill_dir, insecure)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def add_to_agent(sts_integration: ADKTokenPropagationPlugin, agent: BaseAgent):
|
|
111
|
+
"""
|
|
112
|
+
Add the plugin to an ADK LLM agent by updating its MCP toolset
|
|
113
|
+
Call this once when setting up the agent; do not call it at runtime.
|
|
114
|
+
"""
|
|
115
|
+
from google.adk.agents import LlmAgent
|
|
116
|
+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
|
117
|
+
|
|
118
|
+
if not isinstance(agent, LlmAgent):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if not agent.tools:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
for tool in agent.tools:
|
|
125
|
+
if isinstance(tool, McpToolset):
|
|
126
|
+
mcp_toolset = tool
|
|
127
|
+
mcp_toolset._header_provider = sts_integration.header_provider
|
|
128
|
+
logger.debug("Updated tool connection params to include access token from STS server")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@app.command()
|
|
132
|
+
def run(
|
|
133
|
+
name: Annotated[str, typer.Argument(help="The name of the agent to run")],
|
|
134
|
+
working_dir: str = ".",
|
|
135
|
+
host: str = "127.0.0.1",
|
|
136
|
+
port: int = 8080,
|
|
137
|
+
workers: int = 1,
|
|
138
|
+
local: Annotated[
|
|
139
|
+
bool, typer.Option("--local", help="Run with in-memory session service (for local development)")
|
|
140
|
+
] = False,
|
|
141
|
+
):
|
|
142
|
+
app_cfg = KAgentConfig()
|
|
143
|
+
|
|
144
|
+
plugins = None
|
|
145
|
+
sts_integration = create_sts_integration()
|
|
146
|
+
if sts_integration:
|
|
147
|
+
plugins = [sts_integration]
|
|
148
|
+
|
|
149
|
+
agent_loader = AgentLoader(agents_dir=working_dir)
|
|
150
|
+
|
|
151
|
+
def root_agent_factory() -> BaseAgent:
|
|
152
|
+
root_agent = agent_loader.load_agent(name)
|
|
153
|
+
|
|
154
|
+
if sts_integration:
|
|
155
|
+
add_to_agent(sts_integration, root_agent)
|
|
156
|
+
|
|
157
|
+
maybe_add_skills(root_agent)
|
|
158
|
+
|
|
159
|
+
return root_agent
|
|
160
|
+
|
|
161
|
+
# Load agent config to get stream setting
|
|
162
|
+
agent_config = None
|
|
163
|
+
config_path = os.path.join(working_dir, name, "config.json")
|
|
164
|
+
try:
|
|
165
|
+
with open(config_path, "r") as f:
|
|
166
|
+
config = json.load(f)
|
|
167
|
+
agent_config = AgentConfig.model_validate(config)
|
|
168
|
+
except FileNotFoundError:
|
|
169
|
+
logger.debug(f"No config.json found at {config_path}, using defaults")
|
|
170
|
+
|
|
171
|
+
with open(os.path.join(working_dir, name, "agent-card.json"), "r") as f:
|
|
172
|
+
agent_card = json.load(f)
|
|
173
|
+
agent_card = AgentCard.model_validate(agent_card)
|
|
174
|
+
|
|
175
|
+
# Attempt to import optional user-defined lifespan(app) from the agent package
|
|
176
|
+
lifespan = None
|
|
177
|
+
try:
|
|
178
|
+
module_candidate = importlib.import_module(name)
|
|
179
|
+
if hasattr(module_candidate, "lifespan"):
|
|
180
|
+
lifespan = module_candidate.lifespan
|
|
181
|
+
except Exception:
|
|
182
|
+
logger.exception(f"Failed to load agent module '{name}' for lifespan")
|
|
183
|
+
|
|
184
|
+
kagent_app = KAgentApp(
|
|
185
|
+
root_agent_factory,
|
|
186
|
+
agent_card,
|
|
187
|
+
app_cfg.url,
|
|
188
|
+
app_cfg.app_name,
|
|
189
|
+
lifespan=lifespan,
|
|
190
|
+
plugins=plugins,
|
|
191
|
+
stream=agent_config.stream if agent_config and agent_config.stream is not None else False,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if local:
|
|
195
|
+
logger.info("Running in local mode with InMemorySessionService")
|
|
196
|
+
server = kagent_app.build(local=True)
|
|
197
|
+
else:
|
|
198
|
+
server = kagent_app.build()
|
|
199
|
+
|
|
200
|
+
configure_tracing(server)
|
|
201
|
+
|
|
202
|
+
uvicorn.run(
|
|
203
|
+
server,
|
|
204
|
+
host=host,
|
|
205
|
+
port=port,
|
|
206
|
+
workers=workers,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
async def test_agent(agent_config: AgentConfig, agent_card: AgentCard, task: str):
|
|
211
|
+
app_cfg = KAgentConfig(url="http://fake-url.example.com", name="test-agent", namespace="kagent")
|
|
212
|
+
plugins = None
|
|
213
|
+
sts_integration = create_sts_integration()
|
|
214
|
+
if sts_integration:
|
|
215
|
+
plugins = [sts_integration]
|
|
216
|
+
|
|
217
|
+
def root_agent_factory() -> BaseAgent:
|
|
218
|
+
root_agent = agent_config.to_agent(app_cfg.name, sts_integration)
|
|
219
|
+
maybe_add_skills(root_agent)
|
|
220
|
+
return root_agent
|
|
221
|
+
|
|
222
|
+
app = KAgentApp(root_agent_factory, agent_card, app_cfg.url, app_cfg.app_name, plugins=plugins)
|
|
223
|
+
await app.test(task)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@app.command()
|
|
227
|
+
def test(
|
|
228
|
+
task: Annotated[str, typer.Option("--task", help="The task to test the agent with")],
|
|
229
|
+
filepath: Annotated[str, typer.Option("--filepath", help="The path to the agent config file")],
|
|
230
|
+
):
|
|
231
|
+
with open(os.path.join(filepath, "config.json"), "r") as f:
|
|
232
|
+
content = f.read()
|
|
233
|
+
config = json.loads(content)
|
|
234
|
+
|
|
235
|
+
with open(os.path.join(filepath, "agent-card.json"), "r") as f:
|
|
236
|
+
agent_card = json.load(f)
|
|
237
|
+
agent_card = AgentCard.model_validate(agent_card)
|
|
238
|
+
agent_config = AgentConfig.model_validate(config)
|
|
239
|
+
asyncio.run(test_agent(agent_config, agent_card, task))
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def run_cli():
|
|
243
|
+
configure_logging()
|
|
244
|
+
logger.info("Starting KAgent")
|
|
245
|
+
app()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
if __name__ == "__main__":
|
|
249
|
+
run_cli()
|
|
File without changes
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Error code to user-friendly message mappings for ADK events.
|
|
2
|
+
|
|
3
|
+
This module provides mappings from Google GenAI finish reasons to user-friendly
|
|
4
|
+
error messages, excluding STOP which is a normal completion reason.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, Optional
|
|
8
|
+
|
|
9
|
+
from google.genai import types as genai_types
|
|
10
|
+
|
|
11
|
+
# Error code to user-friendly message mappings
|
|
12
|
+
# Based on Google GenAI types.py FinishReason enum (excluding STOP)
|
|
13
|
+
ERROR_CODE_MESSAGES: Dict[str, str] = {
|
|
14
|
+
# Length and token limits
|
|
15
|
+
genai_types.FinishReason.MAX_TOKENS: "Response was truncated due to maximum token limit. Try asking a shorter question or breaking it into parts.",
|
|
16
|
+
# Safety and content filtering
|
|
17
|
+
genai_types.FinishReason.SAFETY: "Response was blocked due to safety concerns. Please rephrase your request to avoid potentially harmful content.",
|
|
18
|
+
genai_types.FinishReason.RECITATION: "Response was blocked due to unauthorized citations. Please rephrase your request.",
|
|
19
|
+
genai_types.FinishReason.BLOCKLIST: "Response was blocked due to restricted terminology. Please rephrase your request using different words.",
|
|
20
|
+
genai_types.FinishReason.PROHIBITED_CONTENT: "Response was blocked due to prohibited content. Please rephrase your request.",
|
|
21
|
+
genai_types.FinishReason.SPII: "Response was blocked due to sensitive personal information concerns. Please avoid including personal details.",
|
|
22
|
+
# Function calling errors
|
|
23
|
+
genai_types.FinishReason.MALFORMED_FUNCTION_CALL: "The agent generated an invalid function call. This may be due to complex input data. Try rephrasing your request or breaking it into simpler steps.",
|
|
24
|
+
# Generic fallback
|
|
25
|
+
genai_types.FinishReason.OTHER: "An unexpected error occurred during processing. Please try again or rephrase your request.",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Normal completion reasons that should not be treated as errors
|
|
29
|
+
NORMAL_COMPLETION_REASONS = {
|
|
30
|
+
genai_types.FinishReason.STOP, # Normal completion
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Default error message when no specific mapping exists
|
|
34
|
+
DEFAULT_ERROR_MESSAGE = "An error occurred during processing"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _get_error_message(error_code: Optional[str]) -> str:
|
|
38
|
+
"""Get a user-friendly error message for the given error code.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
error_code: The error code from the ADK event (e.g., finish_reason)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
User-friendly error message string
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Return mapped message or default
|
|
48
|
+
return ERROR_CODE_MESSAGES.get(error_code, DEFAULT_ERROR_MESSAGE)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _is_normal_completion(error_code: Optional[str]) -> bool:
|
|
52
|
+
"""Check if the error code represents normal completion rather than an error.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
error_code: The error code to check
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
True if this is a normal completion reason, False otherwise
|
|
59
|
+
"""
|
|
60
|
+
return error_code in NORMAL_COMPLETION_REASONS
|