ibm-watsonx-orchestrate-evaluation-framework 1.1.6__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ibm-watsonx-orchestrate-evaluation-framework might be problematic. Click here for more details.

Files changed (42) hide show
  1. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/METADATA +4 -1
  2. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/RECORD +42 -36
  3. wxo_agentic_evaluation/analyze_run.py +49 -32
  4. wxo_agentic_evaluation/arg_configs.py +30 -2
  5. wxo_agentic_evaluation/data_annotator.py +22 -4
  6. wxo_agentic_evaluation/description_quality_checker.py +20 -4
  7. wxo_agentic_evaluation/evaluation_package.py +189 -15
  8. wxo_agentic_evaluation/external_agent/external_validate.py +3 -1
  9. wxo_agentic_evaluation/external_agent/types.py +1 -1
  10. wxo_agentic_evaluation/inference_backend.py +64 -34
  11. wxo_agentic_evaluation/llm_matching.py +92 -2
  12. wxo_agentic_evaluation/llm_user.py +2 -2
  13. wxo_agentic_evaluation/main.py +147 -38
  14. wxo_agentic_evaluation/metrics/__init__.py +5 -1
  15. wxo_agentic_evaluation/metrics/evaluations.py +124 -0
  16. wxo_agentic_evaluation/metrics/metrics.py +24 -3
  17. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  18. wxo_agentic_evaluation/prompt/template_render.py +16 -0
  19. wxo_agentic_evaluation/quick_eval.py +17 -3
  20. wxo_agentic_evaluation/record_chat.py +17 -6
  21. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +44 -14
  22. wxo_agentic_evaluation/red_teaming/attack_generator.py +31 -12
  23. wxo_agentic_evaluation/red_teaming/attack_list.py +23 -24
  24. wxo_agentic_evaluation/red_teaming/attack_runner.py +36 -19
  25. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +42 -16
  26. wxo_agentic_evaluation/service_instance.py +5 -3
  27. wxo_agentic_evaluation/service_provider/__init__.py +129 -9
  28. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  29. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +415 -17
  30. wxo_agentic_evaluation/service_provider/ollama_provider.py +393 -22
  31. wxo_agentic_evaluation/service_provider/provider.py +130 -10
  32. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +52 -0
  33. wxo_agentic_evaluation/service_provider/watsonx_provider.py +480 -52
  34. wxo_agentic_evaluation/type.py +14 -4
  35. wxo_agentic_evaluation/utils/__init__.py +43 -5
  36. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  37. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  38. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  39. wxo_agentic_evaluation/utils/utils.py +14 -9
  40. wxo_agentic_evaluation/wxo_client.py +2 -1
  41. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/WHEEL +0 -0
  42. {ibm_watsonx_orchestrate_evaluation_framework-1.1.6.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.7.dist-info}/top_level.txt +0 -0
@@ -1,39 +1,410 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
- from typing import List
4
+ import time
5
+ import uuid
6
+ from typing import Any, Dict, Iterator, List, Optional, Sequence
4
7
 
5
8
  import requests
6
9
 
7
- from wxo_agentic_evaluation.service_provider.provider import Provider
10
+ from wxo_agentic_evaluation.service_provider.provider import (
11
+ ChatResult,
12
+ Provider,
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
8
16
 
9
17
  OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
10
18
 
11
19
 
20
+ def _truncate(value: Any, max_len: int = 1000) -> str:
21
+ if value is None:
22
+ return ""
23
+ s = str(value)
24
+ return (
25
+ s
26
+ if len(s) <= max_len
27
+ else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
28
+ )
29
+
30
+
31
+ def _translate_params_to_ollama_options(
32
+ params: Optional[Dict[str, Any]]
33
+ ) -> Dict[str, Any]:
34
+ """
35
+ Map generic params to Ollama 'options' field.
36
+ Ollama options docs: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#parameters
37
+ """
38
+ p = params or {}
39
+ out: Dict[str, Any] = {}
40
+
41
+ for key in ("temperature", "top_p", "top_k", "stop", "seed"):
42
+ if key in p:
43
+ out[key] = p[key]
44
+
45
+ if "max_new_tokens" in p:
46
+ out["num_predict"] = p["max_new_tokens"]
47
+ elif "max_tokens" in p:
48
+ out["num_predict"] = p["max_tokens"]
49
+
50
+ if "repeat_penalty" in p:
51
+ out["repeat_penalty"] = p["repeat_penalty"]
52
+ if "repeat_last_n" in p:
53
+ out["repeat_last_n"] = p["repeat_last_n"]
54
+
55
+ return out
56
+
57
+
12
58
  class OllamaProvider(Provider):
13
- def __init__(self, model_id=None):
14
- self.url = OLLAMA_URL + "/api/generate"
15
- self.model_id = model_id
16
- super().__init__()
17
-
18
- def query(self, sentence: str) -> str:
19
- payload = {"model": self.model_id, "prompt": sentence}
20
- resp = requests.post(self.url, json=payload, stream=True)
59
+ def __init__(
60
+ self,
61
+ model_id: Optional[str] = None,
62
+ params: Optional[Dict[str, Any]] = None,
63
+ timeout: int = 300,
64
+ use_legacy_query: Optional[bool] = None,
65
+ system_prompt: Optional[str] = None,
66
+ token: Optional[str] = None,
67
+ instance_url: Optional[str] = None,
68
+ ):
69
+ super().__init__(use_legacy_query=use_legacy_query)
70
+ self.generate_url = (
71
+ OLLAMA_URL.rstrip("/") + "/api/generate"
72
+ ) # legacy text generation
73
+ self.chat_url = OLLAMA_URL.rstrip("/") + "/api/chat" # chat endpoint
74
+ self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
75
+ logger.info("[d b]Using inference model %s", self.model_id)
76
+ self.params = params or {}
77
+ self.timeout = timeout
78
+ self.system_prompt = system_prompt
79
+
80
+ def old_query(self, sentence: str) -> str:
81
+ # Legacy /api/generate
82
+ if not self.model_id:
83
+ raise ValueError("model_id must be specified for Ollama generation")
84
+
85
+ options = _translate_params_to_ollama_options(self.params)
86
+ payload: Dict[str, Any] = {
87
+ "model": self.model_id,
88
+ "prompt": sentence,
89
+ "stream": True,
90
+ }
91
+ if options:
92
+ payload["options"] = options
93
+
94
+ request_id = str(uuid.uuid4())
95
+ t0 = time.time()
96
+
97
+ logger.debug(
98
+ "[d][b]Sending Ollama generate request | request_id=%s url=%s model=%s params=%s input_preview=%s",
99
+ request_id,
100
+ self.generate_url,
101
+ self.model_id,
102
+ json.dumps(options, sort_keys=True, ensure_ascii=False),
103
+ _truncate(sentence, 200),
104
+ )
105
+
106
+ resp = None
21
107
  final_text = ""
22
- data = b""
23
- for chunk in resp:
24
- data += chunk
25
- if data.endswith(b"\n"):
26
- json_obj = json.loads(data)
27
- if not json_obj["done"] and json_obj["response"]:
28
- final_text += json_obj["response"]
29
- data = b""
108
+ usage: Dict[str, Any] = {}
109
+
110
+ try:
111
+ resp = requests.post(
112
+ self.generate_url,
113
+ json=payload,
114
+ stream=True,
115
+ timeout=self.timeout,
116
+ )
117
+
118
+ if resp.status_code != 200:
119
+ resp_text_preview = _truncate(getattr(resp, "text", ""), 2000)
120
+ duration_ms = int((time.time() - t0) * 1000)
121
+ logger.error(
122
+ "[d b red]Ollama generate request failed (non-200) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
123
+ request_id,
124
+ resp.status_code,
125
+ duration_ms,
126
+ resp_text_preview,
127
+ )
128
+ resp.raise_for_status()
129
+
130
+ for line in resp.iter_lines(decode_unicode=True):
131
+ if not line:
132
+ continue
133
+ try:
134
+ obj = json.loads(line)
135
+ except Exception:
136
+ logger.warning(
137
+ "Skipping unparsable line from Ollama generate | request_id=%s line_preview=%s",
138
+ request_id,
139
+ _truncate(line, 500),
140
+ )
141
+ continue
142
+
143
+ if not obj.get("done"):
144
+ chunk = obj.get("response", "")
145
+ if chunk:
146
+ final_text += chunk
147
+ else:
148
+ # Final metrics frame
149
+ usage = {
150
+ "prompt_eval_count": obj.get("prompt_eval_count"),
151
+ "eval_count": obj.get("eval_count"),
152
+ "prompt_eval_duration_ns": obj.get(
153
+ "prompt_eval_duration"
154
+ ),
155
+ "eval_duration_ns": obj.get("eval_duration"),
156
+ "total_duration_ns": obj.get("total_duration"),
157
+ "load_duration_ns": obj.get("load_duration"),
158
+ }
159
+
160
+ duration_ms = int((time.time() - t0) * 1000)
161
+ logger.debug(
162
+ "[d][b]Ollama generate response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s",
163
+ request_id,
164
+ resp.status_code,
165
+ duration_ms,
166
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
167
+ _truncate(final_text, 2000),
168
+ )
169
+
170
+ return final_text
171
+
172
+ except Exception:
173
+ duration_ms = int((time.time() - t0) * 1000)
174
+ status_code = getattr(resp, "status_code", None)
175
+ resp_text_preview = None
176
+ try:
177
+ if resp is not None and not getattr(resp, "raw", None):
178
+ resp_text_preview = _truncate(
179
+ getattr(resp, "text", None), 2000
180
+ )
181
+ except Exception:
182
+ pass
183
+
184
+ logger.exception(
185
+ "Ollama generate request encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
186
+ request_id,
187
+ status_code,
188
+ duration_ms,
189
+ resp_text_preview,
190
+ )
191
+ raise
192
+
193
+ def new_query(self, sentence: str) -> str:
194
+ """
195
+ /api/chat
196
+ Returns assistant message content.
197
+ """
198
+ if not self.model_id:
199
+ raise ValueError("model_id must be specified for Ollama chat")
200
+
201
+ options = _translate_params_to_ollama_options(self.params)
30
202
 
31
- return final_text
203
+ messages: List[Dict[str, str]] = []
204
+ if self.system_prompt:
205
+ messages.append({"role": "system", "content": self.system_prompt})
206
+ messages.append({"role": "user", "content": sentence})
207
+
208
+ payload: Dict[str, Any] = {
209
+ "model": self.model_id,
210
+ "messages": messages,
211
+ "stream": False,
212
+ }
213
+ if options:
214
+ payload["options"] = options
215
+
216
+ request_id = str(uuid.uuid4())
217
+ t0 = time.time()
218
+
219
+ logger.debug(
220
+ "[d][b]Sending Ollama chat request (non-streaming) | request_id=%s url=%s model=%s params=%s input_preview=%s",
221
+ request_id,
222
+ self.chat_url,
223
+ self.model_id,
224
+ json.dumps(options, sort_keys=True, ensure_ascii=False),
225
+ _truncate(sentence, 200),
226
+ )
227
+
228
+ resp = None
229
+ try:
230
+ resp = requests.post(
231
+ self.chat_url, json=payload, timeout=self.timeout
232
+ )
233
+ duration_ms = int((time.time() - t0) * 1000)
234
+ resp.raise_for_status()
235
+ data = resp.json()
236
+
237
+ # Non-streaming chat response: { "message": {"role": "assistant", "content": "..."} , "done": true, ... }
238
+ message = data.get("message") or {}
239
+ content = message.get("content", "") or ""
240
+ finish_reason = data.get("finish_reason")
241
+ usage = {
242
+ "prompt_eval_count": data.get("prompt_eval_count"),
243
+ "eval_count": data.get("eval_count"),
244
+ "prompt_eval_duration_ns": data.get("prompt_eval_duration"),
245
+ "eval_duration_ns": data.get("eval_duration"),
246
+ "total_duration_ns": data.get("total_duration"),
247
+ "load_duration_ns": data.get("load_duration"),
248
+ }
249
+
250
+ logger.debug(
251
+ "[d][b]Ollama chat response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s",
252
+ request_id,
253
+ resp.status_code,
254
+ duration_ms,
255
+ finish_reason,
256
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
257
+ _truncate(content, 2000),
258
+ )
259
+
260
+ return content
261
+
262
+ except Exception:
263
+ duration_ms = int((time.time() - t0) * 1000)
264
+ status_code = getattr(resp, "status_code", None)
265
+ resp_text_preview = (
266
+ _truncate(getattr(resp, "text", None), 2000)
267
+ if resp is not None
268
+ else None
269
+ )
270
+
271
+ logger.exception(
272
+ "Ollama chat request encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
273
+ request_id,
274
+ status_code,
275
+ duration_ms,
276
+ resp_text_preview,
277
+ )
278
+ raise
279
+
280
+ def chat(
281
+ self,
282
+ messages: Sequence[Dict[str, str]],
283
+ params: Optional[Dict[str, Any]] = None,
284
+ ) -> ChatResult:
285
+ """
286
+ Non-streaming chat via /api/chat.
287
+ """
288
+ if not self.model_id:
289
+ raise ValueError("model_id must be specified for Ollama chat")
290
+
291
+ merged_params = dict(self.params or {})
292
+ if params:
293
+ merged_params.update(params)
294
+ options = _translate_params_to_ollama_options(merged_params)
295
+
296
+ payload: Dict[str, Any] = {
297
+ "model": self.model_id,
298
+ "messages": list(messages),
299
+ "stream": False,
300
+ }
301
+ if options:
302
+ payload["options"] = options
303
+
304
+ last_user = next(
305
+ (
306
+ m.get("content", "")
307
+ for m in reversed(messages)
308
+ if m.get("role") == "user"
309
+ ),
310
+ "",
311
+ )
312
+ request_id = str(uuid.uuid4())
313
+ t0 = time.time()
314
+
315
+ logger.debug(
316
+ "[d][b]Sending Ollama chat request (non-streaming, multi-message) | request_id=%s url=%s model=%s params=%s input_preview=%s",
317
+ request_id,
318
+ self.chat_url,
319
+ self.model_id,
320
+ json.dumps(options, sort_keys=True, ensure_ascii=False),
321
+ _truncate(last_user, 200),
322
+ )
323
+
324
+ resp = None
325
+ try:
326
+ resp = requests.post(
327
+ self.chat_url, json=payload, timeout=self.timeout
328
+ )
329
+ duration_ms = int((time.time() - t0) * 1000)
330
+ resp.raise_for_status()
331
+ data = resp.json()
332
+
333
+ message = data.get("message") or {}
334
+ content = message.get("content", "") or ""
335
+ finish_reason = data.get("finish_reason")
336
+ usage = {
337
+ "prompt_eval_count": data.get("prompt_eval_count"),
338
+ "eval_count": data.get("eval_count"),
339
+ "prompt_eval_duration_ns": data.get("prompt_eval_duration"),
340
+ "eval_duration_ns": data.get("eval_duration"),
341
+ "total_duration_ns": data.get("total_duration"),
342
+ "load_duration_ns": data.get("load_duration"),
343
+ }
344
+
345
+ logger.debug(
346
+ "[d][b]Ollama chat response received (non-streaming, multi-message) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s",
347
+ request_id,
348
+ resp.status_code,
349
+ duration_ms,
350
+ finish_reason,
351
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
352
+ _truncate(content, 2000),
353
+ )
354
+
355
+ return ChatResult(
356
+ text=content, usage=usage, finish_reason=finish_reason, raw=data
357
+ )
358
+
359
+ except Exception:
360
+ duration_ms = int((time.time() - t0) * 1000)
361
+ status_code = getattr(resp, "status_code", None)
362
+ resp_text_preview = (
363
+ _truncate(getattr(resp, "text", None), 2000)
364
+ if resp is not None
365
+ else None
366
+ )
367
+
368
+ logger.exception(
369
+ "Ollama chat request (non-streaming, multi-message) encountered an error | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
370
+ request_id,
371
+ status_code,
372
+ duration_ms,
373
+ resp_text_preview,
374
+ )
375
+ raise
32
376
 
33
377
  def encode(self, sentences: List[str]) -> List[list]:
34
- pass
378
+ raise NotImplementedError(
379
+ "encode is not implemented for OllamaProvider"
380
+ )
35
381
 
36
382
 
37
383
  if __name__ == "__main__":
38
- provider = OllamaProvider(model_id="llama3.1:8b")
39
- print(provider.query("ok"))
384
+ logging.basicConfig(
385
+ level=logging.INFO,
386
+ format="%(asctime)s %(levelname)s %(name)s %(message)s",
387
+ )
388
+
389
+ provider = OllamaProvider(model_id="llama3.1:8b", use_legacy_query=False)
390
+
391
+ print("new_query:", provider.query("Say hello in one sentence."))
392
+
393
+ # chat API
394
+ messages = [
395
+ {"role": "system", "content": "You are concise."},
396
+ {"role": "user", "content": "List three fruits."},
397
+ ]
398
+ result = provider.chat(messages)
399
+ print("chat:", result.text)
400
+
401
+ # Streaming chat
402
+ print("stream_chat:")
403
+ assembled = []
404
+ for chunk in provider.stream_chat(
405
+ [{"role": "user", "content": "Stream a short sentence."}]
406
+ ):
407
+ if chunk.get("delta"):
408
+ assembled.append(chunk["delta"])
409
+ if chunk.get("is_final"):
410
+ print("".join(assembled))
@@ -1,18 +1,138 @@
1
- from abc import ABC, abstractmethod
2
- from typing import List
1
+ from __future__ import annotations
3
2
 
3
+ import logging
4
+ import os
5
+ from abc import ABC, ABCMeta, abstractmethod
6
+ from dataclasses import dataclass
7
+ from threading import Lock
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
4
9
 
5
- class Provider(ABC):
6
- def __init__(self):
7
- pass
10
+ from wxo_agentic_evaluation.type import ProviderInstancesCacheKey
11
+
12
+
13
+ class SingletonProviderMeta(type):
14
+
15
+ _provider_instances: Dict[str, "Provider"] = {}
16
+ _instantiation_lock = Lock()
17
+
18
+ def __call__(cls, *args, **kwargs):
19
+
20
+ key_str: str = str(cls._get_key(cls.__name__, args, kwargs))
21
+
22
+ if key_str not in cls._provider_instances:
23
+ with cls._instantiation_lock:
24
+ if key_str not in cls._provider_instances:
25
+ cls._provider_instances[key_str] = super().__call__(
26
+ *args, **kwargs
27
+ )
28
+
29
+ return cls._provider_instances[key_str]
30
+
31
+ @staticmethod
32
+ def _get_key(
33
+ provider: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
34
+ ) -> ProviderInstancesCacheKey:
35
+
36
+ args_str = str(args) if args else "noargs"
37
+ kwargs_str = str(sorted(kwargs.items())) if kwargs else "nokwargs"
38
+
39
+ return ProviderInstancesCacheKey(
40
+ provider=provider,
41
+ hashed_args=args_str,
42
+ hashed_kwargs=kwargs_str,
43
+ )
44
+
45
+
46
+ class SingletonProviderABCMeta(ABCMeta, SingletonProviderMeta):
47
+ pass
48
+
49
+
50
+ @dataclass
51
+ class ChatResult:
52
+ text: str
53
+ usage: Optional[Dict[str, Any]] = None
54
+ finish_reason: Optional[str] = None
55
+ raw: Optional[Any] = None
56
+
57
+
58
+ class Provider(ABC, metaclass=SingletonProviderABCMeta):
59
+ def __init__(
60
+ self,
61
+ use_legacy_query: Optional[bool] = None,
62
+ logger: Optional[logging.Logger] = None,
63
+ ) -> None:
64
+ self.logger = logger or logging.getLogger(self.__class__.__name__)
65
+
66
+ env_use_legacy = os.environ.get("USE_LEGACY_QUERY")
67
+ if env_use_legacy is not None:
68
+ self.use_legacy_query: bool = env_use_legacy.strip().lower() in (
69
+ "1",
70
+ "true",
71
+ "yes",
72
+ "on",
73
+ )
74
+ else:
75
+ self.use_legacy_query = (
76
+ bool(use_legacy_query) if use_legacy_query is not None else True
77
+ )
78
+ if self.use_legacy_query:
79
+ self.logger.debug("[d][b]Using legacy /text/generation queries")
80
+ else:
81
+ self.logger.debug("[d][b]Using new /chat/completions queries")
8
82
 
9
83
  @abstractmethod
10
- def query(self, sentence: str) -> str:
11
- pass
84
+ def old_query(self, sentence: str) -> str:
85
+ raise NotImplementedError
12
86
 
13
- def batch_query(self, sentences: List[str]) -> List[str]:
14
- return [self.query(sentence) for sentence in sentences]
87
+ @abstractmethod
88
+ def new_query(self, sentence: str) -> str:
89
+ raise NotImplementedError
15
90
 
16
91
  @abstractmethod
17
92
  def encode(self, sentences: List[str]) -> List[list]:
18
- pass
93
+ raise NotImplementedError
94
+
95
+ def query(self, sentence: str) -> str:
96
+ if self.use_legacy_query:
97
+ return self.old_query(sentence)
98
+ return self.new_query(sentence)
99
+
100
+ def chat(
101
+ self,
102
+ messages: Sequence[Dict[str, str]],
103
+ params: Optional[Dict[str, Any]] = None,
104
+ ) -> ChatResult:
105
+ raise NotImplementedError(
106
+ f"{self.__class__.__name__} does not implement chat()."
107
+ )
108
+
109
+ def batch_query(
110
+ self,
111
+ sentences: List[str],
112
+ max_workers: Optional[int] = None,
113
+ ) -> List[str]:
114
+ if not sentences:
115
+ return []
116
+
117
+ if not max_workers or max_workers <= 1:
118
+ return [self.query(sentence) for sentence in sentences]
119
+
120
+ from concurrent.futures import ThreadPoolExecutor, as_completed
121
+
122
+ results: List[Optional[str]] = [None] * len(sentences)
123
+ with ThreadPoolExecutor(max_workers=max_workers) as pool:
124
+ future_to_idx = {
125
+ pool.submit(self.query, s): i for i, s in enumerate(sentences)
126
+ }
127
+ for fut in as_completed(future_to_idx):
128
+ idx = future_to_idx[fut]
129
+ results[idx] = fut.result()
130
+
131
+ return [r if r is not None else "" for r in results]
132
+
133
+ def set_routing(self, use_legacy_query: Optional[bool] = None) -> None:
134
+ if use_legacy_query is not None:
135
+ self.use_legacy_query = bool(use_legacy_query)
136
+
137
+ def close(self) -> None:
138
+ return
@@ -1,9 +1,14 @@
1
+ import uuid
1
2
  from abc import ABC, abstractmethod
2
3
  from typing import Any, List, Mapping, Optional, Union
3
4
 
4
5
  import requests
5
6
  import rich
6
7
 
8
+ from wxo_agentic_evaluation.service_provider.gateway_provider import (
9
+ GatewayProvider,
10
+ _translate_params_to_chat,
11
+ )
7
12
  from wxo_agentic_evaluation.service_provider.model_proxy_provider import (
8
13
  ModelProxyProvider,
9
14
  )
@@ -149,3 +154,50 @@ class WatsonXLLMKitWrapper(WatsonXProvider, LLMKitWrapper):
149
154
  return resp.json()
150
155
  else:
151
156
  resp.raise_for_status()
157
+
158
+
159
+ class GatewayProviderLLMKitWrapper(GatewayProvider, LLMKitWrapper):
160
+ def chat(self, sentence: Union[str, List[Mapping[str, str]]]):
161
+ if isinstance(sentence, str):
162
+ messages = []
163
+ if self.system_prompt:
164
+ messages.append(
165
+ {"role": "system", "content": self.system_prompt}
166
+ )
167
+ messages.append({"role": "user", "content": sentence})
168
+ else:
169
+ messages = sentence
170
+
171
+ if self.model_id is None:
172
+ raise Exception("model id must be specified for text generation")
173
+
174
+ self.refresh_token_if_expires()
175
+
176
+ merged_params = dict(self.params or {})
177
+ chat_params = _translate_params_to_chat(merged_params)
178
+ chat_params.pop("stream", None)
179
+
180
+ override_params = dict(merged_params)
181
+ override_params["model"] = self.model_id
182
+
183
+ payload = {
184
+ "model": self._payload_model_str(self.model_id),
185
+ "messages": list(messages),
186
+ **chat_params,
187
+ }
188
+
189
+ request_id = str(uuid.uuid4())
190
+ headers = self._headers(request_id, override_params)
191
+
192
+ resp = requests.post(
193
+ self.chat_url,
194
+ json=payload,
195
+ headers=headers,
196
+ verify=self._wo_ssl_verify,
197
+ timeout=self.timeout,
198
+ )
199
+
200
+ if resp.status_code == 200:
201
+ return resp.json()
202
+ else:
203
+ resp.raise_for_status()