ibm-watsonx-orchestrate-evaluation-framework 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.

Files changed (61) hide show
  1. ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info/METADATA +34 -0
  2. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/RECORD +60 -60
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +36 -21
  4. wxo_agentic_evaluation/analytics/tools/main.py +18 -7
  5. wxo_agentic_evaluation/analytics/tools/types.py +26 -11
  6. wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
  7. wxo_agentic_evaluation/analyze_run.py +69 -48
  8. wxo_agentic_evaluation/annotate.py +6 -4
  9. wxo_agentic_evaluation/arg_configs.py +8 -2
  10. wxo_agentic_evaluation/batch_annotate.py +78 -25
  11. wxo_agentic_evaluation/data_annotator.py +18 -13
  12. wxo_agentic_evaluation/description_quality_checker.py +20 -14
  13. wxo_agentic_evaluation/evaluation_package.py +114 -70
  14. wxo_agentic_evaluation/external_agent/__init__.py +18 -7
  15. wxo_agentic_evaluation/external_agent/external_validate.py +46 -35
  16. wxo_agentic_evaluation/external_agent/performance_test.py +32 -20
  17. wxo_agentic_evaluation/external_agent/types.py +12 -5
  18. wxo_agentic_evaluation/inference_backend.py +158 -73
  19. wxo_agentic_evaluation/llm_matching.py +4 -3
  20. wxo_agentic_evaluation/llm_rag_eval.py +7 -4
  21. wxo_agentic_evaluation/llm_user.py +7 -3
  22. wxo_agentic_evaluation/main.py +175 -67
  23. wxo_agentic_evaluation/metrics/llm_as_judge.py +2 -2
  24. wxo_agentic_evaluation/metrics/metrics.py +26 -12
  25. wxo_agentic_evaluation/prompt/template_render.py +32 -11
  26. wxo_agentic_evaluation/quick_eval.py +49 -23
  27. wxo_agentic_evaluation/record_chat.py +70 -33
  28. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +58 -18
  29. wxo_agentic_evaluation/red_teaming/attack_generator.py +38 -18
  30. wxo_agentic_evaluation/red_teaming/attack_runner.py +43 -27
  31. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +3 -1
  32. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +23 -15
  33. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +13 -8
  34. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +41 -13
  35. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +26 -16
  36. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +17 -11
  37. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +44 -29
  38. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +13 -5
  39. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +16 -5
  40. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +8 -3
  41. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +6 -2
  42. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +5 -1
  43. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +16 -3
  44. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +23 -12
  45. wxo_agentic_evaluation/resource_map.py +2 -1
  46. wxo_agentic_evaluation/service_instance.py +24 -11
  47. wxo_agentic_evaluation/service_provider/__init__.py +33 -13
  48. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +129 -26
  49. wxo_agentic_evaluation/service_provider/ollama_provider.py +10 -11
  50. wxo_agentic_evaluation/service_provider/provider.py +0 -1
  51. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +34 -21
  52. wxo_agentic_evaluation/service_provider/watsonx_provider.py +50 -22
  53. wxo_agentic_evaluation/tool_planner.py +128 -44
  54. wxo_agentic_evaluation/type.py +12 -9
  55. wxo_agentic_evaluation/utils/__init__.py +1 -0
  56. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +41 -20
  57. wxo_agentic_evaluation/utils/rich_utils.py +23 -9
  58. wxo_agentic_evaluation/utils/utils.py +83 -52
  59. ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info/METADATA +0 -386
  60. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/WHEEL +0 -0
  61. {ibm_watsonx_orchestrate_evaluation_framework-1.1.1.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
1
  import os
2
2
  import requests
3
3
  import time
4
-
5
- from typing import List
4
+ from typing import List, Tuple
6
5
  from threading import Lock
7
6
 
8
7
  from wxo_agentic_evaluation.service_provider.provider import Provider
@@ -12,6 +11,22 @@ AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys
12
11
  AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
13
12
  DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
14
13
 
14
+ def _infer_cpd_auth_url(instance_url: str) -> str:
15
+ inst = (instance_url or "").rstrip("/")
16
+ if not inst:
17
+ return "/icp4d-api/v1/authorize"
18
+ if "/orchestrate" in inst:
19
+ base = inst.split("/orchestrate", 1)[0].rstrip("/")
20
+ return base + "/icp4d-api/v1/authorize"
21
+ return inst + "/icp4d-api/v1/authorize"
22
+
23
+
24
+ def _normalize_cpd_auth_url(url: str) -> str:
25
+ u = (url or "").rstrip("/")
26
+ if u.endswith("/icp4d-api"):
27
+ return u + "/v1/authorize"
28
+ return url
29
+
15
30
 
16
31
  class ModelProxyProvider(Provider):
17
32
  def __init__(
@@ -26,20 +41,43 @@ class ModelProxyProvider(Provider):
26
41
  super().__init__()
27
42
 
28
43
  instance_url = os.environ.get("WO_INSTANCE", instance_url)
29
- api_key = os.environ.get("WO_API_KEY", api_key)
30
- if not instance_url or not api_key:
31
- raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
44
+ if not instance_url:
45
+ raise RuntimeError("instance url must be specified to use WO model proxy")
32
46
 
33
47
  self.timeout = timeout
34
- self.model_id = model_id
35
-
48
+ self.model_id = os.environ.get("MODEL_OVERRIDE",model_id)
36
49
  self.embedding_model_id = embedding_model_id
37
50
 
38
- self.api_key = api_key
51
+ self.api_key = os.environ.get("WO_API_KEY", api_key)
52
+ self.username = os.environ.get("WO_USERNAME", None)
53
+ self.password = os.environ.get("WO_PASSWORD", None)
54
+ self.auth_type = os.environ.get("WO_AUTH_TYPE", "").lower() # explicit override if set, otherwise inferred- match ADK values
55
+ explicit_auth_url = os.environ.get("AUTHORIZATION_URL", None)
56
+
39
57
  self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
40
- self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
41
-
42
- self.instance_url = instance_url
58
+ self.instance_url = instance_url.rstrip("/")
59
+
60
+ self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(explicit_auth_url=explicit_auth_url)
61
+ self._wo_ssl_verify = os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
62
+ env_space_id = os.environ.get("WATSONX_SPACE_ID", None)
63
+ if self.auth_mode == "cpd":
64
+ if not env_space_id or not env_space_id.strip():
65
+ raise RuntimeError("CPD mode requires WATSONX_SPACE_ID environment variable to be set")
66
+ self.space_id = env_space_id.strip()
67
+ else:
68
+ self.space_id = (env_space_id.strip() if env_space_id and env_space_id.strip() else "1")
69
+
70
+ if self.auth_mode == "cpd":
71
+ if "/orchestrate" in self.instance_url:
72
+ self.instance_url = self.instance_url.split("/orchestrate", 1)[0].rstrip("/")
73
+ if not self.username:
74
+ raise RuntimeError("CPD auth requires WO_USERNAME to be set")
75
+ if not (self.password or self.api_key):
76
+ raise RuntimeError("CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)")
77
+ else:
78
+ if not self.api_key:
79
+ raise RuntimeError("WO_API_KEY must be specified for SaaS or IBM IAM auth")
80
+
43
81
  self.url = self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
44
82
  self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
45
83
 
@@ -47,20 +85,85 @@ class ModelProxyProvider(Provider):
47
85
  self.token, self.refresh_time = self.get_token()
48
86
  self.params = params if params else DEFAULT_PARAM
49
87
 
50
- def get_token(self):
88
+ def _resolve_auth_mode_and_url(
89
+ self,
90
+ explicit_auth_url: str | None
91
+ ) -> Tuple[str, str]:
92
+ """
93
+ Returns (auth_mode, auth_url)
94
+ - auth_mode: "cpd" | "ibm_iam" | "saas"
95
+ """
96
+ if explicit_auth_url:
97
+ if "/icp4d-api" in explicit_auth_url:
98
+ return "cpd", _normalize_cpd_auth_url(explicit_auth_url)
99
+ if self.auth_type == "ibm_iam":
100
+ return "ibm_iam", explicit_auth_url
101
+ elif self.auth_type == "saas":
102
+ return "saas", explicit_auth_url
103
+ else:
104
+ mode = "ibm_iam" if self.is_ibm_cloud else "saas"
105
+ return mode, explicit_auth_url
106
+
107
+ if self.auth_type == "cpd":
108
+ inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
109
+ return "cpd", inferred_cpd_url
110
+ if self.auth_type == "ibm_iam":
111
+ return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
112
+ if self.auth_type == "saas":
113
+ return "saas", AUTH_ENDPOINT_AWS
114
+
115
+ if "/orchestrate" in self.instance_url:
116
+ inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
117
+ return "cpd", inferred_cpd_url
118
+
51
119
  if self.is_ibm_cloud:
52
- payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
53
- resp = requests.post(self.auth_url, data=payload)
54
- token_key = "access_token"
120
+ return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
55
121
  else:
56
- payload = {"apikey": self.api_key}
57
- resp = requests.post(self.auth_url, json=payload)
58
- token_key = "token"
122
+ return "saas", AUTH_ENDPOINT_AWS
123
+
124
+ def get_token(self):
125
+ headers = {}
126
+ post_args = {}
127
+ timeout = 10
128
+ exchange_url = self.auth_url
129
+
130
+ if self.auth_mode == "ibm_iam":
131
+ headers = {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"}
132
+ form_data = {
133
+ "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
134
+ "apikey": self.api_key
135
+ }
136
+ post_args = {"data": form_data}
137
+ resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
138
+ elif self.auth_mode == "cpd":
139
+ headers = {"Accept": "application/json", "Content-Type": "application/json"}
140
+ body = {"username": self.username}
141
+ if self.password:
142
+ body["password"] = self.password
143
+ else:
144
+ body["api_key"] = self.api_key
145
+ timeout = self.timeout
146
+ resp = requests.post(exchange_url, headers=headers, json=body, timeout=timeout, verify=self._wo_ssl_verify)
147
+ else:
148
+ headers = {"Accept": "application/json", "Content-Type": "application/json"}
149
+ post_args = {"json": {"apikey": self.api_key}}
150
+ resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
151
+
59
152
  if resp.status_code == 200:
60
153
  json_obj = resp.json()
61
- token = json_obj[token_key]
62
- expires_in = json_obj["expires_in"]
63
- refresh_time = time.time() + int(0.8*expires_in)
154
+ token = json_obj.get("access_token") or json_obj.get("token")
155
+ if not token:
156
+ raise RuntimeError(f"No token field found in response: {json_obj!r}")
157
+
158
+ expires_in = json_obj.get("expires_in")
159
+ try:
160
+ expires_in = int(expires_in) if expires_in is not None else None
161
+ except Exception:
162
+ expires_in = None
163
+ if not expires_in or expires_in <= 0:
164
+ expires_in = int(os.environ.get("TOKEN_DEFAULT_EXPIRES_IN", 1))
165
+
166
+ refresh_time = time.time() + int(0.8 * expires_in)
64
167
  return token, refresh_time
65
168
 
66
169
  resp.raise_for_status()
@@ -80,9 +183,9 @@ class ModelProxyProvider(Provider):
80
183
 
81
184
  self.refresh_token_if_expires()
82
185
  headers = self.get_header()
83
- payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
186
+ payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": self.space_id}
84
187
  #"timeout": self.timeout}
85
- resp = requests.post(self.embedding_url, json=payload, headers=headers)
188
+ resp = requests.post(self.embedding_url, json=payload, headers=headers, verify=self._wo_ssl_verify)
86
189
 
87
190
  if resp.status_code == 200:
88
191
  json_obj = resp.json()
@@ -95,9 +198,9 @@ class ModelProxyProvider(Provider):
95
198
  raise Exception("model id must be specified for text generation")
96
199
  self.refresh_token_if_expires()
97
200
  headers = self.get_header()
98
- payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
201
+ payload = {"input": sentence, "model_id": self.model_id, "space_id": self.space_id,
99
202
  "timeout": self.timeout, "parameters": self.params}
100
- resp = requests.post(self.url, json=payload, headers=headers)
203
+ resp = requests.post(self.url, json=payload, headers=headers, verify=self._wo_ssl_verify)
101
204
  if resp.status_code == 200:
102
205
  return resp.json()["results"][0]["generated_text"]
103
206
 
@@ -106,4 +209,4 @@ class ModelProxyProvider(Provider):
106
209
 
107
210
  if __name__ == "__main__":
108
211
  provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
109
- print(provider.query("ok"))
212
+ print(provider.query("ok"))
@@ -1,17 +1,16 @@
1
- import requests
2
1
  import json
3
- from wxo_agentic_evaluation.service_provider.provider import Provider
4
- from typing import List
5
2
  import os
3
+ from typing import List
4
+
5
+ import requests
6
+
7
+ from wxo_agentic_evaluation.service_provider.provider import Provider
6
8
 
7
9
  OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
8
10
 
9
11
 
10
12
  class OllamaProvider(Provider):
11
- def __init__(
12
- self,
13
- model_id=None
14
- ):
13
+ def __init__(self, model_id=None):
15
14
  self.url = OLLAMA_URL + "/api/generate"
16
15
  self.model_id = model_id
17
16
  super().__init__()
@@ -20,14 +19,14 @@ class OllamaProvider(Provider):
20
19
  payload = {"model": self.model_id, "prompt": sentence}
21
20
  resp = requests.post(self.url, json=payload, stream=True)
22
21
  final_text = ""
23
- data = b''
22
+ data = b""
24
23
  for chunk in resp:
25
24
  data += chunk
26
- if data.endswith(b'\n'):
25
+ if data.endswith(b"\n"):
27
26
  json_obj = json.loads(data)
28
27
  if not json_obj["done"] and json_obj["response"]:
29
28
  final_text += json_obj["response"]
30
- data = b''
29
+ data = b""
31
30
 
32
31
  return final_text
33
32
 
@@ -37,4 +36,4 @@ class OllamaProvider(Provider):
37
36
 
38
37
  if __name__ == "__main__":
39
38
  provider = OllamaProvider(model_id="llama3.1:8b")
40
- print(provider.query("ok"))
39
+ print(provider.query("ok"))
@@ -16,4 +16,3 @@ class Provider(ABC):
16
16
  @abstractmethod
17
17
  def encode(self, sentences: List[str]) -> List[list]:
18
18
  pass
19
-
@@ -1,11 +1,15 @@
1
- import requests
2
- from typing import List, Mapping, Union, Optional, Any
3
1
  from abc import ABC, abstractmethod
2
+ from typing import Any, List, Mapping, Optional, Union
4
3
 
4
+ import requests
5
5
  import rich
6
6
 
7
- from wxo_agentic_evaluation.service_provider.model_proxy_provider import ModelProxyProvider
8
- from wxo_agentic_evaluation.service_provider.watsonx_provider import WatsonXProvider
7
+ from wxo_agentic_evaluation.service_provider.model_proxy_provider import (
8
+ ModelProxyProvider,
9
+ )
10
+ from wxo_agentic_evaluation.service_provider.watsonx_provider import (
11
+ WatsonXProvider,
12
+ )
9
13
 
10
14
 
11
15
  class LLMResponse:
@@ -14,7 +18,9 @@ class LLMResponse:
14
18
  Response object that can contain both content and tool calls
15
19
  """
16
20
 
17
- def __init__(self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None):
21
+ def __init__(
22
+ self, content: str, tool_calls: Optional[List[Mapping[str, Any]]] = None
23
+ ):
18
24
  self.content = content
19
25
  self.tool_calls = tool_calls or []
20
26
 
@@ -26,25 +32,26 @@ class LLMResponse:
26
32
  """Return a string representation of the LLMResponse object."""
27
33
  return f"LLMResponse(content='{self.content}', tool_calls={self.tool_calls})"
28
34
 
35
+
29
36
  class LLMKitWrapper(ABC):
30
- """ In the future this wrapper won't be neccesary.
37
+ """In the future this wrapper won't be neccesary.
31
38
  Right now the referenceless code requires a `generate()` function for the metrics client.
32
39
  In refactor, rewrite referenceless code so this wrapper is not needed.
33
40
  """
41
+
34
42
  @abstractmethod
35
43
  def chat():
36
44
  pass
37
45
 
38
46
  def generate(
39
- self,
40
- prompt: Union[str, List[Mapping[str, str]]],
41
- *,
42
- schema,
43
- retries: int = 3,
44
- generation_args: Optional[Any] = None,
45
- **kwargs: Any
46
- ):
47
-
47
+ self,
48
+ prompt: Union[str, List[Mapping[str, str]]],
49
+ *,
50
+ schema,
51
+ retries: int = 3,
52
+ generation_args: Optional[Any] = None,
53
+ **kwargs: Any,
54
+ ):
48
55
  """
49
56
  In future, implement validation of response like in llmevalkit
50
57
  """
@@ -55,7 +62,9 @@ class LLMKitWrapper(ABC):
55
62
  response = self._parse_llm_response(raw_response)
56
63
  return response
57
64
  except Exception as e:
58
- rich.print(f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))")
65
+ rich.print(
66
+ f"[b][r] Generation failed with error '{str(e)}' during `quick-eval` ... Attempt ({attempt} / {retries}))"
67
+ )
59
68
 
60
69
  def _parse_llm_response(self, raw: Any) -> Union[str, LLMResponse]:
61
70
  """
@@ -82,10 +91,12 @@ class LLMKitWrapper(ABC):
82
91
  "id": tool_call.get("id"),
83
92
  "type": tool_call.get("type", "function"),
84
93
  "function": {
85
- "name": tool_call.get("function", {}).get("name"),
86
- "arguments": tool_call.get("function", {}).get(
87
- "arguments"
94
+ "name": tool_call.get("function", {}).get(
95
+ "name"
88
96
  ),
97
+ "arguments": tool_call.get(
98
+ "function", {}
99
+ ).get("arguments"),
89
100
  },
90
101
  }
91
102
  tool_calls.append(tool_call_dict)
@@ -101,6 +112,7 @@ class LLMKitWrapper(ABC):
101
112
 
102
113
  return content
103
114
 
115
+
104
116
  class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
105
117
  def chat(self, sentence: List[str]):
106
118
  if self.model_id is None:
@@ -113,7 +125,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
113
125
  "messages": sentence,
114
126
  "parameters": self.params,
115
127
  "space_id": "1",
116
- "timeout": self.timeout
128
+ "timeout": self.timeout,
117
129
  }
118
130
  resp = requests.post(url=chat_url, headers=headers, json=data)
119
131
  if resp.status_code == 200:
@@ -121,6 +133,7 @@ class ModelProxyProviderLLMKitWrapper(ModelProxyProvider, LLMKitWrapper):
121
133
  else:
122
134
  resp.raise_for_status()
123
135
 
136
+
124
137
  class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
125
138
  def chat(self, sentence: list):
126
139
  chat_url = f"{self.api_endpoint}/ml/v1/text/chat?version=2023-05-02"
@@ -129,7 +142,7 @@ class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
129
142
  "model_id": self.model_id,
130
143
  "messages": sentence,
131
144
  "parameters": self.params,
132
- "space_id": self.space_id
145
+ "space_id": self.space_id,
133
146
  }
134
147
  resp = requests.post(url=chat_url, headers=headers, json=data)
135
148
  if resp.status_code == 200:
@@ -1,11 +1,13 @@
1
- import os
2
- import requests
1
+ import dataclasses
3
2
  import json
3
+ import os
4
+ import time
5
+ from threading import Lock
4
6
  from types import MappingProxyType
5
7
  from typing import List, Mapping, Union
6
- import dataclasses
7
- from threading import Lock
8
- import time
8
+
9
+ import requests
10
+
9
11
  from wxo_agentic_evaluation.service_provider.provider import Provider
10
12
 
11
13
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
@@ -36,7 +38,9 @@ class WatsonXProvider(Provider):
36
38
  super().__init__()
37
39
  self.url = url
38
40
  if (embedding_model_id is None) and (model_id is None):
39
- raise Exception("either model_id or embedding_model_id must be specified")
41
+ raise Exception(
42
+ "either model_id or embedding_model_id must be specified"
43
+ )
40
44
  self.model_id = model_id
41
45
  api_key = os.environ.get("WATSONX_APIKEY", api_key)
42
46
  if not api_key:
@@ -56,7 +60,7 @@ class WatsonXProvider(Provider):
56
60
  self.lock = Lock()
57
61
 
58
62
  self.params = params if params else DEFAULT_PARAM
59
-
63
+
60
64
  if isinstance(self.params, MappingProxyType):
61
65
  self.params = dict(self.params)
62
66
  if dataclasses.is_dataclass(self.params):
@@ -68,7 +72,10 @@ class WatsonXProvider(Provider):
68
72
 
69
73
  def _get_access_token(self):
70
74
  response = requests.post(
71
- self.url, headers=ACCESS_HEADER, data=self.access_data, timeout=self.timeout
75
+ self.url,
76
+ headers=ACCESS_HEADER,
77
+ data=self.access_data,
78
+ timeout=self.timeout,
72
79
  )
73
80
  if response.status_code == 200:
74
81
  token_data = json.loads(response.text)
@@ -84,16 +91,24 @@ class WatsonXProvider(Provider):
84
91
  )
85
92
 
86
93
  def prepare_header(self):
87
- headers = {"Authorization": f"Bearer {self.access_token}",
88
- "Content-Type": "application/json"}
94
+ headers = {
95
+ "Authorization": f"Bearer {self.access_token}",
96
+ "Content-Type": "application/json",
97
+ }
89
98
  return headers
90
99
 
91
100
  def _query(self, sentence: str):
92
101
  headers = self.prepare_header()
93
102
 
94
- data = {"model_id": self.model_id, "input": sentence,
95
- "parameters": self.params, "space_id": self.space_id}
96
- generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
103
+ data = {
104
+ "model_id": self.model_id,
105
+ "input": sentence,
106
+ "parameters": self.params,
107
+ "space_id": self.space_id,
108
+ }
109
+ generation_url = (
110
+ f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
111
+ )
97
112
  resp = requests.post(url=generation_url, headers=headers, json=data)
98
113
  if resp.status_code == 200:
99
114
  return resp.json()["results"][0]
@@ -105,20 +120,25 @@ class WatsonXProvider(Provider):
105
120
  if not self.access_token or time.time() > self.refresh_time:
106
121
  with self.lock:
107
122
  if not self.access_token or time.time() > self.refresh_time:
108
- self.access_token, self.refresh_time = self._get_access_token()
123
+ (
124
+ self.access_token,
125
+ self.refresh_time,
126
+ ) = self._get_access_token()
109
127
 
110
128
  def query(self, sentence: Union[str, Mapping[str, str]]) -> str:
111
129
  if self.model_id is None:
112
130
  raise Exception("model id must be specified for text generation")
113
131
  try:
114
132
  response = self._query(sentence)
115
- if (generated_text := response.get("generated_text")):
133
+ if generated_text := response.get("generated_text"):
116
134
  return generated_text
117
- elif (message := response.get("message")):
135
+ elif message := response.get("message"):
118
136
  return message
119
137
  else:
120
- raise ValueError(f"Unexpected response from WatsonX: {response}")
121
-
138
+ raise ValueError(
139
+ f"Unexpected response from WatsonX: {response}"
140
+ )
141
+
122
142
  except Exception as e:
123
143
  with self.lock:
124
144
  if "authentication_token_expired" in str(e):
@@ -130,12 +150,18 @@ class WatsonXProvider(Provider):
130
150
 
131
151
  def encode(self, sentences: List[str]) -> List[list]:
132
152
  if self.embedding_model_id is None:
133
- raise Exception("embedding model id must be specified for text encoding")
153
+ raise Exception(
154
+ "embedding model id must be specified for text encoding"
155
+ )
134
156
 
135
157
  headers = self.prepare_header()
136
158
  url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
137
159
 
138
- data = {"inputs": sentences, "model_id": self.model_id, "space_id": self.space_id}
160
+ data = {
161
+ "inputs": sentences,
162
+ "model_id": self.model_id,
163
+ "space_id": self.space_id,
164
+ }
139
165
  resp = requests.post(url=url, headers=headers, json=data)
140
166
  if resp.status_code == 200:
141
167
  return [entry["embedding"] for entry in resp.json()["results"]]
@@ -144,7 +170,9 @@ class WatsonXProvider(Provider):
144
170
 
145
171
 
146
172
  if __name__ == "__main__":
147
- provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
173
+ provider = WatsonXProvider(
174
+ model_id="meta-llama/llama-3-2-90b-vision-instruct"
175
+ )
148
176
 
149
177
  prompt = """
150
178
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
@@ -176,4 +204,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
176
204
  <|eot_id|><|start_header_id|>user<|end_header_id|>
177
205
  """
178
206
 
179
- print(provider.query(prompt))
207
+ print(provider.query(prompt))