ibm-watsonx-orchestrate-evaluation-framework 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +49 -39
- wxo_agentic_evaluation/analyze_run.py +822 -344
- wxo_agentic_evaluation/arg_configs.py +39 -2
- wxo_agentic_evaluation/data_annotator.py +22 -4
- wxo_agentic_evaluation/description_quality_checker.py +29 -4
- wxo_agentic_evaluation/evaluation_package.py +197 -18
- wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
- wxo_agentic_evaluation/external_agent/types.py +1 -1
- wxo_agentic_evaluation/inference_backend.py +105 -108
- wxo_agentic_evaluation/llm_matching.py +104 -2
- wxo_agentic_evaluation/llm_user.py +2 -2
- wxo_agentic_evaluation/main.py +147 -38
- wxo_agentic_evaluation/metrics/__init__.py +5 -0
- wxo_agentic_evaluation/metrics/evaluations.py +124 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +4 -3
- wxo_agentic_evaluation/metrics/metrics.py +64 -1
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
- wxo_agentic_evaluation/prompt/template_render.py +20 -2
- wxo_agentic_evaluation/quick_eval.py +23 -11
- wxo_agentic_evaluation/record_chat.py +18 -10
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +169 -100
- wxo_agentic_evaluation/red_teaming/attack_generator.py +63 -40
- wxo_agentic_evaluation/red_teaming/attack_list.py +78 -8
- wxo_agentic_evaluation/red_teaming/attack_runner.py +71 -14
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +10 -10
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +103 -39
- wxo_agentic_evaluation/resource_map.py +3 -1
- wxo_agentic_evaluation/service_instance.py +12 -3
- wxo_agentic_evaluation/service_provider/__init__.py +129 -9
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
- wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
- wxo_agentic_evaluation/service_provider/provider.py +130 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
- wxo_agentic_evaluation/type.py +15 -5
- wxo_agentic_evaluation/utils/__init__.py +44 -3
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/parsers.py +71 -0
- wxo_agentic_evaluation/utils/utils.py +140 -20
- wxo_agentic_evaluation/wxo_client.py +81 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
|
@@ -3,21 +3,15 @@ import os
|
|
|
3
3
|
import time
|
|
4
4
|
from collections import deque
|
|
5
5
|
from enum import Enum
|
|
6
|
-
from typing import Any, Dict, Generator, List, Mapping,
|
|
6
|
+
from typing import Any, Dict, Generator, List, Mapping, Tuple
|
|
7
7
|
|
|
8
8
|
import requests
|
|
9
9
|
import rich
|
|
10
|
-
import urllib3
|
|
11
10
|
import yaml
|
|
12
11
|
from pydantic import BaseModel
|
|
13
|
-
from urllib3.exceptions import InsecureRequestWarning
|
|
14
12
|
|
|
15
13
|
from wxo_agentic_evaluation.arg_configs import TestConfig
|
|
16
14
|
from wxo_agentic_evaluation.llm_user import LLMUser
|
|
17
|
-
from wxo_agentic_evaluation.service_instance import (
|
|
18
|
-
get_env_settings,
|
|
19
|
-
tenant_setup,
|
|
20
|
-
)
|
|
21
15
|
from wxo_agentic_evaluation.service_provider.watsonx_provider import (
|
|
22
16
|
WatsonXProvider,
|
|
23
17
|
)
|
|
@@ -36,6 +30,7 @@ from wxo_agentic_evaluation.utils.utils import (
|
|
|
36
30
|
is_saas_url,
|
|
37
31
|
safe_divide,
|
|
38
32
|
)
|
|
33
|
+
from wxo_agentic_evaluation.wxo_client import WXOClient
|
|
39
34
|
|
|
40
35
|
tokenizer = Tokenizer()
|
|
41
36
|
|
|
@@ -76,67 +71,40 @@ def is_transfer_response(step_detail: Dict):
|
|
|
76
71
|
return False
|
|
77
72
|
|
|
78
73
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
74
|
+
def _generate_user_input(
|
|
75
|
+
user_turn: int,
|
|
76
|
+
story: str,
|
|
77
|
+
conversation_history: list[Message],
|
|
78
|
+
llm_user: LLMUser,
|
|
79
|
+
enable_manual_user_input: bool = False,
|
|
80
|
+
starting_user_input: str | None = None,
|
|
81
|
+
attack_instructions: str | None = None,
|
|
82
|
+
) -> Message:
|
|
83
|
+
"""Generates the user input for the current turn."""
|
|
84
|
+
|
|
85
|
+
if user_turn == 0 and starting_user_input is not None:
|
|
86
|
+
return Message(
|
|
87
|
+
role="user",
|
|
88
|
+
content=starting_user_input,
|
|
89
|
+
type=ContentType.text,
|
|
90
|
+
)
|
|
84
91
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
):
|
|
89
|
-
self.service_url = service_url
|
|
90
|
-
self.api_key = api_key
|
|
92
|
+
if enable_manual_user_input:
|
|
93
|
+
content = input("[medium_orchid1]Enter your input[/medium_orchid1] ✍️: ")
|
|
94
|
+
return Message(role="user", content=content, type=ContentType.text)
|
|
91
95
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
99
|
-
self._verify_ssl = (
|
|
100
|
-
False
|
|
101
|
-
if (
|
|
102
|
-
(bs is True)
|
|
103
|
-
or (isinstance(bs, str) and bs.strip().lower() == "true")
|
|
104
|
-
or (v is None)
|
|
105
|
-
or (
|
|
106
|
-
isinstance(v, str)
|
|
107
|
-
and v.strip().lower() in {"none", "null"}
|
|
108
|
-
)
|
|
109
|
-
)
|
|
110
|
-
else (v if isinstance(v, bool) else True)
|
|
111
|
-
)
|
|
96
|
+
# llm generated user input
|
|
97
|
+
return llm_user.generate_user_input(
|
|
98
|
+
story,
|
|
99
|
+
conversation_history,
|
|
100
|
+
attack_instructions=attack_instructions,
|
|
101
|
+
)
|
|
112
102
|
|
|
113
|
-
if not self._verify_ssl:
|
|
114
|
-
urllib3.disable_warnings(InsecureRequestWarning)
|
|
115
|
-
|
|
116
|
-
def _get_headers(self) -> dict:
|
|
117
|
-
headers = {}
|
|
118
|
-
if self.api_key:
|
|
119
|
-
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
120
|
-
return headers
|
|
121
|
-
|
|
122
|
-
def post(self, payload: dict, path: str, stream=False):
|
|
123
|
-
url = f"{self.service_url}/{path}"
|
|
124
|
-
return requests.post(
|
|
125
|
-
url=url,
|
|
126
|
-
headers=self._get_headers(),
|
|
127
|
-
json=payload,
|
|
128
|
-
stream=stream,
|
|
129
|
-
verify=self._verify_ssl,
|
|
130
|
-
)
|
|
131
103
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
params=params,
|
|
137
|
-
headers=self._get_headers(),
|
|
138
|
-
verify=self._verify_ssl,
|
|
139
|
-
)
|
|
104
|
+
class CallTracker(BaseModel):
|
|
105
|
+
tool_call: List = []
|
|
106
|
+
tool_response: List = []
|
|
107
|
+
generic: List = []
|
|
140
108
|
|
|
141
109
|
|
|
142
110
|
class WXOInferenceBackend:
|
|
@@ -273,7 +241,6 @@ class WXOInferenceBackend:
|
|
|
273
241
|
|
|
274
242
|
start_time = time.time()
|
|
275
243
|
for chunk in self._stream_events(user_input, agent_name, thread_id):
|
|
276
|
-
|
|
277
244
|
event = chunk.get("event", "")
|
|
278
245
|
if _thread_id := chunk.get("data", {}).get("thread_id"):
|
|
279
246
|
thread_id = _thread_id
|
|
@@ -484,7 +451,6 @@ class WXOInferenceBackend:
|
|
|
484
451
|
|
|
485
452
|
messages = []
|
|
486
453
|
for entry in result:
|
|
487
|
-
|
|
488
454
|
tool_call_id = None
|
|
489
455
|
if step_history := entry.get("step_history"):
|
|
490
456
|
for step_message in step_history:
|
|
@@ -613,7 +579,6 @@ class WXOInferenceBackend:
|
|
|
613
579
|
|
|
614
580
|
|
|
615
581
|
class EvaluationController:
|
|
616
|
-
|
|
617
582
|
MAX_CONVERSATION_STEPS = int(os.getenv("MAX_CONVERSATION_STEPS", 20))
|
|
618
583
|
MESSAGE_SIMILARITY_THRESHOLD = float(
|
|
619
584
|
os.getenv("MESSAGE_SIMILARITY_THRESHOLD", 0.98)
|
|
@@ -647,37 +612,32 @@ class EvaluationController:
|
|
|
647
612
|
task_n,
|
|
648
613
|
story,
|
|
649
614
|
agent_name: str,
|
|
650
|
-
starting_user_input: str = None,
|
|
651
|
-
attack_instructions: str = None,
|
|
615
|
+
starting_user_input: str | None = None,
|
|
616
|
+
attack_instructions: str | None = None,
|
|
617
|
+
max_user_turns: int | None = None,
|
|
652
618
|
) -> Tuple[List[Message], List[CallTracker], List[ConversationalSearch]]:
|
|
653
|
-
step = 0
|
|
654
619
|
thread_id = None
|
|
655
620
|
conversation_history: List[Message] = []
|
|
656
621
|
conversational_search_history_data = []
|
|
657
622
|
call_tracker = CallTracker()
|
|
658
623
|
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
if
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
user_input = self.llm_user.generate_user_input(
|
|
677
|
-
story,
|
|
678
|
-
conversation_history,
|
|
679
|
-
attack_instructions=attack_instructions,
|
|
680
|
-
)
|
|
624
|
+
max_turns = (
|
|
625
|
+
self.MAX_CONVERSATION_STEPS
|
|
626
|
+
if max_user_turns is None
|
|
627
|
+
else max_user_turns
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
for user_turn in range(max_turns):
|
|
631
|
+
user_input = _generate_user_input(
|
|
632
|
+
user_turn=user_turn,
|
|
633
|
+
story=story,
|
|
634
|
+
conversation_history=conversation_history,
|
|
635
|
+
llm_user=self.llm_user,
|
|
636
|
+
enable_manual_user_input=self.config.enable_manual_user_input,
|
|
637
|
+
starting_user_input=starting_user_input,
|
|
638
|
+
attack_instructions=attack_instructions,
|
|
639
|
+
)
|
|
640
|
+
|
|
681
641
|
if self.config.enable_verbose_logging:
|
|
682
642
|
rich.print(
|
|
683
643
|
f"[dark_khaki][Task-{task_n}][/dark_khaki] 👤[bold blue] User:[/bold blue]",
|
|
@@ -721,18 +681,37 @@ class EvaluationController:
|
|
|
721
681
|
message.content,
|
|
722
682
|
)
|
|
723
683
|
|
|
684
|
+
# hook for subclasses
|
|
685
|
+
if self._post_message_hook(
|
|
686
|
+
task_n=task_n,
|
|
687
|
+
step=user_turn,
|
|
688
|
+
message=message,
|
|
689
|
+
conversation_history=conversation_history,
|
|
690
|
+
):
|
|
691
|
+
return (
|
|
692
|
+
conversation_history,
|
|
693
|
+
call_tracker,
|
|
694
|
+
conversational_search_history_data,
|
|
695
|
+
)
|
|
696
|
+
|
|
724
697
|
conversation_history.extend(messages)
|
|
725
698
|
conversational_search_history_data.extend(
|
|
726
699
|
conversational_search_data
|
|
727
700
|
)
|
|
728
701
|
|
|
729
|
-
step += 1
|
|
730
702
|
return (
|
|
731
703
|
conversation_history,
|
|
732
704
|
call_tracker,
|
|
733
705
|
conversational_search_history_data,
|
|
734
706
|
)
|
|
735
707
|
|
|
708
|
+
def _post_message_hook(self, **kwargs) -> bool:
|
|
709
|
+
"""
|
|
710
|
+
Hook for subclasses to extend behavior.
|
|
711
|
+
Return True to break the loop early.
|
|
712
|
+
"""
|
|
713
|
+
return False
|
|
714
|
+
|
|
736
715
|
def _is_looping(self, messages: deque) -> bool:
|
|
737
716
|
"""Checks whether the user or assistant is stuck in a loop.
|
|
738
717
|
Args:
|
|
@@ -784,23 +763,40 @@ class EvaluationController:
|
|
|
784
763
|
):
|
|
785
764
|
return True
|
|
786
765
|
|
|
787
|
-
|
|
788
|
-
|
|
766
|
+
# Final fallback for termination is in the main inference loop, which defines MAX_CONVERSATION_STEPS
|
|
767
|
+
return False
|
|
789
768
|
|
|
790
|
-
def get_wxo_client(
|
|
791
|
-
service_url: Optional[str], tenant_name: str, token: Optional[str] = None
|
|
792
|
-
) -> WXOClient:
|
|
793
769
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
770
|
+
class AttackEvaluationController(EvaluationController):
|
|
771
|
+
def __init__(
|
|
772
|
+
self, *args, attack_data=None, attack_evaluator=None, **kwargs
|
|
773
|
+
):
|
|
774
|
+
super().__init__(*args, **kwargs)
|
|
775
|
+
self.attack_data = attack_data
|
|
776
|
+
self.attack_evaluator = attack_evaluator
|
|
777
|
+
|
|
778
|
+
def _post_message_hook(
|
|
779
|
+
self, task_n, step, message, conversation_history
|
|
780
|
+
) -> bool:
|
|
781
|
+
"""Override hook to add live attack evaluation."""
|
|
782
|
+
if self.attack_evaluator and self.attack_data:
|
|
783
|
+
success = self.attack_evaluator.evaluate(
|
|
784
|
+
self.attack_data, conversation_history + [message]
|
|
785
|
+
)
|
|
786
|
+
if success:
|
|
787
|
+
rich.print(
|
|
788
|
+
f"[bold green]Attack for [Task-{task_n}] succeeded early at step {step}! Stopping simulation.[/bold green]"
|
|
789
|
+
)
|
|
790
|
+
# persist the live result so the aggregator can pick it up later
|
|
791
|
+
try:
|
|
792
|
+
self.attack_evaluator.save_evaluation_result(
|
|
793
|
+
self.attack_data, True
|
|
794
|
+
)
|
|
795
|
+
except Exception:
|
|
796
|
+
pass
|
|
797
|
+
conversation_history.append(message)
|
|
798
|
+
return True
|
|
799
|
+
return False
|
|
804
800
|
|
|
805
801
|
|
|
806
802
|
if __name__ == "__main__":
|
|
@@ -810,6 +806,7 @@ if __name__ == "__main__":
|
|
|
810
806
|
)
|
|
811
807
|
with open(auth_config_path, "r") as f:
|
|
812
808
|
auth_config = yaml.safe_load(f)
|
|
809
|
+
|
|
813
810
|
tenant_name = "local"
|
|
814
811
|
token = auth_config["auth"][tenant_name]["wxo_mcsp_token"]
|
|
815
812
|
|
|
@@ -1,10 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Matching Module with Cosine Similarity Support
|
|
3
|
+
|
|
4
|
+
This module provides functionality for matching text using:
|
|
5
|
+
1. LLM-based matching (using a language model to determine semantic equivalence)
|
|
6
|
+
2. Embedding-based matching (using cosine similarity between text embeddings)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
1
10
|
from typing import List
|
|
2
11
|
|
|
12
|
+
from fuzzywuzzy import fuzz
|
|
13
|
+
|
|
3
14
|
from wxo_agentic_evaluation.prompt.template_render import (
|
|
4
15
|
KeywordMatchingTemplateRenderer,
|
|
5
16
|
SemanticMatchingTemplateRenderer,
|
|
6
17
|
)
|
|
7
18
|
from wxo_agentic_evaluation.service_provider.watsonx_provider import Provider
|
|
19
|
+
from wxo_agentic_evaluation.utils.utils import safe_divide
|
|
8
20
|
|
|
9
21
|
|
|
10
22
|
class LLMMatcher:
|
|
@@ -13,10 +25,18 @@ class LLMMatcher:
|
|
|
13
25
|
llm_client: Provider,
|
|
14
26
|
keyword_template: KeywordMatchingTemplateRenderer,
|
|
15
27
|
semantic_template: SemanticMatchingTemplateRenderer,
|
|
28
|
+
use_llm_for_semantic: bool = True,
|
|
29
|
+
embedding_model_id: str = "sentence-transformers/all-minilm-l6-v2",
|
|
30
|
+
similarity_threshold: float = 0.8,
|
|
31
|
+
enable_fuzzy_matching: bool = False,
|
|
16
32
|
):
|
|
17
33
|
self.llm_client = llm_client
|
|
18
34
|
self.keyword_template = keyword_template
|
|
19
35
|
self.semantic_template = semantic_template
|
|
36
|
+
self.embedding_model_id = embedding_model_id
|
|
37
|
+
self.use_llm_for_semantic = use_llm_for_semantic
|
|
38
|
+
self.similarity_threshold = similarity_threshold
|
|
39
|
+
self.enable_fuzzy_matching = enable_fuzzy_matching
|
|
20
40
|
|
|
21
41
|
def keywords_match(self, response_text: str, keywords: List[str]) -> bool:
|
|
22
42
|
if len(keywords) == 0:
|
|
@@ -31,10 +51,92 @@ class LLMMatcher:
|
|
|
31
51
|
result = output.strip().lower()
|
|
32
52
|
return result.startswith("true")
|
|
33
53
|
|
|
34
|
-
def
|
|
54
|
+
def generate_embeddings(
|
|
55
|
+
self, prediction: str, ground_truth: str
|
|
56
|
+
) -> List[List[float]]:
|
|
57
|
+
|
|
58
|
+
embeddings = self.llm_client.encode([prediction, ground_truth])
|
|
59
|
+
|
|
60
|
+
return embeddings
|
|
61
|
+
|
|
62
|
+
def compute_cosine_similarity(
|
|
63
|
+
self, vec1: List[float], vec2: List[float]
|
|
64
|
+
) -> float:
|
|
65
|
+
"""Calculate cosine similarity between two vectors using pure Python"""
|
|
66
|
+
|
|
67
|
+
# Manual dot product calculation
|
|
68
|
+
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
69
|
+
|
|
70
|
+
# Manual magnitude calculations
|
|
71
|
+
magnitude1 = math.sqrt(sum(a * a for a in vec1))
|
|
72
|
+
magnitude2 = math.sqrt(sum(b * b for b in vec2))
|
|
73
|
+
|
|
74
|
+
return safe_divide(dot_product, (magnitude1 * magnitude2))
|
|
75
|
+
|
|
76
|
+
def cosine_similarity_semantic_match(
|
|
77
|
+
self, prediction: str, ground_truth: str
|
|
78
|
+
) -> bool:
|
|
79
|
+
embeddings = self.generate_embeddings(prediction, ground_truth)
|
|
80
|
+
cosine_similarity = self.compute_cosine_similarity(
|
|
81
|
+
embeddings[0], embeddings[1]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return cosine_similarity >= self.similarity_threshold
|
|
85
|
+
|
|
86
|
+
def llm_semantic_match(
|
|
87
|
+
self, context, prediction: str, ground_truth: str
|
|
88
|
+
) -> bool:
|
|
89
|
+
"""Performs semantic matching for the agent's final response and the expected response using the starting sentence of the conversation as the context
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
context: The starting sentence of the conversation. TODO can also consider using the LLM user's story
|
|
93
|
+
prediction: the predicted string
|
|
94
|
+
ground_truth: the expected string
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
a boolean indicating if the sentences match.
|
|
98
|
+
"""
|
|
99
|
+
|
|
35
100
|
prompt = self.semantic_template.render(
|
|
36
|
-
expected_text=ground_truth, actual_text=prediction
|
|
101
|
+
context=context, expected_text=ground_truth, actual_text=prediction
|
|
37
102
|
)
|
|
38
103
|
output: str = self.llm_client.query(prompt)
|
|
39
104
|
result = output.strip().lower()
|
|
105
|
+
|
|
40
106
|
return result.startswith("true")
|
|
107
|
+
|
|
108
|
+
def fuzzywuzzy_semantic_match(
|
|
109
|
+
self, prediction: str, ground_truth: str
|
|
110
|
+
) -> bool:
|
|
111
|
+
|
|
112
|
+
similarity_score = fuzz.WRatio(prediction, ground_truth)
|
|
113
|
+
|
|
114
|
+
return similarity_score > self.similarity_threshold
|
|
115
|
+
|
|
116
|
+
def semantic_match(
|
|
117
|
+
self,
|
|
118
|
+
context: str,
|
|
119
|
+
prediction: str,
|
|
120
|
+
ground_truth: str,
|
|
121
|
+
enable_fuzzy_matching: bool = False,
|
|
122
|
+
) -> bool:
|
|
123
|
+
## TODO arjun-gupta1 10/06/2025: revist retry with exponential backoff. Opted for direct fallback to cosine similarity to avoid latency for now.
|
|
124
|
+
try:
|
|
125
|
+
return self.llm_semantic_match(context, prediction, ground_truth)
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"LLM semantic match failed: {e}")
|
|
128
|
+
|
|
129
|
+
if enable_fuzzy_matching:
|
|
130
|
+
print("falling back to fuzzy matching")
|
|
131
|
+
# Fallback to cosine similarity if LLM matching is not used or failed
|
|
132
|
+
try:
|
|
133
|
+
return self.cosine_similarity_semantic_match(
|
|
134
|
+
prediction, ground_truth
|
|
135
|
+
)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
print(
|
|
138
|
+
f"Cosine similarity failed: {e}. Falling back to fuzzywuzzy."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Final fallback to fuzzywuzzy
|
|
142
|
+
return self.fuzzywuzzy_semantic_match(prediction, ground_truth)
|
|
@@ -21,8 +21,8 @@ class LLMUser:
|
|
|
21
21
|
self,
|
|
22
22
|
user_story,
|
|
23
23
|
conversation_history: List[Message],
|
|
24
|
-
attack_instructions: str = None,
|
|
25
|
-
) -> Message
|
|
24
|
+
attack_instructions: str | None = None,
|
|
25
|
+
) -> Message:
|
|
26
26
|
# the tool response is already summarized, we don't need that to take over the chat history context window
|
|
27
27
|
prompt_input = self.prompt_template.render(
|
|
28
28
|
conversation_history=[
|