ibm-watsonx-orchestrate-evaluation-framework 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +49 -39
- wxo_agentic_evaluation/analyze_run.py +822 -344
- wxo_agentic_evaluation/arg_configs.py +39 -2
- wxo_agentic_evaluation/data_annotator.py +22 -4
- wxo_agentic_evaluation/description_quality_checker.py +29 -4
- wxo_agentic_evaluation/evaluation_package.py +197 -18
- wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
- wxo_agentic_evaluation/external_agent/types.py +1 -1
- wxo_agentic_evaluation/inference_backend.py +105 -108
- wxo_agentic_evaluation/llm_matching.py +104 -2
- wxo_agentic_evaluation/llm_user.py +2 -2
- wxo_agentic_evaluation/main.py +147 -38
- wxo_agentic_evaluation/metrics/__init__.py +5 -0
- wxo_agentic_evaluation/metrics/evaluations.py +124 -0
- wxo_agentic_evaluation/metrics/llm_as_judge.py +4 -3
- wxo_agentic_evaluation/metrics/metrics.py +64 -1
- wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
- wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
- wxo_agentic_evaluation/prompt/template_render.py +20 -2
- wxo_agentic_evaluation/quick_eval.py +23 -11
- wxo_agentic_evaluation/record_chat.py +18 -10
- wxo_agentic_evaluation/red_teaming/attack_evaluator.py +169 -100
- wxo_agentic_evaluation/red_teaming/attack_generator.py +63 -40
- wxo_agentic_evaluation/red_teaming/attack_list.py +78 -8
- wxo_agentic_evaluation/red_teaming/attack_runner.py +71 -14
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
- wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +10 -10
- wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +103 -39
- wxo_agentic_evaluation/resource_map.py +3 -1
- wxo_agentic_evaluation/service_instance.py +12 -3
- wxo_agentic_evaluation/service_provider/__init__.py +129 -9
- wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
- wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
- wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
- wxo_agentic_evaluation/service_provider/provider.py +130 -10
- wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
- wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
- wxo_agentic_evaluation/type.py +15 -5
- wxo_agentic_evaluation/utils/__init__.py +44 -3
- wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
- wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
- wxo_agentic_evaluation/utils/messages_parser.py +30 -0
- wxo_agentic_evaluation/utils/parsers.py +71 -0
- wxo_agentic_evaluation/utils/utils.py +140 -20
- wxo_agentic_evaluation/wxo_client.py +81 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate_evaluation_framework-1.1.5.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
|
@@ -1,39 +1,410 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
|
-
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence
|
|
4
7
|
|
|
5
8
|
import requests
|
|
6
9
|
|
|
7
|
-
from wxo_agentic_evaluation.service_provider.provider import
|
|
10
|
+
from wxo_agentic_evaluation.service_provider.provider import (
|
|
11
|
+
ChatResult,
|
|
12
|
+
Provider,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
8
16
|
|
|
9
17
|
OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
|
10
18
|
|
|
11
19
|
|
|
20
|
+
def _truncate(value: Any, max_len: int = 1000) -> str:
|
|
21
|
+
if value is None:
|
|
22
|
+
return ""
|
|
23
|
+
s = str(value)
|
|
24
|
+
return (
|
|
25
|
+
s
|
|
26
|
+
if len(s) <= max_len
|
|
27
|
+
else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _translate_params_to_ollama_options(
|
|
32
|
+
params: Optional[Dict[str, Any]]
|
|
33
|
+
) -> Dict[str, Any]:
|
|
34
|
+
"""
|
|
35
|
+
Map generic params to Ollama 'options' field.
|
|
36
|
+
Ollama options docs: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#parameters
|
|
37
|
+
"""
|
|
38
|
+
p = params or {}
|
|
39
|
+
out: Dict[str, Any] = {}
|
|
40
|
+
|
|
41
|
+
for key in ("temperature", "top_p", "top_k", "stop", "seed"):
|
|
42
|
+
if key in p:
|
|
43
|
+
out[key] = p[key]
|
|
44
|
+
|
|
45
|
+
if "max_new_tokens" in p:
|
|
46
|
+
out["num_predict"] = p["max_new_tokens"]
|
|
47
|
+
elif "max_tokens" in p:
|
|
48
|
+
out["num_predict"] = p["max_tokens"]
|
|
49
|
+
|
|
50
|
+
if "repeat_penalty" in p:
|
|
51
|
+
out["repeat_penalty"] = p["repeat_penalty"]
|
|
52
|
+
if "repeat_last_n" in p:
|
|
53
|
+
out["repeat_last_n"] = p["repeat_last_n"]
|
|
54
|
+
|
|
55
|
+
return out
|
|
56
|
+
|
|
57
|
+
|
|
12
58
|
class OllamaProvider(Provider):
|
|
13
|
-
def __init__(
|
|
14
|
-
self
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
model_id: Optional[str] = None,
|
|
62
|
+
params: Optional[Dict[str, Any]] = None,
|
|
63
|
+
timeout: int = 300,
|
|
64
|
+
use_legacy_query: Optional[bool] = None,
|
|
65
|
+
system_prompt: Optional[str] = None,
|
|
66
|
+
token: Optional[str] = None,
|
|
67
|
+
instance_url: Optional[str] = None,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(use_legacy_query=use_legacy_query)
|
|
70
|
+
self.generate_url = (
|
|
71
|
+
OLLAMA_URL.rstrip("/") + "/api/generate"
|
|
72
|
+
) # legacy text generation
|
|
73
|
+
self.chat_url = OLLAMA_URL.rstrip("/") + "/api/chat" # chat endpoint
|
|
74
|
+
self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
|
|
75
|
+
logger.info("[d b]Using inference model %s", self.model_id)
|
|
76
|
+
self.params = params or {}
|
|
77
|
+
self.timeout = timeout
|
|
78
|
+
self.system_prompt = system_prompt
|
|
79
|
+
|
|
80
|
+
def old_query(self, sentence: str) -> str:
|
|
81
|
+
# Legacy /api/generate
|
|
82
|
+
if not self.model_id:
|
|
83
|
+
raise ValueError("model_id must be specified for Ollama generation")
|
|
84
|
+
|
|
85
|
+
options = _translate_params_to_ollama_options(self.params)
|
|
86
|
+
payload: Dict[str, Any] = {
|
|
87
|
+
"model": self.model_id,
|
|
88
|
+
"prompt": sentence,
|
|
89
|
+
"stream": True,
|
|
90
|
+
}
|
|
91
|
+
if options:
|
|
92
|
+
payload["options"] = options
|
|
93
|
+
|
|
94
|
+
request_id = str(uuid.uuid4())
|
|
95
|
+
t0 = time.time()
|
|
96
|
+
|
|
97
|
+
logger.debug(
|
|
98
|
+
"[d][b]Sending Ollama generate request | request_id=%s url=%s model=%s params=%s input_preview=%s",
|
|
99
|
+
request_id,
|
|
100
|
+
self.generate_url,
|
|
101
|
+
self.model_id,
|
|
102
|
+
json.dumps(options, sort_keys=True, ensure_ascii=False),
|
|
103
|
+
_truncate(sentence, 200),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
resp = None
|
|
21
107
|
final_text = ""
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
108
|
+
usage: Dict[str, Any] = {}
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
resp = requests.post(
|
|
112
|
+
self.generate_url,
|
|
113
|
+
json=payload,
|
|
114
|
+
stream=True,
|
|
115
|
+
timeout=self.timeout,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if resp.status_code != 200:
|
|
119
|
+
resp_text_preview = _truncate(getattr(resp, "text", ""), 2000)
|
|
120
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
121
|
+
logger.error(
|
|
122
|
+
"[d b red]Ollama generate request failed (non-200) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
123
|
+
request_id,
|
|
124
|
+
resp.status_code,
|
|
125
|
+
duration_ms,
|
|
126
|
+
resp_text_preview,
|
|
127
|
+
)
|
|
128
|
+
resp.raise_for_status()
|
|
129
|
+
|
|
130
|
+
for line in resp.iter_lines(decode_unicode=True):
|
|
131
|
+
if not line:
|
|
132
|
+
continue
|
|
133
|
+
try:
|
|
134
|
+
obj = json.loads(line)
|
|
135
|
+
except Exception:
|
|
136
|
+
logger.warning(
|
|
137
|
+
"Skipping unparsable line from Ollama generate | request_id=%s line_preview=%s",
|
|
138
|
+
request_id,
|
|
139
|
+
_truncate(line, 500),
|
|
140
|
+
)
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
if not obj.get("done"):
|
|
144
|
+
chunk = obj.get("response", "")
|
|
145
|
+
if chunk:
|
|
146
|
+
final_text += chunk
|
|
147
|
+
else:
|
|
148
|
+
# Final metrics frame
|
|
149
|
+
usage = {
|
|
150
|
+
"prompt_eval_count": obj.get("prompt_eval_count"),
|
|
151
|
+
"eval_count": obj.get("eval_count"),
|
|
152
|
+
"prompt_eval_duration_ns": obj.get(
|
|
153
|
+
"prompt_eval_duration"
|
|
154
|
+
),
|
|
155
|
+
"eval_duration_ns": obj.get("eval_duration"),
|
|
156
|
+
"total_duration_ns": obj.get("total_duration"),
|
|
157
|
+
"load_duration_ns": obj.get("load_duration"),
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
161
|
+
logger.debug(
|
|
162
|
+
"[d][b]Ollama generate response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s",
|
|
163
|
+
request_id,
|
|
164
|
+
resp.status_code,
|
|
165
|
+
duration_ms,
|
|
166
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
167
|
+
_truncate(final_text, 2000),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return final_text
|
|
171
|
+
|
|
172
|
+
except Exception:
|
|
173
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
174
|
+
status_code = getattr(resp, "status_code", None)
|
|
175
|
+
resp_text_preview = None
|
|
176
|
+
try:
|
|
177
|
+
if resp is not None and not getattr(resp, "raw", None):
|
|
178
|
+
resp_text_preview = _truncate(
|
|
179
|
+
getattr(resp, "text", None), 2000
|
|
180
|
+
)
|
|
181
|
+
except Exception:
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
logger.exception(
|
|
185
|
+
"Ollama generate request encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
186
|
+
request_id,
|
|
187
|
+
status_code,
|
|
188
|
+
duration_ms,
|
|
189
|
+
resp_text_preview,
|
|
190
|
+
)
|
|
191
|
+
raise
|
|
192
|
+
|
|
193
|
+
def new_query(self, sentence: str) -> str:
|
|
194
|
+
"""
|
|
195
|
+
/api/chat
|
|
196
|
+
Returns assistant message content.
|
|
197
|
+
"""
|
|
198
|
+
if not self.model_id:
|
|
199
|
+
raise ValueError("model_id must be specified for Ollama chat")
|
|
200
|
+
|
|
201
|
+
options = _translate_params_to_ollama_options(self.params)
|
|
30
202
|
|
|
31
|
-
|
|
203
|
+
messages: List[Dict[str, str]] = []
|
|
204
|
+
if self.system_prompt:
|
|
205
|
+
messages.append({"role": "system", "content": self.system_prompt})
|
|
206
|
+
messages.append({"role": "user", "content": sentence})
|
|
207
|
+
|
|
208
|
+
payload: Dict[str, Any] = {
|
|
209
|
+
"model": self.model_id,
|
|
210
|
+
"messages": messages,
|
|
211
|
+
"stream": False,
|
|
212
|
+
}
|
|
213
|
+
if options:
|
|
214
|
+
payload["options"] = options
|
|
215
|
+
|
|
216
|
+
request_id = str(uuid.uuid4())
|
|
217
|
+
t0 = time.time()
|
|
218
|
+
|
|
219
|
+
logger.debug(
|
|
220
|
+
"[d][b]Sending Ollama chat request (non-streaming) | request_id=%s url=%s model=%s params=%s input_preview=%s",
|
|
221
|
+
request_id,
|
|
222
|
+
self.chat_url,
|
|
223
|
+
self.model_id,
|
|
224
|
+
json.dumps(options, sort_keys=True, ensure_ascii=False),
|
|
225
|
+
_truncate(sentence, 200),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
resp = None
|
|
229
|
+
try:
|
|
230
|
+
resp = requests.post(
|
|
231
|
+
self.chat_url, json=payload, timeout=self.timeout
|
|
232
|
+
)
|
|
233
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
234
|
+
resp.raise_for_status()
|
|
235
|
+
data = resp.json()
|
|
236
|
+
|
|
237
|
+
# Non-streaming chat response: { "message": {"role": "assistant", "content": "..."} , "done": true, ... }
|
|
238
|
+
message = data.get("message") or {}
|
|
239
|
+
content = message.get("content", "") or ""
|
|
240
|
+
finish_reason = data.get("finish_reason")
|
|
241
|
+
usage = {
|
|
242
|
+
"prompt_eval_count": data.get("prompt_eval_count"),
|
|
243
|
+
"eval_count": data.get("eval_count"),
|
|
244
|
+
"prompt_eval_duration_ns": data.get("prompt_eval_duration"),
|
|
245
|
+
"eval_duration_ns": data.get("eval_duration"),
|
|
246
|
+
"total_duration_ns": data.get("total_duration"),
|
|
247
|
+
"load_duration_ns": data.get("load_duration"),
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
logger.debug(
|
|
251
|
+
"[d][b]Ollama chat response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s",
|
|
252
|
+
request_id,
|
|
253
|
+
resp.status_code,
|
|
254
|
+
duration_ms,
|
|
255
|
+
finish_reason,
|
|
256
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
257
|
+
_truncate(content, 2000),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return content
|
|
261
|
+
|
|
262
|
+
except Exception:
|
|
263
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
264
|
+
status_code = getattr(resp, "status_code", None)
|
|
265
|
+
resp_text_preview = (
|
|
266
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
267
|
+
if resp is not None
|
|
268
|
+
else None
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
logger.exception(
|
|
272
|
+
"Ollama chat request encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
273
|
+
request_id,
|
|
274
|
+
status_code,
|
|
275
|
+
duration_ms,
|
|
276
|
+
resp_text_preview,
|
|
277
|
+
)
|
|
278
|
+
raise
|
|
279
|
+
|
|
280
|
+
def chat(
|
|
281
|
+
self,
|
|
282
|
+
messages: Sequence[Dict[str, str]],
|
|
283
|
+
params: Optional[Dict[str, Any]] = None,
|
|
284
|
+
) -> ChatResult:
|
|
285
|
+
"""
|
|
286
|
+
Non-streaming chat via /api/chat.
|
|
287
|
+
"""
|
|
288
|
+
if not self.model_id:
|
|
289
|
+
raise ValueError("model_id must be specified for Ollama chat")
|
|
290
|
+
|
|
291
|
+
merged_params = dict(self.params or {})
|
|
292
|
+
if params:
|
|
293
|
+
merged_params.update(params)
|
|
294
|
+
options = _translate_params_to_ollama_options(merged_params)
|
|
295
|
+
|
|
296
|
+
payload: Dict[str, Any] = {
|
|
297
|
+
"model": self.model_id,
|
|
298
|
+
"messages": list(messages),
|
|
299
|
+
"stream": False,
|
|
300
|
+
}
|
|
301
|
+
if options:
|
|
302
|
+
payload["options"] = options
|
|
303
|
+
|
|
304
|
+
last_user = next(
|
|
305
|
+
(
|
|
306
|
+
m.get("content", "")
|
|
307
|
+
for m in reversed(messages)
|
|
308
|
+
if m.get("role") == "user"
|
|
309
|
+
),
|
|
310
|
+
"",
|
|
311
|
+
)
|
|
312
|
+
request_id = str(uuid.uuid4())
|
|
313
|
+
t0 = time.time()
|
|
314
|
+
|
|
315
|
+
logger.debug(
|
|
316
|
+
"[d][b]Sending Ollama chat request (non-streaming, multi-message) | request_id=%s url=%s model=%s params=%s input_preview=%s",
|
|
317
|
+
request_id,
|
|
318
|
+
self.chat_url,
|
|
319
|
+
self.model_id,
|
|
320
|
+
json.dumps(options, sort_keys=True, ensure_ascii=False),
|
|
321
|
+
_truncate(last_user, 200),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
resp = None
|
|
325
|
+
try:
|
|
326
|
+
resp = requests.post(
|
|
327
|
+
self.chat_url, json=payload, timeout=self.timeout
|
|
328
|
+
)
|
|
329
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
330
|
+
resp.raise_for_status()
|
|
331
|
+
data = resp.json()
|
|
332
|
+
|
|
333
|
+
message = data.get("message") or {}
|
|
334
|
+
content = message.get("content", "") or ""
|
|
335
|
+
finish_reason = data.get("finish_reason")
|
|
336
|
+
usage = {
|
|
337
|
+
"prompt_eval_count": data.get("prompt_eval_count"),
|
|
338
|
+
"eval_count": data.get("eval_count"),
|
|
339
|
+
"prompt_eval_duration_ns": data.get("prompt_eval_duration"),
|
|
340
|
+
"eval_duration_ns": data.get("eval_duration"),
|
|
341
|
+
"total_duration_ns": data.get("total_duration"),
|
|
342
|
+
"load_duration_ns": data.get("load_duration"),
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
logger.debug(
|
|
346
|
+
"[d][b]Ollama chat response received (non-streaming, multi-message) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s",
|
|
347
|
+
request_id,
|
|
348
|
+
resp.status_code,
|
|
349
|
+
duration_ms,
|
|
350
|
+
finish_reason,
|
|
351
|
+
json.dumps(usage, sort_keys=True, ensure_ascii=False),
|
|
352
|
+
_truncate(content, 2000),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
return ChatResult(
|
|
356
|
+
text=content, usage=usage, finish_reason=finish_reason, raw=data
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
except Exception:
|
|
360
|
+
duration_ms = int((time.time() - t0) * 1000)
|
|
361
|
+
status_code = getattr(resp, "status_code", None)
|
|
362
|
+
resp_text_preview = (
|
|
363
|
+
_truncate(getattr(resp, "text", None), 2000)
|
|
364
|
+
if resp is not None
|
|
365
|
+
else None
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
logger.exception(
|
|
369
|
+
"Ollama chat request (non-streaming, multi-message) encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
|
|
370
|
+
request_id,
|
|
371
|
+
status_code,
|
|
372
|
+
duration_ms,
|
|
373
|
+
resp_text_preview,
|
|
374
|
+
)
|
|
375
|
+
raise
|
|
32
376
|
|
|
33
377
|
def encode(self, sentences: List[str]) -> List[list]:
|
|
34
|
-
|
|
378
|
+
raise NotImplementedError(
|
|
379
|
+
"encode is not implemented for OllamaProvider"
|
|
380
|
+
)
|
|
35
381
|
|
|
36
382
|
|
|
37
383
|
if __name__ == "__main__":
|
|
38
|
-
|
|
39
|
-
|
|
384
|
+
logging.basicConfig(
|
|
385
|
+
level=logging.INFO,
|
|
386
|
+
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
provider = OllamaProvider(model_id="llama3.1:8b", use_legacy_query=False)
|
|
390
|
+
|
|
391
|
+
print("new_query:", provider.query("Say hello in one sentence."))
|
|
392
|
+
|
|
393
|
+
# chat API
|
|
394
|
+
messages = [
|
|
395
|
+
{"role": "system", "content": "You are concise."},
|
|
396
|
+
{"role": "user", "content": "List three fruits."},
|
|
397
|
+
]
|
|
398
|
+
result = provider.chat(messages)
|
|
399
|
+
print("chat:", result.text)
|
|
400
|
+
|
|
401
|
+
# Streaming chat
|
|
402
|
+
print("stream_chat:")
|
|
403
|
+
assembled = []
|
|
404
|
+
for chunk in provider.stream_chat(
|
|
405
|
+
[{"role": "user", "content": "Stream a short sentence."}]
|
|
406
|
+
):
|
|
407
|
+
if chunk.get("delta"):
|
|
408
|
+
assembled.append(chunk["delta"])
|
|
409
|
+
if chunk.get("is_final"):
|
|
410
|
+
print("".join(assembled))
|
|
@@ -1,18 +1,138 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import List
|
|
1
|
+
from __future__ import annotations
|
|
3
2
|
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC, ABCMeta, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from threading import Lock
|
|
8
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
4
9
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
10
|
+
from wxo_agentic_evaluation.type import ProviderInstancesCacheKey
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SingletonProviderMeta(type):
|
|
14
|
+
|
|
15
|
+
_provider_instances: Dict[str, "Provider"] = {}
|
|
16
|
+
_instantiation_lock = Lock()
|
|
17
|
+
|
|
18
|
+
def __call__(cls, *args, **kwargs):
|
|
19
|
+
|
|
20
|
+
key_str: str = str(cls._get_key(cls.__name__, args, kwargs))
|
|
21
|
+
|
|
22
|
+
if key_str not in cls._provider_instances:
|
|
23
|
+
with cls._instantiation_lock:
|
|
24
|
+
if key_str not in cls._provider_instances:
|
|
25
|
+
cls._provider_instances[key_str] = super().__call__(
|
|
26
|
+
*args, **kwargs
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return cls._provider_instances[key_str]
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def _get_key(
|
|
33
|
+
provider: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
34
|
+
) -> ProviderInstancesCacheKey:
|
|
35
|
+
|
|
36
|
+
args_str = str(args) if args else "noargs"
|
|
37
|
+
kwargs_str = str(sorted(kwargs.items())) if kwargs else "nokwargs"
|
|
38
|
+
|
|
39
|
+
return ProviderInstancesCacheKey(
|
|
40
|
+
provider=provider,
|
|
41
|
+
hashed_args=args_str,
|
|
42
|
+
hashed_kwargs=kwargs_str,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SingletonProviderABCMeta(ABCMeta, SingletonProviderMeta):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class ChatResult:
|
|
52
|
+
text: str
|
|
53
|
+
usage: Optional[Dict[str, Any]] = None
|
|
54
|
+
finish_reason: Optional[str] = None
|
|
55
|
+
raw: Optional[Any] = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Provider(ABC, metaclass=SingletonProviderABCMeta):
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
use_legacy_query: Optional[bool] = None,
|
|
62
|
+
logger: Optional[logging.Logger] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
|
65
|
+
|
|
66
|
+
env_use_legacy = os.environ.get("USE_LEGACY_QUERY")
|
|
67
|
+
if env_use_legacy is not None:
|
|
68
|
+
self.use_legacy_query: bool = env_use_legacy.strip().lower() in (
|
|
69
|
+
"1",
|
|
70
|
+
"true",
|
|
71
|
+
"yes",
|
|
72
|
+
"on",
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
self.use_legacy_query = (
|
|
76
|
+
bool(use_legacy_query) if use_legacy_query is not None else True
|
|
77
|
+
)
|
|
78
|
+
if self.use_legacy_query:
|
|
79
|
+
self.logger.debug("[d][b]Using legacy /text/generation queries")
|
|
80
|
+
else:
|
|
81
|
+
self.logger.debug("[d][b]Using new /chat/completions queries")
|
|
8
82
|
|
|
9
83
|
@abstractmethod
|
|
10
|
-
def
|
|
11
|
-
|
|
84
|
+
def old_query(self, sentence: str) -> str:
|
|
85
|
+
raise NotImplementedError
|
|
12
86
|
|
|
13
|
-
|
|
14
|
-
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def new_query(self, sentence: str) -> str:
|
|
89
|
+
raise NotImplementedError
|
|
15
90
|
|
|
16
91
|
@abstractmethod
|
|
17
92
|
def encode(self, sentences: List[str]) -> List[list]:
|
|
18
|
-
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
def query(self, sentence: str) -> str:
|
|
96
|
+
if self.use_legacy_query:
|
|
97
|
+
return self.old_query(sentence)
|
|
98
|
+
return self.new_query(sentence)
|
|
99
|
+
|
|
100
|
+
def chat(
|
|
101
|
+
self,
|
|
102
|
+
messages: Sequence[Dict[str, str]],
|
|
103
|
+
params: Optional[Dict[str, Any]] = None,
|
|
104
|
+
) -> ChatResult:
|
|
105
|
+
raise NotImplementedError(
|
|
106
|
+
f"{self.__class__.__name__} does not implement chat()."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def batch_query(
|
|
110
|
+
self,
|
|
111
|
+
sentences: List[str],
|
|
112
|
+
max_workers: Optional[int] = None,
|
|
113
|
+
) -> List[str]:
|
|
114
|
+
if not sentences:
|
|
115
|
+
return []
|
|
116
|
+
|
|
117
|
+
if not max_workers or max_workers <= 1:
|
|
118
|
+
return [self.query(sentence) for sentence in sentences]
|
|
119
|
+
|
|
120
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
121
|
+
|
|
122
|
+
results: List[Optional[str]] = [None] * len(sentences)
|
|
123
|
+
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
|
124
|
+
future_to_idx = {
|
|
125
|
+
pool.submit(self.query, s): i for i, s in enumerate(sentences)
|
|
126
|
+
}
|
|
127
|
+
for fut in as_completed(future_to_idx):
|
|
128
|
+
idx = future_to_idx[fut]
|
|
129
|
+
results[idx] = fut.result()
|
|
130
|
+
|
|
131
|
+
return [r if r is not None else "" for r in results]
|
|
132
|
+
|
|
133
|
+
def set_routing(self, use_legacy_query: Optional[bool] = None) -> None:
|
|
134
|
+
if use_legacy_query is not None:
|
|
135
|
+
self.use_legacy_query = bool(use_legacy_query)
|
|
136
|
+
|
|
137
|
+
def close(self) -> None:
|
|
138
|
+
return
|
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
import uuid
|
|
1
2
|
from abc import ABC, abstractmethod
|
|
2
3
|
from typing import Any, List, Mapping, Optional, Union
|
|
3
4
|
|
|
4
5
|
import requests
|
|
5
6
|
import rich
|
|
6
7
|
|
|
8
|
+
from wxo_agentic_evaluation.service_provider.gateway_provider import (
|
|
9
|
+
GatewayProvider,
|
|
10
|
+
_translate_params_to_chat,
|
|
11
|
+
)
|
|
7
12
|
from wxo_agentic_evaluation.service_provider.model_proxy_provider import (
|
|
8
13
|
ModelProxyProvider,
|
|
9
14
|
)
|
|
@@ -149,3 +154,50 @@ class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
|
|
|
149
154
|
return resp.json()
|
|
150
155
|
else:
|
|
151
156
|
resp.raise_for_status()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class GatewayProviderLLMKitWrapper(GatewayProvider, LLMKitWrapper):
|
|
160
|
+
def chat(self, sentence: Union[str, List[Mapping[str, str]]]):
|
|
161
|
+
if isinstance(sentence, str):
|
|
162
|
+
messages = []
|
|
163
|
+
if self.system_prompt:
|
|
164
|
+
messages.append(
|
|
165
|
+
{"role": "system", "content": self.system_prompt}
|
|
166
|
+
)
|
|
167
|
+
messages.append({"role": "user", "content": sentence})
|
|
168
|
+
else:
|
|
169
|
+
messages = sentence
|
|
170
|
+
|
|
171
|
+
if self.model_id is None:
|
|
172
|
+
raise Exception("model id must be specified for text generation")
|
|
173
|
+
|
|
174
|
+
self.refresh_token_if_expires()
|
|
175
|
+
|
|
176
|
+
merged_params = dict(self.params or {})
|
|
177
|
+
chat_params = _translate_params_to_chat(merged_params)
|
|
178
|
+
chat_params.pop("stream", None)
|
|
179
|
+
|
|
180
|
+
override_params = dict(merged_params)
|
|
181
|
+
override_params["model"] = self.model_id
|
|
182
|
+
|
|
183
|
+
payload = {
|
|
184
|
+
"model": self._payload_model_str(self.model_id),
|
|
185
|
+
"messages": list(messages),
|
|
186
|
+
**chat_params,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
request_id = str(uuid.uuid4())
|
|
190
|
+
headers = self._headers(request_id, override_params)
|
|
191
|
+
|
|
192
|
+
resp = requests.post(
|
|
193
|
+
self.chat_url,
|
|
194
|
+
json=payload,
|
|
195
|
+
headers=headers,
|
|
196
|
+
verify=self._wo_ssl_verify,
|
|
197
|
+
timeout=self.timeout,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if resp.status_code == 200:
|
|
201
|
+
return resp.json()
|
|
202
|
+
else:
|
|
203
|
+
resp.raise_for_status()
|