aixtools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/__init__.py +5 -0
- aixtools/a2a/__init__.py +5 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +74 -0
- aixtools/app.py +143 -0
- aixtools/context.py +12 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +116 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +320 -0
- aixtools/model_patch/model_patch.py +65 -0
- aixtools/server/__init__.py +23 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +147 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +118 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +160 -0
- aixtools-0.1.0.dist-info/METADATA +355 -0
- aixtools-0.1.0.dist-info/RECORD +58 -0
- aixtools-0.1.0.dist-info/WHEEL +5 -0
- aixtools-0.1.0.dist-info/entry_points.txt +2 -0
- aixtools-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import chainlit as cl
|
|
2
|
+
import rich
|
|
3
|
+
from pydantic_ai import Agent
|
|
4
|
+
from pydantic_ai.messages import (
|
|
5
|
+
FinalResultEvent,
|
|
6
|
+
FunctionToolCallEvent,
|
|
7
|
+
FunctionToolResultEvent,
|
|
8
|
+
PartDeltaEvent,
|
|
9
|
+
PartStartEvent,
|
|
10
|
+
TextPartDelta,
|
|
11
|
+
ToolCallPartDelta,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from aixtools.logging.log_objects import ObjectLogger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _show_debug_info(debug, *args):
|
|
18
|
+
if debug:
|
|
19
|
+
rich.print(*args)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
async def show_run(agent: Agent, prompt, msg: cl.Message, debug=False, verbose=True): # noqa: PLR0912
|
|
23
|
+
"""Run an agent with a prompt and send the results to a message."""
|
|
24
|
+
nodes = []
|
|
25
|
+
async with agent.iter(prompt) as run:
|
|
26
|
+
with ObjectLogger(debug=debug, verbose=verbose) as agent_logger:
|
|
27
|
+
async for node in run:
|
|
28
|
+
nodes.append(node)
|
|
29
|
+
agent_logger.log(node)
|
|
30
|
+
if Agent.is_user_prompt_node(node):
|
|
31
|
+
# A user prompt node => The user has provided input
|
|
32
|
+
_show_debug_info(debug, "=== UserPromptNode: ", node)
|
|
33
|
+
elif Agent.is_model_request_node(node):
|
|
34
|
+
# A model request node => We can stream tokens from the model's request
|
|
35
|
+
_show_debug_info(debug, "=== ModelRequestNode: streaming partial request tokens ===")
|
|
36
|
+
async with node.stream(run.ctx) as request_stream:
|
|
37
|
+
async for event in request_stream:
|
|
38
|
+
if isinstance(event, PartStartEvent):
|
|
39
|
+
_show_debug_info(debug, f"[Request] Starting part {event.index}: ", event.part)
|
|
40
|
+
elif isinstance(event, PartDeltaEvent):
|
|
41
|
+
if isinstance(event.delta, TextPartDelta):
|
|
42
|
+
_show_debug_info(
|
|
43
|
+
debug,
|
|
44
|
+
(
|
|
45
|
+
"[ModelRequestNone / PartDeltaEvent / TextPartDelta] "
|
|
46
|
+
f"Part {event.index}: {event.delta.content_delta}"
|
|
47
|
+
),
|
|
48
|
+
)
|
|
49
|
+
await msg.stream_token(event.delta.content_delta)
|
|
50
|
+
elif isinstance(event.delta, ToolCallPartDelta):
|
|
51
|
+
_show_debug_info(
|
|
52
|
+
debug,
|
|
53
|
+
f"[ModelRequestNone / PartDeltaEvent / ToolCallPartDelta] Part {event.index}, ",
|
|
54
|
+
event.delta,
|
|
55
|
+
)
|
|
56
|
+
elif isinstance(event, FinalResultEvent):
|
|
57
|
+
_show_debug_info(
|
|
58
|
+
debug, f"[Result] The model produced a final result (tool_name={event.tool_name})"
|
|
59
|
+
)
|
|
60
|
+
elif Agent.is_call_tools_node(node):
|
|
61
|
+
# A handle-response node => The model returned some data, potentially calls a tool
|
|
62
|
+
_show_debug_info(debug, "=== CallToolsNode: streaming partial response & tool usage ===")
|
|
63
|
+
async with node.stream(run.ctx) as handle_stream:
|
|
64
|
+
async for event in handle_stream:
|
|
65
|
+
if isinstance(event, FunctionToolCallEvent):
|
|
66
|
+
_show_debug_info(
|
|
67
|
+
debug,
|
|
68
|
+
(
|
|
69
|
+
f"[Tools] The LLM calls tool={event.part.tool_name!r} "
|
|
70
|
+
f"with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})"
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
elif isinstance(event, FunctionToolResultEvent):
|
|
74
|
+
_show_debug_info(
|
|
75
|
+
debug,
|
|
76
|
+
f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}",
|
|
77
|
+
)
|
|
78
|
+
elif Agent.is_end_node(node):
|
|
79
|
+
assert run.result.output == node.data.output
|
|
80
|
+
# Once an End node is reached, the agent run is complete
|
|
81
|
+
_show_debug_info(debug, f"=== Final Agent Output: {run.result.output} ===")
|
|
82
|
+
return run.result.output
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for Chainlit
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from typing import Callable, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from chainlit import Step
|
|
12
|
+
from chainlit.context import get_context
|
|
13
|
+
from literalai.observability.step import TrueStepType
|
|
14
|
+
|
|
15
|
+
from aixtools.logging.logging_config import get_logger
|
|
16
|
+
from aixtools.utils.utils import truncate
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
DEFAULT_SKIP_ARGS = ("self", "cls")
|
|
21
|
+
|
|
22
|
+
MAX_SIZE_STR = 10 * 1024
|
|
23
|
+
MAX_SIZE_DF_ROWS = 100
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_chainlit() -> bool:
|
|
27
|
+
"""Are we running in chainlit?"""
|
|
28
|
+
try:
|
|
29
|
+
get_context()
|
|
30
|
+
return True
|
|
31
|
+
except Exception:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def flatten_args_kwargs(func, args, kwargs, skip_args=DEFAULT_SKIP_ARGS):
|
|
36
|
+
signature = inspect.signature(func)
|
|
37
|
+
bound_arguments = signature.bind(*args, **kwargs)
|
|
38
|
+
bound_arguments.apply_defaults()
|
|
39
|
+
return {k: deepcopy(v) for k, v in bound_arguments.arguments.items() if k not in skip_args}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _step_name(func, args, kwargs):
|
|
43
|
+
"""
|
|
44
|
+
Create a step name: class.method
|
|
45
|
+
It detects the class name from the first method's argument.
|
|
46
|
+
"""
|
|
47
|
+
if len(args) == 0:
|
|
48
|
+
return func.__name__
|
|
49
|
+
signature = inspect.signature(func)
|
|
50
|
+
bound_arguments = signature.bind(*args, **kwargs)
|
|
51
|
+
arguments = [(k, v) for k, v in bound_arguments.arguments.items()]
|
|
52
|
+
arg0_name, arg0_value = arguments[0]
|
|
53
|
+
if arg0_name == "self":
|
|
54
|
+
return f"{arg0_value.__class__.__name__}.{func.__name__}"
|
|
55
|
+
if arg0_name == "cls":
|
|
56
|
+
return f"{arg0_value.__name__}.{func.__name__}"
|
|
57
|
+
return func.__name__
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def limit_size(data):
|
|
61
|
+
""" """
|
|
62
|
+
if isinstance(data, str):
|
|
63
|
+
return truncate(data, max_len=MAX_SIZE_STR)
|
|
64
|
+
if isinstance(data, pd.DataFrame):
|
|
65
|
+
if len(data) > MAX_SIZE_DF_ROWS:
|
|
66
|
+
return data.head(MAX_SIZE_DF_ROWS)
|
|
67
|
+
return data
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def cl_step( # noqa: PLR0913
|
|
71
|
+
original_function: Optional[Callable] = None,
|
|
72
|
+
*,
|
|
73
|
+
name: Optional[str] = "",
|
|
74
|
+
type: TrueStepType = "undefined",
|
|
75
|
+
id: Optional[str] = None,
|
|
76
|
+
parent_id: Optional[str] = None,
|
|
77
|
+
tags: Optional[List[str]] = None,
|
|
78
|
+
language: Optional[str] = None,
|
|
79
|
+
show_input: Union[bool, str] = "json",
|
|
80
|
+
default_open: bool = False,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Step decorator for async and sync functions and methods (they ignore the self argument).
|
|
84
|
+
It deactivates if not within a Chainlit context.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def wrapper(func: Callable):
|
|
88
|
+
# Handle async decorator
|
|
89
|
+
if inspect.iscoroutinefunction(func):
|
|
90
|
+
|
|
91
|
+
@wraps(func)
|
|
92
|
+
async def async_wrapper(*args, **kwargs):
|
|
93
|
+
nonlocal name
|
|
94
|
+
if not name:
|
|
95
|
+
name = _step_name(func, args, kwargs)
|
|
96
|
+
if is_chainlit():
|
|
97
|
+
async with Step(
|
|
98
|
+
type=type,
|
|
99
|
+
name=name,
|
|
100
|
+
id=id,
|
|
101
|
+
parent_id=parent_id,
|
|
102
|
+
tags=tags,
|
|
103
|
+
language=language,
|
|
104
|
+
show_input=show_input,
|
|
105
|
+
default_open=default_open,
|
|
106
|
+
) as step:
|
|
107
|
+
try:
|
|
108
|
+
step.input = flatten_args_kwargs(func, args, kwargs)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.exception(e)
|
|
111
|
+
result = await func(*args, **kwargs)
|
|
112
|
+
try:
|
|
113
|
+
if result and not step.output:
|
|
114
|
+
step.output = limit_size(result)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
step.is_error = True
|
|
117
|
+
step.output = str(e)
|
|
118
|
+
return result
|
|
119
|
+
else:
|
|
120
|
+
# If not in Chainlit, just call the function
|
|
121
|
+
result = await func(*args, **kwargs)
|
|
122
|
+
print(f"Function '{func.__name__}' called with args: {args}, kwargs: {kwargs}, result: {result}")
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
return async_wrapper
|
|
126
|
+
else:
|
|
127
|
+
# Handle sync decorator
|
|
128
|
+
@wraps(func)
|
|
129
|
+
def sync_wrapper(*args, **kwargs):
|
|
130
|
+
nonlocal name
|
|
131
|
+
if not name:
|
|
132
|
+
name = _step_name(func, args, kwargs)
|
|
133
|
+
if is_chainlit():
|
|
134
|
+
with Step(
|
|
135
|
+
type=type,
|
|
136
|
+
name=name,
|
|
137
|
+
id=id,
|
|
138
|
+
parent_id=parent_id,
|
|
139
|
+
tags=tags,
|
|
140
|
+
language=language,
|
|
141
|
+
show_input=show_input,
|
|
142
|
+
default_open=default_open,
|
|
143
|
+
) as step:
|
|
144
|
+
try:
|
|
145
|
+
step.input = flatten_args_kwargs(func, args, kwargs)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.exception(e)
|
|
148
|
+
result = func(*args, **kwargs)
|
|
149
|
+
try:
|
|
150
|
+
if result and not step.output:
|
|
151
|
+
step.output = limit_size(result)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
step.is_error = True
|
|
154
|
+
step.output = str(e)
|
|
155
|
+
return result
|
|
156
|
+
else:
|
|
157
|
+
# If not in Chainlit, just call the function
|
|
158
|
+
result = func(*args, **kwargs)
|
|
159
|
+
print(f"Function '{func.__name__}' called with args: {args}, kwargs: {kwargs}, result: {result}")
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
return sync_wrapper
|
|
163
|
+
|
|
164
|
+
func = original_function
|
|
165
|
+
if not func:
|
|
166
|
+
return wrapper
|
|
167
|
+
else:
|
|
168
|
+
return wrapper(func)
|
aixtools/utils/config.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration settings and environment variables for the application.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from dotenv import dotenv_values, load_dotenv
|
|
10
|
+
|
|
11
|
+
from aixtools.utils.config_util import find_env_file, get_project_root, get_variable_env
|
|
12
|
+
|
|
13
|
+
# Debug mode
|
|
14
|
+
LOG_LEVEL = logging.DEBUG
|
|
15
|
+
|
|
16
|
+
# Set up some environment variables (there are usually set up by 'config.sh')
|
|
17
|
+
|
|
18
|
+
# This file's path
|
|
19
|
+
FILE_PATH = Path(__file__).resolve()
|
|
20
|
+
|
|
21
|
+
# This project's root directory (AixTools)
|
|
22
|
+
# if installed as a package, it will be `.venv/lib/python3.x/site-packages/aixtools`
|
|
23
|
+
PROJECT_DIR = FILE_PATH.parent.parent.parent.resolve()
|
|
24
|
+
|
|
25
|
+
# Get the main project directory (the one project that is using this package)
|
|
26
|
+
PROJECT_ROOT = get_project_root()
|
|
27
|
+
|
|
28
|
+
# From the environment variables
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Iterate over all parents of FILE_PATH to find .env files
|
|
32
|
+
def all_parents(path: Path):
|
|
33
|
+
"""Yield all parent directories of a given path."""
|
|
34
|
+
while path.parent != path:
|
|
35
|
+
yield path
|
|
36
|
+
path = path.parent
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Set up environment search path
|
|
40
|
+
# Start with the most specific (current directory) and expand outward
|
|
41
|
+
env_dirs = [Path.cwd(), PROJECT_ROOT, FILE_PATH.parent]
|
|
42
|
+
env_file = find_env_file(env_dirs)
|
|
43
|
+
|
|
44
|
+
if env_file:
|
|
45
|
+
logging.info("Using .env file at '%s'", env_file)
|
|
46
|
+
# Load the environment variables from the found .env file
|
|
47
|
+
load_dotenv(env_file)
|
|
48
|
+
# Assign project dir based on the .env file
|
|
49
|
+
MAIN_PROJECT_DIR = Path(env_file).parent
|
|
50
|
+
logging.info("Using MAIN_PROJECT_DIR='%s'", MAIN_PROJECT_DIR)
|
|
51
|
+
# Assign variables in '.env' global python environment
|
|
52
|
+
env_vars = dotenv_values(env_file)
|
|
53
|
+
globals().update(env_vars)
|
|
54
|
+
else:
|
|
55
|
+
logging.error("No '.env' file found in any of the search paths, or their parents: %s", env_dirs)
|
|
56
|
+
sys.exit(1)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ---
|
|
60
|
+
# Directories
|
|
61
|
+
# ---
|
|
62
|
+
SCRIPTS_DIR = MAIN_PROJECT_DIR / "scripts"
|
|
63
|
+
DATA_DIR = Path(get_variable_env("DATA_DIR") or MAIN_PROJECT_DIR / "data")
|
|
64
|
+
DATA_DB_DIR = Path(get_variable_env("DATA_DB_DIR", default=DATA_DIR / "db"))
|
|
65
|
+
LOGS_DIR = MAIN_PROJECT_DIR / "logs"
|
|
66
|
+
|
|
67
|
+
logging.warning("Using DATA_DIR='%s'", DATA_DIR)
|
|
68
|
+
|
|
69
|
+
# Vector database
|
|
70
|
+
VDB_CHROMA_PATH = DATA_DB_DIR / "chroma.db"
|
|
71
|
+
VDB_DEFAULT_SIMILARITY_THRESHOLD = 0.85
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---
|
|
75
|
+
# Variables in '.env' file
|
|
76
|
+
# Explicitly load specific variables
|
|
77
|
+
# ---
|
|
78
|
+
|
|
79
|
+
MODEL_TIMEOUT = int(get_variable_env("MODEL_TIMEOUT", default="120")) # type: ignore
|
|
80
|
+
|
|
81
|
+
MODEL_FAMILY = get_variable_env("MODEL_FAMILY")
|
|
82
|
+
|
|
83
|
+
# Azure models
|
|
84
|
+
AZURE_MODEL_NAME = get_variable_env("AZURE_MODEL_NAME")
|
|
85
|
+
AZURE_OPENAI_ENDPOINT = get_variable_env("AZURE_OPENAI_ENDPOINT")
|
|
86
|
+
AZURE_OPENAI_API_KEY = get_variable_env("AZURE_OPENAI_API_KEY")
|
|
87
|
+
AZURE_OPENAI_API_VERSION = get_variable_env("AZURE_OPENAI_API_VERSION")
|
|
88
|
+
|
|
89
|
+
# OpenAI models
|
|
90
|
+
OPENAI_API_KEY = get_variable_env("OPENAI_API_KEY")
|
|
91
|
+
OPENAI_MODEL_NAME = get_variable_env("OPENAI_MODEL_NAME")
|
|
92
|
+
|
|
93
|
+
# Ollama models
|
|
94
|
+
OLLAMA_URL = get_variable_env("OLLAMA_URL")
|
|
95
|
+
OLLAMA_MODEL_NAME = get_variable_env("OLLAMA_MODEL_NAME")
|
|
96
|
+
|
|
97
|
+
# OpenRouter models
|
|
98
|
+
OPENROUTER_API_KEY = get_variable_env("OPENROUTER_API_KEY")
|
|
99
|
+
OPENROUTER_API_URL = get_variable_env("OPENROUTER_API_URL", default="https://openrouter.ai/api/v1")
|
|
100
|
+
OPENROUTER_MODEL_NAME = get_variable_env("OPENROUTER_MODEL_NAME")
|
|
101
|
+
|
|
102
|
+
# Embeddings
|
|
103
|
+
VDB_EMBEDDINGS_MODEL_FAMILY = get_variable_env("VDB_EMBEDDINGS_MODEL_FAMILY")
|
|
104
|
+
OPENAI_VDB_EMBEDDINGS_MODEL_NAME = get_variable_env("OPENAI_VDB_EMBEDDINGS_MODEL_NAME")
|
|
105
|
+
AZURE_VDB_EMBEDDINGS_MODEL_NAME = get_variable_env("AZURE_VDB_EMBEDDINGS_MODEL_NAME")
|
|
106
|
+
OLLAMA_VDB_EMBEDDINGS_MODEL_NAME = get_variable_env("OLLAMA_VDB_EMBEDDINGS_MODEL_NAME")
|
|
107
|
+
|
|
108
|
+
# Bedrock models
|
|
109
|
+
AWS_ACCESS_KEY_ID = get_variable_env("AWS_ACCESS_KEY_ID", allow_empty=True)
|
|
110
|
+
AWS_SECRET_ACCESS_KEY = get_variable_env("AWS_SECRET_ACCESS_KEY", allow_empty=True)
|
|
111
|
+
AWS_SESSION_TOKEN = get_variable_env("AWS_SESSION_TOKEN", allow_empty=True)
|
|
112
|
+
AWS_REGION = get_variable_env("AWS_REGION", allow_empty=True, default="us-east-1")
|
|
113
|
+
AWS_PROFILE = get_variable_env("AWS_PROFILE", allow_empty=True)
|
|
114
|
+
BEDROCK_MODEL_NAME = get_variable_env("BEDROCK_MODEL_NAME", allow_empty=True)
|
|
115
|
+
|
|
116
|
+
# LogFire
|
|
117
|
+
LOGFIRE_TOKEN = get_variable_env("LOGFIRE_TOKEN", True, "")
|
|
118
|
+
LOGFIRE_TRACES_ENDPOINT = get_variable_env("LOGFIRE_TRACES_ENDPOINT", True, "")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for configuration management and environment variables.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from dotenv import find_dotenv
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_project_root() -> Path:
|
|
14
|
+
"""
|
|
15
|
+
Return the directory where the main script lives.
|
|
16
|
+
Falls back to the current working directory if run interactively.
|
|
17
|
+
"""
|
|
18
|
+
main_mod = sys.modules.get("__main__")
|
|
19
|
+
main_file = getattr(main_mod, "__file__", None)
|
|
20
|
+
if main_file:
|
|
21
|
+
return Path(main_file).resolve().parent
|
|
22
|
+
|
|
23
|
+
# no __file__ (e.g. interactive shell); assume cwd is the project root
|
|
24
|
+
return Path.cwd()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def all_parents(path: Path):
|
|
28
|
+
"""Yield all parent directories of a given path."""
|
|
29
|
+
while path.parent != path:
|
|
30
|
+
yield path
|
|
31
|
+
path = path.parent
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def find_env_file(env_search_dirs: list[Path]):
|
|
35
|
+
"""Find the first .env file in the given list of paths and their parents."""
|
|
36
|
+
env_file = find_dotenv()
|
|
37
|
+
logging.warning("Looking for '.env' file in default directory")
|
|
38
|
+
if env_file:
|
|
39
|
+
return env_file
|
|
40
|
+
# Find all parents of the paths
|
|
41
|
+
for search_dir in env_search_dirs:
|
|
42
|
+
# '.env' file in this directory?
|
|
43
|
+
logging.warning("Looking for '.env' file at '%s'", search_dir)
|
|
44
|
+
env_file = find_dotenv(str(search_dir / ".env"))
|
|
45
|
+
if env_file:
|
|
46
|
+
return env_file
|
|
47
|
+
# Try all parents of this dir
|
|
48
|
+
for parent_dir in all_parents(search_dir):
|
|
49
|
+
logging.warning("Looking for '.env' file at '%s'", parent_dir)
|
|
50
|
+
env_file = find_dotenv(str(parent_dir / ".env"))
|
|
51
|
+
if env_file:
|
|
52
|
+
return env_file
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_variable_env(name: str, allow_empty=True, default=None) -> str | None:
|
|
57
|
+
"""Retrieve environment variable with optional validation and default value."""
|
|
58
|
+
val = os.environ.get(name, default)
|
|
59
|
+
if not allow_empty and ((val is None) or (val == "")):
|
|
60
|
+
raise ValueError(f"Environment variable {name} is not set")
|
|
61
|
+
return val
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def set_variable_env(name: str, val: str) -> str:
|
|
65
|
+
"""Set environment variable and validate it's not None."""
|
|
66
|
+
os.environ[name] = val
|
|
67
|
+
if val is None:
|
|
68
|
+
raise ValueError(f"Environment variable {name} is set to None")
|
|
69
|
+
return val
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced Enum implementation that supports descriptions for enum values.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EnumWithDescription(str, Enum):
|
|
9
|
+
"""
|
|
10
|
+
An enum with string values and descriptions.
|
|
11
|
+
Each enum value has a string representation and a description.
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
class MyEnum(EnumWithDescription):
|
|
15
|
+
VALUE1 = "value1", "This is a description for VALUE1"
|
|
16
|
+
VALUE2 = "value2", "This is a description for VALUE2"
|
|
17
|
+
VALUE3 = "value3", "This is a description for VALUE3"
|
|
18
|
+
|
|
19
|
+
print(MyEnum.describe())
|
|
20
|
+
# Output:
|
|
21
|
+
# VALUE1: This is a description for VALUE1
|
|
22
|
+
# VALUE2: This is a description for VALUE2
|
|
23
|
+
# VALUE3: This is a description for VALUE3
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def describe(cls) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Get the description of a decision's enum values
|
|
30
|
+
"""
|
|
31
|
+
return "\n".join([f"{field.name}: {field.__doc__}" for field in cls])
|
|
32
|
+
|
|
33
|
+
def __new__(cls, value, doc):
|
|
34
|
+
obj = str.__new__(cls, value)
|
|
35
|
+
obj._value_ = value
|
|
36
|
+
obj.__doc__ = doc
|
|
37
|
+
return obj
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dictionary implementation that automatically persists its contents to disk.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import pickle
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from aixtools.logging.logging_config import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
DATA_KEY = "__dictionary_data__"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PersistedDict(dict):
|
|
17
|
+
"""
|
|
18
|
+
A dictionary that persists to a file on disk as JSON.
|
|
19
|
+
Keys are always converted to strings.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, file_path: Path):
|
|
23
|
+
self.file_path = file_path if isinstance(file_path, Path) else Path(file_path)
|
|
24
|
+
self.use_pickle = None
|
|
25
|
+
if file_path.suffix == ".json":
|
|
26
|
+
self.use_pickle = False
|
|
27
|
+
elif file_path.suffix == ".pkl":
|
|
28
|
+
self.use_pickle = True
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unsupported file extension '{file_path.suffix}' for file '{file_path}'")
|
|
31
|
+
self.load()
|
|
32
|
+
|
|
33
|
+
def __contains__(self, key):
|
|
34
|
+
return super().__contains__(str(key))
|
|
35
|
+
|
|
36
|
+
def __delitem__(self, key):
|
|
37
|
+
super().__delitem__(str(key))
|
|
38
|
+
self.save()
|
|
39
|
+
|
|
40
|
+
def get(self, key, default=None):
|
|
41
|
+
return super().get(str(key), default)
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, key):
|
|
44
|
+
return super().__getitem__(str(key))
|
|
45
|
+
|
|
46
|
+
def load(self):
|
|
47
|
+
"""Load dictionary data from disk using either pickle or JSON format."""
|
|
48
|
+
if self.use_pickle:
|
|
49
|
+
self._load_pickle()
|
|
50
|
+
else:
|
|
51
|
+
self._load_json()
|
|
52
|
+
|
|
53
|
+
def _load_json(self):
|
|
54
|
+
try:
|
|
55
|
+
with open(self.file_path, "r", encoding="utf-8") as f:
|
|
56
|
+
self.update(json.load(f))
|
|
57
|
+
logger.debug("Persistent dictionary: Loaded %d items from JSON file '%s'", len(self), self.file_path)
|
|
58
|
+
except FileNotFoundError:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def _load_pickle(self):
|
|
62
|
+
try:
|
|
63
|
+
with open(self.file_path, "rb") as f:
|
|
64
|
+
object_data = pickle.load(f)
|
|
65
|
+
for k, v in object_data[DATA_KEY].items():
|
|
66
|
+
super().__setitem__(str(k), v)
|
|
67
|
+
for k, v in object_data.items():
|
|
68
|
+
if k != DATA_KEY:
|
|
69
|
+
self.__dict__[k] = v
|
|
70
|
+
logger.debug("Persistent dictionary: Loaded %d items from pickle file '%s'", len(self), self.file_path)
|
|
71
|
+
except FileNotFoundError:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
def save(self):
|
|
75
|
+
"""Save dictionary data to disk using either pickle or JSON format."""
|
|
76
|
+
if self.use_pickle:
|
|
77
|
+
self._save_pickle()
|
|
78
|
+
else:
|
|
79
|
+
self._save_json()
|
|
80
|
+
|
|
81
|
+
def _save_json(self):
|
|
82
|
+
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
with open(self.file_path, "w", encoding="utf-8") as f:
|
|
84
|
+
json.dump(self, f, indent=2)
|
|
85
|
+
|
|
86
|
+
def _save_pickle(self):
|
|
87
|
+
self.file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
with open(self.file_path, "wb") as f:
|
|
89
|
+
object_data = dict(self.__dict__)
|
|
90
|
+
object_data[DATA_KEY] = dict(self)
|
|
91
|
+
pickle.dump(object_data, f)
|
|
92
|
+
|
|
93
|
+
def __setitem__(self, key, value):
|
|
94
|
+
super().__setitem__(str(key), value)
|
|
95
|
+
self.save()
|
|
96
|
+
|
|
97
|
+
def update(self, *args, **kwargs):
|
|
98
|
+
super().update(*args, **kwargs)
|
|
99
|
+
self.save()
|