ibm-watsonx-orchestrate-evaluation-framework 1.0.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/METADATA +53 -0
  2. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +38 -21
  4. wxo_agentic_evaluation/analytics/tools/main.py +19 -25
  5. wxo_agentic_evaluation/analytics/tools/types.py +26 -11
  6. wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
  7. wxo_agentic_evaluation/analyze_run.py +1184 -97
  8. wxo_agentic_evaluation/annotate.py +7 -5
  9. wxo_agentic_evaluation/arg_configs.py +97 -5
  10. wxo_agentic_evaluation/base_user.py +25 -0
  11. wxo_agentic_evaluation/batch_annotate.py +97 -27
  12. wxo_agentic_evaluation/clients.py +103 -0
  13. wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
  14. wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
  15. wxo_agentic_evaluation/compare_runs/diff.py +554 -0
  16. wxo_agentic_evaluation/compare_runs/model.py +193 -0
  17. wxo_agentic_evaluation/data_annotator.py +45 -19
  18. wxo_agentic_evaluation/description_quality_checker.py +178 -0
  19. wxo_agentic_evaluation/evaluation.py +50 -0
  20. wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
  21. wxo_agentic_evaluation/evaluation_package.py +544 -107
  22. wxo_agentic_evaluation/external_agent/__init__.py +18 -7
  23. wxo_agentic_evaluation/external_agent/external_validate.py +49 -36
  24. wxo_agentic_evaluation/external_agent/performance_test.py +33 -22
  25. wxo_agentic_evaluation/external_agent/types.py +8 -7
  26. wxo_agentic_evaluation/extractors/__init__.py +3 -0
  27. wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
  28. wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
  29. wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
  30. wxo_agentic_evaluation/langfuse_collection.py +60 -0
  31. wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
  32. wxo_agentic_evaluation/llm_matching.py +108 -5
  33. wxo_agentic_evaluation/llm_rag_eval.py +7 -4
  34. wxo_agentic_evaluation/llm_safety_eval.py +64 -0
  35. wxo_agentic_evaluation/llm_user.py +12 -6
  36. wxo_agentic_evaluation/llm_user_v2.py +114 -0
  37. wxo_agentic_evaluation/main.py +128 -246
  38. wxo_agentic_evaluation/metrics/__init__.py +15 -0
  39. wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
  40. wxo_agentic_evaluation/metrics/evaluations.py +107 -0
  41. wxo_agentic_evaluation/metrics/journey_success.py +137 -0
  42. wxo_agentic_evaluation/metrics/llm_as_judge.py +28 -2
  43. wxo_agentic_evaluation/metrics/metrics.py +319 -16
  44. wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
  45. wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
  46. wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
  47. wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
  48. wxo_agentic_evaluation/otel_parser/parser.py +163 -0
  49. wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
  50. wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
  51. wxo_agentic_evaluation/otel_parser/utils.py +15 -0
  52. wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
  53. wxo_agentic_evaluation/otel_support/evaluate_tau.py +101 -0
  54. wxo_agentic_evaluation/otel_support/otel_message_conversion.py +29 -0
  55. wxo_agentic_evaluation/otel_support/tasks_test.py +1566 -0
  56. wxo_agentic_evaluation/prompt/bad_tool_descriptions_prompt.jinja2 +178 -0
  57. wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
  58. wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +59 -5
  59. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  60. wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +34 -0
  61. wxo_agentic_evaluation/prompt/on_policy_attack_generation_prompt.jinja2 +46 -0
  62. wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
  63. wxo_agentic_evaluation/prompt/template_render.py +163 -12
  64. wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
  65. wxo_agentic_evaluation/quick_eval.py +384 -0
  66. wxo_agentic_evaluation/record_chat.py +132 -81
  67. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +302 -0
  68. wxo_agentic_evaluation/red_teaming/attack_generator.py +329 -0
  69. wxo_agentic_evaluation/red_teaming/attack_list.py +184 -0
  70. wxo_agentic_evaluation/red_teaming/attack_runner.py +204 -0
  71. wxo_agentic_evaluation/referenceless_eval/__init__.py +3 -0
  72. wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py +0 -0
  73. wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py +28 -0
  74. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py +0 -0
  75. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +29 -0
  76. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/__init__.py +0 -0
  77. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general.py +49 -0
  78. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
  79. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
  80. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/__init__.py +0 -0
  81. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection.py +31 -0
  82. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
  83. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
  84. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +245 -0
  85. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py +0 -0
  86. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +106 -0
  87. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +291 -0
  88. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +465 -0
  89. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +162 -0
  90. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py +509 -0
  91. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +562 -0
  92. wxo_agentic_evaluation/referenceless_eval/metrics/__init__.py +3 -0
  93. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +266 -0
  94. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +344 -0
  95. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +193 -0
  96. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +413 -0
  97. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +46 -0
  98. wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py +0 -0
  99. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +158 -0
  100. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +191 -0
  101. wxo_agentic_evaluation/resource_map.py +6 -3
  102. wxo_agentic_evaluation/runner.py +329 -0
  103. wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
  104. wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
  105. wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +88 -150
  106. wxo_agentic_evaluation/scheduler.py +247 -0
  107. wxo_agentic_evaluation/service_instance.py +117 -26
  108. wxo_agentic_evaluation/service_provider/__init__.py +182 -17
  109. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  110. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +628 -45
  111. wxo_agentic_evaluation/service_provider/ollama_provider.py +392 -22
  112. wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
  113. wxo_agentic_evaluation/service_provider/provider.py +129 -10
  114. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +203 -0
  115. wxo_agentic_evaluation/service_provider/watsonx_provider.py +516 -53
  116. wxo_agentic_evaluation/simluation_runner.py +125 -0
  117. wxo_agentic_evaluation/test_prompt.py +4 -4
  118. wxo_agentic_evaluation/tool_planner.py +141 -46
  119. wxo_agentic_evaluation/type.py +217 -14
  120. wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
  121. wxo_agentic_evaluation/utils/__init__.py +44 -3
  122. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  123. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  124. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  125. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +178 -0
  126. wxo_agentic_evaluation/utils/parsers.py +71 -0
  127. wxo_agentic_evaluation/utils/rich_utils.py +188 -0
  128. wxo_agentic_evaluation/utils/rouge_score.py +23 -0
  129. wxo_agentic_evaluation/utils/utils.py +514 -17
  130. wxo_agentic_evaluation/wxo_client.py +81 -0
  131. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/METADATA +0 -380
  132. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +0 -56
  133. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
  134. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,23 @@
1
- import os
2
- import requests
3
- import json
4
- from types import MappingProxyType
5
- from typing import List
6
1
  import dataclasses
7
- from threading import Lock
2
+ import json
3
+ import logging
4
+ import os
8
5
  import time
9
- from wxo_agentic_evaluation.service_provider.provider import Provider
6
+ import uuid
7
+ from threading import Lock
8
+ from types import MappingProxyType
9
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
10
+
11
+ import requests
12
+
13
+ from wxo_agentic_evaluation.service_provider.provider import (
14
+ ChatResult,
15
+ Provider,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
10
19
 
20
+ # IAM
11
21
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
12
22
  ACCESS_HEADER = {
13
23
  "content-type": "application/x-www-form-urlencoded",
@@ -16,28 +26,83 @@ ACCESS_HEADER = {
16
26
 
17
27
  YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
18
28
  PROD_URL = "https://us-south.ml.cloud.ibm.com"
29
+
19
30
  DEFAULT_PARAM = MappingProxyType(
20
31
  {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
21
32
  )
22
33
 
23
34
 
35
+ def _truncate(value: Any, max_len: int = 1000) -> str:
36
+ if value is None:
37
+ return ""
38
+ s = str(value)
39
+ return (
40
+ s
41
+ if len(s) <= max_len
42
+ else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
43
+ )
44
+
45
+
46
+ def _translate_params_to_chat(
47
+ params: Dict[str, Any] = {},
48
+ ) -> Dict[str, Any]:
49
+ """
50
+ Translate legacy generation params to chat.completions params.
51
+ """
52
+ translated_params: Dict[str, Any] = {}
53
+
54
+ if "max_new_tokens" in params:
55
+ translated_params["max_tokens"] = params["max_new_tokens"]
56
+
57
+ if params.get("decoding_method") == "greedy":
58
+ translated_params.setdefault("temperature", 0)
59
+ translated_params.setdefault("top_p", 1)
60
+
61
+ passthrough = {
62
+ "temperature",
63
+ "top_p",
64
+ "n",
65
+ "stream",
66
+ "stop",
67
+ "presence_penalty",
68
+ "frequency_penalty",
69
+ "logit_bias",
70
+ "user",
71
+ "seed",
72
+ "response_format",
73
+ }
74
+ for k in passthrough:
75
+ if k in params:
76
+ translated_params[k] = params[k]
77
+
78
+ return translated_params
79
+
80
+
24
81
  class WatsonXProvider(Provider):
25
82
  def __init__(
26
83
  self,
27
- model_id=None,
28
- api_key=None,
29
- space_id=None,
30
- api_endpoint=PROD_URL,
31
- url=ACCESS_URL,
32
- timeout=60,
33
- params=None,
34
- embedding_model_id=None,
84
+ model_id: Optional[str] = None,
85
+ api_key: Optional[str] = None,
86
+ space_id: Optional[str] = None,
87
+ api_endpoint: str = PROD_URL,
88
+ url: str = ACCESS_URL,
89
+ timeout: int = 60,
90
+ params: Optional[Any] = None,
91
+ embedding_model_id: Optional[str] = None,
92
+ use_legacy_query: Optional[bool] = None,
93
+ system_prompt: Optional[str] = None,
94
+ token: Optional[str] = None,
95
+ instance_url: Optional[str] = None,
35
96
  ):
36
- super().__init__()
97
+ super().__init__(use_legacy_query=use_legacy_query)
98
+
37
99
  self.url = url
38
100
  if (embedding_model_id is None) and (model_id is None):
39
- raise Exception("either model_id or embedding_model_id must be specified")
101
+ raise Exception(
102
+ "either model_id or embedding_model_id must be specified"
103
+ )
40
104
  self.model_id = model_id
105
+ logger.info("[d b]Using inference model %s", self.model_id)
41
106
  api_key = os.environ.get("WATSONX_APIKEY", api_key)
42
107
  if not api_key:
43
108
  raise Exception("apikey must be specified")
@@ -46,7 +111,7 @@ class WatsonXProvider(Provider):
46
111
  "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
47
112
  "apikey": self.api_key,
48
113
  }
49
- self.api_endpoint = api_endpoint
114
+ self.api_endpoint = (api_endpoint or PROD_URL).rstrip("/")
50
115
  space_id = os.environ.get("WATSONX_SPACE_ID", space_id)
51
116
  if not space_id:
52
117
  raise Exception("space id must be specified")
@@ -55,20 +120,32 @@ class WatsonXProvider(Provider):
55
120
  self.embedding_model_id = embedding_model_id
56
121
  self.lock = Lock()
57
122
 
58
- self.params = params if params else DEFAULT_PARAM
59
-
123
+ self.params = params if params is not None else DEFAULT_PARAM
60
124
  if isinstance(self.params, MappingProxyType):
61
125
  self.params = dict(self.params)
62
126
  if dataclasses.is_dataclass(self.params):
63
127
  self.params = dataclasses.asdict(self.params)
64
128
 
129
+ self.system_prompt = system_prompt
130
+
65
131
  self.refresh_time = None
66
132
  self.access_token = None
67
133
  self._refresh_token()
68
134
 
135
+ self.LEGACY_GEN_URL = (
136
+ f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
137
+ )
138
+ self.CHAT_COMPLETIONS_URL = f"{self.api_endpoint}/ml/v1/text/chat"
139
+ self.EMBEDDINGS_URL = (
140
+ f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
141
+ )
142
+
69
143
  def _get_access_token(self):
70
144
  response = requests.post(
71
- self.url, headers=ACCESS_HEADER, data=self.access_data, timeout=self.timeout
145
+ self.url,
146
+ headers=ACCESS_HEADER,
147
+ data=self.access_data,
148
+ timeout=self.timeout,
72
149
  )
73
150
  if response.status_code == 200:
74
151
  token_data = json.loads(response.text)
@@ -80,64 +157,450 @@ class WatsonXProvider(Provider):
80
157
  return token, refresh_time
81
158
 
82
159
  raise RuntimeError(
83
- f"try to acquire access token and get {response.status_code}"
160
+ f"Try to acquire access token and get {response.status_code}. Reason: {response.text} "
84
161
  )
85
162
 
86
163
  def prepare_header(self):
87
- headers = {"Authorization": f"Bearer {self.access_token}",
88
- "Content-Type": "application/json"}
164
+ headers = {
165
+ "Authorization": f"Bearer {self.access_token}",
166
+ "Content-Type": "application/json",
167
+ }
89
168
  return headers
90
169
 
91
- def generate(self, sentence: str):
92
- headers = self.prepare_header()
93
-
94
- data = {"model_id": self.model_id, "input": sentence,
95
- "parameters": self.params, "space_id": self.space_id}
96
- generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
97
- resp = requests.post(url=generation_url, headers=headers, json=data)
98
- if resp.status_code == 200:
99
- return resp.json()["results"][0]
100
- else:
101
- resp.raise_for_status()
102
-
103
170
  def _refresh_token(self):
104
171
  # if we do not have a token or the current timestamp is 9 minutes away from expire.
105
172
  if not self.access_token or time.time() > self.refresh_time:
106
173
  with self.lock:
107
174
  if not self.access_token or time.time() > self.refresh_time:
108
- self.access_token, self.refresh_time = self._get_access_token()
175
+ (
176
+ self.access_token,
177
+ self.refresh_time,
178
+ ) = self._get_access_token()
179
+
180
+ def old_query(self, sentence: Union[str, Mapping[str, str]]) -> str:
181
+ """
182
+ Legacy /ml/v1/text/generation
183
+ """
184
+ if self.model_id is None:
185
+ raise Exception("model id must be specified for text generation")
186
+
187
+ self._refresh_token()
188
+ headers = self.prepare_header()
189
+
190
+ payload: Dict[str, Any] = {
191
+ "model_id": self.model_id,
192
+ "input": sentence,
193
+ "parameters": self.params or {},
194
+ "space_id": self.space_id,
195
+ }
196
+
197
+ request_id = str(uuid.uuid4())
198
+ t0 = time.time()
199
+
200
+ logger.debug(
201
+ "[d][b]Sending text.generation request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
202
+ request_id,
203
+ self.LEGACY_GEN_URL,
204
+ self.model_id,
205
+ self.space_id,
206
+ json.dumps(
207
+ payload.get("parameters", {}),
208
+ sort_keys=True,
209
+ ensure_ascii=False,
210
+ ),
211
+ _truncate(sentence, 200),
212
+ )
213
+
214
+ resp = None
215
+ try:
216
+ resp = requests.post(
217
+ url=self.LEGACY_GEN_URL,
218
+ headers=headers,
219
+ json=payload,
220
+ timeout=self.timeout,
221
+ )
222
+ duration_ms = int((time.time() - t0) * 1000)
223
+ resp.raise_for_status()
224
+ data = resp.json()
225
+
226
+ if isinstance(data, dict) and "results" in data and data["results"]:
227
+ result = data["results"][0]
228
+ elif isinstance(data, dict):
229
+ result = data
230
+ else:
231
+ raise ValueError(
232
+ f"Unexpected response type from WatsonX: {type(data)}"
233
+ )
234
+
235
+ output_text = ""
236
+ if isinstance(result, dict):
237
+ output_text = (
238
+ result.get("generated_text") or result.get("message") or ""
239
+ )
240
+
241
+ usage = data.get("usage") or {}
242
+ if not usage and isinstance(result, dict):
243
+ in_tok = result.get("input_token_count")
244
+ out_tok = result.get("generated_token_count") or result.get(
245
+ "output_token_count"
246
+ )
247
+ if in_tok is not None or out_tok is not None:
248
+ usage = {
249
+ "prompt_tokens": in_tok,
250
+ "completion_tokens": out_tok,
251
+ "total_tokens": (in_tok or 0) + (out_tok or 0),
252
+ }
253
+
254
+ api_request_id = resp.headers.get(
255
+ "x-request-id"
256
+ ) or resp.headers.get("request-id")
257
+
258
+ logger.debug(
259
+ "[d][b]text.generation response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s api_request_id=%s",
260
+ request_id,
261
+ resp.status_code,
262
+ duration_ms,
263
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
264
+ _truncate(output_text, 2000),
265
+ api_request_id,
266
+ )
267
+
268
+ if output_text:
269
+ return output_text
270
+ raise ValueError(
271
+ f"Unexpected response from legacy endpoint: {data}"
272
+ )
273
+
274
+ except Exception as e:
275
+ duration_ms = int((time.time() - t0) * 1000)
276
+ status_code = getattr(resp, "status_code", None)
277
+ resp_text_preview = (
278
+ _truncate(getattr(resp, "text", None), 2000)
279
+ if resp is not None
280
+ else None
281
+ )
282
+
283
+ logger.exception(
284
+ "text.generation request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
285
+ request_id,
286
+ status_code,
287
+ duration_ms,
288
+ resp_text_preview,
289
+ )
290
+ with self.lock:
291
+ if (
292
+ "authentication_token_expired" in str(e)
293
+ or status_code == 401
294
+ ):
295
+ try:
296
+ self.access_token, self.refresh_time = (
297
+ self._get_access_token()
298
+ )
299
+ except Exception:
300
+ pass
301
+ raise
109
302
 
110
- def query(self, sentence: str) -> str:
303
+ def new_query(self, sentence: str) -> str:
304
+ """
305
+ /ml/v1/text/chat
306
+ Returns assistant content as a plain string.
307
+ """
111
308
  if self.model_id is None:
112
309
  raise Exception("model id must be specified for text generation")
310
+
311
+ self._refresh_token()
312
+ headers = self.prepare_header()
313
+
314
+ messages: List[Dict[str, Any]] = []
315
+ if getattr(self, "system_prompt", None):
316
+ messages.append({"role": "system", "content": self.system_prompt})
317
+ messages.append(
318
+ {
319
+ "role": "user",
320
+ "content": [
321
+ {
322
+ "type": "text",
323
+ "text": sentence,
324
+ }
325
+ ],
326
+ }
327
+ )
328
+
329
+ chat_params = _translate_params_to_chat(self.params)
330
+ if "time_limit" in self.params:
331
+ chat_params["time_limit"] = self.params["time_limit"]
332
+
333
+ payload: Dict[str, Any] = {
334
+ "model_id": self.model_id,
335
+ "space_id": self.space_id,
336
+ "messages": messages,
337
+ **chat_params,
338
+ }
339
+
340
+ url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
341
+ request_id = str(uuid.uuid4())
342
+ t0 = time.time()
343
+
344
+ logger.debug(
345
+ "[d][b]Sending chat.completions request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
346
+ request_id,
347
+ url,
348
+ self.model_id,
349
+ self.space_id,
350
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
351
+ _truncate(sentence, 200),
352
+ )
353
+
354
+ resp = None
113
355
  try:
114
- return self.generate(sentence)["generated_text"]
356
+ resp = requests.post(
357
+ url=url, headers=headers, json=payload, timeout=self.timeout
358
+ )
359
+ duration_ms = int((time.time() - t0) * 1000)
360
+ resp.raise_for_status()
361
+ data = resp.json()
362
+
363
+ choice = data["choices"][0]
364
+ content = choice["message"]["content"]
365
+ finish_reason = choice.get("finish_reason")
366
+ usage = data.get("usage", {})
367
+ api_request_id = resp.headers.get(
368
+ "x-request-id"
369
+ ) or resp.headers.get("request-id")
370
+
371
+ logger.debug(
372
+ "[d][b]chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
373
+ request_id,
374
+ resp.status_code,
375
+ duration_ms,
376
+ finish_reason,
377
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
378
+ _truncate(content, 2000),
379
+ api_request_id,
380
+ )
381
+
382
+ return content
383
+
115
384
  except Exception as e:
385
+ duration_ms = int((time.time() - t0) * 1000)
386
+ status_code = getattr(resp, "status_code", None)
387
+ resp_text_preview = (
388
+ _truncate(getattr(resp, "text", None), 2000)
389
+ if resp is not None
390
+ else None
391
+ )
392
+
393
+ logger.exception(
394
+ "chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
395
+ request_id,
396
+ status_code,
397
+ duration_ms,
398
+ resp_text_preview,
399
+ )
116
400
  with self.lock:
117
- if "authentication_token_expired" in str(e):
118
- self._refresh_token()
119
- raise e
401
+ if (
402
+ "authentication_token_expired" in str(e)
403
+ or status_code == 401
404
+ ):
405
+ try:
406
+ self.access_token, self.refresh_time = (
407
+ self._get_access_token()
408
+ )
409
+ except Exception:
410
+ pass
411
+ raise
412
+
413
+ def chat(
414
+ self,
415
+ messages: Sequence[Dict[str, str]],
416
+ params: Optional[Dict[str, Any]] = None,
417
+ ) -> ChatResult:
418
+ """
419
+ Sends a multi-message chat request to /ml/v1/text/chat
420
+ Returns ChatResult with text, usage, finish_reason, and raw response.
421
+ """
422
+ if self.model_id is None:
423
+ raise Exception("model id must be specified for chat")
424
+
425
+ self._refresh_token()
426
+ headers = self.prepare_header()
427
+
428
+ wx_messages: List[Dict[str, Any]] = []
429
+ for m in messages:
430
+ role = m.get("role")
431
+ content = m.get("content", "")
432
+ if role == "user" and isinstance(content, str):
433
+ wx_messages.append(
434
+ {
435
+ "role": "user",
436
+ "content": [{"type": "text", "text": content}],
437
+ }
438
+ )
439
+ else:
440
+ wx_messages.append({"role": role, "content": content})
120
441
 
121
- def batch_query(self, sentences: List[str]) -> List[dict]:
122
- return [self.query(sentence) for sentence in sentences]
442
+ merged_params = dict(self.params or {})
443
+ if params:
444
+ merged_params.update(params)
445
+ chat_params = _translate_params_to_chat(merged_params)
446
+ chat_params.pop("stream", None)
447
+ if "time_limit" in merged_params:
448
+ chat_params["time_limit"] = merged_params["time_limit"]
449
+
450
+ payload: Dict[str, Any] = {
451
+ "model_id": self.model_id,
452
+ "space_id": self.space_id,
453
+ "messages": wx_messages,
454
+ **chat_params,
455
+ }
456
+
457
+ url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
458
+ request_id = str(uuid.uuid4())
459
+ t0 = time.time()
460
+
461
+ last_user = next(
462
+ (
463
+ m.get("content", "")
464
+ for m in reversed(messages)
465
+ if m.get("role") == "user"
466
+ ),
467
+ "",
468
+ )
469
+ logger.debug(
470
+ "[d][b]Sending chat.completions request (non-streaming) | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
471
+ request_id,
472
+ url,
473
+ self.model_id,
474
+ self.space_id,
475
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
476
+ _truncate(last_user, 200),
477
+ )
478
+
479
+ resp = None
480
+ try:
481
+ resp = requests.post(
482
+ url=url, headers=headers, json=payload, timeout=self.timeout
483
+ )
484
+ duration_ms = int((time.time() - t0) * 1000)
485
+ resp.raise_for_status()
486
+ data = resp.json()
487
+
488
+ choice = data["choices"][0]
489
+ content = choice["message"]["content"]
490
+ finish_reason = choice.get("finish_reason")
491
+ usage = data.get("usage", {})
492
+ api_request_id = resp.headers.get(
493
+ "x-request-id"
494
+ ) or resp.headers.get("request-id")
495
+
496
+ logger.debug(
497
+ "[d][b]chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
498
+ request_id,
499
+ resp.status_code,
500
+ duration_ms,
501
+ finish_reason,
502
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
503
+ _truncate(content, 2000),
504
+ api_request_id,
505
+ )
506
+
507
+ return ChatResult(
508
+ text=content, usage=usage, finish_reason=finish_reason, raw=data
509
+ )
510
+
511
+ except Exception as e:
512
+ duration_ms = int((time.time() - t0) * 1000)
513
+ status_code = getattr(resp, "status_code", None)
514
+ resp_text_preview = (
515
+ _truncate(getattr(resp, "text", None), 2000)
516
+ if resp is not None
517
+ else None
518
+ )
519
+
520
+ logger.exception(
521
+ "chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
522
+ request_id,
523
+ status_code,
524
+ duration_ms,
525
+ resp_text_preview,
526
+ )
527
+ with self.lock:
528
+ if (
529
+ "authentication_token_expired" in str(e)
530
+ or status_code == 401
531
+ ):
532
+ try:
533
+ self.access_token, self.refresh_time = (
534
+ self._get_access_token()
535
+ )
536
+ except Exception:
537
+ pass
538
+ raise
123
539
 
124
540
  def encode(self, sentences: List[str]) -> List[list]:
125
541
  if self.embedding_model_id is None:
126
- raise Exception("embedding model id must be specified for text encoding")
542
+ raise Exception(
543
+ "embedding model id must be specified for text encoding"
544
+ )
127
545
 
546
+ self._refresh_token()
128
547
  headers = self.prepare_header()
129
- url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
130
548
 
131
- data = {"inputs": sentences, "model_id": self.model_id, "space_id": self.space_id}
132
- resp = requests.post(url=url, headers=headers, json=data)
549
+ # Minimal logging for embeddings
550
+ request_id = str(uuid.uuid4())
551
+ t0 = time.time()
552
+ logger.debug(
553
+ "[d][b]Sending embeddings request | request_id=%s url=%s model=%s space_id=%s num_inputs=%s",
554
+ request_id,
555
+ self.EMBEDDINGS_URL,
556
+ self.embedding_model_id,
557
+ self.space_id,
558
+ len(sentences),
559
+ )
560
+
561
+ payload = {
562
+ "inputs": sentences,
563
+ "model_id": self.embedding_model_id,
564
+ "space_id": self.space_id,
565
+ }
566
+
567
+ resp = requests.post(
568
+ url=self.EMBEDDINGS_URL,
569
+ headers=headers,
570
+ json=payload,
571
+ timeout=self.timeout,
572
+ )
573
+ duration_ms = int((time.time() - t0) * 1000)
574
+
133
575
  if resp.status_code == 200:
134
- return [entry["embedding"] for entry in resp.json()["results"]]
135
- else:
136
- resp.raise_for_status()
576
+ data = resp.json()
577
+ vectors = [entry["embedding"] for entry in data["results"]]
578
+ logger.debug(
579
+ "[d][b]Embeddings response received | request_id=%s status_code=%s duration_ms=%s num_vectors=%s",
580
+ request_id,
581
+ resp.status_code,
582
+ duration_ms,
583
+ len(vectors),
584
+ )
585
+ return vectors
586
+
587
+ logger.error(
588
+ "[d b red]Embeddings request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
589
+ request_id,
590
+ resp.status_code,
591
+ duration_ms,
592
+ _truncate(resp.text, 2000),
593
+ )
594
+ resp.raise_for_status()
137
595
 
138
596
 
139
597
  if __name__ == "__main__":
140
- provider = WatsonXProvider(model_id="meta-llama/llama-3-2-90b-vision-instruct")
598
+
599
+ provider = WatsonXProvider(
600
+ model_id="meta-llama/llama-3-2-90b-vision-instruct",
601
+ use_legacy_query=False, # set True to use legacy endpoint
602
+ system_prompt="You are a helpful assistant.",
603
+ )
141
604
 
142
605
  prompt = """
143
606
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
@@ -169,4 +632,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
169
632
  <|eot_id|><|start_header_id|>user<|end_header_id|>
170
633
  """
171
634
 
172
- print(provider.query(prompt))
635
+ print(provider.query(prompt))