ibm-watsonx-orchestrate-evaluation-framework 1.1.2__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,9 @@ import logging
2
2
  import os
3
3
 
4
4
  import requests
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Optional, Any, Dict, Iterable, Tuple
5
8
  import yaml
6
9
 
7
10
  from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url, is_saas_url
@@ -156,8 +159,51 @@ class ServiceInstance:
156
159
 
157
160
  return default_tenant["id"]
158
161
 
162
+ def get_env_settings(
163
+ tenant_name: str,
164
+ env_config_path: Optional[str] = None
165
+ ) -> Dict[str, Any]:
166
+ if env_config_path is None:
167
+ env_config_path = f"{os.path.expanduser('~')}/.config/orchestrate/config.yaml"
159
168
 
160
- def tenant_setup(service_url: str, tenant_name: str):
169
+ try:
170
+ with open(env_config_path, "r", encoding="utf-8") as f:
171
+ cfg = yaml.safe_load(f) or {}
172
+ except FileNotFoundError:
173
+ return {}
174
+
175
+ tenant_env = (cfg.get("environments") or {}).get(tenant_name) or {}
176
+ cached_user_env = cfg.get("cached_user_env") or {}
177
+
178
+ merged = cached_user_env | tenant_env
179
+
180
+ return dict(merged)
181
+
182
+
183
+
184
+ def apply_env_overrides(
185
+ base: Dict[str, Any],
186
+ tenant_name: str,
187
+ keys: Optional[Iterable[str]] = None,
188
+ env_config_path: Optional[str] = None
189
+ ) -> Dict[str, Any]:
190
+ """
191
+ Returns a new dict where base is overridden by tenant-defined values.
192
+ - If keys is None, tries to override any keys present in tenant env.
193
+ - Only overrides when the tenant value is present and not None.
194
+ """
195
+ env = get_env_settings(tenant_name, env_config_path=env_config_path)
196
+ merged = dict(base)
197
+ keys_to_consider = keys if keys is not None else env.keys()
198
+
199
+ for k in keys_to_consider:
200
+ if k in env and env[k] is not None:
201
+ merged[k] = env[k]
202
+ return merged
203
+
204
+
205
+
206
+ def tenant_setup(service_url: Optional[str], tenant_name: str) -> Tuple[Optional[str], Optional[str], Dict[str, Any]]:
161
207
  # service_instance = ServiceInstance(
162
208
  # service_url=service_url,
163
209
  # tenant_name=tenant_name
@@ -175,18 +221,41 @@ def tenant_setup(service_url: str, tenant_name: str):
175
221
  f"{os.path.expanduser('~')}/.config/orchestrate/config.yaml"
176
222
  )
177
223
 
178
- # TO-DO: update SDK and use SDK to manage this
179
- with open(auth_config_path, "r") as f:
180
- auth_config = yaml.safe_load(f)
181
- # auth_config["auth"][tenant_name] = {"wxo_mcsp_token": tenant_token}
224
+ try:
225
+ with open(auth_config_path, "r", encoding="utf-8") as f:
226
+ auth_config = yaml.safe_load(f) or {}
227
+ except FileNotFoundError:
228
+ auth_config = {}
182
229
 
183
- with open(env_config_path, "r") as f:
184
- env_config = yaml.safe_load(f)
185
- env_config["environments"][tenant_name] = {"wxo_url": service_url}
186
- env_config["context"]["active_environment"] = tenant_name
230
+ try:
231
+ with open(env_config_path, "r", encoding="utf-8") as f:
232
+ env_config = yaml.safe_load(f) or {}
233
+ except FileNotFoundError:
234
+ env_config = {}
235
+
236
+ environments = env_config.setdefault("environments", {})
237
+ context = env_config.setdefault("context", {})
238
+
239
+ tenant_env = environments.setdefault(tenant_name, {})
240
+
241
+ if service_url and str(service_url).strip():
242
+ tenant_env["wxo_url"] = service_url
243
+
244
+ resolved_service_url = tenant_env.get("wxo_url")
245
+
246
+ context["active_environment"] = tenant_name
187
247
 
188
248
  with open(auth_config_path, "w") as f:
189
249
  yaml.dump(auth_config, f)
190
250
  with open(env_config_path, "w") as f:
191
251
  yaml.dump(env_config, f)
192
- return auth_config["auth"][tenant_name]["wxo_mcsp_token"]
252
+
253
+ token = (
254
+ auth_config.get("auth", {})
255
+ .get(tenant_name, {})
256
+ .get("wxo_mcsp_token")
257
+ )
258
+
259
+ env_merged = get_env_settings(tenant_name, env_config_path=env_config_path)
260
+
261
+ return token, resolved_service_url, env_merged
@@ -55,7 +55,7 @@ def get_provider(
55
55
  if "WO_INSTANCE" in os.environ:
56
56
  config = ProviderConfig(provider="model_proxy", model_id=model_id)
57
57
  return _instantiate_provider(config, referenceless_eval, **kwargs)
58
-
58
+
59
59
  if config:
60
60
  return _instantiate_provider(config, **kwargs)
61
61
 
@@ -1,15 +1,23 @@
1
1
  import os
2
- import requests
3
2
  import time
4
- from typing import List, Tuple
5
3
  from threading import Lock
4
+ from typing import List, Tuple
5
+
6
+ import requests
6
7
 
7
8
  from wxo_agentic_evaluation.service_provider.provider import Provider
8
9
  from wxo_agentic_evaluation.utils.utils import is_ibm_cloud_url
9
10
 
10
- AUTH_ENDPOINT_AWS = "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
11
+ AUTH_ENDPOINT_AWS = (
12
+ "https://iam.platform.saas.ibm.com/siusermgr/api/1.0/apikeys/token"
13
+ )
11
14
  AUTH_ENDPOINT_IBM_CLOUD = "https://iam.cloud.ibm.com/identity/token"
12
- DEFAULT_PARAM = {"min_new_tokens": 1, "decoding_method": "greedy", "max_new_tokens": 400}
15
+ DEFAULT_PARAM = {
16
+ "min_new_tokens": 1,
17
+ "decoding_method": "greedy",
18
+ "max_new_tokens": 400,
19
+ }
20
+
13
21
 
14
22
  def _infer_cpd_auth_url(instance_url: str) -> str:
15
23
  inst = (instance_url or "").rstrip("/")
@@ -36,49 +44,71 @@ class ModelProxyProvider(Provider):
36
44
  instance_url=None,
37
45
  timeout=300,
38
46
  embedding_model_id=None,
39
- params=None
47
+ params=None,
40
48
  ):
41
49
  super().__init__()
42
50
 
43
51
  instance_url = os.environ.get("WO_INSTANCE", instance_url)
44
52
  if not instance_url:
45
- raise RuntimeError("instance url must be specified to use WO model proxy")
53
+ raise RuntimeError(
54
+ "instance url must be specified to use WO model proxy"
55
+ )
46
56
 
47
57
  self.timeout = timeout
48
- self.model_id = os.environ.get("MODEL_OVERRIDE",model_id)
58
+ self.model_id = os.environ.get("MODEL_OVERRIDE", model_id)
49
59
  self.embedding_model_id = embedding_model_id
50
60
 
51
61
  self.api_key = os.environ.get("WO_API_KEY", api_key)
52
62
  self.username = os.environ.get("WO_USERNAME", None)
53
63
  self.password = os.environ.get("WO_PASSWORD", None)
54
- self.auth_type = os.environ.get("WO_AUTH_TYPE", "").lower() # explicit override if set, otherwise inferred- match ADK values
64
+ self.auth_type = os.environ.get(
65
+ "WO_AUTH_TYPE", ""
66
+ ).lower() # explicit override if set, otherwise inferred- match ADK values
55
67
  explicit_auth_url = os.environ.get("AUTHORIZATION_URL", None)
56
68
 
57
69
  self.is_ibm_cloud = is_ibm_cloud_url(instance_url)
58
70
  self.instance_url = instance_url.rstrip("/")
59
71
 
60
- self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(explicit_auth_url=explicit_auth_url)
61
- self._wo_ssl_verify = os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
72
+ self.auth_mode, self.auth_url = self._resolve_auth_mode_and_url(
73
+ explicit_auth_url=explicit_auth_url
74
+ )
75
+ self._wo_ssl_verify = (
76
+ os.environ.get("WO_SSL_VERIFY", "true").lower() != "false"
77
+ )
62
78
  env_space_id = os.environ.get("WATSONX_SPACE_ID", None)
63
79
  if self.auth_mode == "cpd":
64
80
  if not env_space_id or not env_space_id.strip():
65
- raise RuntimeError("CPD mode requires WATSONX_SPACE_ID environment variable to be set")
81
+ raise RuntimeError(
82
+ "CPD mode requires WATSONX_SPACE_ID environment variable to be set"
83
+ )
66
84
  self.space_id = env_space_id.strip()
67
85
  else:
68
- self.space_id = (env_space_id.strip() if env_space_id and env_space_id.strip() else "1")
86
+ self.space_id = (
87
+ env_space_id.strip()
88
+ if env_space_id and env_space_id.strip()
89
+ else "1"
90
+ )
69
91
 
70
92
  if self.auth_mode == "cpd":
71
93
  if "/orchestrate" in self.instance_url:
72
- self.instance_url = self.instance_url.split("/orchestrate", 1)[0].rstrip("/")
94
+ self.instance_url = self.instance_url.split("/orchestrate", 1)[
95
+ 0
96
+ ].rstrip("/")
73
97
  if not self.username:
74
98
  raise RuntimeError("CPD auth requires WO_USERNAME to be set")
75
99
  if not (self.password or self.api_key):
76
- raise RuntimeError("CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)")
100
+ raise RuntimeError(
101
+ "CPD auth requires either WO_PASSWORD or WO_API_KEY to be set (with WO_USERNAME)"
102
+ )
77
103
  else:
78
104
  if not self.api_key:
79
- raise RuntimeError("WO_API_KEY must be specified for SaaS or IBM IAM auth")
105
+ raise RuntimeError(
106
+ "WO_API_KEY must be specified for SaaS or IBM IAM auth"
107
+ )
80
108
 
81
- self.url = self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
109
+ self.url = (
110
+ self.instance_url + "/ml/v1/text/generation?version=2024-05-01"
111
+ )
82
112
  self.embedding_url = self.instance_url + "/ml/v1/text/embeddings"
83
113
 
84
114
  self.lock = Lock()
@@ -86,8 +116,7 @@ class ModelProxyProvider(Provider):
86
116
  self.params = params if params else DEFAULT_PARAM
87
117
 
88
118
  def _resolve_auth_mode_and_url(
89
- self,
90
- explicit_auth_url: str | None
119
+ self, explicit_auth_url: str | None
91
120
  ) -> Tuple[str, str]:
92
121
  """
93
122
  Returns (auth_mode, auth_url)
@@ -128,32 +157,61 @@ class ModelProxyProvider(Provider):
128
157
  exchange_url = self.auth_url
129
158
 
130
159
  if self.auth_mode == "ibm_iam":
131
- headers = {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"}
160
+ headers = {
161
+ "Accept": "application/json",
162
+ "Content-Type": "application/x-www-form-urlencoded",
163
+ }
132
164
  form_data = {
133
165
  "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
134
- "apikey": self.api_key
166
+ "apikey": self.api_key,
135
167
  }
136
168
  post_args = {"data": form_data}
137
- resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
169
+ resp = requests.post(
170
+ exchange_url,
171
+ headers=headers,
172
+ timeout=timeout,
173
+ verify=self._wo_ssl_verify,
174
+ **post_args,
175
+ )
138
176
  elif self.auth_mode == "cpd":
139
- headers = {"Accept": "application/json", "Content-Type": "application/json"}
177
+ headers = {
178
+ "Accept": "application/json",
179
+ "Content-Type": "application/json",
180
+ }
140
181
  body = {"username": self.username}
141
182
  if self.password:
142
183
  body["password"] = self.password
143
184
  else:
144
185
  body["api_key"] = self.api_key
145
186
  timeout = self.timeout
146
- resp = requests.post(exchange_url, headers=headers, json=body, timeout=timeout, verify=self._wo_ssl_verify)
187
+ resp = requests.post(
188
+ exchange_url,
189
+ headers=headers,
190
+ json=body,
191
+ timeout=timeout,
192
+ verify=self._wo_ssl_verify,
193
+ )
147
194
  else:
148
- headers = {"Accept": "application/json", "Content-Type": "application/json"}
195
+ headers = {
196
+ "Accept": "application/json",
197
+ "Content-Type": "application/json",
198
+ }
149
199
  post_args = {"json": {"apikey": self.api_key}}
150
- resp = requests.post(exchange_url, headers=headers, timeout=timeout, verify=self._wo_ssl_verify, **post_args)
200
+ resp = requests.post(
201
+ exchange_url,
202
+ headers=headers,
203
+ timeout=timeout,
204
+ verify=self._wo_ssl_verify,
205
+ **post_args,
206
+ )
151
207
 
152
208
  if resp.status_code == 200:
153
209
  json_obj = resp.json()
154
210
  token = json_obj.get("access_token") or json_obj.get("token")
155
211
  if not token:
156
- raise RuntimeError(f"No token field found in response: {json_obj!r}")
212
+ raise RuntimeError(
213
+ f"No token field found in response: {json_obj!r}"
214
+ )
157
215
 
158
216
  expires_in = json_obj.get("expires_in")
159
217
  try:
@@ -179,13 +237,24 @@ class ModelProxyProvider(Provider):
179
237
 
180
238
  def encode(self, sentences: List[str]) -> List[list]:
181
239
  if self.embedding_model_id is None:
182
- raise Exception("embedding model id must be specified for text generation")
240
+ raise Exception(
241
+ "embedding model id must be specified for text generation"
242
+ )
183
243
 
184
244
  self.refresh_token_if_expires()
185
245
  headers = self.get_header()
186
- payload = {"inputs": sentences, "model_id": self.embedding_model_id, "space_id": self.space_id}
187
- #"timeout": self.timeout}
188
- resp = requests.post(self.embedding_url, json=payload, headers=headers, verify=self._wo_ssl_verify)
246
+ payload = {
247
+ "inputs": sentences,
248
+ "model_id": self.embedding_model_id,
249
+ "space_id": self.space_id,
250
+ }
251
+ # "timeout": self.timeout}
252
+ resp = requests.post(
253
+ self.embedding_url,
254
+ json=payload,
255
+ headers=headers,
256
+ verify=self._wo_ssl_verify,
257
+ )
189
258
 
190
259
  if resp.status_code == 200:
191
260
  json_obj = resp.json()
@@ -198,9 +267,16 @@ class ModelProxyProvider(Provider):
198
267
  raise Exception("model id must be specified for text generation")
199
268
  self.refresh_token_if_expires()
200
269
  headers = self.get_header()
201
- payload = {"input": sentence, "model_id": self.model_id, "space_id": self.space_id,
202
- "timeout": self.timeout, "parameters": self.params}
203
- resp = requests.post(self.url, json=payload, headers=headers, verify=self._wo_ssl_verify)
270
+ payload = {
271
+ "input": sentence,
272
+ "model_id": self.model_id,
273
+ "space_id": self.space_id,
274
+ "timeout": self.timeout,
275
+ "parameters": self.params,
276
+ }
277
+ resp = requests.post(
278
+ self.url, json=payload, headers=headers, verify=self._wo_ssl_verify
279
+ )
204
280
  if resp.status_code == 200:
205
281
  return resp.json()["results"][0]["generated_text"]
206
282
 
@@ -208,5 +284,8 @@ class ModelProxyProvider(Provider):
208
284
 
209
285
 
210
286
  if __name__ == "__main__":
211
- provider = ModelProxyProvider(model_id="meta-llama/llama-3-3-70b-instruct", embedding_model_id="ibm/slate-30m-english-rtrvr")
212
- print(provider.query("ok"))
287
+ provider = ModelProxyProvider(
288
+ model_id="meta-llama/llama-3-3-70b-instruct",
289
+ embedding_model_id="ibm/slate-30m-english-rtrvr",
290
+ )
291
+ print(provider.query("ok"))