ibm-watsonx-orchestrate-evaluation-framework 1.1.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/METADATA +19 -1
- ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +4 -2
- wxo_agentic_evaluation/analyze_run.py +1025 -220
- wxo_agentic_evaluation/annotate.py +2 -2
- wxo_agentic_evaluation/arg_configs.py +60 -2
- wxo_agentic_evaluation/base_user.py +25 -0
- wxo_agentic_evaluation/batch_annotate.py +19 -2
- wxo_agentic_evaluation/clients.py +103 -0
- wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
- wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
- wxo_agentic_evaluation/compare_runs/diff.py +554 -0
- wxo_agentic_evaluation/compare_runs/model.py +193 -0
- wxo_agentic_evaluation/data_annotator.py +25 -7
- wxo_agentic_evaluation/description_quality_checker.py +29 -6
- wxo_agentic_evaluation/evaluation.py +16 -8
- wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
- wxo_agentic_evaluation/evaluation_package.py +414 -69
- wxo_agentic_evaluation/external_agent/__init__.py +1 -1
- wxo_agentic_evaluation/external_agent/external_validate.py +7 -5
- wxo_agentic_evaluation/external_agent/types.py +3 -9
- wxo_agentic_evaluation/extractors/__init__.py +3 -0
- wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
- wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
- wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
- wxo_agentic_evaluation/langfuse_collection.py +60 -0
- wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
- wxo_agentic_evaluation/llm_matching.py +104 -2
- wxo_agentic_evaluation/llm_safety_eval.py +64 -0
- wxo_agentic_evaluation/llm_user.py +5 -4
- wxo_agentic_evaluation/llm_user_v2.py +114 -0
- wxo_agentic_evaluation/main.py +112 -343
- wxo_agentic_evaluation/metrics/__init__.py +15 -0
- wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
- wxo_agentic_evaluation/metrics/evaluations.py +107 -0
- wxo_agentic_evaluation/metrics/journey_success.py +137 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +26 -0
- wxo_agentic_evaluation/metrics/metrics.py +276 -8
- wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
- wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
- wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
- wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
- wxo_agentic_evaluation/otel_parser/parser.py +163 -0
- wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
- wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
- wxo_agentic_evaluation/otel_parser/utils.py +15 -0
- wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
- wxo_agentic_evaluation/otel_support/evaluate_tau.py +44 -10
- wxo_agentic_evaluation/otel_support/otel_message_conversion.py +12 -4
- wxo_agentic_evaluation/otel_support/tasks_test.py +456 -116
- wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
- wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +50 -4
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +1 -1
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
- wxo_agentic_evaluation/prompt/template_render.py +103 -4
- wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
- wxo_agentic_evaluation/quick_eval.py +33 -17
- wxo_agentic_evaluation/record_chat.py +38 -32
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +211 -62
- wxo_agentic_evaluation/red_teaming/attack_generator.py +63 -40
- wxo_agentic_evaluation/red_teaming/attack_list.py +95 -7
- wxo_agentic_evaluation/red_teaming/attack_runner.py +77 -17
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +10 -10
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +105 -39
- wxo_agentic_evaluation/resource_map.py +3 -1
- wxo_agentic_evaluation/runner.py +329 -0
- wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
- wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
- wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +24 -293
- wxo_agentic_evaluation/scheduler.py +247 -0
- wxo_agentic_evaluation/service_instance.py +26 -17
- wxo_agentic_evaluation/service_provider/__init__.py +145 -9
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +417 -17
- wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
- wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
- wxo_agentic_evaluation/service_provider/provider.py +130 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +481 -53
- wxo_agentic_evaluation/simluation_runner.py +125 -0
- wxo_agentic_evaluation/test_prompt.py +4 -4
- wxo_agentic_evaluation/type.py +185 -16
- wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
- wxo_agentic_evaluation/utils/__init__.py +44 -3
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/parsers.py +71 -0
- wxo_agentic_evaluation/utils/utils.py +313 -9
- wxo_agentic_evaluation/wxo_client.py +81 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info/RECORD +0 -102
- wxo_agentic_evaluation/otel_support/evaluate_tau_traces.py +0 -176
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,21 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
1
3
|
import os
|
|
2
4
|
import time
|
|
5
|
+
import uuid
|
|
3
6
|
from threading import Lock
|
|
4
|
-
from typing import List, Tuple
|
|
7
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
5
8
|
|
|
6
9
|
import requests
|
|
7
10
|
|
|
8
|
-
from wxo_agentic_evaluation.service_provider.provider import
|
|
11
|
+
from wxo_agentic_evaluation.service_provider.provider import (
|
|
12
|
+
ChatResult,
|
|
13
|
+
Provider,
|
|
14
|
+
)
|
|
9
15
|
from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
|
|
10
16
|
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
11
19
|
AUTH_ENDPOINT_AWS = (
|
|
12
20
|
"https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
|
|
13
21
|
)
|
|
@@ -15,7 +23,7 @@ AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
|
|
|
15
23
|
DEFAULT_PARAM = {
|
|
16
24
|
"min_new_tokens": 1,
|
|
17
25
|
"decoding_method": "greedy",
|
|
18
|
-
"max_new_tokens":
|
|
26
|
+
"max_new_tokens": 2500
|
|
19
27
|
}
|
|
20
28
|
|
|
21
29
|
|
|
@@ -36,17 +44,62 @@ def _normalize_cpd_auth_url(url: str) -> str:
|
|
|
36
44
|
return url
|
|
37
45
|
|
|
38
46
|
|
|
47
|
+
def _truncate(value: Any, max_len: int = 1000) -> str:
|
|
48
|
+
if value is None:
|
|
49
|
+
return ""
|
|
50
|
+
s = str(value)
|
|
51
|
+
return (
|
|
52
|
+
s
|
|
53
|
+
if len(s) <= max_len
|
|
54
|
+
else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _translate_params_to_chat(params: Dict[str, Any] | None) -> Dict[str, Any]:
|
|
59
|
+
# Translate legacy generation params to chat.completions params.
|
|
60
|
+
p = params or {}
|
|
61
|
+
out: Dict[str, Any] = {}
|
|
62
|
+
|
|
63
|
+
passthrough = {
|
|
64
|
+
"temperature",
|
|
65
|
+
"top_p",
|
|
66
|
+
"n",
|
|
67
|
+
"stream",
|
|
68
|
+
"stop",
|
|
69
|
+
"presence_penalty",
|
|
70
|
+
"frequency_penalty",
|
|
71
|
+
"logit_bias",
|
|
72
|
+
"user",
|
|
73
|
+
"max_tokens",
|
|
74
|
+
"seed",
|
|
75
|
+
"response_format",
|
|
76
|
+
}
|
|
77
|
+
for k in passthrough:
|
|
78
|
+
if k in p:
|
|
79
|
+
out[k] = p[k]
|
|
80
|
+
|
|
81
|
+
if "max_new_tokens" in p and "max_tokens" not in out:
|
|
82
|
+
out["max_tokens"] = p["max_new_tokens"]
|
|
83
|
+
|
|
84
|
+
return out
|
|
85
|
+
|
|
86
|
+
|
|
39
87
|
class ModelProxyProvider(Provider):
|
|
40
88
|
def __init__(
|
|
41
89
|
self,
|
|
42
|
-
model_id=None,
|
|
43
|
-
api_key=None,
|
|
44
|
-
instance_url=None,
|
|
45
|
-
timeout=300,
|
|
46
|
-
embedding_model_id=None,
|
|
47
|
-
params=None,
|
|
90
|
+
model_id: Optional[str] = None,
|
|
91
|
+
api_key: Optional[str] = None,
|
|
92
|
+
instance_url: Optional[str] = None,
|
|
93
|
+
timeout: int = 300,
|
|
94
|
+
embedding_model_id: Optional[str] = None,
|
|
95
|
+
params: Optional[Dict[str, Any]] = None,
|
|
96
|
+
use_legacy_query: Optional[
|
|
97
|
+
bool
|
|
98
|
+
] = None, # Provider routes query() to old/new based on this
|
|
99
|
+
system_prompt: Optional[str] = None,
|
|
100
|
+
token: Optional[str] = None,
|
|
48
101
|
):
|
|
49
|
-
super().__init__()
|
|
102
|
+
super().__init__(use_legacy_query=use_legacy_query)
|
|
50
103
|
|
|
51
104
|
instance_url = os.environ.get("WO_INSTANCE", instance_url)
|
|
52
105
|
if not instance_url:
|
|
@@ -56,6 +109,7 @@ class ModelProxyProvider(Provider):
|
|
|
56
109
|
|
|
57
110
|
self.timeout = timeout
|
|
58
111
|
self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
|
|
112
|
+
logger.info("[d b]Using inference model %s", self.model_id)
|
|
59
113
|
self.embedding_model_id = embedding_model_id
|
|
60
114
|
|
|
61
115
|
self.api_key = os.environ.get("WO_API_KEY", api_key)
|
|
@@ -106,14 +160,17 @@ class ModelProxyProvider(Provider):
|
|
|
106
160
|
"WO_API_KEY must be specified for SaaS or IBM IAM auth"
|
|
107
161
|
)
|
|
108
162
|
|
|
163
|
+
# Endpoints
|
|
109
164
|
self.url = (
|
|
110
165
|
self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
|
|
111
|
-
)
|
|
166
|
+
) # legacy
|
|
167
|
+
self.chat_url = self.instance_url + "/ml/v1/chat/completions" # chat
|
|
112
168
|
self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
|
|
113
169
|
|
|
114
170
|
self.lock = Lock()
|
|
115
171
|
self.token, self.refresh_time = self.get_token()
|
|
116
172
|
self.params = params if params else DEFAULT_PARAM
|
|
173
|
+
self.system_prompt = system_prompt
|
|
117
174
|
|
|
118
175
|
def _resolve_auth_mode_and_url(
|
|
119
176
|
self, explicit_auth_url: str | None
|
|
@@ -247,6 +304,7 @@ class ModelProxyProvider(Provider):
|
|
|
247
304
|
"inputs": sentences,
|
|
248
305
|
"model_id": self.embedding_model_id,
|
|
249
306
|
"space_id": self.space_id,
|
|
307
|
+
"max_token": 1000
|
|
250
308
|
}
|
|
251
309
|
# "timeout": self.timeout}
|
|
252
310
|
resp = requests.post(
|
|
@@ -254,6 +312,7 @@ class ModelProxyProvider(Provider):
|
|
|
254
312
|
json=payload,
|
|
255
313
|
headers=headers,
|
|
256
314
|
verify=self._wo_ssl_verify,
|
|
315
|
+
timeout=self.timeout,
|
|
257
316
|
)
|
|
258
317
|
|
|
259
318
|
if resp.status_code == 200:
|
|
@@ -262,9 +321,11 @@ class ModelProxyProvider(Provider):
|
|
|
262
321
|
|
|
263
322
|
resp.raise_for_status()
|
|
264
323
|
|
|
265
|
-
def
|
|
324
|
+
def old_query(self, sentence: str) -> str:
|
|
325
|
+
# Legacy /ml/v1/text/generation
|
|
266
326
|
if self.model_id is None:
|
|
267
327
|
raise Exception("model id must be specified for text generation")
|
|
328
|
+
|
|
268
329
|
self.refresh_token_if_expires()
|
|
269
330
|
headers = self.get_header()
|
|
270
331
|
payload = {
|
|
@@ -274,18 +335,357 @@ class ModelProxyProvider(Provider):
|
|
|
274
335
|
"timeout": self.timeout,
|
|
275
336
|
"parameters": self.params,
|
|
276
337
|
}
|
|
277
|
-
|
|
278
|
-
|
|
338
|
+
|
|
339
|
+
request_id = str(uuid.uuid4())
|
|
340
|
+
start_time = time.time()
|
|
341
|
+
|
|
342
|
+
# Input logging
|
|
343
|
+
logger.debug(
|
|
344
|
+
"[d][b]Sending text.generation request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
345
|
+
request_id,
|
|
346
|
+
self.url,
|
|
347
|
+
self.model_id,
|
|
348
|
+
self.space_id,
|
|
349
|
+
json.dumps(
|
|
350
|
+
payload.get("parameters", {}),
|
|
351
|
+
sort_keys=True,
|
|
352
|
+
ensure_ascii=False,
|
|
353
|
+
),
|
|
354
|
+
_truncate(sentence, 200),
|
|
279
355
|
)
|
|
280
|
-
if resp.status_code == 200:
|
|
281
|
-
return resp.json()["results"][0]["generated_text"]
|
|
282
356
|
|
|
283
|
-
resp
|
|
357
|
+
resp = None
|
|
358
|
+
try:
|
|
359
|
+
resp = requests.post(
|
|
360
|
+
self.url,
|
|
361
|
+
json=payload,
|
|
362
|
+
headers=headers,
|
|
363
|
+
verify=self._wo_ssl_verify,
|
|
364
|
+
timeout=self.timeout,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
368
|
+
resp.raise_for_status()
|
|
369
|
+
data = resp.json()
|
|
370
|
+
|
|
371
|
+
result = (
|
|
372
|
+
data["results"][0]
|
|
373
|
+
if "results" in data and data["results"]
|
|
374
|
+
else data
|
|
375
|
+
)
|
|
376
|
+
output_text = (
|
|
377
|
+
(
|
|
378
|
+
result.get("generated_text")
|
|
379
|
+
if isinstance(result, dict)
|
|
380
|
+
else None
|
|
381
|
+
)
|
|
382
|
+
or (result.get("message") if isinstance(result, dict) else None)
|
|
383
|
+
or ""
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Usage (best-effort)
|
|
387
|
+
usage = data.get("usage") or {}
|
|
388
|
+
if not usage and isinstance(result, dict):
|
|
389
|
+
in_tok = result.get("input_token_count")
|
|
390
|
+
out_tok = result.get("generated_token_count") or result.get(
|
|
391
|
+
"output_token_count"
|
|
392
|
+
)
|
|
393
|
+
if in_tok is not None or out_tok is not None:
|
|
394
|
+
usage = {
|
|
395
|
+
"prompt_tokens": in_tok,
|
|
396
|
+
"completion_tokens": out_tok,
|
|
397
|
+
"total_tokens": (in_tok or 0) + (out_tok or 0),
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
api_request_id = resp.headers.get(
|
|
401
|
+
"x-request-id"
|
|
402
|
+
) or resp.headers.get("request-id")
|
|
403
|
+
|
|
404
|
+
# Output logging
|
|
405
|
+
logger.debug(
|
|
406
|
+
"[d][b]text.generation response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s api_request_id=%s",
|
|
407
|
+
request_id,
|
|
408
|
+
resp.status_code,
|
|
409
|
+
duration_ms,
|
|
410
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
411
|
+
_truncate(output_text, 2000),
|
|
412
|
+
api_request_id,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
if output_text:
|
|
416
|
+
return output_text
|
|
417
|
+
else:
|
|
418
|
+
raise ValueError(
|
|
419
|
+
f"Unexpected response from legacy endpoint: {data}"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
except Exception as e:
|
|
423
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
424
|
+
status_code = getattr(resp, "status_code", None)
|
|
425
|
+
resp_text_preview = (
|
|
426
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
427
|
+
if resp is not None
|
|
428
|
+
else None
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
logger.exception(
|
|
432
|
+
"text.generation request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
433
|
+
request_id,
|
|
434
|
+
status_code,
|
|
435
|
+
duration_ms,
|
|
436
|
+
resp_text_preview,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
with self.lock:
|
|
440
|
+
if (
|
|
441
|
+
"authentication_token_expired" in str(e)
|
|
442
|
+
or status_code == 401
|
|
443
|
+
):
|
|
444
|
+
self.token, self.refresh_time = self.get_token()
|
|
445
|
+
raise
|
|
446
|
+
|
|
447
|
+
def new_query(self, sentence: str) -> str:
|
|
448
|
+
"""
|
|
449
|
+
New /ml/v1/chat/completions
|
|
450
|
+
Returns assistant message content of the first choice.
|
|
451
|
+
"""
|
|
452
|
+
if self.model_id is None:
|
|
453
|
+
raise Exception("model id must be specified for text generation")
|
|
454
|
+
|
|
455
|
+
self.refresh_token_if_expires()
|
|
456
|
+
headers = self.get_header()
|
|
457
|
+
|
|
458
|
+
messages: List[Dict[str, Any]] = []
|
|
459
|
+
if getattr(self, "system_prompt", None):
|
|
460
|
+
messages.append({"role": "system", "content": self.system_prompt})
|
|
461
|
+
messages.append(
|
|
462
|
+
{
|
|
463
|
+
"role": "user",
|
|
464
|
+
"content": [{"type": "text", "text": sentence}],
|
|
465
|
+
}
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
chat_params = _translate_params_to_chat(self.params)
|
|
469
|
+
if isinstance(self.params, dict) and "time_limit" in self.params:
|
|
470
|
+
chat_params["time_limit"] = self.params["time_limit"]
|
|
471
|
+
|
|
472
|
+
payload: Dict[str, Any] = {
|
|
473
|
+
"model_id": self.model_id,
|
|
474
|
+
"space_id": self.space_id,
|
|
475
|
+
"messages": messages,
|
|
476
|
+
**chat_params,
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
url = f"{self.instance_url}/ml/v1/text/chat?version=2024-10-08"
|
|
480
|
+
request_id = str(uuid.uuid4())
|
|
481
|
+
start_time = time.time()
|
|
482
|
+
|
|
483
|
+
logger.debug(
|
|
484
|
+
"[d][b]Sending chat.completions request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
485
|
+
request_id,
|
|
486
|
+
url,
|
|
487
|
+
self.model_id,
|
|
488
|
+
self.space_id,
|
|
489
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
490
|
+
_truncate(sentence, 200),
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
resp = None
|
|
494
|
+
try:
|
|
495
|
+
resp = requests.post(
|
|
496
|
+
url,
|
|
497
|
+
json=payload,
|
|
498
|
+
headers=headers,
|
|
499
|
+
verify=self._wo_ssl_verify,
|
|
500
|
+
timeout=self.timeout,
|
|
501
|
+
)
|
|
502
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
503
|
+
resp.raise_for_status()
|
|
504
|
+
data = resp.json()
|
|
505
|
+
|
|
506
|
+
choice = data["choices"][0]
|
|
507
|
+
content = choice["message"]["content"]
|
|
508
|
+
finish_reason = choice.get("finish_reason")
|
|
509
|
+
usage = data.get("usage", {})
|
|
510
|
+
api_request_id = resp.headers.get(
|
|
511
|
+
"x-request-id"
|
|
512
|
+
) or resp.headers.get("request-id")
|
|
513
|
+
|
|
514
|
+
logger.debug(
|
|
515
|
+
"[d][b]chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
516
|
+
request_id,
|
|
517
|
+
resp.status_code,
|
|
518
|
+
duration_ms,
|
|
519
|
+
finish_reason,
|
|
520
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
521
|
+
_truncate(content, 2000),
|
|
522
|
+
api_request_id,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
return content
|
|
526
|
+
|
|
527
|
+
except Exception as e:
|
|
528
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
529
|
+
status_code = getattr(resp, "status_code", None)
|
|
530
|
+
resp_text_preview = (
|
|
531
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
532
|
+
if resp is not None
|
|
533
|
+
else None
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
logger.exception(
|
|
537
|
+
"chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
538
|
+
request_id,
|
|
539
|
+
status_code,
|
|
540
|
+
duration_ms,
|
|
541
|
+
resp_text_preview,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
with self.lock:
|
|
545
|
+
if (
|
|
546
|
+
"authentication_token_expired" in str(e)
|
|
547
|
+
or status_code == 401
|
|
548
|
+
):
|
|
549
|
+
self.token, self.refresh_time = self.get_token()
|
|
550
|
+
raise
|
|
551
|
+
|
|
552
|
+
def chat(
|
|
553
|
+
self,
|
|
554
|
+
messages: Sequence[Dict[str, str]],
|
|
555
|
+
params: Optional[Dict[str, Any]] = None,
|
|
556
|
+
) -> ChatResult:
|
|
557
|
+
# Non-streaming chat using /ml/v1/chat/completions.
|
|
558
|
+
if self.model_id is None:
|
|
559
|
+
raise Exception("model id must be specified for chat")
|
|
560
|
+
|
|
561
|
+
self.refresh_token_if_expires()
|
|
562
|
+
headers = self.get_header()
|
|
563
|
+
|
|
564
|
+
# Convert messages to watsonx format: user content is typed list
|
|
565
|
+
wx_messages: List[Dict[str, Any]] = []
|
|
566
|
+
for m in messages:
|
|
567
|
+
role = m.get("role")
|
|
568
|
+
content = m.get("content", "")
|
|
569
|
+
if role == "user" and isinstance(content, str):
|
|
570
|
+
wx_messages.append(
|
|
571
|
+
{
|
|
572
|
+
"role": "user",
|
|
573
|
+
"content": [{"type": "text", "text": content}],
|
|
574
|
+
}
|
|
575
|
+
)
|
|
576
|
+
else:
|
|
577
|
+
wx_messages.append({"role": role, "content": content})
|
|
578
|
+
|
|
579
|
+
merged_params = dict(self.params or {})
|
|
580
|
+
if params:
|
|
581
|
+
merged_params.update(params)
|
|
582
|
+
chat_params = _translate_params_to_chat(merged_params)
|
|
583
|
+
chat_params.pop("stream", None) # force non-streaming
|
|
584
|
+
if "time_limit" in merged_params:
|
|
585
|
+
chat_params["time_limit"] = merged_params["time_limit"]
|
|
586
|
+
|
|
587
|
+
payload: Dict[str, Any] = {
|
|
588
|
+
"model_id": self.model_id,
|
|
589
|
+
"space_id": self.space_id,
|
|
590
|
+
"messages": wx_messages,
|
|
591
|
+
**chat_params,
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
url = f"{self.instance_url}/ml/v1/text/chat?version=2024-10-08"
|
|
595
|
+
|
|
596
|
+
last_user = next(
|
|
597
|
+
(
|
|
598
|
+
m.get("content", "")
|
|
599
|
+
for m in reversed(messages)
|
|
600
|
+
if m.get("role") == "user"
|
|
601
|
+
),
|
|
602
|
+
"",
|
|
603
|
+
)
|
|
604
|
+
request_id = str(uuid.uuid4())
|
|
605
|
+
start_time = time.time()
|
|
606
|
+
|
|
607
|
+
logger.debug(
|
|
608
|
+
"[d][b]Sending chat.completions request (non-streaming) | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
609
|
+
request_id,
|
|
610
|
+
url,
|
|
611
|
+
self.model_id,
|
|
612
|
+
self.space_id,
|
|
613
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
614
|
+
_truncate(last_user, 200),
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
resp = None
|
|
618
|
+
try:
|
|
619
|
+
resp = requests.post(
|
|
620
|
+
url,
|
|
621
|
+
json=payload,
|
|
622
|
+
headers=headers,
|
|
623
|
+
verify=self._wo_ssl_verify,
|
|
624
|
+
timeout=self.timeout,
|
|
625
|
+
)
|
|
626
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
627
|
+
resp.raise_for_status()
|
|
628
|
+
data = resp.json()
|
|
629
|
+
|
|
630
|
+
choice = data["choices"][0]
|
|
631
|
+
content = choice["message"]["content"]
|
|
632
|
+
finish_reason = choice.get("finish_reason")
|
|
633
|
+
usage = data.get("usage", {})
|
|
634
|
+
api_request_id = resp.headers.get(
|
|
635
|
+
"x-request-id"
|
|
636
|
+
) or resp.headers.get("request-id")
|
|
637
|
+
|
|
638
|
+
logger.debug(
|
|
639
|
+
"[d][b]chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
640
|
+
request_id,
|
|
641
|
+
resp.status_code,
|
|
642
|
+
duration_ms,
|
|
643
|
+
finish_reason,
|
|
644
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
645
|
+
_truncate(content, 2000),
|
|
646
|
+
api_request_id,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
return ChatResult(
|
|
650
|
+
text=content, usage=usage, finish_reason=finish_reason, raw=data
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
except Exception as e:
|
|
654
|
+
duration_ms = int((time.time() - start_time) * 1000)
|
|
655
|
+
status_code = getattr(resp, "status_code", None)
|
|
656
|
+
resp_text_preview = (
|
|
657
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
658
|
+
if resp is not None
|
|
659
|
+
else None
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
logger.exception(
|
|
663
|
+
"chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
664
|
+
request_id,
|
|
665
|
+
status_code,
|
|
666
|
+
duration_ms,
|
|
667
|
+
resp_text_preview,
|
|
668
|
+
)
|
|
669
|
+
with self.lock:
|
|
670
|
+
if (
|
|
671
|
+
"authentication_token_expired" in str(e)
|
|
672
|
+
or status_code == 401
|
|
673
|
+
):
|
|
674
|
+
self.token, self.refresh_time = self.get_token()
|
|
675
|
+
raise
|
|
284
676
|
|
|
285
677
|
|
|
286
678
|
if __name__ == "__main__":
|
|
679
|
+
logging.basicConfig(
|
|
680
|
+
level=logging.INFO,
|
|
681
|
+
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
682
|
+
)
|
|
683
|
+
|
|
287
684
|
provider = ModelProxyProvider(
|
|
288
685
|
model_id="meta-llama/llama-3-3-70b-instruct",
|
|
289
686
|
embedding_model_id="ibm/slate-30m-english-rtrvr",
|
|
687
|
+
use_legacy_query=False,
|
|
688
|
+
system_prompt="",
|
|
290
689
|
)
|
|
690
|
+
# Base class will route .query() to new_query() by default (unless USE_LEGACY_QUERY=true)
|
|
291
691
|
print(provider.query("ok"))
|