ibm-watsonx-orchestrate-evaluation-framework 1.0.8__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.

Files changed (60) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.0.8.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info}/METADATA +103 -109
  2. ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info/RECORD +96 -0
  3. wxo_agentic_evaluation/analytics/tools/main.py +1 -18
  4. wxo_agentic_evaluation/analyze_run.py +358 -97
  5. wxo_agentic_evaluation/arg_configs.py +28 -1
  6. wxo_agentic_evaluation/description_quality_checker.py +149 -0
  7. wxo_agentic_evaluation/evaluation_package.py +58 -17
  8. wxo_agentic_evaluation/inference_backend.py +32 -17
  9. wxo_agentic_evaluation/llm_user.py +2 -1
  10. wxo_agentic_evaluation/metrics/metrics.py +22 -1
  11. wxo_agentic_evaluation/prompt/bad_tool_descriptions_prompt.jinja2 +178 -0
  12. wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +9 -1
  13. wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +34 -0
  14. wxo_agentic_evaluation/prompt/on_policy_attack_generation_prompt.jinja2 +46 -0
  15. wxo_agentic_evaluation/prompt/template_render.py +34 -3
  16. wxo_agentic_evaluation/quick_eval.py +342 -0
  17. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +113 -0
  18. wxo_agentic_evaluation/red_teaming/attack_generator.py +286 -0
  19. wxo_agentic_evaluation/red_teaming/attack_list.py +96 -0
  20. wxo_agentic_evaluation/red_teaming/attack_runner.py +128 -0
  21. wxo_agentic_evaluation/referenceless_eval/__init__.py +3 -0
  22. wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py +0 -0
  23. wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py +28 -0
  24. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py +0 -0
  25. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +27 -0
  26. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/__init__.py +0 -0
  27. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general.py +49 -0
  28. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
  29. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/__init__.py +0 -0
  30. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection.py +31 -0
  31. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
  32. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +237 -0
  33. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py +0 -0
  34. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +101 -0
  35. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +263 -0
  36. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +455 -0
  37. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +156 -0
  38. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py +509 -0
  39. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +547 -0
  40. wxo_agentic_evaluation/referenceless_eval/metrics/__init__.py +3 -0
  41. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +258 -0
  42. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +333 -0
  43. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +188 -0
  44. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +409 -0
  45. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +42 -0
  46. wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py +0 -0
  47. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +145 -0
  48. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +116 -0
  49. wxo_agentic_evaluation/service_instance.py +2 -2
  50. wxo_agentic_evaluation/service_provider/watsonx_provider.py +118 -4
  51. wxo_agentic_evaluation/tool_planner.py +3 -1
  52. wxo_agentic_evaluation/type.py +33 -2
  53. wxo_agentic_evaluation/utils/__init__.py +0 -1
  54. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +157 -0
  55. wxo_agentic_evaluation/utils/rich_utils.py +174 -0
  56. wxo_agentic_evaluation/utils/rouge_score.py +23 -0
  57. wxo_agentic_evaluation/utils/utils.py +167 -5
  58. ibm_watsonx_orchestrate_evaluation_framework-1.0.8.dist-info/RECORD +0 -56
  59. {ibm_watsonx_orchestrate_evaluation_framework-1.0.8.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info}/WHEEL +0 -0
  60. {ibm_watsonx_orchestrate_evaluation_framework-1.0.8.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,145 @@
1
+ import asyncio
2
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TypeVar, Union
3
+
4
+ from pydantic import BaseModel
5
+
6
+ Prompt = Union[str, List[Dict[str, Any]]]
7
+ PromptAndSchema = Tuple[Union[str, List[Dict[str, Any]]], Optional[Dict[str, Any]]]
8
+ SyncGen = Callable[[Prompt], Union[str, Any]]
9
+ BatchGen = Callable[[List[Prompt]], List[Union[str, Any]]]
10
+ AsyncGen = Callable[[Prompt], Awaitable[Union[str, Any]]]
11
+ AsyncBatchGen = Callable[[List[Prompt]], Awaitable[List[Union[str, Any]]]]
12
+
13
+ T = TypeVar("T")
14
+
15
+
16
+ class PromptResult(BaseModel):
17
+ """
18
+ Holds the prompt sent and the response (or error).
19
+ """
20
+
21
+ prompt: Prompt
22
+ response: Optional[Any] = None
23
+ error: Optional[str] = None
24
+
25
+
26
+ class PromptRunner:
27
+ """
28
+ Runs a collection of prompts through various generation strategies.
29
+
30
+ Attributes:
31
+ prompts: the list of prompts to run.
32
+ """
33
+
34
+ def __init__(
35
+ self, prompts: Optional[List[Union[Prompt, PromptAndSchema]]] = None
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ prompts: initial list of prompts (strings or chat messages).
40
+ """
41
+ self.prompts: List[Union[Prompt, PromptAndSchema]] = prompts or []
42
+
43
+ def add_prompt(self, prompt: Union[Prompt, PromptAndSchema]) -> None:
44
+ """Append a prompt to the runner."""
45
+ self.prompts.append(prompt)
46
+
47
+ def remove_prompt(self, prompt: Union[Prompt, PromptAndSchema]) -> None:
48
+ """Remove a prompt (first occurrence)."""
49
+ self.prompts.remove(prompt)
50
+
51
+ def clear_prompts(self) -> None:
52
+ """Remove all prompts."""
53
+ self.prompts.clear()
54
+
55
+ def get_prompt_and_schema(
56
+ self, prompt: Union[Prompt, PromptAndSchema]
57
+ ) -> Tuple[Prompt, Optional[Dict[str, Any]]]:
58
+ """
59
+ Extract the prompt and schema from a Prompt object.
60
+
61
+ Args:
62
+ prompt: The prompt to extract from.
63
+
64
+ Returns:
65
+ Tuple of (prompt, schema).
66
+ """
67
+ if isinstance(prompt, tuple):
68
+ return prompt[0], prompt[1]
69
+ return prompt, None
70
+
71
+ def run_all(
72
+ self,
73
+ gen_fn: SyncGen,
74
+ prompt_param_name: str = "prompt",
75
+ schema_param_name: Optional[str] = None,
76
+ **kwargs: Any,
77
+ ) -> List[PromptResult]:
78
+ """
79
+ Run each prompt through a synchronous single-prompt generator.
80
+
81
+ Args:
82
+ gen_fn: Callable taking one Prompt, returning str or Any.
83
+ prompt_param_name: Name of the parameter for the prompt.
84
+ schema_param_name: Name of the parameter for the schema.
85
+ kwargs: Additional arguments to pass to the function.
86
+
87
+ Returns:
88
+ List of PromptResult.
89
+ """
90
+ results: List[PromptResult] = []
91
+ for p in self.prompts:
92
+ try:
93
+ prompt, schema = self.get_prompt_and_schema(p)
94
+ args = {prompt_param_name: prompt, **kwargs}
95
+ if schema_param_name and schema:
96
+ args[schema_param_name] = schema
97
+ resp = gen_fn(**args)
98
+ results.append(PromptResult(prompt=prompt, response=resp))
99
+ except Exception as e:
100
+ results.append(PromptResult(prompt=prompt, error=str(e)))
101
+ return results
102
+
103
+ async def run_async(
104
+ self,
105
+ async_fn: AsyncGen,
106
+ max_parallel: int = 10,
107
+ prompt_param_name: str = "prompt",
108
+ schema_param_name: Optional[str] = None,
109
+ **kwargs: Any,
110
+ ) -> List[PromptResult]:
111
+ """
112
+ Run each prompt through an async single-prompt generator with concurrency limit.
113
+ Results are returned in the same order as self.prompts.
114
+
115
+ Args:
116
+ async_fn: Async callable taking one Prompt, returning str or Any.
117
+ max_parallel: Max concurrent tasks.
118
+ prompt_param_name: Name of the parameter for the prompt.
119
+ schema_param_name: Name of the parameter for the schema.
120
+ kwargs: Additional arguments to pass to the async function.
121
+
122
+ Returns:
123
+ List of PromptResult.
124
+ """
125
+ semaphore = asyncio.Semaphore(max_parallel)
126
+
127
+ async def _run_one(index: int, p: Prompt) -> Tuple[int, PromptResult]:
128
+ async with semaphore:
129
+ try:
130
+ prompt, schema = self.get_prompt_and_schema(p)
131
+ args = {prompt_param_name: prompt, **kwargs}
132
+ if schema_param_name and schema:
133
+ args[schema_param_name] = schema
134
+ resp = await async_fn(**args)
135
+ return index, PromptResult(prompt=prompt, response=resp)
136
+ except Exception as e:
137
+ return index, PromptResult(prompt=prompt, error=str(e))
138
+
139
+ tasks = [
140
+ asyncio.create_task(_run_one(i, p)) for i, p in enumerate(self.prompts)
141
+ ]
142
+ indexed_results = await asyncio.gather(*tasks)
143
+ # Sort results to match original order
144
+ indexed_results.sort(key=lambda x: x[0])
145
+ return [res for _, res in indexed_results]
@@ -0,0 +1,116 @@
1
+ import json
2
+ import os
3
+ from typing import Any, List, Mapping
4
+
5
+ import rich
6
+
7
+ from wxo_agentic_evaluation.referenceless_eval.function_calling.consts import (
8
+ METRIC_FUNCTION_SELECTION_APPROPRIATENESS,
9
+ METRIC_GENERAL_HALLUCINATION_CHECK,
10
+ )
11
+ from wxo_agentic_evaluation.referenceless_eval.function_calling.pipeline.pipeline import (
12
+ ReflectionPipeline,
13
+ )
14
+ from wxo_agentic_evaluation.referenceless_eval.function_calling.pipeline.types import (
15
+ ToolCall,
16
+ ToolSpec,
17
+ )
18
+ from wxo_agentic_evaluation.type import Message
19
+ from wxo_agentic_evaluation.service_provider.watsonx_provider import WatsonXLLMKitWrapper
20
+
21
+ class ReferencelessEvaluation:
22
+ """
23
+ Note: static.final_decison, if `True` -> then all static metrics were valid. If false, atleast one of the static metrics failed. Look at explanation for reasoning
24
+ Note: if static.final_decision == True, check semantic metrics. Semantic metrics **not** run if static.final_decision is False.
25
+ ---
26
+ Note: For semantic metrics, check agentic constraints. If agent-constraints == False, no point in checking others. If true, check others.
27
+ Note: METRIC_FUNCTION_SELECTION_APPROPRIATENESS == False, implies that the LLM should have called some other function/tool before *OR* it is a redundant call.
28
+ Note: When parsing the semantic metrics, check for `is_correct` field. if `false` there is some mistake that the LLMaJ found in that tool call.
29
+ """
30
+ def __init__(
31
+ self,
32
+ api_spec: List[Mapping[str, Any]],
33
+ messages: List[Message],
34
+ model_id: str,
35
+ task_n: str,
36
+ dataset_name: str,
37
+ ):
38
+ self.metrics_client = WatsonXLLMKitWrapper(
39
+ model_id=model_id,
40
+ api_key=os.getenv("WATSONX_APIKEY", ""),
41
+ space_id=os.getenv("WATSONX_SPACE_ID")
42
+ )
43
+
44
+ self.pipeline = ReflectionPipeline(
45
+ metrics_client=self.metrics_client,
46
+ general_metrics=[METRIC_GENERAL_HALLUCINATION_CHECK],
47
+ function_metrics=[METRIC_FUNCTION_SELECTION_APPROPRIATENESS],
48
+ parameter_metrics=None,
49
+ )
50
+
51
+ self.task_n = task_n
52
+ self.dataset_name = dataset_name
53
+
54
+ self.apis_specs = [ToolSpec.model_validate(spec) for spec in api_spec]
55
+ self.messages = messages
56
+
57
+ def _run_pipeline(self, examples: List[Mapping[str, Any]]):
58
+ results = []
59
+ for example in examples:
60
+ # self.pipeline.sy
61
+ result = self.pipeline.run_sync(
62
+ conversation=example["context"],
63
+ inventory=self.apis_specs,
64
+ call=example["call"],
65
+ continue_on_static=False,
66
+ retries=2,
67
+ )
68
+ result_dict = result.model_dump()
69
+ results.append(result_dict)
70
+
71
+ return results
72
+
73
+ def run(self):
74
+ examples = []
75
+
76
+ processed_data = [
77
+ {k: msg.model_dump().get(k) for k in ["role", "content", "type"] if k in msg.model_dump()}
78
+ for msg in self.messages
79
+ ]
80
+
81
+ for idx, message in enumerate(processed_data):
82
+ role = message["role"]
83
+ content = message["content"]
84
+ context = processed_data[:idx]
85
+
86
+ if role == "assistant" and message["type"] == "tool_call":
87
+ tool_call_msg = json.loads(content)
88
+ if tool_call_msg["name"].startswith("transfer_to"):
89
+ continue
90
+
91
+ call = {
92
+ "call": {
93
+ "id": tool_call_msg.get("id", "1"),
94
+ "type": "function",
95
+ "function": {
96
+ "name": tool_call_msg["name"],
97
+ "arguments": json.dumps(tool_call_msg["args"]),
98
+ },
99
+ },
100
+ "context": context,
101
+ }
102
+ examples.append(call)
103
+
104
+ rich.print(
105
+ f"[yellow][b][Task-{self.task_n}] There are {len(examples)} examples to analyze"
106
+ )
107
+ examples = [
108
+ {
109
+ "call": ToolCall.model_validate(ex["call"]),
110
+ "context": ex["context"],
111
+ }
112
+ for ex in examples
113
+ ]
114
+ results = self._run_pipeline(examples)
115
+
116
+ return results
@@ -49,10 +49,10 @@ class ServiceInstance:
49
49
  def get_user_token(self):
50
50
  try:
51
51
  if self.is_saas:
52
- apikey = os.environ.get("WATSONX_IAM_SAAS_APIKEY")
52
+ apikey = os.environ.get("WO_API_KEY")
53
53
  if not apikey:
54
54
  raise RuntimeError(
55
- "WATSONX_IAM_SAAS_APIKEY not set in environment for SaaS mode"
55
+ "WO_API_KEY not set in environment for SaaS mode"
56
56
  )
57
57
  if self.is_ibm_cloud:
58
58
  data = {
@@ -2,10 +2,12 @@ import os
2
2
  import requests
3
3
  import json
4
4
  from types import MappingProxyType
5
- from typing import List
5
+ from typing import List, Mapping, Union, Optional, Any
6
+ from functools import singledispatchmethod
6
7
  import dataclasses
7
8
  from threading import Lock
8
9
  import time
10
+ import rich
9
11
  from wxo_agentic_evaluation.service_provider.provider import Provider
10
12
 
11
13
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
@@ -88,7 +90,12 @@ class WatsonXProvider(Provider):
88
90
  "Content-Type": "application/json"}
89
91
  return headers
90
92
 
91
- def generate(self, sentence: str):
93
+ @singledispatchmethod
94
+ def generate(self, sentence):
95
+ raise ValueError(f"Input must either be a string or a list of dictionaries")
96
+
97
+ @generate.register
98
+ def _(self, sentence: str):
92
99
  headers = self.prepare_header()
93
100
 
94
101
  data = {"model_id": self.model_id, "input": sentence,
@@ -100,6 +107,22 @@ class WatsonXProvider(Provider):
100
107
  else:
101
108
  resp.raise_for_status()
102
109
 
110
+ @generate.register
111
+ def _(self, sentence: list):
112
+ chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
113
+ headers = self.prepare_header()
114
+ data = {
115
+ "model_id": self.model_id,
116
+ "messages": sentence,
117
+ "parameters": self.params,
118
+ "space_id": self.space_id
119
+ }
120
+ resp = requests.post(url=chat_url, headers=headers, json=data)
121
+ if resp.status_code == 200:
122
+ return resp.json()
123
+ else:
124
+ resp.raise_for_status()
125
+
103
126
  def _refresh_token(self):
104
127
  # if we do not have a token or the current timestamp is 9 minutes away from expire.
105
128
  if not self.access_token or time.time() > self.refresh_time:
@@ -107,11 +130,18 @@ class WatsonXProvider(Provider):
107
130
  if not self.access_token or time.time() > self.refresh_time:
108
131
  self.access_token, self.refresh_time = self._get_access_token()
109
132
 
110
- def query(self, sentence: str) -> str:
133
+ def query(self, sentence: Union[str, Mapping[str, str]]) -> str:
111
134
  if self.model_id is None:
112
135
  raise Exception("model id must be specified for text generation")
113
136
  try:
114
- return self.generate(sentence)["generated_text"]
137
+ response = self.generate(sentence)
138
+ if (generated_text := response.get("generated_text")):
139
+ return generated_text
140
+ elif (message := response.get("message")):
141
+ return message
142
+ else:
143
+ raise ValueError(f"Unexpected response from WatsonX: {response}")
144
+
115
145
  except Exception as e:
116
146
  with self.lock:
117
147
  if "authentication_token_expired" in str(e):
@@ -135,6 +165,90 @@ class WatsonXProvider(Provider):
135
165
  else:
136
166
  resp.raise_for_status()
137
167
 
168
+ class LLMResponse:
169
+ """
170
+ NOTE: Taken from LLM-Eval-Kit
171
+ Response object that can contain both content and tool calls
172
+ """
173
+
174
+ def __init__(self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None):
175
+ self.content = content
176
+ self.tool_calls = tool_calls or []
177
+
178
+ def __str__(self) -> str:
179
+ """Return the content of the response as a string."""
180
+ return self.content
181
+
182
+ def __repr__(self) -> str:
183
+ """Return a string representation of the LLMResponse object."""
184
+ return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
185
+
186
+ class WatsonXLLMKitWrapper(WatsonXProvider):
187
+ def generate(
188
+ self,
189
+ prompt: Union[str, List[Mapping[str, str]]],
190
+ *,
191
+ schema,
192
+ retries: int = 3,
193
+ generation_args: Optional[Any] = None,
194
+ **kwargs: Any
195
+ ):
196
+
197
+ """
198
+ In future, implement validation of response like in llmevalkit
199
+ """
200
+
201
+ for attempt in range(1, retries + 1):
202
+ try:
203
+ raw_response = super().generate(prompt)
204
+ response = self._parse_llm_response(raw_response)
205
+ return response
206
+ except Exception as e:
207
+ rich.print(f"[b][r] WatsonX generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))")
208
+
209
+ def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
210
+ """
211
+ Extract the generated text and tool calls from a watsonx response.
212
+
213
+ - For text generation: raw['results'][0]['generated_text']
214
+ - For chat: raw['choices'][0]['message']['content']
215
+ """
216
+ content = ""
217
+ tool_calls = []
218
+
219
+ if isinstance(raw, dict) and "choices" in raw:
220
+ choices = raw["choices"]
221
+ if isinstance(choices, list) and choices:
222
+ first = choices[0]
223
+ msg = first.get("message")
224
+ if isinstance(msg, dict):
225
+ content = msg.get("content", "")
226
+ # Extract tool calls if present
227
+ if "tool_calls" in msg and msg["tool_calls"]:
228
+ tool_calls = []
229
+ for tool_call in msg["tool_calls"]:
230
+ tool_call_dict = {
231
+ "id": tool_call.get("id"),
232
+ "type": tool_call.get("type", "function"),
233
+ "function": {
234
+ "name": tool_call.get("function", {}).get("name"),
235
+ "arguments": tool_call.get("function", {}).get(
236
+ "arguments"
237
+ ),
238
+ },
239
+ }
240
+ tool_calls.append(tool_call_dict)
241
+ elif "text" in first:
242
+ content = first["text"]
243
+
244
+ if not content and not tool_calls:
245
+ raise ValueError(f"Unexpected watsonx response format: {raw!r}")
246
+
247
+ # Return LLMResponse if tool calls exist, otherwise just content
248
+ if tool_calls:
249
+ return LLMResponse(content=content, tool_calls=tool_calls)
250
+
251
+ return content
138
252
 
139
253
  if __name__ == "__main__":
140
254
  provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
@@ -19,6 +19,8 @@ root_dir = os.path.dirname(__file__)
19
19
  TOOL_PLANNER_PROMPT_PATH = os.path.join(root_dir, "prompt", "tool_planner.jinja2")
20
20
  ARGS_EXTRACTOR_PROMPT_PATH = os.path.join(root_dir, "prompt", "args_extractor_prompt.jinja2")
21
21
 
22
+ MISSING_DOCSTRING_PROMPT = "No description available"
23
+
22
24
  class UniversalEncoder(json.JSONEncoder):
23
25
  def default(self, obj):
24
26
  if is_dataclass(obj):
@@ -131,7 +133,7 @@ def extract_tool_signatures(tools_path: Path) -> list:
131
133
  tool_data.append({
132
134
  "Function Name": name,
133
135
  "Arguments": args,
134
- "Docstring": docstring or "No description available"
136
+ "Docstring": docstring or MISSING_DOCSTRING_PROMPT
135
137
  })
136
138
  except Exception as e:
137
139
  print(f"Warning: Failed to parse {file_path}: {str(e)}")
@@ -1,6 +1,11 @@
1
1
  from typing import Dict, List, Union, Any, Optional
2
- from pydantic import BaseModel, computed_field, ConfigDict
2
+ from pydantic import (
3
+ BaseModel,
4
+ ConfigDict,
5
+ Field
6
+ )
3
7
  from enum import StrEnum
8
+ from rich.text import Text
4
9
 
5
10
 
6
11
  class EventTypes(StrEnum):
@@ -20,6 +25,11 @@ class ContentType(StrEnum):
20
25
  conversational_search = "conversational_search"
21
26
 
22
27
 
28
+ class AttackCategory(StrEnum):
29
+ on_policy = "on_policy"
30
+ off_policy = "off_policy"
31
+
32
+
23
33
  class ConversationalSearchCitations(BaseModel):
24
34
  url: str
25
35
  body: str
@@ -93,7 +103,7 @@ class Message(BaseModel):
93
103
 
94
104
  class ExtendedMessage(BaseModel):
95
105
  message: Message
96
- reason: dict | None = None
106
+ reason: dict | list | None = None
97
107
 
98
108
 
99
109
  class KnowledgeBaseGoalDetail(BaseModel):
@@ -110,6 +120,21 @@ class GoalDetail(BaseModel):
110
120
  keywords: List = None
111
121
  knowledge_base: KnowledgeBaseGoalDetail = KnowledgeBaseGoalDetail()
112
122
 
123
+ class AttackData(BaseModel):
124
+ attack_category: AttackCategory
125
+ attack_type: str
126
+ attack_name: str
127
+ attack_instructions: str
128
+
129
+ class AttackData(BaseModel):
130
+ agent: str
131
+ agents_path: str
132
+ attack_data: AttackData
133
+ story: str
134
+ starting_sentence: str
135
+ goals: Dict = None
136
+ goal_details: List[GoalDetail] = None
137
+
113
138
 
114
139
  class EvaluationData(BaseModel):
115
140
  agent: str
@@ -117,3 +142,9 @@ class EvaluationData(BaseModel):
117
142
  story: str
118
143
  goal_details: List[GoalDetail]
119
144
  starting_sentence: str = None
145
+
146
+ class ToolDefinition(BaseModel):
147
+ tool_description: Optional[str]
148
+ tool_name: str
149
+ tool_params: List[str]
150
+
@@ -1,6 +1,5 @@
1
1
  import json
2
2
 
3
-
4
3
  def json_dump(output_path, object):
5
4
  with open(output_path, "w", encoding="utf-8") as f:
6
5
  json.dump(object, f, indent=4)
@@ -0,0 +1,157 @@
1
+ import ast
2
+ import re
3
+ from pathlib import Path
4
+ from typing import Union, Mapping, Any, List
5
+
6
+ class PythonTypeToJsonType:
7
+ OPTIONAL_PARAM_EXTRACT = re.compile(r"[Oo]ptional\[(\w+)\]")
8
+
9
+ @staticmethod
10
+ def python_to_json_type(python_annotation: str):
11
+ if not python_annotation:
12
+ return "string"
13
+ python_annotation = python_annotation.lower().strip()
14
+ if "str" == python_annotation:
15
+ return "string"
16
+ if "int" == python_annotation:
17
+ return "integer"
18
+ if "float" == python_annotation:
19
+ return "number"
20
+ if "bool" == python_annotation:
21
+ return "boolean"
22
+ if python_annotation.startswith("list"):
23
+ return "array"
24
+ if python_annotation.startswith("dict"):
25
+ return "object"
26
+ if python_annotation.startswith("optional"):
27
+ # extract the type within Optional[T]
28
+ inner_type = PythonTypeToJsonType.OPTIONAL_PARAM_EXTRACT.search(python_annotation).group(1)
29
+ return PythonTypeToJsonType.python_to_json_type(inner_type)
30
+
31
+ return "string"
32
+
33
+ class ToolExtractionOpenAIFormat:
34
+ @staticmethod
35
+ def get_default_arguments(node):
36
+ """ Returns the default arguments (if any)
37
+
38
+ The default arguments are stored in args.default array.
39
+ Since, in Python, the default arguments only come after positional arguments,
40
+ we can index the argument array starting from the last `n` arguments, where n is
41
+ the length of the default arguments.
42
+
43
+ ex.
44
+ def add(a, b=5):
45
+ pass
46
+
47
+ Then we have,
48
+ args = [a, b]
49
+ defaults = [Constant(value=5)]
50
+
51
+ args[-len(defaults):] = [b]
52
+
53
+ (
54
+ "FunctionDef(
55
+ name='add',
56
+ args=arguments(
57
+ posonlyargs=[],
58
+ args=[
59
+ arg(arg='a'), "
60
+ "arg(arg='b')
61
+ ],
62
+ kwonlyargs=[],
63
+ kw_defaults=[],
64
+ defaults=[Constant(value=5)]), "
65
+ "body=[Return(value=BinOp(left=Name(id='a', ctx=Load()), op=Add(), "
66
+ "right=Name(id='b', ctx=Load())))], decorator_list=[], type_params=[])")
67
+ """
68
+ default_arguments = set()
69
+ num_defaults = len(node.args.defaults)
70
+ if num_defaults > 0:
71
+ for arg in node.args.args[-num_defaults:]:
72
+ default_arguments.add(arg)
73
+
74
+ return default_arguments
75
+
76
+ @staticmethod
77
+ def from_file(tools_path: Union[str, Path]) -> Mapping[str, Any]:
78
+ """ Uses `extract_tool_signatures` function, but converts the response
79
+ to open-ai format
80
+
81
+ ```
82
+ function_spec = {
83
+ "type": "function",
84
+ "function": {
85
+ "name": func_name,
86
+ "description": description,
87
+ "parameters": parameters,
88
+ },
89
+ }
90
+ ```
91
+
92
+ """
93
+ tool_data = []
94
+ tools_path = Path(tools_path)
95
+
96
+ with tools_path.open("r", encoding="utf-8") as f:
97
+ code = f.read()
98
+
99
+ try:
100
+ parsed_code = ast.parse(code)
101
+ for node in parsed_code.body:
102
+ if isinstance(node, ast.FunctionDef):
103
+ parameters = {"type": "object", "properties": {}, "required": []}
104
+ function_name = node.name
105
+ for arg in node.args.args:
106
+ type_annotation = None
107
+ if arg.arg == "self":
108
+ continue
109
+ if arg.annotation:
110
+ type_annotation = ast.unparse(arg.annotation)
111
+
112
+ parameter_type = PythonTypeToJsonType.python_to_json_type(type_annotation)
113
+ parameters["properties"][arg.arg] = {
114
+ "type": parameter_type,
115
+ "description": "", # todo
116
+ }
117
+
118
+ if type_annotation and "Optional" not in type_annotation:
119
+ parameters["required"].append(arg.arg)
120
+
121
+ default_arguments = ToolExtractionOpenAIFormat.get_default_arguments(node)
122
+ for arg_name in parameters["required"]:
123
+ if arg_name in default_arguments:
124
+ parameters.remove(arg_name)
125
+
126
+ open_ai_format_fn = {
127
+ "type": "function",
128
+ "function": {
129
+ "name": function_name,
130
+ "parameters": parameters,
131
+ "description": ast.get_docstring(node) # fix (does not do :params)
132
+ }
133
+ }
134
+ tool_data.append(open_ai_format_fn)
135
+
136
+ except Exception as e:
137
+ print(f"Warning: Failed to parse {tools_path}: {str(e)}")
138
+
139
+ return tool_data
140
+
141
+ @staticmethod
142
+ def from_path(tools_path: Union[str, Path]) -> List[Mapping[str, Any]]:
143
+ tools_path = Path(tools_path)
144
+ files_to_parse = []
145
+ all_tools = []
146
+
147
+ if tools_path.is_file():
148
+ files_to_parse.append(tools_path)
149
+ elif tools_path.is_dir():
150
+ files_to_parse.extend(tools_path.glob("**/*.py"))
151
+ else:
152
+ raise ValueError(f"Tools path {tools_path} is neither a file nor directory")
153
+
154
+ for file_path in files_to_parse:
155
+ all_tools.extend(ToolExtractionOpenAIFormat.from_file(file_path))
156
+
157
+ return all_tools