beekeeper-monitors-watsonx 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of beekeeper-monitors-watsonx might be problematic. Click here for more details.

@@ -0,0 +1,1900 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from typing import Any, Dict, List, Literal, Optional, Union
7
+
8
+ import certifi
9
+ from beekeeper.core.monitors import PromptMonitor
10
+ from beekeeper.core.monitors.types import PayloadRecord
11
+ from beekeeper.core.prompts.utils import extract_template_vars
12
+ from beekeeper.monitors.watsonx.instrumentation import suppress_output
13
+ from pydantic.v1 import BaseModel
14
+
15
+ os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
16
+ logging.getLogger("ibm_watsonx_ai.client").setLevel(logging.ERROR)
17
+ logging.getLogger("ibm_watsonx_ai.wml_resource").setLevel(logging.ERROR)
18
+
19
+ REGIONS_URL = {
20
+ "us-south": {
21
+ "wml": "https://us-south.ml.cloud.ibm.com",
22
+ "wos": "https://api.aiopenscale.cloud.ibm.com",
23
+ "factsheet": None,
24
+ },
25
+ "eu-de": {
26
+ "wml": "https://eu-de.ml.cloud.ibm.com",
27
+ "wos": "https://eu-de.api.aiopenscale.cloud.ibm.com",
28
+ "factsheet": "frankfurt",
29
+ },
30
+ "au-syd": {
31
+ "wml": "https://au-syd.ml.cloud.ibm.com",
32
+ "wos": "https://au-syd.api.aiopenscale.cloud.ibm.com",
33
+ "factsheet": "sydney",
34
+ },
35
+ }
36
+
37
+
38
+ def _filter_dict(original_dict: Dict, optional_keys: List, required_keys: List = []):
39
+ """
40
+ Filters a dictionary to keep only the specified keys and checks for required keys.
41
+
42
+ Args:
43
+ original_dict (Dict): The original dictionary.
44
+ optional_keys (list): A list of keys to retain.
45
+ required_keys (list, optional): A list of keys that must be present in the dictionary. Defaults to None.
46
+ """
47
+ # Ensure all required keys are in the source dict
48
+ missing_keys = [key for key in required_keys if key not in original_dict]
49
+ if missing_keys:
50
+ raise KeyError(f"Missing required parameter: {missing_keys}")
51
+
52
+ all_keys_to_keep = set(required_keys + optional_keys)
53
+
54
+ # Create a new dictionary with only the key-value pairs where the key is in 'keys' and value is not None
55
+ return {
56
+ key: original_dict[key]
57
+ for key in all_keys_to_keep
58
+ if key in original_dict and original_dict[key] is not None
59
+ }
60
+
61
+
62
+ def _convert_payload_format(
63
+ records: List[Dict],
64
+ feature_fields: List[str],
65
+ ) -> List[Dict]:
66
+ payload_data = []
67
+ response_fields = ["generated_text", "input_token_count", "generated_token_count"]
68
+
69
+ for record in records:
70
+ request = {"parameters": {"template_variables": {}}}
71
+ results = {}
72
+
73
+ request["parameters"]["template_variables"] = {
74
+ field: str(record.get(field, "")) for field in feature_fields
75
+ }
76
+
77
+ results = {
78
+ field: record.get(field) for field in response_fields if record.get(field)
79
+ }
80
+
81
+ pl_record = {
82
+ "request": request,
83
+ "response": {"results": [results]},
84
+ "scoring_id": str(uuid.uuid4()),
85
+ }
86
+
87
+ if "response_time" in record:
88
+ pl_record["response_time"] = record["response_time"]
89
+
90
+ payload_data.append(pl_record)
91
+
92
+ return payload_data
93
+
94
+
95
+ # ===== Credentials Classes =====
96
+ class CloudPakforDataCredentials(BaseModel):
97
+ """
98
+ Encapsulates the credentials required for IBM Cloud Pak for Data.
99
+
100
+ Attributes:
101
+ url (str): The host URL of the Cloud Pak for Data environment.
102
+ api_key (str, optional): The API key for the environment, if IAM is enabled.
103
+ username (str, optional): The username for the environment.
104
+ password (str, optional): The password for the environment.
105
+ bedrock_url (str, optional): The Bedrock URL. Required only when IAM integration is enabled on CP4D 4.0.x clusters.
106
+ instance_id (str, optional): The instance ID.
107
+ version (str, optional): The version of Cloud Pak for Data.
108
+ disable_ssl_verification (bool, optional): Indicates whether to disable SSL certificate verification.
109
+ Defaults to `True`.
110
+ """
111
+
112
+ url: str
113
+ api_key: Optional[str] = None
114
+ username: Optional[str] = None
115
+ password: Optional[str] = None
116
+ bedrock_url: Optional[str] = None
117
+ instance_id: Optional[Literal["icp", "openshift"]] = None
118
+ version: Optional[str] = None
119
+ disable_ssl_verification: bool = True
120
+
121
+ def __init__(
122
+ self,
123
+ url: str,
124
+ api_key: Optional[str] = None,
125
+ username: Optional[str] = None,
126
+ password: Optional[str] = None,
127
+ bedrock_url: Optional[str] = None,
128
+ instance_id: Optional[Literal["icp", "openshift"]] = None,
129
+ version: Optional[str] = None,
130
+ disable_ssl_verification: bool = True,
131
+ ) -> None:
132
+ super().__init__(
133
+ url=url,
134
+ api_key=api_key,
135
+ username=username,
136
+ password=password,
137
+ bedrock_url=bedrock_url,
138
+ instance_id=instance_id,
139
+ version=version,
140
+ disable_ssl_verification=disable_ssl_verification,
141
+ )
142
+
143
+ def to_dict(self) -> Dict[str, Any]:
144
+ cpd_creds = dict([(k, v) for k, v in self.__dict__.items()]) # noqa: C404
145
+
146
+ if "instance_id" in cpd_creds and self.instance_id.lower() not in [
147
+ "icp",
148
+ "openshift",
149
+ ]:
150
+ cpd_creds.pop("instance_id")
151
+
152
+ return cpd_creds
153
+
154
+
155
+ class IntegratedSystemCredentials(BaseModel):
156
+ """
157
+ Encapsulates the credentials for an Integrated System based on the authentication type.
158
+
159
+ Depending on the `auth_type`, only a subset of the properties is required.
160
+
161
+ Attributes:
162
+ auth_type (str): The type of authentication. Currently supports "basic" and "bearer".
163
+ username (str, optional): The username for Basic Authentication.
164
+ password (str, optional): The password for Basic Authentication.
165
+ token_url (str, optional): The URL of the authentication endpoint used to request a Bearer token.
166
+ token_method (str, optional): The HTTP method (e.g., "POST", "GET") used to request the Bearer token.
167
+ Defaults to "POST".
168
+ token_headers (Dict, optional): Optional headers to include when requesting the Bearer token.
169
+ Defaults to `None`.
170
+ token_payload (str | dict, optional): The body or payload to send when requesting the Bearer token.
171
+ Can be a string (e.g., raw JSON). Defaults to `None`.
172
+ """
173
+
174
+ auth_type: Literal["basic", "bearer"]
175
+ username: Optional[str] # basic
176
+ password: Optional[str] # basic
177
+ token_url: Optional[str] # bearer
178
+ token_method: Optional[str] = "POST" # bearer
179
+ token_headers: Optional[Dict] = {} # bearer
180
+ token_payload: Optional[Union[str, Dict]] = None # bearer
181
+
182
+ def __init__(
183
+ self,
184
+ auth_type: Literal["basic", "bearer"],
185
+ username: str = None,
186
+ password: str = None,
187
+ token_url: str = None,
188
+ token_method: str = "POST",
189
+ token_headers: Dict = {},
190
+ token_payload: Union[str, Dict] = None,
191
+ ) -> None:
192
+ if auth_type == "basic":
193
+ if not username or not password:
194
+ raise ValueError(
195
+ "`username` and `password` are required for auth_type = 'basic'.",
196
+ )
197
+ elif auth_type == "bearer":
198
+ if not token_url:
199
+ raise ValueError(
200
+ "`token_url` are required for auth_type = 'bearer'.",
201
+ )
202
+
203
+ super().__init__(
204
+ auth_type=auth_type,
205
+ username=username,
206
+ password=password,
207
+ token_url=token_url,
208
+ token_method=token_method,
209
+ token_headers=token_headers,
210
+ token_payload=token_payload,
211
+ )
212
+
213
+ def to_dict(self) -> Dict:
214
+ integrated_system_creds = {"auth_type": self.auth_type}
215
+
216
+ if self.auth_type == "basic":
217
+ integrated_system_creds["username"] = self.username
218
+ integrated_system_creds["password"] = self.password
219
+ elif self.auth_type == "bearer":
220
+ integrated_system_creds["token_info"] = {
221
+ "url": self.token_url,
222
+ "method": self.token_method,
223
+ "headers": self.token_headers,
224
+ "payload": self.token_payload,
225
+ }
226
+
227
+ return integrated_system_creds
228
+
229
+
230
+ # ===== Monitor Classes =====
231
+ class WatsonxExternalPromptMonitor(PromptMonitor):
232
+ """
233
+ Provides functionality to interact with IBM watsonx.governance for monitoring external LLMs.
234
+
235
+ Note:
236
+ One of the following parameters is required to create a prompt monitor:
237
+ `project_id` or `space_id`, but not both.
238
+
239
+ Attributes:
240
+ api_key (str): The API key for IBM watsonx.governance.
241
+ space_id (str, optional): The space ID in watsonx.governance.
242
+ project_id (str, optional): The project ID in watsonx.governance.
243
+ region (str, optional): The region where watsonx.governance is hosted when using IBM Cloud.
244
+ Defaults to `us-south`.
245
+ cpd_creds (CloudPakforDataCredentials, optional): The Cloud Pak for Data environment credentials.
246
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
247
+
248
+ Example:
249
+ ```python
250
+ from beekeeper.monitors.watsonx import (
251
+ WatsonxExternalPromptMonitor,
252
+ CloudPakforDataCredentials,
253
+ )
254
+
255
+ # watsonx.governance (IBM Cloud)
256
+ wxgov_client = WatsonxExternalPromptMonitor(
257
+ api_key="API_KEY", space_id="SPACE_ID"
258
+ )
259
+
260
+ # watsonx.governance (CP4D)
261
+ cpd_creds = CloudPakforDataCredentials(
262
+ url="CPD_URL",
263
+ username="USERNAME",
264
+ password="PASSWORD",
265
+ version="5.0",
266
+ instance_id="openshift",
267
+ )
268
+
269
+ wxgov_client = WatsonxExternalPromptMonitor(
270
+ space_id="SPACE_ID", cpd_creds=cpd_creds
271
+ )
272
+ ```
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ api_key: str = None,
278
+ space_id: str = None,
279
+ project_id: str = None,
280
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
281
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
282
+ subscription_id: str = None,
283
+ **kwargs,
284
+ ) -> None:
285
+ import ibm_aigov_facts_client # noqa: F401
286
+ import ibm_cloud_sdk_core.authenticators # noqa: F401
287
+ import ibm_watson_openscale # noqa: F401
288
+ import ibm_watsonx_ai # noqa: F401
289
+
290
+ super().__init__(**kwargs)
291
+
292
+ self.space_id = space_id
293
+ self.project_id = project_id
294
+ self.region = region
295
+ self.subscription_id = subscription_id
296
+ self._api_key = api_key
297
+ self._wos_client = None
298
+
299
+ self._container_id = space_id if space_id else project_id
300
+ self._container_type = "space" if space_id else "project"
301
+ self._deployment_stage = "production" if space_id else "development"
302
+
303
+ if cpd_creds:
304
+ self._wos_cpd_creds = _filter_dict(
305
+ cpd_creds.to_dict(),
306
+ ["username", "password", "api_key", "disable_ssl_verification"],
307
+ ["url"],
308
+ )
309
+ self._fact_cpd_creds = _filter_dict(
310
+ cpd_creds.to_dict(),
311
+ ["username", "password", "api_key", "bedrock_url"],
312
+ ["url"],
313
+ )
314
+ self._fact_cpd_creds["service_url"] = self._fact_cpd_creds.pop("url")
315
+ self._wml_cpd_creds = _filter_dict(
316
+ cpd_creds.to_dict(),
317
+ [
318
+ "username",
319
+ "password",
320
+ "api_key",
321
+ "instance_id",
322
+ "version",
323
+ "bedrock_url",
324
+ ],
325
+ ["url"],
326
+ )
327
+
328
+ def _create_detached_prompt(
329
+ self,
330
+ detached_details: Dict,
331
+ prompt_template_details: Dict,
332
+ detached_asset_details: Dict,
333
+ ) -> str:
334
+ from ibm_aigov_facts_client import ( # type: ignore
335
+ AIGovFactsClient,
336
+ CloudPakforDataConfig,
337
+ DetachedPromptTemplate,
338
+ PromptTemplate,
339
+ )
340
+
341
+ try:
342
+ if hasattr(self, "_fact_cpd_creds") and self._fact_cpd_creds:
343
+ cpd_creds = CloudPakforDataConfig(**self._fact_cpd_creds)
344
+
345
+ aigov_client = AIGovFactsClient(
346
+ container_id=self._container_id,
347
+ container_type=self._container_type,
348
+ cloud_pak_for_data_configs=cpd_creds,
349
+ disable_tracing=True,
350
+ )
351
+
352
+ else:
353
+ aigov_client = AIGovFactsClient(
354
+ api_key=self._api_key,
355
+ container_id=self._container_id,
356
+ container_type=self._container_type,
357
+ disable_tracing=True,
358
+ region=REGIONS_URL[self.region]["factsheet"],
359
+ )
360
+
361
+ except Exception as e:
362
+ logging.error(
363
+ f"Error connecting to IBM watsonx.governance (factsheets): {e}",
364
+ )
365
+ raise
366
+
367
+ created_detached_pta = aigov_client.assets.create_detached_prompt(
368
+ **detached_asset_details,
369
+ prompt_details=PromptTemplate(**prompt_template_details),
370
+ detached_information=DetachedPromptTemplate(**detached_details),
371
+ )
372
+
373
+ return created_detached_pta.to_dict()["asset_id"]
374
+
375
+ def _create_deployment_pta(self, asset_id: str, name: str, model_id: str) -> str:
376
+ from ibm_watsonx_ai import APIClient, Credentials # type: ignore
377
+
378
+ try:
379
+ if hasattr(self, "_wml_cpd_creds") and self._wml_cpd_creds:
380
+ creds = Credentials(**self._wml_cpd_creds)
381
+
382
+ wml_client = APIClient(creds)
383
+ wml_client.set.default_space(self.space_id)
384
+
385
+ else:
386
+ creds = Credentials(
387
+ url=REGIONS_URL[self.region]["wml"],
388
+ api_key=self._api_key,
389
+ )
390
+ wml_client = APIClient(creds)
391
+ wml_client.set.default_space(self.space_id)
392
+
393
+ except Exception as e:
394
+ logging.error(f"Error connecting to IBM watsonx.ai Runtime: {e}")
395
+ raise
396
+
397
+ meta_props = {
398
+ wml_client.deployments.ConfigurationMetaNames.PROMPT_TEMPLATE: {
399
+ "id": asset_id,
400
+ },
401
+ wml_client.deployments.ConfigurationMetaNames.DETACHED: {},
402
+ wml_client.deployments.ConfigurationMetaNames.NAME: name
403
+ + " "
404
+ + "deployment",
405
+ wml_client.deployments.ConfigurationMetaNames.BASE_MODEL_ID: model_id,
406
+ }
407
+
408
+ created_deployment = wml_client.deployments.create(asset_id, meta_props)
409
+
410
+ return wml_client.deployments.get_uid(created_deployment)
411
+
412
+ def add_prompt_monitor(
413
+ self,
414
+ name: str,
415
+ model_id: str,
416
+ task_id: Literal[
417
+ "extraction",
418
+ "generation",
419
+ "question_answering",
420
+ "retrieval_augmented_generation",
421
+ "summarization",
422
+ ],
423
+ detached_model_provider: str,
424
+ description: str = "",
425
+ model_parameters: Dict = None,
426
+ detached_model_name: str = None,
427
+ detached_model_url: str = None,
428
+ detached_prompt_url: str = None,
429
+ detached_prompt_additional_info: Dict = None,
430
+ prompt_variables: List[str] = None,
431
+ locale: str = "en",
432
+ input_text: str = None,
433
+ context_fields: List[str] = None,
434
+ question_field: str = None,
435
+ ) -> Dict:
436
+ """
437
+ Creates a Detached/External Prompt Template Asset and setup monitor for a given prompt template asset.
438
+
439
+ Args:
440
+ name (str): The name of the External Prompt Template Asset.
441
+ model_id (str): The ID of the model associated with the prompt.
442
+ task_id (str): The task identifier.
443
+ detached_model_provider (str): The external model provider.
444
+ description (str, optional): A description of the External Prompt Template Asset.
445
+ model_parameters (Dict, optional): Model parameters and their respective values.
446
+ detached_model_name (str, optional): The name of the external model.
447
+ detached_model_url (str, optional): The URL of the external model.
448
+ detached_prompt_url (str, optional): The URL of the external prompt.
449
+ detached_prompt_additional_info (Dict, optional): Additional information related to the external prompt.
450
+ prompt_variables (List[str], optional): Values for the prompt variables.
451
+ locale (str, optional): Locale code for the input/output language. eg. "en", "pt", "es".
452
+ input_text (str, optional): The input text for the prompt.
453
+ context_fields (List[str], optional): A list of fields that will provide context to the prompt.
454
+ Applicable only for "retrieval_augmented_generation" task type.
455
+ question_field (str, optional): The field containing the question to be answered.
456
+ Applicable only for "retrieval_augmented_generation" task type.
457
+
458
+ Example:
459
+ ```python
460
+ wxgov_client.add_prompt_monitor(
461
+ name="Detached prompt (model AWS Anthropic)",
462
+ model_id="anthropic.claude-v2",
463
+ task_id="retrieval_augmented_generation",
464
+ detached_model_provider="AWS Bedrock",
465
+ detached_model_name="Anthropic Claude 2.0",
466
+ detached_model_url="https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html",
467
+ prompt_variables=["context1", "context2", "input_query"],
468
+ input_text="Prompt text to be given",
469
+ context_fields=["context1", "context2"],
470
+ question_field="input_query",
471
+ )
472
+ ```
473
+ """
474
+ if (not (self.project_id or self.space_id)) or (
475
+ self.project_id and self.space_id
476
+ ):
477
+ raise ValueError(
478
+ "Invalid configuration: Neither was provided: please set either 'project_id' or 'space_id'. "
479
+ "Both were provided: 'project_id' and 'space_id' cannot be set at the same time."
480
+ )
481
+
482
+ if task_id == "retrieval_augmented_generation":
483
+ if not context_fields or not question_field:
484
+ raise ValueError(
485
+ "For 'retrieval_augmented_generation' task, requires non-empty 'context_fields' and 'question_field'."
486
+ )
487
+
488
+ prompt_metadata = locals()
489
+ # Remove unused vars from dict
490
+ prompt_metadata.pop("self", None)
491
+ prompt_metadata.pop("context_fields", None)
492
+ prompt_metadata.pop("question_field", None)
493
+ prompt_metadata.pop("locale", None)
494
+
495
+ # Update name of keys to aigov_facts api
496
+ prompt_metadata["input"] = prompt_metadata.pop("input_text", None)
497
+ prompt_metadata["model_provider"] = prompt_metadata.pop(
498
+ "detached_model_provider",
499
+ None,
500
+ )
501
+ prompt_metadata["model_name"] = prompt_metadata.pop("detached_model_name", None)
502
+ prompt_metadata["model_url"] = prompt_metadata.pop("detached_model_url", None)
503
+ prompt_metadata["prompt_url"] = prompt_metadata.pop("detached_prompt_url", None)
504
+ prompt_metadata["prompt_additional_info"] = prompt_metadata.pop(
505
+ "detached_prompt_additional_info",
506
+ None,
507
+ )
508
+
509
+ # Update list of vars to dict
510
+ prompt_metadata["prompt_variables"] = Dict.fromkeys(
511
+ prompt_metadata["prompt_variables"], ""
512
+ )
513
+
514
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
515
+
516
+ if not self._wos_client:
517
+ try:
518
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
519
+ from ibm_cloud_sdk_core.authenticators import (
520
+ CloudPakForDataAuthenticator, # type: ignore
521
+ )
522
+
523
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
524
+ self._wos_client = WosAPIClient(
525
+ authenticator=authenticator,
526
+ service_url=self._wos_cpd_creds["url"],
527
+ )
528
+
529
+ else:
530
+ from ibm_cloud_sdk_core.authenticators import (
531
+ IAMAuthenticator, # type: ignore
532
+ )
533
+
534
+ authenticator = IAMAuthenticator(apikey=self._api_key)
535
+ self._wos_client = WosAPIClient(
536
+ authenticator=authenticator,
537
+ service_url=REGIONS_URL[self.region]["wos"],
538
+ )
539
+
540
+ except Exception as e:
541
+ logging.error(
542
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
543
+ )
544
+ raise
545
+
546
+ detached_details = _filter_dict(
547
+ prompt_metadata,
548
+ ["model_name", "model_url", "prompt_url", "prompt_additional_info"],
549
+ ["model_id", "model_provider"],
550
+ )
551
+ detached_details["prompt_id"] = "detached_prompt_" + str(uuid.uuid4())
552
+
553
+ prompt_details = _filter_dict(
554
+ prompt_metadata,
555
+ ["prompt_variables", "input", "model_parameters"],
556
+ )
557
+
558
+ detached_asset_details = _filter_dict(
559
+ prompt_metadata,
560
+ ["description"],
561
+ ["name", "model_id", "task_id"],
562
+ )
563
+
564
+ detached_pta_id = suppress_output(
565
+ self._create_detached_prompt,
566
+ detached_details,
567
+ prompt_details,
568
+ detached_asset_details,
569
+ )
570
+ deployment_id = None
571
+ if self._container_type == "space":
572
+ deployment_id = suppress_output(
573
+ self._create_deployment_pta, detached_pta_id, name, model_id
574
+ )
575
+
576
+ monitors = {
577
+ "generative_ai_quality": {
578
+ "parameters": {"min_sample_size": 10, "metrics_configuration": {}},
579
+ },
580
+ }
581
+
582
+ max_attempt_execute_prompt_setup = 0
583
+ while max_attempt_execute_prompt_setup < 2:
584
+ try:
585
+ generative_ai_monitor_details = suppress_output(
586
+ self._wos_client.wos.execute_prompt_setup,
587
+ prompt_template_asset_id=detached_pta_id,
588
+ space_id=self.space_id,
589
+ project_id=self.project_id,
590
+ deployment_id=deployment_id,
591
+ label_column="reference_output",
592
+ context_fields=context_fields,
593
+ question_field=question_field,
594
+ operational_space_id=self._deployment_stage,
595
+ problem_type=task_id,
596
+ data_input_locale=[locale],
597
+ generated_output_locale=[locale],
598
+ input_data_type="unstructured_text",
599
+ supporting_monitors=monitors,
600
+ background_mode=False,
601
+ )
602
+
603
+ break
604
+
605
+ except Exception as e:
606
+ if (
607
+ e.code == 403
608
+ and "The user entitlement does not exist" in e.message
609
+ and max_attempt_execute_prompt_setup < 1
610
+ ):
611
+ max_attempt_execute_prompt_setup = (
612
+ max_attempt_execute_prompt_setup + 1
613
+ )
614
+
615
+ data_marts = self._wos_client.data_marts.list().result
616
+ if (data_marts.data_marts is None) or (not data_marts.data_marts):
617
+ raise ValueError(
618
+ "Error retrieving IBM watsonx.governance (openscale) data mart. \
619
+ Make sure the data mart are configured.",
620
+ )
621
+
622
+ data_mart_id = data_marts.data_marts[0].metadata.id
623
+
624
+ self._wos_client.wos.add_instance_mapping(
625
+ service_instance_id=data_mart_id,
626
+ space_id=self.space_id,
627
+ project_id=self.project_id,
628
+ )
629
+ else:
630
+ max_attempt_execute_prompt_setup = 2
631
+ raise
632
+
633
+ generative_ai_monitor_details = generative_ai_monitor_details.result._to_dict()
634
+
635
+ return {
636
+ "detached_prompt_template_asset_id": detached_pta_id,
637
+ "deployment_id": deployment_id,
638
+ "subscription_id": generative_ai_monitor_details["subscription_id"],
639
+ }
640
+
641
+ def store_payload_records(
642
+ self,
643
+ request_records: List[Dict],
644
+ subscription_id: str = None,
645
+ ) -> List[str]:
646
+ """
647
+ Stores records to the payload logging system.
648
+
649
+ Args:
650
+ request_records (List[Dict]): A list of records to be logged, where each record is represented as a dictionary.
651
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
652
+
653
+ Example:
654
+ ```python
655
+ wxgov_client.store_payload_records(
656
+ request_records=[
657
+ {
658
+ "context1": "value_context1",
659
+ "context2": "value_context1",
660
+ "input_query": "What's Beekeeper Framework?",
661
+ "generated_text": "Beekeeper is a data framework to make AI easier to work with.",
662
+ "input_token_count": 25,
663
+ "generated_token_count": 150,
664
+ }
665
+ ],
666
+ subscription_id="5d62977c-a53d-4b6d-bda1-7b79b3b9d1a0",
667
+ )
668
+ ```
669
+ """
670
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
671
+ from ibm_watson_openscale import APIClient as WosAPIClient
672
+ from ibm_watson_openscale.supporting_classes.enums import (
673
+ DataSetTypes,
674
+ TargetTypes,
675
+ )
676
+
677
+ # Expected behavior: Prefer using fn `subscription_id`.
678
+ # Fallback to `self.subscription_id` if `subscription_id` None or empty.
679
+ _subscription_id = subscription_id or self.subscription_id
680
+
681
+ if _subscription_id is None or _subscription_id == "":
682
+ raise ValueError(
683
+ "Unexpected value for 'subscription_id': Cannot be None or empty string."
684
+ )
685
+
686
+ if not self._wos_client:
687
+ try:
688
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
689
+ from ibm_cloud_sdk_core.authenticators import (
690
+ CloudPakForDataAuthenticator, # type: ignore
691
+ )
692
+
693
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
694
+ self._wos_client = WosAPIClient(
695
+ authenticator=authenticator,
696
+ service_url=self._wos_cpd_creds["url"],
697
+ )
698
+
699
+ else:
700
+ from ibm_cloud_sdk_core.authenticators import (
701
+ IAMAuthenticator, # type: ignore
702
+ )
703
+
704
+ authenticator = IAMAuthenticator(apikey=self._api_key)
705
+ self._wos_client = WosAPIClient(
706
+ authenticator=authenticator,
707
+ service_url=REGIONS_URL[self.region]["wos"],
708
+ )
709
+
710
+ except Exception as e:
711
+ logging.error(
712
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
713
+ )
714
+ raise
715
+
716
+ subscription_details = self._wos_client.subscriptions.get(
717
+ _subscription_id,
718
+ ).result
719
+ subscription_details = json.loads(str(subscription_details))
720
+
721
+ feature_fields = subscription_details["entity"]["asset_properties"][
722
+ "feature_fields"
723
+ ]
724
+
725
+ payload_data_set_id = (
726
+ self._wos_client.data_sets.list(
727
+ type=DataSetTypes.PAYLOAD_LOGGING,
728
+ target_target_id=_subscription_id,
729
+ target_target_type=TargetTypes.SUBSCRIPTION,
730
+ )
731
+ .result.data_sets[0]
732
+ .metadata.id
733
+ )
734
+
735
+ payload_data = _convert_payload_format(request_records, feature_fields)
736
+
737
+ suppress_output(
738
+ self._wos_client.data_sets.store_records,
739
+ data_set_id=payload_data_set_id,
740
+ request_body=payload_data,
741
+ background_mode=False,
742
+ )
743
+
744
+ return [data["scoring_id"] + "-1" for data in payload_data]
745
+
746
+ def __call__(self, payload: PayloadRecord) -> None:
747
+ if self.prompt_template:
748
+ template_vars = extract_template_vars(
749
+ self.prompt_template.template, payload.input_text
750
+ )
751
+
752
+ if not template_vars:
753
+ self.store_payload_records([payload.model_dump()])
754
+ else:
755
+ self.store_payload_records([{**payload.model_dump(), **template_vars}])
756
+
757
+
758
+ class WatsonxPromptMonitor(PromptMonitor):
759
+ """
760
+ Provides functionality to interact with IBM watsonx.governance for monitoring IBM watsonx.ai LLMs.
761
+
762
+ Note:
763
+ One of the following parameters is required to create a prompt monitor:
764
+ `project_id` or `space_id`, but not both.
765
+
766
+ Attributes:
767
+ api_key (str): The API key for IBM watsonx.governance.
768
+ space_id (str, optional): The space ID in watsonx.governance.
769
+ project_id (str, optional): The project ID in watsonx.governance.
770
+ region (str, optional): The region where watsonx.governance is hosted when using IBM Cloud.
771
+ Defaults to `us-south`.
772
+ cpd_creds (CloudPakforDataCredentials, optional): The Cloud Pak for Data environment credentials.
773
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
774
+
775
+ Example:
776
+ ```python
777
+ from beekeeper.monitors.watsonx import (
778
+ WatsonxPromptMonitor,
779
+ CloudPakforDataCredentials,
780
+ )
781
+
782
+ # watsonx.governance (IBM Cloud)
783
+ wxgov_client = WatsonxPromptMonitor(api_key="API_KEY", space_id="SPACE_ID")
784
+
785
+ # watsonx.governance (CP4D)
786
+ cpd_creds = CloudPakforDataCredentials(
787
+ url="CPD_URL",
788
+ username="USERNAME",
789
+ password="PASSWORD",
790
+ version="5.0",
791
+ instance_id="openshift",
792
+ )
793
+
794
+ wxgov_client = WatsonxPromptMonitor(space_id="SPACE_ID", cpd_creds=cpd_creds)
795
+ ```
796
+ """
797
+
798
+ def __init__(
799
+ self,
800
+ api_key: str = None,
801
+ space_id: str = None,
802
+ project_id: str = None,
803
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
804
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
805
+ subscription_id: str = None,
806
+ **kwargs,
807
+ ) -> None:
808
+ import ibm_aigov_facts_client # noqa: F401
809
+ import ibm_cloud_sdk_core.authenticators # noqa: F401
810
+ import ibm_watson_openscale # noqa: F401
811
+ import ibm_watsonx_ai # noqa: F401
812
+
813
+ super().__init__(**kwargs)
814
+
815
+ self.space_id = space_id
816
+ self.project_id = project_id
817
+ self.region = region
818
+ self.subscription_id = subscription_id
819
+ self._api_key = api_key
820
+ self._wos_client = None
821
+
822
+ self._container_id = space_id if space_id else project_id
823
+ self._container_type = "space" if space_id else "project"
824
+ self._deployment_stage = "production" if space_id else "development"
825
+
826
+ if cpd_creds:
827
+ self._wos_cpd_creds = _filter_dict(
828
+ cpd_creds.to_dict(),
829
+ ["username", "password", "api_key", "disable_ssl_verification"],
830
+ ["url"],
831
+ )
832
+ self._fact_cpd_creds = _filter_dict(
833
+ cpd_creds.to_dict(),
834
+ ["username", "password", "api_key", "bedrock_url"],
835
+ ["url"],
836
+ )
837
+ self._fact_cpd_creds["service_url"] = self._fact_cpd_creds.pop("url")
838
+ self._wml_cpd_creds = _filter_dict(
839
+ cpd_creds.to_dict(),
840
+ [
841
+ "username",
842
+ "password",
843
+ "api_key",
844
+ "instance_id",
845
+ "version",
846
+ "bedrock_url",
847
+ ],
848
+ ["url"],
849
+ )
850
+
851
+ def _create_prompt_template(
852
+ self,
853
+ prompt_template_details: Dict,
854
+ asset_details: Dict,
855
+ ) -> str:
856
+ from ibm_aigov_facts_client import (
857
+ AIGovFactsClient,
858
+ CloudPakforDataConfig,
859
+ PromptTemplate,
860
+ )
861
+
862
+ try:
863
+ if hasattr(self, "_fact_cpd_creds") and self._fact_cpd_creds:
864
+ cpd_creds = CloudPakforDataConfig(**self._fact_cpd_creds)
865
+
866
+ aigov_client = AIGovFactsClient(
867
+ container_id=self._container_id,
868
+ container_type=self._container_type,
869
+ cloud_pak_for_data_configs=cpd_creds,
870
+ disable_tracing=True,
871
+ )
872
+
873
+ else:
874
+ aigov_client = AIGovFactsClient(
875
+ api_key=self._api_key,
876
+ container_id=self._container_id,
877
+ container_type=self._container_type,
878
+ disable_tracing=True,
879
+ region=REGIONS_URL[self.region]["factsheet"],
880
+ )
881
+
882
+ except Exception as e:
883
+ logging.error(
884
+ f"Error connecting to IBM watsonx.governance (factsheets): {e}",
885
+ )
886
+ raise
887
+
888
+ created_pta = aigov_client.assets.create_prompt(
889
+ **asset_details,
890
+ input_mode="freeform",
891
+ prompt_details=PromptTemplate(**prompt_template_details),
892
+ )
893
+
894
+ return created_pta.to_dict()["asset_id"]
895
+
896
+ def _create_deployment_pta(self, asset_id: str, name: str, model_id: str) -> str:
897
+ from ibm_watsonx_ai import APIClient, Credentials # type: ignore
898
+
899
+ try:
900
+ if hasattr(self, "_wml_cpd_creds") and self._wml_cpd_creds:
901
+ creds = Credentials(**self._wml_cpd_creds)
902
+
903
+ wml_client = APIClient(creds)
904
+ wml_client.set.default_space(self.space_id)
905
+
906
+ else:
907
+ creds = Credentials(
908
+ url=REGIONS_URL[self.region]["wml"],
909
+ api_key=self._api_key,
910
+ )
911
+
912
+ wml_client = APIClient(creds)
913
+ wml_client.set.default_space(self.space_id)
914
+
915
+ except Exception as e:
916
+ logging.error(f"Error connecting to IBM watsonx.ai Runtime: {e}")
917
+ raise
918
+
919
+ meta_props = {
920
+ wml_client.deployments.ConfigurationMetaNames.PROMPT_TEMPLATE: {
921
+ "id": asset_id,
922
+ },
923
+ wml_client.deployments.ConfigurationMetaNames.FOUNDATION_MODEL: {},
924
+ wml_client.deployments.ConfigurationMetaNames.NAME: name
925
+ + " "
926
+ + "deployment",
927
+ wml_client.deployments.ConfigurationMetaNames.BASE_MODEL_ID: model_id,
928
+ }
929
+
930
+ created_deployment = wml_client.deployments.create(asset_id, meta_props)
931
+
932
+ return wml_client.deployments.get_uid(created_deployment)
933
+
934
+ def add_prompt_monitor(
935
+ self,
936
+ name: str,
937
+ model_id: str,
938
+ task_id: Literal[
939
+ "extraction",
940
+ "generation",
941
+ "question_answering",
942
+ "retrieval_augmented_generation",
943
+ "summarization",
944
+ ],
945
+ description: str = "",
946
+ model_parameters: Dict = None,
947
+ prompt_variables: List[str] = None,
948
+ locale: str = "en",
949
+ input_text: str = None,
950
+ context_fields: List[str] = None,
951
+ question_field: str = None,
952
+ ) -> Dict:
953
+ """
954
+ Creates an IBM Prompt Template Asset and ssetup monitor for the given prompt template asset.
955
+
956
+ Args:
957
+ name (str): The name of the Prompt Template Asset.
958
+ model_id (str): The ID of the model associated with the prompt.
959
+ task_id (str): The task identifier.
960
+ description (str, optional): A description of the Prompt Template Asset.
961
+ model_parameters (Dict, optional): A dictionary of model parameters and their respective values.
962
+ prompt_variables (List[str], optional): A list of values for prompt input variables.
963
+ locale (str, optional): Locale code for the input/output language. eg. "en", "pt", "es".
964
+ input_text (str, optional): The input text for the prompt.
965
+ context_fields (List[str], optional): A list of fields that will provide context to the prompt.
966
+ Applicable only for the `retrieval_augmented_generation` task type.
967
+ question_field (str, optional): The field containing the question to be answered.
968
+ Applicable only for the `retrieval_augmented_generation` task type.
969
+
970
+ Example:
971
+ ```python
972
+ wxgov_client.add_prompt_monitor(
973
+ name="IBM prompt template",
974
+ model_id="ibm/granite-3-2b-instruct",
975
+ task_id="retrieval_augmented_generation",
976
+ prompt_variables=["context1", "context2", "input_query"],
977
+ input_text="Prompt text to be given",
978
+ context_fields=["context1", "context2"],
979
+ question_field="input_query",
980
+ )
981
+ ```
982
+ """
983
+ if (not (self.project_id or self.space_id)) or (
984
+ self.project_id and self.space_id
985
+ ):
986
+ raise ValueError(
987
+ "Invalid configuration: Neither was provided: please set either 'project_id' or 'space_id'. "
988
+ "Both were provided: 'project_id' and 'space_id' cannot be set at the same time."
989
+ )
990
+
991
+ if task_id == "retrieval_augmented_generation":
992
+ if not context_fields or not question_field:
993
+ raise ValueError(
994
+ "For 'retrieval_augmented_generation' task, requires non-empty 'context_fields' and 'question_field'."
995
+ )
996
+
997
+ prompt_metadata = locals()
998
+ # Remove unused vars from dict
999
+ prompt_metadata.pop("self", None)
1000
+ prompt_metadata.pop("context_fields", None)
1001
+ prompt_metadata.pop("question_field", None)
1002
+ prompt_metadata.pop("locale", None)
1003
+
1004
+ # Update name of keys to aigov_facts api
1005
+ prompt_metadata["input"] = prompt_metadata.pop("input_text", None)
1006
+
1007
+ # Update list of vars to dict
1008
+ prompt_metadata["prompt_variables"] = Dict.fromkeys(
1009
+ prompt_metadata["prompt_variables"], ""
1010
+ )
1011
+
1012
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator # type: ignore
1013
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
1014
+
1015
+ if not self._wos_client:
1016
+ try:
1017
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1018
+ from ibm_cloud_sdk_core.authenticators import (
1019
+ CloudPakForDataAuthenticator, # type: ignore
1020
+ )
1021
+
1022
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1023
+
1024
+ self._wos_client = WosAPIClient(
1025
+ authenticator=authenticator,
1026
+ service_url=self._wos_cpd_creds["url"],
1027
+ )
1028
+
1029
+ else:
1030
+ from ibm_cloud_sdk_core.authenticators import (
1031
+ IAMAuthenticator, # type: ignore
1032
+ )
1033
+
1034
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1035
+ self._wos_client = WosAPIClient(
1036
+ authenticator=authenticator,
1037
+ service_url=REGIONS_URL[self.region]["wos"],
1038
+ )
1039
+
1040
+ except Exception as e:
1041
+ logging.error(
1042
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1043
+ )
1044
+ raise
1045
+
1046
+ prompt_details = _filter_dict(
1047
+ prompt_metadata,
1048
+ ["prompt_variables", "input", "model_parameters"],
1049
+ )
1050
+
1051
+ asset_details = _filter_dict(
1052
+ prompt_metadata,
1053
+ ["description"],
1054
+ ["name", "model_id", "task_id"],
1055
+ )
1056
+
1057
+ pta_id = suppress_output(
1058
+ self._create_prompt_template, prompt_details, asset_details
1059
+ )
1060
+ deployment_id = None
1061
+ if self._container_type == "space":
1062
+ deployment_id = suppress_output(
1063
+ self._create_deployment_pta, pta_id, name, model_id
1064
+ )
1065
+
1066
+ monitors = {
1067
+ "generative_ai_quality": {
1068
+ "parameters": {"min_sample_size": 10, "metrics_configuration": {}},
1069
+ },
1070
+ }
1071
+
1072
+ max_attempt_execute_prompt_setup = 0
1073
+ while max_attempt_execute_prompt_setup < 2:
1074
+ try:
1075
+ generative_ai_monitor_details = suppress_output(
1076
+ self._wos_client.wos.execute_prompt_setup,
1077
+ prompt_template_asset_id=pta_id,
1078
+ space_id=self.space_id,
1079
+ project_id=self.project_id,
1080
+ deployment_id=deployment_id,
1081
+ label_column="reference_output",
1082
+ context_fields=context_fields,
1083
+ question_field=question_field,
1084
+ operational_space_id=self._deployment_stage,
1085
+ problem_type=task_id,
1086
+ data_input_locale=[locale],
1087
+ generated_output_locale=[locale],
1088
+ input_data_type="unstructured_text",
1089
+ supporting_monitors=monitors,
1090
+ background_mode=False,
1091
+ ).result
1092
+
1093
+ break
1094
+
1095
+ except Exception as e:
1096
+ if (
1097
+ e.code == 403
1098
+ and "The user entitlement does not exist" in e.message
1099
+ and max_attempt_execute_prompt_setup < 1
1100
+ ):
1101
+ max_attempt_execute_prompt_setup = (
1102
+ max_attempt_execute_prompt_setup + 1
1103
+ )
1104
+
1105
+ data_marts = self._wos_client.data_marts.list().result
1106
+ if (data_marts.data_marts is None) or (not data_marts.data_marts):
1107
+ raise ValueError(
1108
+ "Error retrieving IBM watsonx.governance (openscale) data mart. \
1109
+ Make sure the data mart are configured.",
1110
+ )
1111
+
1112
+ data_mart_id = data_marts.data_marts[0].metadata.id
1113
+
1114
+ self._wos_client.wos.add_instance_mapping(
1115
+ service_instance_id=data_mart_id,
1116
+ space_id=self.space_id,
1117
+ project_id=self.project_id,
1118
+ )
1119
+ else:
1120
+ max_attempt_execute_prompt_setup = 2
1121
+ raise
1122
+
1123
+ generative_ai_monitor_details = generative_ai_monitor_details._to_dict()
1124
+
1125
+ return {
1126
+ "prompt_template_asset_id": pta_id,
1127
+ "deployment_id": deployment_id,
1128
+ "subscription_id": generative_ai_monitor_details["subscription_id"],
1129
+ }
1130
+
1131
+ def store_payload_records(
1132
+ self,
1133
+ request_records: List[Dict],
1134
+ subscription_id: str = None,
1135
+ ) -> List[str]:
1136
+ """
1137
+ Stores records to the payload logging system.
1138
+
1139
+ Args:
1140
+ request_records (List[Dict]): A list of records to be logged. Each record is represented as a dictionary.
1141
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
1142
+
1143
+ Example:
1144
+ ```python
1145
+ wxgov_client.store_payload_records(
1146
+ request_records=[
1147
+ {
1148
+ "context1": "value_context1",
1149
+ "context2": "value_context1",
1150
+ "input_query": "What's Beekeeper Framework?",
1151
+ "generated_text": "Beekeeper is a data framework to make AI easier to work with.",
1152
+ "input_token_count": 25,
1153
+ "generated_token_count": 150,
1154
+ }
1155
+ ],
1156
+ subscription_id="5d62977c-a53d-4b6d-bda1-7b79b3b9d1a0",
1157
+ )
1158
+ ```
1159
+ """
1160
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
1161
+ from ibm_watson_openscale import APIClient as WosAPIClient
1162
+ from ibm_watson_openscale.supporting_classes.enums import (
1163
+ DataSetTypes,
1164
+ TargetTypes,
1165
+ )
1166
+
1167
+ # Expected behavior: Prefer using fn `subscription_id`.
1168
+ # Fallback to `self.subscription_id` if `subscription_id` None or empty.
1169
+ _subscription_id = subscription_id or self.subscription_id
1170
+
1171
+ if _subscription_id is None or _subscription_id == "":
1172
+ raise ValueError(
1173
+ "Unexpected value for 'subscription_id': Cannot be None or empty string."
1174
+ )
1175
+
1176
+ if not self._wos_client:
1177
+ try:
1178
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1179
+ from ibm_cloud_sdk_core.authenticators import (
1180
+ CloudPakForDataAuthenticator, # type: ignore
1181
+ )
1182
+
1183
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1184
+
1185
+ self._wos_client = WosAPIClient(
1186
+ authenticator=authenticator,
1187
+ service_url=self._wos_cpd_creds["url"],
1188
+ )
1189
+
1190
+ else:
1191
+ from ibm_cloud_sdk_core.authenticators import (
1192
+ IAMAuthenticator, # type: ignore
1193
+ )
1194
+
1195
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1196
+ self._wos_client = WosAPIClient(
1197
+ authenticator=authenticator,
1198
+ service_url=REGIONS_URL[self.region]["wos"],
1199
+ )
1200
+
1201
+ except Exception as e:
1202
+ logging.error(
1203
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1204
+ )
1205
+ raise
1206
+
1207
+ subscription_details = self._wos_client.subscriptions.get(
1208
+ _subscription_id,
1209
+ ).result
1210
+ subscription_details = json.loads(str(subscription_details))
1211
+
1212
+ feature_fields = subscription_details["entity"]["asset_properties"][
1213
+ "feature_fields"
1214
+ ]
1215
+
1216
+ payload_data_set_id = (
1217
+ self._wos_client.data_sets.list(
1218
+ type=DataSetTypes.PAYLOAD_LOGGING,
1219
+ target_target_id=_subscription_id,
1220
+ target_target_type=TargetTypes.SUBSCRIPTION,
1221
+ )
1222
+ .result.data_sets[0]
1223
+ .metadata.id
1224
+ )
1225
+
1226
+ payload_data = _convert_payload_format(request_records, feature_fields)
1227
+
1228
+ suppress_output(
1229
+ self._wos_client.data_sets.store_records,
1230
+ data_set_id=payload_data_set_id,
1231
+ request_body=payload_data,
1232
+ background_mode=False,
1233
+ )
1234
+
1235
+ return [data["scoring_id"] + "-1" for data in payload_data]
1236
+
1237
+ def __call__(self, payload: PayloadRecord) -> None:
1238
+ if self.prompt_template:
1239
+ template_vars = extract_template_vars(
1240
+ self.prompt_template.template, payload.input_text
1241
+ )
1242
+
1243
+ if not template_vars:
1244
+ self.store_payload_records([payload.model_dump()])
1245
+ else:
1246
+ self.store_payload_records([{**payload.model_dump(), **template_vars}])
1247
+
1248
+
1249
+ # ===== Supporting Classes =====
1250
+ class WatsonxLocalMetric(BaseModel):
1251
+ """
1252
+ Provides the IBM watsonx.governance local monitor metric definition.
1253
+
1254
+ Attributes:
1255
+ name (str): The name of the metric.
1256
+ data_type (str): The data type of the metric. Currently supports "string", "integer", "double", and "timestamp".
1257
+ nullable (bool, optional): Indicates whether the metric can be null. Defaults to `False`.
1258
+
1259
+ Example:
1260
+ ```python
1261
+ from beekeeper.monitors.watsonx import WatsonxLocalMetric
1262
+
1263
+ WatsonxLocalMetric(name="context_quality", data_type="double")
1264
+ ```
1265
+ """
1266
+
1267
+ name: str
1268
+ data_type: Literal["string", "integer", "double", "timestamp"]
1269
+ nullable: bool = True
1270
+
1271
+ def to_dict(self) -> Dict:
1272
+ return {"name": self.name, "type": self.data_type, "nullable": self.nullable}
1273
+
1274
+
1275
+ class WatsonxMetricThreshold(BaseModel):
1276
+ """
1277
+ Defines the metric threshold for IBM watsonx.governance.
1278
+
1279
+ Attributes:
1280
+ threshold_type (str): The threshold type. Can be either `lower_limit` or `upper_limit`.
1281
+ default_value (float): The metric threshold value.
1282
+
1283
+ Example:
1284
+ ```python
1285
+ from beekeeper.monitors.watsonx import WatsonxMetricThreshold
1286
+
1287
+ WatsonxMetricThreshold(threshold_type="lower_limit", default_value=0.8)
1288
+ ```
1289
+ """
1290
+
1291
+ threshold_type: Literal["lower_limit", "upper_limit"]
1292
+ default_value: float = None
1293
+
1294
+ def to_dict(self) -> Dict:
1295
+ return {"type": self.threshold_type, "default": self.default_value}
1296
+
1297
+
1298
+ class WatsonxMetric(BaseModel):
1299
+ """
1300
+ Defines the IBM watsonx.governance global monitor metric.
1301
+
1302
+ Attributes:
1303
+ name (str): The name of the metric.
1304
+ applies_to (List[str]): A list of task types that the metric applies to. Currently supports:
1305
+ "summarization", "generation", "question_answering", "extraction", and "retrieval_augmented_generation".
1306
+ thresholds (List[WatsonxMetricThreshold]): A list of metric thresholds associated with the metric.
1307
+
1308
+ Example:
1309
+ ```python
1310
+ from beekeeper.monitors.watsonx import (
1311
+ WatsonxMetric,
1312
+ WatsonxMetricThreshold,
1313
+ )
1314
+
1315
+ WatsonxMetric(
1316
+ name="context_quality",
1317
+ applies_to=["retrieval_augmented_generation", "summarization"],
1318
+ thresholds=[
1319
+ WatsonxMetricThreshold(threshold_type="lower_limit", default_value=0.75)
1320
+ ],
1321
+ )
1322
+ ```
1323
+ """
1324
+
1325
+ name: str
1326
+ applies_to: List[
1327
+ Literal[
1328
+ "summarization",
1329
+ "generation",
1330
+ "question_answering",
1331
+ "extraction",
1332
+ "retrieval_augmented_generation",
1333
+ ]
1334
+ ]
1335
+ thresholds: Optional[List[WatsonxMetricThreshold]] = None
1336
+
1337
+ def to_dict(self) -> Dict:
1338
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1339
+ ApplicabilitySelection,
1340
+ MetricThreshold,
1341
+ )
1342
+
1343
+ monitor_metric = {
1344
+ "name": self.name,
1345
+ "applies_to": ApplicabilitySelection(problem_type=self.applies_to),
1346
+ }
1347
+
1348
+ if self.thresholds is not None:
1349
+ monitor_metric["thresholds"] = [
1350
+ MetricThreshold(**threshold.to_dict()) for threshold in self.thresholds
1351
+ ]
1352
+
1353
+ return monitor_metric
1354
+
1355
+
1356
+ # ===== Metric Classes =====
1357
+ class WatsonxCustomMetric:
1358
+ """
1359
+ Provides functionality to set up a custom metric to measure your model's performance with IBM watsonx.governance.
1360
+
1361
+ Attributes:
1362
+ api_key (str): The API key for IBM watsonx.governance.
1363
+ region (str, optional): The region where IBM watsonx.governance is hosted when using IBM Cloud.
1364
+ Defaults to `us-south`.
1365
+ cpd_creds (CloudPakforDataCredentials, optional): IBM Cloud Pak for Data environment credentials.
1366
+
1367
+ Example:
1368
+ ```python
1369
+ from beekeeper.monitors.watsonx import (
1370
+ WatsonxCustomMetric,
1371
+ CloudPakforDataCredentials,
1372
+ )
1373
+
1374
+ # watsonx.governance (IBM Cloud)
1375
+ wxgov_client = WatsonxCustomMetric(api_key="API_KEY")
1376
+
1377
+ # watsonx.governance (CP4D)
1378
+ cpd_creds = CloudPakforDataCredentials(
1379
+ url="CPD_URL",
1380
+ username="USERNAME",
1381
+ password="PASSWORD",
1382
+ version="5.0",
1383
+ instance_id="openshift",
1384
+ )
1385
+
1386
+ wxgov_client = WatsonxCustomMetric(cpd_creds=cpd_creds)
1387
+ ```
1388
+ """
1389
+
1390
+ def __init__(
1391
+ self,
1392
+ api_key: str = None,
1393
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
1394
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
1395
+ ) -> None:
1396
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator # type: ignore
1397
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
1398
+
1399
+ self.region = region
1400
+ self._api_key = api_key
1401
+ self._wos_client = None
1402
+
1403
+ if cpd_creds:
1404
+ self._wos_cpd_creds = _filter_dict(
1405
+ cpd_creds.to_dict(),
1406
+ ["username", "password", "api_key", "disable_ssl_verification"],
1407
+ ["url"],
1408
+ )
1409
+
1410
+ if not self._wos_client:
1411
+ try:
1412
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1413
+ from ibm_cloud_sdk_core.authenticators import (
1414
+ CloudPakForDataAuthenticator, # type: ignore
1415
+ )
1416
+
1417
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1418
+
1419
+ self._wos_client = WosAPIClient(
1420
+ authenticator=authenticator,
1421
+ service_url=self._wos_cpd_creds["url"],
1422
+ )
1423
+
1424
+ else:
1425
+ from ibm_cloud_sdk_core.authenticators import (
1426
+ IAMAuthenticator, # type: ignore
1427
+ )
1428
+
1429
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1430
+ self._wos_client = WosAPIClient(
1431
+ authenticator=authenticator,
1432
+ service_url=REGIONS_URL[self.region]["wos"],
1433
+ )
1434
+
1435
+ except Exception as e:
1436
+ logging.error(
1437
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1438
+ )
1439
+ raise
1440
+
1441
+ def _add_integrated_system(
1442
+ self,
1443
+ credentials: IntegratedSystemCredentials,
1444
+ name: str,
1445
+ endpoint: str,
1446
+ ) -> str:
1447
+ custom_metrics_integrated_system = self._wos_client.integrated_systems.add(
1448
+ name=name,
1449
+ description="Integrated system created by Beekeeper.",
1450
+ type="custom_metrics_provider",
1451
+ credentials=credentials.to_dict(),
1452
+ connection={"display_name": name, "endpoint": endpoint},
1453
+ ).result
1454
+
1455
+ return custom_metrics_integrated_system.metadata.id
1456
+
1457
+ def _add_monitor_definitions(
1458
+ self,
1459
+ name: str,
1460
+ metrics: List[WatsonxMetric],
1461
+ schedule: bool,
1462
+ ):
1463
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1464
+ ApplicabilitySelection,
1465
+ MonitorInstanceSchedule,
1466
+ MonitorMetricRequest,
1467
+ MonitorRuntime,
1468
+ ScheduleStartTime,
1469
+ )
1470
+
1471
+ _metrics = [MonitorMetricRequest(**metric.to_dict()) for metric in metrics]
1472
+ _monitor_runtime = None
1473
+ _monitor_schedule = None
1474
+
1475
+ if schedule:
1476
+ _monitor_runtime = MonitorRuntime(type="custom_metrics_provider")
1477
+ _monitor_schedule = MonitorInstanceSchedule(
1478
+ repeat_interval=1,
1479
+ repeat_unit="hour",
1480
+ start_time=ScheduleStartTime(
1481
+ type="relative",
1482
+ delay_unit="minute",
1483
+ delay=30,
1484
+ ),
1485
+ )
1486
+
1487
+ custom_monitor_details = self._wos_client.monitor_definitions.add(
1488
+ name=name,
1489
+ metrics=_metrics,
1490
+ tags=[],
1491
+ schedule=_monitor_schedule,
1492
+ applies_to=ApplicabilitySelection(input_data_type=["unstructured_text"]),
1493
+ monitor_runtime=_monitor_runtime,
1494
+ background_mode=False,
1495
+ ).result
1496
+
1497
+ return custom_monitor_details.metadata.id
1498
+
1499
+ def _get_monitor_instance(self, subscription_id: str, monitor_definition_id: str):
1500
+ monitor_instances = self._wos_client.monitor_instances.list(
1501
+ monitor_definition_id=monitor_definition_id,
1502
+ target_target_id=subscription_id,
1503
+ ).result.monitor_instances
1504
+
1505
+ if len(monitor_instances) == 1:
1506
+ return monitor_instances[0]
1507
+ else:
1508
+ return None
1509
+
1510
+ def _update_monitor_instance(
1511
+ self,
1512
+ integrated_system_id: str,
1513
+ custom_monitor_id: str,
1514
+ ):
1515
+ payload = [
1516
+ {
1517
+ "op": "replace",
1518
+ "path": "/parameters",
1519
+ "value": {
1520
+ "custom_metrics_provider_id": integrated_system_id,
1521
+ "custom_metrics_wait_time": 60,
1522
+ "enable_custom_metric_runs": True,
1523
+ },
1524
+ },
1525
+ ]
1526
+
1527
+ return self._wos_client.monitor_instances.update(
1528
+ custom_monitor_id,
1529
+ payload,
1530
+ update_metadata_only=True,
1531
+ ).result
1532
+
1533
+ def _get_patch_request_field(
1534
+ self,
1535
+ field_path: str,
1536
+ field_value: Any,
1537
+ op_name: str = "replace",
1538
+ ) -> Dict:
1539
+ return {"op": op_name, "path": field_path, "value": field_value}
1540
+
1541
+ def _get_dataset_id(
1542
+ self,
1543
+ subscription_id: str,
1544
+ data_set_type: Literal["feedback", "payload_logging"],
1545
+ ) -> str:
1546
+ data_sets = self._wos_client.data_sets.list(
1547
+ target_target_id=subscription_id,
1548
+ type=data_set_type,
1549
+ ).result.data_sets
1550
+ data_set_id = None
1551
+ if len(data_sets) > 0:
1552
+ data_set_id = data_sets[0].metadata.id
1553
+ return data_set_id
1554
+
1555
+ def _get_dataset_data(self, data_set_id: str):
1556
+ json_data = self._wos_client.data_sets.get_list_of_records(
1557
+ data_set_id=data_set_id,
1558
+ format="list",
1559
+ ).result
1560
+
1561
+ if not json_data.get("records"):
1562
+ return None
1563
+
1564
+ return json_data["records"][0]
1565
+
1566
+ def _get_existing_data_mart(self):
1567
+ data_marts = self._wos_client.data_marts.list().result.data_marts
1568
+ if len(data_marts) == 0:
1569
+ raise Exception(
1570
+ "No data marts found. Please ensure at least one data mart is available.",
1571
+ )
1572
+
1573
+ return data_marts[0].metadata.id
1574
+
1575
+ # ===== Global Custom Metrics =====
1576
+ def add_metric_definition(
1577
+ self,
1578
+ name: str,
1579
+ metrics: List[WatsonxMetric],
1580
+ integrated_system_url: str,
1581
+ integrated_system_credentials: IntegratedSystemCredentials,
1582
+ schedule: bool = False,
1583
+ ):
1584
+ """
1585
+ Creates a custom monitor definition for IBM watsonx.governance.
1586
+
1587
+ This must be done before using custom metrics.
1588
+
1589
+ Args:
1590
+ name (str): The name of the custom metric group.
1591
+ metrics (List[WatsonxMetric]): A list of metrics to be measured.
1592
+ schedule (bool, optional): Enable or disable the scheduler. Defaults to `False`.
1593
+ integrated_system_url (str): The URL of the external metric provider.
1594
+ integrated_system_credentials (IntegratedSystemCredentials): The credentials for the integrated system.
1595
+
1596
+ Example:
1597
+ ```python
1598
+ from beekeeper.monitors.watsonx import (
1599
+ WatsonxMetric,
1600
+ IntegratedSystemCredentials,
1601
+ WatsonxMetricThreshold,
1602
+ )
1603
+
1604
+ wxgov_client.add_metric_definition(
1605
+ name="Custom Metric - Custom LLM Quality",
1606
+ metrics=[
1607
+ WatsonxMetric(
1608
+ name="context_quality",
1609
+ applies_to=[
1610
+ "retrieval_augmented_generation",
1611
+ "summarization",
1612
+ ],
1613
+ thresholds=[
1614
+ WatsonxMetricThreshold(
1615
+ threshold_type="lower_limit", default_value=0.75
1616
+ )
1617
+ ],
1618
+ )
1619
+ ],
1620
+ integrated_system_url="IS_URL", # URL to the endpoint computing the metric
1621
+ integrated_system_credentials=IntegratedSystemCredentials(
1622
+ auth_type="basic", username="USERNAME", password="PASSWORD"
1623
+ ),
1624
+ )
1625
+ ```
1626
+ """
1627
+ integrated_system_id = self._add_integrated_system(
1628
+ integrated_system_credentials,
1629
+ name,
1630
+ integrated_system_url,
1631
+ )
1632
+
1633
+ external_monitor_id = suppress_output(
1634
+ self._add_monitor_definitions,
1635
+ name,
1636
+ metrics,
1637
+ schedule,
1638
+ )
1639
+
1640
+ # Associate the external monitor with the integrated system
1641
+ payload = [
1642
+ {
1643
+ "op": "add",
1644
+ "path": "/parameters",
1645
+ "value": {"monitor_definition_ids": [external_monitor_id]},
1646
+ },
1647
+ ]
1648
+
1649
+ self._wos_client.integrated_systems.update(integrated_system_id, payload)
1650
+
1651
+ return {
1652
+ "integrated_system_id": integrated_system_id,
1653
+ "monitor_definition_id": external_monitor_id,
1654
+ }
1655
+
1656
+ def add_monitor_instance(
1657
+ self,
1658
+ integrated_system_id: str,
1659
+ monitor_definition_id: str,
1660
+ subscription_id: str,
1661
+ ):
1662
+ """
1663
+ Enables a custom monitor for the specified subscription and monitor definition.
1664
+
1665
+ Args:
1666
+ integrated_system_id (str): The ID of the integrated system.
1667
+ monitor_definition_id (str): The ID of the custom metric monitor instance.
1668
+ subscription_id (str): The ID of the subscription to associate the monitor with.
1669
+
1670
+ Example:
1671
+ ```python
1672
+ wxgov_client.add_monitor_instance(
1673
+ integrated_system_id="019667ca-5687-7838-8d29-4ff70c2b36b0",
1674
+ monitor_definition_id="custom_llm_quality",
1675
+ subscription_id="0195e95d-03a4-7000-b954-b607db10fe9e",
1676
+ )
1677
+ ```
1678
+ """
1679
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import Target
1680
+
1681
+ data_marts = self._wos_client.data_marts.list().result.data_marts
1682
+ if len(data_marts) == 0:
1683
+ raise Exception(
1684
+ "No data marts found. Please ensure at least one data mart is available.",
1685
+ )
1686
+
1687
+ data_mart_id = data_marts[0].metadata.id
1688
+ existing_monitor_instance = self._get_monitor_instance(
1689
+ subscription_id,
1690
+ monitor_definition_id,
1691
+ )
1692
+
1693
+ if existing_monitor_instance is None:
1694
+ target = Target(target_type="subscription", target_id=subscription_id)
1695
+
1696
+ parameters = {
1697
+ "custom_metrics_provider_id": integrated_system_id,
1698
+ "custom_metrics_wait_time": 60,
1699
+ "enable_custom_metric_runs": True,
1700
+ }
1701
+
1702
+ monitor_instance_details = suppress_output(
1703
+ self._wos_client.monitor_instances.create,
1704
+ data_mart_id=data_mart_id,
1705
+ background_mode=False,
1706
+ monitor_definition_id=monitor_definition_id,
1707
+ target=target,
1708
+ parameters=parameters,
1709
+ ).result
1710
+ else:
1711
+ existing_instance_id = existing_monitor_instance.metadata.id
1712
+ monitor_instance_details = self._update_monitor_instance(
1713
+ integrated_system_id,
1714
+ existing_instance_id,
1715
+ )
1716
+
1717
+ return monitor_instance_details
1718
+
1719
+ def publish_metrics(
1720
+ self,
1721
+ monitor_instance_id: str,
1722
+ run_id: str,
1723
+ request_records: Dict[str, Union[float, int]],
1724
+ ):
1725
+ """
1726
+ Publishes computed custom metrics for a specific global monitor instance.
1727
+
1728
+ Args:
1729
+ monitor_instance_id (str): The unique ID of the monitor instance.
1730
+ run_id (str): The ID of the monitor run that generated the metrics.
1731
+ request_records (Dict[str | float | int]): Dict containing the metrics to be published.
1732
+
1733
+ Example:
1734
+ ```python
1735
+ wxgov_client.publish_metrics(
1736
+ monitor_instance_id="01966801-f9ee-7248-a706-41de00a8a998",
1737
+ run_id="RUN_ID",
1738
+ request_records={"context_quality": 0.914, "sensitivity": 0.85},
1739
+ )
1740
+ ```
1741
+ """
1742
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1743
+ MonitorMeasurementRequest,
1744
+ Runs,
1745
+ )
1746
+
1747
+ measurement_request = MonitorMeasurementRequest(
1748
+ timestamp=datetime.datetime.now(datetime.timezone.utc).strftime(
1749
+ "%Y-%m-%dT%H:%M:%S.%fZ",
1750
+ ),
1751
+ run_id=run_id,
1752
+ metrics=[request_records],
1753
+ )
1754
+
1755
+ self._wos_client.monitor_instances.add_measurements(
1756
+ monitor_instance_id=monitor_instance_id,
1757
+ monitor_measurement_request=[measurement_request],
1758
+ ).result
1759
+
1760
+ run = Runs(watson_open_scale=self._wos_client)
1761
+ patch_payload = []
1762
+ patch_payload.append(self._get_patch_request_field("/status/state", "finished"))
1763
+ patch_payload.append(
1764
+ self._get_patch_request_field(
1765
+ "/status/completed_at",
1766
+ datetime.datetime.now(datetime.timezone.utc).strftime(
1767
+ "%Y-%m-%dT%H:%M:%S.%fZ",
1768
+ ),
1769
+ ),
1770
+ )
1771
+
1772
+ return run.update(
1773
+ monitor_instance_id=monitor_instance_id,
1774
+ monitoring_run_id=run_id,
1775
+ json_patch_operation=patch_payload,
1776
+ ).result
1777
+
1778
+ # ===== Local Custom Metrics =====
1779
+ def add_local_metric_definition(
1780
+ self,
1781
+ name: str,
1782
+ metrics: List[WatsonxLocalMetric],
1783
+ subscription_id: str,
1784
+ ) -> str:
1785
+ """
1786
+ Creates a custom metric definition to compute metrics at the local (transaction) level for IBM watsonx.governance.
1787
+
1788
+ Args:
1789
+ name (str): The name of the custom transaction metric group.
1790
+ metrics (List[WatsonxLocalMetric]): A list of metrics to be monitored at the local (transaction) level.
1791
+ subscription_id (str): The IBM watsonx.governance subscription ID associated with the metric definition.
1792
+
1793
+ Example:
1794
+ ```python
1795
+ from beekeeper.monitors.watsonx import WatsonxLocalMetric
1796
+
1797
+ wxgov_client.add_local_metric_definition(
1798
+ name="Custom LLM Local Metric",
1799
+ subscription_id="019674ca-0c38-745f-8e9b-58546e95174e",
1800
+ metrics=[
1801
+ WatsonxLocalMetric(name="context_quality", data_type="double")
1802
+ ],
1803
+ )
1804
+ ```
1805
+ """
1806
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1807
+ LocationTableName,
1808
+ SparkStruct,
1809
+ SparkStructFieldPrimitive,
1810
+ Target,
1811
+ )
1812
+
1813
+ target = Target(target_id=subscription_id, target_type="subscription")
1814
+ data_mart_id = self._get_existing_data_mart()
1815
+ metrics = [SparkStructFieldPrimitive(**metric.to_dict()) for metric in metrics]
1816
+
1817
+ schema_fields = [
1818
+ SparkStructFieldPrimitive(
1819
+ name="scoring_id",
1820
+ type="string",
1821
+ nullable=False,
1822
+ ),
1823
+ SparkStructFieldPrimitive(
1824
+ name="run_id",
1825
+ type="string",
1826
+ nullable=True,
1827
+ ),
1828
+ SparkStructFieldPrimitive(
1829
+ name="computed_on",
1830
+ type="string",
1831
+ nullable=False,
1832
+ ),
1833
+ ]
1834
+
1835
+ schema_fields.extend(metrics)
1836
+
1837
+ data_schema = SparkStruct(type="struct", fields=schema_fields)
1838
+
1839
+ return self._wos_client.data_sets.add(
1840
+ target=target,
1841
+ name=name,
1842
+ type="custom",
1843
+ data_schema=data_schema,
1844
+ data_mart_id=data_mart_id,
1845
+ location=LocationTableName(
1846
+ table_name=name.lower().replace(" ", "_") + "_" + str(uuid.uuid4())[:8],
1847
+ ),
1848
+ background_mode=False,
1849
+ ).result.metadata.id
1850
+
1851
+ def publish_local_metrics(
1852
+ self,
1853
+ metric_instance_id: str,
1854
+ request_records: List[Dict],
1855
+ ):
1856
+ """
1857
+ Publishes computed custom metrics for a specific transaction record.
1858
+
1859
+ Args:
1860
+ metric_instance_id (str): The unique ID of the custom transaction metric.
1861
+ request_records (List[Dict]): A list of dictionaries containing the records to be stored.
1862
+
1863
+ Example:
1864
+ ```python
1865
+ wxgov_client.publish_local_metrics(
1866
+ metric_instance_id="0196ad39-1b75-7e77-bddb-cc5393d575c2",
1867
+ request_records=[
1868
+ {
1869
+ "scoring_id": "304a9270-44a1-4c4d-bfd4-f756541011f8",
1870
+ "run_id": "RUN_ID",
1871
+ "computed_on": "payload",
1872
+ "context_quality": 0.786,
1873
+ }
1874
+ ],
1875
+ )
1876
+ ```
1877
+ """
1878
+ return self._wos_client.data_sets.store_records(
1879
+ data_set_id=metric_instance_id,
1880
+ request_body=request_records,
1881
+ ).result
1882
+
1883
+ def list_local_metrics(
1884
+ self,
1885
+ metric_instance_id: str,
1886
+ ):
1887
+ """
1888
+ Lists records from a custom local metric definition.
1889
+
1890
+ Args:
1891
+ metric_instance_id (str): The unique ID of the custom transaction metric.
1892
+
1893
+ Example:
1894
+ ```python
1895
+ wxgov_client.list_local_metrics(
1896
+ metric_instance_id="0196ad47-c505-73c0-9d7b-91c082b697e3"
1897
+ )
1898
+ ```
1899
+ """
1900
+ return self._get_dataset_data(metric_instance_id)