aixtools 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (88) hide show
  1. aixtools/.chainlit/config.toml +113 -0
  2. aixtools/.chainlit/translations/bn.json +214 -0
  3. aixtools/.chainlit/translations/en-US.json +214 -0
  4. aixtools/.chainlit/translations/gu.json +214 -0
  5. aixtools/.chainlit/translations/he-IL.json +214 -0
  6. aixtools/.chainlit/translations/hi.json +214 -0
  7. aixtools/.chainlit/translations/ja.json +214 -0
  8. aixtools/.chainlit/translations/kn.json +214 -0
  9. aixtools/.chainlit/translations/ml.json +214 -0
  10. aixtools/.chainlit/translations/mr.json +214 -0
  11. aixtools/.chainlit/translations/nl.json +214 -0
  12. aixtools/.chainlit/translations/ta.json +214 -0
  13. aixtools/.chainlit/translations/te.json +214 -0
  14. aixtools/.chainlit/translations/zh-CN.json +214 -0
  15. aixtools/__init__.py +11 -0
  16. aixtools/_version.py +34 -0
  17. aixtools/a2a/app.py +126 -0
  18. aixtools/a2a/google_sdk/__init__.py +0 -0
  19. aixtools/a2a/google_sdk/card.py +27 -0
  20. aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
  21. aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
  22. aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
  23. aixtools/a2a/google_sdk/utils.py +59 -0
  24. aixtools/a2a/utils.py +115 -0
  25. aixtools/agents/__init__.py +12 -0
  26. aixtools/agents/agent.py +164 -0
  27. aixtools/agents/agent_batch.py +71 -0
  28. aixtools/agents/prompt.py +97 -0
  29. aixtools/app.py +143 -0
  30. aixtools/chainlit.md +14 -0
  31. aixtools/compliance/__init__.py +9 -0
  32. aixtools/compliance/private_data.py +138 -0
  33. aixtools/context.py +17 -0
  34. aixtools/db/__init__.py +17 -0
  35. aixtools/db/database.py +110 -0
  36. aixtools/db/vector_db.py +115 -0
  37. aixtools/google/client.py +25 -0
  38. aixtools/log_view/__init__.py +17 -0
  39. aixtools/log_view/app.py +195 -0
  40. aixtools/log_view/display.py +285 -0
  41. aixtools/log_view/export.py +51 -0
  42. aixtools/log_view/filters.py +41 -0
  43. aixtools/log_view/log_utils.py +26 -0
  44. aixtools/log_view/node_summary.py +229 -0
  45. aixtools/logfilters/__init__.py +7 -0
  46. aixtools/logfilters/context_filter.py +67 -0
  47. aixtools/logging/__init__.py +30 -0
  48. aixtools/logging/log_objects.py +227 -0
  49. aixtools/logging/logging_config.py +161 -0
  50. aixtools/logging/mcp_log_models.py +102 -0
  51. aixtools/logging/mcp_logger.py +172 -0
  52. aixtools/logging/model_patch_logging.py +87 -0
  53. aixtools/logging/open_telemetry.py +36 -0
  54. aixtools/mcp/__init__.py +9 -0
  55. aixtools/mcp/client.py +375 -0
  56. aixtools/mcp/example_client.py +30 -0
  57. aixtools/mcp/example_server.py +22 -0
  58. aixtools/mcp/fast_mcp_log.py +31 -0
  59. aixtools/mcp/faulty_mcp.py +319 -0
  60. aixtools/model_patch/model_patch.py +63 -0
  61. aixtools/server/__init__.py +29 -0
  62. aixtools/server/app_mounter.py +90 -0
  63. aixtools/server/path.py +72 -0
  64. aixtools/server/utils.py +70 -0
  65. aixtools/server/workspace_privacy.py +65 -0
  66. aixtools/testing/__init__.py +9 -0
  67. aixtools/testing/aix_test_model.py +149 -0
  68. aixtools/testing/mock_tool.py +66 -0
  69. aixtools/testing/model_patch_cache.py +279 -0
  70. aixtools/tools/doctor/__init__.py +3 -0
  71. aixtools/tools/doctor/tool_doctor.py +61 -0
  72. aixtools/tools/doctor/tool_recommendation.py +44 -0
  73. aixtools/utils/__init__.py +35 -0
  74. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  75. aixtools/utils/chainlit/cl_utils.py +168 -0
  76. aixtools/utils/config.py +131 -0
  77. aixtools/utils/config_util.py +69 -0
  78. aixtools/utils/enum_with_description.py +37 -0
  79. aixtools/utils/files.py +17 -0
  80. aixtools/utils/persisted_dict.py +99 -0
  81. aixtools/utils/utils.py +167 -0
  82. aixtools/vault/__init__.py +7 -0
  83. aixtools/vault/vault.py +137 -0
  84. aixtools-0.0.0.dist-info/METADATA +669 -0
  85. aixtools-0.0.0.dist-info/RECORD +88 -0
  86. aixtools-0.0.0.dist-info/WHEEL +5 -0
  87. aixtools-0.0.0.dist-info/entry_points.txt +2 -0
  88. aixtools-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ """
2
+ Batch processing functionality for running multiple agent queries in parallel.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel, ConfigDict
9
+
10
+ from aixtools.agents.agent import get_agent, run_agent
11
+
12
+
13
+ class AgentQueryParams(BaseModel):
14
+ """Parameters for configuring agent queries in batch processing."""
15
+
16
+ model_config = ConfigDict(arbitrary_types_allowed=True)
17
+
18
+ id: str = "" # Unique identifier for the query
19
+ prompt: str | list[str]
20
+ agent: Any = None
21
+ model: Any = None
22
+ debug: bool = False
23
+ output_type: Any = str
24
+ tools: list | None = []
25
+
26
+ async def run(self):
27
+ """Query the LLM"""
28
+ agent = self.agent
29
+ if agent is None:
30
+ agent = get_agent(
31
+ system_prompt=self.prompt, model=self.model, tools=self.tools, output_type=self.output_type
32
+ )
33
+ return await run_agent(agent=agent, prompt=self.prompt, debug=self.debug)
34
+
35
+
36
+ async def run_agent_batch(query_parameters: list[AgentQueryParams], batch_size=10):
37
+ """
38
+ Run multiple queries simultanously in batches of at most batch_size
39
+ and yield the results as they come in.
40
+
41
+ Usage example:
42
+ query_parameters = [
43
+ AgentQueryParams(prompt="What is the meaning of life")
44
+ AgentQueryParams(prompt="Who is the prime minister of Canada")
45
+ ]
46
+
47
+ async for result in agent_batch(query_parameters):
48
+ print(result)
49
+ """
50
+ tasks = []
51
+ batch_num, total = 1, len(query_parameters)
52
+ for i, qp in enumerate(query_parameters):
53
+ tasks.append(qp.run())
54
+ if len(tasks) >= batch_size:
55
+ # Run a batch of tasks
56
+ print(f"Running batch {batch_num}, {i + 1} / {total}")
57
+ tasks_results = await asyncio.gather(
58
+ *tasks
59
+ ) # Returns a list of results, each one is a tuple (result, nodes)
60
+ # Yield the results
61
+ for r, _ in tasks_results:
62
+ yield r
63
+ tasks = []
64
+ batch_num += 1
65
+ # Run the last batch of tasks
66
+ if tasks:
67
+ print(f"Running final batch {batch_num}")
68
+ tasks_results = await asyncio.gather(*tasks)
69
+ for r, _ in tasks_results:
70
+ yield r
71
+ print("Done")
@@ -0,0 +1,97 @@
1
+ """Prompt building utilities for Pydantic AI agent, including file handling and context management."""
2
+
3
+ import mimetypes
4
+ from pathlib import Path, PurePosixPath
5
+
6
+ from pydantic_ai import BinaryContent
7
+
8
+ from aixtools.context import SessionIdTuple
9
+ from aixtools.server import container_to_host_path
10
+ from aixtools.utils.files import is_text_content
11
+
12
+ CLAUDE_MAX_FILE_SIZE_IN_CONTEXT = 4 * 1024 * 1024 # Claude limit 4.5 MB for PDF files
13
+ CLAUDE_IMAGE_MAX_FILE_SIZE_IN_CONTEXT = (
14
+ 5 * 1024 * 1024
15
+ ) # Claude limit 5 MB for images, to avoid large image files in context
16
+
17
+
18
+ def should_be_included_into_context(
19
+ file_content: BinaryContent | str | None,
20
+ file_size: int,
21
+ *,
22
+ max_img_size_bytes: int = CLAUDE_IMAGE_MAX_FILE_SIZE_IN_CONTEXT,
23
+ max_file_size_bytes: int = CLAUDE_MAX_FILE_SIZE_IN_CONTEXT,
24
+ ) -> bool:
25
+ """Decide whether a file content should be included into the model context based on its type and size."""
26
+ if not isinstance(file_content, BinaryContent):
27
+ return False
28
+
29
+ if file_content.media_type.startswith("text/"):
30
+ return False
31
+
32
+ # Exclude archive files as they're not supported by OpenAI models
33
+ archive_types = {
34
+ "application/zip",
35
+ "application/x-tar",
36
+ "application/gzip",
37
+ "application/x-gzip",
38
+ "application/x-rar-compressed",
39
+ "application/x-7z-compressed",
40
+ }
41
+ if file_content.media_type in archive_types:
42
+ return False
43
+
44
+ if file_content.is_image and file_size < max_img_size_bytes:
45
+ return True
46
+
47
+ return file_size < max_file_size_bytes
48
+
49
+
50
+ def file_to_binary_content(file_path: str | Path, mime_type: str = "") -> str | BinaryContent:
51
+ """
52
+ Read a file and return its content as either a UTF-8 string (for text files)
53
+ or BinaryContent (for binary files).
54
+ """
55
+ with open(file_path, "rb") as f:
56
+ data = f.read()
57
+
58
+ if not mime_type:
59
+ mime_type, _ = mimetypes.guess_type(file_path)
60
+ mime_type = mime_type or "application/octet-stream"
61
+
62
+ if is_text_content(data, mime_type):
63
+ return data.decode("utf-8")
64
+
65
+ return BinaryContent(data=data, media_type=mime_type)
66
+
67
+
68
+ def build_user_input(
69
+ session_tuple: SessionIdTuple,
70
+ user_text: str,
71
+ file_paths: list[Path],
72
+ ) -> str | list[str | BinaryContent]:
73
+ """Build user input for the Pydantic AI agent, including file attachments if provided."""
74
+ if not file_paths:
75
+ return user_text
76
+
77
+ attachment_info_lines = []
78
+ binary_attachments = []
79
+
80
+ for workspace_path in file_paths:
81
+ host_path = container_to_host_path(PurePosixPath(workspace_path), ctx=session_tuple)
82
+ file_size = host_path.stat().st_size
83
+ mime_type, _ = mimetypes.guess_type(host_path)
84
+ mime_type = mime_type or "application/octet-stream"
85
+
86
+ attachment_info = f"* {workspace_path.name} (file_size={file_size} bytes) (path in workspace: {workspace_path})"
87
+ binary_content = file_to_binary_content(host_path, mime_type)
88
+
89
+ if should_be_included_into_context(binary_content, file_size):
90
+ binary_attachments.append(binary_content)
91
+ attachment_info += f" -- provided to model context at index {len(binary_attachments) - 1}"
92
+
93
+ attachment_info_lines.append(attachment_info)
94
+
95
+ full_prompt = user_text + "\nAttachments:\n" + "\n".join(attachment_info_lines)
96
+
97
+ return [full_prompt] + binary_attachments
aixtools/app.py ADDED
@@ -0,0 +1,143 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Simple Chainlit app example
5
+ """
6
+
7
+ import traceback
8
+
9
+ import chainlit as cl
10
+ from pydantic_graph import End
11
+
12
+ from aixtools.agents.agent import get_agent
13
+ from aixtools.logging.logging_config import get_logger
14
+ from aixtools.utils.chainlit import cl_agent_show
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ HISTORY = "history"
19
+
20
+ SYSTEM_PROMPT = """
21
+ You are a helpful assistant.
22
+ """
23
+
24
+
25
+ @cl.step
26
+ async def greet_tool(msg: str) -> str:
27
+ """A simple greeting tool"""
28
+ return f"Hello! You said: {msg}"
29
+
30
+
31
+ async def parse_user_message(message):
32
+ """Parse user message and check if it is a command"""
33
+ # When we type something that starts with ':', we are using a "command" (i.e. it does not go to the agent)
34
+ command = str(message.content).strip().lower()
35
+ if command.startswith(":"):
36
+ logger.debug("Received command: %s", command)
37
+ match command:
38
+ case ":clear":
39
+ # Clear the history
40
+ cl.user_session.set(HISTORY, [])
41
+ return None
42
+ case ":help":
43
+ # Show help
44
+ help_message = """
45
+ Available commands:
46
+ - :clear: Clear the chat history
47
+ - :help: Show this help message
48
+ """
49
+ await cl.Message(content=help_message).send()
50
+ return None
51
+ case ":history":
52
+ # Show history
53
+ history = cl.user_session.get(HISTORY)
54
+ if history:
55
+ history_message = "\n".join(history)
56
+ await cl.Message(content=f"Chat history:\n{history_message}").send()
57
+ else:
58
+ await cl.Message(content="No history available.").send()
59
+ return None
60
+ case _:
61
+ # Unknown command
62
+ await cl.Message(content=f"Unknown command: {command}").send()
63
+ return None
64
+ else:
65
+ user_message = message.content
66
+ logger.debug("User message: %s", user_message)
67
+ return user_message
68
+
69
+
70
+ async def run_agent(messages):
71
+ """Run the agent with the given messages"""
72
+ agent = get_agent(system_prompt=SYSTEM_PROMPT, tools=[greet_tool])
73
+ ret = ""
74
+ msg = cl.Message(content="")
75
+ await msg.send()
76
+ try:
77
+ ret = await cl_agent_show.show_run(agent=agent, prompt=messages, msg=msg, debug=False)
78
+ except Exception as e: # pylint: disable=broad-exception-caught
79
+ msg.elements.append(cl.Text(name="Error", content=f"Error: {e}", type="error")) # pylint: disable=unexpected-keyword-arg
80
+ logger.error("Error: %s", e)
81
+ # Log the full stack trace for debugging
82
+ stack_trace = traceback.format_exc()
83
+ logger.error("Stack tarace:\n%s", stack_trace)
84
+ logger.error("Stack trace:\n%s", stack_trace)
85
+ msg.elements.append(cl.Text(name="Stack Trace", content=stack_trace, language="python"))
86
+ ret = f"Internal server error: {e}"
87
+ await msg.send()
88
+ return ret
89
+
90
+
91
+ def update_history(history, user_message=None, run_return=None):
92
+ """Update history with user message and model run output"""
93
+ assert user_message is not None or run_return is not None, "Either user message or run return must be provided"
94
+ if user_message is not None:
95
+ logger.debug("Updating history: Got user message type %s: %s", type(user_message), user_message)
96
+ assert isinstance(user_message, str)
97
+ history.append(user_message)
98
+ if run_return is not None:
99
+ logger.debug("Updating history: Got agent output type %s: %s", type(run_return), run_return)
100
+ latest_item = ""
101
+ if isinstance(run_return, list):
102
+ # If it is a list of 'node' items, the last element is the 'end_message' with the final result
103
+ end_message: End = run_return[-1]
104
+ final_result = end_message.data
105
+ latest_item = str(final_result.data)
106
+ else:
107
+ latest_item = str(run_return)
108
+ # Update history and store it
109
+ logger.debug("Updating history: Adding to history type %s: %s", type(latest_item), latest_item)
110
+ history.append(latest_item)
111
+ return history
112
+
113
+
114
+ @cl.set_starters
115
+ async def set_starters():
116
+ """Set the starters"""
117
+ return [
118
+ cl.Starter(label="Message", message="Hello world!"),
119
+ ]
120
+
121
+
122
+ @cl.on_chat_start
123
+ async def on_chat_start():
124
+ """Initialize chat session by resetting history when a new chat starts."""
125
+ # Reset history
126
+ logger.debug("On chat start")
127
+ cl.user_session.set(HISTORY, [])
128
+
129
+
130
+ @cl.on_message
131
+ async def on_message(message: cl.Message):
132
+ """Process incoming chat messages and generate responses using the agent."""
133
+ history = cl.user_session.get(HISTORY) # Get user message and history
134
+ user_message = await parse_user_message(message) # Parse user message
135
+ # Check if user message is None (e.g. if it is a command)
136
+ if user_message is None:
137
+ return
138
+ messages = update_history(history, user_message=user_message) # Update history with user message
139
+ # Run the agent
140
+ run_return = await run_agent(messages)
141
+ # Update history and store it
142
+ history = update_history(history, run_return=run_return)
143
+ cl.user_session.set(HISTORY, messages)
aixtools/chainlit.md ADDED
@@ -0,0 +1,14 @@
1
+ # Welcome to Chainlit! 🚀🤖
2
+
3
+ Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
4
+
5
+ ## Useful Links 🔗
6
+
7
+ - **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
8
+ - **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
9
+
10
+ We can't wait to see what you create with Chainlit! Happy coding! 💻😊
11
+
12
+ ## Welcome screen
13
+
14
+ To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
@@ -0,0 +1,9 @@
1
+ """
2
+ Compliance module for aixtools.
3
+
4
+ This module provides utilities for managing compliance-related data and operations.
5
+ """
6
+
7
+ from .private_data import PrivateData
8
+
9
+ __all__ = ["PrivateData"]
@@ -0,0 +1,138 @@
1
+ """Private data management module for aixtools compliance."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from fastmcp import Context
7
+
8
+ from aixtools.server.path import get_workspace_path
9
+
10
+ PRIVATE_DATA_FILE = ".private_data"
11
+
12
+
13
+ class PrivateData:
14
+ """
15
+ Class to manage private data file in the workspace.
16
+
17
+ The information is stored in a JSON file named `.private_data` within the workspace directory.
18
+ If the file does not exist, it indicates that there is no private data.
19
+
20
+ IMPORTANT: All modifications save the data to the file immediately.
21
+
22
+ FIXME: We should add some level of mutex/locking to prevent concurrent writes.
23
+ """
24
+
25
+ def __init__(self, ctx: Context | None = None):
26
+ self.ctx: Context | None = ctx
27
+ self._has_private_data: bool = False # Flag indicating if private data exists
28
+ self._private_datasets: list[str] = [] # List of private datasets
29
+ self._idap_datasets: list[str] = [] # List of dataset with IDAP
30
+ self.load()
31
+
32
+ def add_private_dataset(self, dataset_name: str) -> None:
33
+ """
34
+ Add a private dataset to the list.
35
+ Save the state after modification.
36
+ """
37
+ if dataset_name not in self._private_datasets:
38
+ self._private_datasets.append(dataset_name)
39
+ self._has_private_data = True
40
+ self.save()
41
+
42
+ def add_idap_dataset(self, dataset_name: str) -> None:
43
+ """
44
+ Add a dataset with IDAP to the list.
45
+ This also adds it to the private datasets if not already present.
46
+ Save the state after modification.
47
+ """
48
+ if not self.has_idap_dataset(dataset_name):
49
+ self._idap_datasets.append(dataset_name)
50
+ self._has_private_data = True
51
+ # An IDAP dataset is also a private dataset
52
+ if not self.has_private_dataset(dataset_name):
53
+ self._private_datasets.append(dataset_name)
54
+ self.save()
55
+
56
+ def get_private_datasets(self) -> list[str]:
57
+ """Get the list of private datasets as a copy (to avoid modification)."""
58
+ return list(self._private_datasets)
59
+
60
+ def get_idap_datasets(self) -> list[str]:
61
+ """Get the list of datasets with IDAP as a copy (to avoid modification)."""
62
+ return list(self._idap_datasets)
63
+
64
+ def has_private_dataset(self, dataset_name: str) -> bool:
65
+ """Check if a specific private dataset exists."""
66
+ return dataset_name in self._private_datasets
67
+
68
+ def has_idap_dataset(self, dataset_name: str) -> bool:
69
+ """Check if a specific dataset with IDAP exists."""
70
+ return dataset_name in self._idap_datasets
71
+
72
+ @property
73
+ def has_private_data(self) -> bool:
74
+ """Check if private data exists."""
75
+ return self._has_private_data
76
+
77
+ @has_private_data.setter
78
+ def has_private_data(self, value: bool) -> None:
79
+ """
80
+ Set the flag indicating if private data exists.
81
+ Save the state after modification.
82
+ """
83
+ self._has_private_data = value
84
+ if not value:
85
+ self._private_datasets = []
86
+ self._idap_datasets = []
87
+ self.save()
88
+
89
+ def _get_private_data_path(self) -> Path:
90
+ """Get the path to the private data file in the workspace."""
91
+ return get_workspace_path(service_name=None, ctx=self.ctx) / PRIVATE_DATA_FILE
92
+
93
+ def _has_private_data_file(self) -> bool:
94
+ """Check if the private data file exists in the workspace."""
95
+ private_data_path = self._get_private_data_path()
96
+ return private_data_path.exists()
97
+
98
+ def save(self) -> None:
99
+ """Save content to the private data file in the workspace."""
100
+ private_data_path = self._get_private_data_path()
101
+ # No private data? Delete the file if it exists
102
+ if not self.has_private_data:
103
+ private_data_path.unlink(missing_ok=True)
104
+ return
105
+ # If there is private data, serialize this object as JSON
106
+ private_data_path.parent.mkdir(parents=True, exist_ok=True)
107
+ with open(private_data_path, "w", encoding="utf-8") as f:
108
+ # Dump class as JSON, excluding the context
109
+ data_dict = self.__dict__.copy()
110
+ data_dict["ctx"] = None
111
+ json_data = json.dumps(data_dict, indent=4)
112
+ f.write(json_data)
113
+
114
+ def load(self) -> None:
115
+ """Load content from the private data file in the workspace."""
116
+ private_data_path = self._get_private_data_path()
117
+ if not private_data_path.exists():
118
+ # No private data file
119
+ self.has_private_data = False
120
+ self._private_datasets = []
121
+ self._idap_datasets = []
122
+ return
123
+ with open(private_data_path, "r", encoding="utf-8") as f:
124
+ data = json.load(f)
125
+ self.has_private_data = data.get("_has_private_data", False)
126
+ self._private_datasets = data.get("_private_datasets", [])
127
+ self._idap_datasets = data.get("_idap_datasets", [])
128
+
129
+ def __repr__(self) -> str:
130
+ return (
131
+ f"PrivateData(has_private_data={self.has_private_data}, "
132
+ f"private_datasets={self._private_datasets}, "
133
+ f"idap_datasets={self._idap_datasets}), "
134
+ f"file_path={self._get_private_data_path()})"
135
+ )
136
+
137
+ def __str__(self) -> str:
138
+ return self.__repr__()
aixtools/context.py ADDED
@@ -0,0 +1,17 @@
1
+ """
2
+ This module defines global context variables for request-specific information
3
+ that can be used for logging, tracing, and other purposes across applications
4
+ that use aixtools.
5
+ """
6
+
7
+ from contextvars import ContextVar
8
+
9
+ # Define context variables with default values.
10
+ # These can be populated by middleware or where they are initialized
11
+ session_id_var: ContextVar[str | None] = ContextVar("session_id", default=None)
12
+ user_id_var: ContextVar[str | None] = ContextVar("user_id", default=None)
13
+
14
+ DEFAULT_USER_ID = "default_user"
15
+ DEFAULT_SESSION_ID = "default_session"
16
+
17
+ SessionIdTuple = tuple[str, str]
@@ -0,0 +1,17 @@
1
+ """
2
+ Database module for vector storage and retrieval.
3
+ """
4
+
5
+ from aixtools.db.database import DatabaseError, SqliteDb
6
+ from aixtools.db.vector_db import get_vdb_embedding, get_vector_db, vdb_add, vdb_get_by_id, vdb_has_id, vdb_query
7
+
8
+ __all__ = [
9
+ "DatabaseError",
10
+ "SqliteDb",
11
+ "get_vdb_embedding",
12
+ "get_vector_db",
13
+ "vdb_add",
14
+ "vdb_get_by_id",
15
+ "vdb_has_id",
16
+ "vdb_query",
17
+ ]
@@ -0,0 +1,110 @@
1
+ """
2
+ Database Interface for Clinical Trials Information.
3
+
4
+ This module provides a database interface for querying clinical trials data
5
+ from the SQLite database.
6
+ """
7
+
8
+ import sqlite3
9
+ from contextlib import contextmanager
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import pandas as pd
14
+
15
+ from aixtools.logging.logging_config import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class DatabaseError(Exception):
21
+ """Exception raised for database-related errors."""
22
+
23
+
24
+ class SqliteDb:
25
+ """
26
+ Database interface.
27
+ """
28
+
29
+ def __init__(self, db_path: str | Path):
30
+ """Initialize the database interface"""
31
+ self.db_path = Path(db_path)
32
+ if not self.db_path.exists():
33
+ raise FileNotFoundError(f"Database file not found: {self.db_path}")
34
+ # Test connection
35
+ with self.connection() as conn:
36
+ logger.info("Connected to database: %s, connection: %s", self.db_path, conn)
37
+
38
+ @contextmanager
39
+ def connection(self):
40
+ """
41
+ Context manager for database connections.
42
+
43
+ Yields:
44
+ sqlite3.Connection: An active database connection
45
+ """
46
+ conn = None
47
+ try:
48
+ conn = sqlite3.connect(self.db_path)
49
+ # Enable dictionary row factory
50
+ conn.row_factory = sqlite3.Row
51
+ yield conn
52
+ except sqlite3.Error as e:
53
+ raise DatabaseError(f"Database error: {e}") from e
54
+ finally:
55
+ if conn:
56
+ conn.close()
57
+
58
+ def query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
59
+ """
60
+ Execute a SQL query and return the results as a list of dictionaries.
61
+
62
+ Args:
63
+ query: SQL query to execute
64
+ params: Parameters for the query
65
+
66
+ Returns:
67
+ List of dictionaries representing the query results
68
+ """
69
+ with self.connection() as conn:
70
+ cursor = conn.cursor()
71
+ if params:
72
+ cursor.execute(query, params)
73
+ else:
74
+ cursor.execute(query)
75
+
76
+ results = cursor.fetchall()
77
+ # Convert sqlite3.Row objects to dictionaries
78
+ return [dict(row) for row in results]
79
+
80
+ def query_df(self, query: str, params: dict[str, Any] | None = None) -> pd.DataFrame:
81
+ """
82
+ Execute a SQL query and return the results as a pandas DataFrame.
83
+
84
+ Args:
85
+ query: SQL query to execute.
86
+ params: Parameters to substitute in the query.
87
+
88
+ Returns:
89
+ A pandas DataFrame containing the query results.
90
+ """
91
+ with self.connection() as conn:
92
+ if params:
93
+ df = pd.read_sql_query(query, conn, params=params)
94
+ else:
95
+ df = pd.read_sql_query(query, conn)
96
+ return df
97
+
98
+ def validate(self, query) -> str | None:
99
+ """
100
+ Validate the SQL query by executing an EXPLAIN QUERY PLAN statement.
101
+ Returns the error string if there is an issue, otherwise returns None
102
+ """
103
+ with self.connection() as conn:
104
+ try:
105
+ cursor = conn.cursor()
106
+ cursor.execute(f"EXPLAIN QUERY PLAN\n{query}")
107
+ cursor.fetchall()
108
+ return None
109
+ except Exception as e: # pylint: disable=broad-exception-caught
110
+ return str(e)