ibm-watsonx-orchestrate-evaluation-framework 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/METADATA +70 -12
- ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +56 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +3 -3
- wxo_agentic_evaluation/analytics/tools/ux.py +1 -1
- wxo_agentic_evaluation/analyze_run.py +10 -10
- wxo_agentic_evaluation/arg_configs.py +8 -1
- wxo_agentic_evaluation/batch_annotate.py +4 -10
- wxo_agentic_evaluation/data_annotator.py +50 -36
- wxo_agentic_evaluation/evaluation_package.py +102 -85
- wxo_agentic_evaluation/external_agent/__init__.py +37 -0
- wxo_agentic_evaluation/external_agent/external_validate.py +74 -31
- wxo_agentic_evaluation/external_agent/performance_test.py +66 -0
- wxo_agentic_evaluation/external_agent/types.py +8 -2
- wxo_agentic_evaluation/inference_backend.py +45 -50
- wxo_agentic_evaluation/llm_matching.py +6 -6
- wxo_agentic_evaluation/llm_rag_eval.py +4 -4
- wxo_agentic_evaluation/llm_user.py +3 -3
- wxo_agentic_evaluation/main.py +63 -23
- wxo_agentic_evaluation/metrics/metrics.py +72 -5
- wxo_agentic_evaluation/prompt/args_extractor_prompt.jinja2 +23 -0
- wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +2 -0
- wxo_agentic_evaluation/prompt/examples/data_simple.json +1 -2
- wxo_agentic_evaluation/prompt/starting_sentence_generation_prompt.jinja2 +195 -0
- wxo_agentic_evaluation/prompt/story_generation_prompt.jinja2 +154 -0
- wxo_agentic_evaluation/prompt/template_render.py +17 -0
- wxo_agentic_evaluation/prompt/tool_planner.jinja2 +13 -7
- wxo_agentic_evaluation/record_chat.py +59 -18
- wxo_agentic_evaluation/resource_map.py +47 -0
- wxo_agentic_evaluation/service_provider/__init__.py +35 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +108 -0
- wxo_agentic_evaluation/service_provider/ollama_provider.py +40 -0
- wxo_agentic_evaluation/service_provider/provider.py +19 -0
- wxo_agentic_evaluation/{watsonx_provider.py → service_provider/watsonx_provider.py} +54 -57
- wxo_agentic_evaluation/test_prompt.py +94 -0
- wxo_agentic_evaluation/tool_planner.py +130 -17
- wxo_agentic_evaluation/type.py +0 -57
- wxo_agentic_evaluation/utils/utils.py +6 -54
- ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info/RECORD +0 -46
- ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info/licenses/LICENSE +0 -22
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
from threading import Lock
|
|
7
|
+
|
|
8
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
9
|
+
from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
|
|
10
|
+
|
|
11
|
+
AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
|
|
12
|
+
AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
|
|
13
|
+
WO_INSTANCE = os.environ.get("WO_INSTANCE")
|
|
14
|
+
WO_API_KEY = os.environ.get("WO_API_KEY")
|
|
15
|
+
DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelProxyProvider(Provider):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_id=None,
|
|
22
|
+
api_key=WO_API_KEY,
|
|
23
|
+
instance_url=WO_INSTANCE,
|
|
24
|
+
timeout=300,
|
|
25
|
+
embedding_model_id=None,
|
|
26
|
+
params=None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
if not instance_url or not api_key:
|
|
31
|
+
raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
|
|
32
|
+
|
|
33
|
+
self.timeout = timeout
|
|
34
|
+
self.model_id = model_id
|
|
35
|
+
|
|
36
|
+
self.embedding_model_id = embedding_model_id
|
|
37
|
+
|
|
38
|
+
self.api_key = api_key
|
|
39
|
+
self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
|
|
40
|
+
self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
|
|
41
|
+
self.url = instance_url + "/ml/v1/text/generation?version=2024-05-01"
|
|
42
|
+
|
|
43
|
+
self.embedding_url = instance_url + "/ml/v1/text/embeddings"
|
|
44
|
+
|
|
45
|
+
self.lock = Lock()
|
|
46
|
+
self.token, self.refresh_time = self.get_token()
|
|
47
|
+
self.params = params if params else DEFAULT_PARAM
|
|
48
|
+
|
|
49
|
+
def get_token(self):
|
|
50
|
+
if self.is_ibm_cloud:
|
|
51
|
+
payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
|
|
52
|
+
resp = requests.post(self.auth_url, data=payload)
|
|
53
|
+
token_key = "access_token"
|
|
54
|
+
else:
|
|
55
|
+
payload = {"apikey": self.api_key}
|
|
56
|
+
resp = requests.post(self.auth_url, json=payload)
|
|
57
|
+
token_key = "token"
|
|
58
|
+
if resp.status_code == 200:
|
|
59
|
+
json_obj = resp.json()
|
|
60
|
+
token = json_obj[token_key]
|
|
61
|
+
expires_in = json_obj["expires_in"]
|
|
62
|
+
refresh_time = time.time() + int(0.8*expires_in)
|
|
63
|
+
return token, refresh_time
|
|
64
|
+
|
|
65
|
+
resp.raise_for_status()
|
|
66
|
+
|
|
67
|
+
def refresh_token_if_expires(self):
|
|
68
|
+
if time.time() > self.refresh_time:
|
|
69
|
+
with self.lock:
|
|
70
|
+
if time.time() > self.refresh_time:
|
|
71
|
+
self.token, self.refresh_time = self.get_token()
|
|
72
|
+
|
|
73
|
+
def get_header(self):
|
|
74
|
+
return {"Authorization": f"Bearer {self.token}"}
|
|
75
|
+
|
|
76
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
77
|
+
if self.embedding_model_id is None:
|
|
78
|
+
raise Exception("embedding model id must be specified for text generation")
|
|
79
|
+
|
|
80
|
+
self.refresh_token_if_expires()
|
|
81
|
+
headers = self.get_header()
|
|
82
|
+
payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
|
|
83
|
+
#"timeout": self.timeout}
|
|
84
|
+
resp = requests.post(self.embedding_url, json=payload, headers=headers)
|
|
85
|
+
|
|
86
|
+
if resp.status_code == 200:
|
|
87
|
+
json_obj = resp.json()
|
|
88
|
+
return json_obj["generated_text"]
|
|
89
|
+
|
|
90
|
+
resp.raise_for_status()
|
|
91
|
+
|
|
92
|
+
def query(self, sentence: str) -> str:
|
|
93
|
+
if self.model_id is None:
|
|
94
|
+
raise Exception("model id must be specified for text generation")
|
|
95
|
+
self.refresh_token_if_expires()
|
|
96
|
+
headers = self.get_header()
|
|
97
|
+
payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
|
|
98
|
+
"timeout": self.timeout, "parameters": self.params}
|
|
99
|
+
resp = requests.post(self.url, json=payload, headers=headers)
|
|
100
|
+
if resp.status_code == 200:
|
|
101
|
+
return resp.json()["results"][0]["generated_text"]
|
|
102
|
+
|
|
103
|
+
resp.raise_for_status()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
|
|
108
|
+
print(provider.query("ok"))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import json
|
|
3
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
4
|
+
from typing import List
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OllamaProvider(Provider):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
model_id=None
|
|
14
|
+
):
|
|
15
|
+
self.url = OLLAMA_URL + "/api/generate"
|
|
16
|
+
self.model_id = model_id
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
def query(self, sentence: str) -> str:
|
|
20
|
+
payload = {"model": self.model_id, "prompt": sentence}
|
|
21
|
+
resp = requests.post(self.url, json=payload, stream=True)
|
|
22
|
+
final_text = ""
|
|
23
|
+
data = b''
|
|
24
|
+
for chunk in resp:
|
|
25
|
+
data += chunk
|
|
26
|
+
if data.endswith(b'\n'):
|
|
27
|
+
json_obj = json.loads(data)
|
|
28
|
+
if not json_obj["done"] and json_obj["response"]:
|
|
29
|
+
final_text += json_obj["response"]
|
|
30
|
+
data = b''
|
|
31
|
+
|
|
32
|
+
return final_text
|
|
33
|
+
|
|
34
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if __name__ == "__main__":
|
|
39
|
+
provider = OllamaProvider(model_id="llama3.1:8b")
|
|
40
|
+
print(provider.query("ok"))
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Provider(ABC):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def query(self, sentence: str) -> str:
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
def batch_query(self, sentences: List[str]) -> List[str]:
|
|
14
|
+
return [self.query(sentence) for sentence in sentences]
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
18
|
+
pass
|
|
19
|
+
|
|
@@ -4,10 +4,9 @@ import json
|
|
|
4
4
|
from types import MappingProxyType
|
|
5
5
|
from typing import List
|
|
6
6
|
import dataclasses
|
|
7
|
-
from ibm_watsonx_ai.foundation_models import ModelInference, Embeddings
|
|
8
|
-
from ibm_watsonx_ai.credentials import Credentials
|
|
9
7
|
from threading import Lock
|
|
10
|
-
|
|
8
|
+
import time
|
|
9
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
11
10
|
|
|
12
11
|
ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
|
|
13
12
|
ACCESS_HEADER = {
|
|
@@ -18,11 +17,11 @@ ACCESS_HEADER = {
|
|
|
18
17
|
YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
|
|
19
18
|
PROD_URL = "https://us-south.ml.cloud.ibm.com"
|
|
20
19
|
DEFAULT_PARAM = MappingProxyType(
|
|
21
|
-
{"min_new_tokens":
|
|
20
|
+
{"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
22
21
|
)
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class WatsonXProvider:
|
|
24
|
+
class WatsonXProvider(Provider):
|
|
26
25
|
def __init__(
|
|
27
26
|
self,
|
|
28
27
|
model_id=None,
|
|
@@ -31,7 +30,7 @@ class WatsonXProvider:
|
|
|
31
30
|
api_endpoint=PROD_URL,
|
|
32
31
|
url=ACCESS_URL,
|
|
33
32
|
timeout=60,
|
|
34
|
-
|
|
33
|
+
params=None,
|
|
35
34
|
embedding_model_id=None,
|
|
36
35
|
):
|
|
37
36
|
super().__init__()
|
|
@@ -56,12 +55,15 @@ class WatsonXProvider:
|
|
|
56
55
|
self.embedding_model_id = embedding_model_id
|
|
57
56
|
self.lock = Lock()
|
|
58
57
|
|
|
59
|
-
if
|
|
60
|
-
|
|
61
|
-
if
|
|
62
|
-
|
|
58
|
+
self.params = params if params else DEFAULT_PARAM
|
|
59
|
+
|
|
60
|
+
if isinstance(self.params, MappingProxyType):
|
|
61
|
+
self.params = dict(self.params)
|
|
62
|
+
if dataclasses.is_dataclass(self.params):
|
|
63
|
+
self.params = dataclasses.asdict(self.params)
|
|
63
64
|
|
|
64
|
-
self.
|
|
65
|
+
self.refresh_time = None
|
|
66
|
+
self.access_token = None
|
|
65
67
|
self._refresh_token()
|
|
66
68
|
|
|
67
69
|
def _get_access_token(self):
|
|
@@ -71,75 +73,70 @@ class WatsonXProvider:
|
|
|
71
73
|
if response.status_code == 200:
|
|
72
74
|
token_data = json.loads(response.text)
|
|
73
75
|
token = token_data["access_token"]
|
|
74
|
-
|
|
75
|
-
|
|
76
|
+
expiration = token_data["expiration"]
|
|
77
|
+
expires_in = token_data["expires_in"]
|
|
78
|
+
# 9 minutes before expire
|
|
79
|
+
refresh_time = expiration - int(0.15 * expires_in)
|
|
80
|
+
return token, refresh_time
|
|
76
81
|
|
|
77
82
|
raise RuntimeError(
|
|
78
83
|
f"try to acquire access token and get {response.status_code}"
|
|
79
84
|
)
|
|
80
85
|
|
|
81
|
-
def
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if
|
|
94
|
-
|
|
95
|
-
model_id=self.model_id,
|
|
96
|
-
params=self.decode_param,
|
|
97
|
-
credentials=Credentials(token=self.access_token, url=self.api_endpoint),
|
|
98
|
-
space_id=self.space_id,
|
|
99
|
-
)
|
|
86
|
+
def prepare_header(self):
|
|
87
|
+
headers = {"Authorization": f"Bearer {self.access_token}",
|
|
88
|
+
"Content-Type": "application/json"}
|
|
89
|
+
return headers
|
|
90
|
+
|
|
91
|
+
def generate(self, sentence: str):
|
|
92
|
+
headers = self.prepare_header()
|
|
93
|
+
|
|
94
|
+
data = {"model_id": self.model_id, "input": sentence,
|
|
95
|
+
"parameters": self.params, "space_id": self.space_id}
|
|
96
|
+
generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
|
|
97
|
+
resp = requests.post(url=generation_url, headers=headers, json=data)
|
|
98
|
+
if resp.status_code == 200:
|
|
99
|
+
return resp.json()["results"][0]
|
|
100
100
|
else:
|
|
101
|
-
|
|
101
|
+
resp.raise_for_status()
|
|
102
102
|
|
|
103
|
-
def
|
|
104
|
-
if
|
|
105
|
-
|
|
106
|
-
try:
|
|
107
|
-
return self.client.generate([sentence])[0][
|
|
108
|
-
"results"
|
|
109
|
-
][ # pylint: disable=E1136
|
|
110
|
-
0
|
|
111
|
-
]
|
|
112
|
-
except Exception as e:
|
|
103
|
+
def _refresh_token(self):
|
|
104
|
+
# if we do not have a token or the current timestamp is 9 minutes away from expire.
|
|
105
|
+
if not self.access_token or time.time() > self.refresh_time:
|
|
113
106
|
with self.lock:
|
|
114
|
-
if
|
|
115
|
-
self.
|
|
116
|
-
raise e
|
|
107
|
+
if not self.access_token or time.time() > self.refresh_time:
|
|
108
|
+
self.access_token, self.refresh_time = self._get_access_token()
|
|
117
109
|
|
|
118
|
-
def
|
|
110
|
+
def query(self, sentence: str) -> str:
|
|
119
111
|
if self.model_id is None:
|
|
120
112
|
raise Exception("model id must be specified for text generation")
|
|
121
113
|
try:
|
|
122
|
-
|
|
123
|
-
outputs = [output["results"][0] for output in outputs]
|
|
124
|
-
return outputs
|
|
114
|
+
return self.generate(sentence)["generated_text"]
|
|
125
115
|
except Exception as e:
|
|
126
116
|
with self.lock:
|
|
127
117
|
if "authentication_token_expired" in str(e):
|
|
128
118
|
self._refresh_token()
|
|
129
119
|
raise e
|
|
130
|
-
|
|
131
|
-
|
|
120
|
+
|
|
121
|
+
def batch_query(self, sentences: List[str]) -> List[dict]:
|
|
122
|
+
return [self.query(sentence) for sentence in sentences]
|
|
132
123
|
|
|
133
124
|
def encode(self, sentences: List[str]) -> List[list]:
|
|
134
125
|
if self.embedding_model_id is None:
|
|
135
126
|
raise Exception("embedding model id must be specified for text encoding")
|
|
136
|
-
output = self.embedding_client.generate(sentences)
|
|
137
|
-
return [entry["embedding"] for entry in output["results"]]
|
|
138
127
|
|
|
128
|
+
headers = self.prepare_header()
|
|
129
|
+
url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
|
|
139
130
|
|
|
140
|
-
|
|
141
|
-
|
|
131
|
+
data = {"inputs": sentences, "model_id": self.model_id, "space_id": self.space_id}
|
|
132
|
+
resp = requests.post(url=url, headers=headers, json=data)
|
|
133
|
+
if resp.status_code == 200:
|
|
134
|
+
return [entry["embedding"] for entry in resp.json()["results"]]
|
|
135
|
+
else:
|
|
136
|
+
resp.raise_for_status()
|
|
142
137
|
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
143
140
|
provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
|
|
144
141
|
|
|
145
142
|
prompt = """
|
|
@@ -172,4 +169,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
|
|
|
172
169
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
173
170
|
"""
|
|
174
171
|
|
|
175
|
-
print(provider.query(prompt))
|
|
172
|
+
print(provider.query(prompt))
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def parse_json_string(input_string):
|
|
6
|
+
json_char_count = 0
|
|
7
|
+
json_objects = []
|
|
8
|
+
current_json = ""
|
|
9
|
+
brace_level = 0
|
|
10
|
+
inside_json = False
|
|
11
|
+
|
|
12
|
+
for i, char in enumerate(input_string):
|
|
13
|
+
if char == "{":
|
|
14
|
+
brace_level += 1
|
|
15
|
+
inside_json = True
|
|
16
|
+
json_char_count += 1
|
|
17
|
+
if inside_json:
|
|
18
|
+
current_json += char
|
|
19
|
+
json_char_count += 1
|
|
20
|
+
if char == "}":
|
|
21
|
+
json_char_count += 1
|
|
22
|
+
brace_level -= 1
|
|
23
|
+
if brace_level == 0:
|
|
24
|
+
inside_json = False
|
|
25
|
+
try:
|
|
26
|
+
json_objects.append(json.loads(current_json))
|
|
27
|
+
except json.JSONDecodeError as e:
|
|
28
|
+
print(f"Error decoding JSON: {e}")
|
|
29
|
+
current_json = "" # Reset current JSON string
|
|
30
|
+
# some threshold to say there are some non-funct calling step
|
|
31
|
+
is_thinking_step = len(input_string) - json_char_count > 10
|
|
32
|
+
return json_objects
|
|
33
|
+
|
|
34
|
+
wai_client = WatsonXProvider(model_id="meta-llama/llama-3-405b-instruct")
|
|
35
|
+
|
|
36
|
+
prompt = """
|
|
37
|
+
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
38
|
+
You are trying to make tool calls. Given a raw input and tool output. Try to extract the information to make the tool call
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
Tool description:
|
|
42
|
+
def get_payslips(user_id: str) -> PayslipsResponse:
|
|
43
|
+
Gets a user's payslips from Workday.
|
|
44
|
+
|
|
45
|
+
:param user_id: The user's id uniquely identifying them within the Workday API.
|
|
46
|
+
:return: The user's payslips.
|
|
47
|
+
|
|
48
|
+
Raw inputs:\{"tool_name": "get_payslips", "args": {"user_id": '$get_user_workday_ids'}}
|
|
49
|
+
tool output: {'user_id': UserWorkdayIDs(person_id='', user_id='6dcb8106e8b74b5aabb1fc3ab8ef2b92')}
|
|
50
|
+
<|start_header_id|>ipython<|end_header_id|>
|
|
51
|
+
{"tool_name": "get_payslips", "args": {"user_id": "6dcb8106e8b74b5aabb1fc3ab8ef2b92"}}
|
|
52
|
+
<|eot_id|>
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
test_sample1 = """
|
|
57
|
+
<|start_header_id|>assistant<|end_header_id|>
|
|
58
|
+
Tool description:
|
|
59
|
+
def update_direct_reports(email_id: str, members: List[str], notification:bool) -> PayslipsResponse:
|
|
60
|
+
update direct reports for a given user
|
|
61
|
+
:param email_id: The user's email-id uniquely identifying them within the Workday API.
|
|
62
|
+
:param members: a list of user ids to be added as direct reports
|
|
63
|
+
:param notification: do we send the notification to all members
|
|
64
|
+
|
|
65
|
+
Raw inputs: {"tool_name": "update_direct_reports", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
|
|
66
|
+
tool output: {"email_id": 'jalenm3@163.com'}
|
|
67
|
+
{'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
|
|
68
|
+
<|start_header_id|>ipython<|end_header_id|>"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
test_sample2 = """
|
|
72
|
+
<|start_header_id|>assistant<|end_header_id|>
|
|
73
|
+
Tool description:
|
|
74
|
+
def book_meeting(location: str, date: str, time: str) -> bool:
|
|
75
|
+
update direct reports for a given user
|
|
76
|
+
:param email_id: The user's email-id uniquely identifying them within the Workday API.
|
|
77
|
+
:param members: a list of user ids to be added as direct reports
|
|
78
|
+
:param notification: do we send the notification to all members
|
|
79
|
+
|
|
80
|
+
Raw inputs: {"tool_name": "book_meeting", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
|
|
81
|
+
tool output: {"email_id": 'jalenm3@163.com'}
|
|
82
|
+
{'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
|
|
83
|
+
<|start_header_id|>ipython<|end_header_id|>"""
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
outputs = wai_client.query(prompt + test_sample1)
|
|
88
|
+
|
|
89
|
+
import json
|
|
90
|
+
print(outputs["generated_text"])
|
|
91
|
+
|
|
92
|
+
json_obj = parse_json_string(outputs["generated_text"])[0]
|
|
93
|
+
|
|
94
|
+
print(json_obj)
|
|
@@ -6,15 +6,25 @@ import importlib.util
|
|
|
6
6
|
import re
|
|
7
7
|
from jsonargparse import CLI
|
|
8
8
|
import os
|
|
9
|
+
import textwrap
|
|
10
|
+
from dataclasses import is_dataclass, asdict
|
|
9
11
|
|
|
10
|
-
from wxo_agentic_evaluation.
|
|
12
|
+
from wxo_agentic_evaluation.service_provider import get_provider
|
|
11
13
|
from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
|
|
12
|
-
from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer,
|
|
14
|
+
from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
|
|
13
15
|
from wxo_agentic_evaluation import __file__
|
|
14
16
|
|
|
15
17
|
root_dir = os.path.dirname(__file__)
|
|
16
18
|
TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
|
|
19
|
+
ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
|
|
17
20
|
|
|
21
|
+
class UniversalEncoder(json.JSONEncoder):
|
|
22
|
+
def default(self, obj):
|
|
23
|
+
if is_dataclass(obj):
|
|
24
|
+
return asdict(obj)
|
|
25
|
+
elif hasattr(obj, "__dict__"):
|
|
26
|
+
return obj.__dict__
|
|
27
|
+
return super().default(obj)
|
|
18
28
|
|
|
19
29
|
def extract_first_json_list(raw: str) -> list:
|
|
20
30
|
matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
|
|
@@ -29,6 +39,33 @@ def extract_first_json_list(raw: str) -> list:
|
|
|
29
39
|
print(raw)
|
|
30
40
|
return []
|
|
31
41
|
|
|
42
|
+
def parse_json_string(input_string):
|
|
43
|
+
json_char_count = 0
|
|
44
|
+
json_objects = []
|
|
45
|
+
current_json = ""
|
|
46
|
+
brace_level = 0
|
|
47
|
+
inside_json = False
|
|
48
|
+
|
|
49
|
+
for i, char in enumerate(input_string):
|
|
50
|
+
if char == "{":
|
|
51
|
+
brace_level += 1
|
|
52
|
+
inside_json = True
|
|
53
|
+
json_char_count += 1
|
|
54
|
+
if inside_json:
|
|
55
|
+
current_json += char
|
|
56
|
+
json_char_count += 1
|
|
57
|
+
if char == "}":
|
|
58
|
+
json_char_count += 1
|
|
59
|
+
brace_level -= 1
|
|
60
|
+
if brace_level == 0:
|
|
61
|
+
inside_json = False
|
|
62
|
+
try:
|
|
63
|
+
json_objects.append(json.loads(current_json))
|
|
64
|
+
except json.JSONDecodeError as e:
|
|
65
|
+
print(f"Error decoding JSON: {e}")
|
|
66
|
+
current_json = "" # Reset current JSON string
|
|
67
|
+
return json_objects
|
|
68
|
+
|
|
32
69
|
|
|
33
70
|
def load_tools_module(tools_path: Path) -> dict:
|
|
34
71
|
tools_dict = {}
|
|
@@ -93,8 +130,64 @@ def extract_tool_signatures(tools_path: Path) -> list:
|
|
|
93
130
|
|
|
94
131
|
return tool_data
|
|
95
132
|
|
|
133
|
+
def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
|
|
134
|
+
functions = {}
|
|
135
|
+
files_to_parse = []
|
|
136
|
+
|
|
137
|
+
# Handle both single file and directory cases
|
|
138
|
+
if tools_path.is_file():
|
|
139
|
+
files_to_parse.append(tools_path)
|
|
140
|
+
elif tools_path.is_dir():
|
|
141
|
+
files_to_parse.extend(tools_path.glob("**/*.py"))
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
|
|
144
|
+
|
|
145
|
+
for file_path in files_to_parse:
|
|
146
|
+
try:
|
|
147
|
+
with file_path.open("r", encoding="utf-8") as f:
|
|
148
|
+
code = f.read()
|
|
149
|
+
parsed_code = ast.parse(code)
|
|
150
|
+
|
|
151
|
+
for node in parsed_code.body:
|
|
152
|
+
if isinstance(node, ast.FunctionDef):
|
|
153
|
+
name = node.name
|
|
154
|
+
|
|
155
|
+
# Get args and type annotations
|
|
156
|
+
args = []
|
|
157
|
+
for arg in node.args.args:
|
|
158
|
+
if arg.arg == "self":
|
|
159
|
+
continue
|
|
160
|
+
annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
|
|
161
|
+
args.append((arg.arg, annotation))
|
|
162
|
+
|
|
163
|
+
# Get return type
|
|
164
|
+
returns = ast.unparse(node.returns) if node.returns else "None"
|
|
165
|
+
|
|
166
|
+
# Get docstring
|
|
167
|
+
docstring = ast.get_docstring(node)
|
|
168
|
+
docstring = textwrap.dedent(docstring).strip() if docstring else ""
|
|
169
|
+
|
|
170
|
+
# Format parameter descriptions if available in docstring
|
|
171
|
+
doc_lines = docstring.splitlines()
|
|
172
|
+
doc_summary = doc_lines[0] if doc_lines else ""
|
|
173
|
+
param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
|
|
174
|
+
|
|
175
|
+
# Compose the final string
|
|
176
|
+
args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
|
|
177
|
+
function_str = f"""def {name}({args_str}) -> {returns}:
|
|
178
|
+
{doc_summary}"""
|
|
179
|
+
if param_descriptions:
|
|
180
|
+
function_str += f"\n {param_descriptions}"
|
|
96
181
|
|
|
97
|
-
|
|
182
|
+
functions[name] = function_str
|
|
183
|
+
except Exception as e:
|
|
184
|
+
print(f"Warning: Failed to parse {file_path}: {str(e)}")
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
return functions
|
|
188
|
+
|
|
189
|
+
def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
|
|
190
|
+
tool_name = step["tool_name"]
|
|
98
191
|
cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
|
|
99
192
|
for entry in cache:
|
|
100
193
|
if entry["inputs"] == inputs:
|
|
@@ -103,7 +196,27 @@ def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_mo
|
|
|
103
196
|
if tool_name not in tools_module:
|
|
104
197
|
raise ValueError(f"Tool '{tool_name}' not found")
|
|
105
198
|
|
|
106
|
-
|
|
199
|
+
try:
|
|
200
|
+
output = tools_module[tool_name](**inputs)
|
|
201
|
+
except:
|
|
202
|
+
provider = get_provider(
|
|
203
|
+
model_id="meta-llama/llama-3-405b-instruct",
|
|
204
|
+
params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
|
|
205
|
+
)
|
|
206
|
+
renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
|
|
207
|
+
|
|
208
|
+
prompt = renderer.render(
|
|
209
|
+
tool_signature=tool_signatures_for_prompt[tool_name],
|
|
210
|
+
step=step,
|
|
211
|
+
inputs=inputs,
|
|
212
|
+
)
|
|
213
|
+
response = provider.query(prompt)
|
|
214
|
+
json_obj = parse_json_string(response)[0]
|
|
215
|
+
try:
|
|
216
|
+
output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
|
|
217
|
+
except:
|
|
218
|
+
raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
|
|
219
|
+
|
|
107
220
|
cache.append({"inputs": inputs, "output": output})
|
|
108
221
|
if not isinstance(output, dict):
|
|
109
222
|
print(f" Tool {tool_name} returned non-dict output: {output}")
|
|
@@ -119,15 +232,14 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
|
|
|
119
232
|
available_tools=tool_signatures_str,
|
|
120
233
|
)
|
|
121
234
|
response = provider.query(prompt)
|
|
122
|
-
|
|
123
|
-
parsed = extract_first_json_list(raw)
|
|
235
|
+
parsed = extract_first_json_list(response)
|
|
124
236
|
print("\n LLM Tool Plan:")
|
|
125
237
|
print(json.dumps(parsed, indent=2))
|
|
126
238
|
return parsed
|
|
127
239
|
|
|
128
240
|
|
|
129
241
|
# --- Tool Execution Logic ---
|
|
130
|
-
def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
|
|
242
|
+
def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
|
|
131
243
|
memory = {}
|
|
132
244
|
|
|
133
245
|
for step in tool_plan:
|
|
@@ -166,14 +278,14 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
|
|
|
166
278
|
item_inputs = resolved_inputs.copy()
|
|
167
279
|
item_inputs[list_key] = val
|
|
168
280
|
print(f" ⚙️ Running {name} with {list_key} = {val}")
|
|
169
|
-
output = ensure_data_available(
|
|
281
|
+
output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
|
|
170
282
|
results.append(output)
|
|
171
283
|
memory[f"{name}_{idx}"] = output
|
|
172
284
|
|
|
173
285
|
memory[name] = results
|
|
174
286
|
print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
|
|
175
287
|
else:
|
|
176
|
-
output = ensure_data_available(
|
|
288
|
+
output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
|
|
177
289
|
memory[name] = output
|
|
178
290
|
print(f"Stored output under tool name: {name} = {output}")
|
|
179
291
|
|
|
@@ -183,14 +295,11 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
|
|
|
183
295
|
agent = {"name": agent_name}
|
|
184
296
|
tools_module = load_tools_module(tools_path)
|
|
185
297
|
tool_signatures = extract_tool_signatures(tools_path)
|
|
298
|
+
tool_signatures_for_prompt = extract_tool_signatures_for_prompt(tools_path)
|
|
186
299
|
|
|
187
|
-
provider =
|
|
300
|
+
provider = get_provider(
|
|
188
301
|
model_id="meta-llama/llama-3-405b-instruct",
|
|
189
|
-
|
|
190
|
-
"min_new_tokens": 50,
|
|
191
|
-
"decoding_method": "greedy",
|
|
192
|
-
"max_new_tokens": 200
|
|
193
|
-
}
|
|
302
|
+
params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
|
|
194
303
|
)
|
|
195
304
|
|
|
196
305
|
snapshot = {
|
|
@@ -202,10 +311,14 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
|
|
|
202
311
|
for story in stories:
|
|
203
312
|
print(f"\n📘 Planning tool calls for story: {story}")
|
|
204
313
|
tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
|
|
205
|
-
|
|
314
|
+
try:
|
|
315
|
+
run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
|
|
316
|
+
except ValueError as e:
|
|
317
|
+
print(f"❌ Error running tool chain for story '{story}': {e}")
|
|
318
|
+
continue
|
|
206
319
|
|
|
207
320
|
with output_path.open("w", encoding="utf-8") as f:
|
|
208
|
-
json.dump(snapshot, f, indent=2)
|
|
321
|
+
json.dump(snapshot, f, indent=2, cls=UniversalEncoder)
|
|
209
322
|
print(f"\n✅ Snapshot saved to {output_path}")
|
|
210
323
|
|
|
211
324
|
|