ibm-watsonx-gov 1.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_gov/__init__.py +8 -0
- ibm_watsonx_gov/agent_catalog/__init__.py +8 -0
- ibm_watsonx_gov/agent_catalog/clients/__init__.py +14 -0
- ibm_watsonx_gov/agent_catalog/clients/ai_agent_client.py +333 -0
- ibm_watsonx_gov/agent_catalog/core/__init__.py +8 -0
- ibm_watsonx_gov/agent_catalog/core/agent_loader.py +202 -0
- ibm_watsonx_gov/agent_catalog/core/agents.py +134 -0
- ibm_watsonx_gov/agent_catalog/entities/__init__.py +8 -0
- ibm_watsonx_gov/agent_catalog/entities/ai_agent.py +599 -0
- ibm_watsonx_gov/agent_catalog/utils/__init__.py +8 -0
- ibm_watsonx_gov/agent_catalog/utils/constants.py +36 -0
- ibm_watsonx_gov/agent_catalog/utils/notebook_utils.py +70 -0
- ibm_watsonx_gov/ai_experiments/__init__.py +8 -0
- ibm_watsonx_gov/ai_experiments/ai_experiments_client.py +980 -0
- ibm_watsonx_gov/ai_experiments/utils/__init__.py +8 -0
- ibm_watsonx_gov/ai_experiments/utils/ai_experiment_utils.py +139 -0
- ibm_watsonx_gov/clients/__init__.py +0 -0
- ibm_watsonx_gov/clients/api_client.py +99 -0
- ibm_watsonx_gov/clients/segment_client.py +46 -0
- ibm_watsonx_gov/clients/usage_client.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/clients/wx_ai_client.py +87 -0
- ibm_watsonx_gov/config/__init__.py +14 -0
- ibm_watsonx_gov/config/agentic_ai_configuration.py +225 -0
- ibm_watsonx_gov/config/gen_ai_configuration.py +129 -0
- ibm_watsonx_gov/config/model_risk_configuration.py +173 -0
- ibm_watsonx_gov/config/predictive_ai_configuration.py +20 -0
- ibm_watsonx_gov/entities/__init__.py +8 -0
- ibm_watsonx_gov/entities/agentic_app.py +209 -0
- ibm_watsonx_gov/entities/agentic_evaluation_result.py +185 -0
- ibm_watsonx_gov/entities/ai_evaluation.py +290 -0
- ibm_watsonx_gov/entities/ai_experiment.py +419 -0
- ibm_watsonx_gov/entities/base_classes.py +134 -0
- ibm_watsonx_gov/entities/container.py +54 -0
- ibm_watsonx_gov/entities/credentials.py +633 -0
- ibm_watsonx_gov/entities/criteria.py +508 -0
- ibm_watsonx_gov/entities/enums.py +274 -0
- ibm_watsonx_gov/entities/evaluation_result.py +444 -0
- ibm_watsonx_gov/entities/foundation_model.py +490 -0
- ibm_watsonx_gov/entities/llm_judge.py +44 -0
- ibm_watsonx_gov/entities/locale.py +17 -0
- ibm_watsonx_gov/entities/mapping.py +49 -0
- ibm_watsonx_gov/entities/metric.py +211 -0
- ibm_watsonx_gov/entities/metric_threshold.py +36 -0
- ibm_watsonx_gov/entities/model_provider.py +329 -0
- ibm_watsonx_gov/entities/model_risk_result.py +43 -0
- ibm_watsonx_gov/entities/monitor.py +71 -0
- ibm_watsonx_gov/entities/prompt_setup.py +40 -0
- ibm_watsonx_gov/entities/state.py +22 -0
- ibm_watsonx_gov/entities/utils.py +99 -0
- ibm_watsonx_gov/evaluators/__init__.py +26 -0
- ibm_watsonx_gov/evaluators/agentic_evaluator.py +2725 -0
- ibm_watsonx_gov/evaluators/agentic_traces_evaluator.py +115 -0
- ibm_watsonx_gov/evaluators/base_evaluator.py +22 -0
- ibm_watsonx_gov/evaluators/impl/__init__.py +0 -0
- ibm_watsonx_gov/evaluators/impl/evaluate_metrics_impl.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/evaluators/impl/evaluate_model_risk_impl.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/evaluators/metrics_evaluator.py +187 -0
- ibm_watsonx_gov/evaluators/model_risk_evaluator.py +89 -0
- ibm_watsonx_gov/evaluators/traces_evaluator.py +93 -0
- ibm_watsonx_gov/metric_groups/answer_quality/answer_quality_decorator.py +66 -0
- ibm_watsonx_gov/metric_groups/content_safety/content_safety_decorator.py +76 -0
- ibm_watsonx_gov/metric_groups/readability/readability_decorator.py +59 -0
- ibm_watsonx_gov/metric_groups/retrieval_quality/retrieval_quality_decorator.py +63 -0
- ibm_watsonx_gov/metric_groups/usage/usage_decorator.py +58 -0
- ibm_watsonx_gov/metrics/__init__.py +74 -0
- ibm_watsonx_gov/metrics/answer_relevance/__init__.py +8 -0
- ibm_watsonx_gov/metrics/answer_relevance/answer_relevance_decorator.py +63 -0
- ibm_watsonx_gov/metrics/answer_relevance/answer_relevance_metric.py +260 -0
- ibm_watsonx_gov/metrics/answer_similarity/__init__.py +0 -0
- ibm_watsonx_gov/metrics/answer_similarity/answer_similarity_decorator.py +66 -0
- ibm_watsonx_gov/metrics/answer_similarity/answer_similarity_metric.py +219 -0
- ibm_watsonx_gov/metrics/average_precision/__init__.py +0 -0
- ibm_watsonx_gov/metrics/average_precision/average_precision_decorator.py +62 -0
- ibm_watsonx_gov/metrics/average_precision/average_precision_metric.py +174 -0
- ibm_watsonx_gov/metrics/base_metric_decorator.py +193 -0
- ibm_watsonx_gov/metrics/context_relevance/__init__.py +8 -0
- ibm_watsonx_gov/metrics/context_relevance/context_relevance_decorator.py +60 -0
- ibm_watsonx_gov/metrics/context_relevance/context_relevance_metric.py +414 -0
- ibm_watsonx_gov/metrics/cost/__init__.py +8 -0
- ibm_watsonx_gov/metrics/cost/cost_decorator.py +58 -0
- ibm_watsonx_gov/metrics/cost/cost_metric.py +155 -0
- ibm_watsonx_gov/metrics/duration/__init__.py +8 -0
- ibm_watsonx_gov/metrics/duration/duration_decorator.py +59 -0
- ibm_watsonx_gov/metrics/duration/duration_metric.py +111 -0
- ibm_watsonx_gov/metrics/evasiveness/__init__.py +8 -0
- ibm_watsonx_gov/metrics/evasiveness/evasiveness_decorator.py +61 -0
- ibm_watsonx_gov/metrics/evasiveness/evasiveness_metric.py +103 -0
- ibm_watsonx_gov/metrics/faithfulness/__init__.py +8 -0
- ibm_watsonx_gov/metrics/faithfulness/faithfulness_decorator.py +65 -0
- ibm_watsonx_gov/metrics/faithfulness/faithfulness_metric.py +254 -0
- ibm_watsonx_gov/metrics/hap/__init__.py +16 -0
- ibm_watsonx_gov/metrics/hap/hap_decorator.py +58 -0
- ibm_watsonx_gov/metrics/hap/hap_metric.py +98 -0
- ibm_watsonx_gov/metrics/hap/input_hap_metric.py +104 -0
- ibm_watsonx_gov/metrics/hap/output_hap_metric.py +110 -0
- ibm_watsonx_gov/metrics/harm/__init__.py +8 -0
- ibm_watsonx_gov/metrics/harm/harm_decorator.py +60 -0
- ibm_watsonx_gov/metrics/harm/harm_metric.py +103 -0
- ibm_watsonx_gov/metrics/harm_engagement/__init__.py +8 -0
- ibm_watsonx_gov/metrics/harm_engagement/harm_engagement_decorator.py +61 -0
- ibm_watsonx_gov/metrics/harm_engagement/harm_engagement_metric.py +103 -0
- ibm_watsonx_gov/metrics/hit_rate/__init__.py +0 -0
- ibm_watsonx_gov/metrics/hit_rate/hit_rate_decorator.py +59 -0
- ibm_watsonx_gov/metrics/hit_rate/hit_rate_metric.py +167 -0
- ibm_watsonx_gov/metrics/input_token_count/__init__.py +8 -0
- ibm_watsonx_gov/metrics/input_token_count/input_token_count_decorator.py +58 -0
- ibm_watsonx_gov/metrics/input_token_count/input_token_count_metric.py +112 -0
- ibm_watsonx_gov/metrics/jailbreak/__init__.py +8 -0
- ibm_watsonx_gov/metrics/jailbreak/jailbreak_decorator.py +60 -0
- ibm_watsonx_gov/metrics/jailbreak/jailbreak_metric.py +103 -0
- ibm_watsonx_gov/metrics/keyword_detection/keyword_detection_decorator.py +58 -0
- ibm_watsonx_gov/metrics/keyword_detection/keyword_detection_metric.py +111 -0
- ibm_watsonx_gov/metrics/llm_validation/__init__.py +8 -0
- ibm_watsonx_gov/metrics/llm_validation/evaluation_criteria.py +84 -0
- ibm_watsonx_gov/metrics/llm_validation/llm_validation_constants.py +24 -0
- ibm_watsonx_gov/metrics/llm_validation/llm_validation_decorator.py +54 -0
- ibm_watsonx_gov/metrics/llm_validation/llm_validation_impl.py +525 -0
- ibm_watsonx_gov/metrics/llm_validation/llm_validation_metric.py +258 -0
- ibm_watsonx_gov/metrics/llm_validation/llm_validation_prompts.py +106 -0
- ibm_watsonx_gov/metrics/llmaj/__init__.py +0 -0
- ibm_watsonx_gov/metrics/llmaj/llmaj_metric.py +298 -0
- ibm_watsonx_gov/metrics/ndcg/__init__.py +0 -0
- ibm_watsonx_gov/metrics/ndcg/ndcg_decorator.py +61 -0
- ibm_watsonx_gov/metrics/ndcg/ndcg_metric.py +166 -0
- ibm_watsonx_gov/metrics/output_token_count/__init__.py +8 -0
- ibm_watsonx_gov/metrics/output_token_count/output_token_count_decorator.py +58 -0
- ibm_watsonx_gov/metrics/output_token_count/output_token_count_metric.py +112 -0
- ibm_watsonx_gov/metrics/pii/__init__.py +16 -0
- ibm_watsonx_gov/metrics/pii/input_pii_metric.py +102 -0
- ibm_watsonx_gov/metrics/pii/output_pii_metric.py +107 -0
- ibm_watsonx_gov/metrics/pii/pii_decorator.py +59 -0
- ibm_watsonx_gov/metrics/pii/pii_metric.py +96 -0
- ibm_watsonx_gov/metrics/profanity/__init__.py +8 -0
- ibm_watsonx_gov/metrics/profanity/profanity_decorator.py +60 -0
- ibm_watsonx_gov/metrics/profanity/profanity_metric.py +103 -0
- ibm_watsonx_gov/metrics/prompt_safety_risk/__init__.py +8 -0
- ibm_watsonx_gov/metrics/prompt_safety_risk/prompt_safety_risk_decorator.py +57 -0
- ibm_watsonx_gov/metrics/prompt_safety_risk/prompt_safety_risk_metric.py +128 -0
- ibm_watsonx_gov/metrics/reciprocal_rank/__init__.py +0 -0
- ibm_watsonx_gov/metrics/reciprocal_rank/reciprocal_rank_decorator.py +62 -0
- ibm_watsonx_gov/metrics/reciprocal_rank/reciprocal_rank_metric.py +162 -0
- ibm_watsonx_gov/metrics/regex_detection/regex_detection_decorator.py +58 -0
- ibm_watsonx_gov/metrics/regex_detection/regex_detection_metric.py +106 -0
- ibm_watsonx_gov/metrics/retrieval_precision/__init__.py +0 -0
- ibm_watsonx_gov/metrics/retrieval_precision/retrieval_precision_decorator.py +62 -0
- ibm_watsonx_gov/metrics/retrieval_precision/retrieval_precision_metric.py +170 -0
- ibm_watsonx_gov/metrics/sexual_content/__init__.py +8 -0
- ibm_watsonx_gov/metrics/sexual_content/sexual_content_decorator.py +61 -0
- ibm_watsonx_gov/metrics/sexual_content/sexual_content_metric.py +103 -0
- ibm_watsonx_gov/metrics/social_bias/__init__.py +8 -0
- ibm_watsonx_gov/metrics/social_bias/social_bias_decorator.py +62 -0
- ibm_watsonx_gov/metrics/social_bias/social_bias_metric.py +103 -0
- ibm_watsonx_gov/metrics/status/__init__.py +0 -0
- ibm_watsonx_gov/metrics/status/status_metric.py +113 -0
- ibm_watsonx_gov/metrics/text_grade_level/__init__.py +8 -0
- ibm_watsonx_gov/metrics/text_grade_level/text_grade_level_decorator.py +59 -0
- ibm_watsonx_gov/metrics/text_grade_level/text_grade_level_metric.py +127 -0
- ibm_watsonx_gov/metrics/text_reading_ease/__init__.py +8 -0
- ibm_watsonx_gov/metrics/text_reading_ease/text_reading_ease_decorator.py +59 -0
- ibm_watsonx_gov/metrics/text_reading_ease/text_reading_ease_metric.py +123 -0
- ibm_watsonx_gov/metrics/tool_call_accuracy/__init__.py +0 -0
- ibm_watsonx_gov/metrics/tool_call_accuracy/tool_call_accuracy_decorator.py +67 -0
- ibm_watsonx_gov/metrics/tool_call_accuracy/tool_call_accuracy_metric.py +162 -0
- ibm_watsonx_gov/metrics/tool_call_parameter_accuracy/__init__.py +0 -0
- ibm_watsonx_gov/metrics/tool_call_parameter_accuracy/tool_call_parameter_accuracy_decorator.py +68 -0
- ibm_watsonx_gov/metrics/tool_call_parameter_accuracy/tool_call_parameter_accuracy_metric.py +151 -0
- ibm_watsonx_gov/metrics/tool_call_relevance/__init__.py +0 -0
- ibm_watsonx_gov/metrics/tool_call_relevance/tool_call_relevance_decorator.py +71 -0
- ibm_watsonx_gov/metrics/tool_call_relevance/tool_call_relevance_metric.py +166 -0
- ibm_watsonx_gov/metrics/tool_call_syntactic_accuracy/__init__.py +0 -0
- ibm_watsonx_gov/metrics/tool_call_syntactic_accuracy/tool_call_syntactic_accuracy_decorator.py +66 -0
- ibm_watsonx_gov/metrics/tool_call_syntactic_accuracy/tool_call_syntactic_accuracy_metric.py +121 -0
- ibm_watsonx_gov/metrics/topic_relevance/__init__.py +8 -0
- ibm_watsonx_gov/metrics/topic_relevance/topic_relevance_decorator.py +57 -0
- ibm_watsonx_gov/metrics/topic_relevance/topic_relevance_metric.py +106 -0
- ibm_watsonx_gov/metrics/unethical_behavior/__init__.py +8 -0
- ibm_watsonx_gov/metrics/unethical_behavior/unethical_behavior_decorator.py +61 -0
- ibm_watsonx_gov/metrics/unethical_behavior/unethical_behavior_metric.py +103 -0
- ibm_watsonx_gov/metrics/unsuccessful_requests/__init__.py +0 -0
- ibm_watsonx_gov/metrics/unsuccessful_requests/unsuccessful_requests_decorator.py +66 -0
- ibm_watsonx_gov/metrics/unsuccessful_requests/unsuccessful_requests_metric.py +128 -0
- ibm_watsonx_gov/metrics/user_id/__init__.py +0 -0
- ibm_watsonx_gov/metrics/user_id/user_id_metric.py +111 -0
- ibm_watsonx_gov/metrics/utils.py +440 -0
- ibm_watsonx_gov/metrics/violence/__init__.py +8 -0
- ibm_watsonx_gov/metrics/violence/violence_decorator.py +60 -0
- ibm_watsonx_gov/metrics/violence/violence_metric.py +103 -0
- ibm_watsonx_gov/prompt_evaluator/__init__.py +9 -0
- ibm_watsonx_gov/prompt_evaluator/impl/__init__.py +8 -0
- ibm_watsonx_gov/prompt_evaluator/impl/prompt_evaluator_impl.py +554 -0
- ibm_watsonx_gov/prompt_evaluator/impl/pta_lifecycle_evaluator.py +2332 -0
- ibm_watsonx_gov/prompt_evaluator/prompt_evaluator.py +262 -0
- ibm_watsonx_gov/providers/__init__.py +8 -0
- ibm_watsonx_gov/providers/detectors_provider.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/providers/detectors_provider.py +415 -0
- ibm_watsonx_gov/providers/eval_assist_provider.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/providers/eval_assist_provider.py +266 -0
- ibm_watsonx_gov/providers/inference_engines/__init__.py +0 -0
- ibm_watsonx_gov/providers/inference_engines/custom_inference_engine.py +165 -0
- ibm_watsonx_gov/providers/inference_engines/portkey_inference_engine.py +57 -0
- ibm_watsonx_gov/providers/llmevalkit/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/ciso_agent/main.py +516 -0
- ibm_watsonx_gov/providers/llmevalkit/ciso_agent/preprocess_log.py +111 -0
- ibm_watsonx_gov/providers/llmevalkit/ciso_agent/utils.py +186 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/README.md +411 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/__init__.py +27 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/README.md +306 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/__init__.py +89 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/__init__.py +30 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/base.py +411 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/code_agent.py +1254 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/exact_match.py +134 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/fuzzy_string.py +104 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/hybrid.py +516 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/comparators/llm_judge.py +1882 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/pipeline.py +387 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/types.py +178 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/comparison/utils.py +298 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/consts.py +33 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/__init__.py +31 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/base.py +26 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_call/__init__.py +4 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_call/general.py +46 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_call/general_metrics.json +783 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_call/general_metrics_runtime.json +580 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_selection/__init__.py +6 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_selection/function_selection.py +28 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_selection/function_selection_metrics.json +599 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/function_selection/function_selection_metrics_runtime.json +477 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/loader.py +259 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/parameter/__init__.py +7 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/parameter/parameter.py +52 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/parameter/parameter_metrics.json +613 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/parameter/parameter_metrics_runtime.json +489 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/trajectory/__init__.py +7 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/trajectory/trajectory.py +43 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/metrics/trajectory/trajectory_metrics.json +161 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/adapters.py +102 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/pipeline.py +355 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/semantic_checker.py +816 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/static_checker.py +297 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/transformation_prompts.py +509 -0
- ibm_watsonx_gov/providers/llmevalkit/function_calling/pipeline/types.py +596 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/README.md +375 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/__init__.py +137 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/base.py +426 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/output_parser.py +364 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/consts.py +7 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/ibm_watsonx_ai/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/ibm_watsonx_ai/ibm_watsonx_ai.py +656 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/litellm/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/litellm/litellm.py +509 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/litellm/rits.py +224 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/litellm/watsonx.py +60 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/mock_llm_client.py +75 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/openai/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/openai/openai.py +639 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/wxo_ai_gateway/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/wxo_ai_gateway/wxo_ai_gateway.py +134 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/providers/wxo_ai_gateway/wxo_ai_gateway_inference.py +214 -0
- ibm_watsonx_gov/providers/llmevalkit/llm/types.py +136 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/__init__.py +4 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/field.py +255 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/metric.py +332 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/metrics_runner.py +188 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/prompt.py +403 -0
- ibm_watsonx_gov/providers/llmevalkit/metrics/utils.py +46 -0
- ibm_watsonx_gov/providers/llmevalkit/prompt/__init__.py +0 -0
- ibm_watsonx_gov/providers/llmevalkit/prompt/runner.py +144 -0
- ibm_watsonx_gov/providers/tool_call_metric_provider.py +455 -0
- ibm_watsonx_gov/providers/unitxt_provider.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/tools/__init__.py +10 -0
- ibm_watsonx_gov/tools/clients/__init__.py +11 -0
- ibm_watsonx_gov/tools/clients/ai_tool_client.py +405 -0
- ibm_watsonx_gov/tools/clients/detector_client.py +82 -0
- ibm_watsonx_gov/tools/core/__init__.py +8 -0
- ibm_watsonx_gov/tools/core/tool_loader.py +237 -0
- ibm_watsonx_gov/tools/entities/__init__.py +8 -0
- ibm_watsonx_gov/tools/entities/ai_tools.py +435 -0
- ibm_watsonx_gov/tools/onboarding/create/answer_relevance_detector.json +57 -0
- ibm_watsonx_gov/tools/onboarding/create/chromadb_retrieval_tool.json +63 -0
- ibm_watsonx_gov/tools/onboarding/create/context_relevance_detector.json +57 -0
- ibm_watsonx_gov/tools/onboarding/create/duduckgo_search_tool.json +53 -0
- ibm_watsonx_gov/tools/onboarding/create/google_search_tool.json +62 -0
- ibm_watsonx_gov/tools/onboarding/create/hap_detector.json +70 -0
- ibm_watsonx_gov/tools/onboarding/create/jailbreak_detector.json +70 -0
- ibm_watsonx_gov/tools/onboarding/create/pii_detector.json +36 -0
- ibm_watsonx_gov/tools/onboarding/create/prompt_safety_risk_detector.json +69 -0
- ibm_watsonx_gov/tools/onboarding/create/topic_relevance_detector.json +57 -0
- ibm_watsonx_gov/tools/onboarding/create/weather_tool.json +39 -0
- ibm_watsonx_gov/tools/onboarding/create/webcrawler_tool.json +34 -0
- ibm_watsonx_gov/tools/onboarding/create/wikipedia_search_tool.json +53 -0
- ibm_watsonx_gov/tools/onboarding/delete/delete_tools.json +4 -0
- ibm_watsonx_gov/tools/onboarding/update/google_search_tool.json +38 -0
- ibm_watsonx_gov/tools/ootb/__init__.py +8 -0
- ibm_watsonx_gov/tools/ootb/detectors/__init__.py +8 -0
- ibm_watsonx_gov/tools/ootb/detectors/hap_detector_tool.py +109 -0
- ibm_watsonx_gov/tools/ootb/detectors/jailbreak_detector_tool.py +104 -0
- ibm_watsonx_gov/tools/ootb/detectors/pii_detector_tool.py +83 -0
- ibm_watsonx_gov/tools/ootb/detectors/prompt_safety_risk_detector_tool.py +111 -0
- ibm_watsonx_gov/tools/ootb/detectors/topic_relevance_detector_tool.py +101 -0
- ibm_watsonx_gov/tools/ootb/rag/__init__.py +8 -0
- ibm_watsonx_gov/tools/ootb/rag/answer_relevance_detector_tool.py +119 -0
- ibm_watsonx_gov/tools/ootb/rag/context_relevance_detector_tool.py +118 -0
- ibm_watsonx_gov/tools/ootb/search/__init__.py +8 -0
- ibm_watsonx_gov/tools/ootb/search/duckduckgo_search_tool.py +62 -0
- ibm_watsonx_gov/tools/ootb/search/google_search_tool.py +105 -0
- ibm_watsonx_gov/tools/ootb/search/weather_tool.py +95 -0
- ibm_watsonx_gov/tools/ootb/search/web_crawler_tool.py +69 -0
- ibm_watsonx_gov/tools/ootb/search/wikipedia_search_tool.py +63 -0
- ibm_watsonx_gov/tools/ootb/vectordb/__init__.py +8 -0
- ibm_watsonx_gov/tools/ootb/vectordb/chromadb_retriever_tool.py +111 -0
- ibm_watsonx_gov/tools/rest_api/__init__.py +10 -0
- ibm_watsonx_gov/tools/rest_api/restapi_tool.py +72 -0
- ibm_watsonx_gov/tools/schemas/__init__.py +10 -0
- ibm_watsonx_gov/tools/schemas/search_tool_schema.py +46 -0
- ibm_watsonx_gov/tools/schemas/vectordb_retrieval_schema.py +55 -0
- ibm_watsonx_gov/tools/utils/__init__.py +14 -0
- ibm_watsonx_gov/tools/utils/constants.py +69 -0
- ibm_watsonx_gov/tools/utils/display_utils.py +38 -0
- ibm_watsonx_gov/tools/utils/environment.py +108 -0
- ibm_watsonx_gov/tools/utils/package_utils.py +40 -0
- ibm_watsonx_gov/tools/utils/platform_url_mapping.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/tools/utils/python_utils.py +68 -0
- ibm_watsonx_gov/tools/utils/tool_utils.py +206 -0
- ibm_watsonx_gov/traces/__init__.py +8 -0
- ibm_watsonx_gov/traces/span_exporter.py +195 -0
- ibm_watsonx_gov/traces/span_node.py +251 -0
- ibm_watsonx_gov/traces/span_util.py +153 -0
- ibm_watsonx_gov/traces/trace_utils.py +1074 -0
- ibm_watsonx_gov/utils/__init__.py +8 -0
- ibm_watsonx_gov/utils/aggregation_util.py +346 -0
- ibm_watsonx_gov/utils/async_util.py +62 -0
- ibm_watsonx_gov/utils/authenticator.py +144 -0
- ibm_watsonx_gov/utils/constants.py +15 -0
- ibm_watsonx_gov/utils/errors.py +40 -0
- ibm_watsonx_gov/utils/gov_sdk_logger.py +39 -0
- ibm_watsonx_gov/utils/insights_generator.py +1285 -0
- ibm_watsonx_gov/utils/python_utils.py +425 -0
- ibm_watsonx_gov/utils/rest_util.py +73 -0
- ibm_watsonx_gov/utils/segment_batch_manager.py +162 -0
- ibm_watsonx_gov/utils/singleton_meta.py +25 -0
- ibm_watsonx_gov/utils/url_mapping.cp313-win_amd64.pyd +0 -0
- ibm_watsonx_gov/utils/validation_util.py +126 -0
- ibm_watsonx_gov/visualizations/__init__.py +13 -0
- ibm_watsonx_gov/visualizations/metric_descriptions.py +57 -0
- ibm_watsonx_gov/visualizations/model_insights.py +1304 -0
- ibm_watsonx_gov/visualizations/visualization_utils.py +75 -0
- ibm_watsonx_gov-1.3.3.dist-info/METADATA +93 -0
- ibm_watsonx_gov-1.3.3.dist-info/RECORD +353 -0
- ibm_watsonx_gov-1.3.3.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
|
|
2
|
+
# ----------------------------------------------------------------------------------------------------
|
|
3
|
+
# IBM Confidential
|
|
4
|
+
# Licensed Materials - Property of IBM
|
|
5
|
+
# 5737-H76, 5900-A3Q
|
|
6
|
+
# © Copyright IBM Corp. 2025 All Rights Reserved.
|
|
7
|
+
# US Government Users Restricted Rights - Use, duplication or disclosure restricted by
|
|
8
|
+
# GSA ADPSchedule Contract with IBM Corp.
|
|
9
|
+
# ----------------------------------------------------------------------------------------------------
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import functools
|
|
13
|
+
import re
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from lazy_imports import LazyModule, load
|
|
17
|
+
|
|
18
|
+
from ibm_watsonx_gov.clients.usage_client import validate_usage_client
|
|
19
|
+
from ibm_watsonx_gov.entities.credentials import WxAICredentials
|
|
20
|
+
from ibm_watsonx_gov.entities.criteria import Option
|
|
21
|
+
from ibm_watsonx_gov.entities.enums import EvaluationProvider, MetricGroup
|
|
22
|
+
from ibm_watsonx_gov.entities.evaluation_result import (AggregateMetricResult,
|
|
23
|
+
RecordMetricResult)
|
|
24
|
+
from ibm_watsonx_gov.entities.foundation_model import (
|
|
25
|
+
AzureOpenAIFoundationModel, OpenAIFoundationModel, PortKeyGateway,
|
|
26
|
+
WxAIFoundationModel)
|
|
27
|
+
from ibm_watsonx_gov.entities.llm_judge import LLMJudge
|
|
28
|
+
from ibm_watsonx_gov.entities.metric_threshold import MetricThreshold
|
|
29
|
+
from ibm_watsonx_gov.providers.inference_engines.portkey_inference_engine import \
|
|
30
|
+
PortKeyInferenceEngine
|
|
31
|
+
from ibm_watsonx_gov.utils.async_util import start_event_loop_run_func
|
|
32
|
+
|
|
33
|
+
ea_imports = LazyModule(
|
|
34
|
+
"from evalassist.judges import Criteria as EACriteria",
|
|
35
|
+
"from evalassist.judges import CriteriaOption as EACriteriaOption",
|
|
36
|
+
"from evalassist.judges import Instance, DirectJudge",
|
|
37
|
+
"from unitxt.inference import CrossProviderInferenceEngine",
|
|
38
|
+
name="lazy_ea"
|
|
39
|
+
)
|
|
40
|
+
load(ea_imports)
|
|
41
|
+
|
|
42
|
+
EACriteria = ea_imports.EACriteria
|
|
43
|
+
EACriteriaOption = ea_imports.EACriteriaOption
|
|
44
|
+
Instance = ea_imports.Instance
|
|
45
|
+
DirectJudge = ea_imports.DirectJudge
|
|
46
|
+
CrossProviderInferenceEngine = ea_imports.CrossProviderInferenceEngine
|
|
47
|
+
|
|
48
|
+
VARIABLES_PATTERN = r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EvalAssistProvider():
|
|
52
|
+
"""
|
|
53
|
+
The class to invoke eval assist library for computing the LLMAJ metrics.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, metric_name: str,
|
|
57
|
+
display_name: str,
|
|
58
|
+
value_type: str,
|
|
59
|
+
llm_judge: LLMJudge,
|
|
60
|
+
options: list[Option],
|
|
61
|
+
criteria_description: str | None = None,
|
|
62
|
+
prompt_template: str | None = None,
|
|
63
|
+
context_fields: list[str] = [],
|
|
64
|
+
prediction_field: str | None = None,
|
|
65
|
+
metric_group: MetricGroup = None,
|
|
66
|
+
metric_method: str | None = None,
|
|
67
|
+
thresholds: list[MetricThreshold] = [],
|
|
68
|
+
**kwargs):
|
|
69
|
+
self.metric_name = metric_name
|
|
70
|
+
self.display_name = display_name
|
|
71
|
+
self.value_type = value_type
|
|
72
|
+
self.llm_judge = llm_judge
|
|
73
|
+
self.criteria_description = criteria_description
|
|
74
|
+
self.prompt_template = prompt_template
|
|
75
|
+
self.options = options
|
|
76
|
+
self.context_fields = context_fields
|
|
77
|
+
self.prediction_field = prediction_field
|
|
78
|
+
self.metric_group = metric_group
|
|
79
|
+
self.metric_method = metric_method
|
|
80
|
+
self.thresholds = thresholds
|
|
81
|
+
self.record_id_field = kwargs.get("record_id_field", "record_id")
|
|
82
|
+
validate_usage_client(kwargs.get("usage_client"))
|
|
83
|
+
|
|
84
|
+
async def evaluate_async(self, data: pd.DataFrame) -> AggregateMetricResult:
|
|
85
|
+
loop = asyncio.get_event_loop()
|
|
86
|
+
# If called as async, run it in a separate thread
|
|
87
|
+
return await loop.run_in_executor(
|
|
88
|
+
None,
|
|
89
|
+
functools.partial(
|
|
90
|
+
start_event_loop_run_func,
|
|
91
|
+
func=self.evaluate,
|
|
92
|
+
data=data
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def evaluate(self, data: pd.DataFrame) -> AggregateMetricResult:
|
|
97
|
+
try:
|
|
98
|
+
judge = self.__get_judge()
|
|
99
|
+
if self.criteria_description:
|
|
100
|
+
criteria = self.__get_criteria(
|
|
101
|
+
self.prediction_field, self.context_fields)
|
|
102
|
+
|
|
103
|
+
instances = self.__get_instances(data=data,
|
|
104
|
+
prediction_field=self.prediction_field,
|
|
105
|
+
context_fields=self.context_fields)
|
|
106
|
+
|
|
107
|
+
results = judge(instances=instances, criteria=criteria)
|
|
108
|
+
elif self.prompt_template:
|
|
109
|
+
# Get judge prompts with filled in values
|
|
110
|
+
judge_prompts = data.apply(
|
|
111
|
+
lambda row: self.prompt_template.format(**row), axis=1).to_list()
|
|
112
|
+
|
|
113
|
+
# Get the list of valid outputs from the judge prompt
|
|
114
|
+
valid_outputs = [o.name for o in self.options]
|
|
115
|
+
|
|
116
|
+
results = judge.evaluate_with_custom_prompt(
|
|
117
|
+
judge_prompts=judge_prompts,
|
|
118
|
+
valid_outputs=valid_outputs)
|
|
119
|
+
|
|
120
|
+
aggregated_result = self.__post_process(
|
|
121
|
+
results=results, data=data)
|
|
122
|
+
return aggregated_result
|
|
123
|
+
except Exception as e:
|
|
124
|
+
raise Exception(
|
|
125
|
+
f"Error while computing metrics: {self.metric_name}. Reason: {str(e)}") from e
|
|
126
|
+
|
|
127
|
+
def __get_judge(self):
|
|
128
|
+
if self.llm_judge and isinstance(self.llm_judge.model, PortKeyGateway):
|
|
129
|
+
judge = DirectJudge(
|
|
130
|
+
inference_engine=PortKeyInferenceEngine(
|
|
131
|
+
**self.__get_inference_engine_params()),
|
|
132
|
+
generate_feedback=True,
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
judge = DirectJudge(
|
|
136
|
+
inference_engine=CrossProviderInferenceEngine(
|
|
137
|
+
**self.__get_inference_engine_params()),
|
|
138
|
+
generate_feedback=True,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return judge
|
|
142
|
+
|
|
143
|
+
def __get_instances(self, data, prediction_field, context_fields):
|
|
144
|
+
instances = []
|
|
145
|
+
context_data = data[context_fields].to_dict(orient="records")
|
|
146
|
+
predictions = data[prediction_field].tolist()
|
|
147
|
+
if context_data:
|
|
148
|
+
for c, p in zip(context_data, predictions):
|
|
149
|
+
fields = {prediction_field: p}
|
|
150
|
+
fields.update(c)
|
|
151
|
+
instances.append(Instance(
|
|
152
|
+
fields=fields))
|
|
153
|
+
else:
|
|
154
|
+
for p in predictions:
|
|
155
|
+
instances.append(Instance(
|
|
156
|
+
fields={prediction_field: p}))
|
|
157
|
+
|
|
158
|
+
return instances
|
|
159
|
+
|
|
160
|
+
def __get_inference_engine_params(self):
|
|
161
|
+
params = {"seed": 36,
|
|
162
|
+
"data_classification_policy": ["public"]}
|
|
163
|
+
|
|
164
|
+
if isinstance(self.llm_judge.model, WxAIFoundationModel):
|
|
165
|
+
wxai_credentials: WxAICredentials = self.llm_judge.model.provider.credentials
|
|
166
|
+
wml_credentials = {}
|
|
167
|
+
wml_credentials["api_base"] = wxai_credentials.url
|
|
168
|
+
if wxai_credentials.api_key:
|
|
169
|
+
wml_credentials["api_key"] = wxai_credentials.api_key
|
|
170
|
+
if wxai_credentials.version: # using cpd
|
|
171
|
+
wml_credentials["username"] = wxai_credentials.username
|
|
172
|
+
wml_credentials["instance_id"] = wxai_credentials.instance_id
|
|
173
|
+
if wxai_credentials.password:
|
|
174
|
+
wml_credentials["password"] = wxai_credentials.password
|
|
175
|
+
|
|
176
|
+
if self.llm_judge.model.project_id:
|
|
177
|
+
wml_credentials["project_id"] = self.llm_judge.model.project_id
|
|
178
|
+
elif self.llm_judge.model.space_id:
|
|
179
|
+
wml_credentials["space_id"] = self.llm_judge.model.space_id
|
|
180
|
+
else:
|
|
181
|
+
raise Exception("Either project or space id is required")
|
|
182
|
+
|
|
183
|
+
params.update({
|
|
184
|
+
"credentials": wml_credentials,
|
|
185
|
+
"provider": "watsonx",
|
|
186
|
+
"model": self.llm_judge.model.model_id,
|
|
187
|
+
"provider_specific_args": {
|
|
188
|
+
"watsonx": {
|
|
189
|
+
"max_requests_per_second": 1
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
})
|
|
193
|
+
|
|
194
|
+
elif isinstance(self.llm_judge.model, OpenAIFoundationModel):
|
|
195
|
+
params.update({
|
|
196
|
+
"credentials": {
|
|
197
|
+
"api_key": self.llm_judge.model.provider.credentials.api_key
|
|
198
|
+
},
|
|
199
|
+
"provider": "open-ai",
|
|
200
|
+
"model": self.llm_judge.model.model_id,
|
|
201
|
+
"provider_specific_args": {"temperature": 0}
|
|
202
|
+
})
|
|
203
|
+
elif isinstance(self.llm_judge.model, PortKeyGateway):
|
|
204
|
+
params.update({
|
|
205
|
+
"credentials": self.llm_judge.model.provider.credentials.model_dump(),
|
|
206
|
+
"model": self.llm_judge.model.model_id
|
|
207
|
+
})
|
|
208
|
+
elif isinstance(self.llm_judge.model, AzureOpenAIFoundationModel):
|
|
209
|
+
raise Exception("Azure OpenAI Model provider is not supported.")
|
|
210
|
+
else:
|
|
211
|
+
raise Exception("LLM Model provider is not supported.")
|
|
212
|
+
|
|
213
|
+
return params
|
|
214
|
+
|
|
215
|
+
def __get_criteria(self, prediction_field, context_fields):
|
|
216
|
+
options = []
|
|
217
|
+
|
|
218
|
+
for op in self.options:
|
|
219
|
+
op_desc = op.description.replace(
|
|
220
|
+
"{"+prediction_field+"}", prediction_field)
|
|
221
|
+
op_desc = re.sub(VARIABLES_PATTERN, r"\1", op_desc)
|
|
222
|
+
options.append(EACriteriaOption(
|
|
223
|
+
name=op.name,
|
|
224
|
+
description=op_desc,
|
|
225
|
+
score=op.value
|
|
226
|
+
))
|
|
227
|
+
|
|
228
|
+
desc = self.criteria_description.replace(
|
|
229
|
+
"{"+prediction_field+"}", prediction_field)
|
|
230
|
+
desc = re.sub(VARIABLES_PATTERN, r"\1", desc)
|
|
231
|
+
|
|
232
|
+
criteria_with_options = EACriteria(name=self.metric_name,
|
|
233
|
+
description=desc,
|
|
234
|
+
to_evaluate_field=prediction_field,
|
|
235
|
+
context_fields=context_fields,
|
|
236
|
+
options=options)
|
|
237
|
+
|
|
238
|
+
return criteria_with_options
|
|
239
|
+
|
|
240
|
+
def __post_process(self, results, data: pd.DataFrame) -> AggregateMetricResult:
|
|
241
|
+
record_level_metrics: list[RecordMetricResult] = []
|
|
242
|
+
|
|
243
|
+
score_map = {o.name: o.value for o in self.options}
|
|
244
|
+
|
|
245
|
+
for record_id, result in zip(data[self.record_id_field].tolist(), results):
|
|
246
|
+
record_level_metrics.append(
|
|
247
|
+
RecordMetricResult(
|
|
248
|
+
name=self.metric_name,
|
|
249
|
+
display_name=self.display_name,
|
|
250
|
+
method=self.metric_method,
|
|
251
|
+
group=self.metric_group,
|
|
252
|
+
provider=EvaluationProvider.UNITXT.value,
|
|
253
|
+
value=score_map.get(result.selected_option),
|
|
254
|
+
label=result.selected_option,
|
|
255
|
+
record_id=record_id,
|
|
256
|
+
thresholds=self.thresholds,
|
|
257
|
+
explanation=result.explanation,
|
|
258
|
+
additional_info={
|
|
259
|
+
"feedback": result.feedback} if result.feedback else {}
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
aggregated_result = AggregateMetricResult.create(
|
|
264
|
+
record_level_metrics)
|
|
265
|
+
# return the aggregated result
|
|
266
|
+
return aggregated_result
|
|
File without changes
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# ----------------------------------------------------------------------------------------------------
|
|
2
|
+
# IBM Confidential
|
|
3
|
+
# Licensed Materials - Property of IBM
|
|
4
|
+
# 5737-H76, 5900-A3Q
|
|
5
|
+
# © Copyright IBM Corp. 2025 All Rights Reserved.
|
|
6
|
+
# US Government Users Restricted Rights - Use, duplication or disclosure restricted by
|
|
7
|
+
# GSA ADPSchedule Contract with IBM Corp.
|
|
8
|
+
# ----------------------------------------------------------------------------------------------------
|
|
9
|
+
|
|
10
|
+
from multiprocessing.pool import ThreadPool
|
|
11
|
+
from typing import Annotated, Any, Callable, Dict, List, Optional, Union
|
|
12
|
+
|
|
13
|
+
from datasets import Dataset
|
|
14
|
+
from lazy_imports import LazyModule, load
|
|
15
|
+
from pydantic import Field
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
unitxt_imports = LazyModule(
|
|
19
|
+
"from unitxt.artifact import Artifact",
|
|
20
|
+
"from unitxt.inference import InferenceEngine, TextGenerationInferenceOutput, get_model_and_label_id",
|
|
21
|
+
name="lazy_unitxt",
|
|
22
|
+
)
|
|
23
|
+
load(unitxt_imports)
|
|
24
|
+
|
|
25
|
+
Artifact = unitxt_imports.Artifact
|
|
26
|
+
InferenceEngine = unitxt_imports.InferenceEngine
|
|
27
|
+
TextGenerationInferenceOutput = unitxt_imports.TextGenerationInferenceOutput
|
|
28
|
+
get_model_and_label_id = unitxt_imports.get_model_and_label_id
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run_with_imap(func):
|
|
32
|
+
"""
|
|
33
|
+
Decorator to adapt a function for use with multiprocessing's imap.
|
|
34
|
+
Ensures arguments are unpacked properly when parallelizing inference.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def inner(self, args):
|
|
38
|
+
return func(self, *args)
|
|
39
|
+
return inner
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CustomFnEngineParamsMixin(Artifact):
|
|
43
|
+
"""
|
|
44
|
+
Mixin class that provides configurable parameters for the custom engine.
|
|
45
|
+
- batch_size: number of instances per batch (unused, but reserved for extension).
|
|
46
|
+
- timeout: optional timeout in seconds for inference requests.
|
|
47
|
+
- num_parallel_requests: max number of threads used for parallel inference.
|
|
48
|
+
"""
|
|
49
|
+
batch_size: Optional[int] = None
|
|
50
|
+
timeout: Optional[float] = None
|
|
51
|
+
num_parallel_requests: Optional[int] = 20
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CustomFunctionInferenceEngine(InferenceEngine, CustomFnEngineParamsMixin):
|
|
55
|
+
"""
|
|
56
|
+
A custom inference engine that delegates prediction to a user-provided function (`scoring_fn`).
|
|
57
|
+
Supports parallel execution across multiple threads and integrates seamlessly with Unitxt.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
label: str = "custom_fn"
|
|
61
|
+
model_name: str = "custom_fn"
|
|
62
|
+
num_parallel_requests: int = 20
|
|
63
|
+
|
|
64
|
+
scoring_fn: Callable
|
|
65
|
+
context: Optional[Dict[str, Any]] = None
|
|
66
|
+
|
|
67
|
+
def get_engine_id(self) -> str:
|
|
68
|
+
"""
|
|
69
|
+
Return a unique engine identifier based on model_name and label.
|
|
70
|
+
Used internally by Unitxt to differentiate inference engines.
|
|
71
|
+
"""
|
|
72
|
+
return get_model_and_label_id(self.model_name, self.label)
|
|
73
|
+
|
|
74
|
+
def prepare_engine(self):
|
|
75
|
+
"""
|
|
76
|
+
Hook for initializing resources before inference.
|
|
77
|
+
No-op here since the custom engine delegates everything to scoring_fn.
|
|
78
|
+
"""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def get_return_object(self, predict_result, response, return_meta_data):
|
|
82
|
+
"""
|
|
83
|
+
Return the prediction object in the format expected by Unitxt.
|
|
84
|
+
In this implementation, the prediction is returned as-is.
|
|
85
|
+
"""
|
|
86
|
+
return predict_result
|
|
87
|
+
|
|
88
|
+
def _parallel_infer(
|
|
89
|
+
self,
|
|
90
|
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
|
91
|
+
infer_func,
|
|
92
|
+
return_meta_data: bool = False,
|
|
93
|
+
) -> Union[List[str], List["TextGenerationInferenceOutput"]]:
|
|
94
|
+
"""
|
|
95
|
+
Run inference on a dataset in parallel using a thread pool.
|
|
96
|
+
Args:
|
|
97
|
+
dataset: list of instances or HuggingFace Dataset.
|
|
98
|
+
infer_func: function applied to each instance.
|
|
99
|
+
return_meta_data: if True, expects TextGenerationInferenceOutput.
|
|
100
|
+
Returns:
|
|
101
|
+
A list of predictions or metadata objects.
|
|
102
|
+
"""
|
|
103
|
+
inputs = [(instance, return_meta_data) for instance in dataset]
|
|
104
|
+
outputs: List[Union[str, "TextGenerationInferenceOutput"]] = []
|
|
105
|
+
with ThreadPool(processes=self.num_parallel_requests) as pool:
|
|
106
|
+
for output in tqdm(
|
|
107
|
+
pool.imap(infer_func, inputs),
|
|
108
|
+
total=len(inputs),
|
|
109
|
+
desc=f"Inferring with {self.__class__.__name__}",
|
|
110
|
+
):
|
|
111
|
+
outputs.append(output)
|
|
112
|
+
return outputs
|
|
113
|
+
|
|
114
|
+
def _infer(
|
|
115
|
+
self,
|
|
116
|
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
|
117
|
+
return_meta_data: bool = False,
|
|
118
|
+
) -> Union[List[str], List["TextGenerationInferenceOutput"]]:
|
|
119
|
+
"""
|
|
120
|
+
Core inference method called by Unitxt.
|
|
121
|
+
Delegates to `_parallel_infer` for concurrent execution.
|
|
122
|
+
"""
|
|
123
|
+
return self._parallel_infer(
|
|
124
|
+
dataset=dataset,
|
|
125
|
+
return_meta_data=return_meta_data,
|
|
126
|
+
infer_func=self._score_instance,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@run_with_imap
|
|
130
|
+
def _score_instance(self, instance, return_meta_data):
|
|
131
|
+
"""
|
|
132
|
+
Run inference on a single instance using the user-provided scoring_fn.
|
|
133
|
+
Handles type validation and returns a fallback object if scoring fails.
|
|
134
|
+
"""
|
|
135
|
+
try:
|
|
136
|
+
pred = self.scoring_fn(
|
|
137
|
+
instance, return_meta_data, context=self.context)
|
|
138
|
+
|
|
139
|
+
if return_meta_data and not isinstance(pred, TextGenerationInferenceOutput):
|
|
140
|
+
raise TypeError(
|
|
141
|
+
"With return_meta_data=True, scoring_fn must return TextGenerationInferenceOutput."
|
|
142
|
+
)
|
|
143
|
+
if not return_meta_data and not isinstance(pred, str):
|
|
144
|
+
raise TypeError(
|
|
145
|
+
"With return_meta_data=False, scoring_fn must return str."
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return self.get_return_object(pred, response=None, return_meta_data=return_meta_data)
|
|
149
|
+
except Exception:
|
|
150
|
+
if return_meta_data:
|
|
151
|
+
return TextGenerationInferenceOutput(
|
|
152
|
+
prediction="-", generated_text="-", input_tokens=0, output_tokens=0,
|
|
153
|
+
model_name=self.model_name, inference_type=self.label,
|
|
154
|
+
)
|
|
155
|
+
return "-"
|
|
156
|
+
|
|
157
|
+
def to_dict(self, *args, **kwargs) -> Dict[str, Any]:
|
|
158
|
+
"""
|
|
159
|
+
Convert the engine configuration to a dictionary.
|
|
160
|
+
Excludes unserializable fields like `scoring_fn` and `context` to ensure cache safety.
|
|
161
|
+
"""
|
|
162
|
+
d = super().to_dict(*args, **kwargs)
|
|
163
|
+
d.pop("scoring_fn", None)
|
|
164
|
+
d.pop("context", None)
|
|
165
|
+
return d
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# ----------------------------------------------------------------------------------------------------
|
|
2
|
+
# IBM Confidential
|
|
3
|
+
# Licensed Materials - Property of IBM
|
|
4
|
+
# 5737-H76, 5900-A3Q
|
|
5
|
+
# © Copyright IBM Corp. 2025 All Rights Reserved.
|
|
6
|
+
# US Government Users Restricted Rights - Use, duplication or disclosure restricted by
|
|
7
|
+
# GSA ADPSchedule Contract with IBM Corp.
|
|
8
|
+
# ----------------------------------------------------------------------------------------------------
|
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
from datasets import Dataset
|
|
12
|
+
from ibm_watsonx_gov.entities.llm_judge import LLMJudge
|
|
13
|
+
from unitxt.inference import (InferenceEngine, PackageRequirementsMixin,
|
|
14
|
+
StandardAPIParamsMixin,
|
|
15
|
+
TextGenerationInferenceOutput,
|
|
16
|
+
get_model_and_label_id)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PortKeyInferenceEngine(
|
|
20
|
+
InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin
|
|
21
|
+
):
|
|
22
|
+
label: str = "portkey"
|
|
23
|
+
_requirements_list = {
|
|
24
|
+
"portkey-ai": "Install portkey-ai package using 'pip install --upgrade portkey-ai"
|
|
25
|
+
}
|
|
26
|
+
model: str = None
|
|
27
|
+
credentials: Dict[str, str] = {}
|
|
28
|
+
|
|
29
|
+
def get_engine_id(self):
|
|
30
|
+
return get_model_and_label_id(self.model, self.label)
|
|
31
|
+
|
|
32
|
+
def prepare_engine(self):
|
|
33
|
+
from portkey_ai import Portkey
|
|
34
|
+
|
|
35
|
+
self.client = Portkey(
|
|
36
|
+
api_key=self.credentials["api_key"],
|
|
37
|
+
base_url=self.credentials.get("base_url", None),
|
|
38
|
+
provider=self.credentials.get("provider"),
|
|
39
|
+
Authorization="Bearer " + self.credentials["provider_api_key"],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def _infer(
|
|
43
|
+
self,
|
|
44
|
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
|
45
|
+
return_meta_data: bool = False,
|
|
46
|
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
|
47
|
+
args = self.to_dict([StandardAPIParamsMixin])
|
|
48
|
+
results = []
|
|
49
|
+
for instance in dataset:
|
|
50
|
+
messages = self.to_messages(instance)
|
|
51
|
+
response = self.client.chat.completions.create(
|
|
52
|
+
messages=messages,
|
|
53
|
+
model=self.model
|
|
54
|
+
)
|
|
55
|
+
results.append(response.choices[0].message.content)
|
|
56
|
+
|
|
57
|
+
return results
|
|
File without changes
|