ibm-watsonx-orchestrate-evaluation-framework 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.

Files changed (42) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
  2. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +42 -36
  3. wxo_agentic_evaluation/analyze_run.py +49 -32
  4. wxo_agentic_evaluation/arg_configs.py +30 -2
  5. wxo_agentic_evaluation/data_annotator.py +22 -4
  6. wxo_agentic_evaluation/description_quality_checker.py +20 -4
  7. wxo_agentic_evaluation/evaluation_package.py +189 -15
  8. wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
  9. wxo_agentic_evaluation/external_agent/types.py +1 -1
  10. wxo_agentic_evaluation/inference_backend.py +64 -34
  11. wxo_agentic_evaluation/llm_matching.py +92 -2
  12. wxo_agentic_evaluation/llm_user.py +2 -2
  13. wxo_agentic_evaluation/main.py +147 -38
  14. wxo_agentic_evaluation/metrics/__init__.py +5 -1
  15. wxo_agentic_evaluation/metrics/evaluations.py +124 -0
  16. wxo_agentic_evaluation/metrics/metrics.py +24 -3
  17. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  18. wxo_agentic_evaluation/prompt/template_render.py +16 -0
  19. wxo_agentic_evaluation/quick_eval.py +17 -3
  20. wxo_agentic_evaluation/record_chat.py +17 -6
  21. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +44 -14
  22. wxo_agentic_evaluation/red_teaming/attack_generator.py +31 -12
  23. wxo_agentic_evaluation/red_teaming/attack_list.py +23 -24
  24. wxo_agentic_evaluation/red_teaming/attack_runner.py +36 -19
  25. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +42 -16
  26. wxo_agentic_evaluation/service_instance.py +5 -3
  27. wxo_agentic_evaluation/service_provider/__init__.py +129 -9
  28. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  29. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
  30. wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
  31. wxo_agentic_evaluation/service_provider/provider.py +130 -10
  32. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
  33. wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
  34. wxo_agentic_evaluation/type.py +14 -4
  35. wxo_agentic_evaluation/utils/__init__.py +43 -5
  36. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  37. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  38. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  39. wxo_agentic_evaluation/utils/utils.py +14 -9
  40. wxo_agentic_evaluation/wxo_client.py +2 -1
  41. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
  42. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,27 @@
1
1
  import json
2
2
  import os
3
- from typing import List
3
+ from gc import enable
4
+ from typing import Any, Dict, List, Optional
4
5
 
5
6
  import rich
7
+ from dateutil import parser
6
8
 
7
9
  from wxo_agentic_evaluation import __file__
8
10
  from wxo_agentic_evaluation.data_annotator import ERROR_KEYWORDS
9
11
  from wxo_agentic_evaluation.llm_matching import LLMMatcher
10
12
  from wxo_agentic_evaluation.llm_rag_eval import LLMJudge
11
13
  from wxo_agentic_evaluation.llm_safety_eval import LLMSafetyJudge
14
+ from wxo_agentic_evaluation.metrics.evaluations import (
15
+ Evaluation,
16
+ Extractor,
17
+ Metric,
18
+ )
12
19
  from wxo_agentic_evaluation.metrics.llm_as_judge import (
13
20
  AnswerDerailment,
14
21
  AnswerUnsafeTopic,
15
22
  )
16
23
  from wxo_agentic_evaluation.metrics.metrics import (
24
+ CustomEvalMetrics,
17
25
  KeywordSemanticSearchMetric,
18
26
  KnowledgeBaseMetrics,
19
27
  TextMatchType,
@@ -28,7 +36,12 @@ from wxo_agentic_evaluation.prompt.template_render import (
28
36
  UnsafeTopicTemplateRenderer,
29
37
  )
30
38
  from wxo_agentic_evaluation.resource_map import ResourceMap
31
- from wxo_agentic_evaluation.service_provider import get_provider
39
+ from wxo_agentic_evaluation.service_instance import tenant_setup
40
+ from wxo_agentic_evaluation.service_provider import (
41
+ USE_GATEWAY_MODEL_PROVIDER,
42
+ get_provider,
43
+ )
44
+ from wxo_agentic_evaluation.service_provider.provider import Provider
32
45
  from wxo_agentic_evaluation.type import (
33
46
  ContentType,
34
47
  ConversationalSearch,
@@ -76,12 +89,18 @@ DUMMY_GRAPH_NODE_NAME = "dummy-goal"
76
89
  class EvaluationPackage:
77
90
  def __init__(
78
91
  self,
79
- test_case_name,
92
+ test_case_name: str,
80
93
  ground_truth: EvaluationData,
81
- messages,
94
+ messages: list[Message],
82
95
  conversational_search_data: List[ConversationalSearch] = None,
83
96
  resource_map: ResourceMap = None,
84
97
  is_attack_evaluation: bool = False,
98
+ config=None,
99
+ custom_evals: Optional[list[Evaluation]] = None,
100
+ custom_llmaaj_client: Optional[Provider] = None,
101
+ extractors: Optional[list[Extractor]] = None,
102
+ similarity_threshold=0.8,
103
+ enable_fuzzy_matching=False,
85
104
  ):
86
105
  self.tool_dictionary = (
87
106
  {
@@ -109,10 +128,49 @@ class EvaluationPackage:
109
128
  self.ground_truth = ground_truth
110
129
  self.test_case_name = test_case_name
111
130
  self.resource_map = resource_map
131
+ self.custom_evals = custom_evals
132
+ self.custom_llmaaj_client = custom_llmaaj_client
133
+ self.extractors = extractors
134
+ self.enable_fuzzy_matching = enable_fuzzy_matching
112
135
 
113
136
  if not self.is_attack_evaluation:
114
137
  self.validate_ground_truth(self.ground_truth, self.test_case_name)
115
138
 
139
+ extra_kwargs = {}
140
+
141
+ if USE_GATEWAY_MODEL_PROVIDER:
142
+
143
+ if resource_map and hasattr(resource_map, "wxo_client"):
144
+ wxo_client = resource_map.wxo_client
145
+
146
+ if hasattr(wxo_client, "service_url"):
147
+ extra_kwargs["instance_url"] = wxo_client.service_url
148
+
149
+ if hasattr(wxo_client, "api_key"):
150
+ extra_kwargs["token"] = wxo_client.api_key
151
+
152
+ elif config:
153
+ auth = getattr(config, "auth_config", None)
154
+
155
+ if auth:
156
+ instance_url = getattr(auth, "url", None)
157
+ token = getattr(auth, "token", None)
158
+
159
+ if instance_url:
160
+ extra_kwargs["instance_url"] = instance_url
161
+
162
+ if token:
163
+ extra_kwargs["token"] = token
164
+ else:
165
+ token, instance_url, env = tenant_setup(
166
+ service_url=None, tenant_name="local"
167
+ )
168
+ if instance_url:
169
+ extra_kwargs["instance_url"] = instance_url
170
+
171
+ if token:
172
+ extra_kwargs["token"] = token
173
+
116
174
  # output response matching
117
175
  self.matcher = LLMMatcher(
118
176
  llm_client=get_provider(
@@ -122,6 +180,8 @@ class EvaluationPackage:
122
180
  "decoding_method": "greedy",
123
181
  "max_new_tokens": 10,
124
182
  },
183
+ embedding_model_id="sentence-transformers/all-minilm-l6-v2",
184
+ **extra_kwargs,
125
185
  ),
126
186
  keyword_template=KeywordMatchingTemplateRenderer(
127
187
  KEYWORD_MATCHING_PROMPT_PATH
@@ -129,6 +189,8 @@ class EvaluationPackage:
129
189
  semantic_template=SemanticMatchingTemplateRenderer(
130
190
  SEMANTIC_MATCHING_PROMPT_PATH
131
191
  ),
192
+ similarity_threshold=similarity_threshold,
193
+ enable_fuzzy_matching=enable_fuzzy_matching,
132
194
  )
133
195
  # only used for RAG evaluation
134
196
  self.rag_llm_as_a_judge = LLMJudge(
@@ -139,6 +201,7 @@ class EvaluationPackage:
139
201
  "decoding_method": "greedy",
140
202
  "max_new_tokens": 4096,
141
203
  },
204
+ **extra_kwargs,
142
205
  ),
143
206
  faithfulness=FaithfulnessTemplateRenderer(FAITHFULNESS_PROMPT_PATH),
144
207
  answer_relevancy=AnswerRelevancyTemplateRenderer(
@@ -153,6 +216,7 @@ class EvaluationPackage:
153
216
  "decoding_method": "greedy",
154
217
  "max_new_tokens": 4096,
155
218
  },
219
+ **extra_kwargs,
156
220
  ),
157
221
  answer_derailment=DerailmentTemplateRenderer(
158
222
  DERAILMENT_PROMPT_PATH
@@ -305,8 +369,48 @@ class EvaluationPackage:
305
369
  return str(data).lower()
306
370
 
307
371
  @staticmethod
372
+ def _compare_as_date_or_number(normalized_actual, normalized_expected):
373
+ """
374
+ Attempts to compare two normalized values as dates or numbers.
375
+
376
+ Args:
377
+ normalized_actual: The actual value from tool call
378
+ normalized_expected: The expected value from ground truth
379
+
380
+ Returns:
381
+ tuple: (conversion_succeeded, values_match)
382
+ - conversion_succeeded: True if values could be converted to numbers or dates
383
+ - values_match: True if converted values match
384
+ """
385
+ # Try to convert to numbers
386
+ try:
387
+ num_actual = float(normalized_actual)
388
+ num_expected = float(normalized_expected)
389
+ # Conversion succeeded, check if values match
390
+ return (
391
+ True,
392
+ abs(num_actual - num_expected) <= 0.001,
393
+ ) # Small epsilon for float comparison
394
+ except (ValueError, TypeError):
395
+ pass
396
+
397
+ # Try to convert to dates
398
+ try:
399
+ date_actual = parser.parse(normalized_actual)
400
+ date_expected = parser.parse(normalized_expected)
401
+ # Conversion succeeded, check if values match
402
+ return True, date_actual == date_expected
403
+ except (ValueError, TypeError):
404
+ pass
405
+
406
+ # If we get here, neither number nor date conversion worked
407
+ return False, False
408
+
308
409
  def _check_if_args_match_with_ignore(
309
- actual_args: dict[str, str], expected_args: dict[str, str]
410
+ self,
411
+ actual_args: dict[str, str],
412
+ expected_args: dict[str, str],
413
+ enable_fuzzy_matching: bool = False,
310
414
  ) -> bool:
311
415
  """
312
416
  This function checks if a registered tool call matches with the goal node when:
@@ -315,21 +419,50 @@ class EvaluationPackage:
315
419
  actual_args (dict): Made during inference.
316
420
  expected_args (dict): Defined in the test case/ground truth.
317
421
  Returns:
318
- bool: True if match with keyword parameters ignored | False otherwise (improper tool call).
422
+ bool: True if match with keyword parameters ignored | False otherwise (arguments were not corrected).
319
423
  """
320
-
321
424
  if set(actual_args.keys()) != set(expected_args.keys()):
322
425
  return False
323
426
 
427
+ ## now we go through and check each parameter
324
428
  for key in actual_args:
429
+ normalized_actual = EvaluationPackage.normalize_args(
430
+ actual_args[key]
431
+ )
432
+ normalized_expected = EvaluationPackage.normalize_args(
433
+ expected_args[key]
434
+ )
435
+
436
+ # 1. If the args are an ignored keyword or exactly equal, continue to next parameter
325
437
  if (
326
- EvaluationPackage.normalize_args(actual_args[key])
327
- != EvaluationPackage.normalize_args(expected_args[key])
328
- and EvaluationPackage.normalize_args(expected_args[key])
329
- != RESERVED_KEYWORD_FOR_GROUND_TRUTH_ARGS
330
- ):
331
- return False
438
+ normalized_expected == RESERVED_KEYWORD_FOR_GROUND_TRUTH_ARGS
439
+ ) or (normalized_actual == normalized_expected):
440
+ continue
441
+ else:
442
+ # if they're not equal, and fuzzy matching is enabled, do fuzzy.
443
+ if enable_fuzzy_matching:
444
+ # 3. Check date/number conversion
445
+ conversion_succeeded, values_match = (
446
+ EvaluationPackage._compare_as_date_or_number(
447
+ normalized_actual, normalized_expected
448
+ )
449
+ )
450
+ # If conversion succeeded and values match, continue to next parameter
451
+ if conversion_succeeded and values_match:
452
+ continue
453
+ # If conversion succeeded but values don't match, return False
454
+ if conversion_succeeded and not values_match:
455
+ return False
456
+ # 4. If conversion failed, try cosine matching. If this fails, return false for the function
457
+ if not self.matcher.cosine_similarity_semantic_match(
458
+ normalized_actual, normalized_expected
459
+ ):
460
+ return False
461
+ else:
462
+ # If they're not equal and fuzzy matching is not enabled, return false
463
+ return False
332
464
 
465
+ # If we've made it through all parameters without returning False, return True
333
466
  return True
334
467
 
335
468
  def traverse(self):
@@ -401,8 +534,10 @@ class EvaluationPackage:
401
534
  goal_detail.args
402
535
  )
403
536
  or self._check_if_args_match_with_ignore(
404
- msg_tool_call["args"], goal_detail.args
405
- )
537
+ msg_tool_call["args"],
538
+ goal_detail.args,
539
+ enable_fuzzy_matching=self.enable_fuzzy_matching,
540
+ ) # TODO arjun-gupta1 9/29/25: make this also return the method of matching (llm, fuzzy, cosine similarity) so we can write it out to analyze_run.py results
406
541
  ):
407
542
  labelled_messages.append(goal_detail.name)
408
543
  labelled_messages_without_text_step.append(
@@ -484,6 +619,7 @@ class EvaluationPackage:
484
619
  self.messages[0].content,
485
620
  prediction=message.content,
486
621
  ground_truth=goal_detail.response,
622
+ enable_fuzzy_matching=self.enable_fuzzy_matching,
487
623
  )
488
624
  keyword_semantic_match = KeywordSemanticSearchMetric(
489
625
  keyword_match=keyword_match,
@@ -518,6 +654,29 @@ class EvaluationPackage:
518
654
  else:
519
655
  return TextMatchType.text_mismatch.value
520
656
 
657
+ def generate_custom_metrics(
658
+ self, extracted_context: Dict[str, Any]
659
+ ) -> Optional[CustomEvalMetrics]:
660
+ if self.custom_evals is None:
661
+ return None
662
+
663
+ results: list[Metric] = []
664
+ for evaluation in self.custom_evals:
665
+ # TODO: cleanup. The compute method returns a Metric but pydantic thinks it is different.
666
+ # Probably because of some path issue when we auto-discover metrics
667
+ evaluate_result = evaluation.evaluate(
668
+ messages=self.messages,
669
+ ground_truth=self.ground_truth,
670
+ extracted_context=extracted_context,
671
+ )
672
+ if evaluate_result is not None:
673
+ results.append(Metric(**evaluate_result.model_dump()))
674
+
675
+ custom_eval_results = CustomEvalMetrics(
676
+ dataset_name=self.test_case_name, custom_metrics=results
677
+ )
678
+ return custom_eval_results
679
+
521
680
  def generate_summary(self):
522
681
  llm_steps = 0
523
682
  total_step = 0
@@ -530,6 +689,16 @@ class EvaluationPackage:
530
689
  message_with_reasons,
531
690
  ) = self.traverse()
532
691
 
692
+ extracted_context = {}
693
+ if self.extractors is not None and self.custom_evals is not None:
694
+ for extractor in self.extractors:
695
+ context = extractor.extract(
696
+ messages=self.messages,
697
+ ground_truth=self.ground_truth,
698
+ matcher=self.matcher,
699
+ )
700
+ extracted_context[extractor.name] = context
701
+
533
702
  is_success = self.is_topological_sort(
534
703
  self.ground_truth.goals, labelled_messages
535
704
  )
@@ -550,6 +719,10 @@ class EvaluationPackage:
550
719
  knowledge_base_metric_summary = (
551
720
  self.generate_knowledge_base_metric_summary()
552
721
  )
722
+
723
+ custom_metric_summary = self.generate_custom_metrics(
724
+ extracted_context=extracted_context
725
+ )
553
726
  # TO-DO: the table is not printing properly anymore with the new columns introduced
554
727
  # we need to introduce a separate table for these.
555
728
 
@@ -563,6 +736,7 @@ class EvaluationPackage:
563
736
  knowledge_base_metric_summary,
564
737
  message_with_reasons,
565
738
  metrics,
739
+ custom_metric_summary,
566
740
  )
567
741
 
568
742
  def _get_messages_by_role_before_cs(
@@ -74,7 +74,9 @@ class ExternalAgentValidation:
74
74
  payload = {"stream": True}
75
75
  payload["messages"] = messages
76
76
  resp = requests.post(
77
- url=self.service_url, headers=self.header, json=payload,
77
+ url=self.service_url,
78
+ headers=self.header,
79
+ json=payload,
78
80
  )
79
81
  success, logged_events = self._validate_streaming_response(resp)
80
82
 
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Literal, Mapping, Union, Optional
1
+ from typing import Any, List, Literal, Mapping, Optional, Union
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
@@ -71,6 +71,36 @@ def is_transfer_response(step_detail: Dict):
71
71
  return False
72
72
 
73
73
 
74
+ def _generate_user_input(
75
+ user_turn: int,
76
+ story: str,
77
+ conversation_history: list[Message],
78
+ llm_user: LLMUser,
79
+ enable_manual_user_input: bool = False,
80
+ starting_user_input: str | None = None,
81
+ attack_instructions: str | None = None,
82
+ ) -> Message:
83
+ """Generates the user input for the current turn."""
84
+
85
+ if user_turn == 0 and starting_user_input is not None:
86
+ return Message(
87
+ role="user",
88
+ content=starting_user_input,
89
+ type=ContentType.text,
90
+ )
91
+
92
+ if enable_manual_user_input:
93
+ content = input("[medium_orchid1]Enter your input[/medium_orchid1] ✍️: ")
94
+ return Message(role="user", content=content, type=ContentType.text)
95
+
96
+ # llm generated user input
97
+ return llm_user.generate_user_input(
98
+ story,
99
+ conversation_history,
100
+ attack_instructions=attack_instructions,
101
+ )
102
+
103
+
74
104
  class CallTracker(BaseModel):
75
105
  tool_call: List = []
76
106
  tool_response: List = []
@@ -211,7 +241,6 @@ class WXOInferenceBackend:
211
241
 
212
242
  start_time = time.time()
213
243
  for chunk in self._stream_events(user_input, agent_name, thread_id):
214
-
215
244
  event = chunk.get("event", "")
216
245
  if _thread_id := chunk.get("data", {}).get("thread_id"):
217
246
  thread_id = _thread_id
@@ -422,7 +451,6 @@ class WXOInferenceBackend:
422
451
 
423
452
  messages = []
424
453
  for entry in result:
425
-
426
454
  tool_call_id = None
427
455
  if step_history := entry.get("step_history"):
428
456
  for step_message in step_history:
@@ -551,7 +579,6 @@ class WXOInferenceBackend:
551
579
 
552
580
 
553
581
  class EvaluationController:
554
-
555
582
  MAX_CONVERSATION_STEPS = int(os.getenv("MAX_CONVERSATION_STEPS", 20))
556
583
  MESSAGE_SIMILARITY_THRESHOLD = float(
557
584
  os.getenv("MESSAGE_SIMILARITY_THRESHOLD", 0.98)
@@ -585,37 +612,32 @@ class EvaluationController:
585
612
  task_n,
586
613
  story,
587
614
  agent_name: str,
588
- starting_user_input: str = None,
589
- attack_instructions: str = None,
615
+ starting_user_input: str | None = None,
616
+ attack_instructions: str | None = None,
617
+ max_user_turns: int | None = None,
590
618
  ) -> Tuple[List[Message], List[CallTracker], List[ConversationalSearch]]:
591
- step = 0
592
619
  thread_id = None
593
620
  conversation_history: List[Message] = []
594
621
  conversational_search_history_data = []
595
622
  call_tracker = CallTracker()
596
623
 
597
- # make this configurable
598
- while step < self.MAX_CONVERSATION_STEPS:
599
- if step == 0 and starting_user_input:
600
- user_input = Message(
601
- role="user",
602
- content=starting_user_input,
603
- type=ContentType.text,
604
- )
605
- else:
606
- if self.config.enable_manual_user_input == True:
607
- content = input(
608
- "[medium_orchid1]Enter your input[/medium_orchid1] ✍️: "
609
- )
610
- user_input = Message(
611
- role="user", content=content, type=ContentType.text
612
- )
613
- else: # llm
614
- user_input = self.llm_user.generate_user_input(
615
- story,
616
- conversation_history,
617
- attack_instructions=attack_instructions,
618
- )
624
+ max_turns = (
625
+ self.MAX_CONVERSATION_STEPS
626
+ if max_user_turns is None
627
+ else max_user_turns
628
+ )
629
+
630
+ for user_turn in range(max_turns):
631
+ user_input = _generate_user_input(
632
+ user_turn=user_turn,
633
+ story=story,
634
+ conversation_history=conversation_history,
635
+ llm_user=self.llm_user,
636
+ enable_manual_user_input=self.config.enable_manual_user_input,
637
+ starting_user_input=starting_user_input,
638
+ attack_instructions=attack_instructions,
639
+ )
640
+
619
641
  if self.config.enable_verbose_logging:
620
642
  rich.print(
621
643
  f"[dark_khaki][Task-{task_n}][/dark_khaki] 👤[bold blue] User:[/bold blue]",
@@ -662,7 +684,7 @@ class EvaluationController:
662
684
  # hook for subclasses
663
685
  if self._post_message_hook(
664
686
  task_n=task_n,
665
- step=step,
687
+ step=user_turn,
666
688
  message=message,
667
689
  conversation_history=conversation_history,
668
690
  ):
@@ -677,7 +699,6 @@ class EvaluationController:
677
699
  conversational_search_data
678
700
  )
679
701
 
680
- step += 1
681
702
  return (
682
703
  conversation_history,
683
704
  call_tracker,
@@ -742,15 +763,21 @@ class EvaluationController:
742
763
  ):
743
764
  return True
744
765
 
745
- return False # Final fallback for termination is in the main inference loop, which defines MAX_CONVERSATION_STEPS
766
+ # Final fallback for termination is in the main inference loop, which defines MAX_CONVERSATION_STEPS
767
+ return False
768
+
746
769
 
747
770
  class AttackEvaluationController(EvaluationController):
748
- def __init__(self, *args, attack_data=None, attack_evaluator=None, **kwargs):
771
+ def __init__(
772
+ self, *args, attack_data=None, attack_evaluator=None, **kwargs
773
+ ):
749
774
  super().__init__(*args, **kwargs)
750
775
  self.attack_data = attack_data
751
776
  self.attack_evaluator = attack_evaluator
752
777
 
753
- def _post_message_hook(self, task_n, step, message, conversation_history) -> bool:
778
+ def _post_message_hook(
779
+ self, task_n, step, message, conversation_history
780
+ ) -> bool:
754
781
  """Override hook to add live attack evaluation."""
755
782
  if self.attack_evaluator and self.attack_data:
756
783
  success = self.attack_evaluator.evaluate(
@@ -762,7 +789,9 @@ class AttackEvaluationController(EvaluationController):
762
789
  )
763
790
  # persist the live result so the aggregator can pick it up later
764
791
  try:
765
- self.attack_evaluator.save_evaluation_result(self.attack_data, True)
792
+ self.attack_evaluator.save_evaluation_result(
793
+ self.attack_data, True
794
+ )
766
795
  except Exception:
767
796
  pass
768
797
  conversation_history.append(message)
@@ -777,6 +806,7 @@ if __name__ == "__main__":
777
806
  )
778
807
  with open(auth_config_path, "r") as f:
779
808
  auth_config = yaml.safe_load(f)
809
+
780
810
  tenant_name = "local"
781
811
  token = auth_config["auth"][tenant_name]["wxo_mcsp_token"]
782
812
 
@@ -1,10 +1,22 @@
1
+ """
2
+ LLM Matching Module with Cosine Similarity Support
3
+
4
+ This module provides functionality for matching text using:
5
+ 1. LLM-based matching (using a language model to determine semantic equivalence)
6
+ 2. Embedding-based matching (using cosine similarity between text embeddings)
7
+ """
8
+
9
+ import math
1
10
  from typing import List
2
11
 
12
+ from fuzzywuzzy import fuzz
13
+
3
14
  from wxo_agentic_evaluation.prompt.template_render import (
4
15
  KeywordMatchingTemplateRenderer,
5
16
  SemanticMatchingTemplateRenderer,
6
17
  )
7
18
  from wxo_agentic_evaluation.service_provider.watsonx_provider import Provider
19
+ from wxo_agentic_evaluation.utils.utils import safe_divide
8
20
 
9
21
 
10
22
  class LLMMatcher:
@@ -13,10 +25,18 @@ class LLMMatcher:
13
25
  llm_client: Provider,
14
26
  keyword_template: KeywordMatchingTemplateRenderer,
15
27
  semantic_template: SemanticMatchingTemplateRenderer,
28
+ use_llm_for_semantic: bool = True,
29
+ embedding_model_id: str = "sentence-transformers/all-minilm-l6-v2",
30
+ similarity_threshold: float = 0.8,
31
+ enable_fuzzy_matching: bool = False,
16
32
  ):
17
33
  self.llm_client = llm_client
18
34
  self.keyword_template = keyword_template
19
35
  self.semantic_template = semantic_template
36
+ self.embedding_model_id = embedding_model_id
37
+ self.use_llm_for_semantic = use_llm_for_semantic
38
+ self.similarity_threshold = similarity_threshold
39
+ self.enable_fuzzy_matching = enable_fuzzy_matching
20
40
 
21
41
  def keywords_match(self, response_text: str, keywords: List[str]) -> bool:
22
42
  if len(keywords) == 0:
@@ -31,8 +51,40 @@ class LLMMatcher:
31
51
  result = output.strip().lower()
32
52
  return result.startswith("true")
33
53
 
34
- def semantic_match(
35
- self, context: str, prediction: str, ground_truth: str
54
+ def generate_embeddings(
55
+ self, prediction: str, ground_truth: str
56
+ ) -> List[List[float]]:
57
+
58
+ embeddings = self.llm_client.encode([prediction, ground_truth])
59
+
60
+ return embeddings
61
+
62
+ def compute_cosine_similarity(
63
+ self, vec1: List[float], vec2: List[float]
64
+ ) -> float:
65
+ """Calculate cosine similarity between two vectors using pure Python"""
66
+
67
+ # Manual dot product calculation
68
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
69
+
70
+ # Manual magnitude calculations
71
+ magnitude1 = math.sqrt(sum(a * a for a in vec1))
72
+ magnitude2 = math.sqrt(sum(b * b for b in vec2))
73
+
74
+ return safe_divide(dot_product, (magnitude1 * magnitude2))
75
+
76
+ def cosine_similarity_semantic_match(
77
+ self, prediction: str, ground_truth: str
78
+ ) -> bool:
79
+ embeddings = self.generate_embeddings(prediction, ground_truth)
80
+ cosine_similarity = self.compute_cosine_similarity(
81
+ embeddings[0], embeddings[1]
82
+ )
83
+
84
+ return cosine_similarity >= self.similarity_threshold
85
+
86
+ def llm_semantic_match(
87
+ self, context, prediction: str, ground_truth: str
36
88
  ) -> bool:
37
89
  """Performs semantic matching for the agent's final response and the expected response using the starting sentence of the conversation as the context
38
90
 
@@ -44,9 +96,47 @@ class LLMMatcher:
44
96
  Returns:
45
97
  a boolean indicating if the sentences match.
46
98
  """
99
+
47
100
  prompt = self.semantic_template.render(
48
101
  context=context, expected_text=ground_truth, actual_text=prediction
49
102
  )
50
103
  output: str = self.llm_client.query(prompt)
51
104
  result = output.strip().lower()
105
+
52
106
  return result.startswith("true")
107
+
108
+ def fuzzywuzzy_semantic_match(
109
+ self, prediction: str, ground_truth: str
110
+ ) -> bool:
111
+
112
+ similarity_score = fuzz.WRatio(prediction, ground_truth)
113
+
114
+ return similarity_score > self.similarity_threshold
115
+
116
+ def semantic_match(
117
+ self,
118
+ context: str,
119
+ prediction: str,
120
+ ground_truth: str,
121
+ enable_fuzzy_matching: bool = False,
122
+ ) -> bool:
123
+ ## TODO arjun-gupta1 10/06/2025: revist retry with exponential backoff. Opted for direct fallback to cosine similarity to avoid latency for now.
124
+ try:
125
+ return self.llm_semantic_match(context, prediction, ground_truth)
126
+ except Exception as e:
127
+ print(f"LLM semantic match failed: {e}")
128
+
129
+ if enable_fuzzy_matching:
130
+ print("falling back to fuzzy matching")
131
+ # Fallback to cosine similarity if LLM matching is not used or failed
132
+ try:
133
+ return self.cosine_similarity_semantic_match(
134
+ prediction, ground_truth
135
+ )
136
+ except Exception as e:
137
+ print(
138
+ f"Cosine similarity failed: {e}. Falling back to fuzzywuzzy."
139
+ )
140
+
141
+ # Final fallback to fuzzywuzzy
142
+ return self.fuzzywuzzy_semantic_match(prediction, ground_truth)
@@ -21,8 +21,8 @@ class LLMUser:
21
21
  self,
22
22
  user_story,
23
23
  conversation_history: List[Message],
24
- attack_instructions: str = None,
25
- ) -> Message | None:
24
+ attack_instructions: str | None = None,
25
+ ) -> Message:
26
26
  # the tool response is already summarized, we don't need that to take over the chat history context window
27
27
  prompt_input = self.prompt_template.render(
28
28
  conversation_history=[