aixtools 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (88) hide show
  1. aixtools/.chainlit/config.toml +113 -0
  2. aixtools/.chainlit/translations/bn.json +214 -0
  3. aixtools/.chainlit/translations/en-US.json +214 -0
  4. aixtools/.chainlit/translations/gu.json +214 -0
  5. aixtools/.chainlit/translations/he-IL.json +214 -0
  6. aixtools/.chainlit/translations/hi.json +214 -0
  7. aixtools/.chainlit/translations/ja.json +214 -0
  8. aixtools/.chainlit/translations/kn.json +214 -0
  9. aixtools/.chainlit/translations/ml.json +214 -0
  10. aixtools/.chainlit/translations/mr.json +214 -0
  11. aixtools/.chainlit/translations/nl.json +214 -0
  12. aixtools/.chainlit/translations/ta.json +214 -0
  13. aixtools/.chainlit/translations/te.json +214 -0
  14. aixtools/.chainlit/translations/zh-CN.json +214 -0
  15. aixtools/__init__.py +11 -0
  16. aixtools/_version.py +34 -0
  17. aixtools/a2a/app.py +126 -0
  18. aixtools/a2a/google_sdk/__init__.py +0 -0
  19. aixtools/a2a/google_sdk/card.py +27 -0
  20. aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
  21. aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
  22. aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
  23. aixtools/a2a/google_sdk/utils.py +59 -0
  24. aixtools/a2a/utils.py +115 -0
  25. aixtools/agents/__init__.py +12 -0
  26. aixtools/agents/agent.py +164 -0
  27. aixtools/agents/agent_batch.py +71 -0
  28. aixtools/agents/prompt.py +97 -0
  29. aixtools/app.py +143 -0
  30. aixtools/chainlit.md +14 -0
  31. aixtools/compliance/__init__.py +9 -0
  32. aixtools/compliance/private_data.py +138 -0
  33. aixtools/context.py +17 -0
  34. aixtools/db/__init__.py +17 -0
  35. aixtools/db/database.py +110 -0
  36. aixtools/db/vector_db.py +115 -0
  37. aixtools/google/client.py +25 -0
  38. aixtools/log_view/__init__.py +17 -0
  39. aixtools/log_view/app.py +195 -0
  40. aixtools/log_view/display.py +285 -0
  41. aixtools/log_view/export.py +51 -0
  42. aixtools/log_view/filters.py +41 -0
  43. aixtools/log_view/log_utils.py +26 -0
  44. aixtools/log_view/node_summary.py +229 -0
  45. aixtools/logfilters/__init__.py +7 -0
  46. aixtools/logfilters/context_filter.py +67 -0
  47. aixtools/logging/__init__.py +30 -0
  48. aixtools/logging/log_objects.py +227 -0
  49. aixtools/logging/logging_config.py +161 -0
  50. aixtools/logging/mcp_log_models.py +102 -0
  51. aixtools/logging/mcp_logger.py +172 -0
  52. aixtools/logging/model_patch_logging.py +87 -0
  53. aixtools/logging/open_telemetry.py +36 -0
  54. aixtools/mcp/__init__.py +9 -0
  55. aixtools/mcp/client.py +375 -0
  56. aixtools/mcp/example_client.py +30 -0
  57. aixtools/mcp/example_server.py +22 -0
  58. aixtools/mcp/fast_mcp_log.py +31 -0
  59. aixtools/mcp/faulty_mcp.py +319 -0
  60. aixtools/model_patch/model_patch.py +63 -0
  61. aixtools/server/__init__.py +29 -0
  62. aixtools/server/app_mounter.py +90 -0
  63. aixtools/server/path.py +72 -0
  64. aixtools/server/utils.py +70 -0
  65. aixtools/server/workspace_privacy.py +65 -0
  66. aixtools/testing/__init__.py +9 -0
  67. aixtools/testing/aix_test_model.py +149 -0
  68. aixtools/testing/mock_tool.py +66 -0
  69. aixtools/testing/model_patch_cache.py +279 -0
  70. aixtools/tools/doctor/__init__.py +3 -0
  71. aixtools/tools/doctor/tool_doctor.py +61 -0
  72. aixtools/tools/doctor/tool_recommendation.py +44 -0
  73. aixtools/utils/__init__.py +35 -0
  74. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  75. aixtools/utils/chainlit/cl_utils.py +168 -0
  76. aixtools/utils/config.py +131 -0
  77. aixtools/utils/config_util.py +69 -0
  78. aixtools/utils/enum_with_description.py +37 -0
  79. aixtools/utils/files.py +17 -0
  80. aixtools/utils/persisted_dict.py +99 -0
  81. aixtools/utils/utils.py +167 -0
  82. aixtools/vault/__init__.py +7 -0
  83. aixtools/vault/vault.py +137 -0
  84. aixtools-0.0.0.dist-info/METADATA +669 -0
  85. aixtools-0.0.0.dist-info/RECORD +88 -0
  86. aixtools-0.0.0.dist-info/WHEEL +5 -0
  87. aixtools-0.0.0.dist-info/entry_points.txt +2 -0
  88. aixtools-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ from pathlib import Path
2
+
3
+ from a2a.server.agent_execution import AgentExecutor, RequestContext
4
+ from a2a.server.events import EventQueue
5
+ from a2a.types import (
6
+ Artifact,
7
+ FilePart,
8
+ FileWithUri,
9
+ Message,
10
+ Part,
11
+ TaskArtifactUpdateEvent,
12
+ TaskState,
13
+ TaskStatus,
14
+ TaskStatusUpdateEvent,
15
+ )
16
+ from a2a.utils import get_file_parts, get_message_text, new_agent_text_message, new_task
17
+ from pydantic import BaseModel
18
+ from pydantic_ai import Agent, BinaryContent
19
+
20
+ from aixtools.a2a.google_sdk.pydantic_ai_adapter.storage import InMemoryHistoryStorage
21
+ from aixtools.a2a.google_sdk.remote_agent_connection import is_in_terminal_state
22
+ from aixtools.a2a.google_sdk.utils import get_session_id_tuple
23
+ from aixtools.agents import get_agent
24
+ from aixtools.agents.prompt import build_user_input
25
+ from aixtools.context import SessionIdTuple
26
+ from aixtools.logging.logging_config import get_logger
27
+ from aixtools.mcp.client import get_configured_mcp_servers
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ class AgentParameters(BaseModel):
33
+ system_prompt: str
34
+ mcp_servers: list[str]
35
+
36
+
37
+ class RunOutput(BaseModel):
38
+ is_task_failed: bool
39
+ is_task_in_progress: bool
40
+ is_input_required: bool
41
+ output: str
42
+ created_artifacts_paths: list[str]
43
+
44
+
45
+ def _task_failed_event(text: str, context_id: str | None, task_id: str | None) -> TaskStatusUpdateEvent:
46
+ """Creates a TaskStatusUpdateEvent indicating task failure."""
47
+ return TaskStatusUpdateEvent(
48
+ status=TaskStatus(
49
+ state=TaskState.failed, message=new_agent_text_message(text=text, context_id=context_id, task_id=task_id)
50
+ ),
51
+ final=True,
52
+ context_id=context_id,
53
+ task_id=task_id,
54
+ )
55
+
56
+
57
+ class PydanticAgentExecutor(AgentExecutor):
58
+ def __init__(self, agent_parameters: AgentParameters):
59
+ self._agent_parameters = agent_parameters
60
+ self.history_storage = InMemoryHistoryStorage()
61
+
62
+ def _convert_message_to_pydantic_parts(
63
+ self,
64
+ session_tuple: SessionIdTuple,
65
+ message: Message,
66
+ ) -> str | list[str | BinaryContent]:
67
+ """Convert A2A Message to a Pydantic AI agent input format"""
68
+ text_prompt = get_message_text(message)
69
+ file_parts = get_file_parts(message.parts)
70
+ if not file_parts:
71
+ return text_prompt
72
+ file_paths = [Path(part.uri) for part in file_parts if isinstance(part, FileWithUri)]
73
+
74
+ return build_user_input(session_tuple, text_prompt, file_paths)
75
+
76
+ async def execute(
77
+ self,
78
+ context: RequestContext,
79
+ event_queue: EventQueue,
80
+ ) -> None:
81
+ """
82
+ Execute the agent run.
83
+ Wraps pydantic ai agent execution with a2a protocol events
84
+ Args:
85
+ context (RequestContext): The request context containing the message and task information.
86
+ event_queue (EventQueue): The event queue to enqueue events.
87
+ """
88
+ session_tuple = get_session_id_tuple(context)
89
+ agent = self._build_agent(session_tuple)
90
+ if context.message is None:
91
+ raise ValueError("No message provided")
92
+
93
+ task = context.current_task
94
+ message = context.message
95
+ if not task:
96
+ task = new_task(message)
97
+ await event_queue.enqueue_event(task)
98
+
99
+ if is_in_terminal_state(task):
100
+ raise RuntimeError("Can not perform a task as it is in a terminal state: %s", task.status.state)
101
+
102
+ prompt = self._convert_message_to_pydantic_parts(session_tuple, message)
103
+ history_message = self.history_storage.get(task.id)
104
+
105
+ try:
106
+ result = await agent.run(
107
+ user_prompt=prompt,
108
+ message_history=history_message,
109
+ )
110
+ except Exception as e:
111
+ await event_queue.enqueue_event(
112
+ _task_failed_event(
113
+ text=f"Agent execution error: {e}",
114
+ context_id=context.context_id,
115
+ task_id=task.id,
116
+ )
117
+ )
118
+ return
119
+
120
+ self.history_storage.store(task.id, result.all_messages())
121
+
122
+ run_output: RunOutput = result.output
123
+ if run_output.is_task_failed:
124
+ await event_queue.enqueue_event(
125
+ _task_failed_event(
126
+ text=f"Task failed: {run_output.output}",
127
+ context_id=context.context_id,
128
+ task_id=task.id,
129
+ )
130
+ )
131
+ return
132
+
133
+ if run_output.is_input_required:
134
+ await event_queue.enqueue_event(
135
+ TaskStatusUpdateEvent(
136
+ status=TaskStatus(
137
+ state=TaskState.input_required,
138
+ message=new_agent_text_message(
139
+ text=run_output.output, context_id=context.context_id, task_id=task.id
140
+ ),
141
+ ),
142
+ final=False,
143
+ context_id=context.context_id,
144
+ task_id=task.id,
145
+ )
146
+ )
147
+ return
148
+
149
+ if run_output.is_task_in_progress:
150
+ logger.error("Task hasn't been completed: %s", run_output.output)
151
+ await event_queue.enqueue_event(
152
+ _task_failed_event(
153
+ text=f"Agent didn't manage complete the task: {run_output.output}",
154
+ context_id=context.context_id,
155
+ task_id=task.id,
156
+ )
157
+ )
158
+ return
159
+
160
+ for idx, artifact in enumerate(run_output.created_artifacts_paths):
161
+ image_file = FileWithUri(uri=str(artifact), name=f"image_{idx}")
162
+ await event_queue.enqueue_event(
163
+ TaskArtifactUpdateEvent(
164
+ append=False,
165
+ context_id=task.context_id,
166
+ task_id=task.id,
167
+ last_chunk=True,
168
+ artifact=Artifact(
169
+ artifact_id=f"image_{idx}",
170
+ parts=[Part(root=FilePart(file=image_file))],
171
+ ),
172
+ )
173
+ )
174
+ await event_queue.enqueue_event(
175
+ TaskStatusUpdateEvent(
176
+ status=TaskStatus(
177
+ state=TaskState.completed,
178
+ message=new_agent_text_message(
179
+ text=run_output.output, context_id=context.context_id, task_id=task.id
180
+ ),
181
+ ),
182
+ final=True,
183
+ context_id=context.context_id,
184
+ task_id=task.id,
185
+ )
186
+ )
187
+
188
+ async def cancel(self, ctx: RequestContext, event_queue: EventQueue) -> None:
189
+ """Cancel"""
190
+ raise Exception("cancel not supported")
191
+
192
+ def _build_agent(self, session_tuple: SessionIdTuple) -> Agent:
193
+ params = self._agent_parameters
194
+ mcp_servers = get_configured_mcp_servers(session_tuple, params.mcp_servers)
195
+ return get_agent(
196
+ system_prompt=params.system_prompt,
197
+ toolsets=mcp_servers,
198
+ output_type=RunOutput,
199
+ )
@@ -0,0 +1,26 @@
1
+ """Storage interface and in-memory implementation for Pydantic AI agent history."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from pydantic_ai.messages import ModelRequest, ModelResponse
6
+
7
+
8
+ class PydanticAiAgentHistoryStorage(ABC):
9
+ @abstractmethod
10
+ def get(self, task_id: str) -> list[ModelRequest | ModelResponse] | None:
11
+ pass
12
+
13
+ @abstractmethod
14
+ def store(self, task_id: str, messages: list[ModelRequest | ModelResponse]) -> None:
15
+ pass
16
+
17
+
18
+ class InMemoryHistoryStorage(PydanticAiAgentHistoryStorage):
19
+ def __init__(self):
20
+ self.storage: dict[str, list[ModelRequest | ModelResponse]] = {}
21
+
22
+ def get(self, task_id: str) -> list[ModelRequest | ModelResponse] | None:
23
+ return self.storage.get(task_id, None)
24
+
25
+ def store(self, task_id: str, messages: list[ModelRequest | ModelResponse]) -> None:
26
+ self.storage[task_id] = messages
@@ -0,0 +1,88 @@
1
+ import asyncio
2
+
3
+ from a2a.client import Client
4
+ from a2a.types import (
5
+ AgentCard,
6
+ Message,
7
+ Task,
8
+ TaskQueryParams,
9
+ TaskState,
10
+ )
11
+
12
+ from aixtools.logging.logging_config import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ def is_in_terminal_state(task: Task) -> bool:
18
+ return task.status.state in [
19
+ TaskState.completed,
20
+ TaskState.canceled,
21
+ TaskState.failed,
22
+ ]
23
+
24
+
25
+ def is_in_terminal_or_interrupted_state(task: Task) -> bool:
26
+ return task.status.state in [
27
+ TaskState.input_required,
28
+ TaskState.unknown,
29
+ ] or is_in_terminal_state(task)
30
+
31
+
32
+ class RemoteAgentConnection:
33
+ def __init__(self, card: AgentCard, client: Client):
34
+ self._client = client
35
+ self._card = card
36
+
37
+ def get_agent_card(self) -> AgentCard:
38
+ """
39
+ Returns the agent card associated with this connection.
40
+ """
41
+ return self._card
42
+
43
+ async def send_message(self, message: Message) -> Task | Message | None:
44
+ """
45
+ Sends a message to the remote agent and returns either a Task, a Message, or None.
46
+ """
47
+ last_task: Task | None = None
48
+ try:
49
+ async for event in self._client.send_message(message):
50
+ if isinstance(event, Message):
51
+ return event
52
+ if is_in_terminal_or_interrupted_state(event[0]):
53
+ return event[0]
54
+ last_task = event[0]
55
+ except Exception as e:
56
+ logger.error("Exception found in send_message: %s", str(e))
57
+ raise e
58
+ return last_task
59
+
60
+ async def send_message_with_polling(
61
+ self,
62
+ message: Message,
63
+ *,
64
+ sleep_time: float = 0.2,
65
+ max_iter=1000,
66
+ ) -> Task | Message:
67
+ """
68
+ Sends a message to the remote agent and polls for the task status at regular intervals.
69
+ If the task reaches a terminal state or is interrupted, it returns the task.
70
+ If the task does not complete within the maximum number of iterations, it raises an exception.
71
+ """
72
+ last_task = await self.send_message(message)
73
+ if not last_task:
74
+ raise ValueError("No task or message returned from send_message")
75
+ if isinstance(last_task, Message):
76
+ return last_task
77
+
78
+ if is_in_terminal_or_interrupted_state(last_task):
79
+ return last_task
80
+ task_id = last_task.id
81
+ for _ in range(max_iter):
82
+ await asyncio.sleep(sleep_time)
83
+ task = await self._client.get_task(TaskQueryParams(id=task_id))
84
+ if is_in_terminal_or_interrupted_state(task):
85
+ return task
86
+
87
+ timeout_seconds = max_iter * sleep_time
88
+ raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
@@ -0,0 +1,59 @@
1
+ import asyncio
2
+
3
+ import httpx
4
+ from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
5
+ from a2a.server.agent_execution import RequestContext
6
+ from a2a.types import AgentCard
7
+ from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH
8
+
9
+ from aixtools.a2a.google_sdk.remote_agent_connection import RemoteAgentConnection
10
+ from aixtools.context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, SessionIdTuple
11
+
12
+
13
+ class AgentCardLoadFailedError(Exception):
14
+ pass
15
+
16
+
17
+ class _AgentCardResolver:
18
+ def __init__(self, client: httpx.AsyncClient):
19
+ self._httpx_client = client
20
+ self._a2a_client_factory = ClientFactory(ClientConfig(httpx_client=self._httpx_client))
21
+ self.clients: dict[str, RemoteAgentConnection] = {}
22
+
23
+ def register_agent_card(self, card: AgentCard):
24
+ remote_connection = RemoteAgentConnection(card, self._a2a_client_factory.create(card))
25
+ self.clients[card.name] = remote_connection
26
+
27
+ async def retrieve_card(self, address: str):
28
+ for card_path in [AGENT_CARD_WELL_KNOWN_PATH, PREV_AGENT_CARD_WELL_KNOWN_PATH]:
29
+ try:
30
+ card_resolver = A2ACardResolver(self._httpx_client, address, card_path)
31
+ card = await card_resolver.get_agent_card()
32
+ card.url = address
33
+ self.register_agent_card(card)
34
+ return
35
+ except Exception as e:
36
+ print(f"Error retrieving agent card from {address} at path {card_path}: {e}")
37
+
38
+ raise AgentCardLoadFailedError(f"Failed to load agent card from {address}")
39
+
40
+ async def get_a2a_clients(self, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
41
+ async with asyncio.TaskGroup() as task_group:
42
+ for address in agent_hosts:
43
+ task_group.create_task(self.retrieve_card(address))
44
+
45
+ return self.clients
46
+
47
+
48
+ async def get_a2a_clients(ctx: SessionIdTuple, agent_hosts: list[str]) -> dict[str, RemoteAgentConnection]:
49
+ headers = {
50
+ "user-id": ctx[0],
51
+ "session-id": ctx[1],
52
+ }
53
+ httpx_client = httpx.AsyncClient(headers=headers, timeout=60.0)
54
+ return await _AgentCardResolver(httpx_client).get_a2a_clients(agent_hosts)
55
+
56
+
57
+ def get_session_id_tuple(context: RequestContext) -> SessionIdTuple:
58
+ headers = context.call_context.state.get("headers", {})
59
+ return headers.get("user-id", DEFAULT_USER_ID), headers.get("session-id", DEFAULT_SESSION_ID)
aixtools/a2a/utils.py ADDED
@@ -0,0 +1,115 @@
1
+ """Utilities for Agent-to-Agent (A2A) communication and task management."""
2
+
3
+ import asyncio
4
+ import uuid
5
+ from typing import Callable
6
+
7
+ from fasta2a import Skill
8
+ from fasta2a.client import A2AClient
9
+ from fasta2a.schema import GetTaskResponse, Message, Part, TextPart
10
+ from fastapi import status
11
+
12
+ from ..server import get_session_id_tuple
13
+
14
+ SLEEP_TIME = 0.2
15
+ MAX_ITER = 1000
16
+ HTTP_OK = 200
17
+
18
+
19
+ def card2description(card):
20
+ """Convert agent card to a description string."""
21
+ descr = f"{card['name']}: {card['description']}\n"
22
+ skills = card.get("skills", [])
23
+ for skill in skills:
24
+ descr += f"\t - {skill['name']}: {skill['description']}\n"
25
+ return descr
26
+
27
+
28
+ async def fetch_agent_card(client: A2AClient) -> dict:
29
+ """Request the Agent's card"""
30
+ server_url = str(client.http_client.base_url).rstrip("/")
31
+ agent_card_url = f"{server_url}/.well-known/agent.json"
32
+ response = await client.http_client.get(agent_card_url, timeout=10)
33
+ if response.status_code == status.HTTP_200_OK:
34
+ card_data = response.json()
35
+ return card_data
36
+ raise Exception(f"Failed to retrieve agent card from {agent_card_url}. Status code: {response.status_code}") # pylint: disable=broad-exception-raised
37
+
38
+
39
+ def get_result_text(ret: GetTaskResponse) -> str | None:
40
+ """Extract the result text from the task result"""
41
+ if "result" not in ret:
42
+ return None
43
+ result = ret["result"]
44
+ if "artifacts" not in result:
45
+ return None
46
+ artifacts = result["artifacts"]
47
+ for artifact in artifacts:
48
+ if "parts" not in artifact:
49
+ continue
50
+ parts = artifact["parts"]
51
+ for part in parts:
52
+ if part["kind"] == "text":
53
+ return part["text"]
54
+ return None
55
+
56
+
57
+ async def poll_task(client: A2AClient, task_id: str) -> GetTaskResponse:
58
+ """Polls the task status until it is completed or failed."""
59
+ state = None
60
+ for _ in range(MAX_ITER):
61
+ ret = await client.get_task(task_id=task_id)
62
+ # Check the state of the task
63
+ state = ret["result"]["status"]["state"] if "result" in ret and "status" in ret["result"] else None
64
+ if state == "completed":
65
+ return ret
66
+ if state == "failed":
67
+ raise Exception("Task failed") # pylint: disable=broad-exception-raised
68
+ # Sleep for a while before checking again
69
+ await asyncio.sleep(SLEEP_TIME)
70
+ timeout_seconds = MAX_ITER * SLEEP_TIME
71
+ raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
72
+
73
+
74
+ async def submit_task(client: A2AClient, message: Message) -> str:
75
+ """Send a message to the client and return task id."""
76
+ user_id, session_id = get_session_id_tuple()
77
+ msg = message.copy()
78
+ msg["metadata"] = {
79
+ **msg.get("metadata", {}),
80
+ "user_id": client.http_client.headers.get("user-id", user_id),
81
+ "session_id": client.http_client.headers.get("session-id", session_id),
82
+ }
83
+ ret = await client.send_message(message=msg)
84
+ task_id = ret["result"]["id"] if "result" in ret and "id" in ret["result"] else ""
85
+ return task_id
86
+
87
+
88
+ def multipart_message(parts: list[Part]) -> Message:
89
+ """Create a message object"""
90
+ message = Message(kind="message", role="user", parts=parts, message_id=str(uuid.uuid4()))
91
+ return message
92
+
93
+
94
+ def text_message(text: str) -> Message:
95
+ """Create a message object with a text part."""
96
+ text_part = TextPart(kind="text", text=text, metadata={})
97
+ return multipart_message([text_part])
98
+
99
+
100
+ async def task(client: A2AClient, text: str) -> GetTaskResponse:
101
+ """Send a text message to the client and wait for task completion."""
102
+ msg = text_message(text)
103
+ task_id = await submit_task(client, msg)
104
+ print(f"Task ID: {task_id}")
105
+ ret = await poll_task(client, task_id)
106
+ return ret
107
+
108
+
109
+ def tool2skill(tool: Callable) -> Skill:
110
+ """Convert a tool to a skill."""
111
+ return Skill(
112
+ id=tool.__name__,
113
+ name=tool.__name__,
114
+ description=tool.__doc__ or "",
115
+ ) # type: ignore
@@ -0,0 +1,12 @@
1
+ """Agent utilities for running and managing AI agents."""
2
+
3
+ from .agent import get_agent, get_model, run_agent
4
+ from .agent_batch import AgentQueryParams, run_agent_batch
5
+
6
+ __all__ = [
7
+ "get_agent",
8
+ "get_model",
9
+ "run_agent",
10
+ "AgentQueryParams",
11
+ "run_agent_batch",
12
+ ]
@@ -0,0 +1,164 @@
1
+ """
2
+ Core agent implementation providing model selection and configuration for AI agents.
3
+ """
4
+
5
+ from types import NoneType
6
+ from typing import Any
7
+
8
+ from openai import AsyncAzureOpenAI
9
+ from pydantic_ai import Agent
10
+ from pydantic_ai.models.bedrock import BedrockConverseModel
11
+ from pydantic_ai.models.openai import OpenAIModel
12
+ from pydantic_ai.providers.bedrock import BedrockProvider
13
+ from pydantic_ai.providers.openai import OpenAIProvider
14
+ from pydantic_ai.settings import ModelSettings
15
+ from pydantic_ai.usage import UsageLimits
16
+
17
+ from aixtools.logging.log_objects import ObjectLogger
18
+ from aixtools.logging.logging_config import get_logger
19
+ from aixtools.logging.model_patch_logging import model_patch_logging
20
+ from aixtools.utils.config import (
21
+ AWS_PROFILE,
22
+ AWS_REGION,
23
+ AZURE_MODEL_NAME,
24
+ AZURE_OPENAI_API_KEY,
25
+ AZURE_OPENAI_API_VERSION,
26
+ AZURE_OPENAI_ENDPOINT,
27
+ BEDROCK_MODEL_NAME,
28
+ MODEL_FAMILY,
29
+ MODEL_TIMEOUT,
30
+ OLLAMA_MODEL_NAME,
31
+ OLLAMA_URL,
32
+ OPENAI_API_KEY,
33
+ OPENAI_MODEL_NAME,
34
+ OPENROUTER_API_KEY,
35
+ OPENROUTER_API_URL,
36
+ OPENROUTER_MODEL_NAME,
37
+ )
38
+
39
+ logger = get_logger(__name__)
40
+
41
+
42
+ def _get_model_bedrock(model_name=BEDROCK_MODEL_NAME, aws_region=AWS_REGION):
43
+ assert model_name, "BEDROCK_MODEL_NAME is not set"
44
+ assert aws_region, "AWS_REGION is not set"
45
+
46
+ if AWS_PROFILE is not None:
47
+ return BedrockConverseModel(model_name=model_name)
48
+
49
+ provider = BedrockProvider(region_name=aws_region)
50
+ return BedrockConverseModel(model_name=model_name, provider=provider)
51
+
52
+
53
+ def _get_model_ollama(model_name=OLLAMA_MODEL_NAME, ollama_url=OLLAMA_URL):
54
+ assert ollama_url, "OLLAMA_URL is not set"
55
+ assert model_name, "Model name is not set"
56
+ provider = OpenAIProvider(base_url=ollama_url)
57
+ return OpenAIModel(model_name=model_name, provider=provider)
58
+
59
+
60
+ def _get_model_openai(model_name=OPENAI_MODEL_NAME, openai_api_key=OPENAI_API_KEY):
61
+ assert openai_api_key, "OPENAI_API_KEY is not set"
62
+ assert model_name, "Model name is not set"
63
+ provider = OpenAIProvider(api_key=openai_api_key)
64
+ return OpenAIModel(model_name=model_name, provider=provider)
65
+
66
+
67
+ def _get_model_openai_azure(
68
+ model_name=AZURE_MODEL_NAME,
69
+ azure_openai_api_key=AZURE_OPENAI_API_KEY,
70
+ azure_openai_endpoint=AZURE_OPENAI_ENDPOINT,
71
+ azure_openai_api_version=AZURE_OPENAI_API_VERSION,
72
+ ):
73
+ assert azure_openai_endpoint, "AZURE_OPENAI_ENDPOINT is not set"
74
+ assert azure_openai_api_key, "AZURE_OPENAI_API_KEY is not set"
75
+ assert azure_openai_api_version, "AZURE_OPENAI_API_VERSION is not set"
76
+ assert model_name, "Model name is not set"
77
+ client = AsyncAzureOpenAI(
78
+ azure_endpoint=azure_openai_endpoint, api_version=azure_openai_api_version, api_key=azure_openai_api_key
79
+ )
80
+ return OpenAIModel(model_name=model_name, provider=OpenAIProvider(openai_client=client))
81
+
82
+
83
+ def _get_model_open_router(
84
+ model_name=OPENROUTER_MODEL_NAME, openrouter_api_url=OPENROUTER_API_URL, openrouter_api_key=OPENROUTER_API_KEY
85
+ ):
86
+ assert openrouter_api_url, "OPENROUTER_API_URL is not set"
87
+ assert openrouter_api_key, "OPENROUTER_API_KEY is not set"
88
+ assert model_name, "Model name is not set, missing 'OPENROUTER_MODEL_NAME' environment variable?"
89
+ provider = OpenAIProvider(base_url=openrouter_api_url, api_key=openrouter_api_key)
90
+ return OpenAIModel(model_name, provider=provider)
91
+
92
+
93
+ def get_model(model_family=MODEL_FAMILY, model_name=None, **kwargs):
94
+ """Create and return appropriate model instance based on specified family and name."""
95
+ assert model_family is not None and model_family != "", f"Model family '{model_family}' is not set"
96
+ match model_family:
97
+ case "azure":
98
+ return _get_model_openai_azure(model_name=model_name or AZURE_MODEL_NAME, **kwargs)
99
+ case "bedrock":
100
+ return _get_model_bedrock(model_name=model_name or BEDROCK_MODEL_NAME, **kwargs)
101
+ case "ollama":
102
+ return _get_model_ollama(model_name=model_name or OLLAMA_MODEL_NAME, **kwargs)
103
+ case "openai":
104
+ return _get_model_openai(model_name=model_name or OPENAI_MODEL_NAME, **kwargs)
105
+ case "openrouter":
106
+ return _get_model_open_router(model_name=model_name or OPENROUTER_MODEL_NAME, **kwargs)
107
+ case _:
108
+ raise ValueError(f"Model family '{model_family}' not supported")
109
+
110
+
111
+ def get_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
112
+ model=None,
113
+ *,
114
+ instructions=None,
115
+ system_prompt=(),
116
+ tools=(),
117
+ toolsets=(),
118
+ model_settings=None,
119
+ output_type: Any = str,
120
+ deps_type=NoneType,
121
+ ) -> Agent:
122
+ """Get a PydanticAI agent"""
123
+ if model_settings is None:
124
+ model_settings = ModelSettings(timeout=MODEL_TIMEOUT)
125
+ if model is None:
126
+ model = get_model()
127
+ agent = Agent(
128
+ model=model,
129
+ output_type=output_type,
130
+ instructions=instructions,
131
+ system_prompt=system_prompt,
132
+ deps_type=deps_type,
133
+ model_settings=model_settings,
134
+ tools=tools,
135
+ toolsets=toolsets,
136
+ instrument=True,
137
+ )
138
+ return agent
139
+
140
+
141
+ async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
142
+ agent: Agent,
143
+ prompt: str | list[str],
144
+ usage_limits: UsageLimits | None = None,
145
+ verbose: bool = False,
146
+ debug: bool = False,
147
+ log_model_requests: bool = False,
148
+ parent_logger: ObjectLogger | None = None,
149
+ ):
150
+ """Query the LLM"""
151
+ # Results
152
+ nodes, result = [], None
153
+ async with agent.iter(prompt, usage_limits=usage_limits) as agent_run:
154
+ # Create a new log file for each run
155
+ with ObjectLogger(parent_logger=parent_logger, verbose=verbose, debug=debug) as agent_logger:
156
+ # Patch the model with the logger
157
+ if log_model_requests:
158
+ agent.model = model_patch_logging(agent.model, agent_logger)
159
+ # Run the agent
160
+ async for node in agent_run:
161
+ agent_logger.log(node)
162
+ nodes.append(node)
163
+ result = agent_run.result
164
+ return result.output if result else None, nodes