ibm-watsonx-orchestrate-evaluation-framework 1.0.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/METADATA +53 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
- wxo_agentic_evaluation/analytics/tools/analyzer.py +38 -21
- wxo_agentic_evaluation/analytics/tools/main.py +19 -25
- wxo_agentic_evaluation/analytics/tools/types.py +26 -11
- wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
- wxo_agentic_evaluation/analyze_run.py +1184 -97
- wxo_agentic_evaluation/annotate.py +7 -5
- wxo_agentic_evaluation/arg_configs.py +97 -5
- wxo_agentic_evaluation/base_user.py +25 -0
- wxo_agentic_evaluation/batch_annotate.py +97 -27
- wxo_agentic_evaluation/clients.py +103 -0
- wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
- wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
- wxo_agentic_evaluation/compare_runs/diff.py +554 -0
- wxo_agentic_evaluation/compare_runs/model.py +193 -0
- wxo_agentic_evaluation/data_annotator.py +45 -19
- wxo_agentic_evaluation/description_quality_checker.py +178 -0
- wxo_agentic_evaluation/evaluation.py +50 -0
- wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
- wxo_agentic_evaluation/evaluation_package.py +544 -107
- wxo_agentic_evaluation/external_agent/__init__.py +18 -7
- wxo_agentic_evaluation/external_agent/external_validate.py +49 -36
- wxo_agentic_evaluation/external_agent/performance_test.py +33 -22
- wxo_agentic_evaluation/external_agent/types.py +8 -7
- wxo_agentic_evaluation/extractors/__init__.py +3 -0
- wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
- wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
- wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
- wxo_agentic_evaluation/langfuse_collection.py +60 -0
- wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
- wxo_agentic_evaluation/llm_matching.py +108 -5
- wxo_agentic_evaluation/llm_rag_eval.py +7 -4
- wxo_agentic_evaluation/llm_safety_eval.py +64 -0
- wxo_agentic_evaluation/llm_user.py +12 -6
- wxo_agentic_evaluation/llm_user_v2.py +114 -0
- wxo_agentic_evaluation/main.py +128 -246
- wxo_agentic_evaluation/metrics/__init__.py +15 -0
- wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
- wxo_agentic_evaluation/metrics/evaluations.py +107 -0
- wxo_agentic_evaluation/metrics/journey_success.py +137 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +28 -2
- wxo_agentic_evaluation/metrics/metrics.py +319 -16
- wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
- wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
- wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
- wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
- wxo_agentic_evaluation/otel_parser/parser.py +163 -0
- wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
- wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
- wxo_agentic_evaluation/otel_parser/utils.py +15 -0
- wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
- wxo_agentic_evaluation/otel_support/evaluate_tau.py +101 -0
- wxo_agentic_evaluation/otel_support/otel_message_conversion.py +29 -0
- wxo_agentic_evaluation/otel_support/tasks_test.py +1566 -0
- wxo_agentic_evaluation/prompt/bad_tool_descriptions_prompt.jinja2 +178 -0
- wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
- wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +59 -5
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +34 -0
- wxo_agentic_evaluation/prompt/on_policy_attack_generation_prompt.jinja2 +46 -0
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
- wxo_agentic_evaluation/prompt/template_render.py +163 -12
- wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
- wxo_agentic_evaluation/quick_eval.py +384 -0
- wxo_agentic_evaluation/record_chat.py +132 -81
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +302 -0
- wxo_agentic_evaluation/red_teaming/attack_generator.py +329 -0
- wxo_agentic_evaluation/red_teaming/attack_list.py +184 -0
- wxo_agentic_evaluation/red_teaming/attack_runner.py +204 -0
- wxo_agentic_evaluation/referenceless_eval/__init__.py +3 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py +28 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +29 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general.py +49 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection.py +31 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +245 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +106 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +291 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +465 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +162 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py +509 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +562 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/__init__.py +3 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/field.py +266 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +344 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +193 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +413 -0
- wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +46 -0
- wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py +0 -0
- wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +158 -0
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +191 -0
- wxo_agentic_evaluation/resource_map.py +6 -3
- wxo_agentic_evaluation/runner.py +329 -0
- wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
- wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
- wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +88 -150
- wxo_agentic_evaluation/scheduler.py +247 -0
- wxo_agentic_evaluation/service_instance.py +117 -26
- wxo_agentic_evaluation/service_provider/__init__.py +182 -17
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +628 -45
- wxo_agentic_evaluation/service_provider/ollama_provider.py +392 -22
- wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
- wxo_agentic_evaluation/service_provider/provider.py +129 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +203 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +516 -53
- wxo_agentic_evaluation/simluation_runner.py +125 -0
- wxo_agentic_evaluation/test_prompt.py +4 -4
- wxo_agentic_evaluation/tool_planner.py +141 -46
- wxo_agentic_evaluation/type.py +217 -14
- wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
- wxo_agentic_evaluation/utils/__init__.py +44 -3
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +178 -0
- wxo_agentic_evaluation/utils/parsers.py +71 -0
- wxo_agentic_evaluation/utils/rich_utils.py +188 -0
- wxo_agentic_evaluation/utils/rouge_score.py +23 -0
- wxo_agentic_evaluation/utils/utils.py +514 -17
- wxo_agentic_evaluation/wxo_client.py +81 -0
- ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/METADATA +0 -380
- ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +0 -56
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,23 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import requests
|
|
3
|
-
import json
|
|
4
|
-
from types import MappingProxyType
|
|
5
|
-
from typing import List
|
|
6
1
|
import dataclasses
|
|
7
|
-
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
8
5
|
import time
|
|
9
|
-
|
|
6
|
+
import uuid
|
|
7
|
+
from threading import Lock
|
|
8
|
+
from types import MappingProxyType
|
|
9
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
from wxo_agentic_evaluation.service_provider.provider import (
|
|
14
|
+
ChatResult,
|
|
15
|
+
Provider,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
10
19
|
|
|
20
|
+
# IAM
|
|
11
21
|
ACCESS_URL = "https://iam.cloud.ibm.com/identity/token"
|
|
12
22
|
ACCESS_HEADER = {
|
|
13
23
|
"content-type": "application/x-www-form-urlencoded",
|
|
@@ -16,28 +26,83 @@ ACCESS_HEADER = {
|
|
|
16
26
|
|
|
17
27
|
YPQA_URL = "https://yp-qa.ml.cloud.ibm.com"
|
|
18
28
|
PROD_URL = "https://us-south.ml.cloud.ibm.com"
|
|
29
|
+
|
|
19
30
|
DEFAULT_PARAM = MappingProxyType(
|
|
20
31
|
{"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
|
|
21
32
|
)
|
|
22
33
|
|
|
23
34
|
|
|
35
|
+
def _truncate(value: Any, max_len: int = 1000) -> str:
|
|
36
|
+
if value is None:
|
|
37
|
+
return ""
|
|
38
|
+
s = str(value)
|
|
39
|
+
return (
|
|
40
|
+
s
|
|
41
|
+
if len(s) <= max_len
|
|
42
|
+
else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _translate_params_to_chat(
|
|
47
|
+
params: Dict[str, Any] = {},
|
|
48
|
+
) -> Dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
Translate legacy generation params to chat.completions params.
|
|
51
|
+
"""
|
|
52
|
+
translated_params: Dict[str, Any] = {}
|
|
53
|
+
|
|
54
|
+
if "max_new_tokens" in params:
|
|
55
|
+
translated_params["max_tokens"] = params["max_new_tokens"]
|
|
56
|
+
|
|
57
|
+
if params.get("decoding_method") == "greedy":
|
|
58
|
+
translated_params.setdefault("temperature", 0)
|
|
59
|
+
translated_params.setdefault("top_p", 1)
|
|
60
|
+
|
|
61
|
+
passthrough = {
|
|
62
|
+
"temperature",
|
|
63
|
+
"top_p",
|
|
64
|
+
"n",
|
|
65
|
+
"stream",
|
|
66
|
+
"stop",
|
|
67
|
+
"presence_penalty",
|
|
68
|
+
"frequency_penalty",
|
|
69
|
+
"logit_bias",
|
|
70
|
+
"user",
|
|
71
|
+
"seed",
|
|
72
|
+
"response_format",
|
|
73
|
+
}
|
|
74
|
+
for k in passthrough:
|
|
75
|
+
if k in params:
|
|
76
|
+
translated_params[k] = params[k]
|
|
77
|
+
|
|
78
|
+
return translated_params
|
|
79
|
+
|
|
80
|
+
|
|
24
81
|
class WatsonXProvider(Provider):
|
|
25
82
|
def __init__(
|
|
26
83
|
self,
|
|
27
|
-
model_id=None,
|
|
28
|
-
api_key=None,
|
|
29
|
-
space_id=None,
|
|
30
|
-
api_endpoint=PROD_URL,
|
|
31
|
-
url=ACCESS_URL,
|
|
32
|
-
timeout=60,
|
|
33
|
-
params=None,
|
|
34
|
-
embedding_model_id=None,
|
|
84
|
+
model_id: Optional[str] = None,
|
|
85
|
+
api_key: Optional[str] = None,
|
|
86
|
+
space_id: Optional[str] = None,
|
|
87
|
+
api_endpoint: str = PROD_URL,
|
|
88
|
+
url: str = ACCESS_URL,
|
|
89
|
+
timeout: int = 60,
|
|
90
|
+
params: Optional[Any] = None,
|
|
91
|
+
embedding_model_id: Optional[str] = None,
|
|
92
|
+
use_legacy_query: Optional[bool] = None,
|
|
93
|
+
system_prompt: Optional[str] = None,
|
|
94
|
+
token: Optional[str] = None,
|
|
95
|
+
instance_url: Optional[str] = None,
|
|
35
96
|
):
|
|
36
|
-
super().__init__()
|
|
97
|
+
super().__init__(use_legacy_query=use_legacy_query)
|
|
98
|
+
|
|
37
99
|
self.url = url
|
|
38
100
|
if (embedding_model_id is None) and (model_id is None):
|
|
39
|
-
raise Exception(
|
|
101
|
+
raise Exception(
|
|
102
|
+
"either model_id or embedding_model_id must be specified"
|
|
103
|
+
)
|
|
40
104
|
self.model_id = model_id
|
|
105
|
+
logger.info("[d b]Using inference model %s", self.model_id)
|
|
41
106
|
api_key = os.environ.get("WATSONX_APIKEY", api_key)
|
|
42
107
|
if not api_key:
|
|
43
108
|
raise Exception("apikey must be specified")
|
|
@@ -46,7 +111,7 @@ class WatsonXProvider(Provider):
|
|
|
46
111
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
|
47
112
|
"apikey": self.api_key,
|
|
48
113
|
}
|
|
49
|
-
self.api_endpoint = api_endpoint
|
|
114
|
+
self.api_endpoint = (api_endpoint or PROD_URL).rstrip("/")
|
|
50
115
|
space_id = os.environ.get("WATSONX_SPACE_ID", space_id)
|
|
51
116
|
if not space_id:
|
|
52
117
|
raise Exception("space id must be specified")
|
|
@@ -55,20 +120,32 @@ class WatsonXProvider(Provider):
|
|
|
55
120
|
self.embedding_model_id = embedding_model_id
|
|
56
121
|
self.lock = Lock()
|
|
57
122
|
|
|
58
|
-
self.params = params if params else DEFAULT_PARAM
|
|
59
|
-
|
|
123
|
+
self.params = params if params is not None else DEFAULT_PARAM
|
|
60
124
|
if isinstance(self.params, MappingProxyType):
|
|
61
125
|
self.params = dict(self.params)
|
|
62
126
|
if dataclasses.is_dataclass(self.params):
|
|
63
127
|
self.params = dataclasses.asdict(self.params)
|
|
64
128
|
|
|
129
|
+
self.system_prompt = system_prompt
|
|
130
|
+
|
|
65
131
|
self.refresh_time = None
|
|
66
132
|
self.access_token = None
|
|
67
133
|
self._refresh_token()
|
|
68
134
|
|
|
135
|
+
self.LEGACY_GEN_URL = (
|
|
136
|
+
f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
|
|
137
|
+
)
|
|
138
|
+
self.CHAT_COMPLETIONS_URL = f"{self.api_endpoint}/ml/v1/text/chat"
|
|
139
|
+
self.EMBEDDINGS_URL = (
|
|
140
|
+
f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
|
|
141
|
+
)
|
|
142
|
+
|
|
69
143
|
def _get_access_token(self):
|
|
70
144
|
response = requests.post(
|
|
71
|
-
self.url,
|
|
145
|
+
self.url,
|
|
146
|
+
headers=ACCESS_HEADER,
|
|
147
|
+
data=self.access_data,
|
|
148
|
+
timeout=self.timeout,
|
|
72
149
|
)
|
|
73
150
|
if response.status_code == 200:
|
|
74
151
|
token_data = json.loads(response.text)
|
|
@@ -80,64 +157,450 @@ class WatsonXProvider(Provider):
|
|
|
80
157
|
return token, refresh_time
|
|
81
158
|
|
|
82
159
|
raise RuntimeError(
|
|
83
|
-
f"
|
|
160
|
+
f"Try to acquire access token and get {response.status_code}. Reason: {response.text} "
|
|
84
161
|
)
|
|
85
162
|
|
|
86
163
|
def prepare_header(self):
|
|
87
|
-
headers = {
|
|
88
|
-
|
|
164
|
+
headers = {
|
|
165
|
+
"Authorization": f"Bearer {self.access_token}",
|
|
166
|
+
"Content-Type": "application/json",
|
|
167
|
+
}
|
|
89
168
|
return headers
|
|
90
169
|
|
|
91
|
-
def generate(self, sentence: str):
|
|
92
|
-
headers = self.prepare_header()
|
|
93
|
-
|
|
94
|
-
data = {"model_id": self.model_id, "input": sentence,
|
|
95
|
-
"parameters": self.params, "space_id": self.space_id}
|
|
96
|
-
generation_url = f"{self.api_endpoint}/ml/v1/text/generation?version=2023-05-02"
|
|
97
|
-
resp = requests.post(url=generation_url, headers=headers, json=data)
|
|
98
|
-
if resp.status_code == 200:
|
|
99
|
-
return resp.json()["results"][0]
|
|
100
|
-
else:
|
|
101
|
-
resp.raise_for_status()
|
|
102
|
-
|
|
103
170
|
def _refresh_token(self):
|
|
104
171
|
# if we do not have a token or the current timestamp is 9 minutes away from expire.
|
|
105
172
|
if not self.access_token or time.time() > self.refresh_time:
|
|
106
173
|
with self.lock:
|
|
107
174
|
if not self.access_token or time.time() > self.refresh_time:
|
|
108
|
-
|
|
175
|
+
(
|
|
176
|
+
self.access_token,
|
|
177
|
+
self.refresh_time,
|
|
178
|
+
) = self._get_access_token()
|
|
179
|
+
|
|
180
|
+
def old_query(self, sentence: Union[str, Mapping[str, str]]) -> str:
|
|
181
|
+
"""
|
|
182
|
+
Legacy /ml/v1/text/generation
|
|
183
|
+
"""
|
|
184
|
+
if self.model_id is None:
|
|
185
|
+
raise Exception("model id must be specified for text generation")
|
|
186
|
+
|
|
187
|
+
self._refresh_token()
|
|
188
|
+
headers = self.prepare_header()
|
|
189
|
+
|
|
190
|
+
payload: Dict[str, Any] = {
|
|
191
|
+
"model_id": self.model_id,
|
|
192
|
+
"input": sentence,
|
|
193
|
+
"parameters": self.params or {},
|
|
194
|
+
"space_id": self.space_id,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
request_id = str(uuid.uuid4())
|
|
198
|
+
t0 = time.time()
|
|
199
|
+
|
|
200
|
+
logger.debug(
|
|
201
|
+
"[d][b]Sending text.generation request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
202
|
+
request_id,
|
|
203
|
+
self.LEGACY_GEN_URL,
|
|
204
|
+
self.model_id,
|
|
205
|
+
self.space_id,
|
|
206
|
+
json.dumps(
|
|
207
|
+
payload.get("parameters", {}),
|
|
208
|
+
sort_keys=True,
|
|
209
|
+
ensure_ascii=False,
|
|
210
|
+
),
|
|
211
|
+
_truncate(sentence, 200),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
resp = None
|
|
215
|
+
try:
|
|
216
|
+
resp = requests.post(
|
|
217
|
+
url=self.LEGACY_GEN_URL,
|
|
218
|
+
headers=headers,
|
|
219
|
+
json=payload,
|
|
220
|
+
timeout=self.timeout,
|
|
221
|
+
)
|
|
222
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
223
|
+
resp.raise_for_status()
|
|
224
|
+
data = resp.json()
|
|
225
|
+
|
|
226
|
+
if isinstance(data, dict) and "results" in data and data["results"]:
|
|
227
|
+
result = data["results"][0]
|
|
228
|
+
elif isinstance(data, dict):
|
|
229
|
+
result = data
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Unexpected response type from WatsonX: {type(data)}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
output_text = ""
|
|
236
|
+
if isinstance(result, dict):
|
|
237
|
+
output_text = (
|
|
238
|
+
result.get("generated_text") or result.get("message") or ""
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
usage = data.get("usage") or {}
|
|
242
|
+
if not usage and isinstance(result, dict):
|
|
243
|
+
in_tok = result.get("input_token_count")
|
|
244
|
+
out_tok = result.get("generated_token_count") or result.get(
|
|
245
|
+
"output_token_count"
|
|
246
|
+
)
|
|
247
|
+
if in_tok is not None or out_tok is not None:
|
|
248
|
+
usage = {
|
|
249
|
+
"prompt_tokens": in_tok,
|
|
250
|
+
"completion_tokens": out_tok,
|
|
251
|
+
"total_tokens": (in_tok or 0) + (out_tok or 0),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
api_request_id = resp.headers.get(
|
|
255
|
+
"x-request-id"
|
|
256
|
+
) or resp.headers.get("request-id")
|
|
257
|
+
|
|
258
|
+
logger.debug(
|
|
259
|
+
"[d][b]text.generation response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s api_request_id=%s",
|
|
260
|
+
request_id,
|
|
261
|
+
resp.status_code,
|
|
262
|
+
duration_ms,
|
|
263
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
264
|
+
_truncate(output_text, 2000),
|
|
265
|
+
api_request_id,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if output_text:
|
|
269
|
+
return output_text
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"Unexpected response from legacy endpoint: {data}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
except Exception as e:
|
|
275
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
276
|
+
status_code = getattr(resp, "status_code", None)
|
|
277
|
+
resp_text_preview = (
|
|
278
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
279
|
+
if resp is not None
|
|
280
|
+
else None
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
logger.exception(
|
|
284
|
+
"text.generation request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
285
|
+
request_id,
|
|
286
|
+
status_code,
|
|
287
|
+
duration_ms,
|
|
288
|
+
resp_text_preview,
|
|
289
|
+
)
|
|
290
|
+
with self.lock:
|
|
291
|
+
if (
|
|
292
|
+
"authentication_token_expired" in str(e)
|
|
293
|
+
or status_code == 401
|
|
294
|
+
):
|
|
295
|
+
try:
|
|
296
|
+
self.access_token, self.refresh_time = (
|
|
297
|
+
self._get_access_token()
|
|
298
|
+
)
|
|
299
|
+
except Exception:
|
|
300
|
+
pass
|
|
301
|
+
raise
|
|
109
302
|
|
|
110
|
-
def
|
|
303
|
+
def new_query(self, sentence: str) -> str:
|
|
304
|
+
"""
|
|
305
|
+
/ml/v1/text/chat
|
|
306
|
+
Returns assistant content as a plain string.
|
|
307
|
+
"""
|
|
111
308
|
if self.model_id is None:
|
|
112
309
|
raise Exception("model id must be specified for text generation")
|
|
310
|
+
|
|
311
|
+
self._refresh_token()
|
|
312
|
+
headers = self.prepare_header()
|
|
313
|
+
|
|
314
|
+
messages: List[Dict[str, Any]] = []
|
|
315
|
+
if getattr(self, "system_prompt", None):
|
|
316
|
+
messages.append({"role": "system", "content": self.system_prompt})
|
|
317
|
+
messages.append(
|
|
318
|
+
{
|
|
319
|
+
"role": "user",
|
|
320
|
+
"content": [
|
|
321
|
+
{
|
|
322
|
+
"type": "text",
|
|
323
|
+
"text": sentence,
|
|
324
|
+
}
|
|
325
|
+
],
|
|
326
|
+
}
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
chat_params = _translate_params_to_chat(self.params)
|
|
330
|
+
if "time_limit" in self.params:
|
|
331
|
+
chat_params["time_limit"] = self.params["time_limit"]
|
|
332
|
+
|
|
333
|
+
payload: Dict[str, Any] = {
|
|
334
|
+
"model_id": self.model_id,
|
|
335
|
+
"space_id": self.space_id,
|
|
336
|
+
"messages": messages,
|
|
337
|
+
**chat_params,
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
|
|
341
|
+
request_id = str(uuid.uuid4())
|
|
342
|
+
t0 = time.time()
|
|
343
|
+
|
|
344
|
+
logger.debug(
|
|
345
|
+
"[d][b]Sending chat.completions request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
346
|
+
request_id,
|
|
347
|
+
url,
|
|
348
|
+
self.model_id,
|
|
349
|
+
self.space_id,
|
|
350
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
351
|
+
_truncate(sentence, 200),
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
resp = None
|
|
113
355
|
try:
|
|
114
|
-
|
|
356
|
+
resp = requests.post(
|
|
357
|
+
url=url, headers=headers, json=payload, timeout=self.timeout
|
|
358
|
+
)
|
|
359
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
360
|
+
resp.raise_for_status()
|
|
361
|
+
data = resp.json()
|
|
362
|
+
|
|
363
|
+
choice = data["choices"][0]
|
|
364
|
+
content = choice["message"]["content"]
|
|
365
|
+
finish_reason = choice.get("finish_reason")
|
|
366
|
+
usage = data.get("usage", {})
|
|
367
|
+
api_request_id = resp.headers.get(
|
|
368
|
+
"x-request-id"
|
|
369
|
+
) or resp.headers.get("request-id")
|
|
370
|
+
|
|
371
|
+
logger.debug(
|
|
372
|
+
"[d][b]chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
373
|
+
request_id,
|
|
374
|
+
resp.status_code,
|
|
375
|
+
duration_ms,
|
|
376
|
+
finish_reason,
|
|
377
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
378
|
+
_truncate(content, 2000),
|
|
379
|
+
api_request_id,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
return content
|
|
383
|
+
|
|
115
384
|
except Exception as e:
|
|
385
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
386
|
+
status_code = getattr(resp, "status_code", None)
|
|
387
|
+
resp_text_preview = (
|
|
388
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
389
|
+
if resp is not None
|
|
390
|
+
else None
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
logger.exception(
|
|
394
|
+
"chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
395
|
+
request_id,
|
|
396
|
+
status_code,
|
|
397
|
+
duration_ms,
|
|
398
|
+
resp_text_preview,
|
|
399
|
+
)
|
|
116
400
|
with self.lock:
|
|
117
|
-
if
|
|
118
|
-
|
|
119
|
-
|
|
401
|
+
if (
|
|
402
|
+
"authentication_token_expired" in str(e)
|
|
403
|
+
or status_code == 401
|
|
404
|
+
):
|
|
405
|
+
try:
|
|
406
|
+
self.access_token, self.refresh_time = (
|
|
407
|
+
self._get_access_token()
|
|
408
|
+
)
|
|
409
|
+
except Exception:
|
|
410
|
+
pass
|
|
411
|
+
raise
|
|
412
|
+
|
|
413
|
+
def chat(
|
|
414
|
+
self,
|
|
415
|
+
messages: Sequence[Dict[str, str]],
|
|
416
|
+
params: Optional[Dict[str, Any]] = None,
|
|
417
|
+
) -> ChatResult:
|
|
418
|
+
"""
|
|
419
|
+
Sends a multi-message chat request to /ml/v1/text/chat
|
|
420
|
+
Returns ChatResult with text, usage, finish_reason, and raw response.
|
|
421
|
+
"""
|
|
422
|
+
if self.model_id is None:
|
|
423
|
+
raise Exception("model id must be specified for chat")
|
|
424
|
+
|
|
425
|
+
self._refresh_token()
|
|
426
|
+
headers = self.prepare_header()
|
|
427
|
+
|
|
428
|
+
wx_messages: List[Dict[str, Any]] = []
|
|
429
|
+
for m in messages:
|
|
430
|
+
role = m.get("role")
|
|
431
|
+
content = m.get("content", "")
|
|
432
|
+
if role == "user" and isinstance(content, str):
|
|
433
|
+
wx_messages.append(
|
|
434
|
+
{
|
|
435
|
+
"role": "user",
|
|
436
|
+
"content": [{"type": "text", "text": content}],
|
|
437
|
+
}
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
wx_messages.append({"role": role, "content": content})
|
|
120
441
|
|
|
121
|
-
|
|
122
|
-
|
|
442
|
+
merged_params = dict(self.params or {})
|
|
443
|
+
if params:
|
|
444
|
+
merged_params.update(params)
|
|
445
|
+
chat_params = _translate_params_to_chat(merged_params)
|
|
446
|
+
chat_params.pop("stream", None)
|
|
447
|
+
if "time_limit" in merged_params:
|
|
448
|
+
chat_params["time_limit"] = merged_params["time_limit"]
|
|
449
|
+
|
|
450
|
+
payload: Dict[str, Any] = {
|
|
451
|
+
"model_id": self.model_id,
|
|
452
|
+
"space_id": self.space_id,
|
|
453
|
+
"messages": wx_messages,
|
|
454
|
+
**chat_params,
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
url = f"{self.CHAT_COMPLETIONS_URL}?version=2024-10-08"
|
|
458
|
+
request_id = str(uuid.uuid4())
|
|
459
|
+
t0 = time.time()
|
|
460
|
+
|
|
461
|
+
last_user = next(
|
|
462
|
+
(
|
|
463
|
+
m.get("content", "")
|
|
464
|
+
for m in reversed(messages)
|
|
465
|
+
if m.get("role") == "user"
|
|
466
|
+
),
|
|
467
|
+
"",
|
|
468
|
+
)
|
|
469
|
+
logger.debug(
|
|
470
|
+
"[d][b]Sending chat.completions request (non-streaming) | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
|
|
471
|
+
request_id,
|
|
472
|
+
url,
|
|
473
|
+
self.model_id,
|
|
474
|
+
self.space_id,
|
|
475
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
476
|
+
_truncate(last_user, 200),
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
resp = None
|
|
480
|
+
try:
|
|
481
|
+
resp = requests.post(
|
|
482
|
+
url=url, headers=headers, json=payload, timeout=self.timeout
|
|
483
|
+
)
|
|
484
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
485
|
+
resp.raise_for_status()
|
|
486
|
+
data = resp.json()
|
|
487
|
+
|
|
488
|
+
choice = data["choices"][0]
|
|
489
|
+
content = choice["message"]["content"]
|
|
490
|
+
finish_reason = choice.get("finish_reason")
|
|
491
|
+
usage = data.get("usage", {})
|
|
492
|
+
api_request_id = resp.headers.get(
|
|
493
|
+
"x-request-id"
|
|
494
|
+
) or resp.headers.get("request-id")
|
|
495
|
+
|
|
496
|
+
logger.debug(
|
|
497
|
+
"[d][b]chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
498
|
+
request_id,
|
|
499
|
+
resp.status_code,
|
|
500
|
+
duration_ms,
|
|
501
|
+
finish_reason,
|
|
502
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
503
|
+
_truncate(content, 2000),
|
|
504
|
+
api_request_id,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
return ChatResult(
|
|
508
|
+
text=content, usage=usage, finish_reason=finish_reason, raw=data
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
except Exception as e:
|
|
512
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
513
|
+
status_code = getattr(resp, "status_code", None)
|
|
514
|
+
resp_text_preview = (
|
|
515
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
516
|
+
if resp is not None
|
|
517
|
+
else None
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
logger.exception(
|
|
521
|
+
"chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
522
|
+
request_id,
|
|
523
|
+
status_code,
|
|
524
|
+
duration_ms,
|
|
525
|
+
resp_text_preview,
|
|
526
|
+
)
|
|
527
|
+
with self.lock:
|
|
528
|
+
if (
|
|
529
|
+
"authentication_token_expired" in str(e)
|
|
530
|
+
or status_code == 401
|
|
531
|
+
):
|
|
532
|
+
try:
|
|
533
|
+
self.access_token, self.refresh_time = (
|
|
534
|
+
self._get_access_token()
|
|
535
|
+
)
|
|
536
|
+
except Exception:
|
|
537
|
+
pass
|
|
538
|
+
raise
|
|
123
539
|
|
|
124
540
|
def encode(self, sentences: List[str]) -> List[list]:
|
|
125
541
|
if self.embedding_model_id is None:
|
|
126
|
-
raise Exception(
|
|
542
|
+
raise Exception(
|
|
543
|
+
"embedding model id must be specified for text encoding"
|
|
544
|
+
)
|
|
127
545
|
|
|
546
|
+
self._refresh_token()
|
|
128
547
|
headers = self.prepare_header()
|
|
129
|
-
url = f"{self.api_endpoint}/ml/v1/text/embeddings?version=2023-10-25"
|
|
130
548
|
|
|
131
|
-
|
|
132
|
-
|
|
549
|
+
# Minimal logging for embeddings
|
|
550
|
+
request_id = str(uuid.uuid4())
|
|
551
|
+
t0 = time.time()
|
|
552
|
+
logger.debug(
|
|
553
|
+
"[d][b]Sending embeddings request | request_id=%s url=%s model=%s space_id=%s num_inputs=%s",
|
|
554
|
+
request_id,
|
|
555
|
+
self.EMBEDDINGS_URL,
|
|
556
|
+
self.embedding_model_id,
|
|
557
|
+
self.space_id,
|
|
558
|
+
len(sentences),
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
payload = {
|
|
562
|
+
"inputs": sentences,
|
|
563
|
+
"model_id": self.embedding_model_id,
|
|
564
|
+
"space_id": self.space_id,
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
resp = requests.post(
|
|
568
|
+
url=self.EMBEDDINGS_URL,
|
|
569
|
+
headers=headers,
|
|
570
|
+
json=payload,
|
|
571
|
+
timeout=self.timeout,
|
|
572
|
+
)
|
|
573
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
574
|
+
|
|
133
575
|
if resp.status_code == 200:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
576
|
+
data = resp.json()
|
|
577
|
+
vectors = [entry["embedding"] for entry in data["results"]]
|
|
578
|
+
logger.debug(
|
|
579
|
+
"[d][b]Embeddings response received | request_id=%s status_code=%s duration_ms=%s num_vectors=%s",
|
|
580
|
+
request_id,
|
|
581
|
+
resp.status_code,
|
|
582
|
+
duration_ms,
|
|
583
|
+
len(vectors),
|
|
584
|
+
)
|
|
585
|
+
return vectors
|
|
586
|
+
|
|
587
|
+
logger.error(
|
|
588
|
+
"[d b red]Embeddings request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
589
|
+
request_id,
|
|
590
|
+
resp.status_code,
|
|
591
|
+
duration_ms,
|
|
592
|
+
_truncate(resp.text, 2000),
|
|
593
|
+
)
|
|
594
|
+
resp.raise_for_status()
|
|
137
595
|
|
|
138
596
|
|
|
139
597
|
if __name__ == "__main__":
|
|
140
|
-
|
|
598
|
+
|
|
599
|
+
provider = WatsonXProvider(
|
|
600
|
+
model_id="meta-llama/llama-3-2-90b-vision-instruct",
|
|
601
|
+
use_legacy_query=False, # set True to use legacy endpoint
|
|
602
|
+
system_prompt="You are a helpful assistant.",
|
|
603
|
+
)
|
|
141
604
|
|
|
142
605
|
prompt = """
|
|
143
606
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
@@ -169,4 +632,4 @@ Usernwaters did not take anytime off during the period<|eot_id|>
|
|
|
169
632
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
170
633
|
"""
|
|
171
634
|
|
|
172
|
-
print(provider.query(prompt))
|
|
635
|
+
print(provider.query(prompt))
|