ibm-watsonx-orchestrate-evaluation-framework 1.1.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/METADATA +19 -1
  2. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +4 -2
  4. wxo_agentic_evaluation/analyze_run.py +1025 -220
  5. wxo_agentic_evaluation/annotate.py +2 -2
  6. wxo_agentic_evaluation/arg_configs.py +60 -2
  7. wxo_agentic_evaluation/base_user.py +25 -0
  8. wxo_agentic_evaluation/batch_annotate.py +19 -2
  9. wxo_agentic_evaluation/clients.py +103 -0
  10. wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
  11. wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
  12. wxo_agentic_evaluation/compare_runs/diff.py +554 -0
  13. wxo_agentic_evaluation/compare_runs/model.py +193 -0
  14. wxo_agentic_evaluation/data_annotator.py +25 -7
  15. wxo_agentic_evaluation/description_quality_checker.py +29 -6
  16. wxo_agentic_evaluation/evaluation.py +16 -8
  17. wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
  18. wxo_agentic_evaluation/evaluation_package.py +414 -69
  19. wxo_agentic_evaluation/external_agent/__init__.py +1 -1
  20. wxo_agentic_evaluation/external_agent/external_validate.py +7 -5
  21. wxo_agentic_evaluation/external_agent/types.py +3 -9
  22. wxo_agentic_evaluation/extractors/__init__.py +3 -0
  23. wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
  24. wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
  25. wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
  26. wxo_agentic_evaluation/langfuse_collection.py +60 -0
  27. wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
  28. wxo_agentic_evaluation/llm_matching.py +104 -2
  29. wxo_agentic_evaluation/llm_safety_eval.py +64 -0
  30. wxo_agentic_evaluation/llm_user.py +5 -4
  31. wxo_agentic_evaluation/llm_user_v2.py +114 -0
  32. wxo_agentic_evaluation/main.py +112 -343
  33. wxo_agentic_evaluation/metrics/__init__.py +15 -0
  34. wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
  35. wxo_agentic_evaluation/metrics/evaluations.py +107 -0
  36. wxo_agentic_evaluation/metrics/journey_success.py +137 -0
  37. wxo_agentic_evaluation/metrics/llm_as_judge.py +26 -0
  38. wxo_agentic_evaluation/metrics/metrics.py +276 -8
  39. wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
  40. wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
  41. wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
  42. wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
  43. wxo_agentic_evaluation/otel_parser/parser.py +163 -0
  44. wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
  45. wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
  46. wxo_agentic_evaluation/otel_parser/utils.py +15 -0
  47. wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
  48. wxo_agentic_evaluation/otel_support/evaluate_tau.py +44 -10
  49. wxo_agentic_evaluation/otel_support/otel_message_conversion.py +12 -4
  50. wxo_agentic_evaluation/otel_support/tasks_test.py +456 -116
  51. wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
  52. wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +50 -4
  53. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  54. wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +1 -1
  55. wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
  56. wxo_agentic_evaluation/prompt/template_render.py +103 -4
  57. wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
  58. wxo_agentic_evaluation/quick_eval.py +33 -17
  59. wxo_agentic_evaluation/record_chat.py +38 -32
  60. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +211 -62
  61. wxo_agentic_evaluation/red_teaming/attack_generator.py +63 -40
  62. wxo_agentic_evaluation/red_teaming/attack_list.py +95 -7
  63. wxo_agentic_evaluation/red_teaming/attack_runner.py +77 -17
  64. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
  65. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
  66. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +10 -10
  67. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +105 -39
  68. wxo_agentic_evaluation/resource_map.py +3 -1
  69. wxo_agentic_evaluation/runner.py +329 -0
  70. wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
  71. wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
  72. wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +24 -293
  73. wxo_agentic_evaluation/scheduler.py +247 -0
  74. wxo_agentic_evaluation/service_instance.py +26 -17
  75. wxo_agentic_evaluation/service_provider/__init__.py +145 -9
  76. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  77. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +417 -17
  78. wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
  79. wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
  80. wxo_agentic_evaluation/service_provider/provider.py +130 -10
  81. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
  82. wxo_agentic_evaluation/service_provider/watsonx_provider.py +481 -53
  83. wxo_agentic_evaluation/simluation_runner.py +125 -0
  84. wxo_agentic_evaluation/test_prompt.py +4 -4
  85. wxo_agentic_evaluation/type.py +185 -16
  86. wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
  87. wxo_agentic_evaluation/utils/__init__.py +44 -3
  88. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  89. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  90. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  91. wxo_agentic_evaluation/utils/parsers.py +71 -0
  92. wxo_agentic_evaluation/utils/utils.py +313 -9
  93. wxo_agentic_evaluation/wxo_client.py +81 -0
  94. ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info/RECORD +0 -102
  95. wxo_agentic_evaluation/otel_support/evaluate_tau_traces.py +0 -176
  96. {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
  97. {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,23 @@
1
1
  import dataclasses
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import time
6
+ import uuid
5
7
  from threading import Lock
6
8
  from types import MappingProxyType
7
- from typing import List, Mapping, Union
9
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
8
10
 
9
11
  import requests
10
12
 
11
- from wxo_agentic_evaluation.service_provider.provider import Provider
13
+ from wxo_agentic_evaluation.service_provider.provider import (
14
+ ChatResult,
15
+ Provider,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
12
19
 
20
+ # IAM
13
21
  ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
14
22
  ACCESS_HEADER = {
15
23
  "content-type": "application/x-www-form-urlencoded",
@@ -18,30 +26,83 @@ ACCESS_HEADER = {
18
26
 
19
27
  YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
20
28
  PROD_URL = "https://us-south.ml.cloud.ibm.com"
29
+
21
30
  DEFAULT_PARAM = MappingProxyType(
22
31
  {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
23
32
  )
24
33
 
25
34
 
35
+ def _truncate(value: Any, max_len: int = 1000) -> str:
36
+ if value is None:
37
+ return ""
38
+ s = str(value)
39
+ return (
40
+ s
41
+ if len(s) <= max_len
42
+ else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
43
+ )
44
+
45
+
46
+ def _translate_params_to_chat(
47
+ params: Dict[str, Any] = {},
48
+ ) -> Dict[str, Any]:
49
+ """
50
+ Translate legacy generation params to chat.completions params.
51
+ """
52
+ translated_params: Dict[str, Any] = {}
53
+
54
+ if "max_new_tokens" in params:
55
+ translated_params["max_tokens"] = params["max_new_tokens"]
56
+
57
+ if params.get("decoding_method") == "greedy":
58
+ translated_params.setdefault("temperature", 0)
59
+ translated_params.setdefault("top_p", 1)
60
+
61
+ passthrough = {
62
+ "temperature",
63
+ "top_p",
64
+ "n",
65
+ "stream",
66
+ "stop",
67
+ "presence_penalty",
68
+ "frequency_penalty",
69
+ "logit_bias",
70
+ "user",
71
+ "seed",
72
+ "response_format",
73
+ }
74
+ for k in passthrough:
75
+ if k in params:
76
+ translated_params[k] = params[k]
77
+
78
+ return translated_params
79
+
80
+
26
81
  class WatsonXProvider(Provider):
27
82
  def __init__(
28
83
  self,
29
- model_id=None,
30
- api_key=None,
31
- space_id=None,
32
- api_endpoint=PROD_URL,
33
- url=ACCESS_URL,
34
- timeout=60,
35
- params=None,
36
- embedding_model_id=None,
84
+ model_id: Optional[str] = None,
85
+ api_key: Optional[str] = None,
86
+ space_id: Optional[str] = None,
87
+ api_endpoint: str = PROD_URL,
88
+ url: str = ACCESS_URL,
89
+ timeout: int = 60,
90
+ params: Optional[Any] = None,
91
+ embedding_model_id: Optional[str] = None,
92
+ use_legacy_query: Optional[bool] = None,
93
+ system_prompt: Optional[str] = None,
94
+ token: Optional[str] = None,
95
+ instance_url: Optional[str] = None,
37
96
  ):
38
- super().__init__()
97
+ super().__init__(use_legacy_query=use_legacy_query)
98
+
39
99
  self.url = url
40
100
  if (embedding_model_id is None) and (model_id is None):
41
101
  raise Exception(
42
102
  "either model_id or embedding_model_id must be specified"
43
103
  )
44
104
  self.model_id = model_id
105
+ logger.info("[d b]Using inference model %s", self.model_id)
45
106
  api_key = os.environ.get("WATSONX_APIKEY", api_key)
46
107
  if not api_key:
47
108
  raise Exception("apikey must be specified")
@@ -50,7 +111,7 @@ class WatsonXProvider(Provider):
50
111
  "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
51
112
  "apikey": self.api_key,
52
113
  }
53
- self.api_endpoint = api_endpoint
114
+ self.api_endpoint = (api_endpoint or PROD_URL).rstrip("/")
54
115
  space_id = os.environ.get("WATSONX_SPACE_ID", space_id)
55
116
  if not space_id:
56
117
  raise Exception("space id must be specified")
@@ -59,17 +120,26 @@ class WatsonXProvider(Provider):
59
120
  self.embedding_model_id = embedding_model_id
60
121
  self.lock = Lock()
61
122
 
62
- self.params = params if params else DEFAULT_PARAM
63
-
123
+ self.params = params if params is not None else DEFAULT_PARAM
64
124
  if isinstance(self.params, MappingProxyType):
65
125
  self.params = dict(self.params)
66
126
  if dataclasses.is_dataclass(self.params):
67
127
  self.params = dataclasses.asdict(self.params)
68
128
 
129
+ self.system_prompt = system_prompt
130
+
69
131
  self.refresh_time = None
70
132
  self.access_token = None
71
133
  self._refresh_token()
72
134
 
135
+ self.LEGACY_GEN_URL = (
136
+ f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
137
+ )
138
+ self.CHAT_COMPLETIONS_URL = f"{self.api_endpoint}/ml/v1/text/chat"
139
+ self.EMBEDDINGS_URL = (
140
+ f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
141
+ )
142
+
73
143
  def _get_access_token(self):
74
144
  response = requests.post(
75
145
  self.url,
@@ -87,7 +157,7 @@ class WatsonXProvider(Provider):
87
157
  return token, refresh_time
88
158
 
89
159
  raise RuntimeError(
90
- f"try to acquire access token and get {response.status_code}"
160
+ f"Try to acquire access token and get {response.status_code}. Reason: {response.text} "
91
161
  )
92
162
 
93
163
  def prepare_header(self):
@@ -97,24 +167,6 @@ class WatsonXProvider(Provider):
97
167
  }
98
168
  return headers
99
169
 
100
- def _query(self, sentence: str):
101
- headers = self.prepare_header()
102
-
103
- data = {
104
- "model_id": self.model_id,
105
- "input": sentence,
106
- "parameters": self.params,
107
- "space_id": self.space_id,
108
- }
109
- generation_url = (
110
- f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
111
- )
112
- resp = requests.post(url=generation_url, headers=headers, json=data)
113
- if resp.status_code == 200:
114
- return resp.json()["results"][0]
115
- else:
116
- resp.raise_for_status()
117
-
118
170
  def _refresh_token(self):
119
171
  # if we do not have a token or the current timestamp is 9 minutes away from expire.
120
172
  if not self.access_token or time.time() > self.refresh_time:
@@ -125,28 +177,365 @@ class WatsonXProvider(Provider):
125
177
  self.refresh_time,
126
178
  ) = self._get_access_token()
127
179
 
128
- def query(self, sentence: Union[str, Mapping[str, str]]) -> str:
180
+ def old_query(self, sentence: Union[str, Mapping[str, str]]) -> str:
181
+ """
182
+ Legacy /ml/v1/text/generation
183
+ """
129
184
  if self.model_id is None:
130
185
  raise Exception("model id must be specified for text generation")
186
+
187
+ self._refresh_token()
188
+ headers = self.prepare_header()
189
+
190
+ payload: Dict[str, Any] = {
191
+ "model_id": self.model_id,
192
+ "input": sentence,
193
+ "parameters": self.params or {},
194
+ "space_id": self.space_id,
195
+ }
196
+
197
+ request_id = str(uuid.uuid4())
198
+ t0 = time.time()
199
+
200
+ logger.debug(
201
+ "[d][b]Sending text.generation request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
202
+ request_id,
203
+ self.LEGACY_GEN_URL,
204
+ self.model_id,
205
+ self.space_id,
206
+ json.dumps(
207
+ payload.get("parameters", {}),
208
+ sort_keys=True,
209
+ ensure_ascii=False,
210
+ ),
211
+ _truncate(sentence, 200),
212
+ )
213
+
214
+ resp = None
131
215
  try:
132
- response = self._query(sentence)
133
- if generated_text := response.get("generated_text"):
134
- return generated_text
135
- elif message := response.get("message"):
136
- return message
216
+ resp = requests.post(
217
+ url=self.LEGACY_GEN_URL,
218
+ headers=headers,
219
+ json=payload,
220
+ timeout=self.timeout,
221
+ )
222
+ duration_ms = int((time.time() - t0) * 1000)
223
+ resp.raise_for_status()
224
+ data = resp.json()
225
+
226
+ if isinstance(data, dict) and "results" in data and data["results"]:
227
+ result = data["results"][0]
228
+ elif isinstance(data, dict):
229
+ result = data
137
230
  else:
138
231
  raise ValueError(
139
- f"Unexpected response from WatsonX: {response}"
232
+ f"Unexpected response type from WatsonX: {type(data)}"
233
+ )
234
+
235
+ output_text = ""
236
+ if isinstance(result, dict):
237
+ output_text = (
238
+ result.get("generated_text") or result.get("message") or ""
239
+ )
240
+
241
+ usage = data.get("usage") or {}
242
+ if not usage and isinstance(result, dict):
243
+ in_tok = result.get("input_token_count")
244
+ out_tok = result.get("generated_token_count") or result.get(
245
+ "output_token_count"
140
246
  )
247
+ if in_tok is not None or out_tok is not None:
248
+ usage = {
249
+ "prompt_tokens": in_tok,
250
+ "completion_tokens": out_tok,
251
+ "total_tokens": (in_tok or 0) + (out_tok or 0),
252
+ }
253
+
254
+ api_request_id = resp.headers.get(
255
+ "x-request-id"
256
+ ) or resp.headers.get("request-id")
257
+
258
+ logger.debug(
259
+ "[d][b]text.generation response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s api_request_id=%s",
260
+ request_id,
261
+ resp.status_code,
262
+ duration_ms,
263
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
264
+ _truncate(output_text, 2000),
265
+ api_request_id,
266
+ )
267
+
268
+ if output_text:
269
+ return output_text
270
+ raise ValueError(
271
+ f"Unexpected response from legacy endpoint: {data}"
272
+ )
273
+
274
+ except Exception as e:
275
+ duration_ms = int((time.time() - t0) * 1000)
276
+ status_code = getattr(resp, "status_code", None)
277
+ resp_text_preview = (
278
+ _truncate(getattr(resp, "text", None), 2000)
279
+ if resp is not None
280
+ else None
281
+ )
282
+
283
+ logger.exception(
284
+ "text.generation request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
285
+ request_id,
286
+ status_code,
287
+ duration_ms,
288
+ resp_text_preview,
289
+ )
290
+ with self.lock:
291
+ if (
292
+ "authentication_token_expired" in str(e)
293
+ or status_code == 401
294
+ ):
295
+ try:
296
+ self.access_token, self.refresh_time = (
297
+ self._get_access_token()
298
+ )
299
+ except Exception:
300
+ pass
301
+ raise
302
+
303
+ def new_query(self, sentence: str) -> str:
304
+ """
305
+ /ml/v1/text/chat
306
+ Returns assistant content as a plain string.
307
+ """
308
+ if self.model_id is None:
309
+ raise Exception("model id must be specified for text generation")
310
+
311
+ self._refresh_token()
312
+ headers = self.prepare_header()
313
+
314
+ messages: List[Dict[str, Any]] = []
315
+ if getattr(self, "system_prompt", None):
316
+ messages.append({"role": "system", "content": self.system_prompt})
317
+ messages.append(
318
+ {
319
+ "role": "user",
320
+ "content": [
321
+ {
322
+ "type": "text",
323
+ "text": sentence,
324
+ }
325
+ ],
326
+ }
327
+ )
328
+
329
+ chat_params = _translate_params_to_chat(self.params)
330
+ if "time_limit" in self.params:
331
+ chat_params["time_limit"] = self.params["time_limit"]
332
+
333
+ payload: Dict[str, Any] = {
334
+ "model_id": self.model_id,
335
+ "space_id": self.space_id,
336
+ "messages": messages,
337
+ **chat_params,
338
+ }
339
+
340
+ url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
341
+ request_id = str(uuid.uuid4())
342
+ t0 = time.time()
343
+
344
+ logger.debug(
345
+ "[d][b]Sending chat.completions request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
346
+ request_id,
347
+ url,
348
+ self.model_id,
349
+ self.space_id,
350
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
351
+ _truncate(sentence, 200),
352
+ )
353
+
354
+ resp = None
355
+ try:
356
+ resp = requests.post(
357
+ url=url, headers=headers, json=payload, timeout=self.timeout
358
+ )
359
+ duration_ms = int((time.time() - t0) * 1000)
360
+ resp.raise_for_status()
361
+ data = resp.json()
362
+
363
+ choice = data["choices"][0]
364
+ content = choice["message"]["content"]
365
+ finish_reason = choice.get("finish_reason")
366
+ usage = data.get("usage", {})
367
+ api_request_id = resp.headers.get(
368
+ "x-request-id"
369
+ ) or resp.headers.get("request-id")
370
+
371
+ logger.debug(
372
+ "[d][b]chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
373
+ request_id,
374
+ resp.status_code,
375
+ duration_ms,
376
+ finish_reason,
377
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
378
+ _truncate(content, 2000),
379
+ api_request_id,
380
+ )
381
+
382
+ return content
141
383
 
142
384
  except Exception as e:
385
+ duration_ms = int((time.time() - t0) * 1000)
386
+ status_code = getattr(resp, "status_code", None)
387
+ resp_text_preview = (
388
+ _truncate(getattr(resp, "text", None), 2000)
389
+ if resp is not None
390
+ else None
391
+ )
392
+
393
+ logger.exception(
394
+ "chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
395
+ request_id,
396
+ status_code,
397
+ duration_ms,
398
+ resp_text_preview,
399
+ )
143
400
  with self.lock:
144
- if "authentication_token_expired" in str(e):
145
- self._refresh_token()
146
- raise e
401
+ if (
402
+ "authentication_token_expired" in str(e)
403
+ or status_code == 401
404
+ ):
405
+ try:
406
+ self.access_token, self.refresh_time = (
407
+ self._get_access_token()
408
+ )
409
+ except Exception:
410
+ pass
411
+ raise
412
+
413
+ def chat(
414
+ self,
415
+ messages: Sequence[Dict[str, str]],
416
+ params: Optional[Dict[str, Any]] = None,
417
+ ) -> ChatResult:
418
+ """
419
+ Sends a multi-message chat request to /ml/v1/text/chat
420
+ Returns ChatResult with text, usage, finish_reason, and raw response.
421
+ """
422
+ if self.model_id is None:
423
+ raise Exception("model id must be specified for chat")
424
+
425
+ self._refresh_token()
426
+ headers = self.prepare_header()
427
+
428
+ wx_messages: List[Dict[str, Any]] = []
429
+ for m in messages:
430
+ role = m.get("role")
431
+ content = m.get("content", "")
432
+ if role == "user" and isinstance(content, str):
433
+ wx_messages.append(
434
+ {
435
+ "role": "user",
436
+ "content": [{"type": "text", "text": content}],
437
+ }
438
+ )
439
+ else:
440
+ wx_messages.append({"role": role, "content": content})
441
+
442
+ merged_params = dict(self.params or {})
443
+ if params:
444
+ merged_params.update(params)
445
+ chat_params = _translate_params_to_chat(merged_params)
446
+ chat_params.pop("stream", None)
447
+ if "time_limit" in merged_params:
448
+ chat_params["time_limit"] = merged_params["time_limit"]
449
+
450
+ payload: Dict[str, Any] = {
451
+ "model_id": self.model_id,
452
+ "space_id": self.space_id,
453
+ "messages": wx_messages,
454
+ **chat_params,
455
+ }
456
+
457
+ url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
458
+ request_id = str(uuid.uuid4())
459
+ t0 = time.time()
147
460
 
148
- def batch_query(self, sentences: List[str]) -> List[dict]:
149
- return [self.query(sentence) for sentence in sentences]
461
+ last_user = next(
462
+ (
463
+ m.get("content", "")
464
+ for m in reversed(messages)
465
+ if m.get("role") == "user"
466
+ ),
467
+ "",
468
+ )
469
+ logger.debug(
470
+ "[d][b]Sending chat.completions request (non-streaming) | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
471
+ request_id,
472
+ url,
473
+ self.model_id,
474
+ self.space_id,
475
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
476
+ _truncate(last_user, 200),
477
+ )
478
+
479
+ resp = None
480
+ try:
481
+ resp = requests.post(
482
+ url=url, headers=headers, json=payload, timeout=self.timeout
483
+ )
484
+ duration_ms = int((time.time() - t0) * 1000)
485
+ resp.raise_for_status()
486
+ data = resp.json()
487
+
488
+ choice = data["choices"][0]
489
+ content = choice["message"]["content"]
490
+ finish_reason = choice.get("finish_reason")
491
+ usage = data.get("usage", {})
492
+ api_request_id = resp.headers.get(
493
+ "x-request-id"
494
+ ) or resp.headers.get("request-id")
495
+
496
+ logger.debug(
497
+ "[d][b]chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
498
+ request_id,
499
+ resp.status_code,
500
+ duration_ms,
501
+ finish_reason,
502
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
503
+ _truncate(content, 2000),
504
+ api_request_id,
505
+ )
506
+
507
+ return ChatResult(
508
+ text=content, usage=usage, finish_reason=finish_reason, raw=data
509
+ )
510
+
511
+ except Exception as e:
512
+ duration_ms = int((time.time() - t0) * 1000)
513
+ status_code = getattr(resp, "status_code", None)
514
+ resp_text_preview = (
515
+ _truncate(getattr(resp, "text", None), 2000)
516
+ if resp is not None
517
+ else None
518
+ )
519
+
520
+ logger.exception(
521
+ "chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
522
+ request_id,
523
+ status_code,
524
+ duration_ms,
525
+ resp_text_preview,
526
+ )
527
+ with self.lock:
528
+ if (
529
+ "authentication_token_expired" in str(e)
530
+ or status_code == 401
531
+ ):
532
+ try:
533
+ self.access_token, self.refresh_time = (
534
+ self._get_access_token()
535
+ )
536
+ except Exception:
537
+ pass
538
+ raise
150
539
 
151
540
  def encode(self, sentences: List[str]) -> List[list]:
152
541
  if self.embedding_model_id is None:
@@ -154,24 +543,63 @@ class WatsonXProvider(Provider):
154
543
  "embedding model id must be specified for text encoding"
155
544
  )
156
545
 
546
+ self._refresh_token()
157
547
  headers = self.prepare_header()
158
- url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
159
548
 
160
- data = {
549
+ # Minimal logging for embeddings
550
+ request_id = str(uuid.uuid4())
551
+ t0 = time.time()
552
+ logger.debug(
553
+ "[d][b]Sending embeddings request | request_id=%s url=%s model=%s space_id=%s num_inputs=%s",
554
+ request_id,
555
+ self.EMBEDDINGS_URL,
556
+ self.embedding_model_id,
557
+ self.space_id,
558
+ len(sentences),
559
+ )
560
+
561
+ payload = {
161
562
  "inputs": sentences,
162
- "model_id": self.model_id,
563
+ "model_id": self.embedding_model_id,
163
564
  "space_id": self.space_id,
164
565
  }
165
- resp = requests.post(url=url, headers=headers, json=data)
566
+
567
+ resp = requests.post(
568
+ url=self.EMBEDDINGS_URL,
569
+ headers=headers,
570
+ json=payload,
571
+ timeout=self.timeout,
572
+ )
573
+ duration_ms = int((time.time() - t0) * 1000)
574
+
166
575
  if resp.status_code == 200:
167
- return [entry["embedding"] for entry in resp.json()["results"]]
168
- else:
169
- resp.raise_for_status()
576
+ data = resp.json()
577
+ vectors = [entry["embedding"] for entry in data["results"]]
578
+ logger.debug(
579
+ "[d][b]Embeddings response received | request_id=%s status_code=%s duration_ms=%s num_vectors=%s",
580
+ request_id,
581
+ resp.status_code,
582
+ duration_ms,
583
+ len(vectors),
584
+ )
585
+ return vectors
586
+
587
+ logger.error(
588
+ "[d b red]Embeddings request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
589
+ request_id,
590
+ resp.status_code,
591
+ duration_ms,
592
+ _truncate(resp.text, 2000),
593
+ )
594
+ resp.raise_for_status()
170
595
 
171
596
 
172
597
  if __name__ == "__main__":
598
+
173
599
  provider = WatsonXProvider(
174
- model_id="meta-llama/llama-3-2-90b-vision-instruct"
600
+ model_id="meta-llama/llama-3-2-90b-vision-instruct",
601
+ use_legacy_query=False, # set True to use legacy endpoint
602
+ system_prompt="You are a helpful assistant.",
175
603
  )
176
604
 
177
605
  prompt = """