ibm-watsonx-orchestrate-evaluation-framework 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info/METADATA +35 -0
  2. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info}/RECORD +65 -60
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +36 -21
  4. wxo_agentic_evaluation/analytics/tools/main.py +18 -7
  5. wxo_agentic_evaluation/analytics/tools/types.py +26 -11
  6. wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
  7. wxo_agentic_evaluation/analyze_run.py +69 -48
  8. wxo_agentic_evaluation/annotate.py +6 -4
  9. wxo_agentic_evaluation/arg_configs.py +9 -3
  10. wxo_agentic_evaluation/batch_annotate.py +78 -25
  11. wxo_agentic_evaluation/data_annotator.py +18 -13
  12. wxo_agentic_evaluation/description_quality_checker.py +20 -14
  13. wxo_agentic_evaluation/evaluation.py +42 -0
  14. wxo_agentic_evaluation/evaluation_package.py +117 -70
  15. wxo_agentic_evaluation/external_agent/__init__.py +18 -7
  16. wxo_agentic_evaluation/external_agent/external_validate.py +46 -35
  17. wxo_agentic_evaluation/external_agent/performance_test.py +32 -20
  18. wxo_agentic_evaluation/external_agent/types.py +12 -5
  19. wxo_agentic_evaluation/inference_backend.py +183 -79
  20. wxo_agentic_evaluation/llm_matching.py +4 -3
  21. wxo_agentic_evaluation/llm_rag_eval.py +7 -4
  22. wxo_agentic_evaluation/llm_user.py +7 -3
  23. wxo_agentic_evaluation/main.py +175 -67
  24. wxo_agentic_evaluation/metrics/llm_as_judge.py +2 -2
  25. wxo_agentic_evaluation/metrics/metrics.py +26 -12
  26. wxo_agentic_evaluation/otel_support/evaluate_tau.py +67 -0
  27. wxo_agentic_evaluation/otel_support/evaluate_tau_traces.py +176 -0
  28. wxo_agentic_evaluation/otel_support/otel_message_conversion.py +21 -0
  29. wxo_agentic_evaluation/otel_support/tasks_test.py +1226 -0
  30. wxo_agentic_evaluation/prompt/template_render.py +32 -11
  31. wxo_agentic_evaluation/quick_eval.py +49 -23
  32. wxo_agentic_evaluation/record_chat.py +70 -33
  33. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +58 -18
  34. wxo_agentic_evaluation/red_teaming/attack_generator.py +38 -18
  35. wxo_agentic_evaluation/red_teaming/attack_runner.py +43 -27
  36. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +3 -1
  37. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +23 -15
  38. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +13 -8
  39. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +41 -13
  40. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +26 -16
  41. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +17 -11
  42. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +44 -29
  43. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +13 -5
  44. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +16 -5
  45. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +8 -3
  46. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +6 -2
  47. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +5 -1
  48. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +16 -3
  49. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +23 -12
  50. wxo_agentic_evaluation/resource_map.py +2 -1
  51. wxo_agentic_evaluation/service_instance.py +103 -21
  52. wxo_agentic_evaluation/service_provider/__init__.py +33 -13
  53. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +216 -34
  54. wxo_agentic_evaluation/service_provider/ollama_provider.py +10 -11
  55. wxo_agentic_evaluation/service_provider/provider.py +0 -1
  56. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +34 -21
  57. wxo_agentic_evaluation/service_provider/watsonx_provider.py +50 -22
  58. wxo_agentic_evaluation/tool_planner.py +128 -44
  59. wxo_agentic_evaluation/type.py +12 -9
  60. wxo_agentic_evaluation/utils/__init__.py +1 -0
  61. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +41 -20
  62. wxo_agentic_evaluation/utils/rich_utils.py +23 -9
  63. wxo_agentic_evaluation/utils/utils.py +83 -52
  64. ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info/METADATA +0 -386
  65. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info}/WHEEL +0 -0
  66. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,15 @@
1
- import requests
2
- from typing import List, Mapping, Union, Optional, Any
3
1
  from abc import ABC, abstractmethod
2
+ from typing import Any, List, Mapping, Optional, Union
4
3
 
4
+ import requests
5
5
  import rich
6
6
 
7
- from wxo_agentic_evaluation.service_provider.model_proxy_provider import ModelProxyProvider
8
- from wxo_agentic_evaluation.service_provider.watsonx_provider import WatsonXProvider
7
+ from wxo_agentic_evaluation.service_provider.model_proxy_provider import (
8
+ ModelProxyProvider,
9
+ )
10
+ from wxo_agentic_evaluation.service_provider.watsonx_provider import (
11
+ WatsonXProvider,
12
+ )
9
13
 
10
14
 
11
15
  class LLMResponse:
@@ -14,7 +18,9 @@ class LLMResponse:
14
18
  Response object that can contain both content and tool calls
15
19
  """
16
20
 
17
- def __init__(self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None):
21
+ def __init__(
22
+ self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None
23
+ ):
18
24
  self.content = content
19
25
  self.tool_calls = tool_calls or []
20
26
 
@@ -26,25 +32,26 @@ class LLMResponse:
26
32
  """Return a string representation of the LLMResponse object."""
27
33
  return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
28
34
 
35
+
29
36
  class LLMKitWrapper(ABC):
30
- """ In the future this wrapper won't be neccesary.
37
+ """In the future this wrapper won't be neccesary.
31
38
  Right now the referenceless code requires a `generate()` function for the metrics client.
32
39
  In refactor, rewrite referenceless code so this wrapper is not needed.
33
40
  """
41
+
34
42
  @abstractmethod
35
43
  def chat():
36
44
  pass
37
45
 
38
46
  def generate(
39
- self,
40
- prompt: Union[str, List[Mapping[str, str]]],
41
- *,
42
- schema,
43
- retries: int = 3,
44
- generation_args: Optional[Any] = None,
45
- **kwargs: Any
46
- ):
47
-
47
+ self,
48
+ prompt: Union[str, List[Mapping[str, str]]],
49
+ *,
50
+ schema,
51
+ retries: int = 3,
52
+ generation_args: Optional[Any] = None,
53
+ **kwargs: Any,
54
+ ):
48
55
  """
49
56
  In future, implement validation of response like in llmevalkit
50
57
  """
@@ -55,7 +62,9 @@ class LLMKitWrapper(ABC):
55
62
  response = self._parse_llm_response(raw_response)
56
63
  return response
57
64
  except Exception as e:
58
- rich.print(f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))")
65
+ rich.print(
66
+ f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))"
67
+ )
59
68
 
60
69
  def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
61
70
  """
@@ -82,10 +91,12 @@ class LLMKitWrapper(ABC):
82
91
  "id": tool_call.get("id"),
83
92
  "type": tool_call.get("type", "function"),
84
93
  "function": {
85
- "name": tool_call.get("function", {}).get("name"),
86
- "arguments": tool_call.get("function", {}).get(
87
- "arguments"
94
+ "name": tool_call.get("function", {}).get(
95
+ "name"
88
96
  ),
97
+ "arguments": tool_call.get(
98
+ "function", {}
99
+ ).get("arguments"),
89
100
  },
90
101
  }
91
102
  tool_calls.append(tool_call_dict)
@@ -101,6 +112,7 @@ class LLMKitWrapper(ABC):
101
112
 
102
113
  return content
103
114
 
115
+
104
116
  class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
105
117
  def chat(self, sentence: List[str]):
106
118
  if self.model_id is None:
@@ -113,7 +125,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
113
125
  "messages": sentence,
114
126
  "parameters": self.params,
115
127
  "space_id": "1",
116
- "timeout": self.timeout
128
+ "timeout": self.timeout,
117
129
  }
118
130
  resp = requests.post(url=chat_url, headers=headers, json=data)
119
131
  if resp.status_code == 200:
@@ -121,6 +133,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
121
133
  else:
122
134
  resp.raise_for_status()
123
135
 
136
+
124
137
  class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
125
138
  def chat(self, sentence: list):
126
139
  chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
@@ -129,7 +142,7 @@ class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
129
142
  "model_id": self.model_id,
130
143
  "messages": sentence,
131
144
  "parameters": self.params,
132
- "space_id": self.space_id
145
+ "space_id": self.space_id,
133
146
  }
134
147
  resp = requests.post(url=chat_url, headers=headers, json=data)
135
148
  if resp.status_code == 200:
@@ -1,11 +1,13 @@
1
- import os
2
- import requests
1
+ import dataclasses
3
2
  import json
3
+ import os
4
+ import time
5
+ from threading import Lock
4
6
  from types import MappingProxyType
5
7
  from typing import List, Mapping, Union
6
- import dataclasses
7
- from threading import Lock
8
- import time
8
+
9
+ import requests
10
+
9
11
  from wxo_agentic_evaluation.service_provider.provider import Provider
10
12
 
11
13
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
@@ -36,7 +38,9 @@ class WatsonXProvider(Provider):
36
38
  super().__init__()
37
39
  self.url = url
38
40
  if (embedding_model_id is None) and (model_id is None):
39
- raise Exception("either model_id or embedding_model_id must be specified")
41
+ raise Exception(
42
+ "either model_id or embedding_model_id must be specified"
43
+ )
40
44
  self.model_id = model_id
41
45
  api_key = os.environ.get("WATSONX_APIKEY", api_key)
42
46
  if not api_key:
@@ -56,7 +60,7 @@ class WatsonXProvider(Provider):
56
60
  self.lock = Lock()
57
61
 
58
62
  self.params = params if params else DEFAULT_PARAM
59
-
63
+
60
64
  if isinstance(self.params, MappingProxyType):
61
65
  self.params = dict(self.params)
62
66
  if dataclasses.is_dataclass(self.params):
@@ -68,7 +72,10 @@ class WatsonXProvider(Provider):
68
72
 
69
73
  def _get_access_token(self):
70
74
  response = requests.post(
71
- self.url, headers=ACCESS_HEADER, data=self.access_data, timeout=self.timeout
75
+ self.url,
76
+ headers=ACCESS_HEADER,
77
+ data=self.access_data,
78
+ timeout=self.timeout,
72
79
  )
73
80
  if response.status_code == 200:
74
81
  token_data = json.loads(response.text)
@@ -84,16 +91,24 @@ class WatsonXProvider(Provider):
84
91
  )
85
92
 
86
93
  def prepare_header(self):
87
- headers = {"Authorization": f"Bearer {self.access_token}",
88
- "Content-Type": "application/json"}
94
+ headers = {
95
+ "Authorization": f"Bearer {self.access_token}",
96
+ "Content-Type": "application/json",
97
+ }
89
98
  return headers
90
99
 
91
100
  def _query(self, sentence: str):
92
101
  headers = self.prepare_header()
93
102
 
94
- data = {"model_id": self.model_id, "input": sentence,
95
- "parameters": self.params, "space_id": self.space_id}
96
- generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
103
+ data = {
104
+ "model_id": self.model_id,
105
+ "input": sentence,
106
+ "parameters": self.params,
107
+ "space_id": self.space_id,
108
+ }
109
+ generation_url = (
110
+ f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
111
+ )
97
112
  resp = requests.post(url=generation_url, headers=headers, json=data)
98
113
  if resp.status_code == 200:
99
114
  return resp.json()["results"][0]
@@ -105,20 +120,25 @@ class WatsonXProvider(Provider):
105
120
  if not self.access_token or time.time() > self.refresh_time:
106
121
  with self.lock:
107
122
  if not self.access_token or time.time() > self.refresh_time:
108
- self.access_token, self.refresh_time = self._get_access_token()
123
+ (
124
+ self.access_token,
125
+ self.refresh_time,
126
+ ) = self._get_access_token()
109
127
 
110
128
  def query(self, sentence: Union[str, Mapping[str, str]]) -> str:
111
129
  if self.model_id is None:
112
130
  raise Exception("model id must be specified for text generation")
113
131
  try:
114
132
  response = self._query(sentence)
115
- if (generated_text := response.get("generated_text")):
133
+ if generated_text := response.get("generated_text"):
116
134
  return generated_text
117
- elif (message := response.get("message")):
135
+ elif message := response.get("message"):
118
136
  return message
119
137
  else:
120
- raise ValueError(f"Unexpected response from WatsonX: {response}")
121
-
138
+ raise ValueError(
139
+ f"Unexpected response from WatsonX: {response}"
140
+ )
141
+
122
142
  except Exception as e:
123
143
  with self.lock:
124
144
  if "authentication_token_expired" in str(e):
@@ -130,12 +150,18 @@ class WatsonXProvider(Provider):
130
150
 
131
151
  def encode(self, sentences: List[str]) -> List[list]:
132
152
  if self.embedding_model_id is None:
133
- raise Exception("embedding model id must be specified for text encoding")
153
+ raise Exception(
154
+ "embedding model id must be specified for text encoding"
155
+ )
134
156
 
135
157
  headers = self.prepare_header()
136
158
  url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
137
159
 
138
- data = {"inputs": sentences, "model_id": self.model_id, "space_id": self.space_id}
160
+ data = {
161
+ "inputs": sentences,
162
+ "model_id": self.model_id,
163
+ "space_id": self.space_id,
164
+ }
139
165
  resp = requests.post(url=url, headers=headers, json=data)
140
166
  if resp.status_code == 200:
141
167
  return [entry["embedding"] for entry in resp.json()["results"]]
@@ -144,7 +170,9 @@ class WatsonXProvider(Provider):
144
170
 
145
171
 
146
172
  if __name__ == "__main__":
147
- provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
173
+ provider = WatsonXProvider(
174
+ model_id="meta-llama/llama-3-2-90b-vision-instruct"
175
+ )
148
176
 
149
177
  prompt = """
150
178
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
@@ -176,4 +204,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
176
204
  <|eot_id|><|start_header_id|>user<|end_header_id|>
177
205
  """
178
206
 
179
- print(provider.query(prompt))
207
+ print(provider.query(prompt))
@@ -1,26 +1,35 @@
1
- import json
2
1
  import ast
3
2
  import csv
4
- from pathlib import Path
5
3
  import importlib.util
6
- import re
7
- from jsonargparse import CLI
4
+ import json
8
5
  import os
6
+ import re
9
7
  import sys
10
8
  import textwrap
11
- from dataclasses import is_dataclass, asdict
9
+ from dataclasses import asdict, is_dataclass
10
+ from pathlib import Path
11
+
12
+ from jsonargparse import CLI
12
13
 
13
- from wxo_agentic_evaluation.service_provider import get_provider
14
- from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
15
- from wxo_agentic_evaluation.prompt.template_render import ToolPlannerTemplateRenderer, ArgsExtractorTemplateRenderer
16
14
  from wxo_agentic_evaluation import __file__
15
+ from wxo_agentic_evaluation.arg_configs import BatchAnnotateConfig
16
+ from wxo_agentic_evaluation.prompt.template_render import (
17
+ ArgsExtractorTemplateRenderer,
18
+ ToolPlannerTemplateRenderer,
19
+ )
20
+ from wxo_agentic_evaluation.service_provider import get_provider
17
21
 
18
22
  root_dir = os.path.dirname(__file__)
19
- TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
20
- ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
23
+ TOOL_PLANNER_PROMPT_PATH = os.path.join(
24
+ root_dir, "prompt", "tool_planner.jinja2"
25
+ )
26
+ ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(
27
+ root_dir, "prompt", "args_extractor_prompt.jinja2"
28
+ )
21
29
 
22
30
  MISSING_DOCSTRING_PROMPT = "No description available"
23
31
 
32
+
24
33
  class UniversalEncoder(json.JSONEncoder):
25
34
  def default(self, obj):
26
35
  if is_dataclass(obj):
@@ -29,12 +38,15 @@ class UniversalEncoder(json.JSONEncoder):
29
38
  return obj.__dict__
30
39
  return super().default(obj)
31
40
 
41
+
32
42
  def extract_first_json_list(raw: str) -> list:
33
43
  matches = re.findall(r"\[\s*{.*?}\s*]", raw, re.DOTALL)
34
44
  for match in matches:
35
45
  try:
36
46
  parsed = json.loads(match)
37
- if isinstance(parsed, list) and all("tool_name" in step for step in parsed):
47
+ if isinstance(parsed, list) and all(
48
+ "tool_name" in step for step in parsed
49
+ ):
38
50
  return parsed
39
51
  except Exception:
40
52
  continue
@@ -42,6 +54,7 @@ def extract_first_json_list(raw: str) -> list:
42
54
  print(raw)
43
55
  return []
44
56
 
57
+
45
58
  def parse_json_string(input_string):
46
59
  json_char_count = 0
47
60
  json_objects = []
@@ -79,12 +92,16 @@ def load_tools_module(tools_path: Path) -> dict:
79
92
  elif tools_path.is_dir():
80
93
  files_to_parse.extend(tools_path.glob("**/*.py"))
81
94
  else:
82
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
95
+ raise ValueError(
96
+ f"Tools path {tools_path} is neither a file nor directory"
97
+ )
83
98
 
84
99
  for file_path in files_to_parse:
85
100
  try:
86
101
  module_name = file_path.stem
87
- spec = importlib.util.spec_from_file_location(module_name, file_path)
102
+ spec = importlib.util.spec_from_file_location(
103
+ module_name, file_path
104
+ )
88
105
  module = importlib.util.module_from_spec(spec)
89
106
  parent_dir = str(file_path.parent)
90
107
  sys_path_modified = False
@@ -99,7 +116,7 @@ def load_tools_module(tools_path: Path) -> dict:
99
116
  # Add all module's non-private functions to tools_dict
100
117
  for attr_name in dir(module):
101
118
  attr = getattr(module, attr_name)
102
- if callable(attr) and not attr_name.startswith('_'):
119
+ if callable(attr) and not attr_name.startswith("_"):
103
120
  tools_dict[attr_name] = attr
104
121
  except Exception as e:
105
122
  print(f"Warning: Failed to load {file_path}: {str(e)}")
@@ -117,7 +134,9 @@ def extract_tool_signatures(tools_path: Path) -> list:
117
134
  elif tools_path.is_dir():
118
135
  files_to_parse.extend(tools_path.glob("**/*.py"))
119
136
  else:
120
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
137
+ raise ValueError(
138
+ f"Tools path {tools_path} is neither a file nor directory"
139
+ )
121
140
 
122
141
  for file_path in files_to_parse:
123
142
  try:
@@ -128,19 +147,24 @@ def extract_tool_signatures(tools_path: Path) -> list:
128
147
  for node in parsed_code.body:
129
148
  if isinstance(node, ast.FunctionDef):
130
149
  name = node.name
131
- args = [arg.arg for arg in node.args.args if arg.arg != "self"]
150
+ args = [
151
+ arg.arg for arg in node.args.args if arg.arg != "self"
152
+ ]
132
153
  docstring = ast.get_docstring(node)
133
- tool_data.append({
134
- "Function Name": name,
135
- "Arguments": args,
136
- "Docstring": docstring or MISSING_DOCSTRING_PROMPT
137
- })
154
+ tool_data.append(
155
+ {
156
+ "Function Name": name,
157
+ "Arguments": args,
158
+ "Docstring": docstring or MISSING_DOCSTRING_PROMPT,
159
+ }
160
+ )
138
161
  except Exception as e:
139
162
  print(f"Warning: Failed to parse {file_path}: {str(e)}")
140
163
  continue
141
164
 
142
165
  return tool_data
143
166
 
167
+
144
168
  def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
145
169
  functions = {}
146
170
  files_to_parse = []
@@ -151,7 +175,9 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
151
175
  elif tools_path.is_dir():
152
176
  files_to_parse.extend(tools_path.glob("**/*.py"))
153
177
  else:
154
- raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
178
+ raise ValueError(
179
+ f"Tools path {tools_path} is neither a file nor directory"
180
+ )
155
181
 
156
182
  for file_path in files_to_parse:
157
183
  try:
@@ -168,23 +194,35 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
168
194
  for arg in node.args.args:
169
195
  if arg.arg == "self":
170
196
  continue
171
- annotation = ast.unparse(arg.annotation) if arg.annotation else "Any"
197
+ annotation = (
198
+ ast.unparse(arg.annotation)
199
+ if arg.annotation
200
+ else "Any"
201
+ )
172
202
  args.append((arg.arg, annotation))
173
203
 
174
204
  # Get return type
175
- returns = ast.unparse(node.returns) if node.returns else "None"
205
+ returns = (
206
+ ast.unparse(node.returns) if node.returns else "None"
207
+ )
176
208
 
177
209
  # Get docstring
178
210
  docstring = ast.get_docstring(node)
179
- docstring = textwrap.dedent(docstring).strip() if docstring else ""
211
+ docstring = (
212
+ textwrap.dedent(docstring).strip() if docstring else ""
213
+ )
180
214
 
181
215
  # Format parameter descriptions if available in docstring
182
216
  doc_lines = docstring.splitlines()
183
217
  doc_summary = doc_lines[0] if doc_lines else ""
184
- param_descriptions = "\n".join([line for line in doc_lines[1:] if ":param" in line])
218
+ param_descriptions = "\n".join(
219
+ [line for line in doc_lines[1:] if ":param" in line]
220
+ )
185
221
 
186
222
  # Compose the final string
187
- args_str = ", ".join(f"{arg}: {type_}" for arg, type_ in args)
223
+ args_str = ", ".join(
224
+ f"{arg}: {type_}" for arg, type_ in args
225
+ )
188
226
  function_str = f"""def {name}({args_str}) -> {returns}:
189
227
  {doc_summary}"""
190
228
  if param_descriptions:
@@ -197,9 +235,18 @@ def extract_tool_signatures_for_prompt(tools_path: Path) -> dict[str, str]:
197
235
 
198
236
  return functions
199
237
 
200
- def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module: dict, tool_signatures_for_prompt) -> dict:
238
+
239
+ def ensure_data_available(
240
+ step: dict,
241
+ inputs: dict,
242
+ snapshot: dict,
243
+ tools_module: dict,
244
+ tool_signatures_for_prompt,
245
+ ) -> dict:
201
246
  tool_name = step["tool_name"]
202
- cache = snapshot.setdefault("input_output_examples", {}).setdefault(tool_name, [])
247
+ cache = snapshot.setdefault("input_output_examples", {}).setdefault(
248
+ tool_name, []
249
+ )
203
250
  for entry in cache:
204
251
  if entry["inputs"] == inputs:
205
252
  return entry["output"]
@@ -212,7 +259,11 @@ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module
212
259
  except:
213
260
  provider = get_provider(
214
261
  model_id="meta-llama/llama-3-405b-instruct",
215
- params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 500},
262
+ params={
263
+ "min_new_tokens": 0,
264
+ "decoding_method": "greedy",
265
+ "max_new_tokens": 500,
266
+ },
216
267
  )
217
268
  renderer = ArgsExtractorTemplateRenderer(ARGS_EXTRACTOR_PROMPT_PATH)
218
269
 
@@ -226,14 +277,19 @@ def ensure_data_available(step: dict, inputs: dict, snapshot: dict, tools_module
226
277
  try:
227
278
  output = tools_module[json_obj["tool_name"]](**json_obj["inputs"])
228
279
  except:
229
- raise ValueError(f"Failed to execute tool '{tool_name}' with inputs {inputs}")
280
+ raise ValueError(
281
+ f"Failed to execute tool '{tool_name}' with inputs {inputs}"
282
+ )
230
283
 
231
284
  cache.append({"inputs": inputs, "output": output})
232
285
  if not isinstance(output, dict):
233
286
  print(f" Tool {tool_name} returned non-dict output: {output}")
234
287
  return output
235
288
 
236
- def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: str, provider) -> list:
289
+
290
+ def plan_tool_calls_with_llm(
291
+ story: str, agent_name: str, tool_signatures_str: str, provider
292
+ ) -> list:
237
293
 
238
294
  renderer = ToolPlannerTemplateRenderer(TOOL_PLANNER_PROMPT_PATH)
239
295
 
@@ -250,7 +306,9 @@ def plan_tool_calls_with_llm(story: str, agent_name: str, tool_signatures_str: s
250
306
 
251
307
 
252
308
  # --- Tool Execution Logic ---
253
- def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt) -> None:
309
+ def run_tool_chain(
310
+ tool_plan: list, snapshot: dict, tools_module, tool_signatures_for_prompt
311
+ ) -> None:
254
312
  memory = {}
255
313
 
256
314
  for step in tool_plan:
@@ -280,7 +338,9 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signature
280
338
 
281
339
  if list_keys:
282
340
  if len(list_keys) > 1:
283
- raise ValueError(f"Tool '{name}' received multiple list inputs. Only one supported for now.")
341
+ raise ValueError(
342
+ f"Tool '{name}' received multiple list inputs. Only one supported for now."
343
+ )
284
344
  list_key = list_keys[0]
285
345
  value_list = resolved_inputs[list_key]
286
346
 
@@ -289,20 +349,36 @@ def run_tool_chain(tool_plan: list, snapshot: dict, tools_module, tool_signature
289
349
  item_inputs = resolved_inputs.copy()
290
350
  item_inputs[list_key] = val
291
351
  print(f" ⚙️ Running {name} with {list_key} = {val}")
292
- output = ensure_data_available(step, item_inputs, snapshot, tools_module, tool_signatures_for_prompt)
352
+ output = ensure_data_available(
353
+ step,
354
+ item_inputs,
355
+ snapshot,
356
+ tools_module,
357
+ tool_signatures_for_prompt,
358
+ )
293
359
  results.append(output)
294
360
  memory[f"{name}_{idx}"] = output
295
361
 
296
362
  memory[name] = results
297
- print(f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'")
363
+ print(
364
+ f"Stored {len(results)} outputs under '{name}' and indexed as '{name}_i'"
365
+ )
298
366
  else:
299
- output = ensure_data_available(step, resolved_inputs, snapshot, tools_module, tool_signatures_for_prompt)
367
+ output = ensure_data_available(
368
+ step,
369
+ resolved_inputs,
370
+ snapshot,
371
+ tools_module,
372
+ tool_signatures_for_prompt,
373
+ )
300
374
  memory[name] = output
301
375
  print(f"Stored output under tool name: {name} = {output}")
302
376
 
303
377
 
304
378
  # --- Main Snapshot Builder ---
305
- def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path: Path):
379
+ def build_snapshot(
380
+ agent_name: str, tools_path: Path, stories: list, output_path: Path
381
+ ):
306
382
  agent = {"name": agent_name}
307
383
  tools_module = load_tools_module(tools_path)
308
384
  tool_signatures = extract_tool_signatures(tools_path)
@@ -310,20 +386,28 @@ def build_snapshot(agent_name: str, tools_path: Path, stories: list, output_path
310
386
 
311
387
  provider = get_provider(
312
388
  model_id="meta-llama/llama-3-405b-instruct",
313
- params={"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 2048},
389
+ params={
390
+ "min_new_tokens": 1,
391
+ "decoding_method": "greedy",
392
+ "max_new_tokens": 2048,
393
+ },
314
394
  )
315
395
 
316
396
  snapshot = {
317
397
  "agent": agent,
318
398
  "tools": tool_signatures,
319
- "input_output_examples": {}
399
+ "input_output_examples": {},
320
400
  }
321
401
 
322
402
  for story in stories:
323
403
  print(f"\n📘 Planning tool calls for story: {story}")
324
- tool_plan = plan_tool_calls_with_llm(story, agent["name"], tool_signatures, provider)
404
+ tool_plan = plan_tool_calls_with_llm(
405
+ story, agent["name"], tool_signatures, provider
406
+ )
325
407
  try:
326
- run_tool_chain(tool_plan, snapshot, tools_module, tool_signatures_for_prompt)
408
+ run_tool_chain(
409
+ tool_plan, snapshot, tools_module, tool_signatures_for_prompt
410
+ )
327
411
  except ValueError as e:
328
412
  print(f"❌ Error running tool chain for story '{story}': {e}")
329
413
  continue
@@ -340,7 +424,7 @@ if __name__ == "__main__":
340
424
 
341
425
  stories = []
342
426
  agent_name = None
343
- with stories_path.open("r", encoding="utf-8", newline='') as f:
427
+ with stories_path.open("r", encoding="utf-8", newline="") as f:
344
428
  csv_reader = csv.DictReader(f)
345
429
  for row in csv_reader:
346
430
  stories.append(row["story"])
@@ -349,4 +433,4 @@ if __name__ == "__main__":
349
433
 
350
434
  snapshot_path = stories_path.parent / f"{agent_name}_snapshot_llm.json"
351
435
 
352
- build_snapshot(agent_name, tools_path, stories, snapshot_path)
436
+ build_snapshot(agent_name, tools_path, stories, snapshot_path)
@@ -1,10 +1,7 @@
1
- from typing import Dict, List, Union, Any, Optional
2
- from pydantic import (
3
- BaseModel,
4
- ConfigDict,
5
- Field
6
- )
7
1
  from enum import StrEnum
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field
8
5
  from rich.text import Text
9
6
 
10
7
 
@@ -61,9 +58,13 @@ class ConversationalConfidenceThresholdScore(BaseModel):
61
58
  def table(self):
62
59
  return {
63
60
  "response_confidence": str(self.response_confidence),
64
- "response_confidence_threshold": str(self.response_confidence_threshold),
61
+ "response_confidence_threshold": str(
62
+ self.response_confidence_threshold
63
+ ),
65
64
  "retrieval_confidence": str(self.retrieval_confidence),
66
- "retrieval_confidence_threshold": str(self.retrieval_confidence_threshold),
65
+ "retrieval_confidence_threshold": str(
66
+ self.retrieval_confidence_threshold
67
+ ),
67
68
  }
68
69
 
69
70
 
@@ -120,12 +121,14 @@ class GoalDetail(BaseModel):
120
121
  keywords: List = None
121
122
  knowledge_base: KnowledgeBaseGoalDetail = KnowledgeBaseGoalDetail()
122
123
 
124
+
123
125
  class AttackData(BaseModel):
124
126
  attack_category: AttackCategory
125
127
  attack_type: str
126
128
  attack_name: str
127
129
  attack_instructions: str
128
130
 
131
+
129
132
  class AttackData(BaseModel):
130
133
  agent: str
131
134
  agents_path: str
@@ -143,8 +146,8 @@ class EvaluationData(BaseModel):
143
146
  goal_details: List[GoalDetail]
144
147
  starting_sentence: str = None
145
148
 
149
+
146
150
  class ToolDefinition(BaseModel):
147
151
  tool_description: Optional[str]
148
152
  tool_name: str
149
153
  tool_params: List[str]
150
-