ibm-watsonx-orchestrate-evaluation-framework 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info}/METADATA +1 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info}/RECORD +10 -9
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +3 -3
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +6 -8
- wxo_agentic_evaluation/service_provider/__init__.py +15 -6
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +4 -3
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +138 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +3 -110
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.9.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -53,7 +53,7 @@ wxo_agentic_evaluation/red_teaming/attack_generator.py,sha256=YQi9xoaFATBNGe_Neb
|
|
|
53
53
|
wxo_agentic_evaluation/red_teaming/attack_list.py,sha256=edphWARWqDtXFtcHTVbRXngvO0YfG5SgrfPtrBRXuFw,4734
|
|
54
54
|
wxo_agentic_evaluation/red_teaming/attack_runner.py,sha256=qBZY4GK1352NUMyED5LVjjbcvpdCcxG6mDIN1HvxKIc,4340
|
|
55
55
|
wxo_agentic_evaluation/referenceless_eval/__init__.py,sha256=lijXMgQ8nQe-9eIfade2jLfHMlXfYafMZIwXtC9KDZo,106
|
|
56
|
-
wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py,sha256=
|
|
56
|
+
wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py,sha256=ypEMOeAwaztGkOuDr_2JArSQWwos7XcBTwo8lFs2N5w,4262
|
|
57
57
|
wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
58
|
wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py,sha256=UidTaT9g5IxbcakfQqP_9c5civ1wDqY-PpPUf0uOXJo,915
|
|
59
59
|
wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -67,7 +67,7 @@ wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_sele
|
|
|
67
67
|
wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json,sha256=o4oRur1MiXO2RYzmzj07QOBzX75DyU7T7yd-gFsgFdo,30563
|
|
68
68
|
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
69
69
|
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py,sha256=kMMFq4ABX5q6cPnDdublLMVqXu4Ij-x4OlxZyePWIjc,3599
|
|
70
|
-
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py,sha256=
|
|
70
|
+
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py,sha256=44HNEoIt3_jKZczs1qB8WGltCG-vn3ZI5aNhucxSDeM,9272
|
|
71
71
|
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py,sha256=z_k-qdFoUJqstkPYn9Zmhlp2YTVQKJtoDZCIdKow664,17306
|
|
72
72
|
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py,sha256=_Er2KfCkc3HFmOmxZT6eb-e7qF7ukqsf6Si5CJTqPPg,6016
|
|
73
73
|
wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py,sha256=tW1wc87WIm8BZh2lhdj1RDP6VdRLqZBWSMmemSttbGs,22034
|
|
@@ -80,17 +80,18 @@ wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py,sha256=CGQ5LvhQrmxAy
|
|
|
80
80
|
wxo_agentic_evaluation/referenceless_eval/metrics/utils.py,sha256=jurmc4KFFKH4hwnvor2xg97H91b-xJc3cUKYaU2I8uM,1370
|
|
81
81
|
wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
82
82
|
wxo_agentic_evaluation/referenceless_eval/prompt/runner.py,sha256=FFmcSWXQnLmylpYyj8LZuPwb6nqwQp-jj6Mv9g8zby0,5052
|
|
83
|
-
wxo_agentic_evaluation/service_provider/__init__.py,sha256=
|
|
84
|
-
wxo_agentic_evaluation/service_provider/model_proxy_provider.py,sha256=
|
|
83
|
+
wxo_agentic_evaluation/service_provider/__init__.py,sha256=yNQ-urOIdjANbpCzVAhkPHNcpBY6hndDJgPZM1C2qeo,2107
|
|
84
|
+
wxo_agentic_evaluation/service_provider/model_proxy_provider.py,sha256=EW1JIiIWoKaTTC-fqKURSsbdyo-dbVWYVrXY8-gEmvc,4081
|
|
85
85
|
wxo_agentic_evaluation/service_provider/ollama_provider.py,sha256=HMHQVUGFbLSQI1dhysAn70ozJl90yRg-CbNd4vsz-Dc,1116
|
|
86
86
|
wxo_agentic_evaluation/service_provider/provider.py,sha256=MsnRzLYAaQiU6y6xf6eId7kn6-CetQuNZl00EP-Nl28,417
|
|
87
|
-
wxo_agentic_evaluation/service_provider/
|
|
87
|
+
wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py,sha256=aJrCz8uco6HOQwNCSjEKviwnhlyLTNAGpLtsOAegQ70,5200
|
|
88
|
+
wxo_agentic_evaluation/service_provider/watsonx_provider.py,sha256=ugXCXwrfi_XC2d9FPa96ccMKGQbTd1ElDw8RNR8TDB8,6544
|
|
88
89
|
wxo_agentic_evaluation/utils/__init__.py,sha256=ItryTgc1jVc32rB3XktTFaYGA_A6bRIDZ1Pts_JGmv8,144
|
|
89
90
|
wxo_agentic_evaluation/utils/open_ai_tool_extractor.py,sha256=Vyji_edgou2xMLbsGwFG-QI7xRBNvO3-1nbeOc8ZuFo,5646
|
|
90
91
|
wxo_agentic_evaluation/utils/rich_utils.py,sha256=J9lzL4ETQeiAJcXKsUzXh82XdKvlDY7jmcgTQlwmL9s,6252
|
|
91
92
|
wxo_agentic_evaluation/utils/rouge_score.py,sha256=WvcGh6mwF4rWH599J9_lAt3BfaHbAZKtKEJBsC61iKo,692
|
|
92
93
|
wxo_agentic_evaluation/utils/utils.py,sha256=qQR_2W5p0Rk6KSE3-llRyZrWXkO5zG9JW7H1692L4PI,11428
|
|
93
|
-
ibm_watsonx_orchestrate_evaluation_framework-1.0.
|
|
94
|
-
ibm_watsonx_orchestrate_evaluation_framework-1.0.
|
|
95
|
-
ibm_watsonx_orchestrate_evaluation_framework-1.0.
|
|
96
|
-
ibm_watsonx_orchestrate_evaluation_framework-1.0.
|
|
94
|
+
ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info/METADATA,sha256=4yfSRfaNQUwauYPqvTFAoaVSn_c3i5YbIC7SFK4SnDU,16105
|
|
95
|
+
ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
96
|
+
ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info/top_level.txt,sha256=2okpqtpxyqHoLyb2msio4pzqSg7yPSzwI7ekks96wYE,23
|
|
97
|
+
ibm_watsonx_orchestrate_evaluation_framework-1.1.0.dist-info/RECORD,,
|
|
@@ -20,7 +20,7 @@ from wxo_agentic_evaluation.referenceless_eval.function_calling.pipeline.types i
|
|
|
20
20
|
ToolSpec,
|
|
21
21
|
)
|
|
22
22
|
from wxo_agentic_evaluation.referenceless_eval.function_calling import metrics
|
|
23
|
-
from wxo_agentic_evaluation.service_provider.
|
|
23
|
+
from wxo_agentic_evaluation.service_provider.referenceless_provider_wrapper import LLMKitWrapper
|
|
24
24
|
|
|
25
25
|
def metrics_dir():
|
|
26
26
|
path = importlib.resources.files(metrics)
|
|
@@ -57,8 +57,8 @@ class ReflectionPipeline:
|
|
|
57
57
|
|
|
58
58
|
def __init__(
|
|
59
59
|
self,
|
|
60
|
-
metrics_client:
|
|
61
|
-
codegen_client: Optional[
|
|
60
|
+
metrics_client: LLMKitWrapper,
|
|
61
|
+
codegen_client: Optional[LLMKitWrapper] = None,
|
|
62
62
|
general_metrics: Optional[
|
|
63
63
|
Union[Path, List[FunctionCallMetric], List[str]]
|
|
64
64
|
] = _DEFAULT_GENERAL_RUNTIME,
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import os
|
|
3
2
|
from typing import Any, List, Mapping
|
|
4
3
|
|
|
5
4
|
import rich
|
|
@@ -16,7 +15,7 @@ from wxo_agentic_evaluation.referenceless_eval.function_calling.pipeline.types i
|
|
|
16
15
|
ToolSpec,
|
|
17
16
|
)
|
|
18
17
|
from wxo_agentic_evaluation.type import Message
|
|
19
|
-
from wxo_agentic_evaluation.service_provider
|
|
18
|
+
from wxo_agentic_evaluation.service_provider import get_provider
|
|
20
19
|
|
|
21
20
|
class ReferencelessEvaluation:
|
|
22
21
|
"""
|
|
@@ -33,12 +32,12 @@ class ReferencelessEvaluation:
|
|
|
33
32
|
messages: List[Message],
|
|
34
33
|
model_id: str,
|
|
35
34
|
task_n: str,
|
|
36
|
-
dataset_name: str,
|
|
37
|
-
|
|
38
|
-
self.metrics_client =
|
|
35
|
+
dataset_name: str,):
|
|
36
|
+
|
|
37
|
+
self.metrics_client = get_provider(
|
|
39
38
|
model_id=model_id,
|
|
40
|
-
|
|
41
|
-
|
|
39
|
+
params={"min_new_tokens": 0, "decoding_method": "greedy", "max_new_tokens": 4096},
|
|
40
|
+
referenceless_eval=True
|
|
42
41
|
)
|
|
43
42
|
|
|
44
43
|
self.pipeline = ReflectionPipeline(
|
|
@@ -57,7 +56,6 @@ class ReferencelessEvaluation:
|
|
|
57
56
|
def _run_pipeline(self, examples: List[Mapping[str, Any]]):
|
|
58
57
|
results = []
|
|
59
58
|
for example in examples:
|
|
60
|
-
# self.pipeline.sy
|
|
61
59
|
result = self.pipeline.run_sync(
|
|
62
60
|
conversation=example["context"],
|
|
63
61
|
inventory=self.apis_specs,
|
|
@@ -1,21 +1,30 @@
|
|
|
1
1
|
from wxo_agentic_evaluation.service_provider.ollama_provider import OllamaProvider
|
|
2
2
|
from wxo_agentic_evaluation.service_provider.watsonx_provider import WatsonXProvider
|
|
3
3
|
from wxo_agentic_evaluation.service_provider.model_proxy_provider import ModelProxyProvider
|
|
4
|
+
from wxo_agentic_evaluation.service_provider.referenceless_provider_wrapper import ModelProxyProviderLLMKitWrapper, WatsonXLLMKitWrapper
|
|
4
5
|
from wxo_agentic_evaluation.arg_configs import ProviderConfig
|
|
5
6
|
|
|
6
7
|
import os
|
|
7
8
|
|
|
8
|
-
def _instantiate_provider(config: ProviderConfig, **kwargs):
|
|
9
|
+
def _instantiate_provider(config: ProviderConfig, is_referenceless_eval: bool = False, **kwargs):
|
|
9
10
|
if config.provider == "watsonx":
|
|
10
|
-
|
|
11
|
+
if is_referenceless_eval:
|
|
12
|
+
provider = WatsonXLLMKitWrapper
|
|
13
|
+
else:
|
|
14
|
+
provider = WatsonXProvider
|
|
15
|
+
return provider(model_id=config.model_id, **kwargs)
|
|
11
16
|
elif config.provider == "ollama":
|
|
12
17
|
return OllamaProvider(model_id=config.model_id, **kwargs)
|
|
13
18
|
elif config.provider == "model_proxy":
|
|
14
|
-
|
|
19
|
+
if is_referenceless_eval:
|
|
20
|
+
provider = ModelProxyProviderLLMKitWrapper
|
|
21
|
+
else:
|
|
22
|
+
provider = ModelProxyProvider
|
|
23
|
+
return provider(model_id=config.model_id, **kwargs)
|
|
15
24
|
else:
|
|
16
25
|
raise RuntimeError(f"target provider is not supported {config.provider}")
|
|
17
26
|
|
|
18
|
-
def get_provider(config: ProviderConfig = None, model_id: str = None, **kwargs):
|
|
27
|
+
def get_provider(config: ProviderConfig = None, model_id: str = None, referenceless_eval: bool = False, **kwargs):
|
|
19
28
|
if config:
|
|
20
29
|
return _instantiate_provider(config, **kwargs)
|
|
21
30
|
|
|
@@ -24,11 +33,11 @@ def get_provider(config: ProviderConfig = None, model_id: str = None, **kwargs):
|
|
|
24
33
|
|
|
25
34
|
if "WATSONX_APIKEY" in os.environ and "WATSONX_SPACE_ID" in os.environ:
|
|
26
35
|
config = ProviderConfig(provider="watsonx", model_id=model_id)
|
|
27
|
-
return _instantiate_provider(config, **kwargs)
|
|
36
|
+
return _instantiate_provider(config, referenceless_eval, **kwargs)
|
|
28
37
|
|
|
29
38
|
if "WO_API_KEY" in os.environ and "WO_INSTANCE" in os.environ:
|
|
30
39
|
config = ProviderConfig(provider="model_proxy", model_id=model_id)
|
|
31
|
-
return _instantiate_provider(config, **kwargs)
|
|
40
|
+
return _instantiate_provider(config, referenceless_eval, **kwargs)
|
|
32
41
|
|
|
33
42
|
raise RuntimeError(
|
|
34
43
|
"No provider found. Please either provide a config or set the required environment variables."
|
|
@@ -38,9 +38,10 @@ class ModelProxyProvider(Provider):
|
|
|
38
38
|
self.api_key = api_key
|
|
39
39
|
self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
|
|
40
40
|
self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
self.
|
|
41
|
+
|
|
42
|
+
self.instance_url = instance_url
|
|
43
|
+
self.url = self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
|
|
44
|
+
self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
|
|
44
45
|
|
|
45
46
|
self.lock = Lock()
|
|
46
47
|
self.token, self.refresh_time = self.get_token()
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from typing import List, Mapping, Union, Optional, Any
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import rich
|
|
6
|
+
|
|
7
|
+
from wxo_agentic_evaluation.service_provider.model_proxy_provider import ModelProxyProvider
|
|
8
|
+
from wxo_agentic_evaluation.service_provider.watsonx_provider import WatsonXProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LLMResponse:
|
|
12
|
+
"""
|
|
13
|
+
NOTE: Taken from LLM-Eval-Kit
|
|
14
|
+
Response object that can contain both content and tool calls
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None):
|
|
18
|
+
self.content = content
|
|
19
|
+
self.tool_calls = tool_calls or []
|
|
20
|
+
|
|
21
|
+
def __str__(self) -> str:
|
|
22
|
+
"""Return the content of the response as a string."""
|
|
23
|
+
return self.content
|
|
24
|
+
|
|
25
|
+
def __repr__(self) -> str:
|
|
26
|
+
"""Return a string representation of the LLMResponse object."""
|
|
27
|
+
return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
|
|
28
|
+
|
|
29
|
+
class LLMKitWrapper(ABC):
|
|
30
|
+
""" In the future this wrapper won't be neccesary.
|
|
31
|
+
Right now the referenceless code requires a `generate()` function for the metrics client.
|
|
32
|
+
In refactor, rewrite referenceless code so this wrapper is not needed.
|
|
33
|
+
"""
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def chat():
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def generate(
|
|
39
|
+
self,
|
|
40
|
+
prompt: Union[str, List[Mapping[str, str]]],
|
|
41
|
+
*,
|
|
42
|
+
schema,
|
|
43
|
+
retries: int = 3,
|
|
44
|
+
generation_args: Optional[Any] = None,
|
|
45
|
+
**kwargs: Any
|
|
46
|
+
):
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
In future, implement validation of response like in llmevalkit
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
for attempt in range(1, retries + 1):
|
|
53
|
+
try:
|
|
54
|
+
raw_response = self.chat(prompt)
|
|
55
|
+
response = self._parse_llm_response(raw_response)
|
|
56
|
+
return response
|
|
57
|
+
except Exception as e:
|
|
58
|
+
rich.print(f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))")
|
|
59
|
+
|
|
60
|
+
def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
|
|
61
|
+
"""
|
|
62
|
+
Extract the generated text and tool calls from a watsonx response.
|
|
63
|
+
|
|
64
|
+
- For text generation: raw['results'][0]['generated_text']
|
|
65
|
+
- For chat: raw['choices'][0]['message']['content']
|
|
66
|
+
"""
|
|
67
|
+
content = ""
|
|
68
|
+
tool_calls = []
|
|
69
|
+
|
|
70
|
+
if isinstance(raw, dict) and "choices" in raw:
|
|
71
|
+
choices = raw["choices"]
|
|
72
|
+
if isinstance(choices, list) and choices:
|
|
73
|
+
first = choices[0]
|
|
74
|
+
msg = first.get("message")
|
|
75
|
+
if isinstance(msg, dict):
|
|
76
|
+
content = msg.get("content", "")
|
|
77
|
+
# Extract tool calls if present
|
|
78
|
+
if "tool_calls" in msg and msg["tool_calls"]:
|
|
79
|
+
tool_calls = []
|
|
80
|
+
for tool_call in msg["tool_calls"]:
|
|
81
|
+
tool_call_dict = {
|
|
82
|
+
"id": tool_call.get("id"),
|
|
83
|
+
"type": tool_call.get("type", "function"),
|
|
84
|
+
"function": {
|
|
85
|
+
"name": tool_call.get("function", {}).get("name"),
|
|
86
|
+
"arguments": tool_call.get("function", {}).get(
|
|
87
|
+
"arguments"
|
|
88
|
+
),
|
|
89
|
+
},
|
|
90
|
+
}
|
|
91
|
+
tool_calls.append(tool_call_dict)
|
|
92
|
+
elif "text" in first:
|
|
93
|
+
content = first["text"]
|
|
94
|
+
|
|
95
|
+
if not content and not tool_calls:
|
|
96
|
+
raise ValueError(f"Unexpected watsonx response format: {raw!r}")
|
|
97
|
+
|
|
98
|
+
# Return LLMResponse if tool calls exist, otherwise just content
|
|
99
|
+
if tool_calls:
|
|
100
|
+
return LLMResponse(content=content, tool_calls=tool_calls)
|
|
101
|
+
|
|
102
|
+
return content
|
|
103
|
+
|
|
104
|
+
class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
|
|
105
|
+
def chat(self, sentence: List[str]):
|
|
106
|
+
if self.model_id is None:
|
|
107
|
+
raise Exception("model id must be specified for text generation")
|
|
108
|
+
chat_url = f"{self.instance_url}/ml/v1/text/chat?version=2023-10-25"
|
|
109
|
+
self.refresh_token_if_expires()
|
|
110
|
+
headers = self.get_header()
|
|
111
|
+
data = {
|
|
112
|
+
"model_id": self.model_id,
|
|
113
|
+
"messages": sentence,
|
|
114
|
+
"parameters": self.params,
|
|
115
|
+
"space_id": "1",
|
|
116
|
+
"timeout": self.timeout
|
|
117
|
+
}
|
|
118
|
+
resp = requests.post(url=chat_url, headers=headers, json=data)
|
|
119
|
+
if resp.status_code == 200:
|
|
120
|
+
return resp.json()
|
|
121
|
+
else:
|
|
122
|
+
resp.raise_for_status()
|
|
123
|
+
|
|
124
|
+
class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
|
|
125
|
+
def chat(self, sentence: list):
|
|
126
|
+
chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
|
|
127
|
+
headers = self.prepare_header()
|
|
128
|
+
data = {
|
|
129
|
+
"model_id": self.model_id,
|
|
130
|
+
"messages": sentence,
|
|
131
|
+
"parameters": self.params,
|
|
132
|
+
"space_id": self.space_id
|
|
133
|
+
}
|
|
134
|
+
resp = requests.post(url=chat_url, headers=headers, json=data)
|
|
135
|
+
if resp.status_code == 200:
|
|
136
|
+
return resp.json()
|
|
137
|
+
else:
|
|
138
|
+
resp.raise_for_status()
|
|
@@ -2,12 +2,10 @@ import os
|
|
|
2
2
|
import requests
|
|
3
3
|
import json
|
|
4
4
|
from types import MappingProxyType
|
|
5
|
-
from typing import List, Mapping, Union
|
|
6
|
-
from functools import singledispatchmethod
|
|
5
|
+
from typing import List, Mapping, Union
|
|
7
6
|
import dataclasses
|
|
8
7
|
from threading import Lock
|
|
9
8
|
import time
|
|
10
|
-
import rich
|
|
11
9
|
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
12
10
|
|
|
13
11
|
ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
|
|
@@ -90,12 +88,7 @@ class WatsonXProvider(Provider):
|
|
|
90
88
|
"Content-Type": "application/json"}
|
|
91
89
|
return headers
|
|
92
90
|
|
|
93
|
-
|
|
94
|
-
def generate(self, sentence):
|
|
95
|
-
raise ValueError(f"Input must either be a string or a list of dictionaries")
|
|
96
|
-
|
|
97
|
-
@generate.register
|
|
98
|
-
def _(self, sentence: str):
|
|
91
|
+
def _query(self, sentence: str):
|
|
99
92
|
headers = self.prepare_header()
|
|
100
93
|
|
|
101
94
|
data = {"model_id": self.model_id, "input": sentence,
|
|
@@ -107,22 +100,6 @@ class WatsonXProvider(Provider):
|
|
|
107
100
|
else:
|
|
108
101
|
resp.raise_for_status()
|
|
109
102
|
|
|
110
|
-
@generate.register
|
|
111
|
-
def _(self, sentence: list):
|
|
112
|
-
chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
|
|
113
|
-
headers = self.prepare_header()
|
|
114
|
-
data = {
|
|
115
|
-
"model_id": self.model_id,
|
|
116
|
-
"messages": sentence,
|
|
117
|
-
"parameters": self.params,
|
|
118
|
-
"space_id": self.space_id
|
|
119
|
-
}
|
|
120
|
-
resp = requests.post(url=chat_url, headers=headers, json=data)
|
|
121
|
-
if resp.status_code == 200:
|
|
122
|
-
return resp.json()
|
|
123
|
-
else:
|
|
124
|
-
resp.raise_for_status()
|
|
125
|
-
|
|
126
103
|
def _refresh_token(self):
|
|
127
104
|
# if we do not have a token or the current timestamp is 9 minutes away from expire.
|
|
128
105
|
if not self.access_token or time.time() > self.refresh_time:
|
|
@@ -134,7 +111,7 @@ class WatsonXProvider(Provider):
|
|
|
134
111
|
if self.model_id is None:
|
|
135
112
|
raise Exception("model id must be specified for text generation")
|
|
136
113
|
try:
|
|
137
|
-
response = self.
|
|
114
|
+
response = self._query(sentence)
|
|
138
115
|
if (generated_text := response.get("generated_text")):
|
|
139
116
|
return generated_text
|
|
140
117
|
elif (message := response.get("message")):
|
|
@@ -165,90 +142,6 @@ class WatsonXProvider(Provider):
|
|
|
165
142
|
else:
|
|
166
143
|
resp.raise_for_status()
|
|
167
144
|
|
|
168
|
-
class LLMResponse:
|
|
169
|
-
"""
|
|
170
|
-
NOTE: Taken from LLM-Eval-Kit
|
|
171
|
-
Response object that can contain both content and tool calls
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
def __init__(self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None):
|
|
175
|
-
self.content = content
|
|
176
|
-
self.tool_calls = tool_calls or []
|
|
177
|
-
|
|
178
|
-
def __str__(self) -> str:
|
|
179
|
-
"""Return the content of the response as a string."""
|
|
180
|
-
return self.content
|
|
181
|
-
|
|
182
|
-
def __repr__(self) -> str:
|
|
183
|
-
"""Return a string representation of the LLMResponse object."""
|
|
184
|
-
return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
|
|
185
|
-
|
|
186
|
-
class WatsonXLLMKitWrapper(WatsonXProvider):
|
|
187
|
-
def generate(
|
|
188
|
-
self,
|
|
189
|
-
prompt: Union[str, List[Mapping[str, str]]],
|
|
190
|
-
*,
|
|
191
|
-
schema,
|
|
192
|
-
retries: int = 3,
|
|
193
|
-
generation_args: Optional[Any] = None,
|
|
194
|
-
**kwargs: Any
|
|
195
|
-
):
|
|
196
|
-
|
|
197
|
-
"""
|
|
198
|
-
In future, implement validation of response like in llmevalkit
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
for attempt in range(1, retries + 1):
|
|
202
|
-
try:
|
|
203
|
-
raw_response = super().generate(prompt)
|
|
204
|
-
response = self._parse_llm_response(raw_response)
|
|
205
|
-
return response
|
|
206
|
-
except Exception as e:
|
|
207
|
-
rich.print(f"[b][r] WatsonX generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))")
|
|
208
|
-
|
|
209
|
-
def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
|
|
210
|
-
"""
|
|
211
|
-
Extract the generated text and tool calls from a watsonx response.
|
|
212
|
-
|
|
213
|
-
- For text generation: raw['results'][0]['generated_text']
|
|
214
|
-
- For chat: raw['choices'][0]['message']['content']
|
|
215
|
-
"""
|
|
216
|
-
content = ""
|
|
217
|
-
tool_calls = []
|
|
218
|
-
|
|
219
|
-
if isinstance(raw, dict) and "choices" in raw:
|
|
220
|
-
choices = raw["choices"]
|
|
221
|
-
if isinstance(choices, list) and choices:
|
|
222
|
-
first = choices[0]
|
|
223
|
-
msg = first.get("message")
|
|
224
|
-
if isinstance(msg, dict):
|
|
225
|
-
content = msg.get("content", "")
|
|
226
|
-
# Extract tool calls if present
|
|
227
|
-
if "tool_calls" in msg and msg["tool_calls"]:
|
|
228
|
-
tool_calls = []
|
|
229
|
-
for tool_call in msg["tool_calls"]:
|
|
230
|
-
tool_call_dict = {
|
|
231
|
-
"id": tool_call.get("id"),
|
|
232
|
-
"type": tool_call.get("type", "function"),
|
|
233
|
-
"function": {
|
|
234
|
-
"name": tool_call.get("function", {}).get("name"),
|
|
235
|
-
"arguments": tool_call.get("function", {}).get(
|
|
236
|
-
"arguments"
|
|
237
|
-
),
|
|
238
|
-
},
|
|
239
|
-
}
|
|
240
|
-
tool_calls.append(tool_call_dict)
|
|
241
|
-
elif "text" in first:
|
|
242
|
-
content = first["text"]
|
|
243
|
-
|
|
244
|
-
if not content and not tool_calls:
|
|
245
|
-
raise ValueError(f"Unexpected watsonx response format: {raw!r}")
|
|
246
|
-
|
|
247
|
-
# Return LLMResponse if tool calls exist, otherwise just content
|
|
248
|
-
if tool_calls:
|
|
249
|
-
return LLMResponse(content=content, tool_calls=tool_calls)
|
|
250
|
-
|
|
251
|
-
return content
|
|
252
145
|
|
|
253
146
|
if __name__ == "__main__":
|
|
254
147
|
provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
|
|
File without changes
|