ibm-watsonx-orchestrate-evaluation-framework 1.0.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/METADATA +53 -0
  2. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +38 -21
  4. wxo_agentic_evaluation/analytics/tools/main.py +19 -25
  5. wxo_agentic_evaluation/analytics/tools/types.py +26 -11
  6. wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
  7. wxo_agentic_evaluation/analyze_run.py +1184 -97
  8. wxo_agentic_evaluation/annotate.py +7 -5
  9. wxo_agentic_evaluation/arg_configs.py +97 -5
  10. wxo_agentic_evaluation/base_user.py +25 -0
  11. wxo_agentic_evaluation/batch_annotate.py +97 -27
  12. wxo_agentic_evaluation/clients.py +103 -0
  13. wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
  14. wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
  15. wxo_agentic_evaluation/compare_runs/diff.py +554 -0
  16. wxo_agentic_evaluation/compare_runs/model.py +193 -0
  17. wxo_agentic_evaluation/data_annotator.py +45 -19
  18. wxo_agentic_evaluation/description_quality_checker.py +178 -0
  19. wxo_agentic_evaluation/evaluation.py +50 -0
  20. wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
  21. wxo_agentic_evaluation/evaluation_package.py +544 -107
  22. wxo_agentic_evaluation/external_agent/__init__.py +18 -7
  23. wxo_agentic_evaluation/external_agent/external_validate.py +49 -36
  24. wxo_agentic_evaluation/external_agent/performance_test.py +33 -22
  25. wxo_agentic_evaluation/external_agent/types.py +8 -7
  26. wxo_agentic_evaluation/extractors/__init__.py +3 -0
  27. wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
  28. wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
  29. wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
  30. wxo_agentic_evaluation/langfuse_collection.py +60 -0
  31. wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
  32. wxo_agentic_evaluation/llm_matching.py +108 -5
  33. wxo_agentic_evaluation/llm_rag_eval.py +7 -4
  34. wxo_agentic_evaluation/llm_safety_eval.py +64 -0
  35. wxo_agentic_evaluation/llm_user.py +12 -6
  36. wxo_agentic_evaluation/llm_user_v2.py +114 -0
  37. wxo_agentic_evaluation/main.py +128 -246
  38. wxo_agentic_evaluation/metrics/__init__.py +15 -0
  39. wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
  40. wxo_agentic_evaluation/metrics/evaluations.py +107 -0
  41. wxo_agentic_evaluation/metrics/journey_success.py +137 -0
  42. wxo_agentic_evaluation/metrics/llm_as_judge.py +28 -2
  43. wxo_agentic_evaluation/metrics/metrics.py +319 -16
  44. wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
  45. wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
  46. wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
  47. wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
  48. wxo_agentic_evaluation/otel_parser/parser.py +163 -0
  49. wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
  50. wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
  51. wxo_agentic_evaluation/otel_parser/utils.py +15 -0
  52. wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
  53. wxo_agentic_evaluation/otel_support/evaluate_tau.py +101 -0
  54. wxo_agentic_evaluation/otel_support/otel_message_conversion.py +29 -0
  55. wxo_agentic_evaluation/otel_support/tasks_test.py +1566 -0
  56. wxo_agentic_evaluation/prompt/bad_tool_descriptions_prompt.jinja2 +178 -0
  57. wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
  58. wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +59 -5
  59. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  60. wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +34 -0
  61. wxo_agentic_evaluation/prompt/on_policy_attack_generation_prompt.jinja2 +46 -0
  62. wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
  63. wxo_agentic_evaluation/prompt/template_render.py +163 -12
  64. wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
  65. wxo_agentic_evaluation/quick_eval.py +384 -0
  66. wxo_agentic_evaluation/record_chat.py +132 -81
  67. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +302 -0
  68. wxo_agentic_evaluation/red_teaming/attack_generator.py +329 -0
  69. wxo_agentic_evaluation/red_teaming/attack_list.py +184 -0
  70. wxo_agentic_evaluation/red_teaming/attack_runner.py +204 -0
  71. wxo_agentic_evaluation/referenceless_eval/__init__.py +3 -0
  72. wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py +0 -0
  73. wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py +28 -0
  74. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py +0 -0
  75. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +29 -0
  76. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/__init__.py +0 -0
  77. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general.py +49 -0
  78. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
  79. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
  80. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/__init__.py +0 -0
  81. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection.py +31 -0
  82. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
  83. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
  84. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +245 -0
  85. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py +0 -0
  86. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +106 -0
  87. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +291 -0
  88. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +465 -0
  89. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +162 -0
  90. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py +509 -0
  91. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +562 -0
  92. wxo_agentic_evaluation/referenceless_eval/metrics/__init__.py +3 -0
  93. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +266 -0
  94. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +344 -0
  95. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +193 -0
  96. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +413 -0
  97. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +46 -0
  98. wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py +0 -0
  99. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +158 -0
  100. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +191 -0
  101. wxo_agentic_evaluation/resource_map.py +6 -3
  102. wxo_agentic_evaluation/runner.py +329 -0
  103. wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
  104. wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
  105. wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +88 -150
  106. wxo_agentic_evaluation/scheduler.py +247 -0
  107. wxo_agentic_evaluation/service_instance.py +117 -26
  108. wxo_agentic_evaluation/service_provider/__init__.py +182 -17
  109. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  110. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +628 -45
  111. wxo_agentic_evaluation/service_provider/ollama_provider.py +392 -22
  112. wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
  113. wxo_agentic_evaluation/service_provider/provider.py +129 -10
  114. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +203 -0
  115. wxo_agentic_evaluation/service_provider/watsonx_provider.py +516 -53
  116. wxo_agentic_evaluation/simluation_runner.py +125 -0
  117. wxo_agentic_evaluation/test_prompt.py +4 -4
  118. wxo_agentic_evaluation/tool_planner.py +141 -46
  119. wxo_agentic_evaluation/type.py +217 -14
  120. wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
  121. wxo_agentic_evaluation/utils/__init__.py +44 -3
  122. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  123. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  124. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  125. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +178 -0
  126. wxo_agentic_evaluation/utils/parsers.py +71 -0
  127. wxo_agentic_evaluation/utils/rich_utils.py +188 -0
  128. wxo_agentic_evaluation/utils/rouge_score.py +23 -0
  129. wxo_agentic_evaluation/utils/utils.py +514 -17
  130. wxo_agentic_evaluation/wxo_client.py +81 -0
  131. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/METADATA +0 -380
  132. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +0 -56
  133. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
  134. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,125 @@
1
+ from wxo_agentic_evaluation.evaluation_controller.evaluation_controller import EvaluationController
2
+ from langfuse import get_client
3
+
4
+ from wxo_agentic_evaluation.runtime_adapter.runtime_adapter import RuntimeAdapter
5
+ from wxo_agentic_evaluation.runtime_adapter.wxo_runtime_adapter import WXORuntimeAdapter
6
+ from wxo_agentic_evaluation.type import Message, RuntimeResponse
7
+ from wxo_agentic_evaluation.llm_user import LLMUser
8
+ from wxo_agentic_evaluation.llm_user_v2 import LLMUserV2
9
+ from wxo_agentic_evaluation.arg_configs import ControllerConfig
10
+ from wxo_agentic_evaluation.hr_agent_langgraph import agent
11
+
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+ import os
15
+ import base64
16
+
17
+ os.environ["USE_PORTKEY_PROVIDER"] = "true"
18
+
19
+ lf_public = os.getenv("LANGFUSE_PUBLIC_KEY")
20
+ lf_secret = os.getenv("LANGFUSE_SECRET_KEY")
21
+ auth_bytes = f"{lf_public}:{lf_secret}".encode("utf-8")
22
+ auth_b64 = base64.b64encode(auth_bytes).decode("ascii")
23
+ HEADERS = {"Authorization": f"Basic {auth_b64}"}
24
+
25
+ lf_base_url = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com").rstrip("/")
26
+ OTEL_ENDPOINT = f"{lf_base_url}/api/public/otel/v1/traces"
27
+
28
+ from phoenix.otel import register
29
+ register(endpoint=OTEL_ENDPOINT, headers=HEADERS, auto_instrument=True)
30
+
31
+
32
+ context = {"session_id": "1", "chat_history": []}
33
+
34
+
35
+ class MyAgentWrapper(RuntimeAdapter):
36
+ def run(
37
+ self,
38
+ user_message: Message,
39
+ context: dict,
40
+ thread_id=None,
41
+ ) -> RuntimeResponse:
42
+
43
+ message_json = user_message.model_dump()
44
+ messages = {"messages": [ message_json ]}
45
+ result = agent.invoke(messages)
46
+ # print(result)
47
+ message = Message(role="assistant", content=result["messages"][-1].content)
48
+ # messages = [Message(role="assistant", content=msg.content, type="tool_call") for msg in result["messages"]]
49
+ return RuntimeResponse(messages=[message])
50
+
51
+
52
+
53
+ agent_wrapper = MyAgentWrapper()
54
+ from openinference.instrumentation import using_session
55
+
56
+
57
+ class SimulationRunner:
58
+ def __init__(self, user_agent: LLMUser,
59
+ agent: RuntimeAdapter,
60
+ config: ControllerConfig):
61
+ self.evaluation_controller = EvaluationController(
62
+ runtime=agent,
63
+ llm_user=user_agent,
64
+ config=config,
65
+ )
66
+ self.counter = 0
67
+
68
+
69
+ def run_wrapper(self, session_id = 'session-id-test-00'):
70
+ def run_task(*, item, **kwargs):
71
+ """
72
+ Task function for Langfuse experiment.
73
+ Item input should be: {"persona": "...", "scenario": "..."}
74
+ """
75
+ # print(item)
76
+ with using_session(session_id + "-" + self.counter.__str__()):
77
+ input = item.input
78
+ user_story = input.get("story")
79
+ starting_sentence = input.get("starting_sentence")
80
+ agent_name = input.get("agent")
81
+ _, _, _, thread_id = self.evaluation_controller.run(self.counter, agent_name=agent_name, story=user_story, starting_user_input=starting_sentence)
82
+ self.counter += 1
83
+ if isinstance(self.evaluation_controller.runtime, WXORuntimeAdapter):
84
+ return thread_id
85
+ return session_id
86
+
87
+
88
+ return run_task
89
+
90
+ if __name__ == "__main__":
91
+ import json
92
+ with open("benchmarks/hr_sample/data_simple.json") as f:
93
+ data = json.load(f)
94
+ langfuse = get_client()
95
+ langfuse.create_dataset(name="dataset-test-00")
96
+ # Upload to Langfuse
97
+
98
+ langfuse.create_dataset_item(
99
+ dataset_name="dataset-test-00",
100
+ # any python object or value
101
+ input={"story": data["story"], "starting_sentence": data["starting_sentence"]},
102
+ # any python object or value, optional
103
+ expected_output={"goals": data["goals"], "goal_details": data["goal_details"]},
104
+ )
105
+ from wxo_agentic_evaluation.service_provider import get_provider
106
+
107
+ model_id = "gpt-4o-mini"
108
+ provider = get_provider(provider="openai", model_id=model_id, api_key=os.getenv("OPENAI_API_KEY"),
109
+ use_portkey_provider=True)
110
+ llm_user = LLMUserV2(llm_client=provider, user_prompt_path="src/wxo_agentic_evaluation/prompt/universal_user_template.jinja2")
111
+ config = ControllerConfig()
112
+ simluation_runner = SimulationRunner(agent = agent_wrapper, user_agent=llm_user, config=config)
113
+ dataset = langfuse.get_dataset("dataset-test-00")
114
+
115
+ result = dataset.run_experiment(
116
+ name="experiment-test-00",
117
+ description="Synthetic conversations from persona/scenario pairs",
118
+ task=simluation_runner.run_wrapper()
119
+ )
120
+
121
+ get_client().flush()
122
+ session_id = "dummy-1"
123
+ with using_session(session_id):
124
+ result = agent_wrapper.run(Message(role="user", content="hi"), context={})
125
+ print(result)
@@ -1,7 +1,6 @@
1
1
  from wxo_agentic_evaluation.watsonx_provider import WatsonXProvider
2
2
 
3
3
 
4
-
5
4
  def parse_json_string(input_string):
6
5
  json_char_count = 0
7
6
  json_objects = []
@@ -31,9 +30,10 @@ def parse_json_string(input_string):
31
30
  is_thinking_step = len(input_string) - json_char_count > 10
32
31
  return json_objects
33
32
 
33
+
34
34
  wai_client = WatsonXProvider(model_id="meta-llama/llama-3-405b-instruct")
35
35
 
36
- prompt = """
36
+ prompt = """
37
37
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
38
38
  You are trying to make tool calls. Given a raw input and tool output. Try to extract the information to make the tool call
39
39
 
@@ -83,12 +83,12 @@ test_sample2 = """
83
83
  <|start_header_id|>ipython<|end_header_id|>"""
84
84
 
85
85
 
86
-
87
86
  outputs = wai_client.query(prompt + test_sample1)
88
87
 
89
88
  import json
89
+
90
90
  print(outputs["generated_text"])
91
91
 
92
92
  json_obj = parse_json_string(outputs["generated_text"])[0]
93
93
 
94
- print(json_obj)
94
+ print(json_obj)
@@ -1,22 +1,34 @@
1
- import json
2
1
  import ast
3
2
  import csv
4
- from pathlib import Path
5
3
  import importlib.util
6
- import re
7
- from jsonargparse import CLI
4
+ import json
8
5
  import os
6
+ import re
7
+ import sys
9
8
  import textwrap
10
- from dataclasses import is_dataclass, asdict
9
+ from dataclasses import asdict, is_dataclass
10
+ from pathlib import Path
11
+
12
+ from jsonargparse import CLI
11
13
 
12
- from wxo_agentic_evaluation.service_provider import get_provider
13
- from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
14
- from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
15
14
  from wxo_agentic_evaluation import __file__
15
+ from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
16
+ from wxo_agentic_evaluation.prompt.template_render import (
17
+ ArgsExtractorTemplateRenderer,
18
+ ToolPlannerTemplateRenderer,
19
+ )
20
+ from wxo_agentic_evaluation.service_provider import get_provider
16
21
 
17
22
  root_dir = os.path.dirname(__file__)
18
- TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
19
- ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
23
+ TOOL_PLANNER_PROMPT_PATH = os.path.join(
24
+ root_dir, "prompt", "tool_planner.jinja2"
25
+ )
26
+ ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(
27
+ root_dir, "prompt", "args_extractor_prompt.jinja2"
28
+ )
29
+
30
+ MISSING_DOCSTRING_PROMPT = "No description available"
31
+
20
32
 
21
33
  class UniversalEncoder(json.JSONEncoder):
22
34
  def default(self, obj):
@@ -26,12 +38,15 @@ class UniversalEncoder(json.JSONEncoder):
26
38
  return obj.__dict__
27
39
  return super().default(obj)
28
40
 
41
+
29
42
  def extract_first_json_list(raw: str) -> list:
30
43
  matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
31
44
  for match in matches:
32
45
  try:
33
46
  parsed = json.loads(match)
34
- if isinstance(parsed, list) and all("tool_name" in step for step in parsed):
47
+ if isinstance(parsed, list) and all(
48
+ "tool_name" in step for step in parsed
49
+ ):
35
50
  return parsed
36
51
  except Exception:
37
52
  continue
@@ -39,6 +54,7 @@ def extract_first_json_list(raw: str) -> list:
39
54
  print(raw)
40
55
  return []
41
56
 
57
+
42
58
  def parse_json_string(input_string):
43
59
  json_char_count = 0
44
60
  json_objects = []
@@ -76,19 +92,31 @@ def load_tools_module(tools_path: Path) -> dict:
76
92
  elif tools_path.is_dir():
77
93
  files_to_parse.extend(tools_path.glob("**/*.py"))
78
94
  else:
79
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
95
+ raise ValueError(
96
+ f"Tools path {tools_path} is neither a file nor directory"
97
+ )
80
98
 
81
99
  for file_path in files_to_parse:
82
100
  try:
83
101
  module_name = file_path.stem
84
- spec = importlib.util.spec_from_file_location(module_name, file_path)
102
+ spec = importlib.util.spec_from_file_location(
103
+ module_name, file_path
104
+ )
85
105
  module = importlib.util.module_from_spec(spec)
86
- spec.loader.exec_module(module)
87
-
106
+ parent_dir = str(file_path.parent)
107
+ sys_path_modified = False
108
+ if parent_dir not in sys.path:
109
+ sys.path.append(parent_dir)
110
+ sys_path_modified = True
111
+ try:
112
+ spec.loader.exec_module(module)
113
+ finally:
114
+ if sys_path_modified:
115
+ sys.path.pop()
88
116
  # Add all module's non-private functions to tools_dict
89
117
  for attr_name in dir(module):
90
118
  attr = getattr(module, attr_name)
91
- if callable(attr) and not attr_name.startswith('_'):
119
+ if callable(attr) and not attr_name.startswith("_"):
92
120
  tools_dict[attr_name] = attr
93
121
  except Exception as e:
94
122
  print(f"Warning: Failed to load {file_path}: {str(e)}")
@@ -106,7 +134,9 @@ def extract_tool_signatures(tools_path: Path) -> list:
106
134
  elif tools_path.is_dir():
107
135
  files_to_parse.extend(tools_path.glob("**/*.py"))
108
136
  else:
109
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
137
+ raise ValueError(
138
+ f"Tools path {tools_path} is neither a file nor directory"
139
+ )
110
140
 
111
141
  for file_path in files_to_parse:
112
142
  try:
@@ -117,19 +147,24 @@ def extract_tool_signatures(tools_path: Path) -> list:
117
147
  for node in parsed_code.body:
118
148
  if isinstance(node, ast.FunctionDef):
119
149
  name = node.name
120
- args = [arg.arg for arg in node.args.args if arg.arg != "self"]
150
+ args = [
151
+ arg.arg for arg in node.args.args if arg.arg != "self"
152
+ ]
121
153
  docstring = ast.get_docstring(node)
122
- tool_data.append({
123
- "Function Name": name,
124
- "Arguments": args,
125
- "Docstring": docstring or "No description available"
126
- })
154
+ tool_data.append(
155
+ {
156
+ "Function Name": name,
157
+ "Arguments": args,
158
+ "Docstring": docstring or MISSING_DOCSTRING_PROMPT,
159
+ }
160
+ )
127
161
  except Exception as e:
128
162
  print(f"Warning: Failed to parse {file_path}: {str(e)}")
129
163
  continue
130
164
 
131
165
  return tool_data
132
166
 
167
+
133
168
  def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
134
169
  functions = {}
135
170
  files_to_parse = []
@@ -140,7 +175,9 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
140
175
  elif tools_path.is_dir():
141
176
  files_to_parse.extend(tools_path.glob("**/*.py"))
142
177
  else:
143
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
178
+ raise ValueError(
179
+ f"Tools path {tools_path} is neither a file nor directory"
180
+ )
144
181
 
145
182
  for file_path in files_to_parse:
146
183
  try:
@@ -157,23 +194,35 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
157
194
  for arg in node.args.args:
158
195
  if arg.arg == "self":
159
196
  continue
160
- annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
197
+ annotation = (
198
+ ast.unparse(arg.annotation)
199
+ if arg.annotation
200
+ else "Any"
201
+ )
161
202
  args.append((arg.arg, annotation))
162
203
 
163
204
  # Get return type
164
- returns = ast.unparse(node.returns) if node.returns else "None"
205
+ returns = (
206
+ ast.unparse(node.returns) if node.returns else "None"
207
+ )
165
208
 
166
209
  # Get docstring
167
210
  docstring = ast.get_docstring(node)
168
- docstring = textwrap.dedent(docstring).strip() if docstring else ""
211
+ docstring = (
212
+ textwrap.dedent(docstring).strip() if docstring else ""
213
+ )
169
214
 
170
215
  # Format parameter descriptions if available in docstring
171
216
  doc_lines = docstring.splitlines()
172
217
  doc_summary = doc_lines[0] if doc_lines else ""
173
- param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
218
+ param_descriptions = "\n".join(
219
+ [line for line in doc_lines[1:] if ":param" in line]
220
+ )
174
221
 
175
222
  # Compose the final string
176
- args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
223
+ args_str = ", ".join(
224
+ f"{arg}: {type_}" for arg, type_ in args
225
+ )
177
226
  function_str = f"""def {name}({args_str}) -> {returns}:
178
227
  {doc_summary}"""
179
228
  if param_descriptions:
@@ -186,9 +235,18 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
186
235
 
187
236
  return functions
188
237
 
189
- def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
238
+
239
+ def ensure_data_available(
240
+ step: dict,
241
+ inputs: dict,
242
+ snapshot: dict,
243
+ tools_module: dict,
244
+ tool_signatures_for_prompt,
245
+ ) -> dict:
190
246
  tool_name = step["tool_name"]
191
- cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
247
+ cache = snapshot.setdefault("input_output_examples", {}).setdefault(
248
+ tool_name, []
249
+ )
192
250
  for entry in cache:
193
251
  if entry["inputs"] == inputs:
194
252
  return entry["output"]
@@ -201,7 +259,11 @@ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module
201
259
  except:
202
260
  provider = get_provider(
203
261
  model_id="meta-llama/llama-3-405b-instruct",
204
- params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
262
+ params={
263
+ "min_new_tokens": 0,
264
+ "decoding_method": "greedy",
265
+ "max_new_tokens": 500,
266
+ },
205
267
  )
206
268
  renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
207
269
 
@@ -215,14 +277,19 @@ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module
215
277
  try:
216
278
  output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
217
279
  except:
218
- raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
280
+ raise ValueError(
281
+ f"Failed to execute tool '{tool_name}' with inputs {inputs}"
282
+ )
219
283
 
220
284
  cache.append({"inputs": inputs, "output": output})
221
285
  if not isinstance(output, dict):
222
286
  print(f" Tool {tool_name} returned non-dict output: {output}")
223
287
  return output
224
288
 
225
- def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: str, provider) -> list:
289
+
290
+ def plan_tool_calls_with_llm(
291
+ story: str, agent_name: str, tool_signatures_str: str, provider
292
+ ) -> list:
226
293
 
227
294
  renderer = ToolPlannerTemplateRenderer(TOOL_PLANNER_PROMPT_PATH)
228
295
 
@@ -239,7 +306,9 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
239
306
 
240
307
 
241
308
  # --- Tool Execution Logic ---
242
- def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
309
+ def run_tool_chain(
310
+ tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt
311
+ ) -> None:
243
312
  memory = {}
244
313
 
245
314
  for step in tool_plan:
@@ -269,7 +338,9 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signature
269
338
 
270
339
  if list_keys:
271
340
  if len(list_keys) > 1:
272
- raise ValueError(f"Tool '{name}' received multiple list inputs. Only one supported for now.")
341
+ raise ValueError(
342
+ f"Tool '{name}' received multiple list inputs. Only one supported for now."
343
+ )
273
344
  list_key = list_keys[0]
274
345
  value_list = resolved_inputs[list_key]
275
346
 
@@ -278,20 +349,36 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signature
278
349
  item_inputs = resolved_inputs.copy()
279
350
  item_inputs[list_key] = val
280
351
  print(f" ⚙️ Running {name} with {list_key} = {val}")
281
- output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
352
+ output = ensure_data_available(
353
+ step,
354
+ item_inputs,
355
+ snapshot,
356
+ tools_module,
357
+ tool_signatures_for_prompt,
358
+ )
282
359
  results.append(output)
283
360
  memory[f"{name}_{idx}"] = output
284
361
 
285
362
  memory[name] = results
286
- print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
363
+ print(
364
+ f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'"
365
+ )
287
366
  else:
288
- output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
367
+ output = ensure_data_available(
368
+ step,
369
+ resolved_inputs,
370
+ snapshot,
371
+ tools_module,
372
+ tool_signatures_for_prompt,
373
+ )
289
374
  memory[name] = output
290
375
  print(f"Stored output under tool name: {name} = {output}")
291
376
 
292
377
 
293
378
  # --- Main Snapshot Builder ---
294
- def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path: Path):
379
+ def build_snapshot(
380
+ agent_name: str, tools_path: Path, stories: list, output_path: Path
381
+ ):
295
382
  agent = {"name": agent_name}
296
383
  tools_module = load_tools_module(tools_path)
297
384
  tool_signatures = extract_tool_signatures(tools_path)
@@ -299,20 +386,28 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
299
386
 
300
387
  provider = get_provider(
301
388
  model_id="meta-llama/llama-3-405b-instruct",
302
- params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
389
+ params={
390
+ "min_new_tokens": 1,
391
+ "decoding_method": "greedy",
392
+ "max_new_tokens": 2048,
393
+ },
303
394
  )
304
395
 
305
396
  snapshot = {
306
397
  "agent": agent,
307
398
  "tools": tool_signatures,
308
- "input_output_examples": {}
399
+ "input_output_examples": {},
309
400
  }
310
401
 
311
402
  for story in stories:
312
403
  print(f"\n📘 Planning tool calls for story: {story}")
313
- tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
404
+ tool_plan = plan_tool_calls_with_llm(
405
+ story, agent["name"], tool_signatures, provider
406
+ )
314
407
  try:
315
- run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
408
+ run_tool_chain(
409
+ tool_plan, snapshot, tools_module, tool_signatures_for_prompt
410
+ )
316
411
  except ValueError as e:
317
412
  print(f"❌ Error running tool chain for story '{story}': {e}")
318
413
  continue
@@ -329,7 +424,7 @@ if __name__ == "__main__":
329
424
 
330
425
  stories = []
331
426
  agent_name = None
332
- with stories_path.open("r", encoding="utf-8", newline='') as f:
427
+ with stories_path.open("r", encoding="utf-8", newline="") as f:
333
428
  csv_reader = csv.DictReader(f)
334
429
  for row in csv_reader:
335
430
  stories.append(row["story"])
@@ -338,4 +433,4 @@ if __name__ == "__main__":
338
433
 
339
434
  snapshot_path = stories_path.parent / f"{agent_name}_snapshot_llm.json"
340
435
 
341
- build_snapshot(agent_name, tools_path, stories, snapshot_path)
436
+ build_snapshot(agent_name, tools_path, stories, snapshot_path)