beekeeper-monitors-watsonx 1.0.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2002 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from typing import Any, Dict, List, Literal, Optional, Union
7
+
8
+ import certifi
9
+ from beekeeper.core.monitors import PromptMonitor
10
+ from beekeeper.core.monitors.types import PayloadRecord
11
+ from beekeeper.core.prompts.utils import extract_template_vars
12
+ from beekeeper.monitors.watsonx.instrumentation import suppress_output
13
+ from deprecated import deprecated
14
+ from pydantic.v1 import BaseModel
15
+
16
+ os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
17
+ logging.getLogger("ibm_watsonx_ai.client").setLevel(logging.ERROR)
18
+ logging.getLogger("ibm_watsonx_ai.wml_resource").setLevel(logging.ERROR)
19
+
20
+ REGIONS_URL = {
21
+ "us-south": {
22
+ "wml": "https://us-south.ml.cloud.ibm.com",
23
+ "wos": "https://api.aiopenscale.cloud.ibm.com",
24
+ "factsheet": None,
25
+ },
26
+ "eu-de": {
27
+ "wml": "https://eu-de.ml.cloud.ibm.com",
28
+ "wos": "https://eu-de.api.aiopenscale.cloud.ibm.com",
29
+ "factsheet": "frankfurt",
30
+ },
31
+ "au-syd": {
32
+ "wml": "https://au-syd.ml.cloud.ibm.com",
33
+ "wos": "https://au-syd.api.aiopenscale.cloud.ibm.com",
34
+ "factsheet": "sydney",
35
+ },
36
+ }
37
+
38
+
39
+ def _filter_dict(original_dict: Dict, optional_keys: List, required_keys: List = []):
40
+ """
41
+ Filters a dictionary to keep only the specified keys and checks for required keys.
42
+
43
+ Args:
44
+ original_dict (Dict): The original dictionary.
45
+ optional_keys (list): A list of keys to retain.
46
+ required_keys (list, optional): A list of keys that must be present in the dictionary. Defaults to None.
47
+ """
48
+ # Ensure all required keys are in the source dict
49
+ missing_keys = [key for key in required_keys if key not in original_dict]
50
+ if missing_keys:
51
+ raise KeyError(f"Missing required parameter: {missing_keys}")
52
+
53
+ all_keys_to_keep = set(required_keys + optional_keys)
54
+
55
+ # Create a new dictionary with only the key-value pairs where the key is in 'keys' and value is not None
56
+ return {
57
+ key: original_dict[key]
58
+ for key in all_keys_to_keep
59
+ if key in original_dict and original_dict[key] is not None
60
+ }
61
+
62
+
63
+ def _convert_payload_format(
64
+ records: List[Dict],
65
+ feature_fields: List[str],
66
+ ) -> List[Dict]:
67
+ payload_data = []
68
+ response_fields = ["generated_text", "input_token_count", "generated_token_count"]
69
+
70
+ for record in records:
71
+ request = {"parameters": {"template_variables": {}}}
72
+ results = {}
73
+
74
+ request["parameters"]["template_variables"] = {
75
+ field: str(record.get(field, "")) for field in feature_fields
76
+ }
77
+
78
+ results = {
79
+ field: record.get(field) for field in response_fields if record.get(field)
80
+ }
81
+
82
+ pl_record = {
83
+ "request": request,
84
+ "response": {"results": [results]},
85
+ "scoring_id": str(uuid.uuid4()),
86
+ }
87
+
88
+ if "response_time" in record:
89
+ pl_record["response_time"] = record["response_time"]
90
+
91
+ payload_data.append(pl_record)
92
+
93
+ return payload_data
94
+
95
+
96
+ # ===== Credentials Classes =====
97
+ class CloudPakforDataCredentials(BaseModel):
98
+ """
99
+ Encapsulates the credentials required for IBM Cloud Pak for Data.
100
+
101
+ Attributes:
102
+ url (str): The host URL of the Cloud Pak for Data environment.
103
+ api_key (str, optional): The API key for the environment, if IAM is enabled.
104
+ username (str, optional): The username for the environment.
105
+ password (str, optional): The password for the environment.
106
+ bedrock_url (str, optional): The Bedrock URL. Required only when IAM integration is enabled on CP4D 4.0.x clusters.
107
+ instance_id (str, optional): The instance ID.
108
+ version (str, optional): The version of Cloud Pak for Data.
109
+ disable_ssl_verification (bool, optional): Indicates whether to disable SSL certificate verification.
110
+ Defaults to `True`.
111
+ """
112
+
113
+ url: str
114
+ api_key: Optional[str] = None
115
+ username: Optional[str] = None
116
+ password: Optional[str] = None
117
+ bedrock_url: Optional[str] = None
118
+ instance_id: Optional[Literal["icp", "openshift"]] = None
119
+ version: Optional[str] = None
120
+ disable_ssl_verification: bool = True
121
+
122
+ def __init__(
123
+ self,
124
+ url: str,
125
+ api_key: Optional[str] = None,
126
+ username: Optional[str] = None,
127
+ password: Optional[str] = None,
128
+ bedrock_url: Optional[str] = None,
129
+ instance_id: Optional[Literal["icp", "openshift"]] = None,
130
+ version: Optional[str] = None,
131
+ disable_ssl_verification: bool = True,
132
+ ) -> None:
133
+ super().__init__(
134
+ url=url,
135
+ api_key=api_key,
136
+ username=username,
137
+ password=password,
138
+ bedrock_url=bedrock_url,
139
+ instance_id=instance_id,
140
+ version=version,
141
+ disable_ssl_verification=disable_ssl_verification,
142
+ )
143
+
144
+ def to_dict(self) -> Dict[str, Any]:
145
+ cpd_creds = dict([(k, v) for k, v in self.__dict__.items()]) # noqa: C404
146
+
147
+ if "instance_id" in cpd_creds and self.instance_id.lower() not in [
148
+ "icp",
149
+ "openshift",
150
+ ]:
151
+ cpd_creds.pop("instance_id")
152
+
153
+ return cpd_creds
154
+
155
+
156
+ class IntegratedSystemCredentials(BaseModel):
157
+ """
158
+ Encapsulates the credentials for an Integrated System based on the authentication type.
159
+
160
+ Depending on the `auth_type`, only a subset of the properties is required.
161
+
162
+ Attributes:
163
+ auth_type (str): The type of authentication. Currently supports "basic" and "bearer".
164
+ username (str, optional): The username for Basic Authentication.
165
+ password (str, optional): The password for Basic Authentication.
166
+ token_url (str, optional): The URL of the authentication endpoint used to request a Bearer token.
167
+ token_method (str, optional): The HTTP method (e.g., "POST", "GET") used to request the Bearer token.
168
+ Defaults to "POST".
169
+ token_headers (Dict, optional): Optional headers to include when requesting the Bearer token.
170
+ Defaults to `None`.
171
+ token_payload (str | dict, optional): The body or payload to send when requesting the Bearer token.
172
+ Can be a string (e.g., raw JSON). Defaults to `None`.
173
+ """
174
+
175
+ auth_type: Literal["basic", "bearer"]
176
+ username: Optional[str] # basic
177
+ password: Optional[str] # basic
178
+ token_url: Optional[str] # bearer
179
+ token_method: Optional[str] = "POST" # bearer
180
+ token_headers: Optional[Dict] = {} # bearer
181
+ token_payload: Optional[Union[str, Dict]] = None # bearer
182
+
183
+ def __init__(
184
+ self,
185
+ auth_type: Literal["basic", "bearer"],
186
+ username: str = None,
187
+ password: str = None,
188
+ token_url: str = None,
189
+ token_method: str = "POST",
190
+ token_headers: Dict = {},
191
+ token_payload: Union[str, Dict] = None,
192
+ ) -> None:
193
+ if auth_type == "basic":
194
+ if not username or not password:
195
+ raise ValueError(
196
+ "`username` and `password` are required for auth_type = 'basic'.",
197
+ )
198
+ elif auth_type == "bearer":
199
+ if not token_url:
200
+ raise ValueError(
201
+ "`token_url` are required for auth_type = 'bearer'.",
202
+ )
203
+
204
+ super().__init__(
205
+ auth_type=auth_type,
206
+ username=username,
207
+ password=password,
208
+ token_url=token_url,
209
+ token_method=token_method,
210
+ token_headers=token_headers,
211
+ token_payload=token_payload,
212
+ )
213
+
214
+ def to_dict(self) -> Dict:
215
+ integrated_system_creds = {"auth_type": self.auth_type}
216
+
217
+ if self.auth_type == "basic":
218
+ integrated_system_creds["username"] = self.username
219
+ integrated_system_creds["password"] = self.password
220
+ elif self.auth_type == "bearer":
221
+ integrated_system_creds["token_info"] = {
222
+ "url": self.token_url,
223
+ "method": self.token_method,
224
+ "headers": self.token_headers,
225
+ "payload": self.token_payload,
226
+ }
227
+
228
+ return integrated_system_creds
229
+
230
+
231
+ # ===== Monitor Classes =====
232
+ class WatsonxExternalPromptMonitor(PromptMonitor):
233
+ """
234
+ Provides functionality to interact with IBM watsonx.governance for monitoring external LLMs.
235
+
236
+ Note:
237
+ One of the following parameters is required to create a prompt monitor:
238
+ `project_id` or `space_id`, but not both.
239
+
240
+ Attributes:
241
+ api_key (str): The API key for IBM watsonx.governance.
242
+ space_id (str, optional): The space ID in watsonx.governance.
243
+ project_id (str, optional): The project ID in watsonx.governance.
244
+ region (str, optional): The region where watsonx.governance is hosted when using IBM Cloud.
245
+ Defaults to `us-south`.
246
+ cpd_creds (CloudPakforDataCredentials, optional): The Cloud Pak for Data environment credentials.
247
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
248
+
249
+ Example:
250
+ ```python
251
+ from beekeeper.monitors.watsonx import (
252
+ WatsonxExternalPromptMonitor,
253
+ CloudPakforDataCredentials,
254
+ )
255
+
256
+ # watsonx.governance (IBM Cloud)
257
+ wxgov_client = WatsonxExternalPromptMonitor(
258
+ api_key="API_KEY", space_id="SPACE_ID"
259
+ )
260
+
261
+ # watsonx.governance (CP4D)
262
+ cpd_creds = CloudPakforDataCredentials(
263
+ url="CPD_URL",
264
+ username="USERNAME",
265
+ password="PASSWORD",
266
+ version="5.0",
267
+ instance_id="openshift",
268
+ )
269
+
270
+ wxgov_client = WatsonxExternalPromptMonitor(
271
+ space_id="SPACE_ID", cpd_creds=cpd_creds
272
+ )
273
+ ```
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ api_key: str = None,
279
+ space_id: str = None,
280
+ project_id: str = None,
281
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
282
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
283
+ subscription_id: str = None,
284
+ **kwargs,
285
+ ) -> None:
286
+ import ibm_aigov_facts_client # noqa: F401
287
+ import ibm_cloud_sdk_core.authenticators # noqa: F401
288
+ import ibm_watson_openscale # noqa: F401
289
+ import ibm_watsonx_ai # noqa: F401
290
+
291
+ super().__init__(**kwargs)
292
+
293
+ self.space_id = space_id
294
+ self.project_id = project_id
295
+ self.region = region
296
+ self.subscription_id = subscription_id
297
+ self._api_key = api_key
298
+ self._wos_client = None
299
+
300
+ self._container_id = space_id if space_id else project_id
301
+ self._container_type = "space" if space_id else "project"
302
+ self._deployment_stage = "production" if space_id else "development"
303
+
304
+ if cpd_creds:
305
+ self._wos_cpd_creds = _filter_dict(
306
+ cpd_creds.to_dict(),
307
+ ["username", "password", "api_key", "disable_ssl_verification"],
308
+ ["url"],
309
+ )
310
+ self._fact_cpd_creds = _filter_dict(
311
+ cpd_creds.to_dict(),
312
+ ["username", "password", "api_key", "bedrock_url"],
313
+ ["url"],
314
+ )
315
+ self._fact_cpd_creds["service_url"] = self._fact_cpd_creds.pop("url")
316
+ self._wml_cpd_creds = _filter_dict(
317
+ cpd_creds.to_dict(),
318
+ [
319
+ "username",
320
+ "password",
321
+ "api_key",
322
+ "instance_id",
323
+ "version",
324
+ "bedrock_url",
325
+ ],
326
+ ["url"],
327
+ )
328
+
329
+ def _create_detached_prompt(
330
+ self,
331
+ detached_details: Dict,
332
+ prompt_template_details: Dict,
333
+ detached_asset_details: Dict,
334
+ ) -> str:
335
+ from ibm_aigov_facts_client import ( # type: ignore
336
+ AIGovFactsClient,
337
+ CloudPakforDataConfig,
338
+ DetachedPromptTemplate,
339
+ PromptTemplate,
340
+ )
341
+
342
+ try:
343
+ if hasattr(self, "_fact_cpd_creds") and self._fact_cpd_creds:
344
+ cpd_creds = CloudPakforDataConfig(**self._fact_cpd_creds)
345
+
346
+ aigov_client = AIGovFactsClient(
347
+ container_id=self._container_id,
348
+ container_type=self._container_type,
349
+ cloud_pak_for_data_configs=cpd_creds,
350
+ disable_tracing=True,
351
+ )
352
+
353
+ else:
354
+ aigov_client = AIGovFactsClient(
355
+ api_key=self._api_key,
356
+ container_id=self._container_id,
357
+ container_type=self._container_type,
358
+ disable_tracing=True,
359
+ region=REGIONS_URL[self.region]["factsheet"],
360
+ )
361
+
362
+ except Exception as e:
363
+ logging.error(
364
+ f"Error connecting to IBM watsonx.governance (factsheets): {e}",
365
+ )
366
+ raise
367
+
368
+ created_detached_pta = aigov_client.assets.create_detached_prompt(
369
+ **detached_asset_details,
370
+ prompt_details=PromptTemplate(**prompt_template_details),
371
+ detached_information=DetachedPromptTemplate(**detached_details),
372
+ )
373
+
374
+ return created_detached_pta.to_dict()["asset_id"]
375
+
376
+ def _create_deployment_pta(self, asset_id: str, name: str, model_id: str) -> str:
377
+ from ibm_watsonx_ai import APIClient, Credentials # type: ignore
378
+
379
+ try:
380
+ if hasattr(self, "_wml_cpd_creds") and self._wml_cpd_creds:
381
+ creds = Credentials(**self._wml_cpd_creds)
382
+
383
+ wml_client = APIClient(creds)
384
+ wml_client.set.default_space(self.space_id)
385
+
386
+ else:
387
+ creds = Credentials(
388
+ url=REGIONS_URL[self.region]["wml"],
389
+ api_key=self._api_key,
390
+ )
391
+ wml_client = APIClient(creds)
392
+ wml_client.set.default_space(self.space_id)
393
+
394
+ except Exception as e:
395
+ logging.error(f"Error connecting to IBM watsonx.ai Runtime: {e}")
396
+ raise
397
+
398
+ meta_props = {
399
+ wml_client.deployments.ConfigurationMetaNames.PROMPT_TEMPLATE: {
400
+ "id": asset_id,
401
+ },
402
+ wml_client.deployments.ConfigurationMetaNames.DETACHED: {},
403
+ wml_client.deployments.ConfigurationMetaNames.NAME: name
404
+ + " "
405
+ + "deployment",
406
+ wml_client.deployments.ConfigurationMetaNames.BASE_MODEL_ID: model_id,
407
+ }
408
+
409
+ created_deployment = wml_client.deployments.create(asset_id, meta_props)
410
+
411
+ return wml_client.deployments.get_uid(created_deployment)
412
+
413
+ @deprecated(
414
+ reason="'add_prompt_observer()' is deprecated and will be removed in a future version. Use 'add_prompt_monitor()' instead.",
415
+ version="1.0.5",
416
+ action="always",
417
+ )
418
+ def add_prompt_observer(
419
+ self,
420
+ name: str,
421
+ model_id: str,
422
+ task_id: Literal[
423
+ "extraction",
424
+ "generation",
425
+ "question_answering",
426
+ "retrieval_augmented_generation",
427
+ "summarization",
428
+ ],
429
+ detached_model_provider: str,
430
+ description: str = "",
431
+ model_parameters: Dict = None,
432
+ detached_model_name: str = None,
433
+ detached_model_url: str = None,
434
+ detached_prompt_url: str = None,
435
+ detached_prompt_additional_info: Dict = None,
436
+ prompt_variables: List[str] = None,
437
+ locale: str = "en",
438
+ input_text: str = None,
439
+ context_fields: List[str] = None,
440
+ question_field: str = None,
441
+ ) -> Dict:
442
+ return self.add_prompt_monitor(
443
+ name=name,
444
+ model_id=model_id,
445
+ task_id=task_id,
446
+ detached_model_provider=detached_model_provider,
447
+ description=description,
448
+ model_parameters=model_parameters,
449
+ detached_model_name=detached_model_name,
450
+ detached_model_url=detached_model_url,
451
+ detached_prompt_url=detached_prompt_url,
452
+ detached_prompt_additional_info=detached_prompt_additional_info,
453
+ prompt_variables=prompt_variables,
454
+ locale=locale,
455
+ input_text=input_text,
456
+ context_fields=context_fields,
457
+ question_field=question_field,
458
+ )
459
+
460
+ def add_prompt_monitor(
461
+ self,
462
+ name: str,
463
+ model_id: str,
464
+ task_id: Literal[
465
+ "extraction",
466
+ "generation",
467
+ "question_answering",
468
+ "retrieval_augmented_generation",
469
+ "summarization",
470
+ ],
471
+ detached_model_provider: str,
472
+ description: str = "",
473
+ model_parameters: Dict = None,
474
+ detached_model_name: str = None,
475
+ detached_model_url: str = None,
476
+ detached_prompt_url: str = None,
477
+ detached_prompt_additional_info: Dict = None,
478
+ prompt_variables: List[str] = None,
479
+ locale: str = "en",
480
+ input_text: str = None,
481
+ context_fields: List[str] = None,
482
+ question_field: str = None,
483
+ ) -> Dict:
484
+ """
485
+ Creates a Detached/External Prompt Template Asset and setup monitor for a given prompt template asset.
486
+
487
+ Args:
488
+ name (str): The name of the External Prompt Template Asset.
489
+ model_id (str): The ID of the model associated with the prompt.
490
+ task_id (str): The task identifier.
491
+ detached_model_provider (str): The external model provider.
492
+ description (str, optional): A description of the External Prompt Template Asset.
493
+ model_parameters (Dict, optional): Model parameters and their respective values.
494
+ detached_model_name (str, optional): The name of the external model.
495
+ detached_model_url (str, optional): The URL of the external model.
496
+ detached_prompt_url (str, optional): The URL of the external prompt.
497
+ detached_prompt_additional_info (Dict, optional): Additional information related to the external prompt.
498
+ prompt_variables (List[str], optional): Values for the prompt variables.
499
+ locale (str, optional): Locale code for the input/output language. eg. "en", "pt", "es".
500
+ input_text (str, optional): The input text for the prompt.
501
+ context_fields (List[str], optional): A list of fields that will provide context to the prompt.
502
+ Applicable only for "retrieval_augmented_generation" task type.
503
+ question_field (str, optional): The field containing the question to be answered.
504
+ Applicable only for "retrieval_augmented_generation" task type.
505
+
506
+ Example:
507
+ ```python
508
+ wxgov_client.add_prompt_monitor(
509
+ name="Detached prompt (model AWS Anthropic)",
510
+ model_id="anthropic.claude-v2",
511
+ task_id="retrieval_augmented_generation",
512
+ detached_model_provider="AWS Bedrock",
513
+ detached_model_name="Anthropic Claude 2.0",
514
+ detached_model_url="https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html",
515
+ prompt_variables=["context1", "context2", "input_query"],
516
+ input_text="Prompt text to be given",
517
+ context_fields=["context1", "context2"],
518
+ question_field="input_query",
519
+ )
520
+ ```
521
+ """
522
+ if (not (self.project_id or self.space_id)) or (
523
+ self.project_id and self.space_id
524
+ ):
525
+ raise ValueError(
526
+ "Invalid configuration: Neither was provided: please set either 'project_id' or 'space_id'. "
527
+ "Both were provided: 'project_id' and 'space_id' cannot be set at the same time."
528
+ )
529
+
530
+ if task_id == "retrieval_augmented_generation":
531
+ if not context_fields or not question_field:
532
+ raise ValueError(
533
+ "For 'retrieval_augmented_generation' task, requires non-empty 'context_fields' and 'question_field'."
534
+ )
535
+
536
+ prompt_metadata = locals()
537
+ # Remove unused vars from dict
538
+ prompt_metadata.pop("self", None)
539
+ prompt_metadata.pop("context_fields", None)
540
+ prompt_metadata.pop("question_field", None)
541
+ prompt_metadata.pop("locale", None)
542
+
543
+ # Update name of keys to aigov_facts api
544
+ prompt_metadata["input"] = prompt_metadata.pop("input_text", None)
545
+ prompt_metadata["model_provider"] = prompt_metadata.pop(
546
+ "detached_model_provider",
547
+ None,
548
+ )
549
+ prompt_metadata["model_name"] = prompt_metadata.pop("detached_model_name", None)
550
+ prompt_metadata["model_url"] = prompt_metadata.pop("detached_model_url", None)
551
+ prompt_metadata["prompt_url"] = prompt_metadata.pop("detached_prompt_url", None)
552
+ prompt_metadata["prompt_additional_info"] = prompt_metadata.pop(
553
+ "detached_prompt_additional_info",
554
+ None,
555
+ )
556
+
557
+ # Update list of vars to dict
558
+ prompt_metadata["prompt_variables"] = Dict.fromkeys(
559
+ prompt_metadata["prompt_variables"], ""
560
+ )
561
+
562
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
563
+
564
+ if not self._wos_client:
565
+ try:
566
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
567
+ from ibm_cloud_sdk_core.authenticators import (
568
+ CloudPakForDataAuthenticator, # type: ignore
569
+ )
570
+
571
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
572
+ self._wos_client = WosAPIClient(
573
+ authenticator=authenticator,
574
+ service_url=self._wos_cpd_creds["url"],
575
+ )
576
+
577
+ else:
578
+ from ibm_cloud_sdk_core.authenticators import (
579
+ IAMAuthenticator, # type: ignore
580
+ )
581
+
582
+ authenticator = IAMAuthenticator(apikey=self._api_key)
583
+ self._wos_client = WosAPIClient(
584
+ authenticator=authenticator,
585
+ service_url=REGIONS_URL[self.region]["wos"],
586
+ )
587
+
588
+ except Exception as e:
589
+ logging.error(
590
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
591
+ )
592
+ raise
593
+
594
+ detached_details = _filter_dict(
595
+ prompt_metadata,
596
+ ["model_name", "model_url", "prompt_url", "prompt_additional_info"],
597
+ ["model_id", "model_provider"],
598
+ )
599
+ detached_details["prompt_id"] = "detached_prompt_" + str(uuid.uuid4())
600
+
601
+ prompt_details = _filter_dict(
602
+ prompt_metadata,
603
+ ["prompt_variables", "input", "model_parameters"],
604
+ )
605
+
606
+ detached_asset_details = _filter_dict(
607
+ prompt_metadata,
608
+ ["description"],
609
+ ["name", "model_id", "task_id"],
610
+ )
611
+
612
+ detached_pta_id = suppress_output(
613
+ self._create_detached_prompt,
614
+ detached_details,
615
+ prompt_details,
616
+ detached_asset_details,
617
+ )
618
+ deployment_id = None
619
+ if self._container_type == "space":
620
+ deployment_id = suppress_output(
621
+ self._create_deployment_pta, detached_pta_id, name, model_id
622
+ )
623
+
624
+ monitors = {
625
+ "generative_ai_quality": {
626
+ "parameters": {"min_sample_size": 10, "metrics_configuration": {}},
627
+ },
628
+ }
629
+
630
+ max_attempt_execute_prompt_setup = 0
631
+ while max_attempt_execute_prompt_setup < 2:
632
+ try:
633
+ generative_ai_monitor_details = suppress_output(
634
+ self._wos_client.wos.execute_prompt_setup,
635
+ prompt_template_asset_id=detached_pta_id,
636
+ space_id=self.space_id,
637
+ project_id=self.project_id,
638
+ deployment_id=deployment_id,
639
+ label_column="reference_output",
640
+ context_fields=context_fields,
641
+ question_field=question_field,
642
+ operational_space_id=self._deployment_stage,
643
+ problem_type=task_id,
644
+ data_input_locale=[locale],
645
+ generated_output_locale=[locale],
646
+ input_data_type="unstructured_text",
647
+ supporting_monitors=monitors,
648
+ background_mode=False,
649
+ )
650
+
651
+ break
652
+
653
+ except Exception as e:
654
+ if (
655
+ e.code == 403
656
+ and "The user entitlement does not exist" in e.message
657
+ and max_attempt_execute_prompt_setup < 1
658
+ ):
659
+ max_attempt_execute_prompt_setup = (
660
+ max_attempt_execute_prompt_setup + 1
661
+ )
662
+
663
+ data_marts = self._wos_client.data_marts.list().result
664
+ if (data_marts.data_marts is None) or (not data_marts.data_marts):
665
+ raise ValueError(
666
+ "Error retrieving IBM watsonx.governance (openscale) data mart. \
667
+ Make sure the data mart are configured.",
668
+ )
669
+
670
+ data_mart_id = data_marts.data_marts[0].metadata.id
671
+
672
+ self._wos_client.wos.add_instance_mapping(
673
+ service_instance_id=data_mart_id,
674
+ space_id=self.space_id,
675
+ project_id=self.project_id,
676
+ )
677
+ else:
678
+ max_attempt_execute_prompt_setup = 2
679
+ raise
680
+
681
+ generative_ai_monitor_details = generative_ai_monitor_details.result._to_dict()
682
+
683
+ return {
684
+ "detached_prompt_template_asset_id": detached_pta_id,
685
+ "deployment_id": deployment_id,
686
+ "subscription_id": generative_ai_monitor_details["subscription_id"],
687
+ }
688
+
689
+ def store_payload_records(
690
+ self,
691
+ request_records: List[Dict],
692
+ subscription_id: str = None,
693
+ ) -> List[str]:
694
+ """
695
+ Stores records to the payload logging system.
696
+
697
+ Args:
698
+ request_records (List[Dict]): A list of records to be logged, where each record is represented as a dictionary.
699
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
700
+
701
+ Example:
702
+ ```python
703
+ wxgov_client.store_payload_records(
704
+ request_records=[
705
+ {
706
+ "context1": "value_context1",
707
+ "context2": "value_context1",
708
+ "input_query": "What's Beekeeper Framework?",
709
+ "generated_text": "Beekeeper is a data framework to make AI easier to work with.",
710
+ "input_token_count": 25,
711
+ "generated_token_count": 150,
712
+ }
713
+ ],
714
+ subscription_id="5d62977c-a53d-4b6d-bda1-7b79b3b9d1a0",
715
+ )
716
+ ```
717
+ """
718
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
719
+ from ibm_watson_openscale import APIClient as WosAPIClient
720
+ from ibm_watson_openscale.supporting_classes.enums import (
721
+ DataSetTypes,
722
+ TargetTypes,
723
+ )
724
+
725
+ # Expected behavior: Prefer using fn `subscription_id`.
726
+ # Fallback to `self.subscription_id` if `subscription_id` None or empty.
727
+ _subscription_id = subscription_id or self.subscription_id
728
+
729
+ if _subscription_id is None or _subscription_id == "":
730
+ raise ValueError(
731
+ "Unexpected value for 'subscription_id': Cannot be None or empty string."
732
+ )
733
+
734
+ if not self._wos_client:
735
+ try:
736
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
737
+ from ibm_cloud_sdk_core.authenticators import (
738
+ CloudPakForDataAuthenticator, # type: ignore
739
+ )
740
+
741
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
742
+ self._wos_client = WosAPIClient(
743
+ authenticator=authenticator,
744
+ service_url=self._wos_cpd_creds["url"],
745
+ )
746
+
747
+ else:
748
+ from ibm_cloud_sdk_core.authenticators import (
749
+ IAMAuthenticator, # type: ignore
750
+ )
751
+
752
+ authenticator = IAMAuthenticator(apikey=self._api_key)
753
+ self._wos_client = WosAPIClient(
754
+ authenticator=authenticator,
755
+ service_url=REGIONS_URL[self.region]["wos"],
756
+ )
757
+
758
+ except Exception as e:
759
+ logging.error(
760
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
761
+ )
762
+ raise
763
+
764
+ subscription_details = self._wos_client.subscriptions.get(
765
+ _subscription_id,
766
+ ).result
767
+ subscription_details = json.loads(str(subscription_details))
768
+
769
+ feature_fields = subscription_details["entity"]["asset_properties"][
770
+ "feature_fields"
771
+ ]
772
+
773
+ payload_data_set_id = (
774
+ self._wos_client.data_sets.list(
775
+ type=DataSetTypes.PAYLOAD_LOGGING,
776
+ target_target_id=_subscription_id,
777
+ target_target_type=TargetTypes.SUBSCRIPTION,
778
+ )
779
+ .result.data_sets[0]
780
+ .metadata.id
781
+ )
782
+
783
+ payload_data = _convert_payload_format(request_records, feature_fields)
784
+
785
+ suppress_output(
786
+ self._wos_client.data_sets.store_records,
787
+ data_set_id=payload_data_set_id,
788
+ request_body=payload_data,
789
+ background_mode=False,
790
+ )
791
+
792
+ return [data["scoring_id"] + "-1" for data in payload_data]
793
+
794
+ def __call__(self, payload: PayloadRecord) -> None:
795
+ if self.prompt_template:
796
+ template_vars = extract_template_vars(
797
+ self.prompt_template.template, payload.input_text
798
+ )
799
+
800
+ if not template_vars:
801
+ self.store_payload_records([payload.model_dump()])
802
+ else:
803
+ self.store_payload_records([{**payload.model_dump(), **template_vars}])
804
+
805
+
806
+ class WatsonxPromptMonitor(PromptMonitor):
807
+ """
808
+ Provides functionality to interact with IBM watsonx.governance for monitoring IBM watsonx.ai LLMs.
809
+
810
+ Note:
811
+ One of the following parameters is required to create a prompt monitor:
812
+ `project_id` or `space_id`, but not both.
813
+
814
+ Attributes:
815
+ api_key (str): The API key for IBM watsonx.governance.
816
+ space_id (str, optional): The space ID in watsonx.governance.
817
+ project_id (str, optional): The project ID in watsonx.governance.
818
+ region (str, optional): The region where watsonx.governance is hosted when using IBM Cloud.
819
+ Defaults to `us-south`.
820
+ cpd_creds (CloudPakforDataCredentials, optional): The Cloud Pak for Data environment credentials.
821
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
822
+
823
+ Example:
824
+ ```python
825
+ from beekeeper.monitors.watsonx import (
826
+ WatsonxPromptMonitor,
827
+ CloudPakforDataCredentials,
828
+ )
829
+
830
+ # watsonx.governance (IBM Cloud)
831
+ wxgov_client = WatsonxPromptMonitor(api_key="API_KEY", space_id="SPACE_ID")
832
+
833
+ # watsonx.governance (CP4D)
834
+ cpd_creds = CloudPakforDataCredentials(
835
+ url="CPD_URL",
836
+ username="USERNAME",
837
+ password="PASSWORD",
838
+ version="5.0",
839
+ instance_id="openshift",
840
+ )
841
+
842
+ wxgov_client = WatsonxPromptMonitor(space_id="SPACE_ID", cpd_creds=cpd_creds)
843
+ ```
844
+ """
845
+
846
+ def __init__(
847
+ self,
848
+ api_key: str = None,
849
+ space_id: str = None,
850
+ project_id: str = None,
851
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
852
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
853
+ subscription_id: str = None,
854
+ **kwargs,
855
+ ) -> None:
856
+ import ibm_aigov_facts_client # noqa: F401
857
+ import ibm_cloud_sdk_core.authenticators # noqa: F401
858
+ import ibm_watson_openscale # noqa: F401
859
+ import ibm_watsonx_ai # noqa: F401
860
+
861
+ super().__init__(**kwargs)
862
+
863
+ self.space_id = space_id
864
+ self.project_id = project_id
865
+ self.region = region
866
+ self.subscription_id = subscription_id
867
+ self._api_key = api_key
868
+ self._wos_client = None
869
+
870
+ self._container_id = space_id if space_id else project_id
871
+ self._container_type = "space" if space_id else "project"
872
+ self._deployment_stage = "production" if space_id else "development"
873
+
874
+ if cpd_creds:
875
+ self._wos_cpd_creds = _filter_dict(
876
+ cpd_creds.to_dict(),
877
+ ["username", "password", "api_key", "disable_ssl_verification"],
878
+ ["url"],
879
+ )
880
+ self._fact_cpd_creds = _filter_dict(
881
+ cpd_creds.to_dict(),
882
+ ["username", "password", "api_key", "bedrock_url"],
883
+ ["url"],
884
+ )
885
+ self._fact_cpd_creds["service_url"] = self._fact_cpd_creds.pop("url")
886
+ self._wml_cpd_creds = _filter_dict(
887
+ cpd_creds.to_dict(),
888
+ [
889
+ "username",
890
+ "password",
891
+ "api_key",
892
+ "instance_id",
893
+ "version",
894
+ "bedrock_url",
895
+ ],
896
+ ["url"],
897
+ )
898
+
899
+ def _create_prompt_template(
900
+ self,
901
+ prompt_template_details: Dict,
902
+ asset_details: Dict,
903
+ ) -> str:
904
+ from ibm_aigov_facts_client import (
905
+ AIGovFactsClient,
906
+ CloudPakforDataConfig,
907
+ PromptTemplate,
908
+ )
909
+
910
+ try:
911
+ if hasattr(self, "_fact_cpd_creds") and self._fact_cpd_creds:
912
+ cpd_creds = CloudPakforDataConfig(**self._fact_cpd_creds)
913
+
914
+ aigov_client = AIGovFactsClient(
915
+ container_id=self._container_id,
916
+ container_type=self._container_type,
917
+ cloud_pak_for_data_configs=cpd_creds,
918
+ disable_tracing=True,
919
+ )
920
+
921
+ else:
922
+ aigov_client = AIGovFactsClient(
923
+ api_key=self._api_key,
924
+ container_id=self._container_id,
925
+ container_type=self._container_type,
926
+ disable_tracing=True,
927
+ region=REGIONS_URL[self.region]["factsheet"],
928
+ )
929
+
930
+ except Exception as e:
931
+ logging.error(
932
+ f"Error connecting to IBM watsonx.governance (factsheets): {e}",
933
+ )
934
+ raise
935
+
936
+ created_pta = aigov_client.assets.create_prompt(
937
+ **asset_details,
938
+ input_mode="freeform",
939
+ prompt_details=PromptTemplate(**prompt_template_details),
940
+ )
941
+
942
+ return created_pta.to_dict()["asset_id"]
943
+
944
+ def _create_deployment_pta(self, asset_id: str, name: str, model_id: str) -> str:
945
+ from ibm_watsonx_ai import APIClient, Credentials # type: ignore
946
+
947
+ try:
948
+ if hasattr(self, "_wml_cpd_creds") and self._wml_cpd_creds:
949
+ creds = Credentials(**self._wml_cpd_creds)
950
+
951
+ wml_client = APIClient(creds)
952
+ wml_client.set.default_space(self.space_id)
953
+
954
+ else:
955
+ creds = Credentials(
956
+ url=REGIONS_URL[self.region]["wml"],
957
+ api_key=self._api_key,
958
+ )
959
+
960
+ wml_client = APIClient(creds)
961
+ wml_client.set.default_space(self.space_id)
962
+
963
+ except Exception as e:
964
+ logging.error(f"Error connecting to IBM watsonx.ai Runtime: {e}")
965
+ raise
966
+
967
+ meta_props = {
968
+ wml_client.deployments.ConfigurationMetaNames.PROMPT_TEMPLATE: {
969
+ "id": asset_id,
970
+ },
971
+ wml_client.deployments.ConfigurationMetaNames.FOUNDATION_MODEL: {},
972
+ wml_client.deployments.ConfigurationMetaNames.NAME: name
973
+ + " "
974
+ + "deployment",
975
+ wml_client.deployments.ConfigurationMetaNames.BASE_MODEL_ID: model_id,
976
+ }
977
+
978
+ created_deployment = wml_client.deployments.create(asset_id, meta_props)
979
+
980
+ return wml_client.deployments.get_uid(created_deployment)
981
+
982
+ @deprecated(
983
+ reason="'add_prompt_observer()' is deprecated and will be removed in a future version. Use 'add_prompt_monitor()' instead.",
984
+ version="1.0.5",
985
+ action="always",
986
+ )
987
+ def add_prompt_observer(
988
+ self,
989
+ name: str,
990
+ model_id: str,
991
+ task_id: Literal[
992
+ "extraction",
993
+ "generation",
994
+ "question_answering",
995
+ "retrieval_augmented_generation",
996
+ "summarization",
997
+ ],
998
+ description: str = "",
999
+ model_parameters: Dict = None,
1000
+ prompt_variables: List[str] = None,
1001
+ locale: str = "en",
1002
+ input_text: str = None,
1003
+ context_fields: List[str] = None,
1004
+ question_field: str = None,
1005
+ ) -> Dict:
1006
+ return self.add_prompt_monitor(
1007
+ name=name,
1008
+ model_id=model_id,
1009
+ task_id=task_id,
1010
+ description=description,
1011
+ model_parameters=model_parameters,
1012
+ prompt_variables=prompt_variables,
1013
+ locale=locale,
1014
+ input_text=input_text,
1015
+ context_fields=context_fields,
1016
+ question_field=question_field,
1017
+ )
1018
+
1019
+ def add_prompt_monitor(
1020
+ self,
1021
+ name: str,
1022
+ model_id: str,
1023
+ task_id: Literal[
1024
+ "extraction",
1025
+ "generation",
1026
+ "question_answering",
1027
+ "retrieval_augmented_generation",
1028
+ "summarization",
1029
+ ],
1030
+ description: str = "",
1031
+ model_parameters: Dict = None,
1032
+ prompt_variables: List[str] = None,
1033
+ locale: str = "en",
1034
+ input_text: str = None,
1035
+ context_fields: List[str] = None,
1036
+ question_field: str = None,
1037
+ ) -> Dict:
1038
+ """
1039
+ Creates an IBM Prompt Template Asset and ssetup monitor for the given prompt template asset.
1040
+
1041
+ Args:
1042
+ name (str): The name of the Prompt Template Asset.
1043
+ model_id (str): The ID of the model associated with the prompt.
1044
+ task_id (str): The task identifier.
1045
+ description (str, optional): A description of the Prompt Template Asset.
1046
+ model_parameters (Dict, optional): A dictionary of model parameters and their respective values.
1047
+ prompt_variables (List[str], optional): A list of values for prompt input variables.
1048
+ locale (str, optional): Locale code for the input/output language. eg. "en", "pt", "es".
1049
+ input_text (str, optional): The input text for the prompt.
1050
+ context_fields (List[str], optional): A list of fields that will provide context to the prompt.
1051
+ Applicable only for the `retrieval_augmented_generation` task type.
1052
+ question_field (str, optional): The field containing the question to be answered.
1053
+ Applicable only for the `retrieval_augmented_generation` task type.
1054
+
1055
+ Example:
1056
+ ```python
1057
+ wxgov_client.add_prompt_monitor(
1058
+ name="IBM prompt template",
1059
+ model_id="ibm/granite-3-2b-instruct",
1060
+ task_id="retrieval_augmented_generation",
1061
+ prompt_variables=["context1", "context2", "input_query"],
1062
+ input_text="Prompt text to be given",
1063
+ context_fields=["context1", "context2"],
1064
+ question_field="input_query",
1065
+ )
1066
+ ```
1067
+ """
1068
+ if (not (self.project_id or self.space_id)) or (
1069
+ self.project_id and self.space_id
1070
+ ):
1071
+ raise ValueError(
1072
+ "Invalid configuration: Neither was provided: please set either 'project_id' or 'space_id'. "
1073
+ "Both were provided: 'project_id' and 'space_id' cannot be set at the same time."
1074
+ )
1075
+
1076
+ if task_id == "retrieval_augmented_generation":
1077
+ if not context_fields or not question_field:
1078
+ raise ValueError(
1079
+ "For 'retrieval_augmented_generation' task, requires non-empty 'context_fields' and 'question_field'."
1080
+ )
1081
+
1082
+ prompt_metadata = locals()
1083
+ # Remove unused vars from dict
1084
+ prompt_metadata.pop("self", None)
1085
+ prompt_metadata.pop("context_fields", None)
1086
+ prompt_metadata.pop("question_field", None)
1087
+ prompt_metadata.pop("locale", None)
1088
+
1089
+ # Update name of keys to aigov_facts api
1090
+ prompt_metadata["input"] = prompt_metadata.pop("input_text", None)
1091
+
1092
+ # Update list of vars to dict
1093
+ prompt_metadata["prompt_variables"] = Dict.fromkeys(
1094
+ prompt_metadata["prompt_variables"], ""
1095
+ )
1096
+
1097
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator # type: ignore
1098
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
1099
+
1100
+ if not self._wos_client:
1101
+ try:
1102
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1103
+ from ibm_cloud_sdk_core.authenticators import (
1104
+ CloudPakForDataAuthenticator, # type: ignore
1105
+ )
1106
+
1107
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1108
+
1109
+ self._wos_client = WosAPIClient(
1110
+ authenticator=authenticator,
1111
+ service_url=self._wos_cpd_creds["url"],
1112
+ )
1113
+
1114
+ else:
1115
+ from ibm_cloud_sdk_core.authenticators import (
1116
+ IAMAuthenticator, # type: ignore
1117
+ )
1118
+
1119
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1120
+ self._wos_client = WosAPIClient(
1121
+ authenticator=authenticator,
1122
+ service_url=REGIONS_URL[self.region]["wos"],
1123
+ )
1124
+
1125
+ except Exception as e:
1126
+ logging.error(
1127
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1128
+ )
1129
+ raise
1130
+
1131
+ prompt_details = _filter_dict(
1132
+ prompt_metadata,
1133
+ ["prompt_variables", "input", "model_parameters"],
1134
+ )
1135
+
1136
+ asset_details = _filter_dict(
1137
+ prompt_metadata,
1138
+ ["description"],
1139
+ ["name", "model_id", "task_id"],
1140
+ )
1141
+
1142
+ pta_id = suppress_output(
1143
+ self._create_prompt_template, prompt_details, asset_details
1144
+ )
1145
+ deployment_id = None
1146
+ if self._container_type == "space":
1147
+ deployment_id = suppress_output(
1148
+ self._create_deployment_pta, pta_id, name, model_id
1149
+ )
1150
+
1151
+ monitors = {
1152
+ "generative_ai_quality": {
1153
+ "parameters": {"min_sample_size": 10, "metrics_configuration": {}},
1154
+ },
1155
+ }
1156
+
1157
+ max_attempt_execute_prompt_setup = 0
1158
+ while max_attempt_execute_prompt_setup < 2:
1159
+ try:
1160
+ generative_ai_monitor_details = suppress_output(
1161
+ self._wos_client.wos.execute_prompt_setup,
1162
+ prompt_template_asset_id=pta_id,
1163
+ space_id=self.space_id,
1164
+ project_id=self.project_id,
1165
+ deployment_id=deployment_id,
1166
+ label_column="reference_output",
1167
+ context_fields=context_fields,
1168
+ question_field=question_field,
1169
+ operational_space_id=self._deployment_stage,
1170
+ problem_type=task_id,
1171
+ data_input_locale=[locale],
1172
+ generated_output_locale=[locale],
1173
+ input_data_type="unstructured_text",
1174
+ supporting_monitors=monitors,
1175
+ background_mode=False,
1176
+ ).result
1177
+
1178
+ break
1179
+
1180
+ except Exception as e:
1181
+ if (
1182
+ e.code == 403
1183
+ and "The user entitlement does not exist" in e.message
1184
+ and max_attempt_execute_prompt_setup < 1
1185
+ ):
1186
+ max_attempt_execute_prompt_setup = (
1187
+ max_attempt_execute_prompt_setup + 1
1188
+ )
1189
+
1190
+ data_marts = self._wos_client.data_marts.list().result
1191
+ if (data_marts.data_marts is None) or (not data_marts.data_marts):
1192
+ raise ValueError(
1193
+ "Error retrieving IBM watsonx.governance (openscale) data mart. \
1194
+ Make sure the data mart are configured.",
1195
+ )
1196
+
1197
+ data_mart_id = data_marts.data_marts[0].metadata.id
1198
+
1199
+ self._wos_client.wos.add_instance_mapping(
1200
+ service_instance_id=data_mart_id,
1201
+ space_id=self.space_id,
1202
+ project_id=self.project_id,
1203
+ )
1204
+ else:
1205
+ max_attempt_execute_prompt_setup = 2
1206
+ raise
1207
+
1208
+ generative_ai_monitor_details = generative_ai_monitor_details._to_dict()
1209
+
1210
+ return {
1211
+ "prompt_template_asset_id": pta_id,
1212
+ "deployment_id": deployment_id,
1213
+ "subscription_id": generative_ai_monitor_details["subscription_id"],
1214
+ }
1215
+
1216
+ def store_payload_records(
1217
+ self,
1218
+ request_records: List[Dict],
1219
+ subscription_id: str = None,
1220
+ ) -> List[str]:
1221
+ """
1222
+ Stores records to the payload logging system.
1223
+
1224
+ Args:
1225
+ request_records (List[Dict]): A list of records to be logged. Each record is represented as a dictionary.
1226
+ subscription_id (str, optional): The subscription ID associated with the records being logged.
1227
+
1228
+ Example:
1229
+ ```python
1230
+ wxgov_client.store_payload_records(
1231
+ request_records=[
1232
+ {
1233
+ "context1": "value_context1",
1234
+ "context2": "value_context1",
1235
+ "input_query": "What's Beekeeper Framework?",
1236
+ "generated_text": "Beekeeper is a data framework to make AI easier to work with.",
1237
+ "input_token_count": 25,
1238
+ "generated_token_count": 150,
1239
+ }
1240
+ ],
1241
+ subscription_id="5d62977c-a53d-4b6d-bda1-7b79b3b9d1a0",
1242
+ )
1243
+ ```
1244
+ """
1245
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
1246
+ from ibm_watson_openscale import APIClient as WosAPIClient
1247
+ from ibm_watson_openscale.supporting_classes.enums import (
1248
+ DataSetTypes,
1249
+ TargetTypes,
1250
+ )
1251
+
1252
+ # Expected behavior: Prefer using fn `subscription_id`.
1253
+ # Fallback to `self.subscription_id` if `subscription_id` None or empty.
1254
+ _subscription_id = subscription_id or self.subscription_id
1255
+
1256
+ if _subscription_id is None or _subscription_id == "":
1257
+ raise ValueError(
1258
+ "Unexpected value for 'subscription_id': Cannot be None or empty string."
1259
+ )
1260
+
1261
+ if not self._wos_client:
1262
+ try:
1263
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1264
+ from ibm_cloud_sdk_core.authenticators import (
1265
+ CloudPakForDataAuthenticator, # type: ignore
1266
+ )
1267
+
1268
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1269
+
1270
+ self._wos_client = WosAPIClient(
1271
+ authenticator=authenticator,
1272
+ service_url=self._wos_cpd_creds["url"],
1273
+ )
1274
+
1275
+ else:
1276
+ from ibm_cloud_sdk_core.authenticators import (
1277
+ IAMAuthenticator, # type: ignore
1278
+ )
1279
+
1280
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1281
+ self._wos_client = WosAPIClient(
1282
+ authenticator=authenticator,
1283
+ service_url=REGIONS_URL[self.region]["wos"],
1284
+ )
1285
+
1286
+ except Exception as e:
1287
+ logging.error(
1288
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1289
+ )
1290
+ raise
1291
+
1292
+ subscription_details = self._wos_client.subscriptions.get(
1293
+ _subscription_id,
1294
+ ).result
1295
+ subscription_details = json.loads(str(subscription_details))
1296
+
1297
+ feature_fields = subscription_details["entity"]["asset_properties"][
1298
+ "feature_fields"
1299
+ ]
1300
+
1301
+ payload_data_set_id = (
1302
+ self._wos_client.data_sets.list(
1303
+ type=DataSetTypes.PAYLOAD_LOGGING,
1304
+ target_target_id=_subscription_id,
1305
+ target_target_type=TargetTypes.SUBSCRIPTION,
1306
+ )
1307
+ .result.data_sets[0]
1308
+ .metadata.id
1309
+ )
1310
+
1311
+ payload_data = _convert_payload_format(request_records, feature_fields)
1312
+
1313
+ suppress_output(
1314
+ self._wos_client.data_sets.store_records,
1315
+ data_set_id=payload_data_set_id,
1316
+ request_body=payload_data,
1317
+ background_mode=False,
1318
+ )
1319
+
1320
+ return [data["scoring_id"] + "-1" for data in payload_data]
1321
+
1322
+ def __call__(self, payload: PayloadRecord) -> None:
1323
+ if self.prompt_template:
1324
+ template_vars = extract_template_vars(
1325
+ self.prompt_template.template, payload.input_text
1326
+ )
1327
+
1328
+ if not template_vars:
1329
+ self.store_payload_records([payload.model_dump()])
1330
+ else:
1331
+ self.store_payload_records([{**payload.model_dump(), **template_vars}])
1332
+
1333
+
1334
+ # ===== Supporting Classes =====
1335
+ class WatsonxLocalMetric(BaseModel):
1336
+ """
1337
+ Provides the IBM watsonx.governance local monitor metric definition.
1338
+
1339
+ Attributes:
1340
+ name (str): The name of the metric.
1341
+ data_type (str): The data type of the metric. Currently supports "string", "integer", "double", and "timestamp".
1342
+ nullable (bool, optional): Indicates whether the metric can be null. Defaults to `False`.
1343
+
1344
+ Example:
1345
+ ```python
1346
+ from beekeeper.monitors.watsonx import WatsonxLocalMetric
1347
+
1348
+ WatsonxLocalMetric(name="context_quality", data_type="double")
1349
+ ```
1350
+ """
1351
+
1352
+ name: str
1353
+ data_type: Literal["string", "integer", "double", "timestamp"]
1354
+ nullable: bool = True
1355
+
1356
+ def to_dict(self) -> Dict:
1357
+ return {"name": self.name, "type": self.data_type, "nullable": self.nullable}
1358
+
1359
+
1360
+ class WatsonxMetricThreshold(BaseModel):
1361
+ """
1362
+ Defines the metric threshold for IBM watsonx.governance.
1363
+
1364
+ Attributes:
1365
+ threshold_type (str): The threshold type. Can be either `lower_limit` or `upper_limit`.
1366
+ default_value (float): The metric threshold value.
1367
+
1368
+ Example:
1369
+ ```python
1370
+ from beekeeper.monitors.watsonx import WatsonxMetricThreshold
1371
+
1372
+ WatsonxMetricThreshold(threshold_type="lower_limit", default_value=0.8)
1373
+ ```
1374
+ """
1375
+
1376
+ threshold_type: Literal["lower_limit", "upper_limit"]
1377
+ default_value: float = None
1378
+
1379
+ def to_dict(self) -> Dict:
1380
+ return {"type": self.threshold_type, "default": self.default_value}
1381
+
1382
+
1383
+ class WatsonxMetric(BaseModel):
1384
+ """
1385
+ Defines the IBM watsonx.governance global monitor metric.
1386
+
1387
+ Attributes:
1388
+ name (str): The name of the metric.
1389
+ applies_to (List[str]): A list of task types that the metric applies to. Currently supports:
1390
+ "summarization", "generation", "question_answering", "extraction", and "retrieval_augmented_generation".
1391
+ thresholds (List[WatsonxMetricThreshold]): A list of metric thresholds associated with the metric.
1392
+
1393
+ Example:
1394
+ ```python
1395
+ from beekeeper.monitors.watsonx import (
1396
+ WatsonxMetric,
1397
+ WatsonxMetricThreshold,
1398
+ )
1399
+
1400
+ WatsonxMetric(
1401
+ name="context_quality",
1402
+ applies_to=["retrieval_augmented_generation", "summarization"],
1403
+ thresholds=[
1404
+ WatsonxMetricThreshold(threshold_type="lower_limit", default_value=0.75)
1405
+ ],
1406
+ )
1407
+ ```
1408
+ """
1409
+
1410
+ name: str
1411
+ applies_to: List[
1412
+ Literal[
1413
+ "summarization",
1414
+ "generation",
1415
+ "question_answering",
1416
+ "extraction",
1417
+ "retrieval_augmented_generation",
1418
+ ]
1419
+ ]
1420
+ thresholds: Optional[List[WatsonxMetricThreshold]] = None
1421
+
1422
+ def to_dict(self) -> Dict:
1423
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1424
+ ApplicabilitySelection,
1425
+ MetricThreshold,
1426
+ )
1427
+
1428
+ monitor_metric = {
1429
+ "name": self.name,
1430
+ "applies_to": ApplicabilitySelection(problem_type=self.applies_to),
1431
+ }
1432
+
1433
+ if self.thresholds is not None:
1434
+ monitor_metric["thresholds"] = [
1435
+ MetricThreshold(**threshold.to_dict()) for threshold in self.thresholds
1436
+ ]
1437
+
1438
+ return monitor_metric
1439
+
1440
+
1441
+ # ===== Metric Classes =====
1442
+ class WatsonxCustomMetric:
1443
+ """
1444
+ Provides functionality to set up a custom metric to measure your model's performance with IBM watsonx.governance.
1445
+
1446
+ Attributes:
1447
+ api_key (str): The API key for IBM watsonx.governance.
1448
+ region (str, optional): The region where IBM watsonx.governance is hosted when using IBM Cloud.
1449
+ Defaults to `us-south`.
1450
+ cpd_creds (CloudPakforDataCredentials, optional): IBM Cloud Pak for Data environment credentials.
1451
+
1452
+ Example:
1453
+ ```python
1454
+ from beekeeper.monitors.watsonx import (
1455
+ WatsonxCustomMetric,
1456
+ CloudPakforDataCredentials,
1457
+ )
1458
+
1459
+ # watsonx.governance (IBM Cloud)
1460
+ wxgov_client = WatsonxCustomMetric(api_key="API_KEY")
1461
+
1462
+ # watsonx.governance (CP4D)
1463
+ cpd_creds = CloudPakforDataCredentials(
1464
+ url="CPD_URL",
1465
+ username="USERNAME",
1466
+ password="PASSWORD",
1467
+ version="5.0",
1468
+ instance_id="openshift",
1469
+ )
1470
+
1471
+ wxgov_client = WatsonxCustomMetric(cpd_creds=cpd_creds)
1472
+ ```
1473
+ """
1474
+
1475
+ def __init__(
1476
+ self,
1477
+ api_key: str = None,
1478
+ region: Literal["us-south", "eu-de", "au-syd"] = "us-south",
1479
+ cpd_creds: CloudPakforDataCredentials | Dict = None,
1480
+ ) -> None:
1481
+ from ibm_cloud_sdk_core.authenticators import IAMAuthenticator # type: ignore
1482
+ from ibm_watson_openscale import APIClient as WosAPIClient # type: ignore
1483
+
1484
+ self.region = region
1485
+ self._api_key = api_key
1486
+ self._wos_client = None
1487
+
1488
+ if cpd_creds:
1489
+ self._wos_cpd_creds = _filter_dict(
1490
+ cpd_creds.to_dict(),
1491
+ ["username", "password", "api_key", "disable_ssl_verification"],
1492
+ ["url"],
1493
+ )
1494
+
1495
+ if not self._wos_client:
1496
+ try:
1497
+ if hasattr(self, "_wos_cpd_creds") and self._wos_cpd_creds:
1498
+ from ibm_cloud_sdk_core.authenticators import (
1499
+ CloudPakForDataAuthenticator, # type: ignore
1500
+ )
1501
+
1502
+ authenticator = CloudPakForDataAuthenticator(**self._wos_cpd_creds)
1503
+
1504
+ self._wos_client = WosAPIClient(
1505
+ authenticator=authenticator,
1506
+ service_url=self._wos_cpd_creds["url"],
1507
+ )
1508
+
1509
+ else:
1510
+ from ibm_cloud_sdk_core.authenticators import (
1511
+ IAMAuthenticator, # type: ignore
1512
+ )
1513
+
1514
+ authenticator = IAMAuthenticator(apikey=self._api_key)
1515
+ self._wos_client = WosAPIClient(
1516
+ authenticator=authenticator,
1517
+ service_url=REGIONS_URL[self.region]["wos"],
1518
+ )
1519
+
1520
+ except Exception as e:
1521
+ logging.error(
1522
+ f"Error connecting to IBM watsonx.governance (openscale): {e}",
1523
+ )
1524
+ raise
1525
+
1526
+ def _add_integrated_system(
1527
+ self,
1528
+ credentials: IntegratedSystemCredentials,
1529
+ name: str,
1530
+ endpoint: str,
1531
+ ) -> str:
1532
+ custom_metrics_integrated_system = self._wos_client.integrated_systems.add(
1533
+ name=name,
1534
+ description="Integrated system created by Beekeeper.",
1535
+ type="custom_metrics_provider",
1536
+ credentials=credentials.to_dict(),
1537
+ connection={"display_name": name, "endpoint": endpoint},
1538
+ ).result
1539
+
1540
+ return custom_metrics_integrated_system.metadata.id
1541
+
1542
+ def _add_monitor_definitions(
1543
+ self,
1544
+ name: str,
1545
+ metrics: List[WatsonxMetric],
1546
+ schedule: bool,
1547
+ ):
1548
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1549
+ ApplicabilitySelection,
1550
+ MonitorInstanceSchedule,
1551
+ MonitorMetricRequest,
1552
+ MonitorRuntime,
1553
+ ScheduleStartTime,
1554
+ )
1555
+
1556
+ _metrics = [MonitorMetricRequest(**metric.to_dict()) for metric in metrics]
1557
+ _monitor_runtime = None
1558
+ _monitor_schedule = None
1559
+
1560
+ if schedule:
1561
+ _monitor_runtime = MonitorRuntime(type="custom_metrics_provider")
1562
+ _monitor_schedule = MonitorInstanceSchedule(
1563
+ repeat_interval=1,
1564
+ repeat_unit="hour",
1565
+ start_time=ScheduleStartTime(
1566
+ type="relative",
1567
+ delay_unit="minute",
1568
+ delay=30,
1569
+ ),
1570
+ )
1571
+
1572
+ custom_monitor_details = self._wos_client.monitor_definitions.add(
1573
+ name=name,
1574
+ metrics=_metrics,
1575
+ tags=[],
1576
+ schedule=_monitor_schedule,
1577
+ applies_to=ApplicabilitySelection(input_data_type=["unstructured_text"]),
1578
+ monitor_runtime=_monitor_runtime,
1579
+ background_mode=False,
1580
+ ).result
1581
+
1582
+ return custom_monitor_details.metadata.id
1583
+
1584
+ def _get_monitor_instance(self, subscription_id: str, monitor_definition_id: str):
1585
+ monitor_instances = self._wos_client.monitor_instances.list(
1586
+ monitor_definition_id=monitor_definition_id,
1587
+ target_target_id=subscription_id,
1588
+ ).result.monitor_instances
1589
+
1590
+ if len(monitor_instances) == 1:
1591
+ return monitor_instances[0]
1592
+ else:
1593
+ return None
1594
+
1595
+ def _update_monitor_instance(
1596
+ self,
1597
+ integrated_system_id: str,
1598
+ custom_monitor_id: str,
1599
+ ):
1600
+ payload = [
1601
+ {
1602
+ "op": "replace",
1603
+ "path": "/parameters",
1604
+ "value": {
1605
+ "custom_metrics_provider_id": integrated_system_id,
1606
+ "custom_metrics_wait_time": 60,
1607
+ "enable_custom_metric_runs": True,
1608
+ },
1609
+ },
1610
+ ]
1611
+
1612
+ return self._wos_client.monitor_instances.update(
1613
+ custom_monitor_id,
1614
+ payload,
1615
+ update_metadata_only=True,
1616
+ ).result
1617
+
1618
+ def _get_patch_request_field(
1619
+ self,
1620
+ field_path: str,
1621
+ field_value: Any,
1622
+ op_name: str = "replace",
1623
+ ) -> Dict:
1624
+ return {"op": op_name, "path": field_path, "value": field_value}
1625
+
1626
+ def _get_dataset_id(
1627
+ self,
1628
+ subscription_id: str,
1629
+ data_set_type: Literal["feedback", "payload_logging"],
1630
+ ) -> str:
1631
+ data_sets = self._wos_client.data_sets.list(
1632
+ target_target_id=subscription_id,
1633
+ type=data_set_type,
1634
+ ).result.data_sets
1635
+ data_set_id = None
1636
+ if len(data_sets) > 0:
1637
+ data_set_id = data_sets[0].metadata.id
1638
+ return data_set_id
1639
+
1640
+ def _get_dataset_data(self, data_set_id: str):
1641
+ json_data = self._wos_client.data_sets.get_list_of_records(
1642
+ data_set_id=data_set_id,
1643
+ format="list",
1644
+ ).result
1645
+
1646
+ if not json_data.get("records"):
1647
+ return None
1648
+
1649
+ return json_data["records"][0]
1650
+
1651
+ def _get_existing_data_mart(self):
1652
+ data_marts = self._wos_client.data_marts.list().result.data_marts
1653
+ if len(data_marts) == 0:
1654
+ raise Exception(
1655
+ "No data marts found. Please ensure at least one data mart is available.",
1656
+ )
1657
+
1658
+ return data_marts[0].metadata.id
1659
+
1660
+ # ===== Global Custom Metrics =====
1661
+ def add_metric_definition(
1662
+ self,
1663
+ name: str,
1664
+ metrics: List[WatsonxMetric],
1665
+ integrated_system_url: str,
1666
+ integrated_system_credentials: IntegratedSystemCredentials,
1667
+ schedule: bool = False,
1668
+ ):
1669
+ """
1670
+ Creates a custom monitor definition for IBM watsonx.governance.
1671
+
1672
+ This must be done before using custom metrics.
1673
+
1674
+ Args:
1675
+ name (str): The name of the custom metric group.
1676
+ metrics (List[WatsonxMetric]): A list of metrics to be measured.
1677
+ schedule (bool, optional): Enable or disable the scheduler. Defaults to `False`.
1678
+ integrated_system_url (str): The URL of the external metric provider.
1679
+ integrated_system_credentials (IntegratedSystemCredentials): The credentials for the integrated system.
1680
+
1681
+ Example:
1682
+ ```python
1683
+ from beekeeper.monitors.watsonx import (
1684
+ WatsonxMetric,
1685
+ IntegratedSystemCredentials,
1686
+ WatsonxMetricThreshold,
1687
+ )
1688
+
1689
+ wxgov_client.add_metric_definition(
1690
+ name="Custom Metric - Custom LLM Quality",
1691
+ metrics=[
1692
+ WatsonxMetric(
1693
+ name="context_quality",
1694
+ applies_to=[
1695
+ "retrieval_augmented_generation",
1696
+ "summarization",
1697
+ ],
1698
+ thresholds=[
1699
+ WatsonxMetricThreshold(
1700
+ threshold_type="lower_limit", default_value=0.75
1701
+ )
1702
+ ],
1703
+ )
1704
+ ],
1705
+ integrated_system_url="IS_URL", # URL to the endpoint computing the metric
1706
+ integrated_system_credentials=IntegratedSystemCredentials(
1707
+ auth_type="basic", username="USERNAME", password="PASSWORD"
1708
+ ),
1709
+ )
1710
+ ```
1711
+ """
1712
+ integrated_system_id = self._add_integrated_system(
1713
+ integrated_system_credentials,
1714
+ name,
1715
+ integrated_system_url,
1716
+ )
1717
+
1718
+ external_monitor_id = suppress_output(
1719
+ self._add_monitor_definitions,
1720
+ name,
1721
+ metrics,
1722
+ schedule,
1723
+ )
1724
+
1725
+ # Associate the external monitor with the integrated system
1726
+ payload = [
1727
+ {
1728
+ "op": "add",
1729
+ "path": "/parameters",
1730
+ "value": {"monitor_definition_ids": [external_monitor_id]},
1731
+ },
1732
+ ]
1733
+
1734
+ self._wos_client.integrated_systems.update(integrated_system_id, payload)
1735
+
1736
+ return {
1737
+ "integrated_system_id": integrated_system_id,
1738
+ "monitor_definition_id": external_monitor_id,
1739
+ }
1740
+
1741
+ @deprecated(
1742
+ reason="'add_observer_instance()' is deprecated and will be removed in a future version. Use 'add_monitor_instance()' from 'beekeeper-monitors-watsonx' instead.",
1743
+ version="1.0.5",
1744
+ action="always",
1745
+ )
1746
+ def add_observer_instance(
1747
+ self,
1748
+ integrated_system_id: str,
1749
+ monitor_definition_id: str,
1750
+ subscription_id: str,
1751
+ ):
1752
+ return self.add_monitor_instance(
1753
+ integrated_system_id=integrated_system_id,
1754
+ monitor_definition_id=monitor_definition_id,
1755
+ subscription_id=subscription_id,
1756
+ )
1757
+
1758
+ def add_monitor_instance(
1759
+ self,
1760
+ integrated_system_id: str,
1761
+ monitor_definition_id: str,
1762
+ subscription_id: str,
1763
+ ):
1764
+ """
1765
+ Enables a custom monitor for the specified subscription and monitor definition.
1766
+
1767
+ Args:
1768
+ integrated_system_id (str): The ID of the integrated system.
1769
+ monitor_definition_id (str): The ID of the custom metric monitor instance.
1770
+ subscription_id (str): The ID of the subscription to associate the monitor with.
1771
+
1772
+ Example:
1773
+ ```python
1774
+ wxgov_client.add_monitor_instance(
1775
+ integrated_system_id="019667ca-5687-7838-8d29-4ff70c2b36b0",
1776
+ monitor_definition_id="custom_llm_quality",
1777
+ subscription_id="0195e95d-03a4-7000-b954-b607db10fe9e",
1778
+ )
1779
+ ```
1780
+ """
1781
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import Target
1782
+
1783
+ data_marts = self._wos_client.data_marts.list().result.data_marts
1784
+ if len(data_marts) == 0:
1785
+ raise Exception(
1786
+ "No data marts found. Please ensure at least one data mart is available.",
1787
+ )
1788
+
1789
+ data_mart_id = data_marts[0].metadata.id
1790
+ existing_monitor_instance = self._get_monitor_instance(
1791
+ subscription_id,
1792
+ monitor_definition_id,
1793
+ )
1794
+
1795
+ if existing_monitor_instance is None:
1796
+ target = Target(target_type="subscription", target_id=subscription_id)
1797
+
1798
+ parameters = {
1799
+ "custom_metrics_provider_id": integrated_system_id,
1800
+ "custom_metrics_wait_time": 60,
1801
+ "enable_custom_metric_runs": True,
1802
+ }
1803
+
1804
+ monitor_instance_details = suppress_output(
1805
+ self._wos_client.monitor_instances.create,
1806
+ data_mart_id=data_mart_id,
1807
+ background_mode=False,
1808
+ monitor_definition_id=monitor_definition_id,
1809
+ target=target,
1810
+ parameters=parameters,
1811
+ ).result
1812
+ else:
1813
+ existing_instance_id = existing_monitor_instance.metadata.id
1814
+ monitor_instance_details = self._update_monitor_instance(
1815
+ integrated_system_id,
1816
+ existing_instance_id,
1817
+ )
1818
+
1819
+ return monitor_instance_details
1820
+
1821
+ def publish_metrics(
1822
+ self,
1823
+ monitor_instance_id: str,
1824
+ run_id: str,
1825
+ request_records: Dict[str, Union[float, int]],
1826
+ ):
1827
+ """
1828
+ Publishes computed custom metrics for a specific global monitor instance.
1829
+
1830
+ Args:
1831
+ monitor_instance_id (str): The unique ID of the monitor instance.
1832
+ run_id (str): The ID of the monitor run that generated the metrics.
1833
+ request_records (Dict[str | float | int]): Dict containing the metrics to be published.
1834
+
1835
+ Example:
1836
+ ```python
1837
+ wxgov_client.publish_metrics(
1838
+ monitor_instance_id="01966801-f9ee-7248-a706-41de00a8a998",
1839
+ run_id="RUN_ID",
1840
+ request_records={"context_quality": 0.914, "sensitivity": 0.85},
1841
+ )
1842
+ ```
1843
+ """
1844
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1845
+ MonitorMeasurementRequest,
1846
+ Runs,
1847
+ )
1848
+
1849
+ measurement_request = MonitorMeasurementRequest(
1850
+ timestamp=datetime.datetime.now(datetime.timezone.utc).strftime(
1851
+ "%Y-%m-%dT%H:%M:%S.%fZ",
1852
+ ),
1853
+ run_id=run_id,
1854
+ metrics=[request_records],
1855
+ )
1856
+
1857
+ self._wos_client.monitor_instances.add_measurements(
1858
+ monitor_instance_id=monitor_instance_id,
1859
+ monitor_measurement_request=[measurement_request],
1860
+ ).result
1861
+
1862
+ run = Runs(watson_open_scale=self._wos_client)
1863
+ patch_payload = []
1864
+ patch_payload.append(self._get_patch_request_field("/status/state", "finished"))
1865
+ patch_payload.append(
1866
+ self._get_patch_request_field(
1867
+ "/status/completed_at",
1868
+ datetime.datetime.now(datetime.timezone.utc).strftime(
1869
+ "%Y-%m-%dT%H:%M:%S.%fZ",
1870
+ ),
1871
+ ),
1872
+ )
1873
+
1874
+ return run.update(
1875
+ monitor_instance_id=monitor_instance_id,
1876
+ monitoring_run_id=run_id,
1877
+ json_patch_operation=patch_payload,
1878
+ ).result
1879
+
1880
+ # ===== Local Custom Metrics =====
1881
+ def add_local_metric_definition(
1882
+ self,
1883
+ name: str,
1884
+ metrics: List[WatsonxLocalMetric],
1885
+ subscription_id: str,
1886
+ ) -> str:
1887
+ """
1888
+ Creates a custom metric definition to compute metrics at the local (transaction) level for IBM watsonx.governance.
1889
+
1890
+ Args:
1891
+ name (str): The name of the custom transaction metric group.
1892
+ metrics (List[WatsonxLocalMetric]): A list of metrics to be monitored at the local (transaction) level.
1893
+ subscription_id (str): The IBM watsonx.governance subscription ID associated with the metric definition.
1894
+
1895
+ Example:
1896
+ ```python
1897
+ from beekeeper.monitors.watsonx import WatsonxLocalMetric
1898
+
1899
+ wxgov_client.add_local_metric_definition(
1900
+ name="Custom LLM Local Metric",
1901
+ subscription_id="019674ca-0c38-745f-8e9b-58546e95174e",
1902
+ metrics=[
1903
+ WatsonxLocalMetric(name="context_quality", data_type="double")
1904
+ ],
1905
+ )
1906
+ ```
1907
+ """
1908
+ from ibm_watson_openscale.base_classes.watson_open_scale_v2 import (
1909
+ LocationTableName,
1910
+ SparkStruct,
1911
+ SparkStructFieldPrimitive,
1912
+ Target,
1913
+ )
1914
+
1915
+ target = Target(target_id=subscription_id, target_type="subscription")
1916
+ data_mart_id = self._get_existing_data_mart()
1917
+ metrics = [SparkStructFieldPrimitive(**metric.to_dict()) for metric in metrics]
1918
+
1919
+ schema_fields = [
1920
+ SparkStructFieldPrimitive(
1921
+ name="scoring_id",
1922
+ type="string",
1923
+ nullable=False,
1924
+ ),
1925
+ SparkStructFieldPrimitive(
1926
+ name="run_id",
1927
+ type="string",
1928
+ nullable=True,
1929
+ ),
1930
+ SparkStructFieldPrimitive(
1931
+ name="computed_on",
1932
+ type="string",
1933
+ nullable=False,
1934
+ ),
1935
+ ]
1936
+
1937
+ schema_fields.extend(metrics)
1938
+
1939
+ data_schema = SparkStruct(type="struct", fields=schema_fields)
1940
+
1941
+ return self._wos_client.data_sets.add(
1942
+ target=target,
1943
+ name=name,
1944
+ type="custom",
1945
+ data_schema=data_schema,
1946
+ data_mart_id=data_mart_id,
1947
+ location=LocationTableName(
1948
+ table_name=name.lower().replace(" ", "_") + "_" + str(uuid.uuid4())[:8],
1949
+ ),
1950
+ background_mode=False,
1951
+ ).result.metadata.id
1952
+
1953
+ def publish_local_metrics(
1954
+ self,
1955
+ metric_instance_id: str,
1956
+ request_records: List[Dict],
1957
+ ):
1958
+ """
1959
+ Publishes computed custom metrics for a specific transaction record.
1960
+
1961
+ Args:
1962
+ metric_instance_id (str): The unique ID of the custom transaction metric.
1963
+ request_records (List[Dict]): A list of dictionaries containing the records to be stored.
1964
+
1965
+ Example:
1966
+ ```python
1967
+ wxgov_client.publish_local_metrics(
1968
+ metric_instance_id="0196ad39-1b75-7e77-bddb-cc5393d575c2",
1969
+ request_records=[
1970
+ {
1971
+ "scoring_id": "304a9270-44a1-4c4d-bfd4-f756541011f8",
1972
+ "run_id": "RUN_ID",
1973
+ "computed_on": "payload",
1974
+ "context_quality": 0.786,
1975
+ }
1976
+ ],
1977
+ )
1978
+ ```
1979
+ """
1980
+ return self._wos_client.data_sets.store_records(
1981
+ data_set_id=metric_instance_id,
1982
+ request_body=request_records,
1983
+ ).result
1984
+
1985
+ def list_local_metrics(
1986
+ self,
1987
+ metric_instance_id: str,
1988
+ ):
1989
+ """
1990
+ Lists records from a custom local metric definition.
1991
+
1992
+ Args:
1993
+ metric_instance_id (str): The unique ID of the custom transaction metric.
1994
+
1995
+ Example:
1996
+ ```python
1997
+ wxgov_client.list_local_metrics(
1998
+ metric_instance_id="0196ad47-c505-73c0-9d7b-91c082b697e3"
1999
+ )
2000
+ ```
2001
+ """
2002
+ return self._get_dataset_data(metric_instance_id)