ibm-watsonx-orchestrate-evaluation-framework 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/METADATA +70 -7
- ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info/RECORD +56 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +3 -3
- wxo_agentic_evaluation/analytics/tools/ux.py +1 -1
- wxo_agentic_evaluation/analyze_run.py +10 -10
- wxo_agentic_evaluation/arg_configs.py +8 -1
- wxo_agentic_evaluation/batch_annotate.py +3 -9
- wxo_agentic_evaluation/data_annotator.py +50 -36
- wxo_agentic_evaluation/evaluation_package.py +102 -85
- wxo_agentic_evaluation/external_agent/__init__.py +37 -0
- wxo_agentic_evaluation/external_agent/external_validate.py +74 -29
- wxo_agentic_evaluation/external_agent/performance_test.py +66 -0
- wxo_agentic_evaluation/external_agent/types.py +8 -2
- wxo_agentic_evaluation/inference_backend.py +45 -50
- wxo_agentic_evaluation/llm_matching.py +6 -6
- wxo_agentic_evaluation/llm_rag_eval.py +4 -4
- wxo_agentic_evaluation/llm_user.py +3 -3
- wxo_agentic_evaluation/main.py +63 -23
- wxo_agentic_evaluation/metrics/metrics.py +59 -0
- wxo_agentic_evaluation/prompt/args_extractor_prompt.jinja2 +23 -0
- wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +2 -0
- wxo_agentic_evaluation/prompt/examples/data_simple.json +1 -2
- wxo_agentic_evaluation/prompt/starting_sentence_generation_prompt.jinja2 +195 -0
- wxo_agentic_evaluation/prompt/story_generation_prompt.jinja2 +154 -0
- wxo_agentic_evaluation/prompt/template_render.py +17 -0
- wxo_agentic_evaluation/prompt/tool_planner.jinja2 +13 -7
- wxo_agentic_evaluation/record_chat.py +74 -26
- wxo_agentic_evaluation/resource_map.py +47 -0
- wxo_agentic_evaluation/service_provider/__init__.py +35 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +108 -0
- wxo_agentic_evaluation/service_provider/ollama_provider.py +40 -0
- wxo_agentic_evaluation/service_provider/provider.py +19 -0
- wxo_agentic_evaluation/{watsonx_provider.py → service_provider/watsonx_provider.py} +27 -18
- wxo_agentic_evaluation/test_prompt.py +94 -0
- wxo_agentic_evaluation/tool_planner.py +130 -17
- wxo_agentic_evaluation/type.py +0 -57
- wxo_agentic_evaluation/utils/utils.py +6 -54
- ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/RECORD +0 -46
- ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/licenses/LICENSE +0 -22
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
from threading import Lock
|
|
7
|
+
|
|
8
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
9
|
+
from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
|
|
10
|
+
|
|
11
|
+
AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
|
|
12
|
+
AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
|
|
13
|
+
WO_INSTANCE = os.environ.get("WO_INSTANCE")
|
|
14
|
+
WO_API_KEY = os.environ.get("WO_API_KEY")
|
|
15
|
+
DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelProxyProvider(Provider):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_id=None,
|
|
22
|
+
api_key=WO_API_KEY,
|
|
23
|
+
instance_url=WO_INSTANCE,
|
|
24
|
+
timeout=300,
|
|
25
|
+
embedding_model_id=None,
|
|
26
|
+
params=None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
if not instance_url or not api_key:
|
|
31
|
+
raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
|
|
32
|
+
|
|
33
|
+
self.timeout = timeout
|
|
34
|
+
self.model_id = model_id
|
|
35
|
+
|
|
36
|
+
self.embedding_model_id = embedding_model_id
|
|
37
|
+
|
|
38
|
+
self.api_key = api_key
|
|
39
|
+
self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
|
|
40
|
+
self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
|
|
41
|
+
self.url = instance_url + "/ml/v1/text/generation?version=2024-05-01"
|
|
42
|
+
|
|
43
|
+
self.embedding_url = instance_url + "/ml/v1/text/embeddings"
|
|
44
|
+
|
|
45
|
+
self.lock = Lock()
|
|
46
|
+
self.token, self.refresh_time = self.get_token()
|
|
47
|
+
self.params = params if params else DEFAULT_PARAM
|
|
48
|
+
|
|
49
|
+
def get_token(self):
|
|
50
|
+
if self.is_ibm_cloud:
|
|
51
|
+
payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
|
|
52
|
+
resp = requests.post(self.auth_url, data=payload)
|
|
53
|
+
token_key = "access_token"
|
|
54
|
+
else:
|
|
55
|
+
payload = {"apikey": self.api_key}
|
|
56
|
+
resp = requests.post(self.auth_url, json=payload)
|
|
57
|
+
token_key = "token"
|
|
58
|
+
if resp.status_code == 200:
|
|
59
|
+
json_obj = resp.json()
|
|
60
|
+
token = json_obj[token_key]
|
|
61
|
+
expires_in = json_obj["expires_in"]
|
|
62
|
+
refresh_time = time.time() + int(0.8*expires_in)
|
|
63
|
+
return token, refresh_time
|
|
64
|
+
|
|
65
|
+
resp.raise_for_status()
|
|
66
|
+
|
|
67
|
+
def refresh_token_if_expires(self):
|
|
68
|
+
if time.time() > self.refresh_time:
|
|
69
|
+
with self.lock:
|
|
70
|
+
if time.time() > self.refresh_time:
|
|
71
|
+
self.token, self.refresh_time = self.get_token()
|
|
72
|
+
|
|
73
|
+
def get_header(self):
|
|
74
|
+
return {"Authorization": f"Bearer {self.token}"}
|
|
75
|
+
|
|
76
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
77
|
+
if self.embedding_model_id is None:
|
|
78
|
+
raise Exception("embedding model id must be specified for text generation")
|
|
79
|
+
|
|
80
|
+
self.refresh_token_if_expires()
|
|
81
|
+
headers = self.get_header()
|
|
82
|
+
payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
|
|
83
|
+
#"timeout": self.timeout}
|
|
84
|
+
resp = requests.post(self.embedding_url, json=payload, headers=headers)
|
|
85
|
+
|
|
86
|
+
if resp.status_code == 200:
|
|
87
|
+
json_obj = resp.json()
|
|
88
|
+
return json_obj["generated_text"]
|
|
89
|
+
|
|
90
|
+
resp.raise_for_status()
|
|
91
|
+
|
|
92
|
+
def query(self, sentence: str) -> str:
|
|
93
|
+
if self.model_id is None:
|
|
94
|
+
raise Exception("model id must be specified for text generation")
|
|
95
|
+
self.refresh_token_if_expires()
|
|
96
|
+
headers = self.get_header()
|
|
97
|
+
payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
|
|
98
|
+
"timeout": self.timeout, "parameters": self.params}
|
|
99
|
+
resp = requests.post(self.url, json=payload, headers=headers)
|
|
100
|
+
if resp.status_code == 200:
|
|
101
|
+
return resp.json()["results"][0]["generated_text"]
|
|
102
|
+
|
|
103
|
+
resp.raise_for_status()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
|
|
108
|
+
print(provider.query("ok"))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import json
|
|
3
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
4
|
+
from typing import List
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OllamaProvider(Provider):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
model_id=None
|
|
14
|
+
):
|
|
15
|
+
self.url = OLLAMA_URL + "/api/generate"
|
|
16
|
+
self.model_id = model_id
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
def query(self, sentence: str) -> str:
|
|
20
|
+
payload = {"model": self.model_id, "prompt": sentence}
|
|
21
|
+
resp = requests.post(self.url, json=payload, stream=True)
|
|
22
|
+
final_text = ""
|
|
23
|
+
data = b''
|
|
24
|
+
for chunk in resp:
|
|
25
|
+
data += chunk
|
|
26
|
+
if data.endswith(b'\n'):
|
|
27
|
+
json_obj = json.loads(data)
|
|
28
|
+
if not json_obj["done"] and json_obj["response"]:
|
|
29
|
+
final_text += json_obj["response"]
|
|
30
|
+
data = b''
|
|
31
|
+
|
|
32
|
+
return final_text
|
|
33
|
+
|
|
34
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if __name__ == "__main__":
|
|
39
|
+
provider = OllamaProvider(model_id="llama3.1:8b")
|
|
40
|
+
print(provider.query("ok"))
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Provider(ABC):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def query(self, sentence: str) -> str:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def batch_query(self, sentences: List[str]) -> List[str]:
|
|
14
|
+
return [self.query(sentence) for sentence in sentences]
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
18
|
+
pass
|
|
19
|
+
|
|
@@ -5,7 +5,8 @@ from types import MappingProxyType
|
|
|
5
5
|
from typing import List
|
|
6
6
|
import dataclasses
|
|
7
7
|
from threading import Lock
|
|
8
|
-
|
|
8
|
+
import time
|
|
9
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
9
10
|
|
|
10
11
|
ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
|
|
11
12
|
ACCESS_HEADER = {
|
|
@@ -16,11 +17,11 @@ ACCESS_HEADER = {
|
|
|
16
17
|
YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
|
|
17
18
|
PROD_URL = "https://us-south.ml.cloud.ibm.com"
|
|
18
19
|
DEFAULT_PARAM = MappingProxyType(
|
|
19
|
-
{"min_new_tokens":
|
|
20
|
+
{"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
class WatsonXProvider:
|
|
24
|
+
class WatsonXProvider(Provider):
|
|
24
25
|
def __init__(
|
|
25
26
|
self,
|
|
26
27
|
model_id=None,
|
|
@@ -29,7 +30,7 @@ class WatsonXProvider:
|
|
|
29
30
|
api_endpoint=PROD_URL,
|
|
30
31
|
url=ACCESS_URL,
|
|
31
32
|
timeout=60,
|
|
32
|
-
|
|
33
|
+
params=None,
|
|
33
34
|
embedding_model_id=None,
|
|
34
35
|
):
|
|
35
36
|
super().__init__()
|
|
@@ -54,12 +55,15 @@ class WatsonXProvider:
|
|
|
54
55
|
self.embedding_model_id = embedding_model_id
|
|
55
56
|
self.lock = Lock()
|
|
56
57
|
|
|
57
|
-
if
|
|
58
|
-
|
|
59
|
-
if
|
|
60
|
-
|
|
58
|
+
self.params = params if params else DEFAULT_PARAM
|
|
59
|
+
|
|
60
|
+
if isinstance(self.params, MappingProxyType):
|
|
61
|
+
self.params = dict(self.params)
|
|
62
|
+
if dataclasses.is_dataclass(self.params):
|
|
63
|
+
self.params = dataclasses.asdict(self.params)
|
|
61
64
|
|
|
62
|
-
self.
|
|
65
|
+
self.refresh_time = None
|
|
66
|
+
self.access_token = None
|
|
63
67
|
self._refresh_token()
|
|
64
68
|
|
|
65
69
|
def _get_access_token(self):
|
|
@@ -69,8 +73,11 @@ class WatsonXProvider:
|
|
|
69
73
|
if response.status_code == 200:
|
|
70
74
|
token_data = json.loads(response.text)
|
|
71
75
|
token = token_data["access_token"]
|
|
72
|
-
|
|
73
|
-
|
|
76
|
+
expiration = token_data["expiration"]
|
|
77
|
+
expires_in = token_data["expires_in"]
|
|
78
|
+
# 9 minutes before expire
|
|
79
|
+
refresh_time = expiration - int(0.15 * expires_in)
|
|
80
|
+
return token, refresh_time
|
|
74
81
|
|
|
75
82
|
raise RuntimeError(
|
|
76
83
|
f"try to acquire access token and get {response.status_code}"
|
|
@@ -85,7 +92,7 @@ class WatsonXProvider:
|
|
|
85
92
|
headers = self.prepare_header()
|
|
86
93
|
|
|
87
94
|
data = {"model_id": self.model_id, "input": sentence,
|
|
88
|
-
"parameters": self.
|
|
95
|
+
"parameters": self.params, "space_id": self.space_id}
|
|
89
96
|
generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
|
|
90
97
|
resp = requests.post(url=generation_url, headers=headers, json=data)
|
|
91
98
|
if resp.status_code == 200:
|
|
@@ -94,13 +101,17 @@ class WatsonXProvider:
|
|
|
94
101
|
resp.raise_for_status()
|
|
95
102
|
|
|
96
103
|
def _refresh_token(self):
|
|
97
|
-
|
|
104
|
+
# if we do not have a token or the current timestamp is 9 minutes away from expire.
|
|
105
|
+
if not self.access_token or time.time() > self.refresh_time:
|
|
106
|
+
with self.lock:
|
|
107
|
+
if not self.access_token or time.time() > self.refresh_time:
|
|
108
|
+
self.access_token, self.refresh_time = self._get_access_token()
|
|
98
109
|
|
|
99
|
-
def query(self, sentence: str) ->
|
|
110
|
+
def query(self, sentence: str) -> str:
|
|
100
111
|
if self.model_id is None:
|
|
101
112
|
raise Exception("model id must be specified for text generation")
|
|
102
113
|
try:
|
|
103
|
-
return self.generate(sentence)
|
|
114
|
+
return self.generate(sentence)["generated_text"]
|
|
104
115
|
except Exception as e:
|
|
105
116
|
with self.lock:
|
|
106
117
|
if "authentication_token_expired" in str(e):
|
|
@@ -126,8 +137,6 @@ class WatsonXProvider:
|
|
|
126
137
|
|
|
127
138
|
|
|
128
139
|
if __name__ == "__main__":
|
|
129
|
-
import os
|
|
130
|
-
|
|
131
140
|
provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
|
|
132
141
|
|
|
133
142
|
prompt = """
|
|
@@ -160,4 +169,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
|
|
|
160
169
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
161
170
|
"""
|
|
162
171
|
|
|
163
|
-
print(provider.
|
|
172
|
+
print(provider.query(prompt))
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def parse_json_string(input_string):
|
|
6
|
+
json_char_count = 0
|
|
7
|
+
json_objects = []
|
|
8
|
+
current_json = ""
|
|
9
|
+
brace_level = 0
|
|
10
|
+
inside_json = False
|
|
11
|
+
|
|
12
|
+
for i, char in enumerate(input_string):
|
|
13
|
+
if char == "{":
|
|
14
|
+
brace_level += 1
|
|
15
|
+
inside_json = True
|
|
16
|
+
json_char_count += 1
|
|
17
|
+
if inside_json:
|
|
18
|
+
current_json += char
|
|
19
|
+
json_char_count += 1
|
|
20
|
+
if char == "}":
|
|
21
|
+
json_char_count += 1
|
|
22
|
+
brace_level -= 1
|
|
23
|
+
if brace_level == 0:
|
|
24
|
+
inside_json = False
|
|
25
|
+
try:
|
|
26
|
+
json_objects.append(json.loads(current_json))
|
|
27
|
+
except json.JSONDecodeError as e:
|
|
28
|
+
print(f"Error decoding JSON: {e}")
|
|
29
|
+
current_json = "" # Reset current JSON string
|
|
30
|
+
# some threshold to say there are some non-funct calling step
|
|
31
|
+
is_thinking_step = len(input_string) - json_char_count > 10
|
|
32
|
+
return json_objects
|
|
33
|
+
|
|
34
|
+
wai_client = WatsonXProvider(model_id="meta-llama/llama-3-405b-instruct")
|
|
35
|
+
|
|
36
|
+
prompt = """
|
|
37
|
+
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
38
|
+
You are trying to make tool calls. Given a raw input and tool output. Try to extract the information to make the tool call
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
Tool description:
|
|
42
|
+
def get_payslips(user_id: str) -> PayslipsResponse:
|
|
43
|
+
Gets a user's payslips from Workday.
|
|
44
|
+
|
|
45
|
+
:param user_id: The user's id uniquely identifying them within the Workday API.
|
|
46
|
+
:return: The user's payslips.
|
|
47
|
+
|
|
48
|
+
Raw inputs:\{"tool_name": "get_payslips", "args": {"user_id": '$get_user_workday_ids'}}
|
|
49
|
+
tool output: {'user_id': UserWorkdayIDs(person_id='', user_id='6dcb8106e8b74b5aabb1fc3ab8ef2b92')}
|
|
50
|
+
<|start_header_id|>ipython<|end_header_id|>
|
|
51
|
+
{"tool_name": "get_payslips", "args": {"user_id": "6dcb8106e8b74b5aabb1fc3ab8ef2b92"}}
|
|
52
|
+
<|eot_id|>
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
test_sample1 = """
|
|
57
|
+
<|start_header_id|>assistant<|end_header_id|>
|
|
58
|
+
Tool description:
|
|
59
|
+
def update_direct_reports(email_id: str, members: List[str], notification:bool) -> PayslipsResponse:
|
|
60
|
+
update direct reports for a given user
|
|
61
|
+
:param email_id: The user's email-id uniquely identifying them within the Workday API.
|
|
62
|
+
:param members: a list of user ids to be added as direct reports
|
|
63
|
+
:param notification: do we send the notification to all members
|
|
64
|
+
|
|
65
|
+
Raw inputs: {"tool_name": "update_direct_reports", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
|
|
66
|
+
tool output: {"email_id": 'jalenm3@163.com'}
|
|
67
|
+
{'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
|
|
68
|
+
<|start_header_id|>ipython<|end_header_id|>"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
test_sample2 = """
|
|
72
|
+
<|start_header_id|>assistant<|end_header_id|>
|
|
73
|
+
Tool description:
|
|
74
|
+
def book_meeting(location: str, date: str, time: str) -> bool:
|
|
75
|
+
update direct reports for a given user
|
|
76
|
+
:param email_id: The user's email-id uniquely identifying them within the Workday API.
|
|
77
|
+
:param members: a list of user ids to be added as direct reports
|
|
78
|
+
:param notification: do we send the notification to all members
|
|
79
|
+
|
|
80
|
+
Raw inputs: {"tool_name": "book_meeting", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
|
|
81
|
+
tool output: {"email_id": 'jalenm3@163.com'}
|
|
82
|
+
{'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
|
|
83
|
+
<|start_header_id|>ipython<|end_header_id|>"""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
outputs = wai_client.query(prompt + test_sample1)
|
|
88
|
+
|
|
89
|
+
import json
|
|
90
|
+
print(outputs["generated_text"])
|
|
91
|
+
|
|
92
|
+
json_obj = parse_json_string(outputs["generated_text"])[0]
|
|
93
|
+
|
|
94
|
+
print(json_obj)
|
|
@@ -6,15 +6,25 @@ import importlib.util
|
|
|
6
6
|
import re
|
|
7
7
|
from jsonargparse import CLI
|
|
8
8
|
import os
|
|
9
|
+
import textwrap
|
|
10
|
+
from dataclasses import is_dataclass, asdict
|
|
9
11
|
|
|
10
|
-
from wxo_agentic_evaluation.
|
|
12
|
+
from wxo_agentic_evaluation.service_provider import get_provider
|
|
11
13
|
from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
|
|
12
|
-
from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer,
|
|
14
|
+
from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
|
|
13
15
|
from wxo_agentic_evaluation import __file__
|
|
14
16
|
|
|
15
17
|
root_dir = os.path.dirname(__file__)
|
|
16
18
|
TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
|
|
19
|
+
ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
|
|
17
20
|
|
|
21
|
+
class UniversalEncoder(json.JSONEncoder):
|
|
22
|
+
def default(self, obj):
|
|
23
|
+
if is_dataclass(obj):
|
|
24
|
+
return asdict(obj)
|
|
25
|
+
elif hasattr(obj, "__dict__"):
|
|
26
|
+
return obj.__dict__
|
|
27
|
+
return super().default(obj)
|
|
18
28
|
|
|
19
29
|
def extract_first_json_list(raw: str) -> list:
|
|
20
30
|
matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
|
|
@@ -29,6 +39,33 @@ def extract_first_json_list(raw: str) -> list:
|
|
|
29
39
|
print(raw)
|
|
30
40
|
return []
|
|
31
41
|
|
|
42
|
+
def parse_json_string(input_string):
|
|
43
|
+
json_char_count = 0
|
|
44
|
+
json_objects = []
|
|
45
|
+
current_json = ""
|
|
46
|
+
brace_level = 0
|
|
47
|
+
inside_json = False
|
|
48
|
+
|
|
49
|
+
for i, char in enumerate(input_string):
|
|
50
|
+
if char == "{":
|
|
51
|
+
brace_level += 1
|
|
52
|
+
inside_json = True
|
|
53
|
+
json_char_count += 1
|
|
54
|
+
if inside_json:
|
|
55
|
+
current_json += char
|
|
56
|
+
json_char_count += 1
|
|
57
|
+
if char == "}":
|
|
58
|
+
json_char_count += 1
|
|
59
|
+
brace_level -= 1
|
|
60
|
+
if brace_level == 0:
|
|
61
|
+
inside_json = False
|
|
62
|
+
try:
|
|
63
|
+
json_objects.append(json.loads(current_json))
|
|
64
|
+
except json.JSONDecodeError as e:
|
|
65
|
+
print(f"Error decoding JSON: {e}")
|
|
66
|
+
current_json = "" # Reset current JSON string
|
|
67
|
+
return json_objects
|
|
68
|
+
|
|
32
69
|
|
|
33
70
|
def load_tools_module(tools_path: Path) -> dict:
|
|
34
71
|
tools_dict = {}
|
|
@@ -93,8 +130,64 @@ def extract_tool_signatures(tools_path: Path) -> list:
|
|
|
93
130
|
|
|
94
131
|
return tool_data
|
|
95
132
|
|
|
133
|
+
def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
|
|
134
|
+
functions = {}
|
|
135
|
+
files_to_parse = []
|
|
136
|
+
|
|
137
|
+
# Handle both single file and directory cases
|
|
138
|
+
if tools_path.is_file():
|
|
139
|
+
files_to_parse.append(tools_path)
|
|
140
|
+
elif tools_path.is_dir():
|
|
141
|
+
files_to_parse.extend(tools_path.glob("**/*.py"))
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
|
|
144
|
+
|
|
145
|
+
for file_path in files_to_parse:
|
|
146
|
+
try:
|
|
147
|
+
with file_path.open("r", encoding="utf-8") as f:
|
|
148
|
+
code = f.read()
|
|
149
|
+
parsed_code = ast.parse(code)
|
|
150
|
+
|
|
151
|
+
for node in parsed_code.body:
|
|
152
|
+
if isinstance(node, ast.FunctionDef):
|
|
153
|
+
name = node.name
|
|
154
|
+
|
|
155
|
+
# Get args and type annotations
|
|
156
|
+
args = []
|
|
157
|
+
for arg in node.args.args:
|
|
158
|
+
if arg.arg == "self":
|
|
159
|
+
continue
|
|
160
|
+
annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
|
|
161
|
+
args.append((arg.arg, annotation))
|
|
162
|
+
|
|
163
|
+
# Get return type
|
|
164
|
+
returns = ast.unparse(node.returns) if node.returns else "None"
|
|
165
|
+
|
|
166
|
+
# Get docstring
|
|
167
|
+
docstring = ast.get_docstring(node)
|
|
168
|
+
docstring = textwrap.dedent(docstring).strip() if docstring else ""
|
|
169
|
+
|
|
170
|
+
# Format parameter descriptions if available in docstring
|
|
171
|
+
doc_lines = docstring.splitlines()
|
|
172
|
+
doc_summary = doc_lines[0] if doc_lines else ""
|
|
173
|
+
param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
|
|
174
|
+
|
|
175
|
+
# Compose the final string
|
|
176
|
+
args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
|
|
177
|
+
function_str = f"""def {name}({args_str}) -> {returns}:
|
|
178
|
+
{doc_summary}"""
|
|
179
|
+
if param_descriptions:
|
|
180
|
+
function_str += f"\n {param_descriptions}"
|
|
96
181
|
|
|
97
|
-
|
|
182
|
+
functions[name] = function_str
|
|
183
|
+
except Exception as e:
|
|
184
|
+
print(f"Warning: Failed to parse {file_path}: {str(e)}")
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
return functions
|
|
188
|
+
|
|
189
|
+
def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
|
|
190
|
+
tool_name = step["tool_name"]
|
|
98
191
|
cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
|
|
99
192
|
for entry in cache:
|
|
100
193
|
if entry["inputs"] == inputs:
|
|
@@ -103,7 +196,27 @@ def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_mo
|
|
|
103
196
|
if tool_name not in tools_module:
|
|
104
197
|
raise ValueError(f"Tool '{tool_name}' not found")
|
|
105
198
|
|
|
106
|
-
|
|
199
|
+
try:
|
|
200
|
+
output = tools_module[tool_name](**inputs)
|
|
201
|
+
except:
|
|
202
|
+
provider = get_provider(
|
|
203
|
+
model_id="meta-llama/llama-3-405b-instruct",
|
|
204
|
+
params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
|
|
205
|
+
)
|
|
206
|
+
renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
|
|
207
|
+
|
|
208
|
+
prompt = renderer.render(
|
|
209
|
+
tool_signature=tool_signatures_for_prompt[tool_name],
|
|
210
|
+
step=step,
|
|
211
|
+
inputs=inputs,
|
|
212
|
+
)
|
|
213
|
+
response = provider.query(prompt)
|
|
214
|
+
json_obj = parse_json_string(response)[0]
|
|
215
|
+
try:
|
|
216
|
+
output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
|
|
217
|
+
except:
|
|
218
|
+
raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
|
|
219
|
+
|
|
107
220
|
cache.append({"inputs": inputs, "output": output})
|
|
108
221
|
if not isinstance(output, dict):
|
|
109
222
|
print(f" Tool {tool_name} returned non-dict output: {output}")
|
|
@@ -119,15 +232,14 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
|
|
|
119
232
|
available_tools=tool_signatures_str,
|
|
120
233
|
)
|
|
121
234
|
response = provider.query(prompt)
|
|
122
|
-
|
|
123
|
-
parsed = extract_first_json_list(raw)
|
|
235
|
+
parsed = extract_first_json_list(response)
|
|
124
236
|
print("\n LLM Tool Plan:")
|
|
125
237
|
print(json.dumps(parsed, indent=2))
|
|
126
238
|
return parsed
|
|
127
239
|
|
|
128
240
|
|
|
129
241
|
# --- Tool Execution Logic ---
|
|
130
|
-
def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
|
|
242
|
+
def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
|
|
131
243
|
memory = {}
|
|
132
244
|
|
|
133
245
|
for step in tool_plan:
|
|
@@ -166,14 +278,14 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
|
|
|
166
278
|
item_inputs = resolved_inputs.copy()
|
|
167
279
|
item_inputs[list_key] = val
|
|
168
280
|
print(f" ⚙️ Running {name} with {list_key} = {val}")
|
|
169
|
-
output = ensure_data_available(
|
|
281
|
+
output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
|
|
170
282
|
results.append(output)
|
|
171
283
|
memory[f"{name}_{idx}"] = output
|
|
172
284
|
|
|
173
285
|
memory[name] = results
|
|
174
286
|
print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
|
|
175
287
|
else:
|
|
176
|
-
output = ensure_data_available(
|
|
288
|
+
output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
|
|
177
289
|
memory[name] = output
|
|
178
290
|
print(f"Stored output under tool name: {name} = {output}")
|
|
179
291
|
|
|
@@ -183,14 +295,11 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
|
|
|
183
295
|
agent = {"name": agent_name}
|
|
184
296
|
tools_module = load_tools_module(tools_path)
|
|
185
297
|
tool_signatures = extract_tool_signatures(tools_path)
|
|
298
|
+
tool_signatures_for_prompt = extract_tool_signatures_for_prompt(tools_path)
|
|
186
299
|
|
|
187
|
-
provider =
|
|
300
|
+
provider = get_provider(
|
|
188
301
|
model_id="meta-llama/llama-3-405b-instruct",
|
|
189
|
-
|
|
190
|
-
"min_new_tokens": 50,
|
|
191
|
-
"decoding_method": "greedy",
|
|
192
|
-
"max_new_tokens": 200
|
|
193
|
-
}
|
|
302
|
+
params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
|
|
194
303
|
)
|
|
195
304
|
|
|
196
305
|
snapshot = {
|
|
@@ -202,10 +311,14 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
|
|
|
202
311
|
for story in stories:
|
|
203
312
|
print(f"\n📘 Planning tool calls for story: {story}")
|
|
204
313
|
tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
|
|
205
|
-
|
|
314
|
+
try:
|
|
315
|
+
run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
|
|
316
|
+
except ValueError as e:
|
|
317
|
+
print(f"❌ Error running tool chain for story '{story}': {e}")
|
|
318
|
+
continue
|
|
206
319
|
|
|
207
320
|
with output_path.open("w", encoding="utf-8") as f:
|
|
208
|
-
json.dump(snapshot, f, indent=2)
|
|
321
|
+
json.dump(snapshot, f, indent=2, cls=UniversalEncoder)
|
|
209
322
|
print(f"\n✅ Snapshot saved to {output_path}")
|
|
210
323
|
|
|
211
324
|
|
wxo_agentic_evaluation/type.py
CHANGED
|
@@ -111,66 +111,9 @@ class GoalDetail(BaseModel):
|
|
|
111
111
|
knowledge_base: KnowledgeBaseGoalDetail = KnowledgeBaseGoalDetail()
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
class MineField(BaseModel):
|
|
115
|
-
type: ContentType
|
|
116
|
-
name: str
|
|
117
|
-
|
|
118
|
-
|
|
119
114
|
class EvaluationData(BaseModel):
|
|
120
115
|
agent: str
|
|
121
116
|
goals: Dict
|
|
122
117
|
story: str
|
|
123
|
-
mine_fields: List[MineField]
|
|
124
118
|
goal_details: List[GoalDetail]
|
|
125
119
|
starting_sentence: str = None
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
class ToolCallAndRoutingMetrics(BaseModel):
|
|
129
|
-
total_tool_calls: int
|
|
130
|
-
expected_tool_calls: int
|
|
131
|
-
relevant_tool_calls: int
|
|
132
|
-
correct_tool_calls: int
|
|
133
|
-
total_routing_calls: int
|
|
134
|
-
expected_routing_calls: int
|
|
135
|
-
|
|
136
|
-
@computed_field
|
|
137
|
-
@property
|
|
138
|
-
def non_transfer_tool_calls(self) -> int:
|
|
139
|
-
return self.total_tool_calls - self.total_routing_calls
|
|
140
|
-
|
|
141
|
-
@computed_field
|
|
142
|
-
@property
|
|
143
|
-
def tool_call_accuracy(self) -> float:
|
|
144
|
-
return round(
|
|
145
|
-
(
|
|
146
|
-
self.correct_tool_calls / self.non_transfer_tool_calls
|
|
147
|
-
if self.non_transfer_tool_calls > 0
|
|
148
|
-
else 0.0
|
|
149
|
-
),
|
|
150
|
-
2,
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
@computed_field
|
|
154
|
-
@property
|
|
155
|
-
def tool_call_relevancy(self) -> float:
|
|
156
|
-
return round(
|
|
157
|
-
(
|
|
158
|
-
(self.relevant_tool_calls - self.expected_routing_calls)
|
|
159
|
-
/ self.non_transfer_tool_calls
|
|
160
|
-
if self.non_transfer_tool_calls > 0
|
|
161
|
-
else 0.0
|
|
162
|
-
),
|
|
163
|
-
2,
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
@computed_field
|
|
167
|
-
@property
|
|
168
|
-
def agent_routing_accuracy(self) -> float:
|
|
169
|
-
return round(
|
|
170
|
-
(
|
|
171
|
-
self.expected_routing_calls / self.total_routing_calls
|
|
172
|
-
if self.total_routing_calls > 0
|
|
173
|
-
else 0.0
|
|
174
|
-
),
|
|
175
|
-
2,
|
|
176
|
-
)
|