ibm-watsonx-orchestrate-evaluation-framework 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +42 -36
- wxo_agentic_evaluation/analyze_run.py +49 -32
- wxo_agentic_evaluation/arg_configs.py +30 -2
- wxo_agentic_evaluation/data_annotator.py +22 -4
- wxo_agentic_evaluation/description_quality_checker.py +20 -4
- wxo_agentic_evaluation/evaluation_package.py +189 -15
- wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
- wxo_agentic_evaluation/external_agent/types.py +1 -1
- wxo_agentic_evaluation/inference_backend.py +64 -34
- wxo_agentic_evaluation/llm_matching.py +92 -2
- wxo_agentic_evaluation/llm_user.py +2 -2
- wxo_agentic_evaluation/main.py +147 -38
- wxo_agentic_evaluation/metrics/__init__.py +5 -1
- wxo_agentic_evaluation/metrics/evaluations.py +124 -0
- wxo_agentic_evaluation/metrics/metrics.py +24 -3
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/template_render.py +16 -0
- wxo_agentic_evaluation/quick_eval.py +17 -3
- wxo_agentic_evaluation/record_chat.py +17 -6
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +44 -14
- wxo_agentic_evaluation/red_teaming/attack_generator.py +31 -12
- wxo_agentic_evaluation/red_teaming/attack_list.py +23 -24
- wxo_agentic_evaluation/red_teaming/attack_runner.py +36 -19
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +42 -16
- wxo_agentic_evaluation/service_instance.py +5 -3
- wxo_agentic_evaluation/service_provider/__init__.py +129 -9
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
- wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
- wxo_agentic_evaluation/service_provider/provider.py +130 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
- wxo_agentic_evaluation/type.py +14 -4
- wxo_agentic_evaluation/utils/__init__.py +43 -5
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/utils.py +14 -9
- wxo_agentic_evaluation/wxo_client.py +2 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,707 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from threading import Lock
|
|
9
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
from wxo_agentic_evaluation.service_instance import tenant_setup
|
|
14
|
+
from wxo_agentic_evaluation.service_provider.provider import (
|
|
15
|
+
ChatResult,
|
|
16
|
+
Provider,
|
|
17
|
+
)
|
|
18
|
+
from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
AUTH_ENDPOINT_AWS = (
|
|
23
|
+
"https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
|
|
24
|
+
)
|
|
25
|
+
AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
|
|
26
|
+
|
|
27
|
+
DEFAULT_PARAM = {
|
|
28
|
+
"min_new_tokens": 1,
|
|
29
|
+
"decoding_method": "greedy",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _truncate(value: Any, max_len: int = 1000) -> str:
|
|
34
|
+
if value is None:
|
|
35
|
+
return ""
|
|
36
|
+
s = str(value)
|
|
37
|
+
return (
|
|
38
|
+
s
|
|
39
|
+
if len(s) <= max_len
|
|
40
|
+
else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _translate_params_to_chat(params: Dict[str, Any] | None) -> Dict[str, Any]:
|
|
45
|
+
# Translate legacy generation params to chat.completions params.
|
|
46
|
+
p = params or {}
|
|
47
|
+
out: Dict[str, Any] = {}
|
|
48
|
+
|
|
49
|
+
passthrough = {
|
|
50
|
+
"temperature",
|
|
51
|
+
"top_p",
|
|
52
|
+
"n",
|
|
53
|
+
"stream",
|
|
54
|
+
"stop",
|
|
55
|
+
"presence_penalty",
|
|
56
|
+
"frequency_penalty",
|
|
57
|
+
"logit_bias",
|
|
58
|
+
"user",
|
|
59
|
+
# "max_tokens", #reasoning frequently uses up max_tokens so not passing for now
|
|
60
|
+
"seed",
|
|
61
|
+
"response_format",
|
|
62
|
+
}
|
|
63
|
+
for k in passthrough:
|
|
64
|
+
if k in p:
|
|
65
|
+
out[k] = p[k]
|
|
66
|
+
|
|
67
|
+
# reasoning frequently uses up max_tokens so not passing for now
|
|
68
|
+
# if "max_new_tokens" in p and "max_completion_tokens" not in out:
|
|
69
|
+
# out["max_completion_tokens"] = p["max_new_tokens"]
|
|
70
|
+
|
|
71
|
+
return out
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _infer_cpd_auth_url(instance_url: str) -> str:
|
|
75
|
+
inst = (instance_url or "").rstrip("/")
|
|
76
|
+
if not inst:
|
|
77
|
+
return "/icp4d-api/v1/authorize"
|
|
78
|
+
if "/orchestrate" in inst:
|
|
79
|
+
base = inst.split("/orchestrate", 1)[0].rstrip("/")
|
|
80
|
+
return base + "/icp4d-api/v1/authorize"
|
|
81
|
+
return inst + "/icp4d-api/v1/authorize"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _normalize_cpd_auth_url(url: str) -> str:
|
|
85
|
+
u = (url or "").rstrip("/")
|
|
86
|
+
if u.endswith("/icp4d-api"):
|
|
87
|
+
return u + "/v1/authorize"
|
|
88
|
+
return url
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class GatewayProvider(Provider):
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
model_id: Optional[str] = None,
|
|
96
|
+
api_key: Optional[str] = None,
|
|
97
|
+
instance_url: Optional[str] = None,
|
|
98
|
+
timeout: int = 300,
|
|
99
|
+
embedding_model_id: Optional[str] = None,
|
|
100
|
+
params: Optional[Dict[str, Any]] = None,
|
|
101
|
+
use_legacy_query: Optional[bool] = None,
|
|
102
|
+
system_prompt: Optional[str] = None,
|
|
103
|
+
chat_path: Optional[str] = None,
|
|
104
|
+
embeddings_path: Optional[str] = None,
|
|
105
|
+
gateway_provider: Optional[str] = None,
|
|
106
|
+
gateway_api_key_label: Optional[str] = None,
|
|
107
|
+
x_gateway_config: Optional[Dict[str, Any]] = None,
|
|
108
|
+
# New: static bearer token (overridden by WO_TOKEN if present)
|
|
109
|
+
token: Optional[str] = None,
|
|
110
|
+
):
|
|
111
|
+
super().__init__(use_legacy_query=use_legacy_query)
|
|
112
|
+
instance_url = os.environ.get("WO_INSTANCE", instance_url)
|
|
113
|
+
if not instance_url:
|
|
114
|
+
logger.info("[d b]Gateway provider defaulting to local tenant")
|
|
115
|
+
token, instance_url, _ = tenant_setup(
|
|
116
|
+
service_url=None, tenant_name="local"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if not instance_url:
|
|
120
|
+
raise RuntimeError(
|
|
121
|
+
"instance url must be specified for gateway provider"
|
|
122
|
+
)
|
|
123
|
+
self.timeout = timeout
|
|
124
|
+
self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
|
|
125
|
+
logger.info("[d b]Using inference model %s", self.model_id)
|
|
126
|
+
self.embedding_model_id = embedding_model_id
|
|
127
|
+
|
|
128
|
+
self.api_key = os.environ.get("WO_API_KEY", api_key)
|
|
129
|
+
self.username = os.environ.get("WO_USERNAME", None)
|
|
130
|
+
self.password = os.environ.get("WO_PASSWORD", None)
|
|
131
|
+
self.auth_type = os.environ.get("WO_AUTH_TYPE", "").lower()
|
|
132
|
+
explicit_auth_url = os.environ.get("AUTHORIZATION_URL", None)
|
|
133
|
+
|
|
134
|
+
self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
|
|
135
|
+
self.instance_url = instance_url.rstrip("/")
|
|
136
|
+
|
|
137
|
+
self._wo_ssl_verify = (
|
|
138
|
+
os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Decide: static token vs exchange/refresh
|
|
142
|
+
token_from_env = os.environ.get("WO_TOKEN", None)
|
|
143
|
+
static_token = token_from_env if token_from_env is not None else token
|
|
144
|
+
self._use_static_token = bool(static_token)
|
|
145
|
+
|
|
146
|
+
if not self._use_static_token:
|
|
147
|
+
self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(
|
|
148
|
+
explicit_auth_url=explicit_auth_url
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
self.auth_mode, self.auth_url = ("static", "")
|
|
152
|
+
|
|
153
|
+
env_space_id = os.environ.get("WATSONX_SPACE_ID", None)
|
|
154
|
+
if self._use_static_token:
|
|
155
|
+
self.space_id = (
|
|
156
|
+
env_space_id.strip()
|
|
157
|
+
if env_space_id and env_space_id.strip()
|
|
158
|
+
else "1"
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
if self.auth_mode == "cpd":
|
|
162
|
+
if not env_space_id or not env_space_id.strip():
|
|
163
|
+
raise RuntimeError(
|
|
164
|
+
"CPD mode requires WATSONX_SPACE_ID environment variable to be set"
|
|
165
|
+
)
|
|
166
|
+
self.space_id = env_space_id.strip()
|
|
167
|
+
if "/orchestrate" in self.instance_url:
|
|
168
|
+
self.instance_url = self.instance_url.split(
|
|
169
|
+
"/orchestrate", 1
|
|
170
|
+
)[0].rstrip("/")
|
|
171
|
+
if not self.username:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
"CPD auth requires WO_USERNAME to be set"
|
|
174
|
+
)
|
|
175
|
+
if not (self.password or self.api_key):
|
|
176
|
+
raise RuntimeError(
|
|
177
|
+
"CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)"
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
self.space_id = (
|
|
181
|
+
env_space_id.strip()
|
|
182
|
+
if env_space_id and env_space_id.strip()
|
|
183
|
+
else "1"
|
|
184
|
+
)
|
|
185
|
+
if not self.api_key:
|
|
186
|
+
raise RuntimeError(
|
|
187
|
+
"WO_API_KEY must be specified for SaaS or IBM IAM auth"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
default_chat_path = os.environ.get(
|
|
191
|
+
"GATEWAY_CHAT_PATH",
|
|
192
|
+
"/v1/orchestrate/gateway/model/chat/completions",
|
|
193
|
+
)
|
|
194
|
+
default_embeddings_path = os.environ.get(
|
|
195
|
+
"GATEWAY_EMBEDDINGS_PATH",
|
|
196
|
+
"/v1/orchestrate/gateway/model/embeddings",
|
|
197
|
+
)
|
|
198
|
+
self.chat_url = self.instance_url + (chat_path or default_chat_path)
|
|
199
|
+
self.embeddings_url = self.instance_url + (
|
|
200
|
+
embeddings_path or default_embeddings_path
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.gateway_provider = gateway_provider or os.environ.get(
|
|
204
|
+
"GATEWAY_PROVIDER", "watsonx"
|
|
205
|
+
)
|
|
206
|
+
self.gateway_api_key_label = gateway_api_key_label or os.environ.get(
|
|
207
|
+
"GATEWAY_API_KEY_LABEL", "gateway"
|
|
208
|
+
)
|
|
209
|
+
self.x_gateway_config_override = (
|
|
210
|
+
x_gateway_config # if set, we use it verbatim
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
self.payload_model_prefix = os.environ.get(
|
|
214
|
+
"GATEWAY_MODEL_PREFIX", "watsonx/"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
self.lock = Lock()
|
|
218
|
+
|
|
219
|
+
# Token initialization
|
|
220
|
+
if self._use_static_token:
|
|
221
|
+
# Use the provided or env token as-is; no refresh
|
|
222
|
+
self.token = static_token # type: ignore[assignment]
|
|
223
|
+
self.refresh_time = float("inf")
|
|
224
|
+
else:
|
|
225
|
+
# Original behavior: exchange to acquire token + refresh schedule
|
|
226
|
+
self.token, self.refresh_time = self.get_token()
|
|
227
|
+
|
|
228
|
+
self.params = params if params else DEFAULT_PARAM
|
|
229
|
+
self.system_prompt = system_prompt
|
|
230
|
+
|
|
231
|
+
def _resolve_auth_mode_and_url(
|
|
232
|
+
self, explicit_auth_url: str | None
|
|
233
|
+
) -> Tuple[str, str]:
|
|
234
|
+
"""
|
|
235
|
+
Returns (auth_mode, auth_url)
|
|
236
|
+
- auth_mode: "cpd" | "ibm_iam" | "saas"
|
|
237
|
+
"""
|
|
238
|
+
if explicit_auth_url:
|
|
239
|
+
if "/icp4d-api" in explicit_auth_url:
|
|
240
|
+
return "cpd", _normalize_cpd_auth_url(explicit_auth_url)
|
|
241
|
+
if self.auth_type == "ibm_iam":
|
|
242
|
+
return "ibm_iam", explicit_auth_url
|
|
243
|
+
elif self.auth_type == "saas":
|
|
244
|
+
return "saas", explicit_auth_url
|
|
245
|
+
else:
|
|
246
|
+
mode = "ibm_iam" if self.is_ibm_cloud else "saas"
|
|
247
|
+
return mode, explicit_auth_url
|
|
248
|
+
|
|
249
|
+
if self.auth_type == "cpd":
|
|
250
|
+
inferred_cpd_auth_url = _infer_cpd_auth_url(self.instance_url)
|
|
251
|
+
return "cpd", inferred_cpd_auth_url
|
|
252
|
+
if self.auth_type == "ibm_iam":
|
|
253
|
+
return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
|
|
254
|
+
if self.auth_type == "saas":
|
|
255
|
+
return "saas", AUTH_ENDPOINT_AWS
|
|
256
|
+
|
|
257
|
+
if "/orchestrate" in self.instance_url:
|
|
258
|
+
inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
|
|
259
|
+
return "cpd", inferred_cpd_url
|
|
260
|
+
|
|
261
|
+
if self.is_ibm_cloud:
|
|
262
|
+
return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
|
|
263
|
+
else:
|
|
264
|
+
return "saas", AUTH_ENDPOINT_AWS
|
|
265
|
+
|
|
266
|
+
def get_token(self):
|
|
267
|
+
headers = {}
|
|
268
|
+
post_args = {}
|
|
269
|
+
timeout = 10
|
|
270
|
+
exchange_url = self.auth_url
|
|
271
|
+
|
|
272
|
+
if self.auth_mode == "ibm_iam":
|
|
273
|
+
headers = {
|
|
274
|
+
"Accept": "application/json",
|
|
275
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
276
|
+
}
|
|
277
|
+
form_data = {
|
|
278
|
+
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
|
279
|
+
"apikey": self.api_key,
|
|
280
|
+
}
|
|
281
|
+
post_args = {"data": form_data}
|
|
282
|
+
resp = requests.post(
|
|
283
|
+
exchange_url,
|
|
284
|
+
headers=headers,
|
|
285
|
+
timeout=timeout,
|
|
286
|
+
verify=self._wo_ssl_verify,
|
|
287
|
+
**post_args,
|
|
288
|
+
)
|
|
289
|
+
elif self.auth_mode == "cpd":
|
|
290
|
+
headers = {
|
|
291
|
+
"Accept": "application/json",
|
|
292
|
+
"Content-Type": "application/json",
|
|
293
|
+
}
|
|
294
|
+
body = {"username": self.username}
|
|
295
|
+
if self.password:
|
|
296
|
+
body["password"] = self.password
|
|
297
|
+
else:
|
|
298
|
+
body["api_key"] = self.api_key
|
|
299
|
+
timeout = self.timeout
|
|
300
|
+
resp = requests.post(
|
|
301
|
+
exchange_url,
|
|
302
|
+
headers=headers,
|
|
303
|
+
json=body,
|
|
304
|
+
timeout=timeout,
|
|
305
|
+
verify=self._wo_ssl_verify,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
headers = {
|
|
309
|
+
"Accept": "application/json",
|
|
310
|
+
"Content-Type": "application/json",
|
|
311
|
+
}
|
|
312
|
+
post_args = {"json": {"apikey": self.api_key}}
|
|
313
|
+
resp = requests.post(
|
|
314
|
+
exchange_url,
|
|
315
|
+
headers=headers,
|
|
316
|
+
timeout=timeout,
|
|
317
|
+
verify=self._wo_ssl_verify,
|
|
318
|
+
**post_args,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if resp.status_code == 200:
|
|
322
|
+
json_obj = resp.json()
|
|
323
|
+
token = json_obj.get("access_token") or json_obj.get("token")
|
|
324
|
+
if not token:
|
|
325
|
+
raise RuntimeError(
|
|
326
|
+
f"No token field found in response: {json_obj!r}"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
expires_in = json_obj.get("expires_in")
|
|
330
|
+
try:
|
|
331
|
+
expires_in = int(expires_in) if expires_in is not None else None
|
|
332
|
+
except Exception:
|
|
333
|
+
expires_in = None
|
|
334
|
+
if not expires_in or expires_in <= 0:
|
|
335
|
+
expires_in = int(os.environ.get("TOKEN_DEFAULT_EXPIRES_IN", 1))
|
|
336
|
+
|
|
337
|
+
refresh_time = time.time() + int(0.8 * expires_in)
|
|
338
|
+
return token, refresh_time
|
|
339
|
+
|
|
340
|
+
resp.raise_for_status()
|
|
341
|
+
|
|
342
|
+
def refresh_token_if_expires(self):
|
|
343
|
+
# No-op if using static token
|
|
344
|
+
if self._use_static_token:
|
|
345
|
+
return
|
|
346
|
+
if time.time() > self.refresh_time:
|
|
347
|
+
with self.lock:
|
|
348
|
+
if time.time() > self.refresh_time:
|
|
349
|
+
self.token, self.refresh_time = self.get_token()
|
|
350
|
+
|
|
351
|
+
def _auth_header(self) -> Dict[str, str]:
|
|
352
|
+
return {"Authorization": f"Bearer {self.token}"}
|
|
353
|
+
|
|
354
|
+
def _build_x_gateway_config(self, override_params: Dict[str, Any]) -> str:
|
|
355
|
+
"""
|
|
356
|
+
Build x-gateway-config header JSON string.
|
|
357
|
+
"""
|
|
358
|
+
if self.x_gateway_config_override:
|
|
359
|
+
return json.dumps(self.x_gateway_config_override)
|
|
360
|
+
|
|
361
|
+
config = {
|
|
362
|
+
"strategy": {"mode": "single"},
|
|
363
|
+
"targets": [
|
|
364
|
+
{
|
|
365
|
+
"provider": self.gateway_provider,
|
|
366
|
+
"api_key": self.gateway_api_key_label,
|
|
367
|
+
"override_params": override_params or {},
|
|
368
|
+
}
|
|
369
|
+
],
|
|
370
|
+
}
|
|
371
|
+
return json.dumps(config, separators=(",", ":"))
|
|
372
|
+
|
|
373
|
+
def _headers(
|
|
374
|
+
self, request_id: str, override_params: Dict[str, Any]
|
|
375
|
+
) -> Dict[str, str]:
|
|
376
|
+
h = {
|
|
377
|
+
"Accept": "application/json",
|
|
378
|
+
"Content-Type": "application/json",
|
|
379
|
+
"x-request-id": request_id,
|
|
380
|
+
"x-gateway-config": self._build_x_gateway_config(override_params),
|
|
381
|
+
}
|
|
382
|
+
h.update(self._auth_header())
|
|
383
|
+
return h
|
|
384
|
+
|
|
385
|
+
def _payload_model_str(self, model_id: str) -> str:
|
|
386
|
+
prefix = self.payload_model_prefix or ""
|
|
387
|
+
# Check if prefix already provided
|
|
388
|
+
return (
|
|
389
|
+
model_id
|
|
390
|
+
if (
|
|
391
|
+
prefix
|
|
392
|
+
and (
|
|
393
|
+
model_id.startswith(prefix)
|
|
394
|
+
or model_id.startswith("virtual-model")
|
|
395
|
+
)
|
|
396
|
+
)
|
|
397
|
+
else f"{prefix}{model_id}"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# -------------------- Provider API --------------------
|
|
401
|
+
|
|
402
|
+
def old_query(self, sentence: str) -> str:
|
|
403
|
+
# Does not have a legacy /text/generation, route to chat with a single user turn
|
|
404
|
+
return self.new_query(sentence)
|
|
405
|
+
|
|
406
|
+
def new_query(self, sentence: str) -> str:
|
|
407
|
+
"""
|
|
408
|
+
POST to gateway chat/completions (non-streaming).
|
|
409
|
+
Returns assistant text as a string.
|
|
410
|
+
"""
|
|
411
|
+
if self.model_id is None:
|
|
412
|
+
raise Exception("model id must be specified for text generation")
|
|
413
|
+
|
|
414
|
+
self.refresh_token_if_expires()
|
|
415
|
+
|
|
416
|
+
messages: List[Dict[str, str]] = []
|
|
417
|
+
if getattr(self, "system_prompt", None):
|
|
418
|
+
messages.append({"role": "system", "content": self.system_prompt})
|
|
419
|
+
messages.append({"role": "user", "content": sentence})
|
|
420
|
+
|
|
421
|
+
chat_params = _translate_params_to_chat(self.params)
|
|
422
|
+
override_params = dict(self.params or {})
|
|
423
|
+
override_params["model"] = self.model_id
|
|
424
|
+
|
|
425
|
+
payload: Dict[str, Any] = {
|
|
426
|
+
"model": self._payload_model_str(self.model_id),
|
|
427
|
+
"messages": messages,
|
|
428
|
+
**{k: v for k, v in chat_params.items() if k != "stream"},
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
request_id = str(uuid.uuid4())
|
|
432
|
+
headers = self._headers(request_id, override_params)
|
|
433
|
+
|
|
434
|
+
t0 = time.time()
|
|
435
|
+
logger.debug(
|
|
436
|
+
"[d][b]Sending gateway chat.completions request | request_id=%s url=%s model=%s params=%s input_preview=%s",
|
|
437
|
+
request_id,
|
|
438
|
+
self.chat_url,
|
|
439
|
+
self.model_id,
|
|
440
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
441
|
+
_truncate(sentence, 200),
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
resp = None
|
|
445
|
+
try:
|
|
446
|
+
resp = requests.post(
|
|
447
|
+
self.chat_url,
|
|
448
|
+
json=payload,
|
|
449
|
+
headers=headers,
|
|
450
|
+
verify=self._wo_ssl_verify,
|
|
451
|
+
timeout=self.timeout,
|
|
452
|
+
)
|
|
453
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
454
|
+
resp.raise_for_status()
|
|
455
|
+
data = resp.json()
|
|
456
|
+
|
|
457
|
+
choice = (data.get("choices") or [{}])[0]
|
|
458
|
+
content = None
|
|
459
|
+
if isinstance(choice, dict):
|
|
460
|
+
if "message" in choice and isinstance(choice["message"], dict):
|
|
461
|
+
content = choice["message"].get("content")
|
|
462
|
+
if content is None and "text" in choice:
|
|
463
|
+
content = choice.get("text")
|
|
464
|
+
|
|
465
|
+
if content is None:
|
|
466
|
+
content = data.get("output") or data.get("text") or ""
|
|
467
|
+
|
|
468
|
+
finish_reason = (
|
|
469
|
+
choice.get("finish_reason")
|
|
470
|
+
if isinstance(choice, dict)
|
|
471
|
+
else None
|
|
472
|
+
)
|
|
473
|
+
usage = data.get("usage", {})
|
|
474
|
+
api_request_id = resp.headers.get(
|
|
475
|
+
"x-request-id"
|
|
476
|
+
) or resp.headers.get("request-id")
|
|
477
|
+
|
|
478
|
+
logger.debug(
|
|
479
|
+
"[d][b]Gateway chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
480
|
+
request_id,
|
|
481
|
+
resp.status_code,
|
|
482
|
+
duration_ms,
|
|
483
|
+
finish_reason,
|
|
484
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
485
|
+
_truncate(content, 2000),
|
|
486
|
+
api_request_id,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
return content or ""
|
|
490
|
+
|
|
491
|
+
except Exception:
|
|
492
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
493
|
+
status_code = getattr(resp, "status_code", None)
|
|
494
|
+
resp_text_preview = None
|
|
495
|
+
try:
|
|
496
|
+
if resp is not None:
|
|
497
|
+
resp_text_preview = _truncate(
|
|
498
|
+
getattr(resp, "text", None), 2000
|
|
499
|
+
)
|
|
500
|
+
except Exception:
|
|
501
|
+
pass
|
|
502
|
+
|
|
503
|
+
logger.exception(
|
|
504
|
+
"Gateway chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
505
|
+
request_id,
|
|
506
|
+
status_code,
|
|
507
|
+
duration_ms,
|
|
508
|
+
resp_text_preview,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Only attempt re-auth if not in static-token mode
|
|
512
|
+
if (status_code == 401) and (not self._use_static_token):
|
|
513
|
+
with self.lock:
|
|
514
|
+
try:
|
|
515
|
+
self.token, self.refresh_time = self.get_token()
|
|
516
|
+
except Exception:
|
|
517
|
+
pass
|
|
518
|
+
raise
|
|
519
|
+
|
|
520
|
+
def chat(
|
|
521
|
+
self,
|
|
522
|
+
messages: Sequence[Dict[str, str]],
|
|
523
|
+
params: Optional[Dict[str, Any]] = None,
|
|
524
|
+
) -> ChatResult:
|
|
525
|
+
"""
|
|
526
|
+
Returns ChatResult with text, usage, finish_reason.
|
|
527
|
+
"""
|
|
528
|
+
if self.model_id is None:
|
|
529
|
+
raise Exception("model id must be specified for chat")
|
|
530
|
+
|
|
531
|
+
self.refresh_token_if_expires()
|
|
532
|
+
|
|
533
|
+
merged_params = dict(self.params or {})
|
|
534
|
+
if params:
|
|
535
|
+
merged_params.update(params)
|
|
536
|
+
chat_params = _translate_params_to_chat(merged_params)
|
|
537
|
+
chat_params.pop("stream", None)
|
|
538
|
+
|
|
539
|
+
override_params = dict(merged_params)
|
|
540
|
+
override_params["model"] = self.model_id
|
|
541
|
+
|
|
542
|
+
payload: Dict[str, Any] = {
|
|
543
|
+
"model": self._payload_model_str(self.model_id),
|
|
544
|
+
"messages": list(messages),
|
|
545
|
+
**chat_params,
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
request_id = str(uuid.uuid4())
|
|
549
|
+
headers = self._headers(request_id, override_params)
|
|
550
|
+
|
|
551
|
+
last_user = next(
|
|
552
|
+
(
|
|
553
|
+
m.get("content", "")
|
|
554
|
+
for m in reversed(messages)
|
|
555
|
+
if m.get("role") == "user"
|
|
556
|
+
),
|
|
557
|
+
"",
|
|
558
|
+
)
|
|
559
|
+
t0 = time.time()
|
|
560
|
+
logger.debug(
|
|
561
|
+
"[d][b]Sending gateway chat.completions request (non-streaming) | request_id=%s url=%s model=%s params=%s input_preview=%s",
|
|
562
|
+
request_id,
|
|
563
|
+
self.chat_url,
|
|
564
|
+
self.model_id,
|
|
565
|
+
json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
|
|
566
|
+
_truncate(last_user, 200),
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
resp = None
|
|
570
|
+
try:
|
|
571
|
+
resp = requests.post(
|
|
572
|
+
self.chat_url,
|
|
573
|
+
json=payload,
|
|
574
|
+
headers=headers,
|
|
575
|
+
verify=self._wo_ssl_verify,
|
|
576
|
+
timeout=self.timeout,
|
|
577
|
+
)
|
|
578
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
579
|
+
resp.raise_for_status()
|
|
580
|
+
data = resp.json()
|
|
581
|
+
|
|
582
|
+
choice = (data.get("choices") or [{}])[0]
|
|
583
|
+
content = ""
|
|
584
|
+
if isinstance(choice, dict):
|
|
585
|
+
content = (
|
|
586
|
+
(choice.get("message", {}) or {}).get("content")
|
|
587
|
+
or choice.get("text")
|
|
588
|
+
or ""
|
|
589
|
+
)
|
|
590
|
+
finish_reason = (
|
|
591
|
+
choice.get("finish_reason")
|
|
592
|
+
if isinstance(choice, dict)
|
|
593
|
+
else None
|
|
594
|
+
)
|
|
595
|
+
usage = data.get("usage", {})
|
|
596
|
+
api_request_id = resp.headers.get(
|
|
597
|
+
"x-request-id"
|
|
598
|
+
) or resp.headers.get("request-id")
|
|
599
|
+
|
|
600
|
+
logger.debug(
|
|
601
|
+
"[d][b]Gateway chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
|
|
602
|
+
request_id,
|
|
603
|
+
resp.status_code,
|
|
604
|
+
duration_ms,
|
|
605
|
+
finish_reason,
|
|
606
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
607
|
+
_truncate(content, 2000),
|
|
608
|
+
api_request_id,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
return ChatResult(
|
|
612
|
+
text=content, usage=usage, finish_reason=finish_reason, raw=data
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
except Exception:
|
|
616
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
617
|
+
status_code = getattr(resp, "status_code", None)
|
|
618
|
+
resp_text_preview = None
|
|
619
|
+
try:
|
|
620
|
+
if resp is not None:
|
|
621
|
+
resp_text_preview = _truncate(
|
|
622
|
+
getattr(resp, "text", None), 2000
|
|
623
|
+
)
|
|
624
|
+
except Exception:
|
|
625
|
+
pass
|
|
626
|
+
|
|
627
|
+
logger.exception(
|
|
628
|
+
"Gateway chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
629
|
+
request_id,
|
|
630
|
+
status_code,
|
|
631
|
+
duration_ms,
|
|
632
|
+
resp_text_preview,
|
|
633
|
+
)
|
|
634
|
+
# Only attempt re-auth if not in static-token mode
|
|
635
|
+
if (status_code == 401) and (not self._use_static_token):
|
|
636
|
+
with self.lock:
|
|
637
|
+
try:
|
|
638
|
+
self.token, self.refresh_time = self.get_token()
|
|
639
|
+
except Exception:
|
|
640
|
+
pass
|
|
641
|
+
raise
|
|
642
|
+
|
|
643
|
+
def encode(self, sentences: List[str]) -> List[list]:
|
|
644
|
+
"""
|
|
645
|
+
Embeddings via gateway. Returns a list of vectors (list[float]) per input.
|
|
646
|
+
"""
|
|
647
|
+
if self.embedding_model_id is None:
|
|
648
|
+
raise Exception(
|
|
649
|
+
"embedding model id must be specified for text encoding"
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
self.refresh_token_if_expires()
|
|
653
|
+
|
|
654
|
+
override_params = {"model": self.embedding_model_id}
|
|
655
|
+
request_id = str(uuid.uuid4())
|
|
656
|
+
headers = self._headers(request_id, override_params)
|
|
657
|
+
|
|
658
|
+
payload = {
|
|
659
|
+
"input": sentences,
|
|
660
|
+
"model": self._payload_model_str(self.embedding_model_id),
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
t0 = time.time()
|
|
664
|
+
logger.debug(
|
|
665
|
+
"[d][b]Sending gateway embeddings request | request_id=%s url=%s model=%s num_inputs=%s",
|
|
666
|
+
request_id,
|
|
667
|
+
self.embeddings_url,
|
|
668
|
+
self.embedding_model_id,
|
|
669
|
+
len(sentences),
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
resp = requests.post(
|
|
673
|
+
self.embeddings_url,
|
|
674
|
+
json=payload,
|
|
675
|
+
headers=headers,
|
|
676
|
+
verify=self._wo_ssl_verify,
|
|
677
|
+
timeout=self.timeout,
|
|
678
|
+
)
|
|
679
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
680
|
+
|
|
681
|
+
if resp.status_code != 200:
|
|
682
|
+
logger.error(
|
|
683
|
+
"[d b red]Gateway embeddings request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
684
|
+
request_id,
|
|
685
|
+
resp.status_code,
|
|
686
|
+
duration_ms,
|
|
687
|
+
_truncate(resp.text, 2000),
|
|
688
|
+
)
|
|
689
|
+
resp.raise_for_status()
|
|
690
|
+
|
|
691
|
+
data = resp.json()
|
|
692
|
+
|
|
693
|
+
if "data" in data and isinstance(data["data"], list) and data["data"]:
|
|
694
|
+
vectors = [entry.get("embedding") for entry in data["data"]]
|
|
695
|
+
logger.debug(
|
|
696
|
+
"[d][b]Gateway embeddings response received | request_id=%s status_code=%s duration_ms=%s num_vectors=%s",
|
|
697
|
+
request_id,
|
|
698
|
+
resp.status_code,
|
|
699
|
+
duration_ms,
|
|
700
|
+
len(vectors),
|
|
701
|
+
)
|
|
702
|
+
return vectors
|
|
703
|
+
|
|
704
|
+
# Fallback
|
|
705
|
+
raise RuntimeError(
|
|
706
|
+
f"Unexpected embeddings response: {json.dumps(data)[:500]}"
|
|
707
|
+
)
|