ibm-watsonx-orchestrate-evaluation-framework 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/METADATA +70 -12
  2. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +56 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +3 -3
  4. wxo_agentic_evaluation/analytics/tools/ux.py +1 -1
  5. wxo_agentic_evaluation/analyze_run.py +10 -10
  6. wxo_agentic_evaluation/arg_configs.py +8 -1
  7. wxo_agentic_evaluation/batch_annotate.py +4 -10
  8. wxo_agentic_evaluation/data_annotator.py +50 -36
  9. wxo_agentic_evaluation/evaluation_package.py +102 -85
  10. wxo_agentic_evaluation/external_agent/__init__.py +37 -0
  11. wxo_agentic_evaluation/external_agent/external_validate.py +74 -31
  12. wxo_agentic_evaluation/external_agent/performance_test.py +66 -0
  13. wxo_agentic_evaluation/external_agent/types.py +8 -2
  14. wxo_agentic_evaluation/inference_backend.py +45 -50
  15. wxo_agentic_evaluation/llm_matching.py +6 -6
  16. wxo_agentic_evaluation/llm_rag_eval.py +4 -4
  17. wxo_agentic_evaluation/llm_user.py +3 -3
  18. wxo_agentic_evaluation/main.py +63 -23
  19. wxo_agentic_evaluation/metrics/metrics.py +72 -5
  20. wxo_agentic_evaluation/prompt/args_extractor_prompt.jinja2 +23 -0
  21. wxo_agentic_evaluation/prompt/batch_testcase_prompt.jinja2 +2 -0
  22. wxo_agentic_evaluation/prompt/examples/data_simple.json +1 -2
  23. wxo_agentic_evaluation/prompt/starting_sentence_generation_prompt.jinja2 +195 -0
  24. wxo_agentic_evaluation/prompt/story_generation_prompt.jinja2 +154 -0
  25. wxo_agentic_evaluation/prompt/template_render.py +17 -0
  26. wxo_agentic_evaluation/prompt/tool_planner.jinja2 +13 -7
  27. wxo_agentic_evaluation/record_chat.py +59 -18
  28. wxo_agentic_evaluation/resource_map.py +47 -0
  29. wxo_agentic_evaluation/service_provider/__init__.py +35 -0
  30. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +108 -0
  31. wxo_agentic_evaluation/service_provider/ollama_provider.py +40 -0
  32. wxo_agentic_evaluation/service_provider/provider.py +19 -0
  33. wxo_agentic_evaluation/{watsonx_provider.py → service_provider/watsonx_provider.py} +54 -57
  34. wxo_agentic_evaluation/test_prompt.py +94 -0
  35. wxo_agentic_evaluation/tool_planner.py +130 -17
  36. wxo_agentic_evaluation/type.py +0 -57
  37. wxo_agentic_evaluation/utils/utils.py +6 -54
  38. ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info/RECORD +0 -46
  39. ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info/licenses/LICENSE +0 -22
  40. {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/WHEEL +0 -0
  41. {ibm_watsonx_orchestrate_evaluation_framework-1.0.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import os
2
+ import requests
3
+ import time
4
+
5
+ from typing import List
6
+ from threading import Lock
7
+
8
+ from wxo_agentic_evaluation.service_provider.provider import Provider
9
+ from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
10
+
11
+ AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
12
+ AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
13
+ WO_INSTANCE = os.environ.get("WO_INSTANCE")
14
+ WO_API_KEY = os.environ.get("WO_API_KEY")
15
+ DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
16
+
17
+
18
+ class ModelProxyProvider(Provider):
19
+ def __init__(
20
+ self,
21
+ model_id=None,
22
+ api_key=WO_API_KEY,
23
+ instance_url=WO_INSTANCE,
24
+ timeout=300,
25
+ embedding_model_id=None,
26
+ params=None
27
+ ):
28
+ super().__init__()
29
+
30
+ if not instance_url or not api_key:
31
+ raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
32
+
33
+ self.timeout = timeout
34
+ self.model_id = model_id
35
+
36
+ self.embedding_model_id = embedding_model_id
37
+
38
+ self.api_key = api_key
39
+ self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
40
+ self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
41
+ self.url = instance_url + "/ml/v1/text/generation?version=2024-05-01"
42
+
43
+ self.embedding_url = instance_url + "/ml/v1/text/embeddings"
44
+
45
+ self.lock = Lock()
46
+ self.token, self.refresh_time = self.get_token()
47
+ self.params = params if params else DEFAULT_PARAM
48
+
49
+ def get_token(self):
50
+ if self.is_ibm_cloud:
51
+ payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
52
+ resp = requests.post(self.auth_url, data=payload)
53
+ token_key = "access_token"
54
+ else:
55
+ payload = {"apikey": self.api_key}
56
+ resp = requests.post(self.auth_url, json=payload)
57
+ token_key = "token"
58
+ if resp.status_code == 200:
59
+ json_obj = resp.json()
60
+ token = json_obj[token_key]
61
+ expires_in = json_obj["expires_in"]
62
+ refresh_time = time.time() + int(0.8*expires_in)
63
+ return token, refresh_time
64
+
65
+ resp.raise_for_status()
66
+
67
+ def refresh_token_if_expires(self):
68
+ if time.time() > self.refresh_time:
69
+ with self.lock:
70
+ if time.time() > self.refresh_time:
71
+ self.token, self.refresh_time = self.get_token()
72
+
73
+ def get_header(self):
74
+ return {"Authorization": f"Bearer {self.token}"}
75
+
76
+ def encode(self, sentences: List[str]) -> List[list]:
77
+ if self.embedding_model_id is None:
78
+ raise Exception("embedding model id must be specified for text generation")
79
+
80
+ self.refresh_token_if_expires()
81
+ headers = self.get_header()
82
+ payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
83
+ #"timeout": self.timeout}
84
+ resp = requests.post(self.embedding_url, json=payload, headers=headers)
85
+
86
+ if resp.status_code == 200:
87
+ json_obj = resp.json()
88
+ return json_obj["generated_text"]
89
+
90
+ resp.raise_for_status()
91
+
92
+ def query(self, sentence: str) -> str:
93
+ if self.model_id is None:
94
+ raise Exception("model id must be specified for text generation")
95
+ self.refresh_token_if_expires()
96
+ headers = self.get_header()
97
+ payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
98
+ "timeout": self.timeout, "parameters": self.params}
99
+ resp = requests.post(self.url, json=payload, headers=headers)
100
+ if resp.status_code == 200:
101
+ return resp.json()["results"][0]["generated_text"]
102
+
103
+ resp.raise_for_status()
104
+
105
+
106
+ if __name__ == "__main__":
107
+ provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
108
+ print(provider.query("ok"))
@@ -0,0 +1,40 @@
1
+ import requests
2
+ import json
3
+ from wxo_agentic_evaluation.service_provider.provider import Provider
4
+ from typing import List
5
+ import os
6
+
7
+ OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
8
+
9
+
10
+ class OllamaProvider(Provider):
11
+ def __init__(
12
+ self,
13
+ model_id=None
14
+ ):
15
+ self.url = OLLAMA_URL + "/api/generate"
16
+ self.model_id = model_id
17
+ super().__init__()
18
+
19
+ def query(self, sentence: str) -> str:
20
+ payload = {"model": self.model_id, "prompt": sentence}
21
+ resp = requests.post(self.url, json=payload, stream=True)
22
+ final_text = ""
23
+ data = b''
24
+ for chunk in resp:
25
+ data += chunk
26
+ if data.endswith(b'\n'):
27
+ json_obj = json.loads(data)
28
+ if not json_obj["done"] and json_obj["response"]:
29
+ final_text += json_obj["response"]
30
+ data = b''
31
+
32
+ return final_text
33
+
34
+ def encode(self, sentences: List[str]) -> List[list]:
35
+ pass
36
+
37
+
38
+ if __name__ == "__main__":
39
+ provider = OllamaProvider(model_id="llama3.1:8b")
40
+ print(provider.query("ok"))
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+
4
+
5
+ class Provider(ABC):
6
+ def __init__(self):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def query(self, sentence: str) -> str:
11
+ pass
12
+
13
+ def batch_query(self, sentences: List[str]) -> List[str]:
14
+ return [self.query(sentence) for sentence in sentences]
15
+
16
+ @abstractmethod
17
+ def encode(self, sentences: List[str]) -> List[list]:
18
+ pass
19
+
@@ -4,10 +4,9 @@ import json
4
4
  from types import MappingProxyType
5
5
  from typing import List
6
6
  import dataclasses
7
- from ibm_watsonx_ai.foundation_models import ModelInference, Embeddings
8
- from ibm_watsonx_ai.credentials import Credentials
9
7
  from threading import Lock
10
-
8
+ import time
9
+ from wxo_agentic_evaluation.service_provider.provider import Provider
11
10
 
12
11
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
13
12
  ACCESS_HEADER = {
@@ -18,11 +17,11 @@ ACCESS_HEADER = {
18
17
  YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
19
18
  PROD_URL = "https://us-south.ml.cloud.ibm.com"
20
19
  DEFAULT_PARAM = MappingProxyType(
21
- {"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 100}
20
+ {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
22
21
  )
23
22
 
24
23
 
25
- class WatsonXProvider:
24
+ class WatsonXProvider(Provider):
26
25
  def __init__(
27
26
  self,
28
27
  model_id=None,
@@ -31,7 +30,7 @@ class WatsonXProvider:
31
30
  api_endpoint=PROD_URL,
32
31
  url=ACCESS_URL,
33
32
  timeout=60,
34
- llm_decode_parameter=DEFAULT_PARAM,
33
+ params=None,
35
34
  embedding_model_id=None,
36
35
  ):
37
36
  super().__init__()
@@ -56,12 +55,15 @@ class WatsonXProvider:
56
55
  self.embedding_model_id = embedding_model_id
57
56
  self.lock = Lock()
58
57
 
59
- if isinstance(llm_decode_parameter, MappingProxyType):
60
- llm_decode_parameter = dict(llm_decode_parameter)
61
- if dataclasses.is_dataclass(llm_decode_parameter):
62
- llm_decode_parameter = dataclasses.asdict(llm_decode_parameter)
58
+ self.params = params if params else DEFAULT_PARAM
59
+
60
+ if isinstance(self.params, MappingProxyType):
61
+ self.params = dict(self.params)
62
+ if dataclasses.is_dataclass(self.params):
63
+ self.params = dataclasses.asdict(self.params)
63
64
 
64
- self.decode_param = llm_decode_parameter
65
+ self.refresh_time = None
66
+ self.access_token = None
65
67
  self._refresh_token()
66
68
 
67
69
  def _get_access_token(self):
@@ -71,75 +73,70 @@ class WatsonXProvider:
71
73
  if response.status_code == 200:
72
74
  token_data = json.loads(response.text)
73
75
  token = token_data["access_token"]
74
-
75
- return token
76
+ expiration = token_data["expiration"]
77
+ expires_in = token_data["expires_in"]
78
+ # 9 minutes before expire
79
+ refresh_time = expiration - int(0.15 * expires_in)
80
+ return token, refresh_time
76
81
 
77
82
  raise RuntimeError(
78
83
  f"try to acquire access token and get {response.status_code}"
79
84
  )
80
85
 
81
- def _refresh_token(self):
82
- self.access_token = self._get_access_token()
83
-
84
- if self.embedding_model_id is not None:
85
- self.embedding_client = Embeddings(
86
- model_id=self.embedding_model_id,
87
- credentials=Credentials(token=self.access_token, url=self.api_endpoint),
88
- space_id=self.space_id,
89
- )
90
- else:
91
- self.embedding_client = None
92
-
93
- if self.model_id is not None:
94
- self.client = ModelInference(
95
- model_id=self.model_id,
96
- params=self.decode_param,
97
- credentials=Credentials(token=self.access_token, url=self.api_endpoint),
98
- space_id=self.space_id,
99
- )
86
+ def prepare_header(self):
87
+ headers = {"Authorization": f"Bearer {self.access_token}",
88
+ "Content-Type": "application/json"}
89
+ return headers
90
+
91
+ def generate(self, sentence: str):
92
+ headers = self.prepare_header()
93
+
94
+ data = {"model_id": self.model_id, "input": sentence,
95
+ "parameters": self.params, "space_id": self.space_id}
96
+ generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
97
+ resp = requests.post(url=generation_url, headers=headers, json=data)
98
+ if resp.status_code == 200:
99
+ return resp.json()["results"][0]
100
100
  else:
101
- self.client = None
101
+ resp.raise_for_status()
102
102
 
103
- def query(self, sentence: str) -> dict:
104
- if self.model_id is None:
105
- raise Exception("model id must be specified for text generation")
106
- try:
107
- return self.client.generate([sentence])[0][
108
- "results"
109
- ][ # pylint: disable=E1136
110
- 0
111
- ]
112
- except Exception as e:
103
+ def _refresh_token(self):
104
+ # if we do not have a token or the current timestamp is 9 minutes away from expire.
105
+ if not self.access_token or time.time() > self.refresh_time:
113
106
  with self.lock:
114
- if "authentication_token_expired" in str(e):
115
- self._refresh_token()
116
- raise e
107
+ if not self.access_token or time.time() > self.refresh_time:
108
+ self.access_token, self.refresh_time = self._get_access_token()
117
109
 
118
- def batch_query(self, sentences: List[str]) -> List[dict]:
110
+ def query(self, sentence: str) -> str:
119
111
  if self.model_id is None:
120
112
  raise Exception("model id must be specified for text generation")
121
113
  try:
122
- outputs = self.client.generate(sentences)
123
- outputs = [output["results"][0] for output in outputs]
124
- return outputs
114
+ return self.generate(sentence)["generated_text"]
125
115
  except Exception as e:
126
116
  with self.lock:
127
117
  if "authentication_token_expired" in str(e):
128
118
  self._refresh_token()
129
119
  raise e
130
- # pylint: disable=E1133
131
- return []
120
+
121
+ def batch_query(self, sentences: List[str]) -> List[dict]:
122
+ return [self.query(sentence) for sentence in sentences]
132
123
 
133
124
  def encode(self, sentences: List[str]) -> List[list]:
134
125
  if self.embedding_model_id is None:
135
126
  raise Exception("embedding model id must be specified for text encoding")
136
- output = self.embedding_client.generate(sentences)
137
- return [entry["embedding"] for entry in output["results"]]
138
127
 
128
+ headers = self.prepare_header()
129
+ url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
139
130
 
140
- if __name__ == "__main__":
141
- import os
131
+ data = {"inputs": sentences, "model_id": self.model_id, "space_id": self.space_id}
132
+ resp = requests.post(url=url, headers=headers, json=data)
133
+ if resp.status_code == 200:
134
+ return [entry["embedding"] for entry in resp.json()["results"]]
135
+ else:
136
+ resp.raise_for_status()
142
137
 
138
+
139
+ if __name__ == "__main__":
143
140
  provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
144
141
 
145
142
  prompt = """
@@ -172,4 +169,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
172
169
  <|eot_id|><|start_header_id|>user<|end_header_id|>
173
170
  """
174
171
 
175
- print(provider.query(prompt))
172
+ print(provider.query(prompt))
@@ -0,0 +1,94 @@
1
+ from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
2
+
3
+
4
+
5
+ def parse_json_string(input_string):
6
+ json_char_count = 0
7
+ json_objects = []
8
+ current_json = ""
9
+ brace_level = 0
10
+ inside_json = False
11
+
12
+ for i, char in enumerate(input_string):
13
+ if char == "{":
14
+ brace_level += 1
15
+ inside_json = True
16
+ json_char_count += 1
17
+ if inside_json:
18
+ current_json += char
19
+ json_char_count += 1
20
+ if char == "}":
21
+ json_char_count += 1
22
+ brace_level -= 1
23
+ if brace_level == 0:
24
+ inside_json = False
25
+ try:
26
+ json_objects.append(json.loads(current_json))
27
+ except json.JSONDecodeError as e:
28
+ print(f"Error decoding JSON: {e}")
29
+ current_json = "" # Reset current JSON string
30
+ # some threshold to say there are some non-funct calling step
31
+ is_thinking_step = len(input_string) - json_char_count > 10
32
+ return json_objects
33
+
34
+ wai_client = WatsonXProvider(model_id="meta-llama/llama-3-405b-instruct")
35
+
36
+ prompt = """
37
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
38
+ You are trying to make tool calls. Given a raw input and tool output. Try to extract the information to make the tool call
39
+
40
+ Example:
41
+ Tool description:
42
+ def get_payslips(user_id: str) -> PayslipsResponse:
43
+ Gets a user's payslips from Workday.
44
+
45
+ :param user_id: The user's id uniquely identifying them within the Workday API.
46
+ :return: The user's payslips.
47
+
48
+ Raw inputs:\{"tool_name": "get_payslips", "args": {"user_id": '$get_user_workday_ids'}}
49
+ tool output: {'user_id': UserWorkdayIDs(person_id='', user_id='6dcb8106e8b74b5aabb1fc3ab8ef2b92')}
50
+ <|start_header_id|>ipython<|end_header_id|>
51
+ {"tool_name": "get_payslips", "args": {"user_id": "6dcb8106e8b74b5aabb1fc3ab8ef2b92"}}
52
+ <|eot_id|>
53
+
54
+ """
55
+
56
+ test_sample1 = """
57
+ <|start_header_id|>assistant<|end_header_id|>
58
+ Tool description:
59
+ def update_direct_reports(email_id: str, members: List[str], notification:bool) -> PayslipsResponse:
60
+ update direct reports for a given user
61
+ :param email_id: The user's email-id uniquely identifying them within the Workday API.
62
+ :param members: a list of user ids to be added as direct reports
63
+ :param notification: do we send the notification to all members
64
+
65
+ Raw inputs: {"tool_name": "update_direct_reports", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
66
+ tool output: {"email_id": 'jalenm3@163.com'}
67
+ {'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
68
+ <|start_header_id|>ipython<|end_header_id|>"""
69
+
70
+
71
+ test_sample2 = """
72
+ <|start_header_id|>assistant<|end_header_id|>
73
+ Tool description:
74
+ def book_meeting(location: str, date: str, time: str) -> bool:
75
+ update direct reports for a given user
76
+ :param email_id: The user's email-id uniquely identifying them within the Workday API.
77
+ :param members: a list of user ids to be added as direct reports
78
+ :param notification: do we send the notification to all members
79
+
80
+ Raw inputs: {"tool_name": "book_meeting", "args": {"email_id": '$get_email_id', 'members': $get_user_by_dvision]}}
81
+ tool output: {"email_id": 'jalenm3@163.com'}
82
+ {'members': [UserProfile(name="Lan Smith", user_id="46873f8i93", email="lan_smith@gmail.com"), UserProfile(name="Mary Rubic", user_id="34sss31", email="MaryRobic@gmail.com"), UserProfile(name="Jason Dai", user_id="8e8ewer3", email="jd@gmail.com"])}
83
+ <|start_header_id|>ipython<|end_header_id|>"""
84
+
85
+
86
+
87
+ outputs = wai_client.query(prompt + test_sample1)
88
+
89
+ import json
90
+ print(outputs["generated_text"])
91
+
92
+ json_obj = parse_json_string(outputs["generated_text"])[0]
93
+
94
+ print(json_obj)
@@ -6,15 +6,25 @@ import importlib.util
6
6
  import re
7
7
  from jsonargparse import CLI
8
8
  import os
9
+ import textwrap
10
+ from dataclasses import is_dataclass, asdict
9
11
 
10
- from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
12
+ from wxo_agentic_evaluation.service_provider import get_provider
11
13
  from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
12
- from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ToolChainAgentTemplateRenderer
14
+ from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
13
15
  from wxo_agentic_evaluation import __file__
14
16
 
15
17
  root_dir = os.path.dirname(__file__)
16
18
  TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
19
+ ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
17
20
 
21
+ class UniversalEncoder(json.JSONEncoder):
22
+ def default(self, obj):
23
+ if is_dataclass(obj):
24
+ return asdict(obj)
25
+ elif hasattr(obj, "__dict__"):
26
+ return obj.__dict__
27
+ return super().default(obj)
18
28
 
19
29
  def extract_first_json_list(raw: str) -> list:
20
30
  matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
@@ -29,6 +39,33 @@ def extract_first_json_list(raw: str) -> list:
29
39
  print(raw)
30
40
  return []
31
41
 
42
+ def parse_json_string(input_string):
43
+ json_char_count = 0
44
+ json_objects = []
45
+ current_json = ""
46
+ brace_level = 0
47
+ inside_json = False
48
+
49
+ for i, char in enumerate(input_string):
50
+ if char == "{":
51
+ brace_level += 1
52
+ inside_json = True
53
+ json_char_count += 1
54
+ if inside_json:
55
+ current_json += char
56
+ json_char_count += 1
57
+ if char == "}":
58
+ json_char_count += 1
59
+ brace_level -= 1
60
+ if brace_level == 0:
61
+ inside_json = False
62
+ try:
63
+ json_objects.append(json.loads(current_json))
64
+ except json.JSONDecodeError as e:
65
+ print(f"Error decoding JSON: {e}")
66
+ current_json = "" # Reset current JSON string
67
+ return json_objects
68
+
32
69
 
33
70
  def load_tools_module(tools_path: Path) -> dict:
34
71
  tools_dict = {}
@@ -93,8 +130,64 @@ def extract_tool_signatures(tools_path: Path) -> list:
93
130
 
94
131
  return tool_data
95
132
 
133
+ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
134
+ functions = {}
135
+ files_to_parse = []
136
+
137
+ # Handle both single file and directory cases
138
+ if tools_path.is_file():
139
+ files_to_parse.append(tools_path)
140
+ elif tools_path.is_dir():
141
+ files_to_parse.extend(tools_path.glob("**/*.py"))
142
+ else:
143
+ raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
144
+
145
+ for file_path in files_to_parse:
146
+ try:
147
+ with file_path.open("r", encoding="utf-8") as f:
148
+ code = f.read()
149
+ parsed_code = ast.parse(code)
150
+
151
+ for node in parsed_code.body:
152
+ if isinstance(node, ast.FunctionDef):
153
+ name = node.name
154
+
155
+ # Get args and type annotations
156
+ args = []
157
+ for arg in node.args.args:
158
+ if arg.arg == "self":
159
+ continue
160
+ annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
161
+ args.append((arg.arg, annotation))
162
+
163
+ # Get return type
164
+ returns = ast.unparse(node.returns) if node.returns else "None"
165
+
166
+ # Get docstring
167
+ docstring = ast.get_docstring(node)
168
+ docstring = textwrap.dedent(docstring).strip() if docstring else ""
169
+
170
+ # Format parameter descriptions if available in docstring
171
+ doc_lines = docstring.splitlines()
172
+ doc_summary = doc_lines[0] if doc_lines else ""
173
+ param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
174
+
175
+ # Compose the final string
176
+ args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
177
+ function_str = f"""def {name}({args_str}) -> {returns}:
178
+ {doc_summary}"""
179
+ if param_descriptions:
180
+ function_str += f"\n {param_descriptions}"
96
181
 
97
- def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_module: dict) -> dict:
182
+ functions[name] = function_str
183
+ except Exception as e:
184
+ print(f"Warning: Failed to parse {file_path}: {str(e)}")
185
+ continue
186
+
187
+ return functions
188
+
189
+ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
190
+ tool_name = step["tool_name"]
98
191
  cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
99
192
  for entry in cache:
100
193
  if entry["inputs"] == inputs:
@@ -103,7 +196,27 @@ def ensure_data_available(tool_name: str, inputs: dict, snapshot: dict, tools_mo
103
196
  if tool_name not in tools_module:
104
197
  raise ValueError(f"Tool '{tool_name}' not found")
105
198
 
106
- output = tools_module[tool_name](**inputs)
199
+ try:
200
+ output = tools_module[tool_name](**inputs)
201
+ except:
202
+ provider = get_provider(
203
+ model_id="meta-llama/llama-3-405b-instruct",
204
+ params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
205
+ )
206
+ renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
207
+
208
+ prompt = renderer.render(
209
+ tool_signature=tool_signatures_for_prompt[tool_name],
210
+ step=step,
211
+ inputs=inputs,
212
+ )
213
+ response = provider.query(prompt)
214
+ json_obj = parse_json_string(response)[0]
215
+ try:
216
+ output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
217
+ except:
218
+ raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
219
+
107
220
  cache.append({"inputs": inputs, "output": output})
108
221
  if not isinstance(output, dict):
109
222
  print(f" Tool {tool_name} returned non-dict output: {output}")
@@ -119,15 +232,14 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
119
232
  available_tools=tool_signatures_str,
120
233
  )
121
234
  response = provider.query(prompt)
122
- raw = response.get("generated_text", "")
123
- parsed = extract_first_json_list(raw)
235
+ parsed = extract_first_json_list(response)
124
236
  print("\n LLM Tool Plan:")
125
237
  print(json.dumps(parsed, indent=2))
126
238
  return parsed
127
239
 
128
240
 
129
241
  # --- Tool Execution Logic ---
130
- def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
242
+ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
131
243
  memory = {}
132
244
 
133
245
  for step in tool_plan:
@@ -166,14 +278,14 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module) -> None:
166
278
  item_inputs = resolved_inputs.copy()
167
279
  item_inputs[list_key] = val
168
280
  print(f" ⚙️ Running {name} with {list_key} = {val}")
169
- output = ensure_data_available(name, item_inputs, snapshot, tools_module)
281
+ output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
170
282
  results.append(output)
171
283
  memory[f"{name}_{idx}"] = output
172
284
 
173
285
  memory[name] = results
174
286
  print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
175
287
  else:
176
- output = ensure_data_available(name, resolved_inputs, snapshot, tools_module)
288
+ output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
177
289
  memory[name] = output
178
290
  print(f"Stored output under tool name: {name} = {output}")
179
291
 
@@ -183,14 +295,11 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
183
295
  agent = {"name": agent_name}
184
296
  tools_module = load_tools_module(tools_path)
185
297
  tool_signatures = extract_tool_signatures(tools_path)
298
+ tool_signatures_for_prompt = extract_tool_signatures_for_prompt(tools_path)
186
299
 
187
- provider = WatsonXProvider(
300
+ provider = get_provider(
188
301
  model_id="meta-llama/llama-3-405b-instruct",
189
- llm_decode_parameter={
190
- "min_new_tokens": 50,
191
- "decoding_method": "greedy",
192
- "max_new_tokens": 200
193
- }
302
+ params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
194
303
  )
195
304
 
196
305
  snapshot = {
@@ -202,10 +311,14 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
202
311
  for story in stories:
203
312
  print(f"\n📘 Planning tool calls for story: {story}")
204
313
  tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
205
- run_tool_chain(tool_plan, snapshot, tools_module)
314
+ try:
315
+ run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
316
+ except ValueError as e:
317
+ print(f"❌ Error running tool chain for story '{story}': {e}")
318
+ continue
206
319
 
207
320
  with output_path.open("w", encoding="utf-8") as f:
208
- json.dump(snapshot, f, indent=2)
321
+ json.dump(snapshot, f, indent=2, cls=UniversalEncoder)
209
322
  print(f"\n✅ Snapshot saved to {output_path}")
210
323
 
211
324