ibm-watsonx-orchestrate-evaluation-framework 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.

Files changed (41) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/METADATA +70 -7
  2. ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info/RECORD +56 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +3 -3
  4. wxo_agentic_evaluation/analytics/tools/ux.py +1 -1
  5. wxo_agentic_evaluation/analyze_run.py +10 -10
  6. wxo_agentic_evaluation/arg_configs.py +8 -1
  7. wxo_agentic_evaluation/batch_annotate.py +3 -9
  8. wxo_agentic_evaluation/data_annotator.py +50 -36
  9. wxo_agentic_evaluation/evaluation_package.py +102 -85
  10. wxo_agentic_evaluation/external_agent/__init__.py +37 -0
  11. wxo_agentic_evaluation/external_agent/external_validate.py +74 -29
  12. wxo_agentic_evaluation/external_agent/performance_test.py +66 -0
  13. wxo_agentic_evaluation/external_agent/types.py +8 -2
  14. wxo_agentic_evaluation/inference_backend.py +45 -50
  15. wxo_agentic_evaluation/llm_matching.py +6 -6
  16. wxo_agentic_evaluation/llm_rag_eval.py +4 -4
  17. wxo_agentic_evaluation/llm_user.py +3 -3
  18. wxo_agentic_evaluation/main.py +63 -23
  19. wxo_agentic_evaluation/metrics/metrics.py +59 -0
  20. wxo_agentic_evaluation/prompt/args_extractor_prompt.jinja2 +23 -0
  21. wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +2 -0
  22. wxo_agentic_evaluation/prompt/examples/data_simple.json +1 -2
  23. wxo_agentic_evaluation/prompt/starting_sentence_generation_prompt.jinja2 +195 -0
  24. wxo_agentic_evaluation/prompt/story_generation_prompt.jinja2 +154 -0
  25. wxo_agentic_evaluation/prompt/template_render.py +17 -0
  26. wxo_agentic_evaluation/prompt/tool_planner.jinja2 +13 -7
  27. wxo_agentic_evaluation/record_chat.py +74 -26
  28. wxo_agentic_evaluation/resource_map.py +47 -0
  29. wxo_agentic_evaluation/service_provider/__init__.py +35 -0
  30. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +108 -0
  31. wxo_agentic_evaluation/service_provider/ollama_provider.py +40 -0
  32. wxo_agentic_evaluation/service_provider/provider.py +19 -0
  33. wxo_agentic_evaluation/{watsonx_provider.py → service_provider/watsonx_provider.py} +27 -18
  34. wxo_agentic_evaluation/test_prompt.py +94 -0
  35. wxo_agentic_evaluation/tool_planner.py +130 -17
  36. wxo_agentic_evaluation/type.py +0 -57
  37. wxo_agentic_evaluation/utils/utils.py +6 -54
  38. ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/RECORD +0 -46
  39. ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info/licenses/LICENSE +0 -22
  40. {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/WHEEL +0 -0
  41. {ibm_watsonx_orchestrate_evaluation_framework-1.0.2.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import os
2
+ import requests
3
+ import time
4
+
5
+ from typing import List
6
+ from threading import Lock
7
+
8
+ from wxo_agentic_evaluation.service_provider.provider import Provider
9
+ from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
10
+
11
+ AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
12
+ AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
13
+ WO_INSTANCE = os.environ.get("WO_INSTANCE")
14
+ WO_API_KEY = os.environ.get("WO_API_KEY")
15
+ DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
16
+
17
+
18
+ class ModelProxyProvider(Provider):
19
+ def __init__(
20
+ self,
21
+ model_id=None,
22
+ api_key=WO_API_KEY,
23
+ instance_url=WO_INSTANCE,
24
+ timeout=300,
25
+ embedding_model_id=None,
26
+ params=None
27
+ ):
28
+ super().__init__()
29
+
30
+ if not instance_url or not api_key:
31
+ raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
32
+
33
+ self.timeout = timeout
34
+ self.model_id = model_id
35
+
36
+ self.embedding_model_id = embedding_model_id
37
+
38
+ self.api_key = api_key
39
+ self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
40
+ self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
41
+ self.url = instance_url + "/ml/v1/text/generation?version=2024-05-01"
42
+
43
+ self.embedding_url = instance_url + "/ml/v1/text/embeddings"
44
+
45
+ self.lock = Lock()
46
+ self.token, self.refresh_time = self.get_token()
47
+ self.params = params if params else DEFAULT_PARAM
48
+
49
+ def get_token(self):
50
+ if self.is_ibm_cloud:
51
+ payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
52
+ resp = requests.post(self.auth_url, data=payload)
53
+ token_key = "access_token"
54
+ else:
55
+ payload = {"apikey": self.api_key}
56
+ resp = requests.post(self.auth_url, json=payload)
57
+ token_key = "token"
58
+ if resp.status_code == 200:
59
+ json_obj = resp.json()
60
+ token = json_obj[token_key]
61
+ expires_in = json_obj["expires_in"]
62
+ refresh_time = time.time() + int(0.8*expires_in)
63
+ return token, refresh_time
64
+
65
+ resp.raise_for_status()
66
+
67
+ def refresh_token_if_expires(self):
68
+ if time.time() > self.refresh_time:
69
+ with self.lock:
70
+ if time.time() > self.refresh_time:
71
+ self.token, self.refresh_time = self.get_token()
72
+
73
+ def get_header(self):
74
+ return {"Authorization": f"Bearer {self.token}"}
75
+
76
+ def encode(self, sentences: List[str]) -> List[list]:
77
+ if self.embedding_model_id is None:
78
+ raise Exception("embedding model id must be specified for text generation")
79
+
80
+ self.refresh_token_if_expires()
81
+ headers = self.get_header()
82
+ payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
83
+ #"timeout": self.timeout}
84
+ resp = requests.post(self.embedding_url, json=payload, headers=headers)
85
+
86
+ if resp.status_code == 200:
87
+ json_obj = resp.json()
88
+ return json_obj["generated_text"]
89
+
90
+ resp.raise_for_status()
91
+
92
+ def query(self, sentence: str) -> str:
93
+ if self.model_id is None:
94
+ raise Exception("model id must be specified for text generation")
95
+ self.refresh_token_if_expires()
96
+ headers = self.get_header()
97
+ payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
98
+ "timeout": self.timeout, "parameters": self.params}
99
+ resp = requests.post(self.url, json=payload, headers=headers)
100
+ if resp.status_code == 200:
101
+ return resp.json()["results"][0]["generated_text"]
102
+
103
+ resp.raise_for_status()
104
+
105
+
106
+ if __name__ == "__main__":
107
+ provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
108
+ print(provider.query("ok"))
@@ -0,0 +1,40 @@
1
+ import requests
2
+ import json
3
+ from wxo_agentic_evaluation.service_provider.provider import Provider
4
+ from typing import List
5
+ import os
6
+
7
+ OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
8
+
9
+
10
+ class OllamaProvider(Provider):
11
+ def __init__(
12
+ self,
13
+ model_id=None
14
+ ):
15
+ self.url = OLLAMA_URL + "/api/generate"
16
+ self.model_id = model_id
17
+ super().__init__()
18
+
19
+ def query(self, sentence: str) -> str:
20
+ payload = {"model": self.model_id, "prompt": sentence}
21
+ resp = requests.post(self.url, json=payload, stream=True)
22
+ final_text = ""
23
+ data = b''
24
+ for chunk in resp:
25
+ data += chunk
26
+ if data.endswith(b'\n'):
27
+ json_obj = json.loads(data)
28
+ if not json_obj["done"] and json_obj["response"]:
29
+ final_text += json_obj["response"]
30
+ data = b''
31
+
32
+ return final_text
33
+
34
+ def encode(self, sentences: List[str]) -> List[list]:
35
+ pass
36
+
37
+
38
+ if __name__ == "__main__":
39
+ provider = OllamaProvider(model_id="llama3.1:8b")
40
+ print(provider.query("ok"))
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+
4
+
5
+ class Provider(ABC):
6
+ def __init__(self):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def query(self, sentence: str) -> str:
11
+ pass
12
+
13
+ def batch_query(self, sentences: List[str]) -> List[str]:
14
+ return [self.query(sentence) for sentence in sentences]
15
+
16
+ @abstractmethod
17
+ def encode(self, sentences: List[str]) -> List[list]:
18
+ pass
19
+
@@ -5,7 +5,8 @@ from types import MappingProxyType
5
5
  from typing import List
6
6
  import dataclasses
7
7
  from threading import Lock
8
-
8
+ import time
9
+ from wxo_agentic_evaluation.service_provider.provider import Provider
9
10
 
10
11
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
11
12
  ACCESS_HEADER = {
@@ -16,11 +17,11 @@ ACCESS_HEADER = {
16
17
  YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
17
18
  PROD_URL = "https://us-south.ml.cloud.ibm.com"
18
19
  DEFAULT_PARAM = MappingProxyType(
19
- {"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 100}
20
+ {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
20
21
  )
21
22
 
22
23
 
23
- class WatsonXProvider:
24
+ class WatsonXProvider(Provider):
24
25
  def __init__(
25
26
  self,
26
27
  model_id=None,
@@ -29,7 +30,7 @@ class WatsonXProvider:
29
30
  api_endpoint=PROD_URL,
30
31
  url=ACCESS_URL,
31
32
  timeout=60,
32
- llm_decode_parameter=DEFAULT_PARAM,
33
+ params=None,
33
34
  embedding_model_id=None,
34
35
  ):
35
36
  super().__init__()
@@ -54,12 +55,15 @@ class WatsonXProvider:
54
55
  self.embedding_model_id = embedding_model_id
55
56
  self.lock = Lock()
56
57
 
57
- if isinstance(llm_decode_parameter, MappingProxyType):
58
- llm_decode_parameter = dict(llm_decode_parameter)
59
- if dataclasses.is_dataclass(llm_decode_parameter):
60
- llm_decode_parameter = dataclasses.asdict(llm_decode_parameter)
58
+ self.params = params if params else DEFAULT_PARAM
59
+
60
+ if isinstance(self.params, MappingProxyType):
61
+ self.params = dict(self.params)
62
+ if dataclasses.is_dataclass(self.params):
63
+ self.params = dataclasses.asdict(self.params)
61
64
 
62
- self.decode_param = llm_decode_parameter
65
+ self.refresh_time = None
66
+ self.access_token = None
63
67
  self._refresh_token()
64
68
 
65
69
  def _get_access_token(self):
@@ -69,8 +73,11 @@ class WatsonXProvider:
69
73
  if response.status_code == 200:
70
74
  token_data = json.loads(response.text)
71
75
  token = token_data["access_token"]
72
-
73
- return token
76
+ expiration = token_data["expiration"]
77
+ expires_in = token_data["expires_in"]
78
+ # 9 minutes before expire
79
+ refresh_time = expiration - int(0.15 * expires_in)
80
+ return token, refresh_time
74
81
 
75
82
  raise RuntimeError(
76
83
  f"try to acquire access token and get {response.status_code}"
@@ -85,7 +92,7 @@ class WatsonXProvider:
85
92
  headers = self.prepare_header()
86
93
 
87
94
  data = {"model_id": self.model_id, "input": sentence,
88
- "parameters": self.decode_param, "space_id": self.space_id}
95
+ "parameters": self.params, "space_id": self.space_id}
89
96
  generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
90
97
  resp = requests.post(url=generation_url, headers=headers, json=data)
91
98
  if resp.status_code == 200:
@@ -94,13 +101,17 @@ class WatsonXProvider:
94
101
  resp.raise_for_status()
95
102
 
96
103
  def _refresh_token(self):
97
- self.access_token = self._get_access_token()
104
+ # if we do not have a token or the current timestamp is 9 minutes away from expire.
105
+ if not self.access_token or time.time() > self.refresh_time:
106
+ with self.lock:
107
+ if not self.access_token or time.time() > self.refresh_time:
108
+ self.access_token, self.refresh_time = self._get_access_token()
98
109
 
99
- def query(self, sentence: str) -> dict:
110
+ def query(self, sentence: str) -> str:
100
111
  if self.model_id is None:
101
112
  raise Exception("model id must be specified for text generation")
102
113
  try:
103
- return self.generate(sentence)
114
+ return self.generate(sentence)["generated_text"]
104
115
  except Exception as e:
105
116
  with self.lock:
106
117
  if "authentication_token_expired" in str(e):
@@ -126,8 +137,6 @@ class WatsonXProvider:
126
137
 
127
138
 
128
139
  if __name__ == "__main__":
129
- import os
130
-
131
140
  provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
132
141
 
133
142
  prompt = """
@@ -160,4 +169,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
160
169
  <|eot_id|><|start_header_id|>user<|end_header_id|>
161
170
  """
162
171
 
163
- print(provider.batch_query([prompt]))
172
+ print(provider.query(prompt))
@@ -0,0 +1,94 @@
1
+ from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
2
+
3
+
4
+
5
+ def parse_json_string(input_string):
6
+ json_char_count = 0
7
+ json_objects = []
8
+ current_json = ""
9
+ brace_level = 0
10
+ inside_json = False
11
+
12
+ for i, char in enumerate(input_string):
13
+ if char == "{":
14
+ brace_level += 1
15
+ inside_json = True
16
+ json_char_count += 1
17
+ if inside_json:
18
+ current_json += char
19
+ json_char_count += 1
20
+ if char == "}":
21
+ json_char_count += 1
22
+ brace_level -= 1
23
+ if brace_level == 0:
24
+ inside_json = False
25
+ try:
26
+ json_objects.append(json.loads(current_json))
27
+ except json.JSONDecodeError as e:
28
+ print(f"Error decoding JSON: {e}")
29
+ current_json = "" # Reset current JSON string
30
+ # some threshold to say there are some non-funct calling step
31
+ is_thinking_step = len(input_string) - json_char_count > 10
32
+ return json_objects
33
+
34
+ wai_client = WatsonXProvider(model_id="meta-llama/llama-3-405b-instruct")
35
+
36
+ prompt = """
37
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
38
+ You are trying to make tool calls. Given a raw input and tool output. Try to extract the information to make the tool call
39
+
40
+ Example:
41
+ Tool description:
42
+ def get_payslips(user_id: str) -> PayslipsResponse:
43
+ Gets a user's payslips from Workday.
44
+
45
+ :param user_id: The user's id uniquely identifying them within the Workday API.
46
+ :return: The user's payslips.
47
+
48
+ Raw inputs:\{"tool_name": "get_payslips", "args": {"user_id": '$get_user_workday_ids'}}
49
+ tool output: {'user_id': UserWorkdayIDs(person_id='', user_id='6dcb8106e8b74b5aabb1fc3ab8ef2b92')}
50
+ <|start_header_id|>ipython<|end_header_id|>
51
+ {"tool_name": "get_payslips", "args": {"user_id": "6dcb8106e8b74b5aabb1fc3ab8ef2b92"}}
52
+ <|eot_id|>
53
+
54
+ """
55
+
56
+ test_sample1 = """
57
+ <|start_header_id|>assistant<|end_header_id|>
58
+ Tool description:
59
+ def update_direct_reports(email_id: str, members: List[str], notification:bool) -> PayslipsResponse:
60
+ update direct reports for a given user
61
+ :param email_id: The user's email-id uniquely identifying them within the Workday API.
62
+ :param members: a list of user ids to be added as direct reports
63
+ :param notification: do we send the notification to all members
64
+
65
+ Raw inputs: {"tool_name": "update_direct_reports", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
66
+ tool output: {"email_id": 'jalenm3@163.com'}
67
+ {'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
68
+ <|start_header_id|>ipython<|end_header_id|>"""
69
+
70
+
71
+ test_sample2 = """
72
+ <|start_header_id|>assistant<|end_header_id|>
73
+ Tool description:
74
+ def book_meeting(location: str, date: str, time: str) -> bool:
75
+ update direct reports for a given user
76
+ :param email_id: The user's email-id uniquely identifying them within the Workday API.
77
+ :param members: a list of user ids to be added as direct reports
78
+ :param notification: do we send the notification to all members
79
+
80
+ Raw inputs: {"tool_name": "book_meeting", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
81
+ tool output: {"email_id": 'jalenm3@163.com'}
82
+ {'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
83
+ <|start_header_id|>ipython<|end_header_id|>"""
84
+
85
+
86
+
87
+ outputs = wai_client.query(prompt + test_sample1)
88
+
89
+ import json
90
+ print(outputs["generated_text"])
91
+
92
+ json_obj = parse_json_string(outputs["generated_text"])[0]
93
+
94
+ print(json_obj)
@@ -6,15 +6,25 @@ import importlib.util
6
6
  import re
7
7
  from jsonargparse import CLI
8
8
  import os
9
+ import textwrap
10
+ from dataclasses import is_dataclass, asdict
9
11
 
10
- from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
12
+ from wxo_agentic_evaluation.service_provider import get_provider
11
13
  from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
12
- from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ToolChainAgentTemplateRenderer
14
+ from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
13
15
  from wxo_agentic_evaluation import __file__
14
16
 
15
17
  root_dir = os.path.dirname(__file__)
16
18
  TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
19
+ ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
17
20
 
21
+ class UniversalEncoder(json.JSONEncoder):
22
+ def default(self, obj):
23
+ if is_dataclass(obj):
24
+ return asdict(obj)
25
+ elif hasattr(obj, "__dict__"):
26
+ return obj.__dict__
27
+ return super().default(obj)
18
28
 
19
29
  def extract_first_json_list(raw: str) -> list:
20
30
  matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
@@ -29,6 +39,33 @@ def extract_first_json_list(raw: str) -> list:
29
39
  print(raw)
30
40
  return []
31
41
 
42
+ def parse_json_string(input_string):
43
+ json_char_count = 0
44
+ json_objects = []
45
+ current_json = ""
46
+ brace_level = 0
47
+ inside_json = False
48
+
49
+ for i, char in enumerate(input_string):
50
+ if char == "{":
51
+ brace_level += 1
52
+ inside_json = True
53
+ json_char_count += 1
54
+ if inside_json:
55
+ current_json += char
56
+ json_char_count += 1
57
+ if char == "}":
58
+ json_char_count += 1
59
+ brace_level -= 1
60
+ if brace_level == 0:
61
+ inside_json = False
62
+ try:
63
+ json_objects.append(json.loads(current_json))
64
+ except json.JSONDecodeError as e:
65
+ print(f"Error decoding JSON: {e}")
66
+ current_json = "" # Reset current JSON string
67
+ return json_objects
68
+
32
69
 
33
70
  def load_tools_module(tools_path: Path) -> dict:
34
71
  tools_dict = {}
@@ -93,8 +130,64 @@ def extract_tool_signatures(tools_path: Path) -> list:
93
130
 
94
131
  return tool_data
95
132
 
133
+ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
134
+ functions = {}
135
+ files_to_parse = []
136
+
137
+ # Handle both single file and directory cases
138
+ if tools_path.is_file():
139
+ files_to_parse.append(tools_path)
140
+ elif tools_path.is_dir():
141
+ files_to_parse.extend(tools_path.glob("**/*.py"))
142
+ else:
143
+ raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
144
+
145
+ for file_path in files_to_parse:
146
+ try:
147
+ with file_path.open("r", encoding="utf-8") as f:
148
+ code = f.read()
149
+ parsed_code = ast.parse(code)
150
+
151
+ for node in parsed_code.body:
152
+ if isinstance(node, ast.FunctionDef):
153
+ name = node.name
154
+
155
+ # Get args and type annotations
156
+ args = []
157
+ for arg in node.args.args:
158
+ if arg.arg == "self":
159
+ continue
160
+ annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
161
+ args.append((arg.arg, annotation))
162
+
163
+ # Get return type
164
+ returns = ast.unparse(node.returns) if node.returns else "None"
165
+
166
+ # Get docstring
167
+ docstring = ast.get_docstring(node)
168
+ docstring = textwrap.dedent(docstring).strip() if docstring else ""
169
+
170
+ # Format parameter descriptions if available in docstring
171
+ doc_lines = docstring.splitlines()
172
+ doc_summary = doc_lines[0] if doc_lines else ""
173
+ param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
174
+
175
+ # Compose the final string
176
+ args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
177
+ function_str = f"""def {name}({args_str}) -> {returns}:
178
+ {doc_summary}"""
179
+ if param_descriptions:
180
+ function_str += f"\n {param_descriptions}"
96
181
 
97
- def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_module: dict) -> dict:
182
+ functions[name] = function_str
183
+ except Exception as e:
184
+ print(f"Warning: Failed to parse {file_path}: {str(e)}")
185
+ continue
186
+
187
+ return functions
188
+
189
+ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
190
+ tool_name = step["tool_name"]
98
191
  cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
99
192
  for entry in cache:
100
193
  if entry["inputs"] == inputs:
@@ -103,7 +196,27 @@ def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_mo
103
196
  if tool_name not in tools_module:
104
197
  raise ValueError(f"Tool '{tool_name}' not found")
105
198
 
106
- output = tools_module[tool_name](**inputs)
199
+ try:
200
+ output = tools_module[tool_name](**inputs)
201
+ except:
202
+ provider = get_provider(
203
+ model_id="meta-llama/llama-3-405b-instruct",
204
+ params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
205
+ )
206
+ renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
207
+
208
+ prompt = renderer.render(
209
+ tool_signature=tool_signatures_for_prompt[tool_name],
210
+ step=step,
211
+ inputs=inputs,
212
+ )
213
+ response = provider.query(prompt)
214
+ json_obj = parse_json_string(response)[0]
215
+ try:
216
+ output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
217
+ except:
218
+ raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
219
+
107
220
  cache.append({"inputs": inputs, "output": output})
108
221
  if not isinstance(output, dict):
109
222
  print(f" Tool {tool_name} returned non-dict output: {output}")
@@ -119,15 +232,14 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
119
232
  available_tools=tool_signatures_str,
120
233
  )
121
234
  response = provider.query(prompt)
122
- raw = response.get("generated_text", "")
123
- parsed = extract_first_json_list(raw)
235
+ parsed = extract_first_json_list(response)
124
236
  print("\n LLM Tool Plan:")
125
237
  print(json.dumps(parsed, indent=2))
126
238
  return parsed
127
239
 
128
240
 
129
241
  # --- Tool Execution Logic ---
130
- def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
242
+ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
131
243
  memory = {}
132
244
 
133
245
  for step in tool_plan:
@@ -166,14 +278,14 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
166
278
  item_inputs = resolved_inputs.copy()
167
279
  item_inputs[list_key] = val
168
280
  print(f" ⚙️ Running {name} with {list_key} = {val}")
169
- output = ensure_data_available(name, item_inputs, snapshot, tools_module)
281
+ output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
170
282
  results.append(output)
171
283
  memory[f"{name}_{idx}"] = output
172
284
 
173
285
  memory[name] = results
174
286
  print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
175
287
  else:
176
- output = ensure_data_available(name, resolved_inputs, snapshot, tools_module)
288
+ output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
177
289
  memory[name] = output
178
290
  print(f"Stored output under tool name: {name} = {output}")
179
291
 
@@ -183,14 +295,11 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
183
295
  agent = {"name": agent_name}
184
296
  tools_module = load_tools_module(tools_path)
185
297
  tool_signatures = extract_tool_signatures(tools_path)
298
+ tool_signatures_for_prompt = extract_tool_signatures_for_prompt(tools_path)
186
299
 
187
- provider = WatsonXProvider(
300
+ provider = get_provider(
188
301
  model_id="meta-llama/llama-3-405b-instruct",
189
- llm_decode_parameter={
190
- "min_new_tokens": 50,
191
- "decoding_method": "greedy",
192
- "max_new_tokens": 200
193
- }
302
+ params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
194
303
  )
195
304
 
196
305
  snapshot = {
@@ -202,10 +311,14 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
202
311
  for story in stories:
203
312
  print(f"\n📘 Planning tool calls for story: {story}")
204
313
  tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
205
- run_tool_chain(tool_plan, snapshot, tools_module)
314
+ try:
315
+ run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
316
+ except ValueError as e:
317
+ print(f"❌ Error running tool chain for story '{story}': {e}")
318
+ continue
206
319
 
207
320
  with output_path.open("w", encoding="utf-8") as f:
208
- json.dump(snapshot, f, indent=2)
321
+ json.dump(snapshot, f, indent=2, cls=UniversalEncoder)
209
322
  print(f"\n✅ Snapshot saved to {output_path}")
210
323
 
211
324
 
@@ -111,66 +111,9 @@ class GoalDetail(BaseModel):
111
111
  knowledge_base: KnowledgeBaseGoalDetail = KnowledgeBaseGoalDetail()
112
112
 
113
113
 
114
- class MineField(BaseModel):
115
- type: ContentType
116
- name: str
117
-
118
-
119
114
  class EvaluationData(BaseModel):
120
115
  agent: str
121
116
  goals: Dict
122
117
  story: str
123
- mine_fields: List[MineField]
124
118
  goal_details: List[GoalDetail]
125
119
  starting_sentence: str = None
126
-
127
-
128
- class ToolCallAndRoutingMetrics(BaseModel):
129
- total_tool_calls: int
130
- expected_tool_calls: int
131
- relevant_tool_calls: int
132
- correct_tool_calls: int
133
- total_routing_calls: int
134
- expected_routing_calls: int
135
-
136
- @computed_field
137
- @property
138
- def non_transfer_tool_calls(self) -> int:
139
- return self.total_tool_calls - self.total_routing_calls
140
-
141
- @computed_field
142
- @property
143
- def tool_call_accuracy(self) -> float:
144
- return round(
145
- (
146
- self.correct_tool_calls / self.non_transfer_tool_calls
147
- if self.non_transfer_tool_calls > 0
148
- else 0.0
149
- ),
150
- 2,
151
- )
152
-
153
- @computed_field
154
- @property
155
- def tool_call_relevancy(self) -> float:
156
- return round(
157
- (
158
- (self.relevant_tool_calls - self.expected_routing_calls)
159
- / self.non_transfer_tool_calls
160
- if self.non_transfer_tool_calls > 0
161
- else 0.0
162
- ),
163
- 2,
164
- )
165
-
166
- @computed_field
167
- @property
168
- def agent_routing_accuracy(self) -> float:
169
- return round(
170
- (
171
- self.expected_routing_calls / self.total_routing_calls
172
- if self.total_routing_calls > 0
173
- else 0.0
174
- ),
175
- 2,
176
- )