aixtools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/__init__.py +5 -0
- aixtools/a2a/__init__.py +5 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +74 -0
- aixtools/app.py +143 -0
- aixtools/context.py +12 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +116 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +320 -0
- aixtools/model_patch/model_patch.py +65 -0
- aixtools/server/__init__.py +23 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +147 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +118 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +160 -0
- aixtools-0.1.0.dist-info/METADATA +355 -0
- aixtools-0.1.0.dist-info/RECORD +58 -0
- aixtools-0.1.0.dist-info/WHEEL +5 -0
- aixtools-0.1.0.dist-info/entry_points.txt +2 -0
- aixtools-0.1.0.dist-info/top_level.txt +1 -0
aixtools/__init__.py
ADDED
aixtools/a2a/__init__.py
ADDED
aixtools/a2a/app.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functionality to convert a Pydantic AI Agent into a FastA2A application
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import assert_never
|
|
8
|
+
|
|
9
|
+
from fasta2a.applications import FastA2A
|
|
10
|
+
from fasta2a.broker import InMemoryBroker
|
|
11
|
+
from fasta2a.schema import Part, TaskSendParams
|
|
12
|
+
from fasta2a.storage import InMemoryStorage
|
|
13
|
+
from pydantic_ai import Agent
|
|
14
|
+
from pydantic_ai._a2a import AgentWorker, worker_lifespan
|
|
15
|
+
from pydantic_ai.messages import (
|
|
16
|
+
AudioUrl,
|
|
17
|
+
BinaryContent,
|
|
18
|
+
DocumentUrl,
|
|
19
|
+
ImageUrl,
|
|
20
|
+
ModelRequestPart,
|
|
21
|
+
UserPromptPart,
|
|
22
|
+
VideoUrl,
|
|
23
|
+
)
|
|
24
|
+
from starlette.applications import Starlette
|
|
25
|
+
from starlette.exceptions import HTTPException
|
|
26
|
+
from starlette.requests import Request
|
|
27
|
+
from starlette.responses import RedirectResponse
|
|
28
|
+
|
|
29
|
+
from aixtools.context import session_id_var, user_id_var
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AgentWorkerWithMetadataParser(AgentWorker):
|
|
33
|
+
"""Custom AgentWorker class that extracts the session metadata from message metadata."""
|
|
34
|
+
|
|
35
|
+
async def run_task(self, params: TaskSendParams) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Extract session metadata from message and store them in context variables,
|
|
38
|
+
then call the parent class's run_task method.
|
|
39
|
+
"""
|
|
40
|
+
# Load the task to extract metadata
|
|
41
|
+
task = await self.storage.load_task(params["id"])
|
|
42
|
+
if task:
|
|
43
|
+
# Extract headers from message metadata if available
|
|
44
|
+
if message := (task.get("history") or [None])[-1]:
|
|
45
|
+
metadata = message.get("metadata", {})
|
|
46
|
+
# Store in context variables
|
|
47
|
+
user_id_var.set(metadata.get("user_id", ""))
|
|
48
|
+
session_id_var.set(metadata.get("session_id", ""))
|
|
49
|
+
# Call the parent class's run_task method
|
|
50
|
+
return await super().run_task(params)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AgentWorkerWithDataPartSupport(AgentWorkerWithMetadataParser):
|
|
54
|
+
"""Custom agent worker that adds support for data parts in messages."""
|
|
55
|
+
|
|
56
|
+
def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]:
|
|
57
|
+
"""
|
|
58
|
+
Clones underlying method with additional support for data parts.
|
|
59
|
+
TODO: remove once pydantic-ai supports data parts natively.
|
|
60
|
+
"""
|
|
61
|
+
model_parts: list[ModelRequestPart] = []
|
|
62
|
+
for part in parts:
|
|
63
|
+
if part["kind"] == "text":
|
|
64
|
+
model_parts.append(UserPromptPart(content=part["text"]))
|
|
65
|
+
elif part["kind"] == "file":
|
|
66
|
+
file_content = part["file"]
|
|
67
|
+
if "bytes" in file_content:
|
|
68
|
+
data = file_content["bytes"].encode("utf-8")
|
|
69
|
+
mime_type = file_content.get("mime_type", "application/octet-stream")
|
|
70
|
+
content = BinaryContent(data=data, media_type=mime_type)
|
|
71
|
+
model_parts.append(UserPromptPart(content=[content]))
|
|
72
|
+
else:
|
|
73
|
+
url = file_content["uri"]
|
|
74
|
+
for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl):
|
|
75
|
+
content = url_cls(url=url)
|
|
76
|
+
try:
|
|
77
|
+
content.media_type
|
|
78
|
+
except ValueError: # pragma: no cover
|
|
79
|
+
continue
|
|
80
|
+
else:
|
|
81
|
+
break
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Unsupported file type: {url}") # pragma: no cover
|
|
84
|
+
model_parts.append(UserPromptPart(content=[content]))
|
|
85
|
+
elif part["kind"] == "data":
|
|
86
|
+
content = json.dumps(part["data"])
|
|
87
|
+
model_parts.append(UserPromptPart(content=[content]))
|
|
88
|
+
else:
|
|
89
|
+
assert_never(part)
|
|
90
|
+
return model_parts
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def agent_to_a2a(
|
|
94
|
+
agent: Agent, name: str, description: str, skills: list[dict], worker_class=AgentWorkerWithMetadataParser
|
|
95
|
+
) -> FastA2A:
|
|
96
|
+
"""Convert the agent to an A2A application taking care of session metadata extraction."""
|
|
97
|
+
storage = InMemoryStorage()
|
|
98
|
+
broker = InMemoryBroker()
|
|
99
|
+
worker = worker_class(broker=broker, storage=storage, agent=agent)
|
|
100
|
+
return FastA2A(
|
|
101
|
+
storage=storage,
|
|
102
|
+
broker=broker,
|
|
103
|
+
name=name,
|
|
104
|
+
description=description,
|
|
105
|
+
skills=skills,
|
|
106
|
+
url=None,
|
|
107
|
+
lifespan=partial(worker_lifespan, worker=worker, agent=agent),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def fix_a2a_docs_pages(app: Starlette) -> None:
|
|
112
|
+
"""
|
|
113
|
+
Fix the FastA2A documentation to point to the correct path.
|
|
114
|
+
This is a workaround for the issue with the FastA2A docs not being served correctly
|
|
115
|
+
when mounted as a sub-path.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
async def redirect_to_sub_agent(request: Request):
|
|
119
|
+
"""Redirect to proper sub-app using the Referer header to determine the path prefix."""
|
|
120
|
+
referer = request.headers.get("referer", "")
|
|
121
|
+
if referer.endswith("/docs"):
|
|
122
|
+
return RedirectResponse(url=f"{referer.rsplit('/', 1)[0]}{request.url.path}")
|
|
123
|
+
raise HTTPException(status_code=404)
|
|
124
|
+
|
|
125
|
+
app.router.add_route("/.well-known/agent.json", redirect_to_sub_agent, methods=["GET"])
|
|
126
|
+
app.router.add_route("/", redirect_to_sub_agent, methods=["POST"])
|
aixtools/a2a/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Utilities for Agent-to-Agent (A2A) communication and task management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from fasta2a import Skill
|
|
8
|
+
from fasta2a.client import A2AClient
|
|
9
|
+
from fasta2a.schema import GetTaskResponse, Message, Part, TextPart
|
|
10
|
+
from fastapi import status
|
|
11
|
+
|
|
12
|
+
from ..server import get_session_id_tuple
|
|
13
|
+
|
|
14
|
+
SLEEP_TIME = 0.2
|
|
15
|
+
MAX_ITER = 1000
|
|
16
|
+
HTTP_OK = 200
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def card2description(card):
|
|
20
|
+
"""Convert agent card to a description string."""
|
|
21
|
+
descr = f"{card['name']}: {card['description']}\n"
|
|
22
|
+
skills = card.get("skills", [])
|
|
23
|
+
for skill in skills:
|
|
24
|
+
descr += f"\t - {skill['name']}: {skill['description']}\n"
|
|
25
|
+
return descr
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def fetch_agent_card(client: A2AClient) -> dict:
|
|
29
|
+
"""Request the Agent's card"""
|
|
30
|
+
server_url = str(client.http_client.base_url).rstrip("/")
|
|
31
|
+
agent_card_url = f"{server_url}/.well-known/agent.json"
|
|
32
|
+
response = await client.http_client.get(agent_card_url, timeout=10)
|
|
33
|
+
if response.status_code == status.HTTP_200_OK:
|
|
34
|
+
card_data = response.json()
|
|
35
|
+
return card_data
|
|
36
|
+
raise Exception(f"Failed to retrieve agent card from {agent_card_url}. Status code: {response.status_code}") # pylint: disable=broad-exception-raised
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_result_text(ret: GetTaskResponse) -> str | None:
|
|
40
|
+
"""Extract the result text from the task result"""
|
|
41
|
+
if "result" not in ret:
|
|
42
|
+
return None
|
|
43
|
+
result = ret["result"]
|
|
44
|
+
if "artifacts" not in result:
|
|
45
|
+
return None
|
|
46
|
+
artifacts = result["artifacts"]
|
|
47
|
+
for artifact in artifacts:
|
|
48
|
+
if "parts" not in artifact:
|
|
49
|
+
continue
|
|
50
|
+
parts = artifact["parts"]
|
|
51
|
+
for part in parts:
|
|
52
|
+
if part["kind"] == "text":
|
|
53
|
+
return part["text"]
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def poll_task(client: A2AClient, task_id: str) -> GetTaskResponse:
|
|
58
|
+
"""Polls the task status until it is completed or failed."""
|
|
59
|
+
state = None
|
|
60
|
+
for _ in range(MAX_ITER):
|
|
61
|
+
ret = await client.get_task(task_id=task_id)
|
|
62
|
+
# Check the state of the task
|
|
63
|
+
state = ret["result"]["status"]["state"] if "result" in ret and "status" in ret["result"] else None
|
|
64
|
+
if state == "completed":
|
|
65
|
+
return ret
|
|
66
|
+
if state == "failed":
|
|
67
|
+
raise Exception("Task failed") # pylint: disable=broad-exception-raised
|
|
68
|
+
# Sleep for a while before checking again
|
|
69
|
+
await asyncio.sleep(SLEEP_TIME)
|
|
70
|
+
timeout_seconds = MAX_ITER * SLEEP_TIME
|
|
71
|
+
raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def submit_task(client: A2AClient, message: Message) -> str:
|
|
75
|
+
"""Send a message to the client and return task id."""
|
|
76
|
+
user_id, session_id = get_session_id_tuple()
|
|
77
|
+
msg = message.copy()
|
|
78
|
+
msg["metadata"] = {
|
|
79
|
+
**msg.get("metadata", {}),
|
|
80
|
+
"user_id": client.http_client.headers.get("user-id", user_id),
|
|
81
|
+
"session_id": client.http_client.headers.get("session-id", session_id),
|
|
82
|
+
}
|
|
83
|
+
ret = await client.send_message(message=msg)
|
|
84
|
+
task_id = ret["result"]["id"] if "result" in ret and "id" in ret["result"] else ""
|
|
85
|
+
return task_id
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def multipart_message(parts: list[Part]) -> Message:
|
|
89
|
+
"""Create a message object"""
|
|
90
|
+
message = Message(kind="message", role="user", parts=parts, message_id=str(uuid.uuid4()))
|
|
91
|
+
return message
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def text_message(text: str) -> Message:
|
|
95
|
+
"""Create a message object with a text part."""
|
|
96
|
+
text_part = TextPart(kind="text", text=text, metadata={})
|
|
97
|
+
return multipart_message([text_part])
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def task(client: A2AClient, text: str) -> GetTaskResponse:
|
|
101
|
+
"""Send a text message to the client and wait for task completion."""
|
|
102
|
+
msg = text_message(text)
|
|
103
|
+
task_id = await submit_task(client, msg)
|
|
104
|
+
print(f"Task ID: {task_id}")
|
|
105
|
+
ret = await poll_task(client, task_id)
|
|
106
|
+
return ret
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def tool2skill(tool: Callable) -> Skill:
|
|
110
|
+
"""Convert a tool to a skill."""
|
|
111
|
+
return Skill(
|
|
112
|
+
id=tool.__name__,
|
|
113
|
+
name=tool.__name__,
|
|
114
|
+
description=tool.__doc__ or "",
|
|
115
|
+
) # type: ignore
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Agent utilities for running and managing AI agents."""
|
|
2
|
+
|
|
3
|
+
from .agent import get_agent, get_model, run_agent
|
|
4
|
+
from .agent_batch import AgentQueryParams, run_agent_batch
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"get_agent",
|
|
8
|
+
"get_model",
|
|
9
|
+
"run_agent",
|
|
10
|
+
"AgentQueryParams",
|
|
11
|
+
"run_agent_batch",
|
|
12
|
+
]
|
aixtools/agents/agent.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core agent implementation providing model selection and configuration for AI agents.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from types import NoneType
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from openai import AsyncAzureOpenAI
|
|
9
|
+
from pydantic_ai import Agent
|
|
10
|
+
from pydantic_ai.models.bedrock import BedrockConverseModel
|
|
11
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
12
|
+
from pydantic_ai.providers.bedrock import BedrockProvider
|
|
13
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
14
|
+
from pydantic_ai.settings import ModelSettings
|
|
15
|
+
from pydantic_ai.usage import UsageLimits
|
|
16
|
+
|
|
17
|
+
from aixtools.logging.log_objects import ObjectLogger
|
|
18
|
+
from aixtools.logging.logging_config import get_logger
|
|
19
|
+
from aixtools.logging.model_patch_logging import model_patch_logging
|
|
20
|
+
from aixtools.utils.config import (
|
|
21
|
+
AWS_PROFILE,
|
|
22
|
+
AWS_REGION,
|
|
23
|
+
AZURE_MODEL_NAME,
|
|
24
|
+
AZURE_OPENAI_API_KEY,
|
|
25
|
+
AZURE_OPENAI_API_VERSION,
|
|
26
|
+
AZURE_OPENAI_ENDPOINT,
|
|
27
|
+
BEDROCK_MODEL_NAME,
|
|
28
|
+
MODEL_FAMILY,
|
|
29
|
+
MODEL_TIMEOUT,
|
|
30
|
+
OLLAMA_MODEL_NAME,
|
|
31
|
+
OLLAMA_URL,
|
|
32
|
+
OPENAI_API_KEY,
|
|
33
|
+
OPENAI_MODEL_NAME,
|
|
34
|
+
OPENROUTER_API_KEY,
|
|
35
|
+
OPENROUTER_API_URL,
|
|
36
|
+
OPENROUTER_MODEL_NAME,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
logger = get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_model_bedrock(model_name=BEDROCK_MODEL_NAME, aws_region=AWS_REGION):
|
|
43
|
+
assert model_name, "BEDROCK_MODEL_NAME is not set"
|
|
44
|
+
assert aws_region, "AWS_REGION is not set"
|
|
45
|
+
|
|
46
|
+
if AWS_PROFILE is not None:
|
|
47
|
+
return BedrockConverseModel(model_name=model_name)
|
|
48
|
+
|
|
49
|
+
provider = BedrockProvider(region_name=aws_region)
|
|
50
|
+
return BedrockConverseModel(model_name=model_name, provider=provider)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_model_ollama(model_name=OLLAMA_MODEL_NAME, ollama_url=OLLAMA_URL):
|
|
54
|
+
assert ollama_url, "OLLAMA_URL is not set"
|
|
55
|
+
assert model_name, "Model name is not set"
|
|
56
|
+
provider = OpenAIProvider(base_url=ollama_url)
|
|
57
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_model_openai(model_name=OPENAI_MODEL_NAME, openai_api_key=OPENAI_API_KEY):
|
|
61
|
+
assert openai_api_key, "OPENAI_API_KEY is not set"
|
|
62
|
+
assert model_name, "Model name is not set"
|
|
63
|
+
provider = OpenAIProvider(api_key=openai_api_key)
|
|
64
|
+
return OpenAIModel(model_name=model_name, provider=provider)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_model_openai_azure(
|
|
68
|
+
model_name=AZURE_MODEL_NAME,
|
|
69
|
+
azure_openai_api_key=AZURE_OPENAI_API_KEY,
|
|
70
|
+
azure_openai_endpoint=AZURE_OPENAI_ENDPOINT,
|
|
71
|
+
azure_openai_api_version=AZURE_OPENAI_API_VERSION,
|
|
72
|
+
):
|
|
73
|
+
assert azure_openai_endpoint, "AZURE_OPENAI_ENDPOINT is not set"
|
|
74
|
+
assert azure_openai_api_key, "AZURE_OPENAI_API_KEY is not set"
|
|
75
|
+
assert azure_openai_api_version, "AZURE_OPENAI_API_VERSION is not set"
|
|
76
|
+
assert model_name, "Model name is not set"
|
|
77
|
+
client = AsyncAzureOpenAI(
|
|
78
|
+
azure_endpoint=azure_openai_endpoint, api_version=azure_openai_api_version, api_key=azure_openai_api_key
|
|
79
|
+
)
|
|
80
|
+
return OpenAIModel(model_name=model_name, provider=OpenAIProvider(openai_client=client))
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_model_open_router(
|
|
84
|
+
model_name=OPENROUTER_MODEL_NAME, openrouter_api_url=OPENROUTER_API_URL, openrouter_api_key=OPENROUTER_API_KEY
|
|
85
|
+
):
|
|
86
|
+
assert openrouter_api_url, "OPENROUTER_API_URL is not set"
|
|
87
|
+
assert openrouter_api_key, "OPENROUTER_API_KEY is not set"
|
|
88
|
+
assert model_name, "Model name is not set, missing 'OPENROUTER_MODEL_NAME' environment variable?"
|
|
89
|
+
provider = OpenAIProvider(base_url=openrouter_api_url, api_key=openrouter_api_key)
|
|
90
|
+
return OpenAIModel(model_name, provider=provider)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_model(model_family=MODEL_FAMILY, model_name=None, **kwargs):
|
|
94
|
+
"""Create and return appropriate model instance based on specified family and name."""
|
|
95
|
+
assert model_family is not None and model_family != "", f"Model family '{model_family}' is not set"
|
|
96
|
+
match model_family:
|
|
97
|
+
case "azure":
|
|
98
|
+
return _get_model_openai_azure(model_name=model_name or AZURE_MODEL_NAME, **kwargs)
|
|
99
|
+
case "bedrock":
|
|
100
|
+
return _get_model_bedrock(model_name=model_name or BEDROCK_MODEL_NAME, **kwargs)
|
|
101
|
+
case "ollama":
|
|
102
|
+
return _get_model_ollama(model_name=model_name or OLLAMA_MODEL_NAME, **kwargs)
|
|
103
|
+
case "openai":
|
|
104
|
+
return _get_model_openai(model_name=model_name or OPENAI_MODEL_NAME, **kwargs)
|
|
105
|
+
case "openrouter":
|
|
106
|
+
return _get_model_open_router(model_name=model_name or OPENROUTER_MODEL_NAME, **kwargs)
|
|
107
|
+
case _:
|
|
108
|
+
raise ValueError(f"Model family '{model_family}' not supported")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
112
|
+
model=None,
|
|
113
|
+
*,
|
|
114
|
+
instructions=None,
|
|
115
|
+
system_prompt=(),
|
|
116
|
+
tools=(),
|
|
117
|
+
toolsets=(),
|
|
118
|
+
model_settings=None,
|
|
119
|
+
output_type: Any = str,
|
|
120
|
+
deps_type=NoneType,
|
|
121
|
+
) -> Agent:
|
|
122
|
+
"""Get a PydanticAI agent"""
|
|
123
|
+
if model_settings is None:
|
|
124
|
+
model_settings = ModelSettings(timeout=MODEL_TIMEOUT)
|
|
125
|
+
if model is None:
|
|
126
|
+
model = get_model()
|
|
127
|
+
agent = Agent(
|
|
128
|
+
model=model,
|
|
129
|
+
output_type=output_type,
|
|
130
|
+
instructions=instructions,
|
|
131
|
+
system_prompt=system_prompt,
|
|
132
|
+
deps_type=deps_type,
|
|
133
|
+
model_settings=model_settings,
|
|
134
|
+
tools=tools,
|
|
135
|
+
toolsets=toolsets,
|
|
136
|
+
instrument=True,
|
|
137
|
+
)
|
|
138
|
+
return agent
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
142
|
+
agent: Agent,
|
|
143
|
+
prompt: str | list[str],
|
|
144
|
+
usage_limits: UsageLimits | None = None,
|
|
145
|
+
verbose: bool = False,
|
|
146
|
+
debug: bool = False,
|
|
147
|
+
log_model_requests: bool = False,
|
|
148
|
+
parent_logger: ObjectLogger | None = None,
|
|
149
|
+
):
|
|
150
|
+
"""Query the LLM"""
|
|
151
|
+
# Results
|
|
152
|
+
nodes, result = [], None
|
|
153
|
+
async with agent.iter(prompt, usage_limits=usage_limits) as agent_run:
|
|
154
|
+
# Create a new log file for each run
|
|
155
|
+
with ObjectLogger(parent_logger=parent_logger, verbose=verbose, debug=debug) as agent_logger:
|
|
156
|
+
# Patch the model with the logger
|
|
157
|
+
if log_model_requests:
|
|
158
|
+
agent.model = model_patch_logging(agent.model, agent_logger)
|
|
159
|
+
# Run the agent
|
|
160
|
+
async for node in agent_run:
|
|
161
|
+
agent_logger.log(node)
|
|
162
|
+
nodes.append(node)
|
|
163
|
+
result = agent_run.result
|
|
164
|
+
return result.output if result else None, nodes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch processing functionality for running multiple agent queries in parallel.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from aixtools.agents.agent import get_agent, run_agent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AgentQueryParams(BaseModel):
|
|
14
|
+
"""Parameters for configuring agent queries in batch processing."""
|
|
15
|
+
|
|
16
|
+
class Config: # pylint: disable=too-few-public-methods
|
|
17
|
+
"""Configuration for the model."""
|
|
18
|
+
|
|
19
|
+
arbitrary_types_allowed = True
|
|
20
|
+
|
|
21
|
+
id: str = "" # Unique identifier for the query
|
|
22
|
+
prompt: str | list[str]
|
|
23
|
+
agent: Any = None
|
|
24
|
+
model: Any = None
|
|
25
|
+
debug: bool = False
|
|
26
|
+
output_type: Any = str
|
|
27
|
+
tools: list | None = []
|
|
28
|
+
|
|
29
|
+
async def run(self):
|
|
30
|
+
"""Query the LLM"""
|
|
31
|
+
agent = self.agent
|
|
32
|
+
if agent is None:
|
|
33
|
+
agent = get_agent(
|
|
34
|
+
system_prompt=self.prompt, model=self.model, tools=self.tools, output_type=self.output_type
|
|
35
|
+
)
|
|
36
|
+
return await run_agent(agent=agent, prompt=self.prompt, debug=self.debug)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def run_agent_batch(query_parameters: list[AgentQueryParams], batch_size=10):
|
|
40
|
+
"""
|
|
41
|
+
Run multiple queries simultanously in batches of at most batch_size
|
|
42
|
+
and yield the results as they come in.
|
|
43
|
+
|
|
44
|
+
Usage example:
|
|
45
|
+
query_parameters = [
|
|
46
|
+
AgentQueryParams(prompt="What is the meaning of life")
|
|
47
|
+
AgentQueryParams(prompt="Who is the prime minister of Canada")
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
async for result in agent_batch(query_parameters):
|
|
51
|
+
print(result)
|
|
52
|
+
"""
|
|
53
|
+
tasks = []
|
|
54
|
+
batch_num, total = 1, len(query_parameters)
|
|
55
|
+
for i, qp in enumerate(query_parameters):
|
|
56
|
+
tasks.append(qp.run())
|
|
57
|
+
if len(tasks) >= batch_size:
|
|
58
|
+
# Run a batch of tasks
|
|
59
|
+
print(f"Running batch {batch_num}, {i + 1} / {total}")
|
|
60
|
+
tasks_results = await asyncio.gather(
|
|
61
|
+
*tasks
|
|
62
|
+
) # Returns a list of results, each one is a tuple (result, nodes)
|
|
63
|
+
# Yield the results
|
|
64
|
+
for r, _ in tasks_results:
|
|
65
|
+
yield r
|
|
66
|
+
tasks = []
|
|
67
|
+
batch_num += 1
|
|
68
|
+
# Run the last batch of tasks
|
|
69
|
+
if tasks:
|
|
70
|
+
print(f"Running final batch {batch_num}")
|
|
71
|
+
tasks_results = await asyncio.gather(*tasks)
|
|
72
|
+
for r, _ in tasks_results:
|
|
73
|
+
yield r
|
|
74
|
+
print("Done")
|
aixtools/app.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Simple Chainlit app example
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import traceback
|
|
8
|
+
|
|
9
|
+
import chainlit as cl
|
|
10
|
+
from pydantic_graph import End
|
|
11
|
+
|
|
12
|
+
from aixtools.agents.agent import get_agent
|
|
13
|
+
from aixtools.logging.logging_config import get_logger
|
|
14
|
+
from aixtools.utils.chainlit import cl_agent_show
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
HISTORY = "history"
|
|
19
|
+
|
|
20
|
+
SYSTEM_PROMPT = """
|
|
21
|
+
You are a helpful assistant.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@cl.step
|
|
26
|
+
async def greet_tool(msg: str) -> str:
|
|
27
|
+
"""A simple greeting tool"""
|
|
28
|
+
return f"Hello! You said: {msg}"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def parse_user_message(message):
|
|
32
|
+
"""Parse user message and check if it is a command"""
|
|
33
|
+
# When we type something that starts with ':', we are using a "command" (i.e. it does not go to the agent)
|
|
34
|
+
command = str(message.content).strip().lower()
|
|
35
|
+
if command.startswith(":"):
|
|
36
|
+
logger.debug("Received command: %s", command)
|
|
37
|
+
match command:
|
|
38
|
+
case ":clear":
|
|
39
|
+
# Clear the history
|
|
40
|
+
cl.user_session.set(HISTORY, [])
|
|
41
|
+
return None
|
|
42
|
+
case ":help":
|
|
43
|
+
# Show help
|
|
44
|
+
help_message = """
|
|
45
|
+
Available commands:
|
|
46
|
+
- :clear: Clear the chat history
|
|
47
|
+
- :help: Show this help message
|
|
48
|
+
"""
|
|
49
|
+
await cl.Message(content=help_message).send()
|
|
50
|
+
return None
|
|
51
|
+
case ":history":
|
|
52
|
+
# Show history
|
|
53
|
+
history = cl.user_session.get(HISTORY)
|
|
54
|
+
if history:
|
|
55
|
+
history_message = "\n".join(history)
|
|
56
|
+
await cl.Message(content=f"Chat history:\n{history_message}").send()
|
|
57
|
+
else:
|
|
58
|
+
await cl.Message(content="No history available.").send()
|
|
59
|
+
return None
|
|
60
|
+
case _:
|
|
61
|
+
# Unknown command
|
|
62
|
+
await cl.Message(content=f"Unknown command: {command}").send()
|
|
63
|
+
return None
|
|
64
|
+
else:
|
|
65
|
+
user_message = message.content
|
|
66
|
+
logger.debug("User message: %s", user_message)
|
|
67
|
+
return user_message
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def run_agent(messages):
|
|
71
|
+
"""Run the agent with the given messages"""
|
|
72
|
+
agent = get_agent(system_prompt=SYSTEM_PROMPT, tools=[greet_tool])
|
|
73
|
+
ret = ""
|
|
74
|
+
msg = cl.Message(content="")
|
|
75
|
+
await msg.send()
|
|
76
|
+
try:
|
|
77
|
+
ret = await cl_agent_show.show_run(agent=agent, prompt=messages, msg=msg, debug=False)
|
|
78
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
79
|
+
msg.elements.append(cl.Text(name="Error", content=f"Error: {e}", type="error")) # pylint: disable=unexpected-keyword-arg
|
|
80
|
+
logger.error("Error: %s", e)
|
|
81
|
+
# Log the full stack trace for debugging
|
|
82
|
+
stack_trace = traceback.format_exc()
|
|
83
|
+
logger.error("Stack tarace:\n%s", stack_trace)
|
|
84
|
+
logger.error("Stack trace:\n%s", stack_trace)
|
|
85
|
+
msg.elements.append(cl.Text(name="Stack Trace", content=stack_trace, language="python"))
|
|
86
|
+
ret = f"Internal server error: {e}"
|
|
87
|
+
await msg.send()
|
|
88
|
+
return ret
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def update_history(history, user_message=None, run_return=None):
|
|
92
|
+
"""Update history with user message and model run output"""
|
|
93
|
+
assert user_message is not None or run_return is not None, "Either user message or run return must be provided"
|
|
94
|
+
if user_message is not None:
|
|
95
|
+
logger.debug("Updating history: Got user message type %s: %s", type(user_message), user_message)
|
|
96
|
+
assert isinstance(user_message, str)
|
|
97
|
+
history.append(user_message)
|
|
98
|
+
if run_return is not None:
|
|
99
|
+
logger.debug("Updating history: Got agent output type %s: %s", type(run_return), run_return)
|
|
100
|
+
latest_item = ""
|
|
101
|
+
if isinstance(run_return, list):
|
|
102
|
+
# If it is a list of 'node' items, the last element is the 'end_message' with the final result
|
|
103
|
+
end_message: End = run_return[-1]
|
|
104
|
+
final_result = end_message.data
|
|
105
|
+
latest_item = str(final_result.data)
|
|
106
|
+
else:
|
|
107
|
+
latest_item = str(run_return)
|
|
108
|
+
# Update history and store it
|
|
109
|
+
logger.debug("Updating history: Adding to history type %s: %s", type(latest_item), latest_item)
|
|
110
|
+
history.append(latest_item)
|
|
111
|
+
return history
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@cl.set_starters
|
|
115
|
+
async def set_starters():
|
|
116
|
+
"""Set the starters"""
|
|
117
|
+
return [
|
|
118
|
+
cl.Starter(label="Message", message="Hello world!"),
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@cl.on_chat_start
|
|
123
|
+
async def on_chat_start():
|
|
124
|
+
"""Initialize chat session by resetting history when a new chat starts."""
|
|
125
|
+
# Reset history
|
|
126
|
+
logger.debug("On chat start")
|
|
127
|
+
cl.user_session.set(HISTORY, [])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@cl.on_message
|
|
131
|
+
async def on_message(message: cl.Message):
|
|
132
|
+
"""Process incoming chat messages and generate responses using the agent."""
|
|
133
|
+
history = cl.user_session.get(HISTORY) # Get user message and history
|
|
134
|
+
user_message = await parse_user_message(message) # Parse user message
|
|
135
|
+
# Check if user message is None (e.g. if it is a command)
|
|
136
|
+
if user_message is None:
|
|
137
|
+
return
|
|
138
|
+
messages = update_history(history, user_message=user_message) # Update history with user message
|
|
139
|
+
# Run the agent
|
|
140
|
+
run_return = await run_agent(messages)
|
|
141
|
+
# Update history and store it
|
|
142
|
+
history = update_history(history, run_return=run_return)
|
|
143
|
+
cl.user_session.set(HISTORY, messages)
|
aixtools/context.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines global context variables for request-specific information
|
|
3
|
+
that can be used for logging, tracing, and other purposes across applications
|
|
4
|
+
that use aixtools.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextvars import ContextVar
|
|
8
|
+
|
|
9
|
+
# Define context variables with default values.
|
|
10
|
+
# These can be populated by middleware or where they are initialized
|
|
11
|
+
session_id_var: ContextVar[str | None] = ContextVar("session_id", default=None)
|
|
12
|
+
user_id_var: ContextVar[str | None] = ContextVar("user_id", default=None)
|
aixtools/db/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database module for vector storage and retrieval.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from aixtools.db.database import DatabaseError, SqliteDb
|
|
6
|
+
from aixtools.db.vector_db import get_vdb_embedding, get_vector_db, vdb_add, vdb_get_by_id, vdb_has_id, vdb_query
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"DatabaseError",
|
|
10
|
+
"SqliteDb",
|
|
11
|
+
"get_vdb_embedding",
|
|
12
|
+
"get_vector_db",
|
|
13
|
+
"vdb_add",
|
|
14
|
+
"vdb_get_by_id",
|
|
15
|
+
"vdb_has_id",
|
|
16
|
+
"vdb_query",
|
|
17
|
+
]
|