aixtools 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/_version.py +2 -2
- aixtools/agents/agent.py +26 -7
- aixtools/agents/print_nodes.py +54 -0
- aixtools/agents/prompt.py +2 -2
- aixtools/compliance/private_data.py +1 -1
- aixtools/evals/__init__.py +0 -0
- aixtools/evals/discovery.py +174 -0
- aixtools/evals/evals.py +74 -0
- aixtools/evals/run_evals.py +110 -0
- aixtools/logging/log_objects.py +24 -23
- aixtools/mcp/client.py +46 -1
- aixtools/server/__init__.py +0 -6
- aixtools/server/path.py +88 -31
- aixtools/testing/aix_test_model.py +7 -1
- aixtools/tools/doctor/mcp_tool_doctor.py +79 -0
- aixtools/tools/doctor/tool_doctor.py +4 -0
- aixtools/tools/doctor/tool_recommendation.py +5 -0
- aixtools/utils/config.py +0 -1
- {aixtools-0.1.11.dist-info → aixtools-0.2.1.dist-info}/METADATA +185 -30
- {aixtools-0.1.11.dist-info → aixtools-0.2.1.dist-info}/RECORD +23 -18
- aixtools-0.2.1.dist-info/entry_points.txt +4 -0
- aixtools/server/workspace_privacy.py +0 -65
- aixtools-0.1.11.dist-info/entry_points.txt +0 -2
- {aixtools-0.1.11.dist-info → aixtools-0.2.1.dist-info}/WHEEL +0 -0
- {aixtools-0.1.11.dist-info → aixtools-0.2.1.dist-info}/top_level.txt +0 -0
aixtools/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.2.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 1)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
aixtools/agents/agent.py
CHANGED
|
@@ -5,10 +5,11 @@ Core agent implementation providing model selection and configuration for AI age
|
|
|
5
5
|
from types import NoneType
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
+
from fastmcp import Context
|
|
8
9
|
from openai import AsyncAzureOpenAI
|
|
9
10
|
from pydantic_ai import Agent
|
|
10
11
|
from pydantic_ai.models.bedrock import BedrockConverseModel
|
|
11
|
-
from pydantic_ai.models.openai import
|
|
12
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
12
13
|
from pydantic_ai.providers.bedrock import BedrockProvider
|
|
13
14
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
14
15
|
from pydantic_ai.settings import ModelSettings
|
|
@@ -54,14 +55,14 @@ def _get_model_ollama(model_name=OLLAMA_MODEL_NAME, ollama_url=OLLAMA_URL):
|
|
|
54
55
|
assert ollama_url, "OLLAMA_URL is not set"
|
|
55
56
|
assert model_name, "Model name is not set"
|
|
56
57
|
provider = OpenAIProvider(base_url=ollama_url)
|
|
57
|
-
return
|
|
58
|
+
return OpenAIChatModel(model_name=model_name, provider=provider)
|
|
58
59
|
|
|
59
60
|
|
|
60
61
|
def _get_model_openai(model_name=OPENAI_MODEL_NAME, openai_api_key=OPENAI_API_KEY):
|
|
61
62
|
assert openai_api_key, "OPENAI_API_KEY is not set"
|
|
62
63
|
assert model_name, "Model name is not set"
|
|
63
64
|
provider = OpenAIProvider(api_key=openai_api_key)
|
|
64
|
-
return
|
|
65
|
+
return OpenAIChatModel(model_name=model_name, provider=provider)
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
def _get_model_openai_azure(
|
|
@@ -77,7 +78,7 @@ def _get_model_openai_azure(
|
|
|
77
78
|
client = AsyncAzureOpenAI(
|
|
78
79
|
azure_endpoint=azure_openai_endpoint, api_version=azure_openai_api_version, api_key=azure_openai_api_key
|
|
79
80
|
)
|
|
80
|
-
return
|
|
81
|
+
return OpenAIChatModel(model_name=model_name, provider=OpenAIProvider(openai_client=client))
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
def _get_model_open_router(
|
|
@@ -87,7 +88,7 @@ def _get_model_open_router(
|
|
|
87
88
|
assert openrouter_api_key, "OPENROUTER_API_KEY is not set"
|
|
88
89
|
assert model_name, "Model name is not set, missing 'OPENROUTER_MODEL_NAME' environment variable?"
|
|
89
90
|
provider = OpenAIProvider(base_url=openrouter_api_url, api_key=openrouter_api_key)
|
|
90
|
-
return
|
|
91
|
+
return OpenAIChatModel(model_name, provider=provider)
|
|
91
92
|
|
|
92
93
|
|
|
93
94
|
def get_model(model_family=MODEL_FAMILY, model_name=None, **kwargs):
|
|
@@ -146,8 +147,22 @@ async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-ma
|
|
|
146
147
|
debug: bool = False,
|
|
147
148
|
log_model_requests: bool = False,
|
|
148
149
|
parent_logger: ObjectLogger | None = None,
|
|
150
|
+
ctx: Context | None = None,
|
|
149
151
|
):
|
|
150
|
-
"""
|
|
152
|
+
"""
|
|
153
|
+
Run the agent with the given prompt and log the execution details.
|
|
154
|
+
Args:
|
|
155
|
+
agent (Agent): The PydanticAI agent to run.
|
|
156
|
+
prompt (str | list[str]): The input prompt(s) for the agent.
|
|
157
|
+
usage_limits (UsageLimits | None): Optional usage limits for the agent.
|
|
158
|
+
verbose (bool): If True, enables verbose logging.
|
|
159
|
+
debug (bool): If True, enables debug logging.
|
|
160
|
+
log_model_requests (bool): If True, logs model requests and responses.
|
|
161
|
+
parent_logger (ObjectLogger | None): Optional parent logger for hierarchical logging.
|
|
162
|
+
ctx (Context | None): Optional FastMCP context for logging messages to the MCP client.
|
|
163
|
+
Returns:
|
|
164
|
+
tuple[final_output, nodes]: A tuple containing the agent's final output and a list of all logged nodes.
|
|
165
|
+
"""
|
|
151
166
|
# Results
|
|
152
167
|
nodes, result = [], None
|
|
153
168
|
async with agent.iter(prompt, usage_limits=usage_limits) as agent_run:
|
|
@@ -158,7 +173,11 @@ async def run_agent( # noqa: PLR0913, pylint: disable=too-many-arguments,too-ma
|
|
|
158
173
|
agent.model = model_patch_logging(agent.model, agent_logger)
|
|
159
174
|
# Run the agent
|
|
160
175
|
async for node in agent_run:
|
|
161
|
-
agent_logger.log(node)
|
|
176
|
+
await agent_logger.log(node) # Log each node
|
|
177
|
+
if ctx:
|
|
178
|
+
# If we are executing in an MCP server, send info messages to the client for better debugging
|
|
179
|
+
server_name = ctx.fastmcp.name
|
|
180
|
+
await ctx.info(f"MCP server {server_name}: {node}")
|
|
162
181
|
nodes.append(node)
|
|
163
182
|
result = agent_run.result
|
|
164
183
|
return result.output if result else None, nodes
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Utility functions to print nodes and their parts in a readable format."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import CallToolsNode, ModelRequestNode, UserPromptNode
|
|
4
|
+
from pydantic_ai.messages import TextPart, ToolCallPart
|
|
5
|
+
from pydantic_graph.nodes import End
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def tab(s, prefix: str = "\t|") -> str:
|
|
9
|
+
""" "Tab a string with a given prefix (default is tab + pipe)."""
|
|
10
|
+
return prefix + str(s).replace("\n", "\n" + prefix)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def part2str(p, prefix: str = "\t"):
|
|
14
|
+
"""Convert a Part to a string representation."""
|
|
15
|
+
match p:
|
|
16
|
+
case ToolCallPart():
|
|
17
|
+
return f"{prefix}Tool: {p.tool_name}, args: {p.args}"
|
|
18
|
+
case TextPart():
|
|
19
|
+
return f"{prefix}Text: {tab(p.content)}"
|
|
20
|
+
case _:
|
|
21
|
+
return f"{prefix}Part {type(p)}: {p}"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def print_parts(parts, prefix: str = ""):
|
|
25
|
+
"""Print a list of Parts with a given prefix."""
|
|
26
|
+
if len(parts) == 0:
|
|
27
|
+
print(f"{prefix}No parts")
|
|
28
|
+
return
|
|
29
|
+
if len(parts) == 1:
|
|
30
|
+
print(part2str(parts[0], prefix=prefix))
|
|
31
|
+
return
|
|
32
|
+
for p in parts:
|
|
33
|
+
print(f"{part2str(p, prefix=prefix)}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def print_node(n):
|
|
37
|
+
"""Print a node in a readable format."""
|
|
38
|
+
match n:
|
|
39
|
+
case UserPromptNode():
|
|
40
|
+
print(f"Prompt:\n{tab(n.user_prompt)}")
|
|
41
|
+
case CallToolsNode():
|
|
42
|
+
print_parts(n.model_response.parts)
|
|
43
|
+
case ModelRequestNode():
|
|
44
|
+
print(f"Model request: ~ {len(str(n))} chars")
|
|
45
|
+
case End():
|
|
46
|
+
pass # print(f"End:\n{tab(n.data.output)}")
|
|
47
|
+
case _:
|
|
48
|
+
print(f"{type(n)}: {n}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def print_nodes(nodes):
|
|
52
|
+
"""Print a list of nodes in a readable format."""
|
|
53
|
+
for n in nodes:
|
|
54
|
+
print_node(n)
|
aixtools/agents/prompt.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Prompt building utilities for Pydantic AI agent, including file handling and context management."""
|
|
2
2
|
|
|
3
3
|
import mimetypes
|
|
4
|
-
from pathlib import Path
|
|
4
|
+
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
from pydantic_ai import BinaryContent
|
|
7
7
|
|
|
@@ -78,7 +78,7 @@ def build_user_input(
|
|
|
78
78
|
binary_attachments = []
|
|
79
79
|
|
|
80
80
|
for workspace_path in file_paths:
|
|
81
|
-
host_path = container_to_host_path(
|
|
81
|
+
host_path = container_to_host_path(workspace_path, ctx=session_tuple)
|
|
82
82
|
file_size = host_path.stat().st_size
|
|
83
83
|
mime_type, _ = mimetypes.guess_type(host_path)
|
|
84
84
|
mime_type = mime_type or "application/octet-stream"
|
|
@@ -88,7 +88,7 @@ class PrivateData:
|
|
|
88
88
|
|
|
89
89
|
def _get_private_data_path(self) -> Path:
|
|
90
90
|
"""Get the path to the private data file in the workspace."""
|
|
91
|
-
return get_workspace_path(
|
|
91
|
+
return get_workspace_path(ctx=self.ctx) / PRIVATE_DATA_FILE
|
|
92
92
|
|
|
93
93
|
def _has_private_data_file(self) -> bool:
|
|
94
94
|
"""Check if the private data file exists in the workspace."""
|
|
File without changes
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset discovery functionality for LLM evaluations.
|
|
3
|
+
|
|
4
|
+
This module handles discovering and loading Dataset objects from eval_*.py files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import inspect
|
|
9
|
+
import sys
|
|
10
|
+
import traceback
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from pydantic_evals.dataset import Dataset
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def find_eval_files(evals_dir: Path) -> list[Path]:
|
|
18
|
+
"""Find all eval_*.py files in the evals directory."""
|
|
19
|
+
if not evals_dir.exists():
|
|
20
|
+
print(f"Error: Evals directory '{evals_dir}' does not exist")
|
|
21
|
+
sys.exit(1)
|
|
22
|
+
|
|
23
|
+
eval_files = list(evals_dir.glob("eval_*.py"))
|
|
24
|
+
if not eval_files:
|
|
25
|
+
print(f"No eval_*.py files found in '{evals_dir}'")
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
|
|
28
|
+
return eval_files
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def find_datasets_in_module(module: Any) -> list[tuple[str, Dataset]]:
|
|
32
|
+
"""Find all Dataset objects with names matching dataset_* in a module."""
|
|
33
|
+
datasets = []
|
|
34
|
+
|
|
35
|
+
for name, obj in inspect.getmembers(module):
|
|
36
|
+
if name.startswith("dataset_") and isinstance(obj, Dataset):
|
|
37
|
+
datasets.append((name, obj))
|
|
38
|
+
|
|
39
|
+
return datasets
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_module_from_file(file_path: Path) -> Any:
|
|
43
|
+
"""Load a Python module from a file path."""
|
|
44
|
+
module_name = file_path.stem
|
|
45
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
46
|
+
if spec is None or spec.loader is None:
|
|
47
|
+
raise ImportError(f"Could not load module from {file_path}")
|
|
48
|
+
|
|
49
|
+
module = importlib.util.module_from_spec(spec)
|
|
50
|
+
spec.loader.exec_module(module)
|
|
51
|
+
return module
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def matches_filter(module_name: str, file_name: str, dataset_name: str, name_filter: str | None) -> bool:
|
|
55
|
+
"""Check if the dataset matches the name filter."""
|
|
56
|
+
if name_filter is None:
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
# Check if filter matches any of: module name, file name, dataset name, or full qualified name
|
|
60
|
+
full_name = f"{module_name}.{dataset_name}"
|
|
61
|
+
return (
|
|
62
|
+
name_filter in module_name
|
|
63
|
+
or name_filter in file_name
|
|
64
|
+
or name_filter in dataset_name
|
|
65
|
+
or name_filter in full_name
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def find_target_function(module: Any) -> Any | None:
|
|
70
|
+
"""Find the first async function in a module that doesn't start with underscore."""
|
|
71
|
+
for name, obj in inspect.getmembers(module):
|
|
72
|
+
if inspect.iscoroutinefunction(obj) and not name.startswith("_"):
|
|
73
|
+
return obj
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_async_function_names(module: Any) -> list[str]:
|
|
78
|
+
"""Get names of all async functions in a module that don't start with underscore."""
|
|
79
|
+
return [
|
|
80
|
+
name
|
|
81
|
+
for name, obj in inspect.getmembers(module)
|
|
82
|
+
if inspect.iscoroutinefunction(obj) and not name.startswith("_")
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def process_datasets_from_module(
|
|
87
|
+
module: Any, eval_file: Path, name_filter: str | None, verbose: bool
|
|
88
|
+
) -> list[tuple[str, Dataset, Any]]:
|
|
89
|
+
"""Process all datasets from a single module and return valid dataset tuples."""
|
|
90
|
+
datasets = find_datasets_in_module(module)
|
|
91
|
+
if verbose:
|
|
92
|
+
print(f" Found {len(datasets)} datasets: {[name for name, _ in datasets]}")
|
|
93
|
+
|
|
94
|
+
valid_datasets = []
|
|
95
|
+
|
|
96
|
+
for dataset_name, dataset in datasets:
|
|
97
|
+
full_name = f"{eval_file.stem}.{dataset_name}"
|
|
98
|
+
|
|
99
|
+
if not matches_filter(module.__name__, eval_file.stem, dataset_name, name_filter):
|
|
100
|
+
if verbose:
|
|
101
|
+
print(f" ✗ Skipping dataset: {dataset_name} (doesn't match filter: {name_filter})")
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
if verbose:
|
|
105
|
+
print(f" ✓ Including dataset: {dataset_name}")
|
|
106
|
+
|
|
107
|
+
# Find the target function
|
|
108
|
+
target_function = find_target_function(module)
|
|
109
|
+
async_functions = get_async_function_names(module)
|
|
110
|
+
|
|
111
|
+
if verbose:
|
|
112
|
+
print(f" Found async functions: {async_functions}")
|
|
113
|
+
if target_function:
|
|
114
|
+
print(f" Using target function: {target_function.__name__}")
|
|
115
|
+
|
|
116
|
+
if target_function is None:
|
|
117
|
+
if verbose:
|
|
118
|
+
print(f"Warning: No async function found in {eval_file.name} for dataset {dataset_name}")
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
valid_datasets.append((full_name, dataset, target_function))
|
|
122
|
+
|
|
123
|
+
return valid_datasets
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def discover_all_datasets(
|
|
127
|
+
eval_files: list[Path], name_filter: str | None, verbose: bool
|
|
128
|
+
) -> list[tuple[str, Dataset, Any]]:
|
|
129
|
+
"""Discover all datasets from eval files."""
|
|
130
|
+
all_datasets = []
|
|
131
|
+
|
|
132
|
+
for eval_file in eval_files:
|
|
133
|
+
if verbose:
|
|
134
|
+
print(f"\nProcessing file: {eval_file}")
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
module = load_module_from_file(eval_file)
|
|
138
|
+
if verbose:
|
|
139
|
+
print(f" Loaded module: {module.__name__}")
|
|
140
|
+
|
|
141
|
+
datasets = process_datasets_from_module(module, eval_file, name_filter, verbose)
|
|
142
|
+
all_datasets.extend(datasets)
|
|
143
|
+
|
|
144
|
+
except Exception as e: # pylint: disable=W0718
|
|
145
|
+
if verbose:
|
|
146
|
+
print(f"Error loading {eval_file}: {e}")
|
|
147
|
+
print(f" Traceback: {traceback.format_exc()}")
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
# Check if any datasets were found
|
|
151
|
+
if not all_datasets:
|
|
152
|
+
print("No datasets found to evaluate")
|
|
153
|
+
if verbose:
|
|
154
|
+
print("This could be because:")
|
|
155
|
+
print(" - No eval_*.py files contain dataset_* objects")
|
|
156
|
+
print(" - The filter excluded all datasets")
|
|
157
|
+
print(" - There were errors loading the modules")
|
|
158
|
+
sys.exit(1)
|
|
159
|
+
|
|
160
|
+
# Print summary of discovered datasets
|
|
161
|
+
if verbose:
|
|
162
|
+
print(f"\n{'=' * 60}")
|
|
163
|
+
print("Datasets to Evaluate:")
|
|
164
|
+
print(f"{'=' * 60}")
|
|
165
|
+
for i, (dataset_name, dataset, target_function) in enumerate(all_datasets, 1):
|
|
166
|
+
print(f"{i}. {dataset_name}")
|
|
167
|
+
print(f" Target function: {target_function.__name__}")
|
|
168
|
+
print(f" Cases: {len(dataset.cases)}")
|
|
169
|
+
print(f" Evaluators: {len(dataset.evaluators)}")
|
|
170
|
+
print(f"{'=' * 60}")
|
|
171
|
+
else:
|
|
172
|
+
print(f"Found {len(all_datasets)} datasets to evaluate")
|
|
173
|
+
|
|
174
|
+
return all_datasets
|
aixtools/evals/evals.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Script to run all LLM evaluations.
|
|
4
|
+
|
|
5
|
+
This script discovers and runs all Dataset objects from eval_*.py files in the evals directory.
|
|
6
|
+
Similar to test runners but for LLM evaluations using pydantic_evals.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import asyncio
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from .discovery import discover_all_datasets, find_eval_files
|
|
15
|
+
from .run_evals import run_all_evaluations_and_print_results
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def main():
|
|
19
|
+
"""Main function to discover and run all evaluations."""
|
|
20
|
+
parser = argparse.ArgumentParser(description="Run LLM evaluations")
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"--evals-dir", type=Path, default=Path("evals"), help="Directory containing eval_*.py files (default: evals)"
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"--filter", type=str, help="Filter to run only matching evaluations (matches module, file, or dataset names)"
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument("--include-input", action="store_true", help="Include input in report output")
|
|
28
|
+
parser.add_argument("--include-output", action="store_true", help="Include output in report output")
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--include-evaluator-failures", action="store_true", help="Include evaluator failures in report output"
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument("--include-reasons", action="store_true", help="Include reasons in report output")
|
|
33
|
+
parser.add_argument(
|
|
34
|
+
"--min-assertions",
|
|
35
|
+
type=float,
|
|
36
|
+
default=1.0,
|
|
37
|
+
help="Minimum assertions average required for success (default: 1.0)",
|
|
38
|
+
)
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--verbose", action="store_true", help="Print detailed information about discovery and processing"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
args = parser.parse_args()
|
|
44
|
+
|
|
45
|
+
# Prepare print options
|
|
46
|
+
print_options = {
|
|
47
|
+
"include_input": args.include_input,
|
|
48
|
+
"include_output": args.include_output,
|
|
49
|
+
"include_evaluator_failures": args.include_evaluator_failures,
|
|
50
|
+
"include_reasons": args.include_reasons,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# Find all eval files
|
|
54
|
+
eval_files = find_eval_files(args.evals_dir)
|
|
55
|
+
if args.verbose:
|
|
56
|
+
print(f"Scanning directory: {args.evals_dir}")
|
|
57
|
+
print(f"Found {len(eval_files)} eval files:")
|
|
58
|
+
for f in eval_files:
|
|
59
|
+
print(f" - {f}")
|
|
60
|
+
|
|
61
|
+
# Discover all datasets
|
|
62
|
+
all_datasets = discover_all_datasets(eval_files, args.filter, args.verbose)
|
|
63
|
+
|
|
64
|
+
if args.filter and not args.verbose:
|
|
65
|
+
print(f"Filter applied: {args.filter}")
|
|
66
|
+
|
|
67
|
+
# Run all evaluations and print results
|
|
68
|
+
await run_all_evaluations_and_print_results(all_datasets, print_options, args.min_assertions, args.verbose)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
if __name__ == "__main__":
|
|
72
|
+
# Add the current directory to Python path so we can import modules
|
|
73
|
+
sys.path.insert(0, str(Path.cwd()))
|
|
74
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Evaluation execution functionality for LLM evaluations.
|
|
3
|
+
|
|
4
|
+
This module handles running evaluations and printing results.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from pydantic_evals.dataset import Dataset
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def run_dataset_evaluation( # noqa: PLR0913, pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
14
|
+
dataset_name: str,
|
|
15
|
+
dataset: Dataset,
|
|
16
|
+
target_function: Any,
|
|
17
|
+
print_options: dict[str, bool],
|
|
18
|
+
min_assertions: float,
|
|
19
|
+
verbose: bool = False,
|
|
20
|
+
) -> tuple[str, bool]:
|
|
21
|
+
"""Run evaluation for a single dataset and return (name, success)."""
|
|
22
|
+
if verbose:
|
|
23
|
+
print(f"\n{'=' * 60}")
|
|
24
|
+
print(f"Running evaluation: {dataset_name}")
|
|
25
|
+
print(f"{'=' * 60}")
|
|
26
|
+
else:
|
|
27
|
+
print(f"Running {dataset_name}...", end=" ")
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
# Execute the evaluation
|
|
31
|
+
report = await dataset.evaluate(target_function)
|
|
32
|
+
|
|
33
|
+
# Print the results
|
|
34
|
+
report.print(
|
|
35
|
+
include_input=print_options["include_input"],
|
|
36
|
+
include_output=print_options["include_output"],
|
|
37
|
+
include_evaluator_failures=print_options["include_evaluator_failures"],
|
|
38
|
+
include_reasons=print_options["include_reasons"],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Check if evaluation passed based on assertions average
|
|
42
|
+
averages = report.averages()
|
|
43
|
+
if averages and averages.assertions is not None:
|
|
44
|
+
success = averages.assertions >= min_assertions
|
|
45
|
+
if verbose:
|
|
46
|
+
print(f"\nEvaluation Summary for {dataset_name}:")
|
|
47
|
+
print(f" Assertions Average: {averages.assertions:.3f}")
|
|
48
|
+
print(f" Minimum Required: {min_assertions:.3f}")
|
|
49
|
+
print(f" Status: {'PASSED' if success else 'FAILED'}")
|
|
50
|
+
else:
|
|
51
|
+
print(f"{'PASSED' if success else 'FAILED'} ({averages.assertions:.3f})")
|
|
52
|
+
else:
|
|
53
|
+
success = False
|
|
54
|
+
if verbose:
|
|
55
|
+
print(f"\nEvaluation Summary for {dataset_name}:")
|
|
56
|
+
print(" No assertions found or evaluation failed")
|
|
57
|
+
print(f" Minimum Required: {min_assertions:.3f}")
|
|
58
|
+
print(" Status: FAILED")
|
|
59
|
+
else:
|
|
60
|
+
print("FAILED (no assertions)")
|
|
61
|
+
|
|
62
|
+
return dataset_name, success
|
|
63
|
+
|
|
64
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
65
|
+
if verbose:
|
|
66
|
+
print(f"Error running evaluation {dataset_name}: {e}")
|
|
67
|
+
else:
|
|
68
|
+
print(f"ERROR ({e})")
|
|
69
|
+
return dataset_name, False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def run_all_evaluations_and_print_results(
|
|
73
|
+
datasets: list[tuple[str, Dataset, Any]], print_options: dict[str, bool], min_assertions: float, verbose: bool
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Run all evaluations and print results with summary."""
|
|
76
|
+
# Run all evaluations
|
|
77
|
+
results = []
|
|
78
|
+
for dataset_name, dataset, target_function in datasets:
|
|
79
|
+
result = await run_dataset_evaluation(
|
|
80
|
+
dataset_name, dataset, target_function, print_options, min_assertions, verbose
|
|
81
|
+
)
|
|
82
|
+
results.append(result)
|
|
83
|
+
|
|
84
|
+
# Print summary
|
|
85
|
+
passed = sum(1 for _, success in results if success)
|
|
86
|
+
total = len(results)
|
|
87
|
+
failed_results = [(name, success) for name, success in results if not success]
|
|
88
|
+
|
|
89
|
+
if verbose:
|
|
90
|
+
print(f"\n{'=' * 60}")
|
|
91
|
+
print("EVALUATION SUMMARY")
|
|
92
|
+
print(f"{'=' * 60}")
|
|
93
|
+
|
|
94
|
+
for name, success in results:
|
|
95
|
+
status = "PASSED" if success else "FAILED"
|
|
96
|
+
print(f" {name}: {status}")
|
|
97
|
+
|
|
98
|
+
print(f"\nTotal: {passed}/{total} evaluations passed")
|
|
99
|
+
# Only show failed evaluations when not verbose
|
|
100
|
+
elif failed_results:
|
|
101
|
+
print("\nFailed evaluations:")
|
|
102
|
+
for name, _ in failed_results:
|
|
103
|
+
print(f" {name}: FAILED")
|
|
104
|
+
|
|
105
|
+
# Exit with non-zero code if any evaluations failed
|
|
106
|
+
if passed < total:
|
|
107
|
+
print(f"\n{total - passed} evaluation(s) failed")
|
|
108
|
+
sys.exit(1)
|
|
109
|
+
else:
|
|
110
|
+
print("\nAll evaluations passed!")
|
aixtools/logging/log_objects.py
CHANGED
|
@@ -114,7 +114,26 @@ def save_objects_to_logfile(objects: list, log_dir=LOGS_DIR):
|
|
|
114
114
|
object_logger.log(obj)
|
|
115
115
|
|
|
116
116
|
|
|
117
|
-
class
|
|
117
|
+
class BaseLogger:
|
|
118
|
+
"""
|
|
119
|
+
Base class for loggers.
|
|
120
|
+
A context manager for logging objects.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, **kwargs):
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def __enter__(self):
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
async def log(self, obj):
|
|
130
|
+
"""Log an object to the configured destination."""
|
|
131
|
+
|
|
132
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ObjectLogger(BaseLogger):
|
|
118
137
|
"""
|
|
119
138
|
A context manager for logging objects to a file.
|
|
120
139
|
It uses pickle to save the objects and handles exceptions during the save process.
|
|
@@ -161,14 +180,14 @@ class ObjectLogger:
|
|
|
161
180
|
self.file = open(self.log_file, "ab") # append in binary mode
|
|
162
181
|
return self
|
|
163
182
|
|
|
164
|
-
def log(self, obj):
|
|
183
|
+
async def log(self, obj):
|
|
165
184
|
"""
|
|
166
185
|
Log an object to the file.
|
|
167
186
|
It uses safe_deepcopy to ensure the object is pickleable.
|
|
168
187
|
"""
|
|
169
188
|
if self.has_parent():
|
|
170
189
|
# Delegate to the parent logger
|
|
171
|
-
self.parent_logger.log(obj)
|
|
190
|
+
await self.parent_logger.log(obj)
|
|
172
191
|
else:
|
|
173
192
|
try:
|
|
174
193
|
if self.debug:
|
|
@@ -190,25 +209,7 @@ class ObjectLogger:
|
|
|
190
209
|
self.file.close()
|
|
191
210
|
|
|
192
211
|
|
|
193
|
-
class
|
|
194
|
-
"""
|
|
195
|
-
A null logger that does nothing.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
def __init__(self, **kwargs):
|
|
199
|
-
pass
|
|
200
|
-
|
|
201
|
-
def __enter__(self):
|
|
202
|
-
pass
|
|
203
|
-
|
|
204
|
-
def log(self, obj):
|
|
205
|
-
"""Log an object to the configured destination."""
|
|
206
|
-
|
|
207
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
208
|
-
pass
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
class PrintObjectLogger:
|
|
212
|
+
class PrintObjectLogger(BaseLogger):
|
|
212
213
|
"""
|
|
213
214
|
Print to stdout
|
|
214
215
|
"""
|
|
@@ -219,7 +220,7 @@ class PrintObjectLogger:
|
|
|
219
220
|
def __enter__(self):
|
|
220
221
|
pass
|
|
221
222
|
|
|
222
|
-
def log(self, obj):
|
|
223
|
+
async def log(self, obj):
|
|
223
224
|
"""Log an object using rich print for formatted output."""
|
|
224
225
|
rich.print(obj, flush=True)
|
|
225
226
|
|