aixtools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (58) hide show
  1. aixtools/__init__.py +5 -0
  2. aixtools/a2a/__init__.py +5 -0
  3. aixtools/a2a/app.py +126 -0
  4. aixtools/a2a/utils.py +115 -0
  5. aixtools/agents/__init__.py +12 -0
  6. aixtools/agents/agent.py +164 -0
  7. aixtools/agents/agent_batch.py +74 -0
  8. aixtools/app.py +143 -0
  9. aixtools/context.py +12 -0
  10. aixtools/db/__init__.py +17 -0
  11. aixtools/db/database.py +110 -0
  12. aixtools/db/vector_db.py +115 -0
  13. aixtools/log_view/__init__.py +17 -0
  14. aixtools/log_view/app.py +195 -0
  15. aixtools/log_view/display.py +285 -0
  16. aixtools/log_view/export.py +51 -0
  17. aixtools/log_view/filters.py +41 -0
  18. aixtools/log_view/log_utils.py +26 -0
  19. aixtools/log_view/node_summary.py +229 -0
  20. aixtools/logfilters/__init__.py +7 -0
  21. aixtools/logfilters/context_filter.py +67 -0
  22. aixtools/logging/__init__.py +30 -0
  23. aixtools/logging/log_objects.py +227 -0
  24. aixtools/logging/logging_config.py +116 -0
  25. aixtools/logging/mcp_log_models.py +102 -0
  26. aixtools/logging/mcp_logger.py +172 -0
  27. aixtools/logging/model_patch_logging.py +87 -0
  28. aixtools/logging/open_telemetry.py +36 -0
  29. aixtools/mcp/__init__.py +9 -0
  30. aixtools/mcp/example_client.py +30 -0
  31. aixtools/mcp/example_server.py +22 -0
  32. aixtools/mcp/fast_mcp_log.py +31 -0
  33. aixtools/mcp/faulty_mcp.py +320 -0
  34. aixtools/model_patch/model_patch.py +65 -0
  35. aixtools/server/__init__.py +23 -0
  36. aixtools/server/app_mounter.py +90 -0
  37. aixtools/server/path.py +72 -0
  38. aixtools/server/utils.py +70 -0
  39. aixtools/testing/__init__.py +9 -0
  40. aixtools/testing/aix_test_model.py +147 -0
  41. aixtools/testing/mock_tool.py +66 -0
  42. aixtools/testing/model_patch_cache.py +279 -0
  43. aixtools/tools/doctor/__init__.py +3 -0
  44. aixtools/tools/doctor/tool_doctor.py +61 -0
  45. aixtools/tools/doctor/tool_recommendation.py +44 -0
  46. aixtools/utils/__init__.py +35 -0
  47. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  48. aixtools/utils/chainlit/cl_utils.py +168 -0
  49. aixtools/utils/config.py +118 -0
  50. aixtools/utils/config_util.py +69 -0
  51. aixtools/utils/enum_with_description.py +37 -0
  52. aixtools/utils/persisted_dict.py +99 -0
  53. aixtools/utils/utils.py +160 -0
  54. aixtools-0.1.0.dist-info/METADATA +355 -0
  55. aixtools-0.1.0.dist-info/RECORD +58 -0
  56. aixtools-0.1.0.dist-info/WHEEL +5 -0
  57. aixtools-0.1.0.dist-info/entry_points.txt +2 -0
  58. aixtools-0.1.0.dist-info/top_level.txt +1 -0
aixtools/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """
2
+ AiXtools - Tools for AI exploration and debugging
3
+ """
4
+
5
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ """A2A (Agent-to-Agent) communication utilities."""
2
+
3
+ from .utils import fetch_agent_card, task, text_message, tool2skill
4
+
5
+ __all__ = ["fetch_agent_card", "task", "text_message", "tool2skill"]
aixtools/a2a/app.py ADDED
@@ -0,0 +1,126 @@
1
+ """
2
+ This module provides functionality to convert a Pydantic AI Agent into a FastA2A application
3
+ """
4
+
5
+ import json
6
+ from functools import partial
7
+ from typing import assert_never
8
+
9
+ from fasta2a.applications import FastA2A
10
+ from fasta2a.broker import InMemoryBroker
11
+ from fasta2a.schema import Part, TaskSendParams
12
+ from fasta2a.storage import InMemoryStorage
13
+ from pydantic_ai import Agent
14
+ from pydantic_ai._a2a import AgentWorker, worker_lifespan
15
+ from pydantic_ai.messages import (
16
+ AudioUrl,
17
+ BinaryContent,
18
+ DocumentUrl,
19
+ ImageUrl,
20
+ ModelRequestPart,
21
+ UserPromptPart,
22
+ VideoUrl,
23
+ )
24
+ from starlette.applications import Starlette
25
+ from starlette.exceptions import HTTPException
26
+ from starlette.requests import Request
27
+ from starlette.responses import RedirectResponse
28
+
29
+ from aixtools.context import session_id_var, user_id_var
30
+
31
+
32
+ class AgentWorkerWithMetadataParser(AgentWorker):
33
+ """Custom AgentWorker class that extracts the session metadata from message metadata."""
34
+
35
+ async def run_task(self, params: TaskSendParams) -> None:
36
+ """
37
+ Extract session metadata from message and store them in context variables,
38
+ then call the parent class's run_task method.
39
+ """
40
+ # Load the task to extract metadata
41
+ task = await self.storage.load_task(params["id"])
42
+ if task:
43
+ # Extract headers from message metadata if available
44
+ if message := (task.get("history") or [None])[-1]:
45
+ metadata = message.get("metadata", {})
46
+ # Store in context variables
47
+ user_id_var.set(metadata.get("user_id", ""))
48
+ session_id_var.set(metadata.get("session_id", ""))
49
+ # Call the parent class's run_task method
50
+ return await super().run_task(params)
51
+
52
+
53
+ class AgentWorkerWithDataPartSupport(AgentWorkerWithMetadataParser):
54
+ """Custom agent worker that adds support for data parts in messages."""
55
+
56
+ def _request_parts_from_a2a(self, parts: list[Part]) -> list[ModelRequestPart]:
57
+ """
58
+ Clones underlying method with additional support for data parts.
59
+ TODO: remove once pydantic-ai supports data parts natively.
60
+ """
61
+ model_parts: list[ModelRequestPart] = []
62
+ for part in parts:
63
+ if part["kind"] == "text":
64
+ model_parts.append(UserPromptPart(content=part["text"]))
65
+ elif part["kind"] == "file":
66
+ file_content = part["file"]
67
+ if "bytes" in file_content:
68
+ data = file_content["bytes"].encode("utf-8")
69
+ mime_type = file_content.get("mime_type", "application/octet-stream")
70
+ content = BinaryContent(data=data, media_type=mime_type)
71
+ model_parts.append(UserPromptPart(content=[content]))
72
+ else:
73
+ url = file_content["uri"]
74
+ for url_cls in (DocumentUrl, AudioUrl, ImageUrl, VideoUrl):
75
+ content = url_cls(url=url)
76
+ try:
77
+ content.media_type
78
+ except ValueError: # pragma: no cover
79
+ continue
80
+ else:
81
+ break
82
+ else:
83
+ raise ValueError(f"Unsupported file type: {url}") # pragma: no cover
84
+ model_parts.append(UserPromptPart(content=[content]))
85
+ elif part["kind"] == "data":
86
+ content = json.dumps(part["data"])
87
+ model_parts.append(UserPromptPart(content=[content]))
88
+ else:
89
+ assert_never(part)
90
+ return model_parts
91
+
92
+
93
+ def agent_to_a2a(
94
+ agent: Agent, name: str, description: str, skills: list[dict], worker_class=AgentWorkerWithMetadataParser
95
+ ) -> FastA2A:
96
+ """Convert the agent to an A2A application taking care of session metadata extraction."""
97
+ storage = InMemoryStorage()
98
+ broker = InMemoryBroker()
99
+ worker = worker_class(broker=broker, storage=storage, agent=agent)
100
+ return FastA2A(
101
+ storage=storage,
102
+ broker=broker,
103
+ name=name,
104
+ description=description,
105
+ skills=skills,
106
+ url=None,
107
+ lifespan=partial(worker_lifespan, worker=worker, agent=agent),
108
+ )
109
+
110
+
111
+ def fix_a2a_docs_pages(app: Starlette) -> None:
112
+ """
113
+ Fix the FastA2A documentation to point to the correct path.
114
+ This is a workaround for the issue with the FastA2A docs not being served correctly
115
+ when mounted as a sub-path.
116
+ """
117
+
118
+ async def redirect_to_sub_agent(request: Request):
119
+ """Redirect to proper sub-app using the Referer header to determine the path prefix."""
120
+ referer = request.headers.get("referer", "")
121
+ if referer.endswith("/docs"):
122
+ return RedirectResponse(url=f"{referer.rsplit('/', 1)[0]}{request.url.path}")
123
+ raise HTTPException(status_code=404)
124
+
125
+ app.router.add_route("/.well-known/agent.json", redirect_to_sub_agent, methods=["GET"])
126
+ app.router.add_route("/", redirect_to_sub_agent, methods=["POST"])
aixtools/a2a/utils.py ADDED
@@ -0,0 +1,115 @@
1
+ """Utilities for Agent-to-Agent (A2A) communication and task management."""
2
+
3
+ import asyncio
4
+ import uuid
5
+ from typing import Callable
6
+
7
+ from fasta2a import Skill
8
+ from fasta2a.client import A2AClient
9
+ from fasta2a.schema import GetTaskResponse, Message, Part, TextPart
10
+ from fastapi import status
11
+
12
+ from ..server import get_session_id_tuple
13
+
14
+ SLEEP_TIME = 0.2
15
+ MAX_ITER = 1000
16
+ HTTP_OK = 200
17
+
18
+
19
+ def card2description(card):
20
+ """Convert agent card to a description string."""
21
+ descr = f"{card['name']}: {card['description']}\n"
22
+ skills = card.get("skills", [])
23
+ for skill in skills:
24
+ descr += f"\t - {skill['name']}: {skill['description']}\n"
25
+ return descr
26
+
27
+
28
+ async def fetch_agent_card(client: A2AClient) -> dict:
29
+ """Request the Agent's card"""
30
+ server_url = str(client.http_client.base_url).rstrip("/")
31
+ agent_card_url = f"{server_url}/.well-known/agent.json"
32
+ response = await client.http_client.get(agent_card_url, timeout=10)
33
+ if response.status_code == status.HTTP_200_OK:
34
+ card_data = response.json()
35
+ return card_data
36
+ raise Exception(f"Failed to retrieve agent card from {agent_card_url}. Status code: {response.status_code}") # pylint: disable=broad-exception-raised
37
+
38
+
39
+ def get_result_text(ret: GetTaskResponse) -> str | None:
40
+ """Extract the result text from the task result"""
41
+ if "result" not in ret:
42
+ return None
43
+ result = ret["result"]
44
+ if "artifacts" not in result:
45
+ return None
46
+ artifacts = result["artifacts"]
47
+ for artifact in artifacts:
48
+ if "parts" not in artifact:
49
+ continue
50
+ parts = artifact["parts"]
51
+ for part in parts:
52
+ if part["kind"] == "text":
53
+ return part["text"]
54
+ return None
55
+
56
+
57
+ async def poll_task(client: A2AClient, task_id: str) -> GetTaskResponse:
58
+ """Polls the task status until it is completed or failed."""
59
+ state = None
60
+ for _ in range(MAX_ITER):
61
+ ret = await client.get_task(task_id=task_id)
62
+ # Check the state of the task
63
+ state = ret["result"]["status"]["state"] if "result" in ret and "status" in ret["result"] else None
64
+ if state == "completed":
65
+ return ret
66
+ if state == "failed":
67
+ raise Exception("Task failed") # pylint: disable=broad-exception-raised
68
+ # Sleep for a while before checking again
69
+ await asyncio.sleep(SLEEP_TIME)
70
+ timeout_seconds = MAX_ITER * SLEEP_TIME
71
+ raise Exception(f"Task did not complete in {timeout_seconds} seconds") # pylint: disable=broad-exception-raised
72
+
73
+
74
+ async def submit_task(client: A2AClient, message: Message) -> str:
75
+ """Send a message to the client and return task id."""
76
+ user_id, session_id = get_session_id_tuple()
77
+ msg = message.copy()
78
+ msg["metadata"] = {
79
+ **msg.get("metadata", {}),
80
+ "user_id": client.http_client.headers.get("user-id", user_id),
81
+ "session_id": client.http_client.headers.get("session-id", session_id),
82
+ }
83
+ ret = await client.send_message(message=msg)
84
+ task_id = ret["result"]["id"] if "result" in ret and "id" in ret["result"] else ""
85
+ return task_id
86
+
87
+
88
+ def multipart_message(parts: list[Part]) -> Message:
89
+ """Create a message object"""
90
+ message = Message(kind="message", role="user", parts=parts, message_id=str(uuid.uuid4()))
91
+ return message
92
+
93
+
94
+ def text_message(text: str) -> Message:
95
+ """Create a message object with a text part."""
96
+ text_part = TextPart(kind="text", text=text, metadata={})
97
+ return multipart_message([text_part])
98
+
99
+
100
+ async def task(client: A2AClient, text: str) -> GetTaskResponse:
101
+ """Send a text message to the client and wait for task completion."""
102
+ msg = text_message(text)
103
+ task_id = await submit_task(client, msg)
104
+ print(f"Task ID: {task_id}")
105
+ ret = await poll_task(client, task_id)
106
+ return ret
107
+
108
+
109
+ def tool2skill(tool: Callable) -> Skill:
110
+ """Convert a tool to a skill."""
111
+ return Skill(
112
+ id=tool.__name__,
113
+ name=tool.__name__,
114
+ description=tool.__doc__ or "",
115
+ ) # type: ignore
@@ -0,0 +1,12 @@
1
+ """Agent utilities for running and managing AI agents."""
2
+
3
+ from .agent import get_agent, get_model, run_agent
4
+ from .agent_batch import AgentQueryParams, run_agent_batch
5
+
6
+ __all__ = [
7
+ "get_agent",
8
+ "get_model",
9
+ "run_agent",
10
+ "AgentQueryParams",
11
+ "run_agent_batch",
12
+ ]
@@ -0,0 +1,164 @@
1
+ """
2
+ Core agent implementation providing model selection and configuration for AI agents.
3
+ """
4
+
5
+ from types import NoneType
6
+ from typing import Any
7
+
8
+ from openai import AsyncAzureOpenAI
9
+ from pydantic_ai import Agent
10
+ from pydantic_ai.models.bedrock import BedrockConverseModel
11
+ from pydantic_ai.models.openai import OpenAIModel
12
+ from pydantic_ai.providers.bedrock import BedrockProvider
13
+ from pydantic_ai.providers.openai import OpenAIProvider
14
+ from pydantic_ai.settings import ModelSettings
15
+ from pydantic_ai.usage import UsageLimits
16
+
17
+ from aixtools.logging.log_objects import ObjectLogger
18
+ from aixtools.logging.logging_config import get_logger
19
+ from aixtools.logging.model_patch_logging import model_patch_logging
20
+ from aixtools.utils.config import (
21
+ AWS_PROFILE,
22
+ AWS_REGION,
23
+ AZURE_MODEL_NAME,
24
+ AZURE_OPENAI_API_KEY,
25
+ AZURE_OPENAI_API_VERSION,
26
+ AZURE_OPENAI_ENDPOINT,
27
+ BEDROCK_MODEL_NAME,
28
+ MODEL_FAMILY,
29
+ MODEL_TIMEOUT,
30
+ OLLAMA_MODEL_NAME,
31
+ OLLAMA_URL,
32
+ OPENAI_API_KEY,
33
+ OPENAI_MODEL_NAME,
34
+ OPENROUTER_API_KEY,
35
+ OPENROUTER_API_URL,
36
+ OPENROUTER_MODEL_NAME,
37
+ )
38
+
39
+ logger = get_logger(__name__)
40
+
41
+
42
+ def _get_model_bedrock(model_name=BEDROCK_MODEL_NAME, aws_region=AWS_REGION):
43
+ assert model_name, "BEDROCK_MODEL_NAME is not set"
44
+ assert aws_region, "AWS_REGION is not set"
45
+
46
+ if AWS_PROFILE is not None:
47
+ return BedrockConverseModel(model_name=model_name)
48
+
49
+ provider = BedrockProvider(region_name=aws_region)
50
+ return BedrockConverseModel(model_name=model_name, provider=provider)
51
+
52
+
53
+ def _get_model_ollama(model_name=OLLAMA_MODEL_NAME, ollama_url=OLLAMA_URL):
54
+ assert ollama_url, "OLLAMA_URL is not set"
55
+ assert model_name, "Model name is not set"
56
+ provider = OpenAIProvider(base_url=ollama_url)
57
+ return OpenAIModel(model_name=model_name, provider=provider)
58
+
59
+
60
+ def _get_model_openai(model_name=OPENAI_MODEL_NAME, openai_api_key=OPENAI_API_KEY):
61
+ assert openai_api_key, "OPENAI_API_KEY is not set"
62
+ assert model_name, "Model name is not set"
63
+ provider = OpenAIProvider(api_key=openai_api_key)
64
+ return OpenAIModel(model_name=model_name, provider=provider)
65
+
66
+
67
+ def _get_model_openai_azure(
68
+ model_name=AZURE_MODEL_NAME,
69
+ azure_openai_api_key=AZURE_OPENAI_API_KEY,
70
+ azure_openai_endpoint=AZURE_OPENAI_ENDPOINT,
71
+ azure_openai_api_version=AZURE_OPENAI_API_VERSION,
72
+ ):
73
+ assert azure_openai_endpoint, "AZURE_OPENAI_ENDPOINT is not set"
74
+ assert azure_openai_api_key, "AZURE_OPENAI_API_KEY is not set"
75
+ assert azure_openai_api_version, "AZURE_OPENAI_API_VERSION is not set"
76
+ assert model_name, "Model name is not set"
77
+ client = AsyncAzureOpenAI(
78
+ azure_endpoint=azure_openai_endpoint, api_version=azure_openai_api_version, api_key=azure_openai_api_key
79
+ )
80
+ return OpenAIModel(model_name=model_name, provider=OpenAIProvider(openai_client=client))
81
+
82
+
83
+ def _get_model_open_router(
84
+ model_name=OPENROUTER_MODEL_NAME, openrouter_api_url=OPENROUTER_API_URL, openrouter_api_key=OPENROUTER_API_KEY
85
+ ):
86
+ assert openrouter_api_url, "OPENROUTER_API_URL is not set"
87
+ assert openrouter_api_key, "OPENROUTER_API_KEY is not set"
88
+ assert model_name, "Model name is not set, missing 'OPENROUTER_MODEL_NAME' environment variable?"
89
+ provider = OpenAIProvider(base_url=openrouter_api_url, api_key=openrouter_api_key)
90
+ return OpenAIModel(model_name, provider=provider)
91
+
92
+
93
+ def get_model(model_family=MODEL_FAMILY, model_name=None, **kwargs):
94
+ """Create and return appropriate model instance based on specified family and name."""
95
+ assert model_family is not None and model_family != "", f"Model family '{model_family}' is not set"
96
+ match model_family:
97
+ case "azure":
98
+ return _get_model_openai_azure(model_name=model_name or AZURE_MODEL_NAME, **kwargs)
99
+ case "bedrock":
100
+ return _get_model_bedrock(model_name=model_name or BEDROCK_MODEL_NAME, **kwargs)
101
+ case "ollama":
102
+ return _get_model_ollama(model_name=model_name or OLLAMA_MODEL_NAME, **kwargs)
103
+ case "openai":
104
+ return _get_model_openai(model_name=model_name or OPENAI_MODEL_NAME, **kwargs)
105
+ case "openrouter":
106
+ return _get_model_open_router(model_name=model_name or OPENROUTER_MODEL_NAME, **kwargs)
107
+ case _:
108
+ raise ValueError(f"Model family '{model_family}' not supported")
109
+
110
+
111
+ def get_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
112
+ model=None,
113
+ *,
114
+ instructions=None,
115
+ system_prompt=(),
116
+ tools=(),
117
+ toolsets=(),
118
+ model_settings=None,
119
+ output_type: Any = str,
120
+ deps_type=NoneType,
121
+ ) -> Agent:
122
+ """Get a PydanticAI agent"""
123
+ if model_settings is None:
124
+ model_settings = ModelSettings(timeout=MODEL_TIMEOUT)
125
+ if model is None:
126
+ model = get_model()
127
+ agent = Agent(
128
+ model=model,
129
+ output_type=output_type,
130
+ instructions=instructions,
131
+ system_prompt=system_prompt,
132
+ deps_type=deps_type,
133
+ model_settings=model_settings,
134
+ tools=tools,
135
+ toolsets=toolsets,
136
+ instrument=True,
137
+ )
138
+ return agent
139
+
140
+
141
+ async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
142
+ agent: Agent,
143
+ prompt: str | list[str],
144
+ usage_limits: UsageLimits | None = None,
145
+ verbose: bool = False,
146
+ debug: bool = False,
147
+ log_model_requests: bool = False,
148
+ parent_logger: ObjectLogger | None = None,
149
+ ):
150
+ """Query the LLM"""
151
+ # Results
152
+ nodes, result = [], None
153
+ async with agent.iter(prompt, usage_limits=usage_limits) as agent_run:
154
+ # Create a new log file for each run
155
+ with ObjectLogger(parent_logger=parent_logger, verbose=verbose, debug=debug) as agent_logger:
156
+ # Patch the model with the logger
157
+ if log_model_requests:
158
+ agent.model = model_patch_logging(agent.model, agent_logger)
159
+ # Run the agent
160
+ async for node in agent_run:
161
+ agent_logger.log(node)
162
+ nodes.append(node)
163
+ result = agent_run.result
164
+ return result.output if result else None, nodes
@@ -0,0 +1,74 @@
1
+ """
2
+ Batch processing functionality for running multiple agent queries in parallel.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from aixtools.agents.agent import get_agent, run_agent
11
+
12
+
13
+ class AgentQueryParams(BaseModel):
14
+ """Parameters for configuring agent queries in batch processing."""
15
+
16
+ class Config: # pylint: disable=too-few-public-methods
17
+ """Configuration for the model."""
18
+
19
+ arbitrary_types_allowed = True
20
+
21
+ id: str = "" # Unique identifier for the query
22
+ prompt: str | list[str]
23
+ agent: Any = None
24
+ model: Any = None
25
+ debug: bool = False
26
+ output_type: Any = str
27
+ tools: list | None = []
28
+
29
+ async def run(self):
30
+ """Query the LLM"""
31
+ agent = self.agent
32
+ if agent is None:
33
+ agent = get_agent(
34
+ system_prompt=self.prompt, model=self.model, tools=self.tools, output_type=self.output_type
35
+ )
36
+ return await run_agent(agent=agent, prompt=self.prompt, debug=self.debug)
37
+
38
+
39
+ async def run_agent_batch(query_parameters: list[AgentQueryParams], batch_size=10):
40
+ """
41
+ Run multiple queries simultanously in batches of at most batch_size
42
+ and yield the results as they come in.
43
+
44
+ Usage example:
45
+ query_parameters = [
46
+ AgentQueryParams(prompt="What is the meaning of life")
47
+ AgentQueryParams(prompt="Who is the prime minister of Canada")
48
+ ]
49
+
50
+ async for result in agent_batch(query_parameters):
51
+ print(result)
52
+ """
53
+ tasks = []
54
+ batch_num, total = 1, len(query_parameters)
55
+ for i, qp in enumerate(query_parameters):
56
+ tasks.append(qp.run())
57
+ if len(tasks) >= batch_size:
58
+ # Run a batch of tasks
59
+ print(f"Running batch {batch_num}, {i + 1} / {total}")
60
+ tasks_results = await asyncio.gather(
61
+ *tasks
62
+ ) # Returns a list of results, each one is a tuple (result, nodes)
63
+ # Yield the results
64
+ for r, _ in tasks_results:
65
+ yield r
66
+ tasks = []
67
+ batch_num += 1
68
+ # Run the last batch of tasks
69
+ if tasks:
70
+ print(f"Running final batch {batch_num}")
71
+ tasks_results = await asyncio.gather(*tasks)
72
+ for r, _ in tasks_results:
73
+ yield r
74
+ print("Done")
aixtools/app.py ADDED
@@ -0,0 +1,143 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Simple Chainlit app example
5
+ """
6
+
7
+ import traceback
8
+
9
+ import chainlit as cl
10
+ from pydantic_graph import End
11
+
12
+ from aixtools.agents.agent import get_agent
13
+ from aixtools.logging.logging_config import get_logger
14
+ from aixtools.utils.chainlit import cl_agent_show
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ HISTORY = "history"
19
+
20
+ SYSTEM_PROMPT = """
21
+ You are a helpful assistant.
22
+ """
23
+
24
+
25
+ @cl.step
26
+ async def greet_tool(msg: str) -> str:
27
+ """A simple greeting tool"""
28
+ return f"Hello! You said: {msg}"
29
+
30
+
31
+ async def parse_user_message(message):
32
+ """Parse user message and check if it is a command"""
33
+ # When we type something that starts with ':', we are using a "command" (i.e. it does not go to the agent)
34
+ command = str(message.content).strip().lower()
35
+ if command.startswith(":"):
36
+ logger.debug("Received command: %s", command)
37
+ match command:
38
+ case ":clear":
39
+ # Clear the history
40
+ cl.user_session.set(HISTORY, [])
41
+ return None
42
+ case ":help":
43
+ # Show help
44
+ help_message = """
45
+ Available commands:
46
+ - :clear: Clear the chat history
47
+ - :help: Show this help message
48
+ """
49
+ await cl.Message(content=help_message).send()
50
+ return None
51
+ case ":history":
52
+ # Show history
53
+ history = cl.user_session.get(HISTORY)
54
+ if history:
55
+ history_message = "\n".join(history)
56
+ await cl.Message(content=f"Chat history:\n{history_message}").send()
57
+ else:
58
+ await cl.Message(content="No history available.").send()
59
+ return None
60
+ case _:
61
+ # Unknown command
62
+ await cl.Message(content=f"Unknown command: {command}").send()
63
+ return None
64
+ else:
65
+ user_message = message.content
66
+ logger.debug("User message: %s", user_message)
67
+ return user_message
68
+
69
+
70
+ async def run_agent(messages):
71
+ """Run the agent with the given messages"""
72
+ agent = get_agent(system_prompt=SYSTEM_PROMPT, tools=[greet_tool])
73
+ ret = ""
74
+ msg = cl.Message(content="")
75
+ await msg.send()
76
+ try:
77
+ ret = await cl_agent_show.show_run(agent=agent, prompt=messages, msg=msg, debug=False)
78
+ except Exception as e: # pylint: disable=broad-exception-caught
79
+ msg.elements.append(cl.Text(name="Error", content=f"Error: {e}", type="error")) # pylint: disable=unexpected-keyword-arg
80
+ logger.error("Error: %s", e)
81
+ # Log the full stack trace for debugging
82
+ stack_trace = traceback.format_exc()
83
+ logger.error("Stack tarace:\n%s", stack_trace)
84
+ logger.error("Stack trace:\n%s", stack_trace)
85
+ msg.elements.append(cl.Text(name="Stack Trace", content=stack_trace, language="python"))
86
+ ret = f"Internal server error: {e}"
87
+ await msg.send()
88
+ return ret
89
+
90
+
91
+ def update_history(history, user_message=None, run_return=None):
92
+ """Update history with user message and model run output"""
93
+ assert user_message is not None or run_return is not None, "Either user message or run return must be provided"
94
+ if user_message is not None:
95
+ logger.debug("Updating history: Got user message type %s: %s", type(user_message), user_message)
96
+ assert isinstance(user_message, str)
97
+ history.append(user_message)
98
+ if run_return is not None:
99
+ logger.debug("Updating history: Got agent output type %s: %s", type(run_return), run_return)
100
+ latest_item = ""
101
+ if isinstance(run_return, list):
102
+ # If it is a list of 'node' items, the last element is the 'end_message' with the final result
103
+ end_message: End = run_return[-1]
104
+ final_result = end_message.data
105
+ latest_item = str(final_result.data)
106
+ else:
107
+ latest_item = str(run_return)
108
+ # Update history and store it
109
+ logger.debug("Updating history: Adding to history type %s: %s", type(latest_item), latest_item)
110
+ history.append(latest_item)
111
+ return history
112
+
113
+
114
+ @cl.set_starters
115
+ async def set_starters():
116
+ """Set the starters"""
117
+ return [
118
+ cl.Starter(label="Message", message="Hello world!"),
119
+ ]
120
+
121
+
122
+ @cl.on_chat_start
123
+ async def on_chat_start():
124
+ """Initialize chat session by resetting history when a new chat starts."""
125
+ # Reset history
126
+ logger.debug("On chat start")
127
+ cl.user_session.set(HISTORY, [])
128
+
129
+
130
+ @cl.on_message
131
+ async def on_message(message: cl.Message):
132
+ """Process incoming chat messages and generate responses using the agent."""
133
+ history = cl.user_session.get(HISTORY) # Get user message and history
134
+ user_message = await parse_user_message(message) # Parse user message
135
+ # Check if user message is None (e.g. if it is a command)
136
+ if user_message is None:
137
+ return
138
+ messages = update_history(history, user_message=user_message) # Update history with user message
139
+ # Run the agent
140
+ run_return = await run_agent(messages)
141
+ # Update history and store it
142
+ history = update_history(history, run_return=run_return)
143
+ cl.user_session.set(HISTORY, messages)
aixtools/context.py ADDED
@@ -0,0 +1,12 @@
1
+ """
2
+ This module defines global context variables for request-specific information
3
+ that can be used for logging, tracing, and other purposes across applications
4
+ that use aixtools.
5
+ """
6
+
7
+ from contextvars import ContextVar
8
+
9
+ # Define context variables with default values.
10
+ # These can be populated by middleware or where they are initialized
11
+ session_id_var: ContextVar[str | None] = ContextVar("session_id", default=None)
12
+ user_id_var: ContextVar[str | None] = ContextVar("user_id", default=None)
@@ -0,0 +1,17 @@
1
+ """
2
+ Database module for vector storage and retrieval.
3
+ """
4
+
5
+ from aixtools.db.database import DatabaseError, SqliteDb
6
+ from aixtools.db.vector_db import get_vdb_embedding, get_vector_db, vdb_add, vdb_get_by_id, vdb_has_id, vdb_query
7
+
8
+ __all__ = [
9
+ "DatabaseError",
10
+ "SqliteDb",
11
+ "get_vdb_embedding",
12
+ "get_vector_db",
13
+ "vdb_add",
14
+ "vdb_get_by_id",
15
+ "vdb_has_id",
16
+ "vdb_query",
17
+ ]