ibm-watsonx-orchestrate-evaluation-framework 1.0.3__py3-none-any.whl → 1.1.8b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/METADATA +53 -0
  2. ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info/RECORD +146 -0
  3. wxo_agentic_evaluation/analytics/tools/analyzer.py +38 -21
  4. wxo_agentic_evaluation/analytics/tools/main.py +19 -25
  5. wxo_agentic_evaluation/analytics/tools/types.py +26 -11
  6. wxo_agentic_evaluation/analytics/tools/ux.py +75 -31
  7. wxo_agentic_evaluation/analyze_run.py +1184 -97
  8. wxo_agentic_evaluation/annotate.py +7 -5
  9. wxo_agentic_evaluation/arg_configs.py +97 -5
  10. wxo_agentic_evaluation/base_user.py +25 -0
  11. wxo_agentic_evaluation/batch_annotate.py +97 -27
  12. wxo_agentic_evaluation/clients.py +103 -0
  13. wxo_agentic_evaluation/compare_runs/__init__.py +0 -0
  14. wxo_agentic_evaluation/compare_runs/compare_2_runs.py +74 -0
  15. wxo_agentic_evaluation/compare_runs/diff.py +554 -0
  16. wxo_agentic_evaluation/compare_runs/model.py +193 -0
  17. wxo_agentic_evaluation/data_annotator.py +45 -19
  18. wxo_agentic_evaluation/description_quality_checker.py +178 -0
  19. wxo_agentic_evaluation/evaluation.py +50 -0
  20. wxo_agentic_evaluation/evaluation_controller/evaluation_controller.py +303 -0
  21. wxo_agentic_evaluation/evaluation_package.py +544 -107
  22. wxo_agentic_evaluation/external_agent/__init__.py +18 -7
  23. wxo_agentic_evaluation/external_agent/external_validate.py +49 -36
  24. wxo_agentic_evaluation/external_agent/performance_test.py +33 -22
  25. wxo_agentic_evaluation/external_agent/types.py +8 -7
  26. wxo_agentic_evaluation/extractors/__init__.py +3 -0
  27. wxo_agentic_evaluation/extractors/extractor_base.py +21 -0
  28. wxo_agentic_evaluation/extractors/labeled_messages.py +47 -0
  29. wxo_agentic_evaluation/hr_agent_langgraph.py +68 -0
  30. wxo_agentic_evaluation/langfuse_collection.py +60 -0
  31. wxo_agentic_evaluation/langfuse_evaluation_package.py +192 -0
  32. wxo_agentic_evaluation/llm_matching.py +108 -5
  33. wxo_agentic_evaluation/llm_rag_eval.py +7 -4
  34. wxo_agentic_evaluation/llm_safety_eval.py +64 -0
  35. wxo_agentic_evaluation/llm_user.py +12 -6
  36. wxo_agentic_evaluation/llm_user_v2.py +114 -0
  37. wxo_agentic_evaluation/main.py +128 -246
  38. wxo_agentic_evaluation/metrics/__init__.py +15 -0
  39. wxo_agentic_evaluation/metrics/dummy_metric.py +16 -0
  40. wxo_agentic_evaluation/metrics/evaluations.py +107 -0
  41. wxo_agentic_evaluation/metrics/journey_success.py +137 -0
  42. wxo_agentic_evaluation/metrics/llm_as_judge.py +28 -2
  43. wxo_agentic_evaluation/metrics/metrics.py +319 -16
  44. wxo_agentic_evaluation/metrics/tool_calling.py +93 -0
  45. wxo_agentic_evaluation/otel_parser/__init__.py +1 -0
  46. wxo_agentic_evaluation/otel_parser/langflow_parser.py +86 -0
  47. wxo_agentic_evaluation/otel_parser/langgraph_parser.py +61 -0
  48. wxo_agentic_evaluation/otel_parser/parser.py +163 -0
  49. wxo_agentic_evaluation/otel_parser/parser_types.py +38 -0
  50. wxo_agentic_evaluation/otel_parser/pydantic_parser.py +50 -0
  51. wxo_agentic_evaluation/otel_parser/utils.py +15 -0
  52. wxo_agentic_evaluation/otel_parser/wxo_parser.py +39 -0
  53. wxo_agentic_evaluation/otel_support/evaluate_tau.py +101 -0
  54. wxo_agentic_evaluation/otel_support/otel_message_conversion.py +29 -0
  55. wxo_agentic_evaluation/otel_support/tasks_test.py +1566 -0
  56. wxo_agentic_evaluation/prompt/bad_tool_descriptions_prompt.jinja2 +178 -0
  57. wxo_agentic_evaluation/prompt/derailment_prompt.jinja2 +55 -0
  58. wxo_agentic_evaluation/prompt/llama_user_prompt.jinja2 +59 -5
  59. wxo_agentic_evaluation/prompt/llmaaj_prompt.jinja2 +15 -0
  60. wxo_agentic_evaluation/prompt/off_policy_attack_generation_prompt.jinja2 +34 -0
  61. wxo_agentic_evaluation/prompt/on_policy_attack_generation_prompt.jinja2 +46 -0
  62. wxo_agentic_evaluation/prompt/semantic_matching_prompt.jinja2 +41 -9
  63. wxo_agentic_evaluation/prompt/template_render.py +163 -12
  64. wxo_agentic_evaluation/prompt/unsafe_topic_prompt.jinja2 +65 -0
  65. wxo_agentic_evaluation/quick_eval.py +384 -0
  66. wxo_agentic_evaluation/record_chat.py +132 -81
  67. wxo_agentic_evaluation/red_teaming/attack_evaluator.py +302 -0
  68. wxo_agentic_evaluation/red_teaming/attack_generator.py +329 -0
  69. wxo_agentic_evaluation/red_teaming/attack_list.py +184 -0
  70. wxo_agentic_evaluation/red_teaming/attack_runner.py +204 -0
  71. wxo_agentic_evaluation/referenceless_eval/__init__.py +3 -0
  72. wxo_agentic_evaluation/referenceless_eval/function_calling/__init__.py +0 -0
  73. wxo_agentic_evaluation/referenceless_eval/function_calling/consts.py +28 -0
  74. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/__init__.py +0 -0
  75. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/base.py +29 -0
  76. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/__init__.py +0 -0
  77. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general.py +49 -0
  78. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics.json +783 -0
  79. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
  80. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/__init__.py +0 -0
  81. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection.py +31 -0
  82. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics.json +600 -0
  83. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
  84. wxo_agentic_evaluation/referenceless_eval/function_calling/metrics/loader.py +245 -0
  85. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/__init__.py +0 -0
  86. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/adapters.py +106 -0
  87. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/pipeline.py +291 -0
  88. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/semantic_checker.py +465 -0
  89. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/static_checker.py +162 -0
  90. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/transformation_prompts.py +509 -0
  91. wxo_agentic_evaluation/referenceless_eval/function_calling/pipeline/types.py +562 -0
  92. wxo_agentic_evaluation/referenceless_eval/metrics/__init__.py +3 -0
  93. wxo_agentic_evaluation/referenceless_eval/metrics/field.py +266 -0
  94. wxo_agentic_evaluation/referenceless_eval/metrics/metric.py +344 -0
  95. wxo_agentic_evaluation/referenceless_eval/metrics/metrics_runner.py +193 -0
  96. wxo_agentic_evaluation/referenceless_eval/metrics/prompt.py +413 -0
  97. wxo_agentic_evaluation/referenceless_eval/metrics/utils.py +46 -0
  98. wxo_agentic_evaluation/referenceless_eval/prompt/__init__.py +0 -0
  99. wxo_agentic_evaluation/referenceless_eval/prompt/runner.py +158 -0
  100. wxo_agentic_evaluation/referenceless_eval/referenceless_eval.py +191 -0
  101. wxo_agentic_evaluation/resource_map.py +6 -3
  102. wxo_agentic_evaluation/runner.py +329 -0
  103. wxo_agentic_evaluation/runtime_adapter/a2a_runtime_adapter.py +0 -0
  104. wxo_agentic_evaluation/runtime_adapter/runtime_adapter.py +14 -0
  105. wxo_agentic_evaluation/{inference_backend.py → runtime_adapter/wxo_runtime_adapter.py} +88 -150
  106. wxo_agentic_evaluation/scheduler.py +247 -0
  107. wxo_agentic_evaluation/service_instance.py +117 -26
  108. wxo_agentic_evaluation/service_provider/__init__.py +182 -17
  109. wxo_agentic_evaluation/service_provider/gateway_provider.py +707 -0
  110. wxo_agentic_evaluation/service_provider/model_proxy_provider.py +628 -45
  111. wxo_agentic_evaluation/service_provider/ollama_provider.py +392 -22
  112. wxo_agentic_evaluation/service_provider/portkey_provider.py +229 -0
  113. wxo_agentic_evaluation/service_provider/provider.py +129 -10
  114. wxo_agentic_evaluation/service_provider/referenceless_provider_wrapper.py +203 -0
  115. wxo_agentic_evaluation/service_provider/watsonx_provider.py +516 -53
  116. wxo_agentic_evaluation/simluation_runner.py +125 -0
  117. wxo_agentic_evaluation/test_prompt.py +4 -4
  118. wxo_agentic_evaluation/tool_planner.py +141 -46
  119. wxo_agentic_evaluation/type.py +217 -14
  120. wxo_agentic_evaluation/user_simulator/demo_usage_llm_user.py +100 -0
  121. wxo_agentic_evaluation/utils/__init__.py +44 -3
  122. wxo_agentic_evaluation/utils/evaluation_discovery.py +47 -0
  123. wxo_agentic_evaluation/utils/gateway_provider_utils.py +39 -0
  124. wxo_agentic_evaluation/utils/messages_parser.py +30 -0
  125. wxo_agentic_evaluation/utils/open_ai_tool_extractor.py +178 -0
  126. wxo_agentic_evaluation/utils/parsers.py +71 -0
  127. wxo_agentic_evaluation/utils/rich_utils.py +188 -0
  128. wxo_agentic_evaluation/utils/rouge_score.py +23 -0
  129. wxo_agentic_evaluation/utils/utils.py +514 -17
  130. wxo_agentic_evaluation/wxo_client.py +81 -0
  131. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/METADATA +0 -380
  132. ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info/RECORD +0 -56
  133. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/WHEEL +0 -0
  134. {ibm_watsonx_orchestrate_evaluation_framework-1.0.3.dist-info → ibm_watsonx_orchestrate_evaluation_framework-1.1.8b0.dist-info}/top_level.txt +0 -0
@@ -1,65 +1,284 @@
1
+ import json
2
+ import logging
1
3
  import os
2
- import requests
3
4
  import time
4
-
5
- from typing import List
5
+ import uuid
6
6
  from threading import Lock
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
8
+
9
+ import requests
7
10
 
8
- from wxo_agentic_evaluation.service_provider.provider import Provider
11
+ from wxo_agentic_evaluation.service_provider.provider import (
12
+ ChatResult,
13
+ Provider,
14
+ )
9
15
  from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
10
16
 
11
- AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
17
+ logger = logging.getLogger(__name__)
18
+
19
+ AUTH_ENDPOINT_AWS = (
20
+ "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
21
+ )
12
22
  AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
13
- WO_INSTANCE = os.environ.get("WO_INSTANCE")
14
- WO_API_KEY = os.environ.get("WO_API_KEY")
15
- DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
23
+ DEFAULT_PARAM = {
24
+ "min_new_tokens": 1,
25
+ "decoding_method": "greedy",
26
+ "max_new_tokens": 2500
27
+ }
28
+
29
+
30
+ def _infer_cpd_auth_url(instance_url: str) -> str:
31
+ inst = (instance_url or "").rstrip("/")
32
+ if not inst:
33
+ return "/icp4d-api/v1/authorize"
34
+ if "/orchestrate" in inst:
35
+ base = inst.split("/orchestrate", 1)[0].rstrip("/")
36
+ return base + "/icp4d-api/v1/authorize"
37
+ return inst + "/icp4d-api/v1/authorize"
38
+
39
+
40
+ def _normalize_cpd_auth_url(url: str) -> str:
41
+ u = (url or "").rstrip("/")
42
+ if u.endswith("/icp4d-api"):
43
+ return u + "/v1/authorize"
44
+ return url
45
+
46
+
47
+ def _truncate(value: Any, max_len: int = 1000) -> str:
48
+ if value is None:
49
+ return ""
50
+ s = str(value)
51
+ return (
52
+ s
53
+ if len(s) <= max_len
54
+ else s[:max_len] + f"... [truncated {len(s) - max_len} chars]"
55
+ )
56
+
57
+
58
+ def _translate_params_to_chat(params: Dict[str, Any] | None) -> Dict[str, Any]:
59
+ # Translate legacy generation params to chat.completions params.
60
+ p = params or {}
61
+ out: Dict[str, Any] = {}
62
+
63
+ passthrough = {
64
+ "temperature",
65
+ "top_p",
66
+ "n",
67
+ "stream",
68
+ "stop",
69
+ "presence_penalty",
70
+ "frequency_penalty",
71
+ "logit_bias",
72
+ "user",
73
+ "max_tokens",
74
+ "seed",
75
+ "response_format",
76
+ }
77
+ for k in passthrough:
78
+ if k in p:
79
+ out[k] = p[k]
80
+
81
+ if "max_new_tokens" in p and "max_tokens" not in out:
82
+ out["max_tokens"] = p["max_new_tokens"]
83
+
84
+ return out
16
85
 
17
86
 
18
87
  class ModelProxyProvider(Provider):
19
88
  def __init__(
20
89
  self,
21
- model_id=None,
22
- api_key=WO_API_KEY,
23
- instance_url=WO_INSTANCE,
24
- timeout=300,
25
- embedding_model_id=None,
26
- params=None
90
+ model_id: Optional[str] = None,
91
+ api_key: Optional[str] = None,
92
+ instance_url: Optional[str] = None,
93
+ timeout: int = 300,
94
+ embedding_model_id: Optional[str] = None,
95
+ params: Optional[Dict[str, Any]] = None,
96
+ use_legacy_query: Optional[
97
+ bool
98
+ ] = None, # Provider routes query() to old/new based on this
99
+ system_prompt: Optional[str] = None,
100
+ token: Optional[str] = None,
27
101
  ):
28
- super().__init__()
102
+ super().__init__(use_legacy_query=use_legacy_query)
29
103
 
30
- if not instance_url or not api_key:
31
- raise RuntimeError("instance url and WO apikey must be specified to use WO model proxy")
104
+ instance_url = os.environ.get("WO_INSTANCE", instance_url)
105
+ if not instance_url:
106
+ raise RuntimeError(
107
+ "instance url must be specified to use WO model proxy"
108
+ )
32
109
 
33
110
  self.timeout = timeout
34
- self.model_id = model_id
35
-
111
+ self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
112
+ logger.info("[d b]Using inference model %s", self.model_id)
36
113
  self.embedding_model_id = embedding_model_id
37
114
 
38
- self.api_key = api_key
115
+ self.api_key = os.environ.get("WO_API_KEY", api_key)
116
+ self.username = os.environ.get("WO_USERNAME", None)
117
+ self.password = os.environ.get("WO_PASSWORD", None)
118
+ self.auth_type = os.environ.get(
119
+ "WO_AUTH_TYPE", ""
120
+ ).lower() # explicit override if set, otherwise inferred- match ADK values
121
+ explicit_auth_url = os.environ.get("AUTHORIZATION_URL", None)
122
+
39
123
  self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
40
- self.auth_url = AUTH_ENDPOINT_IBM_CLOUD if self.is_ibm_cloud else AUTH_ENDPOINT_AWS
41
- self.url = instance_url + "/ml/v1/text/generation?version=2024-05-01"
124
+ self.instance_url = instance_url.rstrip("/")
42
125
 
43
- self.embedding_url = instance_url + "/ml/v1/text/embeddings"
126
+ self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(
127
+ explicit_auth_url=explicit_auth_url
128
+ )
129
+ self._wo_ssl_verify = (
130
+ os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
131
+ )
132
+ env_space_id = os.environ.get("WATSONX_SPACE_ID", None)
133
+ if self.auth_mode == "cpd":
134
+ if not env_space_id or not env_space_id.strip():
135
+ raise RuntimeError(
136
+ "CPD mode requires WATSONX_SPACE_ID environment variable to be set"
137
+ )
138
+ self.space_id = env_space_id.strip()
139
+ else:
140
+ self.space_id = (
141
+ env_space_id.strip()
142
+ if env_space_id and env_space_id.strip()
143
+ else "1"
144
+ )
145
+
146
+ if self.auth_mode == "cpd":
147
+ if "/orchestrate" in self.instance_url:
148
+ self.instance_url = self.instance_url.split("/orchestrate", 1)[
149
+ 0
150
+ ].rstrip("/")
151
+ if not self.username:
152
+ raise RuntimeError("CPD auth requires WO_USERNAME to be set")
153
+ if not (self.password or self.api_key):
154
+ raise RuntimeError(
155
+ "CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)"
156
+ )
157
+ else:
158
+ if not self.api_key:
159
+ raise RuntimeError(
160
+ "WO_API_KEY must be specified for SaaS or IBM IAM auth"
161
+ )
162
+
163
+ # Endpoints
164
+ self.url = (
165
+ self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
166
+ ) # legacy
167
+ self.chat_url = self.instance_url + "/ml/v1/chat/completions" # chat
168
+ self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
44
169
 
45
170
  self.lock = Lock()
46
171
  self.token, self.refresh_time = self.get_token()
47
172
  self.params = params if params else DEFAULT_PARAM
173
+ self.system_prompt = system_prompt
174
+
175
+ def _resolve_auth_mode_and_url(
176
+ self, explicit_auth_url: str | None
177
+ ) -> Tuple[str, str]:
178
+ """
179
+ Returns (auth_mode, auth_url)
180
+ - auth_mode: "cpd" | "ibm_iam" | "saas"
181
+ """
182
+ if explicit_auth_url:
183
+ if "/icp4d-api" in explicit_auth_url:
184
+ return "cpd", _normalize_cpd_auth_url(explicit_auth_url)
185
+ if self.auth_type == "ibm_iam":
186
+ return "ibm_iam", explicit_auth_url
187
+ elif self.auth_type == "saas":
188
+ return "saas", explicit_auth_url
189
+ else:
190
+ mode = "ibm_iam" if self.is_ibm_cloud else "saas"
191
+ return mode, explicit_auth_url
192
+
193
+ if self.auth_type == "cpd":
194
+ inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
195
+ return "cpd", inferred_cpd_url
196
+ if self.auth_type == "ibm_iam":
197
+ return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
198
+ if self.auth_type == "saas":
199
+ return "saas", AUTH_ENDPOINT_AWS
200
+
201
+ if "/orchestrate" in self.instance_url:
202
+ inferred_cpd_url = _infer_cpd_auth_url(self.instance_url)
203
+ return "cpd", inferred_cpd_url
48
204
 
49
- def get_token(self):
50
205
  if self.is_ibm_cloud:
51
- payload = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": self.api_key}
52
- resp = requests.post(self.auth_url, data=payload)
53
- token_key = "access_token"
206
+ return "ibm_iam", AUTH_ENDPOINT_IBM_CLOUD
54
207
  else:
55
- payload = {"apikey": self.api_key}
56
- resp = requests.post(self.auth_url, json=payload)
57
- token_key = "token"
208
+ return "saas", AUTH_ENDPOINT_AWS
209
+
210
+ def get_token(self):
211
+ headers = {}
212
+ post_args = {}
213
+ timeout = 10
214
+ exchange_url = self.auth_url
215
+
216
+ if self.auth_mode == "ibm_iam":
217
+ headers = {
218
+ "Accept": "application/json",
219
+ "Content-Type": "application/x-www-form-urlencoded",
220
+ }
221
+ form_data = {
222
+ "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
223
+ "apikey": self.api_key,
224
+ }
225
+ post_args = {"data": form_data}
226
+ resp = requests.post(
227
+ exchange_url,
228
+ headers=headers,
229
+ timeout=timeout,
230
+ verify=self._wo_ssl_verify,
231
+ **post_args,
232
+ )
233
+ elif self.auth_mode == "cpd":
234
+ headers = {
235
+ "Accept": "application/json",
236
+ "Content-Type": "application/json",
237
+ }
238
+ body = {"username": self.username}
239
+ if self.password:
240
+ body["password"] = self.password
241
+ else:
242
+ body["api_key"] = self.api_key
243
+ timeout = self.timeout
244
+ resp = requests.post(
245
+ exchange_url,
246
+ headers=headers,
247
+ json=body,
248
+ timeout=timeout,
249
+ verify=self._wo_ssl_verify,
250
+ )
251
+ else:
252
+ headers = {
253
+ "Accept": "application/json",
254
+ "Content-Type": "application/json",
255
+ }
256
+ post_args = {"json": {"apikey": self.api_key}}
257
+ resp = requests.post(
258
+ exchange_url,
259
+ headers=headers,
260
+ timeout=timeout,
261
+ verify=self._wo_ssl_verify,
262
+ **post_args,
263
+ )
264
+
58
265
  if resp.status_code == 200:
59
266
  json_obj = resp.json()
60
- token = json_obj[token_key]
61
- expires_in = json_obj["expires_in"]
62
- refresh_time = time.time() + int(0.8*expires_in)
267
+ token = json_obj.get("access_token") or json_obj.get("token")
268
+ if not token:
269
+ raise RuntimeError(
270
+ f"No token field found in response: {json_obj!r}"
271
+ )
272
+
273
+ expires_in = json_obj.get("expires_in")
274
+ try:
275
+ expires_in = int(expires_in) if expires_in is not None else None
276
+ except Exception:
277
+ expires_in = None
278
+ if not expires_in or expires_in <= 0:
279
+ expires_in = int(os.environ.get("TOKEN_DEFAULT_EXPIRES_IN", 1))
280
+
281
+ refresh_time = time.time() + int(0.8 * expires_in)
63
282
  return token, refresh_time
64
283
 
65
284
  resp.raise_for_status()
@@ -75,13 +294,26 @@ class ModelProxyProvider(Provider):
75
294
 
76
295
  def encode(self, sentences: List[str]) -> List[list]:
77
296
  if self.embedding_model_id is None:
78
- raise Exception("embedding model id must be specified for text generation")
297
+ raise Exception(
298
+ "embedding model id must be specified for text generation"
299
+ )
79
300
 
80
301
  self.refresh_token_if_expires()
81
302
  headers = self.get_header()
82
- payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": "1"}
83
- #"timeout": self.timeout}
84
- resp = requests.post(self.embedding_url, json=payload, headers=headers)
303
+ payload = {
304
+ "inputs": sentences,
305
+ "model_id": self.embedding_model_id,
306
+ "space_id": self.space_id,
307
+ "max_token": 1000
308
+ }
309
+ # "timeout": self.timeout}
310
+ resp = requests.post(
311
+ self.embedding_url,
312
+ json=payload,
313
+ headers=headers,
314
+ verify=self._wo_ssl_verify,
315
+ timeout=self.timeout,
316
+ )
85
317
 
86
318
  if resp.status_code == 200:
87
319
  json_obj = resp.json()
@@ -89,20 +321,371 @@ class ModelProxyProvider(Provider):
89
321
 
90
322
  resp.raise_for_status()
91
323
 
92
- def query(self, sentence: str) -> str:
324
+ def old_query(self, sentence: str) -> str:
325
+ # Legacy /ml/v1/text/generation
93
326
  if self.model_id is None:
94
327
  raise Exception("model id must be specified for text generation")
328
+
95
329
  self.refresh_token_if_expires()
96
330
  headers = self.get_header()
97
- payload = {"input": sentence, "model_id": self.model_id, "space_id": "1",
98
- "timeout": self.timeout, "parameters": self.params}
99
- resp = requests.post(self.url, json=payload, headers=headers)
100
- if resp.status_code == 200:
101
- return resp.json()["results"][0]["generated_text"]
331
+ payload = {
332
+ "input": sentence,
333
+ "model_id": self.model_id,
334
+ "space_id": self.space_id,
335
+ "timeout": self.timeout,
336
+ "parameters": self.params,
337
+ }
102
338
 
103
- resp.raise_for_status()
339
+ request_id = str(uuid.uuid4())
340
+ start_time = time.time()
341
+
342
+ # Input logging
343
+ logger.debug(
344
+ "[d][b]Sending text.generation request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
345
+ request_id,
346
+ self.url,
347
+ self.model_id,
348
+ self.space_id,
349
+ json.dumps(
350
+ payload.get("parameters", {}),
351
+ sort_keys=True,
352
+ ensure_ascii=False,
353
+ ),
354
+ _truncate(sentence, 200),
355
+ )
356
+
357
+ resp = None
358
+ try:
359
+ resp = requests.post(
360
+ self.url,
361
+ json=payload,
362
+ headers=headers,
363
+ verify=self._wo_ssl_verify,
364
+ timeout=self.timeout,
365
+ )
366
+
367
+ duration_ms = int((time.time() - start_time) * 1000)
368
+ resp.raise_for_status()
369
+ data = resp.json()
370
+
371
+ result = (
372
+ data["results"][0]
373
+ if "results" in data and data["results"]
374
+ else data
375
+ )
376
+ output_text = (
377
+ (
378
+ result.get("generated_text")
379
+ if isinstance(result, dict)
380
+ else None
381
+ )
382
+ or (result.get("message") if isinstance(result, dict) else None)
383
+ or ""
384
+ )
385
+
386
+ # Usage (best-effort)
387
+ usage = data.get("usage") or {}
388
+ if not usage and isinstance(result, dict):
389
+ in_tok = result.get("input_token_count")
390
+ out_tok = result.get("generated_token_count") or result.get(
391
+ "output_token_count"
392
+ )
393
+ if in_tok is not None or out_tok is not None:
394
+ usage = {
395
+ "prompt_tokens": in_tok,
396
+ "completion_tokens": out_tok,
397
+ "total_tokens": (in_tok or 0) + (out_tok or 0),
398
+ }
399
+
400
+ api_request_id = resp.headers.get(
401
+ "x-request-id"
402
+ ) or resp.headers.get("request-id")
403
+
404
+ # Output logging
405
+ logger.debug(
406
+ "[d][b]text.generation response received | request_id=%s status_code=%s duration_ms=%s usage=%s output_preview=%s api_request_id=%s",
407
+ request_id,
408
+ resp.status_code,
409
+ duration_ms,
410
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
411
+ _truncate(output_text, 2000),
412
+ api_request_id,
413
+ )
414
+
415
+ if output_text:
416
+ return output_text
417
+ else:
418
+ raise ValueError(
419
+ f"Unexpected response from legacy endpoint: {data}"
420
+ )
421
+
422
+ except Exception as e:
423
+ duration_ms = int((time.time() - start_time) * 1000)
424
+ status_code = getattr(resp, "status_code", None)
425
+ resp_text_preview = (
426
+ _truncate(getattr(resp, "text", None), 2000)
427
+ if resp is not None
428
+ else None
429
+ )
430
+
431
+ logger.exception(
432
+ "text.generation request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
433
+ request_id,
434
+ status_code,
435
+ duration_ms,
436
+ resp_text_preview,
437
+ )
438
+
439
+ with self.lock:
440
+ if (
441
+ "authentication_token_expired" in str(e)
442
+ or status_code == 401
443
+ ):
444
+ self.token, self.refresh_time = self.get_token()
445
+ raise
446
+
447
+ def new_query(self, sentence: str) -> str:
448
+ """
449
+ New /ml/v1/chat/completions
450
+ Returns assistant message content of the first choice.
451
+ """
452
+ if self.model_id is None:
453
+ raise Exception("model id must be specified for text generation")
454
+
455
+ self.refresh_token_if_expires()
456
+ headers = self.get_header()
457
+
458
+ messages: List[Dict[str, Any]] = []
459
+ if getattr(self, "system_prompt", None):
460
+ messages.append({"role": "system", "content": self.system_prompt})
461
+ messages.append(
462
+ {
463
+ "role": "user",
464
+ "content": [{"type": "text", "text": sentence}],
465
+ }
466
+ )
467
+
468
+ chat_params = _translate_params_to_chat(self.params)
469
+ if isinstance(self.params, dict) and "time_limit" in self.params:
470
+ chat_params["time_limit"] = self.params["time_limit"]
471
+
472
+ payload: Dict[str, Any] = {
473
+ "model_id": self.model_id,
474
+ "space_id": self.space_id,
475
+ "messages": messages,
476
+ **chat_params,
477
+ }
478
+
479
+ url = f"{self.instance_url}/ml/v1/text/chat?version=2024-10-08"
480
+ request_id = str(uuid.uuid4())
481
+ start_time = time.time()
482
+
483
+ logger.debug(
484
+ "[d][b]Sending chat.completions request | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
485
+ request_id,
486
+ url,
487
+ self.model_id,
488
+ self.space_id,
489
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
490
+ _truncate(sentence, 200),
491
+ )
492
+
493
+ resp = None
494
+ try:
495
+ resp = requests.post(
496
+ url,
497
+ json=payload,
498
+ headers=headers,
499
+ verify=self._wo_ssl_verify,
500
+ timeout=self.timeout,
501
+ )
502
+ duration_ms = int((time.time() - start_time) * 1000)
503
+ resp.raise_for_status()
504
+ data = resp.json()
505
+
506
+ choice = data["choices"][0]
507
+ content = choice["message"]["content"]
508
+ finish_reason = choice.get("finish_reason")
509
+ usage = data.get("usage", {})
510
+ api_request_id = resp.headers.get(
511
+ "x-request-id"
512
+ ) or resp.headers.get("request-id")
513
+
514
+ logger.debug(
515
+ "[d][b]chat.completions response received | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
516
+ request_id,
517
+ resp.status_code,
518
+ duration_ms,
519
+ finish_reason,
520
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
521
+ _truncate(content, 2000),
522
+ api_request_id,
523
+ )
524
+
525
+ return content
526
+
527
+ except Exception as e:
528
+ duration_ms = int((time.time() - start_time) * 1000)
529
+ status_code = getattr(resp, "status_code", None)
530
+ resp_text_preview = (
531
+ _truncate(getattr(resp, "text", None), 2000)
532
+ if resp is not None
533
+ else None
534
+ )
535
+
536
+ logger.exception(
537
+ "chat.completions request failed | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
538
+ request_id,
539
+ status_code,
540
+ duration_ms,
541
+ resp_text_preview,
542
+ )
543
+
544
+ with self.lock:
545
+ if (
546
+ "authentication_token_expired" in str(e)
547
+ or status_code == 401
548
+ ):
549
+ self.token, self.refresh_time = self.get_token()
550
+ raise
551
+
552
+ def chat(
553
+ self,
554
+ messages: Sequence[Dict[str, str]],
555
+ params: Optional[Dict[str, Any]] = None,
556
+ ) -> ChatResult:
557
+ # Non-streaming chat using /ml/v1/chat/completions.
558
+ if self.model_id is None:
559
+ raise Exception("model id must be specified for chat")
560
+
561
+ self.refresh_token_if_expires()
562
+ headers = self.get_header()
563
+
564
+ # Convert messages to watsonx format: user content is typed list
565
+ wx_messages: List[Dict[str, Any]] = []
566
+ for m in messages:
567
+ role = m.get("role")
568
+ content = m.get("content", "")
569
+ if role == "user" and isinstance(content, str):
570
+ wx_messages.append(
571
+ {
572
+ "role": "user",
573
+ "content": [{"type": "text", "text": content}],
574
+ }
575
+ )
576
+ else:
577
+ wx_messages.append({"role": role, "content": content})
578
+
579
+ merged_params = dict(self.params or {})
580
+ if params:
581
+ merged_params.update(params)
582
+ chat_params = _translate_params_to_chat(merged_params)
583
+ chat_params.pop("stream", None) # force non-streaming
584
+ if "time_limit" in merged_params:
585
+ chat_params["time_limit"] = merged_params["time_limit"]
586
+
587
+ payload: Dict[str, Any] = {
588
+ "model_id": self.model_id,
589
+ "space_id": self.space_id,
590
+ "messages": wx_messages,
591
+ **chat_params,
592
+ }
593
+
594
+ url = f"{self.instance_url}/ml/v1/text/chat?version=2024-10-08"
595
+
596
+ last_user = next(
597
+ (
598
+ m.get("content", "")
599
+ for m in reversed(messages)
600
+ if m.get("role") == "user"
601
+ ),
602
+ "",
603
+ )
604
+ request_id = str(uuid.uuid4())
605
+ start_time = time.time()
606
+
607
+ logger.debug(
608
+ "[d][b]Sending chat.completions request (non-streaming) | request_id=%s url=%s model=%s space_id=%s params=%s input_preview=%s",
609
+ request_id,
610
+ url,
611
+ self.model_id,
612
+ self.space_id,
613
+ json.dumps(chat_params, sort_keys=True, ensure_ascii=False),
614
+ _truncate(last_user, 200),
615
+ )
616
+
617
+ resp = None
618
+ try:
619
+ resp = requests.post(
620
+ url,
621
+ json=payload,
622
+ headers=headers,
623
+ verify=self._wo_ssl_verify,
624
+ timeout=self.timeout,
625
+ )
626
+ duration_ms = int((time.time() - start_time) * 1000)
627
+ resp.raise_for_status()
628
+ data = resp.json()
629
+
630
+ choice = data["choices"][0]
631
+ content = choice["message"]["content"]
632
+ finish_reason = choice.get("finish_reason")
633
+ usage = data.get("usage", {})
634
+ api_request_id = resp.headers.get(
635
+ "x-request-id"
636
+ ) or resp.headers.get("request-id")
637
+
638
+ logger.debug(
639
+ "[d][b]chat.completions response received (non-streaming) | request_id=%s status_code=%s duration_ms=%s finish_reason=%s usage=%s output_preview=%s api_request_id=%s",
640
+ request_id,
641
+ resp.status_code,
642
+ duration_ms,
643
+ finish_reason,
644
+ json.dumps(usage, sort_keys=True, ensure_ascii=False),
645
+ _truncate(content, 2000),
646
+ api_request_id,
647
+ )
648
+
649
+ return ChatResult(
650
+ text=content, usage=usage, finish_reason=finish_reason, raw=data
651
+ )
652
+
653
+ except Exception as e:
654
+ duration_ms = int((time.time() - start_time) * 1000)
655
+ status_code = getattr(resp, "status_code", None)
656
+ resp_text_preview = (
657
+ _truncate(getattr(resp, "text", None), 2000)
658
+ if resp is not None
659
+ else None
660
+ )
661
+
662
+ logger.exception(
663
+ "chat.completions request failed (non-streaming) | request_id=%s status_code=%s duration_ms=%s response_text_preview=%s",
664
+ request_id,
665
+ status_code,
666
+ duration_ms,
667
+ resp_text_preview,
668
+ )
669
+ with self.lock:
670
+ if (
671
+ "authentication_token_expired" in str(e)
672
+ or status_code == 401
673
+ ):
674
+ self.token, self.refresh_time = self.get_token()
675
+ raise
104
676
 
105
677
 
106
678
  if __name__ == "__main__":
107
- provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
679
+ logging.basicConfig(
680
+ level=logging.INFO,
681
+ format="%(asctime)s %(levelname)s %(name)s %(message)s",
682
+ )
683
+
684
+ provider = ModelProxyProvider(
685
+ model_id="meta-llama/llama-3-3-70b-instruct",
686
+ embedding_model_id="ibm/slate-30m-english-rtrvr",
687
+ use_legacy_query=False,
688
+ system_prompt="",
689
+ )
690
+ # Base class will route .query() to new_query() by default (unless USE_LEGACY_QUERY=true)
108
691
  print(provider.query("ok"))