ibm-watsonx-orchestrate-evaluation-framework 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- ibm_watsonx_orchestrate_evaluation_framework-1.0.0.dist-info/METADATA +322 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.0.0.dist-info/RECORD +46 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.0.0.dist-info/WHEEL +5 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.0.0.dist-info/licenses/LICENSE +22 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.0.0.dist-info/top_level.txt +1 -0
- wxo_agentic_evaluation/__init__.py +0 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +405 -0
- wxo_agentic_evaluation/analytics/tools/main.py +163 -0
- wxo_agentic_evaluation/analytics/tools/types.py +130 -0
- wxo_agentic_evaluation/analytics/tools/ux.py +428 -0
- wxo_agentic_evaluation/analyze_run.py +123 -0
- wxo_agentic_evaluation/annotate.py +40 -0
- wxo_agentic_evaluation/arg_configs.py +78 -0
- wxo_agentic_evaluation/batch_annotate.py +181 -0
- wxo_agentic_evaluation/data_annotator.py +253 -0
- wxo_agentic_evaluation/evaluation_package.py +518 -0
- wxo_agentic_evaluation/external_agent/external_validate.py +69 -0
- wxo_agentic_evaluation/external_agent/types.py +65 -0
- wxo_agentic_evaluation/inference_backend.py +601 -0
- wxo_agentic_evaluation/llm_matching.py +39 -0
- wxo_agentic_evaluation/llm_rag_eval.py +47 -0
- wxo_agentic_evaluation/llm_user.py +38 -0
- wxo_agentic_evaluation/main.py +231 -0
- wxo_agentic_evaluation/metrics/__init__.py +0 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +46 -0
- wxo_agentic_evaluation/metrics/metrics.py +101 -0
- wxo_agentic_evaluation/prompt/__init__.py +0 -0
- wxo_agentic_evaluation/prompt/answer_relevancy_prompt.jinja2 +120 -0
- wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +51 -0
- wxo_agentic_evaluation/prompt/examples/__init__.py +0 -0
- wxo_agentic_evaluation/prompt/examples/data_simple.json +93 -0
- wxo_agentic_evaluation/prompt/faithfulness_prompt.jinja2 +59 -0
- wxo_agentic_evaluation/prompt/keyword_matching_prompt.jinja2 +75 -0
- wxo_agentic_evaluation/prompt/keywords_generation_prompt.jinja2 +20 -0
- wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +22 -0
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +114 -0
- wxo_agentic_evaluation/prompt/template_render.py +90 -0
- wxo_agentic_evaluation/prompt/tool_chain_agent.jinja2 +11 -0
- wxo_agentic_evaluation/prompt/tool_planner.jinja2 +40 -0
- wxo_agentic_evaluation/record_chat.py +165 -0
- wxo_agentic_evaluation/service_instance.py +179 -0
- wxo_agentic_evaluation/tool_planner.py +228 -0
- wxo_agentic_evaluation/type.py +176 -0
- wxo_agentic_evaluation/utils/__init__.py +6 -0
- wxo_agentic_evaluation/utils/utils.py +233 -0
- wxo_agentic_evaluation/watsonx_provider.py +175 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from wxo_agentic_evaluation.type import Message, EvaluationData
|
|
2
|
+
from wxo_agentic_evaluation.arg_configs import TestCaseGenerationConfig
|
|
3
|
+
from wxo_agentic_evaluation.data_annotator import DataAnnotator
|
|
4
|
+
import json
|
|
5
|
+
from pprint import pprint
|
|
6
|
+
from jsonargparse import CLI
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def main(config: TestCaseGenerationConfig):
|
|
11
|
+
messages = []
|
|
12
|
+
with open(config.log_path, "r") as f:
|
|
13
|
+
data = json.load(f)
|
|
14
|
+
for entry in data:
|
|
15
|
+
messages.append(Message.model_validate(entry))
|
|
16
|
+
|
|
17
|
+
with open(config.seed_data_path, "r") as f:
|
|
18
|
+
evaluation_data = EvaluationData(**json.load(f))
|
|
19
|
+
|
|
20
|
+
# Generate annonated dataset
|
|
21
|
+
annotator = DataAnnotator(
|
|
22
|
+
messages=messages,
|
|
23
|
+
keywords_generation_config=config.keywords_generation_config,
|
|
24
|
+
initial_data=evaluation_data,
|
|
25
|
+
)
|
|
26
|
+
dataset = annotator.generate()
|
|
27
|
+
|
|
28
|
+
# Save dataset
|
|
29
|
+
filename = config.seed_data_path.split("/")[-1]
|
|
30
|
+
core_name = filename.split(".")[0]
|
|
31
|
+
new_filename = f"{core_name}_annotated.json"
|
|
32
|
+
|
|
33
|
+
with open(os.path.join(config.output_dir, new_filename), "w") as f:
|
|
34
|
+
json.dump(dataset, f, indent=4)
|
|
35
|
+
|
|
36
|
+
pprint(dataset)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
main(CLI(TestCaseGenerationConfig, as_positional=False))
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import List
|
|
4
|
+
from wxo_agentic_evaluation import __file__
|
|
5
|
+
|
|
6
|
+
root_dir = os.path.dirname(__file__)
|
|
7
|
+
LLAMA_USER_PROMPT_PATH = os.path.join(root_dir, "prompt", "llama_user_prompt.jinja2")
|
|
8
|
+
KEYWORDS_GENERATION_PROMPT_PATH = os.path.join(root_dir, "prompt", "keywords_generation_prompt.jinja2")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AuthConfig:
|
|
13
|
+
url: str
|
|
14
|
+
tenant_name: str = "local"
|
|
15
|
+
token: str = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class LLMUserConfig:
|
|
20
|
+
model_id: str = field(default="meta-llama/llama-3-405b-instruct")
|
|
21
|
+
prompt_config: str = field(default=LLAMA_USER_PROMPT_PATH)
|
|
22
|
+
user_response_style: List[str] = field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class TestConfig:
|
|
27
|
+
test_paths: List[str]
|
|
28
|
+
output_dir: str
|
|
29
|
+
auth_config: AuthConfig
|
|
30
|
+
wxo_lite_version: str
|
|
31
|
+
llm_user_config: LLMUserConfig = field(default_factory=LLMUserConfig)
|
|
32
|
+
enable_verbose_logging: bool = True
|
|
33
|
+
enable_manual_user_input: bool = False
|
|
34
|
+
skip_available_results: bool = False
|
|
35
|
+
data_annotation_run: bool = False
|
|
36
|
+
num_workers: int = 2
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class AnalyzeConfig:
|
|
41
|
+
data_path: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class KeywordsGenerationConfig:
|
|
46
|
+
model_id: str = field(default="meta-llama/llama-3-405b-instruct")
|
|
47
|
+
prompt_config: str = field(default=KEYWORDS_GENERATION_PROMPT_PATH)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class TestCaseGenerationConfig:
|
|
52
|
+
log_path: str
|
|
53
|
+
seed_data_path: str
|
|
54
|
+
output_dir: str
|
|
55
|
+
keywords_generation_config: KeywordsGenerationConfig = field(
|
|
56
|
+
default_factory=KeywordsGenerationConfig
|
|
57
|
+
)
|
|
58
|
+
enable_verbose_logging: bool = True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class ChatRecordingConfig:
|
|
63
|
+
output_dir: str
|
|
64
|
+
keywords_generation_config: KeywordsGenerationConfig = field(
|
|
65
|
+
default_factory=KeywordsGenerationConfig
|
|
66
|
+
)
|
|
67
|
+
service_url: str = "http://localhost:4321"
|
|
68
|
+
tenant_name: str = "wxo-dev"
|
|
69
|
+
token: str = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class BatchAnnotateConfig:
|
|
74
|
+
allowed_tools: List[str]
|
|
75
|
+
tools_path: str
|
|
76
|
+
stories_path: str
|
|
77
|
+
output_dir: str
|
|
78
|
+
num_variants: int = 2
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import ast
|
|
3
|
+
import csv
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from jsonargparse import CLI
|
|
7
|
+
|
|
8
|
+
from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
|
|
9
|
+
from wxo_agentic_evaluation.prompt.template_render import BatchTestCaseGeneratorTemplateRenderer
|
|
10
|
+
from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
|
|
11
|
+
from wxo_agentic_evaluation import __file__
|
|
12
|
+
|
|
13
|
+
root_dir = os.path.dirname(__file__)
|
|
14
|
+
BATCH_TEST_CASE_GENERATOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "batch_testcase_prompt.jinja2")
|
|
15
|
+
EXAMPLE_PATH = os.path.join(root_dir, "prompt", "examples", "data_simple.json")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def parse_tools_with_filter(agent_name: str, tools_path: Path, allowed_tool_names: list[str]) -> tuple[
|
|
19
|
+
dict, list[dict]]:
|
|
20
|
+
if not allowed_tool_names:
|
|
21
|
+
raise ValueError("Allowed tool list cannot be empty.")
|
|
22
|
+
|
|
23
|
+
tool_data = []
|
|
24
|
+
files_to_parse = []
|
|
25
|
+
|
|
26
|
+
# Handle both single file and directory cases
|
|
27
|
+
if tools_path.is_file():
|
|
28
|
+
files_to_parse.append(tools_path)
|
|
29
|
+
elif tools_path.is_dir():
|
|
30
|
+
files_to_parse.extend(tools_path.glob("**/*.py"))
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
|
|
33
|
+
|
|
34
|
+
for file_path in files_to_parse:
|
|
35
|
+
try:
|
|
36
|
+
with file_path.open("r", encoding="utf-8") as f:
|
|
37
|
+
tools_code = f.read()
|
|
38
|
+
|
|
39
|
+
parsed_code = ast.parse(tools_code)
|
|
40
|
+
|
|
41
|
+
# Process only module-level functions
|
|
42
|
+
for node in parsed_code.body:
|
|
43
|
+
if isinstance(node, ast.FunctionDef):
|
|
44
|
+
tool_data.append({
|
|
45
|
+
"Function Name": node.name,
|
|
46
|
+
"Arguments": [arg.arg for arg in node.args.args],
|
|
47
|
+
"Docstring": ast.get_docstring(node)
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
print(f"Warning: Failed to parse {file_path}: {str(e)}")
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
# Filter tools based on allowed names
|
|
55
|
+
filtered_tools = [tool for tool in tool_data if tool["Function Name"] in allowed_tool_names]
|
|
56
|
+
|
|
57
|
+
if not filtered_tools:
|
|
58
|
+
print(f"Warning: No matching tools found. Available tools: {[t['Function Name'] for t in tool_data]}")
|
|
59
|
+
|
|
60
|
+
return {"name": agent_name}, filtered_tools
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Step 2: Extract tool input/output examples from snapshot
|
|
64
|
+
def extract_inputs_from_snapshot(snapshot_path: Path) -> dict:
|
|
65
|
+
with snapshot_path.open("r", encoding="utf-8") as f:
|
|
66
|
+
snapshot = json.load(f)
|
|
67
|
+
return snapshot.get("input_output_examples", {})
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Step 3: Load a single example test case just for structure
|
|
71
|
+
def load_example(example_path: Path):
|
|
72
|
+
with example_path.open("r", encoding="utf-8") as f:
|
|
73
|
+
data = json.load(f)
|
|
74
|
+
data.pop("mine_fields", None)
|
|
75
|
+
return data
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Step 4: Prompt builder for N test cases from a given story
|
|
79
|
+
def build_prompt_for_story(agent, tools, tool_inputs, example_case: dict, story: str, num_variants: int = 2):
|
|
80
|
+
renderer = BatchTestCaseGeneratorTemplateRenderer(BATCH_TEST_CASE_GENERATOR_PROMPT_PATH)
|
|
81
|
+
|
|
82
|
+
tool_blocks = "\n".join(
|
|
83
|
+
f"- Tool: {t['Function Name']}\n Description: {t['Docstring']}\n Args: {', '.join(t['Arguments']) or 'None'}"
|
|
84
|
+
for t in tools
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
prompt = renderer.render(
|
|
88
|
+
agent_name=agent["name"],
|
|
89
|
+
tool_blocks=tool_blocks,
|
|
90
|
+
tool_inputs_str=json.dumps(tool_inputs, indent=2),
|
|
91
|
+
story=story,
|
|
92
|
+
num_variants=num_variants,
|
|
93
|
+
example_str=json.dumps(example_case, indent=2),
|
|
94
|
+
)
|
|
95
|
+
return prompt
|
|
96
|
+
|
|
97
|
+
# Step 5: Send prompt to LLM and save test cases
|
|
98
|
+
def generate_multiple_in_one(prompt, output_dir, starting_index, model_id="meta-llama/llama-3-405b-instruct", ):
|
|
99
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
|
|
101
|
+
provider = WatsonXProvider(
|
|
102
|
+
model_id=model_id,
|
|
103
|
+
llm_decode_parameter={
|
|
104
|
+
"min_new_tokens": 50,
|
|
105
|
+
"decoding_method": "greedy",
|
|
106
|
+
"max_new_tokens": 3000
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
response = provider.query(prompt)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
raw_text = response.get("generated_text", "")
|
|
114
|
+
json_start = raw_text.find("[")
|
|
115
|
+
json_end = raw_text.rfind("]") + 1
|
|
116
|
+
json_block = raw_text[json_start:json_end].strip()
|
|
117
|
+
|
|
118
|
+
test_cases = json.loads(json_block)
|
|
119
|
+
assert isinstance(test_cases, list), "Expected list of test cases"
|
|
120
|
+
|
|
121
|
+
for i, case in enumerate(test_cases, start=starting_index):
|
|
122
|
+
case["mine_fields"] = [] # ✅ Add the field here
|
|
123
|
+
out_file = output_dir / f"synthetic_test_case_{i}.json"
|
|
124
|
+
with out_file.open("w", encoding="utf-8") as f:
|
|
125
|
+
json.dump(case, f, indent=2)
|
|
126
|
+
print(f"✅ Test case {i} written to {out_file}")
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
print("⚠️ Failed to parse or validate test case output.")
|
|
130
|
+
print("Raw text:\n", raw_text)
|
|
131
|
+
print("Error:", str(e))
|
|
132
|
+
|
|
133
|
+
def generate_test_cases_from_stories(agent_name: str, stories: list[str], tools_path: Path, snapshot_path: Path, output_dir: Path, allowed_tools: list[str], num_variants: int = 2):
|
|
134
|
+
agent, tools = parse_tools_with_filter(agent_name, tools_path, allowed_tools)
|
|
135
|
+
tool_inputs = extract_inputs_from_snapshot(snapshot_path)
|
|
136
|
+
example_json = load_example(Path(EXAMPLE_PATH))
|
|
137
|
+
|
|
138
|
+
test_case_counter = 1
|
|
139
|
+
for idx, story in enumerate(stories, start=1):
|
|
140
|
+
print(f"\n Generating test cases for story {idx}: {story}")
|
|
141
|
+
|
|
142
|
+
prompt = build_prompt_for_story(
|
|
143
|
+
agent, tools, tool_inputs, example_json, story, num_variants=num_variants
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
generate_multiple_in_one(
|
|
147
|
+
prompt=prompt,
|
|
148
|
+
output_dir=output_dir,
|
|
149
|
+
starting_index=test_case_counter
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
test_case_counter += num_variants
|
|
153
|
+
|
|
154
|
+
def main(config: BatchAnnotateConfig):
|
|
155
|
+
stories_path = Path(config.stories_path)
|
|
156
|
+
|
|
157
|
+
stories = []
|
|
158
|
+
agent_name = None
|
|
159
|
+
with stories_path.open("r", encoding="utf-8", newline='') as f:
|
|
160
|
+
csv_reader = csv.DictReader(f)
|
|
161
|
+
for row in csv_reader:
|
|
162
|
+
stories.append(row["story"])
|
|
163
|
+
if agent_name is None:
|
|
164
|
+
agent_name = row["agent"]
|
|
165
|
+
|
|
166
|
+
tools_path = Path(config.tools_path)
|
|
167
|
+
snapshot_path = stories_path.parent / f"{agent_name}_snapshot_llm.json"
|
|
168
|
+
output_dir = Path(config.output_dir) / f"{agent_name}_test_cases"
|
|
169
|
+
|
|
170
|
+
generate_test_cases_from_stories(
|
|
171
|
+
agent_name,
|
|
172
|
+
stories,
|
|
173
|
+
tools_path,
|
|
174
|
+
snapshot_path,
|
|
175
|
+
output_dir,
|
|
176
|
+
config.allowed_tools,
|
|
177
|
+
num_variants=config.num_variants
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if __name__ == "__main__":
|
|
181
|
+
main(CLI(BatchAnnotateConfig, as_positional=False))
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from wxo_agentic_evaluation.type import Message, EvaluationData
|
|
2
|
+
from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
|
|
3
|
+
from wxo_agentic_evaluation.prompt.template_render import (
|
|
4
|
+
LlamaKeywordsGenerationTemplateRenderer,
|
|
5
|
+
)
|
|
6
|
+
from wxo_agentic_evaluation.arg_configs import KeywordsGenerationConfig
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
import json
|
|
10
|
+
import collections
|
|
11
|
+
from typing import Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
ERROR_KEYWORDS = [
|
|
14
|
+
"error",
|
|
15
|
+
"erroneous",
|
|
16
|
+
"exception",
|
|
17
|
+
"traceback",
|
|
18
|
+
"failed",
|
|
19
|
+
"fail",
|
|
20
|
+
"fatal",
|
|
21
|
+
"panic",
|
|
22
|
+
"abort",
|
|
23
|
+
"not found",
|
|
24
|
+
"notfound",
|
|
25
|
+
"cannot",
|
|
26
|
+
"can't",
|
|
27
|
+
"unable",
|
|
28
|
+
"unsuccessful",
|
|
29
|
+
"invalid",
|
|
30
|
+
"incorrect",
|
|
31
|
+
"illegal",
|
|
32
|
+
"unknown",
|
|
33
|
+
"unexpected",
|
|
34
|
+
"unauthorized",
|
|
35
|
+
"permission denied",
|
|
36
|
+
"denied",
|
|
37
|
+
"forbidden",
|
|
38
|
+
"forbidden request",
|
|
39
|
+
"unavailable",
|
|
40
|
+
"unreachable",
|
|
41
|
+
"missing",
|
|
42
|
+
"exceeded",
|
|
43
|
+
"exceeds limit",
|
|
44
|
+
"timed out",
|
|
45
|
+
"timeout",
|
|
46
|
+
"stack trace",
|
|
47
|
+
"syntax error",
|
|
48
|
+
"runtime error",
|
|
49
|
+
"indexerror",
|
|
50
|
+
"keyerror",
|
|
51
|
+
"valueerror",
|
|
52
|
+
"typeerror",
|
|
53
|
+
"zerodivisionerror",
|
|
54
|
+
"segmentation fault",
|
|
55
|
+
"segfault",
|
|
56
|
+
"core dumped",
|
|
57
|
+
"memory error",
|
|
58
|
+
"out of memory",
|
|
59
|
+
"oom",
|
|
60
|
+
"overflow",
|
|
61
|
+
"underflow",
|
|
62
|
+
"crash",
|
|
63
|
+
"bad request",
|
|
64
|
+
"http_code=400",
|
|
65
|
+
"http_code=401",
|
|
66
|
+
"http_code=403",
|
|
67
|
+
"http_code=404",
|
|
68
|
+
"http_code=405",
|
|
69
|
+
"http_code=408",
|
|
70
|
+
"http_code=409",
|
|
71
|
+
"http_code=429",
|
|
72
|
+
"http_code=500",
|
|
73
|
+
"http_code=503",
|
|
74
|
+
"http_code=504",
|
|
75
|
+
"connection refused",
|
|
76
|
+
"connection error",
|
|
77
|
+
"broken pipe",
|
|
78
|
+
"bus error",
|
|
79
|
+
"catastrophic failure",
|
|
80
|
+
"unresolved",
|
|
81
|
+
"infinite recursion",
|
|
82
|
+
"overrun",
|
|
83
|
+
"overwrite",
|
|
84
|
+
"no such file or directory",
|
|
85
|
+
"invalid argument",
|
|
86
|
+
"server is down",
|
|
87
|
+
"server error",
|
|
88
|
+
"sql error",
|
|
89
|
+
"db error",
|
|
90
|
+
"database error",
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class KeywordsGenerationLLM:
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
wai_client: WatsonXProvider,
|
|
98
|
+
template: LlamaKeywordsGenerationTemplateRenderer,
|
|
99
|
+
):
|
|
100
|
+
self.wai_client = wai_client
|
|
101
|
+
self.prompt_template = template
|
|
102
|
+
|
|
103
|
+
def genereate_keywords(self, response) -> Message | None:
|
|
104
|
+
prompt = self.prompt_template.render(response=response)
|
|
105
|
+
res = self.wai_client.query(prompt)
|
|
106
|
+
keywords = ast.literal_eval(res["generated_text"].strip())
|
|
107
|
+
return keywords
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class DataAnnotator:
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
messages: List[Message],
|
|
114
|
+
keywords_generation_config: KeywordsGenerationConfig,
|
|
115
|
+
initial_data: Optional[EvaluationData] = None,
|
|
116
|
+
):
|
|
117
|
+
self.messages = messages
|
|
118
|
+
self.keywords_generation_config = keywords_generation_config
|
|
119
|
+
self.initial_data = initial_data or EvaluationData(
|
|
120
|
+
agent="",
|
|
121
|
+
story="",
|
|
122
|
+
starting_sentence=messages[0].content if messages else "",
|
|
123
|
+
mine_fields=[],
|
|
124
|
+
goals={},
|
|
125
|
+
goal_details=[],
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _is_error_in_message(message: str) -> bool:
|
|
130
|
+
"""Heuristic to catch tool calls that fail"""
|
|
131
|
+
message = message.lower()
|
|
132
|
+
return any(keyword in message for keyword in ERROR_KEYWORDS)
|
|
133
|
+
|
|
134
|
+
def _get_failed_tool_responses(self) -> list[str]:
|
|
135
|
+
"""Get list of IDs for failed tool calls"""
|
|
136
|
+
wrong_tool_response_id = []
|
|
137
|
+
for message in self.messages:
|
|
138
|
+
if message.type == "tool_response":
|
|
139
|
+
content = message.content.lower()
|
|
140
|
+
if self._is_error_in_message(content):
|
|
141
|
+
wrong_tool_response_id.append(
|
|
142
|
+
json.loads(message.content)["tool_call_id"]
|
|
143
|
+
)
|
|
144
|
+
return wrong_tool_response_id
|
|
145
|
+
|
|
146
|
+
def _process_tool_call_order(self, wrong_tool_response_id: list[str]) -> list[str]:
|
|
147
|
+
"""Process and order tool calls, skipping failed ones"""
|
|
148
|
+
order = []
|
|
149
|
+
for message in self.messages:
|
|
150
|
+
if message.type == "tool_call":
|
|
151
|
+
content = json.loads(message.content)
|
|
152
|
+
# skip all the tool calls that fail
|
|
153
|
+
if (
|
|
154
|
+
content.get("tool_call_id", "") in wrong_tool_response_id
|
|
155
|
+
or content.get("id", "") in wrong_tool_response_id
|
|
156
|
+
):
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
if "tool_call_id" in content:
|
|
160
|
+
del content["tool_call_id"]
|
|
161
|
+
if "id" in content:
|
|
162
|
+
del content["id"]
|
|
163
|
+
|
|
164
|
+
content = json.dumps(content, sort_keys=True)
|
|
165
|
+
# for a given tool call signature - function name + args only keep the most recent one
|
|
166
|
+
if content in order:
|
|
167
|
+
idx = order.index(content)
|
|
168
|
+
order = order[:idx] + order[idx + 1 :] + [content]
|
|
169
|
+
else:
|
|
170
|
+
order.append(content)
|
|
171
|
+
return order
|
|
172
|
+
|
|
173
|
+
def _process_tool_calls(self) -> tuple[Dict, List, str]:
|
|
174
|
+
"""Process tool calls and generate goals structure"""
|
|
175
|
+
# Get failed tool response IDs and process tool calls
|
|
176
|
+
wrong_tool_response_id = self._get_failed_tool_responses()
|
|
177
|
+
order = self._process_tool_call_order(wrong_tool_response_id)
|
|
178
|
+
|
|
179
|
+
goals = {}
|
|
180
|
+
goal_details = []
|
|
181
|
+
function_count = collections.defaultdict(int)
|
|
182
|
+
previous = None
|
|
183
|
+
|
|
184
|
+
for tool_call in order:
|
|
185
|
+
call = json.loads(tool_call)
|
|
186
|
+
funct_name = call["name"]
|
|
187
|
+
function_count[funct_name] += 1
|
|
188
|
+
goal_name = funct_name + f"-{function_count[funct_name]}"
|
|
189
|
+
|
|
190
|
+
if previous:
|
|
191
|
+
goals[previous] = [goal_name]
|
|
192
|
+
|
|
193
|
+
goal_detail = {
|
|
194
|
+
"type": "tool_call",
|
|
195
|
+
"name": goal_name,
|
|
196
|
+
"tool_name": funct_name,
|
|
197
|
+
"args": call["args"],
|
|
198
|
+
}
|
|
199
|
+
goal_details.append(goal_detail)
|
|
200
|
+
previous = goal_name
|
|
201
|
+
|
|
202
|
+
return goals, goal_details, previous
|
|
203
|
+
|
|
204
|
+
def _process_summarization(
|
|
205
|
+
self, previous: str, goals: Dict, goal_details: List
|
|
206
|
+
) -> None:
|
|
207
|
+
"""Process summarization step"""
|
|
208
|
+
summarize_step = None
|
|
209
|
+
# we assume single summary step at the end
|
|
210
|
+
for message in self.messages[::-1]:
|
|
211
|
+
if message.role == "assistant":
|
|
212
|
+
wai_client = WatsonXProvider(
|
|
213
|
+
model_id=self.keywords_generation_config.model_id,
|
|
214
|
+
llm_decode_parameter={
|
|
215
|
+
"min_new_tokens": 0,
|
|
216
|
+
"decoding_method": "greedy",
|
|
217
|
+
"max_new_tokens": 256,
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
kw_generator = KeywordsGenerationLLM(
|
|
221
|
+
wai_client=wai_client,
|
|
222
|
+
template=LlamaKeywordsGenerationTemplateRenderer(
|
|
223
|
+
self.keywords_generation_config.prompt_config
|
|
224
|
+
),
|
|
225
|
+
)
|
|
226
|
+
keywords = kw_generator.genereate_keywords(message.content)
|
|
227
|
+
summarize_step = {
|
|
228
|
+
"name": "summarize",
|
|
229
|
+
"type": "text",
|
|
230
|
+
"response": message.content,
|
|
231
|
+
"keywords": keywords,
|
|
232
|
+
}
|
|
233
|
+
goal_details.append(summarize_step)
|
|
234
|
+
break
|
|
235
|
+
|
|
236
|
+
if summarize_step:
|
|
237
|
+
goals[previous] = ["summarize"]
|
|
238
|
+
else:
|
|
239
|
+
goals[previous] = []
|
|
240
|
+
|
|
241
|
+
def generate(self) -> Dict:
|
|
242
|
+
"""Generate the final dataset"""
|
|
243
|
+
goals, goal_details, previous = self._process_tool_calls()
|
|
244
|
+
self._process_summarization(previous, goals, goal_details)
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"agent": self.initial_data.agent,
|
|
248
|
+
"goals": goals,
|
|
249
|
+
"goal_details": goal_details,
|
|
250
|
+
"mine_fields": [],
|
|
251
|
+
"story": self.initial_data.story,
|
|
252
|
+
"starting_sentence": self.initial_data.starting_sentence,
|
|
253
|
+
}
|