aixtools 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/.chainlit/config.toml +113 -0
- aixtools/.chainlit/translations/bn.json +214 -0
- aixtools/.chainlit/translations/en-US.json +214 -0
- aixtools/.chainlit/translations/gu.json +214 -0
- aixtools/.chainlit/translations/he-IL.json +214 -0
- aixtools/.chainlit/translations/hi.json +214 -0
- aixtools/.chainlit/translations/ja.json +214 -0
- aixtools/.chainlit/translations/kn.json +214 -0
- aixtools/.chainlit/translations/ml.json +214 -0
- aixtools/.chainlit/translations/mr.json +214 -0
- aixtools/.chainlit/translations/nl.json +214 -0
- aixtools/.chainlit/translations/ta.json +214 -0
- aixtools/.chainlit/translations/te.json +214 -0
- aixtools/.chainlit/translations/zh-CN.json +214 -0
- aixtools/__init__.py +11 -0
- aixtools/_version.py +34 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/google_sdk/__init__.py +0 -0
- aixtools/a2a/google_sdk/card.py +27 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
- aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
- aixtools/a2a/google_sdk/utils.py +59 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +71 -0
- aixtools/agents/prompt.py +97 -0
- aixtools/app.py +143 -0
- aixtools/chainlit.md +14 -0
- aixtools/compliance/__init__.py +9 -0
- aixtools/compliance/private_data.py +138 -0
- aixtools/context.py +17 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/google/client.py +25 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +161 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/client.py +375 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +319 -0
- aixtools/model_patch/model_patch.py +63 -0
- aixtools/server/__init__.py +29 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/server/workspace_privacy.py +65 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +149 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +131 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/files.py +17 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +167 -0
- aixtools/vault/__init__.py +7 -0
- aixtools/vault/vault.py +137 -0
- aixtools-0.0.0.dist-info/METADATA +669 -0
- aixtools-0.0.0.dist-info/RECORD +88 -0
- aixtools-0.0.0.dist-info/WHEEL +5 -0
- aixtools-0.0.0.dist-info/entry_points.txt +2 -0
- aixtools-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
4
|
+
from a2a.server.events import EventQueue
|
|
5
|
+
from a2a.types import (
|
|
6
|
+
Artifact,
|
|
7
|
+
FilePart,
|
|
8
|
+
FileWithUri,
|
|
9
|
+
Message,
|
|
10
|
+
Part,
|
|
11
|
+
TaskArtifactUpdateEvent,
|
|
12
|
+
TaskState,
|
|
13
|
+
TaskStatus,
|
|
14
|
+
TaskStatusUpdateEvent,
|
|
15
|
+
)
|
|
16
|
+
from a2a.utils import get_file_parts, get_message_text, new_agent_text_message, new_task
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
from pydantic_ai import Agent, BinaryContent
|
|
19
|
+
|
|
20
|
+
from aixtools.a2a.google_sdk.pydantic_ai_adapter.storage import InMemoryHistoryStorage
|
|
21
|
+
from aixtools.a2a.google_sdk.remote_agent_connection import is_in_terminal_state
|
|
22
|
+
from aixtools.a2a.google_sdk.utils import get_session_id_tuple
|
|
23
|
+
from aixtools.agents import get_agent
|
|
24
|
+
from aixtools.agents.prompt import build_user_input
|
|
25
|
+
from aixtools.context import SessionIdTuple
|
|
26
|
+
from aixtools.logging.logging_config import get_logger
|
|
27
|
+
from aixtools.mcp.client import get_configured_mcp_servers
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AgentParameters(BaseModel):
|
|
33
|
+
system_prompt: str
|
|
34
|
+
mcp_servers: list[str]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RunOutput(BaseModel):
|
|
38
|
+
is_task_failed: bool
|
|
39
|
+
is_task_in_progress: bool
|
|
40
|
+
is_input_required: bool
|
|
41
|
+
output: str
|
|
42
|
+
created_artifacts_paths: list[str]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _task_failed_event(text: str, context_id: str | None, task_id: str | None) -> TaskStatusUpdateEvent:
|
|
46
|
+
"""Creates a TaskStatusUpdateEvent indicating task failure."""
|
|
47
|
+
return TaskStatusUpdateEvent(
|
|
48
|
+
status=TaskStatus(
|
|
49
|
+
state=TaskState.failed, message=new_agent_text_message(text=text, context_id=context_id, task_id=task_id)
|
|
50
|
+
),
|
|
51
|
+
final=True,
|
|
52
|
+
context_id=context_id,
|
|
53
|
+
task_id=task_id,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class PydanticAgentExecutor(AgentExecutor):
|
|
58
|
+
def __init__(self, agent_parameters: AgentParameters):
|
|
59
|
+
self._agent_parameters = agent_parameters
|
|
60
|
+
self.history_storage = InMemoryHistoryStorage()
|
|
61
|
+
|
|
62
|
+
def _convert_message_to_pydantic_parts(
|
|
63
|
+
self,
|
|
64
|
+
session_tuple: SessionIdTuple,
|
|
65
|
+
message: Message,
|
|
66
|
+
) -> str | list[str | BinaryContent]:
|
|
67
|
+
"""Convert A2A Message to a Pydantic AI agent input format"""
|
|
68
|
+
text_prompt = get_message_text(message)
|
|
69
|
+
file_parts = get_file_parts(message.parts)
|
|
70
|
+
if not file_parts:
|
|
71
|
+
return text_prompt
|
|
72
|
+
file_paths = [Path(part.uri) for part in file_parts if isinstance(part, FileWithUri)]
|
|
73
|
+
|
|
74
|
+
return build_user_input(session_tuple, text_prompt, file_paths)
|
|
75
|
+
|
|
76
|
+
async def execute(
|
|
77
|
+
self,
|
|
78
|
+
context: RequestContext,
|
|
79
|
+
event_queue: EventQueue,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Execute the agent run.
|
|
83
|
+
Wraps pydantic ai agent execution with a2a protocol events
|
|
84
|
+
Args:
|
|
85
|
+
context (RequestContext): The request context containing the message and task information.
|
|
86
|
+
event_queue (EventQueue): The event queue to enqueue events.
|
|
87
|
+
"""
|
|
88
|
+
session_tuple = get_session_id_tuple(context)
|
|
89
|
+
agent = self._build_agent(session_tuple)
|
|
90
|
+
if context.message is None:
|
|
91
|
+
raise ValueError("No message provided")
|
|
92
|
+
|
|
93
|
+
task = context.current_task
|
|
94
|
+
message = context.message
|
|
95
|
+
if not task:
|
|
96
|
+
task = new_task(message)
|
|
97
|
+
await event_queue.enqueue_event(task)
|
|
98
|
+
|
|
99
|
+
if is_in_terminal_state(task):
|
|
100
|
+
raise RuntimeError("Can not perform a task as it is in a terminal state: %s", task.status.state)
|
|
101
|
+
|
|
102
|
+
prompt = self._convert_message_to_pydantic_parts(session_tuple, message)
|
|
103
|
+
history_message = self.history_storage.get(task.id)
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
result = await agent.run(
|
|
107
|
+
user_prompt=prompt,
|
|
108
|
+
message_history=history_message,
|
|
109
|
+
)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
await event_queue.enqueue_event(
|
|
112
|
+
_task_failed_event(
|
|
113
|
+
text=f"Agent execution error: {e}",
|
|
114
|
+
context_id=context.context_id,
|
|
115
|
+
task_id=task.id,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
self.history_storage.store(task.id, result.all_messages())
|
|
121
|
+
|
|
122
|
+
run_output: RunOutput = result.output
|
|
123
|
+
if run_output.is_task_failed:
|
|
124
|
+
await event_queue.enqueue_event(
|
|
125
|
+
_task_failed_event(
|
|
126
|
+
text=f"Task failed: {run_output.output}",
|
|
127
|
+
context_id=context.context_id,
|
|
128
|
+
task_id=task.id,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
if run_output.is_input_required:
|
|
134
|
+
await event_queue.enqueue_event(
|
|
135
|
+
TaskStatusUpdateEvent(
|
|
136
|
+
status=TaskStatus(
|
|
137
|
+
state=TaskState.input_required,
|
|
138
|
+
message=new_agent_text_message(
|
|
139
|
+
text=run_output.output, context_id=context.context_id, task_id=task.id
|
|
140
|
+
),
|
|
141
|
+
),
|
|
142
|
+
final=False,
|
|
143
|
+
context_id=context.context_id,
|
|
144
|
+
task_id=task.id,
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
if run_output.is_task_in_progress:
|
|
150
|
+
logger.error("Task hasn't been completed: %s", run_output.output)
|
|
151
|
+
await event_queue.enqueue_event(
|
|
152
|
+
_task_failed_event(
|
|
153
|
+
text=f"Agent didn't manage complete the task: {run_output.output}",
|
|
154
|
+
context_id=context.context_id,
|
|
155
|
+
task_id=task.id,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
for idx, artifact in enumerate(run_output.created_artifacts_paths):
|
|
161
|
+
image_file = FileWithUri(uri=str(artifact), name=f"image_{idx}")
|
|
162
|
+
await event_queue.enqueue_event(
|
|
163
|
+
TaskArtifactUpdateEvent(
|
|
164
|
+
append=False,
|
|
165
|
+
context_id=task.context_id,
|
|
166
|
+
task_id=task.id,
|
|
167
|
+
last_chunk=True,
|
|
168
|
+
artifact=Artifact(
|
|
169
|
+
artifact_id=f"image_{idx}",
|
|
170
|
+
parts=[Part(root=FilePart(file=image_file))],
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
await event_queue.enqueue_event(
|
|
175
|
+
TaskStatusUpdateEvent(
|
|
176
|
+
status=TaskStatus(
|
|
177
|
+
state=TaskState.completed,
|
|
178
|
+
message=new_agent_text_message(
|
|
179
|
+
text=run_output.output, context_id=context.context_id, task_id=task.id
|
|
180
|
+
),
|
|
181
|
+
),
|
|
182
|
+
final=True,
|
|
183
|
+
context_id=context.context_id,
|
|
184
|
+
task_id=task.id,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
async def cancel(self, ctx: RequestContext, event_queue: EventQueue) -> None:
|
|
189
|
+
"""Cancel"""
|
|
190
|
+
raise Exception("cancel not supported")
|
|
191
|
+
|
|
192
|
+
def _build_agent(self, session_tuple: SessionIdTuple) -> Agent:
|
|
193
|
+
params = self._agent_parameters
|
|
194
|
+
mcp_servers = get_configured_mcp_servers(session_tuple, params.mcp_servers)
|
|
195
|
+
return get_agent(
|
|
196
|
+
system_prompt=params.system_prompt,
|
|
197
|
+
toolsets=mcp_servers,
|
|
198
|
+
output_type=RunOutput,
|
|
199
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Storage interface and in-memory implementation for Pydantic AI agent history."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.messages import ModelRequest, ModelResponse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PydanticAiAgentHistoryStorage(ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def get(self, task_id: str) -> list[ModelRequest | ModelResponse] | None:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def store(self, task_id: str, messages: list[ModelRequest | ModelResponse]) -> None:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InMemoryHistoryStorage(PydanticAiAgentHistoryStorage):
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.storage: dict[str, list[ModelRequest | ModelResponse]] = {}
|
|
21
|
+
|
|
22
|
+
def get(self, task_id: str) -> list[ModelRequest | ModelResponse] | None:
|
|
23
|
+
return self.storage.get(task_id, None)
|
|
24
|
+
|
|
25
|
+
def store(self, task_id: str, messages: list[ModelRequest | ModelResponse]) -> None:
|
|
26
|
+
self.storage[task_id] = messages
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from a2a.client import Client
|
|
4
|
+
from a2a.types import (
|
|
5
|
+
AgentCard,
|
|
6
|
+
Message,
|
|
7
|
+
Task,
|
|
8
|
+
TaskQueryParams,
|
|
9
|
+
TaskState,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from aixtools.logging.logging_config import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_in_terminal_state(task: Task) -> bool:
|
|
18
|
+
return task.status.state in [
|
|
19
|
+
TaskState.completed,
|
|
20
|
+
TaskState.canceled,
|
|
21
|
+
TaskState.failed,
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_in_terminal_or_interrupted_state(task: Task) -> bool:
|
|
26
|
+
return task.status.state in [
|
|
27
|
+
TaskState.input_required,
|
|
28
|
+
TaskState.unknown,
|
|
29
|
+
] or is_in_terminal_state(task)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RemoteAgentConnection:
|
|
33
|
+
def __init__(self, card: AgentCard, client: Client):
|
|
34
|
+
self._client = client
|
|
35
|
+
self._card = card
|
|
36
|
+
|
|
37
|
+
def get_agent_card(self) -> AgentCard:
|
|
38
|
+
"""
|
|
39
|
+
Returns the agent card associated with this connection.
|
|
40
|
+
"""
|
|
41
|
+
return self._card
|
|
42
|
+
|
|
43
|
+
async def send_message(self, message: Message) -> Task | Message | None:
|
|
44
|
+
"""
|
|
45
|
+
Sends a message to the remote agent and returns either a Task, a Message, or None.
|
|
46
|
+
"""
|
|
47
|
+
last_task: Task | None = None
|
|
48
|
+
try:
|
|
49
|
+
async for event in self._client.send_message(message):
|
|
50
|
+
if isinstance(event, Message):
|
|
51
|
+
return event
|
|
52
|
+
if is_in_terminal_or_interrupted_state(event[0]):
|
|
53
|
+
return event[0]
|
|
54
|
+
last_task = event[0]
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.error("Exception found in send_message: %s", str(e))
|
|
57
|
+
raise e
|
|
58
|
+
return last_task
|
|
59
|
+
|
|
60
|
+
async def send_message_with_polling(
|
|
61
|
+
self,
|
|
62
|
+
message: Message,
|
|
63
|
+
*,
|
|
64
|
+
sleep_time: float = 0.2,
|
|
65
|
+
max_iter=1000,
|
|
66
|
+
) -> Task | Message:
|
|
67
|
+
"""
|
|
68
|
+
Sends a message to the remote agent and polls for the task status at regular intervals.
|
|
69
|
+
If the task reaches a terminal state or is interrupted, it returns the task.
|
|
70
|
+
If the task does not complete within the maximum number of iterations, it raises an exception.
|
|
71
|
+
"""
|
|
72
|
+
last_task = await self.send_message(message)
|
|
73
|
+
if not last_task:
|
|
74
|
+
raise ValueError("No task or message returned from send_message")
|
|
75
|
+
if isinstance(last_task, Message):
|
|
76
|
+
return last_task
|
|
77
|
+
|
|
78
|
+
if is_in_terminal_or_interrupted_state(last_task):
|
|
79
|
+
return last_task
|
|
80
|
+
task_id = last_task.id
|
|
81
|
+
for _ in range(max_iter):
|
|
82
|
+
await asyncio.sleep(sleep_time)
|
|
83
|
+
task = await self._client.get_task(TaskQueryParams(id=task_id))
|
|
84
|
+
if is_in_terminal_or_interrupted_state(task):
|
|
85
|
+
return task
|
|
86
|
+
|
|
87
|
+
timeout_seconds = max_iter * sleep_time
|
|
88
|
+
raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
|
|
5
|
+
from a2a.server.agent_execution import RequestContext
|
|
6
|
+
from a2a.types import AgentCard
|
|
7
|
+
from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH
|
|
8
|
+
|
|
9
|
+
from aixtools.a2a.google_sdk.remote_agent_connection import RemoteAgentConnection
|
|
10
|
+
from aixtools.context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, SessionIdTuple
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AgentCardLoadFailedError(Exception):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _AgentCardResolver:
|
|
18
|
+
def __init__(self, client: httpx.AsyncClient):
|
|
19
|
+
self._httpx_client = client
|
|
20
|
+
self._a2a_client_factory = ClientFactory(ClientConfig(httpx_client=self._httpx_client))
|
|
21
|
+
self.clients: dict[str, RemoteAgentConnection] = {}
|
|
22
|
+
|
|
23
|
+
def register_agent_card(self, card: AgentCard):
|
|
24
|
+
remote_connection = RemoteAgentConnection(card, self._a2a_client_factory.create(card))
|
|
25
|
+
self.clients[card.name] = remote_connection
|
|
26
|
+
|
|
27
|
+
async def retrieve_card(self, address: str):
|
|
28
|
+
for card_path in [AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH]:
|
|
29
|
+
try:
|
|
30
|
+
card_resolver = A2ACardResolver(self._httpx_client, address, card_path)
|
|
31
|
+
card = await card_resolver.get_agent_card()
|
|
32
|
+
card.url = address
|
|
33
|
+
self.register_agent_card(card)
|
|
34
|
+
return
|
|
35
|
+
except Exception as e:
|
|
36
|
+
print(f"Error retrieving agent card from {address} at path {card_path}: {e}")
|
|
37
|
+
|
|
38
|
+
raise AgentCardLoadFailedError(f"Failed to load agent card from {address}")
|
|
39
|
+
|
|
40
|
+
async def get_a2a_clients(self, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
|
|
41
|
+
async with asyncio.TaskGroup() as task_group:
|
|
42
|
+
for address in agent_hosts:
|
|
43
|
+
task_group.create_task(self.retrieve_card(address))
|
|
44
|
+
|
|
45
|
+
return self.clients
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def get_a2a_clients(ctx: SessionIdTuple, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
|
|
49
|
+
headers = {
|
|
50
|
+
"user-id": ctx[0],
|
|
51
|
+
"session-id": ctx[1],
|
|
52
|
+
}
|
|
53
|
+
httpx_client = httpx.AsyncClient(headers=headers, timeout=60.0)
|
|
54
|
+
return await _AgentCardResolver(httpx_client).get_a2a_clients(agent_hosts)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_session_id_tuple(context: RequestContext) -> SessionIdTuple:
|
|
58
|
+
headers = context.call_context.state.get("headers", {})
|
|
59
|
+
return headers.get("user-id", DEFAULT_USER_ID), headers.get("session-id", DEFAULT_SESSION_ID)
|
aixtools/a2a/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Utilities for Agent-to-Agent (A2A) communication and task management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from fasta2a import Skill
|
|
8
|
+
from fasta2a.client import A2AClient
|
|
9
|
+
from fasta2a.schema import GetTaskResponse, Message, Part, TextPart
|
|
10
|
+
from fastapi import status
|
|
11
|
+
|
|
12
|
+
from ..server import get_session_id_tuple
|
|
13
|
+
|
|
14
|
+
SLEEP_TIME = 0.2
|
|
15
|
+
MAX_ITER = 1000
|
|
16
|
+
HTTP_OK = 200
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def card2description(card):
|
|
20
|
+
"""Convert agent card to a description string."""
|
|
21
|
+
descr = f"{card['name']}: {card['description']}\n"
|
|
22
|
+
skills = card.get("skills", [])
|
|
23
|
+
for skill in skills:
|
|
24
|
+
descr += f"\t - {skill['name']}: {skill['description']}\n"
|
|
25
|
+
return descr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def fetch_agent_card(client: A2AClient) -> dict:
|
|
29
|
+
"""Request the Agent's card"""
|
|
30
|
+
server_url = str(client.http_client.base_url).rstrip("/")
|
|
31
|
+
agent_card_url = f"{server_url}/.well-known/agent.json"
|
|
32
|
+
response = await client.http_client.get(agent_card_url, timeout=10)
|
|
33
|
+
if response.status_code == status.HTTP_200_OK:
|
|
34
|
+
card_data = response.json()
|
|
35
|
+
return card_data
|
|
36
|
+
raise Exception(f"Failed to retrieve agent card from {agent_card_url}. Status code: {response.status_code}") # pylint: disable=broad-exception-raised
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_result_text(ret: GetTaskResponse) -> str | None:
|
|
40
|
+
"""Extract the result text from the task result"""
|
|
41
|
+
if "result" not in ret:
|
|
42
|
+
return None
|
|
43
|
+
result = ret["result"]
|
|
44
|
+
if "artifacts" not in result:
|
|
45
|
+
return None
|
|
46
|
+
artifacts = result["artifacts"]
|
|
47
|
+
for artifact in artifacts:
|
|
48
|
+
if "parts" not in artifact:
|
|
49
|
+
continue
|
|
50
|
+
parts = artifact["parts"]
|
|
51
|
+
for part in parts:
|
|
52
|
+
if part["kind"] == "text":
|
|
53
|
+
return part["text"]
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def poll_task(client: A2AClient, task_id: str) -> GetTaskResponse:
|
|
58
|
+
"""Polls the task status until it is completed or failed."""
|
|
59
|
+
state = None
|
|
60
|
+
for _ in range(MAX_ITER):
|
|
61
|
+
ret = await client.get_task(task_id=task_id)
|
|
62
|
+
# Check the state of the task
|
|
63
|
+
state = ret["result"]["status"]["state"] if "result" in ret and "status" in ret["result"] else None
|
|
64
|
+
if state == "completed":
|
|
65
|
+
return ret
|
|
66
|
+
if state == "failed":
|
|
67
|
+
raise Exception("Task failed") # pylint: disable=broad-exception-raised
|
|
68
|
+
# Sleep for a while before checking again
|
|
69
|
+
await asyncio.sleep(SLEEP_TIME)
|
|
70
|
+
timeout_seconds = MAX_ITER * SLEEP_TIME
|
|
71
|
+
raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def submit_task(client: A2AClient, message: Message) -> str:
|
|
75
|
+
"""Send a message to the client and return task id."""
|
|
76
|
+
user_id, session_id = get_session_id_tuple()
|
|
77
|
+
msg = message.copy()
|
|
78
|
+
msg["metadata"] = {
|
|
79
|
+
**msg.get("metadata", {}),
|
|
80
|
+
"user_id": client.http_client.headers.get("user-id", user_id),
|
|
81
|
+
"session_id": client.http_client.headers.get("session-id", session_id),
|
|
82
|
+
}
|
|
83
|
+
ret = await client.send_message(message=msg)
|
|
84
|
+
task_id = ret["result"]["id"] if "result" in ret and "id" in ret["result"] else ""
|
|
85
|
+
return task_id
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def multipart_message(parts: list[Part]) -> Message:
|
|
89
|
+
"""Create a message object"""
|
|
90
|
+
message = Message(kind="message", role="user", parts=parts, message_id=str(uuid.uuid4()))
|
|
91
|
+
return message
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def text_message(text: str) -> Message:
|
|
95
|
+
"""Create a message object with a text part."""
|
|
96
|
+
text_part = TextPart(kind="text", text=text, metadata={})
|
|
97
|
+
return multipart_message([text_part])
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def task(client: A2AClient, text: str) -> GetTaskResponse:
|
|
101
|
+
"""Send a text message to the client and wait for task completion."""
|
|
102
|
+
msg = text_message(text)
|
|
103
|
+
task_id = await submit_task(client, msg)
|
|
104
|
+
print(f"Task ID: {task_id}")
|
|
105
|
+
ret = await poll_task(client, task_id)
|
|
106
|
+
return ret
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def tool2skill(tool: Callable) -> Skill:
|
|
110
|
+
"""Convert a tool to a skill."""
|
|
111
|
+
return Skill(
|
|
112
|
+
id=tool.__name__,
|
|
113
|
+
name=tool.__name__,
|
|
114
|
+
description=tool.__doc__ or "",
|
|
115
|
+
) # type: ignore
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Agent utilities for running and managing AI agents."""
|
|
2
|
+
|
|
3
|
+
from .agent import get_agent, get_model, run_agent
|
|
4
|
+
from .agent_batch import AgentQueryParams, run_agent_batch
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"get_agent",
|
|
8
|
+
"get_model",
|
|
9
|
+
"run_agent",
|
|
10
|
+
"AgentQueryParams",
|
|
11
|
+
"run_agent_batch",
|
|
12
|
+
]
|
aixtools/agents/agent.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core agent implementation providing model selection and configuration for AI agents.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from types import NoneType
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from openai import AsyncAzureOpenAI
|
|
9
|
+
from pydantic_ai import Agent
|
|
10
|
+
from pydantic_ai.models.bedrock import BedrockConverseModel
|
|
11
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
12
|
+
from pydantic_ai.providers.bedrock import BedrockProvider
|
|
13
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
14
|
+
from pydantic_ai.settings import ModelSettings
|
|
15
|
+
from pydantic_ai.usage import UsageLimits
|
|
16
|
+
|
|
17
|
+
from aixtools.logging.log_objects import ObjectLogger
|
|
18
|
+
from aixtools.logging.logging_config import get_logger
|
|
19
|
+
from aixtools.logging.model_patch_logging import model_patch_logging
|
|
20
|
+
from aixtools.utils.config import (
|
|
21
|
+
AWS_PROFILE,
|
|
22
|
+
AWS_REGION,
|
|
23
|
+
AZURE_MODEL_NAME,
|
|
24
|
+
AZURE_OPENAI_API_KEY,
|
|
25
|
+
AZURE_OPENAI_API_VERSION,
|
|
26
|
+
AZURE_OPENAI_ENDPOINT,
|
|
27
|
+
BEDROCK_MODEL_NAME,
|
|
28
|
+
MODEL_FAMILY,
|
|
29
|
+
MODEL_TIMEOUT,
|
|
30
|
+
OLLAMA_MODEL_NAME,
|
|
31
|
+
OLLAMA_URL,
|
|
32
|
+
OPENAI_API_KEY,
|
|
33
|
+
OPENAI_MODEL_NAME,
|
|
34
|
+
OPENROUTER_API_KEY,
|
|
35
|
+
OPENROUTER_API_URL,
|
|
36
|
+
OPENROUTER_MODEL_NAME,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
logger = get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_model_bedrock(model_name=BEDROCK_MODEL_NAME, aws_region=AWS_REGION):
|
|
43
|
+
assert model_name, "BEDROCK_MODEL_NAME is not set"
|
|
44
|
+
assert aws_region, "AWS_REGION is not set"
|
|
45
|
+
|
|
46
|
+
if AWS_PROFILE is not None:
|
|
47
|
+
return BedrockConverseModel(model_name=model_name)
|
|
48
|
+
|
|
49
|
+
provider = BedrockProvider(region_name=aws_region)
|
|
50
|
+
return BedrockConverseModel(model_name=model_name, provider=provider)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_model_ollama(model_name=OLLAMA_MODEL_NAME, ollama_url=OLLAMA_URL):
|
|
54
|
+
assert ollama_url, "OLLAMA_URL is not set"
|
|
55
|
+
assert model_name, "Model name is not set"
|
|
56
|
+
provider = OpenAIProvider(base_url=ollama_url)
|
|
57
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_model_openai(model_name=OPENAI_MODEL_NAME, openai_api_key=OPENAI_API_KEY):
|
|
61
|
+
assert openai_api_key, "OPENAI_API_KEY is not set"
|
|
62
|
+
assert model_name, "Model name is not set"
|
|
63
|
+
provider = OpenAIProvider(api_key=openai_api_key)
|
|
64
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_model_openai_azure(
|
|
68
|
+
model_name=AZURE_MODEL_NAME,
|
|
69
|
+
azure_openai_api_key=AZURE_OPENAI_API_KEY,
|
|
70
|
+
azure_openai_endpoint=AZURE_OPENAI_ENDPOINT,
|
|
71
|
+
azure_openai_api_version=AZURE_OPENAI_API_VERSION,
|
|
72
|
+
):
|
|
73
|
+
assert azure_openai_endpoint, "AZURE_OPENAI_ENDPOINT is not set"
|
|
74
|
+
assert azure_openai_api_key, "AZURE_OPENAI_API_KEY is not set"
|
|
75
|
+
assert azure_openai_api_version, "AZURE_OPENAI_API_VERSION is not set"
|
|
76
|
+
assert model_name, "Model name is not set"
|
|
77
|
+
client = AsyncAzureOpenAI(
|
|
78
|
+
azure_endpoint=azure_openai_endpoint, api_version=azure_openai_api_version, api_key=azure_openai_api_key
|
|
79
|
+
)
|
|
80
|
+
return OpenAIModel(model_name=model_name, provider=OpenAIProvider(openai_client=client))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_model_open_router(
|
|
84
|
+
model_name=OPENROUTER_MODEL_NAME, openrouter_api_url=OPENROUTER_API_URL, openrouter_api_key=OPENROUTER_API_KEY
|
|
85
|
+
):
|
|
86
|
+
assert openrouter_api_url, "OPENROUTER_API_URL is not set"
|
|
87
|
+
assert openrouter_api_key, "OPENROUTER_API_KEY is not set"
|
|
88
|
+
assert model_name, "Model name is not set, missing 'OPENROUTER_MODEL_NAME' environment variable?"
|
|
89
|
+
provider = OpenAIProvider(base_url=openrouter_api_url, api_key=openrouter_api_key)
|
|
90
|
+
return OpenAIModel(model_name, provider=provider)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_model(model_family=MODEL_FAMILY, model_name=None, **kwargs):
|
|
94
|
+
"""Create and return appropriate model instance based on specified family and name."""
|
|
95
|
+
assert model_family is not None and model_family != "", f"Model family '{model_family}' is not set"
|
|
96
|
+
match model_family:
|
|
97
|
+
case "azure":
|
|
98
|
+
return _get_model_openai_azure(model_name=model_name or AZURE_MODEL_NAME, **kwargs)
|
|
99
|
+
case "bedrock":
|
|
100
|
+
return _get_model_bedrock(model_name=model_name or BEDROCK_MODEL_NAME, **kwargs)
|
|
101
|
+
case "ollama":
|
|
102
|
+
return _get_model_ollama(model_name=model_name or OLLAMA_MODEL_NAME, **kwargs)
|
|
103
|
+
case "openai":
|
|
104
|
+
return _get_model_openai(model_name=model_name or OPENAI_MODEL_NAME, **kwargs)
|
|
105
|
+
case "openrouter":
|
|
106
|
+
return _get_model_open_router(model_name=model_name or OPENROUTER_MODEL_NAME, **kwargs)
|
|
107
|
+
case _:
|
|
108
|
+
raise ValueError(f"Model family '{model_family}' not supported")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
112
|
+
model=None,
|
|
113
|
+
*,
|
|
114
|
+
instructions=None,
|
|
115
|
+
system_prompt=(),
|
|
116
|
+
tools=(),
|
|
117
|
+
toolsets=(),
|
|
118
|
+
model_settings=None,
|
|
119
|
+
output_type: Any = str,
|
|
120
|
+
deps_type=NoneType,
|
|
121
|
+
) -> Agent:
|
|
122
|
+
"""Get a PydanticAI agent"""
|
|
123
|
+
if model_settings is None:
|
|
124
|
+
model_settings = ModelSettings(timeout=MODEL_TIMEOUT)
|
|
125
|
+
if model is None:
|
|
126
|
+
model = get_model()
|
|
127
|
+
agent = Agent(
|
|
128
|
+
model=model,
|
|
129
|
+
output_type=output_type,
|
|
130
|
+
instructions=instructions,
|
|
131
|
+
system_prompt=system_prompt,
|
|
132
|
+
deps_type=deps_type,
|
|
133
|
+
model_settings=model_settings,
|
|
134
|
+
tools=tools,
|
|
135
|
+
toolsets=toolsets,
|
|
136
|
+
instrument=True,
|
|
137
|
+
)
|
|
138
|
+
return agent
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
142
|
+
agent: Agent,
|
|
143
|
+
prompt: str | list[str],
|
|
144
|
+
usage_limits: UsageLimits | None = None,
|
|
145
|
+
verbose: bool = False,
|
|
146
|
+
debug: bool = False,
|
|
147
|
+
log_model_requests: bool = False,
|
|
148
|
+
parent_logger: ObjectLogger | None = None,
|
|
149
|
+
):
|
|
150
|
+
"""Query the LLM"""
|
|
151
|
+
# Results
|
|
152
|
+
nodes, result = [], None
|
|
153
|
+
async with agent.iter(prompt, usage_limits=usage_limits) as agent_run:
|
|
154
|
+
# Create a new log file for each run
|
|
155
|
+
with ObjectLogger(parent_logger=parent_logger, verbose=verbose, debug=debug) as agent_logger:
|
|
156
|
+
# Patch the model with the logger
|
|
157
|
+
if log_model_requests:
|
|
158
|
+
agent.model = model_patch_logging(agent.model, agent_logger)
|
|
159
|
+
# Run the agent
|
|
160
|
+
async for node in agent_run:
|
|
161
|
+
agent_logger.log(node)
|
|
162
|
+
nodes.append(node)
|
|
163
|
+
result = agent_run.result
|
|
164
|
+
return result.output if result else None, nodes
|