ibm-watsonx-orchestrate-evaluation-framework 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/METADATA +70 -7
- ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +56 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +3 -3
- wxo_agentic_evaluation/analytics/tools/ux.py +1 -1
- wxo_agentic_evaluation/analyze_run.py +10 -10
- wxo_agentic_evaluation/arg_configs.py +8 -1
- wxo_agentic_evaluation/batch_annotate.py +3 -9
- wxo_agentic_evaluation/data_annotator.py +50 -36
- wxo_agentic_evaluation/evaluation_package.py +102 -85
- wxo_agentic_evaluation/external_agent/__init__.py +37 -0
- wxo_agentic_evaluation/external_agent/external_validate.py +74 -29
- wxo_agentic_evaluation/external_agent/performance_test.py +66 -0
- wxo_agentic_evaluation/external_agent/types.py +8 -2
- wxo_agentic_evaluation/inference_backend.py +45 -50
- wxo_agentic_evaluation/llm_matching.py +6 -6
- wxo_agentic_evaluation/llm_rag_eval.py +4 -4
- wxo_agentic_evaluation/llm_user.py +3 -3
- wxo_agentic_evaluation/main.py +63 -23
- wxo_agentic_evaluation/metrics/metrics.py +59 -0
- wxo_agentic_evaluation/prompt/args_extractor_prompt.jinja2 +23 -0
- wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +2 -0
- wxo_agentic_evaluation/prompt/examples/data_simple.json +1 -2
- wxo_agentic_evaluation/prompt/starting_sentence_generation_prompt.jinja2 +195 -0
- wxo_agentic_evaluation/prompt/story_generation_prompt.jinja2 +154 -0
- wxo_agentic_evaluation/prompt/template_render.py +17 -0
- wxo_agentic_evaluation/prompt/tool_planner.jinja2 +13 -7
- wxo_agentic_evaluation/record_chat.py +59 -18
- wxo_agentic_evaluation/resource_map.py +47 -0
- wxo_agentic_evaluation/service_provider/__init__.py +35 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +108 -0
- wxo_agentic_evaluation/service_provider/ollama_provider.py +40 -0
- wxo_agentic_evaluation/service_provider/provider.py +19 -0
- wxo_agentic_evaluation/{watsonx_provider.py → service_provider/watsonx_provider.py} +27 -18
- wxo_agentic_evaluation/test_prompt.py +94 -0
- wxo_agentic_evaluation/tool_planner.py +130 -17
- wxo_agentic_evaluation/type.py +0 -57
- wxo_agentic_evaluation/utils/utils.py +6 -54
- ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/RECORD +0 -46
- ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/licenses/LICENSE +0 -22
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/top_level.txt +0 -0
|
@@ -9,15 +9,17 @@ from wxo_agentic_evaluation.type import (
|
|
|
9
9
|
ContentType,
|
|
10
10
|
Message,
|
|
11
11
|
EvaluationData,
|
|
12
|
-
ToolCallAndRoutingMetrics,
|
|
13
12
|
EventTypes,
|
|
14
13
|
ConversationalSearch,
|
|
15
14
|
ExtendedMessage,
|
|
16
15
|
)
|
|
17
|
-
from wxo_agentic_evaluation.
|
|
16
|
+
from wxo_agentic_evaluation.resource_map import ResourceMap
|
|
17
|
+
from wxo_agentic_evaluation.service_provider import get_provider
|
|
18
18
|
from wxo_agentic_evaluation.metrics.metrics import (
|
|
19
19
|
KnowledgeBaseMetrics,
|
|
20
20
|
KeywordSemanticSearchMetric,
|
|
21
|
+
ToolCallAndRoutingMetrics,
|
|
22
|
+
TextMatchType
|
|
21
23
|
)
|
|
22
24
|
from wxo_agentic_evaluation.prompt.template_render import (
|
|
23
25
|
KeywordMatchingTemplateRenderer,
|
|
@@ -35,6 +37,13 @@ SEMANTIC_MATCHING_PROMPT_PATH = os.path.join(root_dir, "prompt", "semantic_match
|
|
|
35
37
|
FAITHFULNESS_PROMPT_PATH = os.path.join(root_dir, "prompt", "faithfulness_prompt.jinja2")
|
|
36
38
|
ANSWER_RELEVANCY_PROMPT_PATH = os.path.join(root_dir, "prompt", "answer_relevancy_prompt.jinja2")
|
|
37
39
|
|
|
40
|
+
"""
|
|
41
|
+
- hyphens are not allowed in python function names, so it is safe to use as a dummy function name
|
|
42
|
+
- purpose behind `DUMMY_GRAPH_NODE_NAME` is to append
|
|
43
|
+
a dummy node to the ground truth and the labelled messages to take into account
|
|
44
|
+
single, summary step goals.
|
|
45
|
+
"""
|
|
46
|
+
DUMMY_GRAPH_NODE_NAME = "dummy-goal"
|
|
38
47
|
|
|
39
48
|
class EvaluationPackage:
|
|
40
49
|
def __init__(
|
|
@@ -44,6 +53,7 @@ class EvaluationPackage:
|
|
|
44
53
|
messages,
|
|
45
54
|
conversational_search_data: List[ConversationalSearch] = None,
|
|
46
55
|
is_analyze_run=False,
|
|
56
|
+
resource_map: ResourceMap = None,
|
|
47
57
|
):
|
|
48
58
|
self.tool_dictionary = {
|
|
49
59
|
goal_detail.name: goal_detail
|
|
@@ -63,13 +73,9 @@ class EvaluationPackage:
|
|
|
63
73
|
self.is_analyze_run = is_analyze_run
|
|
64
74
|
|
|
65
75
|
self.matcher = LLMMatcher(
|
|
66
|
-
llm_client=
|
|
76
|
+
llm_client=get_provider(
|
|
67
77
|
model_id="meta-llama/llama-3-405b-instruct",
|
|
68
|
-
|
|
69
|
-
"min_new_tokens": 0,
|
|
70
|
-
"decoding_method": "greedy",
|
|
71
|
-
"max_new_tokens": 10,
|
|
72
|
-
},
|
|
78
|
+
params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 10},
|
|
73
79
|
),
|
|
74
80
|
keyword_template=KeywordMatchingTemplateRenderer(
|
|
75
81
|
KEYWORD_MATCHING_PROMPT_PATH
|
|
@@ -79,23 +85,55 @@ class EvaluationPackage:
|
|
|
79
85
|
),
|
|
80
86
|
)
|
|
81
87
|
self.rag_llm_as_a_judge = LLMJudge(
|
|
82
|
-
llm_client=
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
"decoding_method": "greedy",
|
|
87
|
-
"max_new_tokens": 4096,
|
|
88
|
-
},
|
|
89
|
-
),
|
|
88
|
+
llm_client=get_provider(
|
|
89
|
+
model_id="meta-llama/llama-3-405b-instruct",
|
|
90
|
+
params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 4096},
|
|
91
|
+
),
|
|
90
92
|
faithfulness=FaithfulnessTemplateRenderer(FAITHFULNESS_PROMPT_PATH),
|
|
91
93
|
answer_relevancy=AnswerRelevancyTemplateRenderer(
|
|
92
94
|
ANSWER_RELEVANCY_PROMPT_PATH
|
|
93
95
|
),
|
|
94
96
|
)
|
|
95
97
|
|
|
98
|
+
self.resource_map = resource_map
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def find_ground_node(graph, start_node):
|
|
102
|
+
""" Simple implementation. Should be fixed in the future
|
|
103
|
+
|
|
104
|
+
Assumes that there is a single graph node that does not have children
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
stack = [start_node]
|
|
108
|
+
visited_set = set()
|
|
109
|
+
|
|
110
|
+
while stack:
|
|
111
|
+
node = stack.pop()
|
|
112
|
+
if node not in visited_set:
|
|
113
|
+
visited_set.add(node)
|
|
114
|
+
|
|
115
|
+
# check for children
|
|
116
|
+
# improvement for future: add the ground nodes here
|
|
117
|
+
# right now, just return the first one
|
|
118
|
+
if not graph.get(node):
|
|
119
|
+
return node
|
|
120
|
+
|
|
121
|
+
stack.extend(graph[node])
|
|
122
|
+
|
|
123
|
+
return None
|
|
124
|
+
|
|
96
125
|
@staticmethod
|
|
97
126
|
def is_topological_sort(graph, ordering):
|
|
98
127
|
position = {node: i for i, node in enumerate(ordering)}
|
|
128
|
+
ground_node = EvaluationPackage.find_ground_node(graph, list(graph.keys())[0])
|
|
129
|
+
|
|
130
|
+
if ground_node is not None:
|
|
131
|
+
graph[ground_node] = [DUMMY_GRAPH_NODE_NAME]
|
|
132
|
+
graph[DUMMY_GRAPH_NODE_NAME] = []
|
|
133
|
+
|
|
134
|
+
next_idx = len(position)
|
|
135
|
+
position[DUMMY_GRAPH_NODE_NAME] = next_idx
|
|
136
|
+
|
|
99
137
|
for u in graph:
|
|
100
138
|
for v in graph[u]:
|
|
101
139
|
if u not in position or v not in position:
|
|
@@ -143,7 +181,7 @@ class EvaluationPackage:
|
|
|
143
181
|
f"Goal detail '{goal_detail.name}' does not match any goals: {goals}. test_case_name: {test_case_name}"
|
|
144
182
|
)
|
|
145
183
|
if goal_detail.name == "summarize":
|
|
146
|
-
if len(goal_detail.keywords) == 0 and len(goal_detail.response) == 0:
|
|
184
|
+
if (not goal_detail.keywords or len(goal_detail.keywords) == 0) and (not goal_detail.response or len(goal_detail.response) == 0):
|
|
147
185
|
rich.print(
|
|
148
186
|
f"Summarize goal should have keywords or final response. test_case_name: {test_case_name}"
|
|
149
187
|
)
|
|
@@ -178,23 +216,35 @@ class EvaluationPackage:
|
|
|
178
216
|
labelled_messages_without_text_step = []
|
|
179
217
|
# Counters for tool-calling related metrics
|
|
180
218
|
tool_call_and_routing_metrics = ToolCallAndRoutingMetrics(
|
|
181
|
-
total_tool_calls=0,
|
|
182
|
-
expected_tool_calls=0,
|
|
183
|
-
relevant_tool_calls=0,
|
|
184
|
-
correct_tool_calls=0,
|
|
185
|
-
total_routing_calls=0,
|
|
186
|
-
expected_routing_calls=0,
|
|
187
219
|
)
|
|
188
220
|
tool_call_and_routing_metrics.expected_tool_calls = len(self.tool_dictionary)
|
|
189
221
|
|
|
190
222
|
for message in self.messages:
|
|
191
223
|
if message.type == ContentType.tool_call:
|
|
192
|
-
tool_call_and_routing_metrics.total_tool_calls += 1
|
|
193
|
-
msg_tool_call = json.loads(message.content)
|
|
194
224
|
|
|
195
|
-
|
|
196
|
-
if msg_tool_call["name"].
|
|
225
|
+
msg_tool_call = json.loads(message.content)
|
|
226
|
+
if self.resource_map and msg_tool_call["name"] in self.resource_map.agent2tools:
|
|
197
227
|
tool_call_and_routing_metrics.total_routing_calls += 1
|
|
228
|
+
relevant = False
|
|
229
|
+
for tool in self.resource_map.agent2tools[msg_tool_call["name"]]:
|
|
230
|
+
for goal_detail in self.tool_dictionary.values():
|
|
231
|
+
if goal_detail.tool_name == tool:
|
|
232
|
+
relevant = True
|
|
233
|
+
break
|
|
234
|
+
if relevant:
|
|
235
|
+
break
|
|
236
|
+
|
|
237
|
+
if relevant:
|
|
238
|
+
tool_call_and_routing_metrics.relevant_routing_calls += 1
|
|
239
|
+
else:
|
|
240
|
+
message_outcome = ExtendedMessage(message=message)
|
|
241
|
+
message_outcome.reason = {
|
|
242
|
+
"reason": "irrelevant routing call",
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
tool_call_and_routing_metrics.total_tool_calls += 1
|
|
198
248
|
|
|
199
249
|
# evaluating more than once is fine
|
|
200
250
|
# agent could make repeated calls with the same function signature
|
|
@@ -207,57 +257,41 @@ class EvaluationPackage:
|
|
|
207
257
|
if len(matching_goal_details) > 0:
|
|
208
258
|
tool_call_and_routing_metrics.relevant_tool_calls += 1 # tool name matches one of the expected tool names, as defined in the ground truth
|
|
209
259
|
found = False
|
|
210
|
-
|
|
260
|
+
possible_ground_truth_for_analysis = []
|
|
211
261
|
for goal_detail in matching_goal_details:
|
|
212
|
-
if
|
|
213
|
-
is_transfer := msg_tool_call["name"].startswith(
|
|
214
|
-
"transfer_to_"
|
|
215
|
-
)
|
|
216
|
-
) or msg_tool_call["args"] == goal_detail.args:
|
|
262
|
+
if msg_tool_call["args"] == goal_detail.args:
|
|
217
263
|
labelled_messages.append(goal_detail.name)
|
|
218
264
|
labelled_messages_without_text_step.append(goal_detail.name)
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
1
|
|
222
|
-
)
|
|
223
|
-
else:
|
|
224
|
-
tool_call_and_routing_metrics.correct_tool_calls += 1 # correct tool call (no erroneous response) + expected arguments, as defined in the ground truth
|
|
265
|
+
|
|
266
|
+
tool_call_and_routing_metrics.correct_tool_calls += 1 # correct tool call (no erroneous response) + expected arguments, as defined in the ground truth
|
|
225
267
|
found = True
|
|
226
268
|
message_outcome = ExtendedMessage(message=message)
|
|
227
269
|
message_outcomes.append(message_outcome)
|
|
228
270
|
break
|
|
229
271
|
else:
|
|
230
|
-
|
|
272
|
+
possible_ground_truth_for_analysis.append(goal_detail.args)
|
|
231
273
|
|
|
232
274
|
if not found:
|
|
233
275
|
message_outcome = ExtendedMessage(message=message)
|
|
234
276
|
message_outcome.reason = {
|
|
235
277
|
"reason": "incorrect parameter",
|
|
236
278
|
"actual": msg_tool_call["args"],
|
|
237
|
-
"expected":
|
|
279
|
+
"expected": possible_ground_truth_for_analysis,
|
|
238
280
|
}
|
|
239
281
|
message_outcomes.append(message_outcome)
|
|
240
282
|
rich.print(
|
|
241
283
|
f"[red][ERROR] Wrong parameters for function: {msg_tool_call['name']}. "
|
|
242
284
|
f"Expected one of {[g.args for g in matching_goal_details]}, Received={msg_tool_call['args']}[/red]"
|
|
243
285
|
)
|
|
244
|
-
labelled_messages.append(
|
|
245
|
-
msg_tool_call["name"] + "_WRONG_PARAMETERS"
|
|
246
|
-
)
|
|
247
286
|
else:
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
msg_tool_call["name"] + "_WRONG_FUNCTION_CALL"
|
|
257
|
-
)
|
|
258
|
-
message_outcome = ExtendedMessage(message=message)
|
|
259
|
-
message_outcome.reason = {"reason": "irrelevant tool call"}
|
|
260
|
-
message_outcomes.append(message_outcome)
|
|
287
|
+
|
|
288
|
+
rich.print(
|
|
289
|
+
f"[yellow][WARNING] Unexpected function call: {msg_tool_call['name']}[/yellow]"
|
|
290
|
+
)
|
|
291
|
+
# note: this is incorrect after the 1.6 change
|
|
292
|
+
message_outcome = ExtendedMessage(message=message)
|
|
293
|
+
message_outcome.reason = {"reason": "irrelevant tool call"}
|
|
294
|
+
message_outcomes.append(message_outcome)
|
|
261
295
|
|
|
262
296
|
elif message.type == ContentType.tool_response:
|
|
263
297
|
found = False
|
|
@@ -272,7 +306,6 @@ class EvaluationPackage:
|
|
|
272
306
|
message_outcome = ExtendedMessage(message=message)
|
|
273
307
|
message_outcomes.append(message_outcome)
|
|
274
308
|
else:
|
|
275
|
-
|
|
276
309
|
message_outcome = ExtendedMessage(message=message)
|
|
277
310
|
message_outcomes.append(message_outcome)
|
|
278
311
|
assistant_responses = [
|
|
@@ -318,15 +351,16 @@ class EvaluationPackage:
|
|
|
318
351
|
):
|
|
319
352
|
|
|
320
353
|
if len(self.text_list) == 0:
|
|
321
|
-
return
|
|
354
|
+
return TextMatchType.na.value
|
|
322
355
|
elif len(self.text_list) == len(keyword_semantic_match_list):
|
|
323
|
-
return
|
|
356
|
+
return TextMatchType.text_match.value
|
|
324
357
|
else:
|
|
325
|
-
return
|
|
358
|
+
return TextMatchType.text_mismatch.value
|
|
326
359
|
|
|
327
360
|
def generate_summary(self):
|
|
328
361
|
llm_steps = 0
|
|
329
362
|
total_step = 0
|
|
363
|
+
metrics: ToolCallAndRoutingMetrics
|
|
330
364
|
(
|
|
331
365
|
labelled_messages,
|
|
332
366
|
labelled_messages_without_text_step,
|
|
@@ -336,9 +370,7 @@ class EvaluationPackage:
|
|
|
336
370
|
) = self.traverse()
|
|
337
371
|
if self.is_analyze_run:
|
|
338
372
|
print(labelled_messages)
|
|
339
|
-
|
|
340
|
-
1 for msg in labelled_messages if "_WRONG_FUNCTION_CALL" in msg
|
|
341
|
-
)
|
|
373
|
+
|
|
342
374
|
is_success = self.is_topological_sort(
|
|
343
375
|
self.ground_truth.goals, labelled_messages
|
|
344
376
|
)
|
|
@@ -359,28 +391,13 @@ class EvaluationPackage:
|
|
|
359
391
|
knowledge_base_metric_summary = self.generate_knowledge_base_metric_summary()
|
|
360
392
|
# TO-DO: the table is not printing properly anymore with the new columns introduced
|
|
361
393
|
# we need to introduce a separate table for these.
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
"Wrong Function Calls": wrong_call_count,
|
|
368
|
-
# "Bad Calls": 0,
|
|
369
|
-
"Wrong Parameters": sum(
|
|
370
|
-
1 for msg in labelled_messages if "_WRONG_PARAMETERS" in msg
|
|
371
|
-
),
|
|
372
|
-
"Wrong Routing Calls": sum(
|
|
373
|
-
1 for msg in labelled_messages if "_WRONG_ROUTING_CALL" in msg
|
|
374
|
-
),
|
|
375
|
-
"Text Match": match,
|
|
376
|
-
"Journey Success": is_success,
|
|
377
|
-
# "Tool Call Accuracy": metrics.tool_call_accuracy,
|
|
378
|
-
# "Tool Call Relevancy": metrics.tool_call_relevancy,
|
|
379
|
-
# "Agent Routing Accuracy": metrics.agent_routing_accuracy
|
|
380
|
-
}
|
|
394
|
+
|
|
395
|
+
metrics.total_steps = total_step
|
|
396
|
+
metrics.llm_step = llm_steps
|
|
397
|
+
metrics.text_match = match
|
|
398
|
+
metrics.is_success = is_success
|
|
381
399
|
|
|
382
400
|
return (
|
|
383
|
-
data,
|
|
384
401
|
matches,
|
|
385
402
|
knowledge_base_metric_summary,
|
|
386
403
|
message_with_reasons,
|
|
@@ -512,7 +529,7 @@ if __name__ == "__main__":
|
|
|
512
529
|
evaluate_package = EvaluationPackage(
|
|
513
530
|
test_case_name="data1.messages.json",
|
|
514
531
|
ground_truth=ground_truth,
|
|
515
|
-
messages=messages
|
|
532
|
+
messages=messages
|
|
516
533
|
)
|
|
517
534
|
print(evaluate_package.generate_summary())
|
|
518
535
|
# print(evaluate_package.traverse())
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import importlib.resources
|
|
2
|
+
import json
|
|
3
|
+
import rich
|
|
4
|
+
|
|
5
|
+
from wxo_agentic_evaluation.prompt.template_render import StoryGenerationTemplateRenderer
|
|
6
|
+
from wxo_agentic_evaluation.service_provider import get_provider, ProviderConfig
|
|
7
|
+
from wxo_agentic_evaluation import prompt
|
|
8
|
+
|
|
9
|
+
console = rich.console.Console()
|
|
10
|
+
|
|
11
|
+
def starting_sentence_generation_prompt():
|
|
12
|
+
with importlib.resources.path(prompt, "starting_sentence_generation_prompt.jinja2") as fp:
|
|
13
|
+
# reuse the StoryGenerationTemplateRenderer class, even though we are generating a "starting_sentence" instead of a "story"
|
|
14
|
+
# the starting sentence generation prompts uses the same input variable
|
|
15
|
+
render = StoryGenerationTemplateRenderer(str(fp))
|
|
16
|
+
|
|
17
|
+
return render
|
|
18
|
+
|
|
19
|
+
def generate_starting_sentence(annotated_data: dict):
|
|
20
|
+
renderer = starting_sentence_generation_prompt()
|
|
21
|
+
llm_decode_parameter = {
|
|
22
|
+
"min_new_tokens": 0,
|
|
23
|
+
"decoding_method": "greedy",
|
|
24
|
+
"max_new_tokens": 4096,
|
|
25
|
+
}
|
|
26
|
+
wai_client = get_provider(config=ProviderConfig(), params=llm_decode_parameter)
|
|
27
|
+
prompt = renderer.render(input_data=json.dumps(annotated_data, indent=4))
|
|
28
|
+
res = wai_client.query(prompt)
|
|
29
|
+
res = res.strip()
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
# ideally the LLM outputted a dictionary like: {"starting_sentence": "lorem ipsum"}
|
|
33
|
+
res = json.loads(res)
|
|
34
|
+
return res["starting_sentence"]
|
|
35
|
+
except Exception:
|
|
36
|
+
console.log(f"The generated `starting_sentence` had incorrect format: '{res}'")
|
|
37
|
+
return res
|
|
@@ -1,12 +1,16 @@
|
|
|
1
|
-
from wxo_agentic_evaluation.external_agent.types import UniversalData
|
|
2
|
-
import requests
|
|
3
1
|
from typing import Generator
|
|
2
|
+
import requests
|
|
4
3
|
import json
|
|
4
|
+
import rich
|
|
5
|
+
|
|
6
|
+
from wxo_agentic_evaluation.external_agent.types import UniversalData, SchemaValidationResults
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
MESSAGES = [
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
MESSAGES = [
|
|
10
|
+
{"role": "user", "content": "what's the holiday is June 13th in us?"},
|
|
11
|
+
{"role": "assistant", "content": "tool_name: calendar_lookup, args {\"location\": \"USA\", \"data\": \"06-13-2025\"}}"},
|
|
12
|
+
{"role": "assistant", "content":"it's National Sewing Machine Day"}
|
|
13
|
+
]
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
class ExternalAgentValidation:
|
|
@@ -14,20 +18,20 @@ class ExternalAgentValidation:
|
|
|
14
18
|
self.credential = credential
|
|
15
19
|
self.auth_scheme = auth_scheme
|
|
16
20
|
self.service_url = service_url
|
|
17
|
-
|
|
18
|
-
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def header(self):
|
|
24
|
+
header = {"Content-Type": "application/json"}
|
|
19
25
|
if self.auth_scheme == "API_KEY":
|
|
20
|
-
header = {"
|
|
21
|
-
|
|
26
|
+
header = {"X-API-Key": self.credential}
|
|
22
27
|
elif self.auth_scheme == "BEARER_TOKEN":
|
|
23
28
|
header = {"Authorization": f"Bearer {self.credential}"}
|
|
24
|
-
|
|
25
29
|
else:
|
|
26
30
|
raise Exception(f"Auth scheme: {self.auth_scheme} is not supported")
|
|
27
31
|
|
|
28
32
|
return header
|
|
29
33
|
|
|
30
|
-
def
|
|
34
|
+
def _parse_streaming_events(self, resp: Generator[bytes, None, None]):
|
|
31
35
|
data = b''
|
|
32
36
|
for chunk in resp:
|
|
33
37
|
for line in chunk.splitlines(True):
|
|
@@ -37,31 +41,72 @@ class ExternalAgentValidation:
|
|
|
37
41
|
return
|
|
38
42
|
data += line
|
|
39
43
|
if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
|
|
44
|
+
# NOTE: edge case, "data" can be sent in two different chunks
|
|
45
|
+
if data.startswith(b'data:'):
|
|
46
|
+
data = data.replace(b'data:', b'')
|
|
40
47
|
yield data
|
|
41
48
|
data = b''
|
|
42
49
|
if data:
|
|
43
50
|
yield data
|
|
44
|
-
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
new_messages = []
|
|
50
|
-
new_messages.extend(MESSAGES)
|
|
51
|
-
new_messages.append({"role": "user", "content": input})
|
|
52
|
-
|
|
53
|
-
payload = {"messages": new_messages}
|
|
54
|
-
|
|
55
|
-
resp = requests.post(url=self.service_url, headers=header, json=payload, stream=True)
|
|
56
|
-
results = []
|
|
57
|
-
for json_str in self._parse_streaming_evenst(resp):
|
|
51
|
+
|
|
52
|
+
def _validate_streaming_response(self, resp):
|
|
53
|
+
success = True
|
|
54
|
+
logged_events = []
|
|
55
|
+
for json_str in self._parse_streaming_events(resp):
|
|
58
56
|
json_dict = None
|
|
57
|
+
logged_events.append(json_str)
|
|
59
58
|
try:
|
|
60
59
|
json_dict = json.loads(json_str)
|
|
61
60
|
UniversalData(**json_dict)
|
|
62
|
-
results.append(json_dict)
|
|
63
61
|
except Exception as e:
|
|
64
|
-
|
|
65
|
-
|
|
62
|
+
success = False
|
|
63
|
+
break
|
|
64
|
+
|
|
65
|
+
return success, logged_events
|
|
66
|
+
|
|
67
|
+
def _validate_schema_compliance(self, messages):
|
|
68
|
+
payload = {"stream": True}
|
|
69
|
+
payload["messages"] = messages
|
|
70
|
+
resp = requests.post(url=self.service_url, headers=self.header, json=payload)
|
|
71
|
+
success, logged_events = self._validate_streaming_response(resp)
|
|
72
|
+
|
|
73
|
+
msg = ", ".join([msg["content"] for msg in payload["messages"]])
|
|
74
|
+
|
|
75
|
+
if success:
|
|
76
|
+
rich.print(f":white_check_mark: External Agent streaming response validation succeeded for '{msg}'.")
|
|
77
|
+
else:
|
|
78
|
+
rich.print(f":heavy_exclamation_mark:Schema validation failed for messages: '{msg}':heavy_exclamation_mark:\n The last logged event was {logged_events[-1]}.\n")
|
|
79
|
+
|
|
80
|
+
return success, logged_events
|
|
81
|
+
|
|
82
|
+
def call_validation(self, input_str: str, add_context: bool = False) -> SchemaValidationResults:
|
|
83
|
+
if add_context:
|
|
84
|
+
return self.block_validation(input_str)
|
|
85
|
+
|
|
86
|
+
msg = {
|
|
87
|
+
"role": "user",
|
|
88
|
+
"content": input_str
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
success, logged_events = self._validate_schema_compliance([msg])
|
|
92
|
+
results = SchemaValidationResults(success=success, logged_events=logged_events, messages=[msg])
|
|
93
|
+
|
|
94
|
+
return results.model_dump()
|
|
95
|
+
|
|
96
|
+
def block_validation(self, input_str: str) -> SchemaValidationResults:
|
|
97
|
+
""" Tests a block of messages
|
|
98
|
+
"""
|
|
99
|
+
rich.print(
|
|
100
|
+
f"[gold3]The following prebuilt messages, '{MESSAGES}' is prepended to the input message, '{input_str}'"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
msg = {
|
|
104
|
+
"role": "user",
|
|
105
|
+
"content": input_str
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
messages = MESSAGES + [msg]
|
|
109
|
+
success, logged_events = self._validate_schema_compliance(messages)
|
|
110
|
+
results = SchemaValidationResults(success=success, logged_events=logged_events, messages=messages)
|
|
66
111
|
|
|
67
|
-
return results
|
|
112
|
+
return results.model_dump()
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from typing import List, Mapping, Any
|
|
2
|
+
from rich.console import Console
|
|
3
|
+
|
|
4
|
+
from wxo_agentic_evaluation.external_agent import generate_starting_sentence
|
|
5
|
+
from wxo_agentic_evaluation.arg_configs import KeywordsGenerationConfig
|
|
6
|
+
from wxo_agentic_evaluation.service_provider import get_provider, ProviderConfig
|
|
7
|
+
from wxo_agentic_evaluation.data_annotator import KeywordsGenerationLLM, LlamaKeywordsGenerationTemplateRenderer
|
|
8
|
+
|
|
9
|
+
class ExternalAgentPerformanceTest:
|
|
10
|
+
def __init__(self, agent_name: str, test_data: List[str]):
|
|
11
|
+
self.test_data = test_data
|
|
12
|
+
self.goal_template = {
|
|
13
|
+
"agent": agent_name,
|
|
14
|
+
"goals": {"summarize": []},
|
|
15
|
+
"goal_details": [
|
|
16
|
+
],
|
|
17
|
+
"story": "<placeholder>",
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
kw_gen_config = KeywordsGenerationConfig()
|
|
21
|
+
|
|
22
|
+
provider_config = ProviderConfig(model_id=kw_gen_config.model_id)
|
|
23
|
+
llm_decode_parameter = {
|
|
24
|
+
"min_new_tokens": 0,
|
|
25
|
+
"decoding_method": "greedy",
|
|
26
|
+
"max_new_tokens": 256,
|
|
27
|
+
}
|
|
28
|
+
wai_client = get_provider(config=provider_config, params=llm_decode_parameter)
|
|
29
|
+
|
|
30
|
+
self.kw_gen = KeywordsGenerationLLM(
|
|
31
|
+
provider=wai_client,
|
|
32
|
+
template=LlamaKeywordsGenerationTemplateRenderer(
|
|
33
|
+
kw_gen_config.prompt_config
|
|
34
|
+
),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def generate_tests(self) -> List[Mapping[str, Any]]:
|
|
38
|
+
console = Console()
|
|
39
|
+
goal_templates = []
|
|
40
|
+
|
|
41
|
+
with console.status("[gold3]Creating starting sentence for user story from input file for performance testing") as status:
|
|
42
|
+
for sentence, response in self.test_data:
|
|
43
|
+
goal_temp = self.goal_template.copy()
|
|
44
|
+
goal_temp["story"] = sentence
|
|
45
|
+
|
|
46
|
+
keywords = self.kw_gen.genereate_keywords(response)
|
|
47
|
+
summarize_step = {
|
|
48
|
+
"name": "summarize",
|
|
49
|
+
"type": "text",
|
|
50
|
+
"response": response,
|
|
51
|
+
"keywords": keywords
|
|
52
|
+
}
|
|
53
|
+
goal_temp["goal_details"] = [summarize_step]
|
|
54
|
+
goal_temp["starting_sentence"] = generate_starting_sentence(goal_temp)
|
|
55
|
+
|
|
56
|
+
goal_templates.append(goal_temp)
|
|
57
|
+
|
|
58
|
+
status.stop()
|
|
59
|
+
console.print("[bold green]Done creating starting sentence from provided input data")
|
|
60
|
+
|
|
61
|
+
return goal_templates
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == "__main__":
|
|
65
|
+
t = ExternalAgentPerformanceTest("test")
|
|
66
|
+
t.generate_tests()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pydantic import BaseModel
|
|
2
|
-
from typing import List, Union, Literal
|
|
2
|
+
from typing import List, Union, Literal, Mapping, Any
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class ThinkingStepDetails(BaseModel):
|
|
@@ -62,4 +62,10 @@ class UniversalData(BaseEventData):
|
|
|
62
62
|
object: Union[Literal["thread.message.delta"], Literal["thread.run.step.delta"],
|
|
63
63
|
Literal["thread.run.step.created"], Literal["thread.run.step.completed"]]
|
|
64
64
|
choices: List[ThreadMessageDeltaChoice]
|
|
65
|
-
choices: List[Union[ThreadMessageDeltaChoice, dict]]
|
|
65
|
+
choices: List[Union[ThreadMessageDeltaChoice, dict]]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SchemaValidationResults(BaseModel):
|
|
69
|
+
success: bool
|
|
70
|
+
logged_events: List[str]
|
|
71
|
+
messages: List[Mapping[Any, Any]]
|