ibm-watsonx-orchestrate-evaluation-framework 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info/METADATA +34 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/RECORD +60 -60
- wxo_agentic_evaluation/analytics/tools/analyzer.py +36 -21
- wxo_agentic_evaluation/analytics/tools/main.py +18 -7
- wxo_agentic_evaluation/analytics/tools/types.py +26 -11
- wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
- wxo_agentic_evaluation/analyze_run.py +69 -48
- wxo_agentic_evaluation/annotate.py +6 -4
- wxo_agentic_evaluation/arg_configs.py +8 -2
- wxo_agentic_evaluation/batch_annotate.py +78 -25
- wxo_agentic_evaluation/data_annotator.py +18 -13
- wxo_agentic_evaluation/description_quality_checker.py +20 -14
- wxo_agentic_evaluation/evaluation_package.py +114 -70
- wxo_agentic_evaluation/external_agent/__init__.py +18 -7
- wxo_agentic_evaluation/external_agent/external_validate.py +46 -35
- wxo_agentic_evaluation/external_agent/performance_test.py +32 -20
- wxo_agentic_evaluation/external_agent/types.py +12 -5
- wxo_agentic_evaluation/inference_backend.py +158 -73
- wxo_agentic_evaluation/llm_matching.py +4 -3
- wxo_agentic_evaluation/llm_rag_eval.py +7 -4
- wxo_agentic_evaluation/llm_user.py +7 -3
- wxo_agentic_evaluation/main.py +175 -67
- wxo_agentic_evaluation/metrics/llm_as_judge.py +2 -2
- wxo_agentic_evaluation/metrics/metrics.py +26 -12
- wxo_agentic_evaluation/prompt/template_render.py +32 -11
- wxo_agentic_evaluation/quick_eval.py +49 -23
- wxo_agentic_evaluation/record_chat.py +70 -33
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +58 -18
- wxo_agentic_evaluation/red_teaming/attack_generator.py +38 -18
- wxo_agentic_evaluation/red_teaming/attack_runner.py +43 -27
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +3 -1
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +23 -15
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +13 -8
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +41 -13
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +26 -16
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +17 -11
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +44 -29
- wxo_agentic_evaluation/referenceless_eval/metrics/field.py +13 -5
- wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +16 -5
- wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +8 -3
- wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +6 -2
- wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +5 -1
- wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +16 -3
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +23 -12
- wxo_agentic_evaluation/resource_map.py +2 -1
- wxo_agentic_evaluation/service_instance.py +24 -11
- wxo_agentic_evaluation/service_provider/__init__.py +33 -13
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +129 -26
- wxo_agentic_evaluation/service_provider/ollama_provider.py +10 -11
- wxo_agentic_evaluation/service_provider/provider.py +0 -1
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +34 -21
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +50 -22
- wxo_agentic_evaluation/tool_planner.py +128 -44
- wxo_agentic_evaluation/type.py +12 -9
- wxo_agentic_evaluation/utils/__init__.py +1 -0
- wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +41 -20
- wxo_agentic_evaluation/utils/rich_utils.py +23 -9
- wxo_agentic_evaluation/utils/utils.py +83 -52
- ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info/METADATA +0 -386
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import requests
|
|
3
3
|
import time
|
|
4
|
-
|
|
5
|
-
from typing import List
|
|
4
|
+
from typing import List, Tuple
|
|
6
5
|
from threading import Lock
|
|
7
6
|
|
|
8
7
|
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
@@ -12,6 +11,22 @@ AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys
|
|
|
12
11
|
AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
|
|
13
12
|
DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
14
13
|
|
|
14
|
+
def _infer_cpd_auth_url(instance_url: str) -> str:
|
|
15
|
+
inst = (instance_url or "").rstrip("/")
|
|
16
|
+
if not inst:
|
|
17
|
+
return "/icp4d-api/v1/authorize"
|
|
18
|
+
if "/orchestrate" in inst:
|
|
19
|
+
base = inst.split("/orchestrate", 1)[0].rstrip("/")
|
|
20
|
+
return base + "/icp4d-api/v1/authorize"
|
|
21
|
+
return inst + "/icp4d-api/v1/authorize"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _normalize_cpd_auth_url(url: str) -> str:
|
|
25
|
+
u = (url or "").rstrip("/")
|
|
26
|
+
if u.endswith("/icp4d-api"):
|
|
27
|
+
return u + "/v1/authorize"
|
|
28
|
+
return url
|
|
29
|
+
|
|
15
30
|
|
|
16
31
|
class ModelProxyProvider(Provider):
|
|
17
32
|
def __init__(
|
|
@@ -26,20 +41,43 @@ class ModelProxyProvider(Provider):
|
|
|
26
41
|
super().__init__()
|
|
27
42
|
|
|
28
43
|
instance_url = os.environ.get("WO_INSTANCE", instance_url)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
|
|
44
|
+
if not instance_url:
|
|
45
|
+
raise RuntimeError("instance url must be specified to use WO model proxy")
|
|
32
46
|
|
|
33
47
|
self.timeout = timeout
|
|
34
|
-
self.model_id = model_id
|
|
35
|
-
|
|
48
|
+
self.model_id = os.environ.get("MODEL_OVERRIDE",model_id)
|
|
36
49
|
self.embedding_model_id = embedding_model_id
|
|
37
50
|
|
|
38
|
-
self.api_key = api_key
|
|
51
|
+
self.api_key = os.environ.get("WO_API_KEY", api_key)
|
|
52
|
+
self.username = os.environ.get("WO_USERNAME", None)
|
|
53
|
+
self.password = os.environ.get("WO_PASSWORD", None)
|
|
54
|
+
self.auth_type = os.environ.get("WO_AUTH_TYPE", "").lower() # explicit override if set, otherwise inferred- match ADK values
|
|
55
|
+
explicit_auth_url = os.environ.get("AUTHORIZATION_URL", None)
|
|
56
|
+
|
|
39
57
|
self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
|
|
40
|
-
self.
|
|
41
|
-
|
|
42
|
-
self.
|
|
58
|
+
self.instance_url = instance_url.rstrip("/")
|
|
59
|
+
|
|
60
|
+
self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(explicit_auth_url=explicit_auth_url)
|
|
61
|
+
self._wo_ssl_verify = os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
|
|
62
|
+
env_space_id = os.environ.get("WATSONX_SPACE_ID", None)
|
|
63
|
+
if self.auth_mode == "cpd":
|
|
64
|
+
if not env_space_id or not env_space_id.strip():
|
|
65
|
+
raise RuntimeError("CPD mode requires WATSONX_SPACE_ID environment variable to be set")
|
|
66
|
+
self.space_id = env_space_id.strip()
|
|
67
|
+
else:
|
|
68
|
+
self.space_id = (env_space_id.strip() if env_space_id and env_space_id.strip() else "1")
|
|
69
|
+
|
|
70
|
+
if self.auth_mode == "cpd":
|
|
71
|
+
if "/orchestrate" in self.instance_url:
|
|
72
|
+
self.instance_url = self.instance_url.split("/orchestrate", 1)[0].rstrip("/")
|
|
73
|
+
if not self.username:
|
|
74
|
+
raise RuntimeError("CPD auth requires WO_USERNAME to be set")
|
|
75
|
+
if not (self.password or self.api_key):
|
|
76
|
+
raise RuntimeError("CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)")
|
|
77
|
+
else:
|
|
78
|
+
if not self.api_key:
|
|
79
|
+
raise RuntimeError("WO_API_KEY must be specified for SaaS or IBM IAM auth")
|
|
80
|
+
|
|
43
81
|
self.url = self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
|
|
44
82
|
self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
|
|
45
83
|
|
|
@@ -47,20 +85,85 @@ class ModelProxyProvider(Provider):
|
|
|
47
85
|
self.token, self.refresh_time = self.get_token()
|
|
48
86
|
self.params = params if params else DEFAULT_PARAM
|
|
49
87
|
|
|
50
|
-
def
|
|
88
|
+
def _resolve_auth_mode_and_url(
|
|
89
|
+
self,
|
|
90
|
+
explicit_auth_url: str | None
|
|
91
|
+
) -> Tuple[str, str]:
|
|
92
|
+
"""
|
|
93
|
+
Returns (auth_mode, auth_url)
|
|
94
|
+
- auth_mode: "cpd" | "ibm_iam" | "saas"
|
|
95
|
+
"""
|
|
96
|
+
if explicit_auth_url:
|
|
97
|
+
if "/icp4d-api" in explicit_auth_url:
|
|
98
|
+
return "cpd", _normalize_cpd_auth_url(explicit_auth_url)
|
|
99
|
+
if self.auth_type == "ibm_iam":
|
|
100
|
+
return "ibm_iam", explicit_auth_url
|
|
101
|
+
elif self.auth_type == "saas":
|
|
102
|
+
return "saas", explicit_auth_url
|
|
103
|
+
else:
|
|
104
|
+
mode = "ibm_iam" if self.is_ibm_cloud else "saas"
|
|
105
|
+
return mode, explicit_auth_url
|
|
106
|
+
|
|
107
|
+
if self.auth_type == "cpd":
|
|
108
|
+
inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
|
|
109
|
+
return "cpd", inferred_cpd_url
|
|
110
|
+
if self.auth_type == "ibm_iam":
|
|
111
|
+
return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
|
|
112
|
+
if self.auth_type == "saas":
|
|
113
|
+
return "saas", AUTH_ENDPOINT_AWS
|
|
114
|
+
|
|
115
|
+
if "/orchestrate" in self.instance_url:
|
|
116
|
+
inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
|
|
117
|
+
return "cpd", inferred_cpd_url
|
|
118
|
+
|
|
51
119
|
if self.is_ibm_cloud:
|
|
52
|
-
|
|
53
|
-
resp = requests.post(self.auth_url, data=payload)
|
|
54
|
-
token_key = "access_token"
|
|
120
|
+
return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
|
|
55
121
|
else:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
122
|
+
return "saas", AUTH_ENDPOINT_AWS
|
|
123
|
+
|
|
124
|
+
def get_token(self):
|
|
125
|
+
headers = {}
|
|
126
|
+
post_args = {}
|
|
127
|
+
timeout = 10
|
|
128
|
+
exchange_url = self.auth_url
|
|
129
|
+
|
|
130
|
+
if self.auth_mode == "ibm_iam":
|
|
131
|
+
headers = {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"}
|
|
132
|
+
form_data = {
|
|
133
|
+
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
|
134
|
+
"apikey": self.api_key
|
|
135
|
+
}
|
|
136
|
+
post_args = {"data": form_data}
|
|
137
|
+
resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
|
|
138
|
+
elif self.auth_mode == "cpd":
|
|
139
|
+
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
|
140
|
+
body = {"username": self.username}
|
|
141
|
+
if self.password:
|
|
142
|
+
body["password"] = self.password
|
|
143
|
+
else:
|
|
144
|
+
body["api_key"] = self.api_key
|
|
145
|
+
timeout = self.timeout
|
|
146
|
+
resp = requests.post(exchange_url, headers=headers, json=body, timeout=timeout, verify=self._wo_ssl_verify)
|
|
147
|
+
else:
|
|
148
|
+
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
|
149
|
+
post_args = {"json": {"apikey": self.api_key}}
|
|
150
|
+
resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
|
|
151
|
+
|
|
59
152
|
if resp.status_code == 200:
|
|
60
153
|
json_obj = resp.json()
|
|
61
|
-
token = json_obj
|
|
62
|
-
|
|
63
|
-
|
|
154
|
+
token = json_obj.get("access_token") or json_obj.get("token")
|
|
155
|
+
if not token:
|
|
156
|
+
raise RuntimeError(f"No token field found in response: {json_obj!r}")
|
|
157
|
+
|
|
158
|
+
expires_in = json_obj.get("expires_in")
|
|
159
|
+
try:
|
|
160
|
+
expires_in = int(expires_in) if expires_in is not None else None
|
|
161
|
+
except Exception:
|
|
162
|
+
expires_in = None
|
|
163
|
+
if not expires_in or expires_in <= 0:
|
|
164
|
+
expires_in = int(os.environ.get("TOKEN_DEFAULT_EXPIRES_IN", 1))
|
|
165
|
+
|
|
166
|
+
refresh_time = time.time() + int(0.8 * expires_in)
|
|
64
167
|
return token, refresh_time
|
|
65
168
|
|
|
66
169
|
resp.raise_for_status()
|
|
@@ -80,9 +183,9 @@ class ModelProxyProvider(Provider):
|
|
|
80
183
|
|
|
81
184
|
self.refresh_token_if_expires()
|
|
82
185
|
headers = self.get_header()
|
|
83
|
-
payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id":
|
|
186
|
+
payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": self.space_id}
|
|
84
187
|
#"timeout": self.timeout}
|
|
85
|
-
resp = requests.post(self.embedding_url, json=payload, headers=headers)
|
|
188
|
+
resp = requests.post(self.embedding_url, json=payload, headers=headers, verify=self._wo_ssl_verify)
|
|
86
189
|
|
|
87
190
|
if resp.status_code == 200:
|
|
88
191
|
json_obj = resp.json()
|
|
@@ -95,9 +198,9 @@ class ModelProxyProvider(Provider):
|
|
|
95
198
|
raise Exception("model id must be specified for text generation")
|
|
96
199
|
self.refresh_token_if_expires()
|
|
97
200
|
headers = self.get_header()
|
|
98
|
-
payload = {"input": sentence, "model_id": self.model_id, "space_id":
|
|
201
|
+
payload = {"input": sentence, "model_id": self.model_id, "space_id": self.space_id,
|
|
99
202
|
"timeout": self.timeout, "parameters": self.params}
|
|
100
|
-
resp = requests.post(self.url, json=payload, headers=headers)
|
|
203
|
+
resp = requests.post(self.url, json=payload, headers=headers, verify=self._wo_ssl_verify)
|
|
101
204
|
if resp.status_code == 200:
|
|
102
205
|
return resp.json()["results"][0]["generated_text"]
|
|
103
206
|
|
|
@@ -106,4 +209,4 @@ class ModelProxyProvider(Provider):
|
|
|
106
209
|
|
|
107
210
|
if __name__ == "__main__":
|
|
108
211
|
provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
|
|
109
|
-
print(provider.query("ok"))
|
|
212
|
+
print(provider.query("ok"))
|
|
@@ -1,17 +1,16 @@
|
|
|
1
|
-
import requests
|
|
2
1
|
import json
|
|
3
|
-
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
4
|
-
from typing import List
|
|
5
2
|
import os
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
6
8
|
|
|
7
9
|
OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class OllamaProvider(Provider):
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
model_id=None
|
|
14
|
-
):
|
|
13
|
+
def __init__(self, model_id=None):
|
|
15
14
|
self.url = OLLAMA_URL + "/api/generate"
|
|
16
15
|
self.model_id = model_id
|
|
17
16
|
super().__init__()
|
|
@@ -20,14 +19,14 @@ class OllamaProvider(Provider):
|
|
|
20
19
|
payload = {"model": self.model_id, "prompt": sentence}
|
|
21
20
|
resp = requests.post(self.url, json=payload, stream=True)
|
|
22
21
|
final_text = ""
|
|
23
|
-
data = b
|
|
22
|
+
data = b""
|
|
24
23
|
for chunk in resp:
|
|
25
24
|
data += chunk
|
|
26
|
-
if data.endswith(b
|
|
25
|
+
if data.endswith(b"\n"):
|
|
27
26
|
json_obj = json.loads(data)
|
|
28
27
|
if not json_obj["done"] and json_obj["response"]:
|
|
29
28
|
final_text += json_obj["response"]
|
|
30
|
-
data = b
|
|
29
|
+
data = b""
|
|
31
30
|
|
|
32
31
|
return final_text
|
|
33
32
|
|
|
@@ -37,4 +36,4 @@ class OllamaProvider(Provider):
|
|
|
37
36
|
|
|
38
37
|
if __name__ == "__main__":
|
|
39
38
|
provider = OllamaProvider(model_id="llama3.1:8b")
|
|
40
|
-
print(provider.query("ok"))
|
|
39
|
+
print(provider.query("ok"))
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
import requests
|
|
2
|
-
from typing import List, Mapping, Union, Optional, Any
|
|
3
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, List, Mapping, Optional, Union
|
|
4
3
|
|
|
4
|
+
import requests
|
|
5
5
|
import rich
|
|
6
6
|
|
|
7
|
-
from wxo_agentic_evaluation.service_provider.model_proxy_provider import
|
|
8
|
-
|
|
7
|
+
from wxo_agentic_evaluation.service_provider.model_proxy_provider import (
|
|
8
|
+
ModelProxyProvider,
|
|
9
|
+
)
|
|
10
|
+
from wxo_agentic_evaluation.service_provider.watsonx_provider import (
|
|
11
|
+
WatsonXProvider,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class LLMResponse:
|
|
@@ -14,7 +18,9 @@ class LLMResponse:
|
|
|
14
18
|
Response object that can contain both content and tool calls
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None
|
|
23
|
+
):
|
|
18
24
|
self.content = content
|
|
19
25
|
self.tool_calls = tool_calls or []
|
|
20
26
|
|
|
@@ -26,25 +32,26 @@ class LLMResponse:
|
|
|
26
32
|
"""Return a string representation of the LLMResponse object."""
|
|
27
33
|
return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
|
|
28
34
|
|
|
35
|
+
|
|
29
36
|
class LLMKitWrapper(ABC):
|
|
30
|
-
"""
|
|
37
|
+
"""In the future this wrapper won't be neccesary.
|
|
31
38
|
Right now the referenceless code requires a `generate()` function for the metrics client.
|
|
32
39
|
In refactor, rewrite referenceless code so this wrapper is not needed.
|
|
33
40
|
"""
|
|
41
|
+
|
|
34
42
|
@abstractmethod
|
|
35
43
|
def chat():
|
|
36
44
|
pass
|
|
37
45
|
|
|
38
46
|
def generate(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
self,
|
|
48
|
+
prompt: Union[str, List[Mapping[str, str]]],
|
|
49
|
+
*,
|
|
50
|
+
schema,
|
|
51
|
+
retries: int = 3,
|
|
52
|
+
generation_args: Optional[Any] = None,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
):
|
|
48
55
|
"""
|
|
49
56
|
In future, implement validation of response like in llmevalkit
|
|
50
57
|
"""
|
|
@@ -55,7 +62,9 @@ class LLMKitWrapper(ABC):
|
|
|
55
62
|
response = self._parse_llm_response(raw_response)
|
|
56
63
|
return response
|
|
57
64
|
except Exception as e:
|
|
58
|
-
rich.print(
|
|
65
|
+
rich.print(
|
|
66
|
+
f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))"
|
|
67
|
+
)
|
|
59
68
|
|
|
60
69
|
def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
|
|
61
70
|
"""
|
|
@@ -82,10 +91,12 @@ class LLMKitWrapper(ABC):
|
|
|
82
91
|
"id": tool_call.get("id"),
|
|
83
92
|
"type": tool_call.get("type", "function"),
|
|
84
93
|
"function": {
|
|
85
|
-
"name": tool_call.get("function", {}).get(
|
|
86
|
-
|
|
87
|
-
"arguments"
|
|
94
|
+
"name": tool_call.get("function", {}).get(
|
|
95
|
+
"name"
|
|
88
96
|
),
|
|
97
|
+
"arguments": tool_call.get(
|
|
98
|
+
"function", {}
|
|
99
|
+
).get("arguments"),
|
|
89
100
|
},
|
|
90
101
|
}
|
|
91
102
|
tool_calls.append(tool_call_dict)
|
|
@@ -101,6 +112,7 @@ class LLMKitWrapper(ABC):
|
|
|
101
112
|
|
|
102
113
|
return content
|
|
103
114
|
|
|
115
|
+
|
|
104
116
|
class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
|
|
105
117
|
def chat(self, sentence: List[str]):
|
|
106
118
|
if self.model_id is None:
|
|
@@ -113,7 +125,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
|
|
|
113
125
|
"messages": sentence,
|
|
114
126
|
"parameters": self.params,
|
|
115
127
|
"space_id": "1",
|
|
116
|
-
"timeout": self.timeout
|
|
128
|
+
"timeout": self.timeout,
|
|
117
129
|
}
|
|
118
130
|
resp = requests.post(url=chat_url, headers=headers, json=data)
|
|
119
131
|
if resp.status_code == 200:
|
|
@@ -121,6 +133,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
|
|
|
121
133
|
else:
|
|
122
134
|
resp.raise_for_status()
|
|
123
135
|
|
|
136
|
+
|
|
124
137
|
class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
|
|
125
138
|
def chat(self, sentence: list):
|
|
126
139
|
chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
|
|
@@ -129,7 +142,7 @@ class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
|
|
|
129
142
|
"model_id": self.model_id,
|
|
130
143
|
"messages": sentence,
|
|
131
144
|
"parameters": self.params,
|
|
132
|
-
"space_id": self.space_id
|
|
145
|
+
"space_id": self.space_id,
|
|
133
146
|
}
|
|
134
147
|
resp = requests.post(url=chat_url, headers=headers, json=data)
|
|
135
148
|
if resp.status_code == 200:
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
-
import
|
|
2
|
-
import requests
|
|
1
|
+
import dataclasses
|
|
3
2
|
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from threading import Lock
|
|
4
6
|
from types import MappingProxyType
|
|
5
7
|
from typing import List, Mapping, Union
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
|
|
9
|
+
import requests
|
|
10
|
+
|
|
9
11
|
from wxo_agentic_evaluation.service_provider.provider import Provider
|
|
10
12
|
|
|
11
13
|
ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
|
|
@@ -36,7 +38,9 @@ class WatsonXProvider(Provider):
|
|
|
36
38
|
super().__init__()
|
|
37
39
|
self.url = url
|
|
38
40
|
if (embedding_model_id is None) and (model_id is None):
|
|
39
|
-
raise Exception(
|
|
41
|
+
raise Exception(
|
|
42
|
+
"either model_id or embedding_model_id must be specified"
|
|
43
|
+
)
|
|
40
44
|
self.model_id = model_id
|
|
41
45
|
api_key = os.environ.get("WATSONX_APIKEY", api_key)
|
|
42
46
|
if not api_key:
|
|
@@ -56,7 +60,7 @@ class WatsonXProvider(Provider):
|
|
|
56
60
|
self.lock = Lock()
|
|
57
61
|
|
|
58
62
|
self.params = params if params else DEFAULT_PARAM
|
|
59
|
-
|
|
63
|
+
|
|
60
64
|
if isinstance(self.params, MappingProxyType):
|
|
61
65
|
self.params = dict(self.params)
|
|
62
66
|
if dataclasses.is_dataclass(self.params):
|
|
@@ -68,7 +72,10 @@ class WatsonXProvider(Provider):
|
|
|
68
72
|
|
|
69
73
|
def _get_access_token(self):
|
|
70
74
|
response = requests.post(
|
|
71
|
-
self.url,
|
|
75
|
+
self.url,
|
|
76
|
+
headers=ACCESS_HEADER,
|
|
77
|
+
data=self.access_data,
|
|
78
|
+
timeout=self.timeout,
|
|
72
79
|
)
|
|
73
80
|
if response.status_code == 200:
|
|
74
81
|
token_data = json.loads(response.text)
|
|
@@ -84,16 +91,24 @@ class WatsonXProvider(Provider):
|
|
|
84
91
|
)
|
|
85
92
|
|
|
86
93
|
def prepare_header(self):
|
|
87
|
-
headers = {
|
|
88
|
-
|
|
94
|
+
headers = {
|
|
95
|
+
"Authorization": f"Bearer {self.access_token}",
|
|
96
|
+
"Content-Type": "application/json",
|
|
97
|
+
}
|
|
89
98
|
return headers
|
|
90
99
|
|
|
91
100
|
def _query(self, sentence: str):
|
|
92
101
|
headers = self.prepare_header()
|
|
93
102
|
|
|
94
|
-
data = {
|
|
95
|
-
|
|
96
|
-
|
|
103
|
+
data = {
|
|
104
|
+
"model_id": self.model_id,
|
|
105
|
+
"input": sentence,
|
|
106
|
+
"parameters": self.params,
|
|
107
|
+
"space_id": self.space_id,
|
|
108
|
+
}
|
|
109
|
+
generation_url = (
|
|
110
|
+
f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
|
|
111
|
+
)
|
|
97
112
|
resp = requests.post(url=generation_url, headers=headers, json=data)
|
|
98
113
|
if resp.status_code == 200:
|
|
99
114
|
return resp.json()["results"][0]
|
|
@@ -105,20 +120,25 @@ class WatsonXProvider(Provider):
|
|
|
105
120
|
if not self.access_token or time.time() > self.refresh_time:
|
|
106
121
|
with self.lock:
|
|
107
122
|
if not self.access_token or time.time() > self.refresh_time:
|
|
108
|
-
|
|
123
|
+
(
|
|
124
|
+
self.access_token,
|
|
125
|
+
self.refresh_time,
|
|
126
|
+
) = self._get_access_token()
|
|
109
127
|
|
|
110
128
|
def query(self, sentence: Union[str, Mapping[str, str]]) -> str:
|
|
111
129
|
if self.model_id is None:
|
|
112
130
|
raise Exception("model id must be specified for text generation")
|
|
113
131
|
try:
|
|
114
132
|
response = self._query(sentence)
|
|
115
|
-
if
|
|
133
|
+
if generated_text := response.get("generated_text"):
|
|
116
134
|
return generated_text
|
|
117
|
-
elif
|
|
135
|
+
elif message := response.get("message"):
|
|
118
136
|
return message
|
|
119
137
|
else:
|
|
120
|
-
raise ValueError(
|
|
121
|
-
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Unexpected response from WatsonX: {response}"
|
|
140
|
+
)
|
|
141
|
+
|
|
122
142
|
except Exception as e:
|
|
123
143
|
with self.lock:
|
|
124
144
|
if "authentication_token_expired" in str(e):
|
|
@@ -130,12 +150,18 @@ class WatsonXProvider(Provider):
|
|
|
130
150
|
|
|
131
151
|
def encode(self, sentences: List[str]) -> List[list]:
|
|
132
152
|
if self.embedding_model_id is None:
|
|
133
|
-
raise Exception(
|
|
153
|
+
raise Exception(
|
|
154
|
+
"embedding model id must be specified for text encoding"
|
|
155
|
+
)
|
|
134
156
|
|
|
135
157
|
headers = self.prepare_header()
|
|
136
158
|
url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
|
|
137
159
|
|
|
138
|
-
data = {
|
|
160
|
+
data = {
|
|
161
|
+
"inputs": sentences,
|
|
162
|
+
"model_id": self.model_id,
|
|
163
|
+
"space_id": self.space_id,
|
|
164
|
+
}
|
|
139
165
|
resp = requests.post(url=url, headers=headers, json=data)
|
|
140
166
|
if resp.status_code == 200:
|
|
141
167
|
return [entry["embedding"] for entry in resp.json()["results"]]
|
|
@@ -144,7 +170,9 @@ class WatsonXProvider(Provider):
|
|
|
144
170
|
|
|
145
171
|
|
|
146
172
|
if __name__ == "__main__":
|
|
147
|
-
provider = WatsonXProvider(
|
|
173
|
+
provider = WatsonXProvider(
|
|
174
|
+
model_id="meta-llama/llama-3-2-90b-vision-instruct"
|
|
175
|
+
)
|
|
148
176
|
|
|
149
177
|
prompt = """
|
|
150
178
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
@@ -176,4 +204,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
|
|
|
176
204
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
177
205
|
"""
|
|
178
206
|
|
|
179
|
-
print(provider.query(prompt))
|
|
207
|
+
print(provider.query(prompt))
|