aixtools 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/.chainlit/config.toml +113 -0
- aixtools/.chainlit/translations/bn.json +214 -0
- aixtools/.chainlit/translations/en-US.json +214 -0
- aixtools/.chainlit/translations/gu.json +214 -0
- aixtools/.chainlit/translations/he-IL.json +214 -0
- aixtools/.chainlit/translations/hi.json +214 -0
- aixtools/.chainlit/translations/ja.json +214 -0
- aixtools/.chainlit/translations/kn.json +214 -0
- aixtools/.chainlit/translations/ml.json +214 -0
- aixtools/.chainlit/translations/mr.json +214 -0
- aixtools/.chainlit/translations/nl.json +214 -0
- aixtools/.chainlit/translations/ta.json +214 -0
- aixtools/.chainlit/translations/te.json +214 -0
- aixtools/.chainlit/translations/zh-CN.json +214 -0
- aixtools/__init__.py +11 -0
- aixtools/_version.py +34 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/google_sdk/__init__.py +0 -0
- aixtools/a2a/google_sdk/card.py +27 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
- aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
- aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
- aixtools/a2a/google_sdk/utils.py +59 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +71 -0
- aixtools/agents/prompt.py +97 -0
- aixtools/app.py +143 -0
- aixtools/chainlit.md +14 -0
- aixtools/compliance/__init__.py +9 -0
- aixtools/compliance/private_data.py +138 -0
- aixtools/context.py +17 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/google/client.py +25 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +161 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/client.py +375 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +319 -0
- aixtools/model_patch/model_patch.py +63 -0
- aixtools/server/__init__.py +29 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/server/workspace_privacy.py +65 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +149 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +131 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/files.py +17 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +167 -0
- aixtools/vault/__init__.py +7 -0
- aixtools/vault/vault.py +137 -0
- aixtools-0.0.0.dist-info/METADATA +669 -0
- aixtools-0.0.0.dist-info/RECORD +88 -0
- aixtools-0.0.0.dist-info/WHEEL +5 -0
- aixtools-0.0.0.dist-info/entry_points.txt +2 -0
- aixtools-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Batch processing functionality for running multiple agent queries in parallel.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
|
9
|
+
|
|
10
|
+
from aixtools.agents.agent import get_agent, run_agent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AgentQueryParams(BaseModel):
|
|
14
|
+
"""Parameters for configuring agent queries in batch processing."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
17
|
+
|
|
18
|
+
id: str = "" # Unique identifier for the query
|
|
19
|
+
prompt: str | list[str]
|
|
20
|
+
agent: Any = None
|
|
21
|
+
model: Any = None
|
|
22
|
+
debug: bool = False
|
|
23
|
+
output_type: Any = str
|
|
24
|
+
tools: list | None = []
|
|
25
|
+
|
|
26
|
+
async def run(self):
|
|
27
|
+
"""Query the LLM"""
|
|
28
|
+
agent = self.agent
|
|
29
|
+
if agent is None:
|
|
30
|
+
agent = get_agent(
|
|
31
|
+
system_prompt=self.prompt, model=self.model, tools=self.tools, output_type=self.output_type
|
|
32
|
+
)
|
|
33
|
+
return await run_agent(agent=agent, prompt=self.prompt, debug=self.debug)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def run_agent_batch(query_parameters: list[AgentQueryParams], batch_size=10):
|
|
37
|
+
"""
|
|
38
|
+
Run multiple queries simultanously in batches of at most batch_size
|
|
39
|
+
and yield the results as they come in.
|
|
40
|
+
|
|
41
|
+
Usage example:
|
|
42
|
+
query_parameters = [
|
|
43
|
+
AgentQueryParams(prompt="What is the meaning of life")
|
|
44
|
+
AgentQueryParams(prompt="Who is the prime minister of Canada")
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
async for result in agent_batch(query_parameters):
|
|
48
|
+
print(result)
|
|
49
|
+
"""
|
|
50
|
+
tasks = []
|
|
51
|
+
batch_num, total = 1, len(query_parameters)
|
|
52
|
+
for i, qp in enumerate(query_parameters):
|
|
53
|
+
tasks.append(qp.run())
|
|
54
|
+
if len(tasks) >= batch_size:
|
|
55
|
+
# Run a batch of tasks
|
|
56
|
+
print(f"Running batch {batch_num}, {i + 1} / {total}")
|
|
57
|
+
tasks_results = await asyncio.gather(
|
|
58
|
+
*tasks
|
|
59
|
+
) # Returns a list of results, each one is a tuple (result, nodes)
|
|
60
|
+
# Yield the results
|
|
61
|
+
for r, _ in tasks_results:
|
|
62
|
+
yield r
|
|
63
|
+
tasks = []
|
|
64
|
+
batch_num += 1
|
|
65
|
+
# Run the last batch of tasks
|
|
66
|
+
if tasks:
|
|
67
|
+
print(f"Running final batch {batch_num}")
|
|
68
|
+
tasks_results = await asyncio.gather(*tasks)
|
|
69
|
+
for r, _ in tasks_results:
|
|
70
|
+
yield r
|
|
71
|
+
print("Done")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Prompt building utilities for Pydantic AI agent, including file handling and context management."""
|
|
2
|
+
|
|
3
|
+
import mimetypes
|
|
4
|
+
from pathlib import Path, PurePosixPath
|
|
5
|
+
|
|
6
|
+
from pydantic_ai import BinaryContent
|
|
7
|
+
|
|
8
|
+
from aixtools.context import SessionIdTuple
|
|
9
|
+
from aixtools.server import container_to_host_path
|
|
10
|
+
from aixtools.utils.files import is_text_content
|
|
11
|
+
|
|
12
|
+
CLAUDE_MAX_FILE_SIZE_IN_CONTEXT = 4 * 1024 * 1024 # Claude limit 4.5 MB for PDF files
|
|
13
|
+
CLAUDE_IMAGE_MAX_FILE_SIZE_IN_CONTEXT = (
|
|
14
|
+
5 * 1024 * 1024
|
|
15
|
+
) # Claude limit 5 MB for images, to avoid large image files in context
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def should_be_included_into_context(
|
|
19
|
+
file_content: BinaryContent | str | None,
|
|
20
|
+
file_size: int,
|
|
21
|
+
*,
|
|
22
|
+
max_img_size_bytes: int = CLAUDE_IMAGE_MAX_FILE_SIZE_IN_CONTEXT,
|
|
23
|
+
max_file_size_bytes: int = CLAUDE_MAX_FILE_SIZE_IN_CONTEXT,
|
|
24
|
+
) -> bool:
|
|
25
|
+
"""Decide whether a file content should be included into the model context based on its type and size."""
|
|
26
|
+
if not isinstance(file_content, BinaryContent):
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
if file_content.media_type.startswith("text/"):
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
# Exclude archive files as they're not supported by OpenAI models
|
|
33
|
+
archive_types = {
|
|
34
|
+
"application/zip",
|
|
35
|
+
"application/x-tar",
|
|
36
|
+
"application/gzip",
|
|
37
|
+
"application/x-gzip",
|
|
38
|
+
"application/x-rar-compressed",
|
|
39
|
+
"application/x-7z-compressed",
|
|
40
|
+
}
|
|
41
|
+
if file_content.media_type in archive_types:
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
if file_content.is_image and file_size < max_img_size_bytes:
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
return file_size < max_file_size_bytes
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def file_to_binary_content(file_path: str | Path, mime_type: str = "") -> str | BinaryContent:
|
|
51
|
+
"""
|
|
52
|
+
Read a file and return its content as either a UTF-8 string (for text files)
|
|
53
|
+
or BinaryContent (for binary files).
|
|
54
|
+
"""
|
|
55
|
+
with open(file_path, "rb") as f:
|
|
56
|
+
data = f.read()
|
|
57
|
+
|
|
58
|
+
if not mime_type:
|
|
59
|
+
mime_type, _ = mimetypes.guess_type(file_path)
|
|
60
|
+
mime_type = mime_type or "application/octet-stream"
|
|
61
|
+
|
|
62
|
+
if is_text_content(data, mime_type):
|
|
63
|
+
return data.decode("utf-8")
|
|
64
|
+
|
|
65
|
+
return BinaryContent(data=data, media_type=mime_type)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def build_user_input(
|
|
69
|
+
session_tuple: SessionIdTuple,
|
|
70
|
+
user_text: str,
|
|
71
|
+
file_paths: list[Path],
|
|
72
|
+
) -> str | list[str | BinaryContent]:
|
|
73
|
+
"""Build user input for the Pydantic AI agent, including file attachments if provided."""
|
|
74
|
+
if not file_paths:
|
|
75
|
+
return user_text
|
|
76
|
+
|
|
77
|
+
attachment_info_lines = []
|
|
78
|
+
binary_attachments = []
|
|
79
|
+
|
|
80
|
+
for workspace_path in file_paths:
|
|
81
|
+
host_path = container_to_host_path(PurePosixPath(workspace_path), ctx=session_tuple)
|
|
82
|
+
file_size = host_path.stat().st_size
|
|
83
|
+
mime_type, _ = mimetypes.guess_type(host_path)
|
|
84
|
+
mime_type = mime_type or "application/octet-stream"
|
|
85
|
+
|
|
86
|
+
attachment_info = f"* {workspace_path.name} (file_size={file_size} bytes) (path in workspace: {workspace_path})"
|
|
87
|
+
binary_content = file_to_binary_content(host_path, mime_type)
|
|
88
|
+
|
|
89
|
+
if should_be_included_into_context(binary_content, file_size):
|
|
90
|
+
binary_attachments.append(binary_content)
|
|
91
|
+
attachment_info += f" -- provided to model context at index {len(binary_attachments) - 1}"
|
|
92
|
+
|
|
93
|
+
attachment_info_lines.append(attachment_info)
|
|
94
|
+
|
|
95
|
+
full_prompt = user_text + "\nAttachments:\n" + "\n".join(attachment_info_lines)
|
|
96
|
+
|
|
97
|
+
return [full_prompt] + binary_attachments
|
aixtools/app.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Simple Chainlit app example
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import traceback
|
|
8
|
+
|
|
9
|
+
import chainlit as cl
|
|
10
|
+
from pydantic_graph import End
|
|
11
|
+
|
|
12
|
+
from aixtools.agents.agent import get_agent
|
|
13
|
+
from aixtools.logging.logging_config import get_logger
|
|
14
|
+
from aixtools.utils.chainlit import cl_agent_show
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
HISTORY = "history"
|
|
19
|
+
|
|
20
|
+
SYSTEM_PROMPT = """
|
|
21
|
+
You are a helpful assistant.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@cl.step
|
|
26
|
+
async def greet_tool(msg: str) -> str:
|
|
27
|
+
"""A simple greeting tool"""
|
|
28
|
+
return f"Hello! You said: {msg}"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def parse_user_message(message):
|
|
32
|
+
"""Parse user message and check if it is a command"""
|
|
33
|
+
# When we type something that starts with ':', we are using a "command" (i.e. it does not go to the agent)
|
|
34
|
+
command = str(message.content).strip().lower()
|
|
35
|
+
if command.startswith(":"):
|
|
36
|
+
logger.debug("Received command: %s", command)
|
|
37
|
+
match command:
|
|
38
|
+
case ":clear":
|
|
39
|
+
# Clear the history
|
|
40
|
+
cl.user_session.set(HISTORY, [])
|
|
41
|
+
return None
|
|
42
|
+
case ":help":
|
|
43
|
+
# Show help
|
|
44
|
+
help_message = """
|
|
45
|
+
Available commands:
|
|
46
|
+
- :clear: Clear the chat history
|
|
47
|
+
- :help: Show this help message
|
|
48
|
+
"""
|
|
49
|
+
await cl.Message(content=help_message).send()
|
|
50
|
+
return None
|
|
51
|
+
case ":history":
|
|
52
|
+
# Show history
|
|
53
|
+
history = cl.user_session.get(HISTORY)
|
|
54
|
+
if history:
|
|
55
|
+
history_message = "\n".join(history)
|
|
56
|
+
await cl.Message(content=f"Chat history:\n{history_message}").send()
|
|
57
|
+
else:
|
|
58
|
+
await cl.Message(content="No history available.").send()
|
|
59
|
+
return None
|
|
60
|
+
case _:
|
|
61
|
+
# Unknown command
|
|
62
|
+
await cl.Message(content=f"Unknown command: {command}").send()
|
|
63
|
+
return None
|
|
64
|
+
else:
|
|
65
|
+
user_message = message.content
|
|
66
|
+
logger.debug("User message: %s", user_message)
|
|
67
|
+
return user_message
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def run_agent(messages):
|
|
71
|
+
"""Run the agent with the given messages"""
|
|
72
|
+
agent = get_agent(system_prompt=SYSTEM_PROMPT, tools=[greet_tool])
|
|
73
|
+
ret = ""
|
|
74
|
+
msg = cl.Message(content="")
|
|
75
|
+
await msg.send()
|
|
76
|
+
try:
|
|
77
|
+
ret = await cl_agent_show.show_run(agent=agent, prompt=messages, msg=msg, debug=False)
|
|
78
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
79
|
+
msg.elements.append(cl.Text(name="Error", content=f"Error: {e}", type="error")) # pylint: disable=unexpected-keyword-arg
|
|
80
|
+
logger.error("Error: %s", e)
|
|
81
|
+
# Log the full stack trace for debugging
|
|
82
|
+
stack_trace = traceback.format_exc()
|
|
83
|
+
logger.error("Stack tarace:\n%s", stack_trace)
|
|
84
|
+
logger.error("Stack trace:\n%s", stack_trace)
|
|
85
|
+
msg.elements.append(cl.Text(name="Stack Trace", content=stack_trace, language="python"))
|
|
86
|
+
ret = f"Internal server error: {e}"
|
|
87
|
+
await msg.send()
|
|
88
|
+
return ret
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def update_history(history, user_message=None, run_return=None):
|
|
92
|
+
"""Update history with user message and model run output"""
|
|
93
|
+
assert user_message is not None or run_return is not None, "Either user message or run return must be provided"
|
|
94
|
+
if user_message is not None:
|
|
95
|
+
logger.debug("Updating history: Got user message type %s: %s", type(user_message), user_message)
|
|
96
|
+
assert isinstance(user_message, str)
|
|
97
|
+
history.append(user_message)
|
|
98
|
+
if run_return is not None:
|
|
99
|
+
logger.debug("Updating history: Got agent output type %s: %s", type(run_return), run_return)
|
|
100
|
+
latest_item = ""
|
|
101
|
+
if isinstance(run_return, list):
|
|
102
|
+
# If it is a list of 'node' items, the last element is the 'end_message' with the final result
|
|
103
|
+
end_message: End = run_return[-1]
|
|
104
|
+
final_result = end_message.data
|
|
105
|
+
latest_item = str(final_result.data)
|
|
106
|
+
else:
|
|
107
|
+
latest_item = str(run_return)
|
|
108
|
+
# Update history and store it
|
|
109
|
+
logger.debug("Updating history: Adding to history type %s: %s", type(latest_item), latest_item)
|
|
110
|
+
history.append(latest_item)
|
|
111
|
+
return history
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@cl.set_starters
|
|
115
|
+
async def set_starters():
|
|
116
|
+
"""Set the starters"""
|
|
117
|
+
return [
|
|
118
|
+
cl.Starter(label="Message", message="Hello world!"),
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@cl.on_chat_start
|
|
123
|
+
async def on_chat_start():
|
|
124
|
+
"""Initialize chat session by resetting history when a new chat starts."""
|
|
125
|
+
# Reset history
|
|
126
|
+
logger.debug("On chat start")
|
|
127
|
+
cl.user_session.set(HISTORY, [])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@cl.on_message
|
|
131
|
+
async def on_message(message: cl.Message):
|
|
132
|
+
"""Process incoming chat messages and generate responses using the agent."""
|
|
133
|
+
history = cl.user_session.get(HISTORY) # Get user message and history
|
|
134
|
+
user_message = await parse_user_message(message) # Parse user message
|
|
135
|
+
# Check if user message is None (e.g. if it is a command)
|
|
136
|
+
if user_message is None:
|
|
137
|
+
return
|
|
138
|
+
messages = update_history(history, user_message=user_message) # Update history with user message
|
|
139
|
+
# Run the agent
|
|
140
|
+
run_return = await run_agent(messages)
|
|
141
|
+
# Update history and store it
|
|
142
|
+
history = update_history(history, run_return=run_return)
|
|
143
|
+
cl.user_session.set(HISTORY, messages)
|
aixtools/chainlit.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Welcome to Chainlit! 🚀🤖
|
|
2
|
+
|
|
3
|
+
Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
|
|
4
|
+
|
|
5
|
+
## Useful Links 🔗
|
|
6
|
+
|
|
7
|
+
- **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
|
|
8
|
+
- **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
|
|
9
|
+
|
|
10
|
+
We can't wait to see what you create with Chainlit! Happy coding! 💻😊
|
|
11
|
+
|
|
12
|
+
## Welcome screen
|
|
13
|
+
|
|
14
|
+
To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Private data management module for aixtools compliance."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from fastmcp import Context
|
|
7
|
+
|
|
8
|
+
from aixtools.server.path import get_workspace_path
|
|
9
|
+
|
|
10
|
+
PRIVATE_DATA_FILE = ".private_data"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PrivateData:
|
|
14
|
+
"""
|
|
15
|
+
Class to manage private data file in the workspace.
|
|
16
|
+
|
|
17
|
+
The information is stored in a JSON file named `.private_data` within the workspace directory.
|
|
18
|
+
If the file does not exist, it indicates that there is no private data.
|
|
19
|
+
|
|
20
|
+
IMPORTANT: All modifications save the data to the file immediately.
|
|
21
|
+
|
|
22
|
+
FIXME: We should add some level of mutex/locking to prevent concurrent writes.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, ctx: Context | None = None):
|
|
26
|
+
self.ctx: Context | None = ctx
|
|
27
|
+
self._has_private_data: bool = False # Flag indicating if private data exists
|
|
28
|
+
self._private_datasets: list[str] = [] # List of private datasets
|
|
29
|
+
self._idap_datasets: list[str] = [] # List of dataset with IDAP
|
|
30
|
+
self.load()
|
|
31
|
+
|
|
32
|
+
def add_private_dataset(self, dataset_name: str) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Add a private dataset to the list.
|
|
35
|
+
Save the state after modification.
|
|
36
|
+
"""
|
|
37
|
+
if dataset_name not in self._private_datasets:
|
|
38
|
+
self._private_datasets.append(dataset_name)
|
|
39
|
+
self._has_private_data = True
|
|
40
|
+
self.save()
|
|
41
|
+
|
|
42
|
+
def add_idap_dataset(self, dataset_name: str) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Add a dataset with IDAP to the list.
|
|
45
|
+
This also adds it to the private datasets if not already present.
|
|
46
|
+
Save the state after modification.
|
|
47
|
+
"""
|
|
48
|
+
if not self.has_idap_dataset(dataset_name):
|
|
49
|
+
self._idap_datasets.append(dataset_name)
|
|
50
|
+
self._has_private_data = True
|
|
51
|
+
# An IDAP dataset is also a private dataset
|
|
52
|
+
if not self.has_private_dataset(dataset_name):
|
|
53
|
+
self._private_datasets.append(dataset_name)
|
|
54
|
+
self.save()
|
|
55
|
+
|
|
56
|
+
def get_private_datasets(self) -> list[str]:
|
|
57
|
+
"""Get the list of private datasets as a copy (to avoid modification)."""
|
|
58
|
+
return list(self._private_datasets)
|
|
59
|
+
|
|
60
|
+
def get_idap_datasets(self) -> list[str]:
|
|
61
|
+
"""Get the list of datasets with IDAP as a copy (to avoid modification)."""
|
|
62
|
+
return list(self._idap_datasets)
|
|
63
|
+
|
|
64
|
+
def has_private_dataset(self, dataset_name: str) -> bool:
|
|
65
|
+
"""Check if a specific private dataset exists."""
|
|
66
|
+
return dataset_name in self._private_datasets
|
|
67
|
+
|
|
68
|
+
def has_idap_dataset(self, dataset_name: str) -> bool:
|
|
69
|
+
"""Check if a specific dataset with IDAP exists."""
|
|
70
|
+
return dataset_name in self._idap_datasets
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def has_private_data(self) -> bool:
|
|
74
|
+
"""Check if private data exists."""
|
|
75
|
+
return self._has_private_data
|
|
76
|
+
|
|
77
|
+
@has_private_data.setter
|
|
78
|
+
def has_private_data(self, value: bool) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Set the flag indicating if private data exists.
|
|
81
|
+
Save the state after modification.
|
|
82
|
+
"""
|
|
83
|
+
self._has_private_data = value
|
|
84
|
+
if not value:
|
|
85
|
+
self._private_datasets = []
|
|
86
|
+
self._idap_datasets = []
|
|
87
|
+
self.save()
|
|
88
|
+
|
|
89
|
+
def _get_private_data_path(self) -> Path:
|
|
90
|
+
"""Get the path to the private data file in the workspace."""
|
|
91
|
+
return get_workspace_path(service_name=None, ctx=self.ctx) / PRIVATE_DATA_FILE
|
|
92
|
+
|
|
93
|
+
def _has_private_data_file(self) -> bool:
|
|
94
|
+
"""Check if the private data file exists in the workspace."""
|
|
95
|
+
private_data_path = self._get_private_data_path()
|
|
96
|
+
return private_data_path.exists()
|
|
97
|
+
|
|
98
|
+
def save(self) -> None:
|
|
99
|
+
"""Save content to the private data file in the workspace."""
|
|
100
|
+
private_data_path = self._get_private_data_path()
|
|
101
|
+
# No private data? Delete the file if it exists
|
|
102
|
+
if not self.has_private_data:
|
|
103
|
+
private_data_path.unlink(missing_ok=True)
|
|
104
|
+
return
|
|
105
|
+
# If there is private data, serialize this object as JSON
|
|
106
|
+
private_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
107
|
+
with open(private_data_path, "w", encoding="utf-8") as f:
|
|
108
|
+
# Dump class as JSON, excluding the context
|
|
109
|
+
data_dict = self.__dict__.copy()
|
|
110
|
+
data_dict["ctx"] = None
|
|
111
|
+
json_data = json.dumps(data_dict, indent=4)
|
|
112
|
+
f.write(json_data)
|
|
113
|
+
|
|
114
|
+
def load(self) -> None:
|
|
115
|
+
"""Load content from the private data file in the workspace."""
|
|
116
|
+
private_data_path = self._get_private_data_path()
|
|
117
|
+
if not private_data_path.exists():
|
|
118
|
+
# No private data file
|
|
119
|
+
self.has_private_data = False
|
|
120
|
+
self._private_datasets = []
|
|
121
|
+
self._idap_datasets = []
|
|
122
|
+
return
|
|
123
|
+
with open(private_data_path, "r", encoding="utf-8") as f:
|
|
124
|
+
data = json.load(f)
|
|
125
|
+
self.has_private_data = data.get("_has_private_data", False)
|
|
126
|
+
self._private_datasets = data.get("_private_datasets", [])
|
|
127
|
+
self._idap_datasets = data.get("_idap_datasets", [])
|
|
128
|
+
|
|
129
|
+
def __repr__(self) -> str:
|
|
130
|
+
return (
|
|
131
|
+
f"PrivateData(has_private_data={self.has_private_data}, "
|
|
132
|
+
f"private_datasets={self._private_datasets}, "
|
|
133
|
+
f"idap_datasets={self._idap_datasets}), "
|
|
134
|
+
f"file_path={self._get_private_data_path()})"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def __str__(self) -> str:
|
|
138
|
+
return self.__repr__()
|
aixtools/context.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines global context variables for request-specific information
|
|
3
|
+
that can be used for logging, tracing, and other purposes across applications
|
|
4
|
+
that use aixtools.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextvars import ContextVar
|
|
8
|
+
|
|
9
|
+
# Define context variables with default values.
|
|
10
|
+
# These can be populated by middleware or where they are initialized
|
|
11
|
+
session_id_var: ContextVar[str | None] = ContextVar("session_id", default=None)
|
|
12
|
+
user_id_var: ContextVar[str | None] = ContextVar("user_id", default=None)
|
|
13
|
+
|
|
14
|
+
DEFAULT_USER_ID = "default_user"
|
|
15
|
+
DEFAULT_SESSION_ID = "default_session"
|
|
16
|
+
|
|
17
|
+
SessionIdTuple = tuple[str, str]
|
aixtools/db/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database module for vector storage and retrieval.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from aixtools.db.database import DatabaseError, SqliteDb
|
|
6
|
+
from aixtools.db.vector_db import get_vdb_embedding, get_vector_db, vdb_add, vdb_get_by_id, vdb_has_id, vdb_query
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"DatabaseError",
|
|
10
|
+
"SqliteDb",
|
|
11
|
+
"get_vdb_embedding",
|
|
12
|
+
"get_vector_db",
|
|
13
|
+
"vdb_add",
|
|
14
|
+
"vdb_get_by_id",
|
|
15
|
+
"vdb_has_id",
|
|
16
|
+
"vdb_query",
|
|
17
|
+
]
|
aixtools/db/database.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database Interface for Clinical Trials Information.
|
|
3
|
+
|
|
4
|
+
This module provides a database interface for querying clinical trials data
|
|
5
|
+
from the SQLite database.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import sqlite3
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from aixtools.logging.logging_config import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DatabaseError(Exception):
|
|
21
|
+
"""Exception raised for database-related errors."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SqliteDb:
|
|
25
|
+
"""
|
|
26
|
+
Database interface.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, db_path: str | Path):
|
|
30
|
+
"""Initialize the database interface"""
|
|
31
|
+
self.db_path = Path(db_path)
|
|
32
|
+
if not self.db_path.exists():
|
|
33
|
+
raise FileNotFoundError(f"Database file not found: {self.db_path}")
|
|
34
|
+
# Test connection
|
|
35
|
+
with self.connection() as conn:
|
|
36
|
+
logger.info("Connected to database: %s, connection: %s", self.db_path, conn)
|
|
37
|
+
|
|
38
|
+
@contextmanager
|
|
39
|
+
def connection(self):
|
|
40
|
+
"""
|
|
41
|
+
Context manager for database connections.
|
|
42
|
+
|
|
43
|
+
Yields:
|
|
44
|
+
sqlite3.Connection: An active database connection
|
|
45
|
+
"""
|
|
46
|
+
conn = None
|
|
47
|
+
try:
|
|
48
|
+
conn = sqlite3.connect(self.db_path)
|
|
49
|
+
# Enable dictionary row factory
|
|
50
|
+
conn.row_factory = sqlite3.Row
|
|
51
|
+
yield conn
|
|
52
|
+
except sqlite3.Error as e:
|
|
53
|
+
raise DatabaseError(f"Database error: {e}") from e
|
|
54
|
+
finally:
|
|
55
|
+
if conn:
|
|
56
|
+
conn.close()
|
|
57
|
+
|
|
58
|
+
def query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
59
|
+
"""
|
|
60
|
+
Execute a SQL query and return the results as a list of dictionaries.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
query: SQL query to execute
|
|
64
|
+
params: Parameters for the query
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
List of dictionaries representing the query results
|
|
68
|
+
"""
|
|
69
|
+
with self.connection() as conn:
|
|
70
|
+
cursor = conn.cursor()
|
|
71
|
+
if params:
|
|
72
|
+
cursor.execute(query, params)
|
|
73
|
+
else:
|
|
74
|
+
cursor.execute(query)
|
|
75
|
+
|
|
76
|
+
results = cursor.fetchall()
|
|
77
|
+
# Convert sqlite3.Row objects to dictionaries
|
|
78
|
+
return [dict(row) for row in results]
|
|
79
|
+
|
|
80
|
+
def query_df(self, query: str, params: dict[str, Any] | None = None) -> pd.DataFrame:
|
|
81
|
+
"""
|
|
82
|
+
Execute a SQL query and return the results as a pandas DataFrame.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
query: SQL query to execute.
|
|
86
|
+
params: Parameters to substitute in the query.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
A pandas DataFrame containing the query results.
|
|
90
|
+
"""
|
|
91
|
+
with self.connection() as conn:
|
|
92
|
+
if params:
|
|
93
|
+
df = pd.read_sql_query(query, conn, params=params)
|
|
94
|
+
else:
|
|
95
|
+
df = pd.read_sql_query(query, conn)
|
|
96
|
+
return df
|
|
97
|
+
|
|
98
|
+
def validate(self, query) -> str | None:
|
|
99
|
+
"""
|
|
100
|
+
Validate the SQL query by executing an EXPLAIN QUERY PLAN statement.
|
|
101
|
+
Returns the error string if there is an issue, otherwise returns None
|
|
102
|
+
"""
|
|
103
|
+
with self.connection() as conn:
|
|
104
|
+
try:
|
|
105
|
+
cursor = conn.cursor()
|
|
106
|
+
cursor.execute(f"EXPLAIN QUERY PLAN\n{query}")
|
|
107
|
+
cursor.fetchall()
|
|
108
|
+
return None
|
|
109
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
110
|
+
return str(e)
|