azure-ai-evaluation 1.9.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (85) hide show
  1. azure/ai/evaluation/__init__.py +46 -12
  2. azure/ai/evaluation/_aoai/python_grader.py +84 -0
  3. azure/ai/evaluation/_aoai/score_model_grader.py +1 -0
  4. azure/ai/evaluation/_common/onedp/models/_models.py +5 -0
  5. azure/ai/evaluation/_common/rai_service.py +3 -3
  6. azure/ai/evaluation/_common/utils.py +74 -17
  7. azure/ai/evaluation/_converters/_ai_services.py +60 -10
  8. azure/ai/evaluation/_converters/_models.py +75 -26
  9. azure/ai/evaluation/_evaluate/_batch_run/_run_submitter_client.py +70 -22
  10. azure/ai/evaluation/_evaluate/_eval_run.py +14 -1
  11. azure/ai/evaluation/_evaluate/_evaluate.py +163 -44
  12. azure/ai/evaluation/_evaluate/_evaluate_aoai.py +79 -33
  13. azure/ai/evaluation/_evaluate/_utils.py +5 -2
  14. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
  15. azure/ai/evaluation/_evaluators/_code_vulnerability/_code_vulnerability.py +8 -1
  16. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +3 -2
  17. azure/ai/evaluation/_evaluators/_common/_base_eval.py +143 -25
  18. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +7 -2
  19. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +19 -9
  20. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +15 -5
  21. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +4 -1
  22. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +4 -1
  23. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +5 -2
  24. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +4 -1
  25. azure/ai/evaluation/_evaluators/_document_retrieval/_document_retrieval.py +3 -0
  26. azure/ai/evaluation/_evaluators/_eci/_eci.py +3 -0
  27. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +1 -1
  28. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +3 -2
  29. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
  30. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +114 -4
  31. azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +9 -3
  32. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +1 -1
  33. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +8 -1
  34. azure/ai/evaluation/_evaluators/_qa/_qa.py +1 -1
  35. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +56 -3
  36. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +140 -59
  37. azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +11 -3
  38. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +3 -2
  39. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +1 -1
  40. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +2 -1
  41. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +3 -2
  42. azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +24 -12
  43. azure/ai/evaluation/_evaluators/_task_adherence/task_adherence.prompty +354 -66
  44. azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +214 -187
  45. azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +126 -31
  46. azure/ai/evaluation/_evaluators/_ungrounded_attributes/_ungrounded_attributes.py +8 -1
  47. azure/ai/evaluation/_evaluators/_xpia/xpia.py +4 -1
  48. azure/ai/evaluation/_exceptions.py +1 -0
  49. azure/ai/evaluation/_legacy/_batch_engine/_config.py +6 -3
  50. azure/ai/evaluation/_legacy/_batch_engine/_engine.py +115 -30
  51. azure/ai/evaluation/_legacy/_batch_engine/_result.py +2 -0
  52. azure/ai/evaluation/_legacy/_batch_engine/_run.py +2 -2
  53. azure/ai/evaluation/_legacy/_batch_engine/_run_submitter.py +28 -31
  54. azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +2 -0
  55. azure/ai/evaluation/_version.py +1 -1
  56. azure/ai/evaluation/red_team/__init__.py +4 -3
  57. azure/ai/evaluation/red_team/_attack_objective_generator.py +17 -0
  58. azure/ai/evaluation/red_team/_callback_chat_target.py +14 -1
  59. azure/ai/evaluation/red_team/_evaluation_processor.py +376 -0
  60. azure/ai/evaluation/red_team/_mlflow_integration.py +322 -0
  61. azure/ai/evaluation/red_team/_orchestrator_manager.py +661 -0
  62. azure/ai/evaluation/red_team/_red_team.py +655 -2665
  63. azure/ai/evaluation/red_team/_red_team_result.py +6 -0
  64. azure/ai/evaluation/red_team/_result_processor.py +610 -0
  65. azure/ai/evaluation/red_team/_utils/__init__.py +34 -0
  66. azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +11 -4
  67. azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +6 -0
  68. azure/ai/evaluation/red_team/_utils/constants.py +0 -2
  69. azure/ai/evaluation/red_team/_utils/exception_utils.py +345 -0
  70. azure/ai/evaluation/red_team/_utils/file_utils.py +266 -0
  71. azure/ai/evaluation/red_team/_utils/formatting_utils.py +115 -13
  72. azure/ai/evaluation/red_team/_utils/metric_mapping.py +24 -4
  73. azure/ai/evaluation/red_team/_utils/progress_utils.py +252 -0
  74. azure/ai/evaluation/red_team/_utils/retry_utils.py +218 -0
  75. azure/ai/evaluation/red_team/_utils/strategy_utils.py +17 -4
  76. azure/ai/evaluation/simulator/_adversarial_simulator.py +14 -2
  77. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +13 -1
  78. azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +21 -7
  79. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +24 -5
  80. azure/ai/evaluation/simulator/_simulator.py +12 -0
  81. {azure_ai_evaluation-1.9.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/METADATA +63 -4
  82. {azure_ai_evaluation-1.9.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/RECORD +85 -76
  83. {azure_ai_evaluation-1.9.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/WHEEL +1 -1
  84. {azure_ai_evaluation-1.9.0.dist-info → azure_ai_evaluation-1.11.0.dist-info/licenses}/NOTICE.txt +0 -0
  85. {azure_ai_evaluation-1.9.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/top_level.txt +0 -0
@@ -5,20 +5,20 @@ model:
5
5
  api: chat
6
6
  parameters:
7
7
  temperature: 0.0
8
- max_tokens: 800
8
+ max_tokens: 3000
9
9
  top_p: 1.0
10
10
  presence_penalty: 0
11
11
  frequency_penalty: 0
12
12
  response_format:
13
- type: text
13
+ type: json_object
14
14
 
15
15
  inputs:
16
16
  query:
17
- type: array
18
- tool_call:
19
- type: object
20
- tool_definition:
21
- type: object
17
+ type: List
18
+ tool_calls:
19
+ type: List
20
+ tool_definitions:
21
+ type: Dict
22
22
 
23
23
  ---
24
24
  system:
@@ -27,45 +27,140 @@ system:
27
27
  ### Your are an expert in evaluating the accuracy of a tool call considering relevance and potential usefulness including syntactic and semantic correctness of a proposed tool call from an intelligent system based on provided definition and data. Your goal will involve answering the questions below using the information provided.
28
28
  - **Definition**: You are given a definition of the communication trait that is being evaluated to help guide your Score.
29
29
  - **Data**: Your input data include CONVERSATION , TOOL CALL and TOOL DEFINITION.
30
- - **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways.
30
+ - **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways, and you need to be very precise in your evaluation.
31
31
 
32
32
  user:
33
33
  # Definition
34
- **Tool Call Accuracy** refers to the relevance and potential usefulness of a TOOL CALL in the context of an ongoing CONVERSATION and EXTRACTION of RIGHT PARAMETER VALUES from the CONVERSATION.It assesses how likely the TOOL CALL is to contribute meaningfully to the CONVERSATION and help address the user's needs. Focus on evaluating the potential value of the TOOL CALL within the specific context of the given CONVERSATION, without making assumptions beyond the provided information.
35
- Consider the following factors in your evaluation:
34
+ # Definition
35
+ **Tool Call Accuracy** refers to the overall effectiveness of ALL TOOL CALLS made by an agent in response to a user's query within an ongoing CONVERSATION.
36
+
37
+ # EVALUATION CRITERIA
38
+ Evaluate based on these factors:
39
+
40
+ 1. **Collective Relevance**: Do the tool calls, taken together, appropriately address the user's query?
41
+ 2. **Parameter Correctness**: Are all parameter values extracted from or reasonably inferred from the CONVERSATION?
42
+ - *Fabricated parameters automatically result in Level 2*
43
+ 3. **Completeness**: Did the agent make all necessary tool calls available in the tool definitions?
44
+ - *Failed calls don't count as missing*
45
+ 4. **Efficiency**: Did the agent avoid unnecessary duplicate tool calls with identical parameters?
46
+ - *Don't penalize single tools returning multiple results (like file_search)*
47
+ 5. **Execution Success**: Were tool calls executed successfully or recovered from errors appropriately?
48
+ 6. **Scope Limitation**: ONLY evaluate tool calls in the "TOOL CALLS TO BE EVALUATED" section.
49
+ - Tool calls in the CONVERSATION section are for context only
50
+ - Focus exclusively on the agent's response to the user's LAST query
51
+ - Use conversation history only to verify parameter correctness and context
52
+
53
+ **Success Criteria**: Tools should retrieve relevant data to help answer the query. Complete final answers are not required from individual tools.
36
54
 
37
- 1. Relevance: How well does the proposed tool call align with the current topic and flow of the conversation?
38
- 2. Parameter Appropriateness: Do the parameters used in the TOOL CALL match the TOOL DEFINITION and are the parameters relevant to the latest user's query?
39
- 3. Parameter Value Correctness: Are the parameters values used in the TOOL CALL present or inferred by CONVERSATION and relevant to the latest user's query?
40
- 4. Potential Value: Is the information this tool call might provide likely to be useful in advancing the conversation or addressing the user expressed or implied needs?
41
- 5. Context Appropriateness: Does the tool call make sense at this point in the conversation, given what has been discussed so far?
55
+ **Tool Assessment**: Focus solely on appropriate use of available tools, not on capabilities beyond what tools can provide.
42
56
 
43
57
 
44
58
  # Ratings
45
- ## [Tool Call Accuracy: 0] (Irrelevant)
59
+ ## [Tool Call Accuracy: 1] (Irrelevant)
46
60
  **Definition:**
47
- 1. The TOOL CALL is not relevant and will not help resolve the user's need.
48
- 2. TOOL CALL include parameters values that are not present or inferred from CONVERSATION.
49
- 3. TOOL CALL has parameters that is not present in TOOL DEFINITION.
61
+ Tool calls were not relevant to the user's query, resulting in an irrelevant or unhelpful final output.
50
62
 
51
- ## [Tool Call Accuracy: 1] (Relevant)
63
+ **Example:**
64
+ User asks for distance between two cities -> Agent calls a weather function to get the weather in the two cities.
65
+
66
+
67
+ ## [Tool Call Accuracy: 2] (Partially Relevant - Wrong Execution)
52
68
  **Definition:**
53
- 1. The TOOL CALL is directly relevant and very likely to help resolve the user's need.
54
- 2. TOOL CALL include parameters values that are present or inferred from CONVERSATION.
55
- 3. TOOL CALL has parameters that is present in TOOL DEFINITION.
69
+ Tool calls were somewhat related to the user's query, but the agent was not able to reach information that helps address the user query due to one or more of the following:
70
+ • Parameters passed to the tool were incorrect.
71
+ • Not enough tools (available in the tool definitions) were called to fully help address the query (missing tool calls).
72
+ • Tools returned errors, and no retrials for the tool call were successful.
73
+
74
+
75
+ **Example:**
76
+ The user asks for the coordinates of Chicago. The agent calls the tool that gets the coordinates but passes 'New York' instead of Chicago as parameter.
77
+
78
+ **Example:**
79
+ The user asks for the coordinates of Chicago. The agent calls the tool that gets the coordinates and passes 'Chicago' as the tool parameter, but the tool returns an error.
80
+
81
+ **Example:**
82
+ The user asks a question that needs 3 tool calls for it to be answered. The agent calls only one of the three required tool calls. So this case is a Level 2.
83
+
84
+
85
+ ## [Tool Call Accuracy: 3] (Relevant but Inefficient)
86
+ **Definition:**
87
+ Tool calls were relevant, correct and grounded parameters were passed so that led to a correct output. However, multiple excessive, unnecessary tool calls were made.
88
+
89
+ **Important**: Do NOT penalize built-in tools like file_search that naturally return multiple results in a single call. Only penalize when there are actually multiple separate tool call objects.
90
+
91
+ **Example:**
92
+ The user asked to do a modification in the database. The agent called the tool multiple times, resulting in multiple modifications in the database instead of one.
93
+
94
+ **Example:**
95
+ The user asked for popular hotels in a certain place. The agent calls the same tool with the same parameters multiple times, even though a single tool call that returns an output is sufficient. So there were unnecessary tool calls.
96
+
97
+
98
+ ## [Tool Call Accuracy: 4] (Correct with Retrials)
99
+ **Definition:**
100
+ Tool calls were fully relevant and efficient:
101
+ • Correct tools were called with the correct and grounded parameters, whether they are extracted from the conversation history or the current user query.
102
+ • A tool returned an error, but the agent retried calling the tool and successfully got an output.
103
+
104
+ **Example:**
105
+ The user asks for the weather forecast in a certain place. The agent calls the correct tool that retrieves the weather forecast, but the tool returns an error. The agent re-calls the tool once again and it returns the correct output. This is a Level 4.
106
+
107
+
108
+ ## [Tool Call Accuracy: 5] (Optimal Solution)
109
+ **Definition:**
110
+ Tool calls were fully relevant and efficient:
111
+ • Correct tools were called with the correct and grounded parameters, whether they are extracted from the conversation history or the current user query.
112
+ • No unnecessary or excessive tool calls were made.
113
+ • No errors occurred in any of the tools.
114
+ • The tool calls made helped the agent address the user's query without facing any issues.
115
+
116
+ **Example:**
117
+ The user asks for the distance between two places. The agent correctly calls the tools that retrieve the coordinates for the two places respectively, then calls the tool that calculates the distance between the two sets of coordinates, passing the correct arguments to all the tools, without calling other tools excessively or unnecessarily. This is the optimal solution for the user's query.
118
+
119
+ **Example:**
120
+ The user asks for the distance between two places. The agent retrieves the needed coordinates from the outputs of the tool calls in the conversation history, and then correctly passes these coordinates to the tool that calculates the distance to output it to the user. This is also an optimal solution for the user's query.
121
+
122
+ **Example:**
123
+ The user asked to summarize a file on their SharePoint. The agent calls the sharepoint_grounding tool to retrieve the file. This retrieved file will help the agent fulfill the task of summarization. This is a Level 5.
124
+
125
+
126
+ ## Chain of Thought Structure
127
+ Structure your reasoning as follows:
128
+ 1. **Start with the user's last query**: Understand well what the last message that is sent by the user is.
129
+ 2. **Identify relevant available tools**: Look into the TOOL DEFINITIONS and analyze which tools could help answer the user's last query in the conversation.
130
+ 3. **Analyze the actual tool calls made**: Compare what was done in the TOOL CALLS TO BE EVALUATED section vs. What should've been done by the agent.
131
+ 4. **Check parameter grounding** - Ensure all parameters are grounded from the CONVERSATION section and are not hallucinated.
132
+ 5. **Determine the appropriate level** - Be VERY precise and follow the level definitions exactly.
56
133
 
57
134
  # Data
58
135
  CONVERSATION : {{query}}
59
- TOOL CALL: {{tool_call}}
60
- TOOL DEFINITION: {{tool_definition}}
136
+ TOOL CALLS TO BE EVALUATED: {{tool_calls}}
137
+ TOOL DEFINITIONS: {{tool_definitions}}
61
138
 
62
139
 
63
140
  # Tasks
64
- ## Please provide your assessment Score for the previous CONVERSATION , TOOL CALL and TOOL DEFINITION based on the Definitions above. Your output should include the following information:
65
- - **ThoughtChain**: To improve the reasoning process, think step by step and include a step-by-step explanation of your thought process as you analyze the data based on the definitions. Keep it brief and start your ThoughtChain with "Let's think step by step:".
66
- - **Explanation**: a very short explanation of why you think the input Data should get that Score.
67
- - **Score**: based on your previous analysis, provide your Score. The Score you give MUST be a integer score (i.e., "0", "1") based on the levels of the definitions.
68
-
141
+ ## Please provide your evaluation for the assistant RESPONSE in relation to the user QUERY and tool definitions based on the Definitions and examples above.
142
+ Your output should consist only of a JSON object, as provided in the examples, that has the following keys:
143
+ - chain_of_thought: a string that explains your thought process to decide on the tool call accuracy level, based on the Chain of Thought structure. Start this string with 'Let's think step by step:'.
144
+ - tool_calls_success_level: a integer value between 1 and 5 that represents the level of tool call success, based on the level definitions mentioned before. You need to be very precise when deciding on this level. Ensure you are correctly following the rating system based on the description of each level.
145
+ - details: a dictionary that contains the following keys:
146
+ - tool_calls_made_by_agent: total number of tool calls made by the agent
147
+ - correct_tool_calls_made_by_agent: total number of correct tool calls made by the agent
148
+ - per_tool_call_details: a list of dictionaries, each containing:
149
+ - tool_name: name of the tool
150
+ - total_calls_required: total number of calls required for the tool
151
+ - correct_calls_made_by_agent: number of correct calls made by the agent
152
+ - correct_tool_percentage: percentage of correct calls made by the agent for this tool. It is a value between 0.0 and 1.0
153
+ - tool_call_errors: number of errors encountered during the tool call
154
+ - tool_success_result: 'pass' or 'fail' based on the evaluation of the tool call accuracy for this tool
155
+ - excess_tool_calls: a dictionary with the following keys:
156
+ - total: total number of excess, unnecessary tool calls made by the agent
157
+ - details: a list of dictionaries, each containing:
158
+ - tool_name: name of the tool
159
+ - excess_count: number of excess calls made for this query
160
+ - missing_tool_calls: a dictionary with the following keys:
161
+ - total: total number of missing tool calls that should have been made by the agent to be able to answer the query, but were not made by the agent at all.
162
+ - details: a list of dictionaries, each containing:
163
+ - tool_name: name of the tool
164
+ - missing_count: number of missing calls for this query
69
165
 
70
- ## Please provide your answers between the tags: <S0>your chain of thoughts</S0>, <S1>your explanation</S1>, <S2>your Score</S2>.
71
166
  # Output
@@ -58,19 +58,26 @@ class UngroundedAttributesEvaluator(RaiServiceEvaluatorBase[Union[str, bool]]):
58
58
  for the ungrounded attributes will be "ungrounded_attributes_label".
59
59
  """
60
60
 
61
- id = "ungrounded_attributes"
61
+ id = "azureai://built-in/evaluators/ungrounded_attributes"
62
62
  """Evaluator identifier, experimental and to be used only with evaluation in cloud."""
63
+ _OPTIONAL_PARAMS = ["query"]
63
64
 
64
65
  @override
65
66
  def __init__(
66
67
  self,
67
68
  credential,
68
69
  azure_ai_project,
70
+ **kwargs,
69
71
  ):
72
+ # Set default for evaluate_query if not provided
73
+ if "evaluate_query" not in kwargs:
74
+ kwargs["evaluate_query"] = True
75
+
70
76
  super().__init__(
71
77
  eval_metric=EvaluationMetrics.UNGROUNDED_ATTRIBUTES,
72
78
  azure_ai_project=azure_ai_project,
73
79
  credential=credential,
80
+ **kwargs,
74
81
  )
75
82
 
76
83
  @overload
@@ -67,19 +67,22 @@ class IndirectAttackEvaluator(RaiServiceEvaluatorBase[Union[str, bool]]):
67
67
 
68
68
  """
69
69
 
70
- id = "azureml://registries/azureml/models/Indirect-Attack-Evaluator/versions/3"
70
+ id = "azureai://built-in/evaluators/indirect_attack"
71
71
  """Evaluator identifier, experimental and to be used only with evaluation in cloud."""
72
+ _OPTIONAL_PARAMS = ["query"]
72
73
 
73
74
  @override
74
75
  def __init__(
75
76
  self,
76
77
  credential,
77
78
  azure_ai_project,
79
+ **kwargs,
78
80
  ):
79
81
  super().__init__(
80
82
  eval_metric=EvaluationMetrics.XPIA,
81
83
  azure_ai_project=azure_ai_project,
82
84
  credential=credential,
85
+ **kwargs,
83
86
  )
84
87
 
85
88
  @overload
@@ -48,6 +48,7 @@ class ErrorCategory(Enum):
48
48
  PROJECT_ACCESS_ERROR = "PROJECT ACCESS ERROR"
49
49
  UNKNOWN = "UNKNOWN"
50
50
  UPLOAD_ERROR = "UPLOAD ERROR"
51
+ NOT_APPLICABLE = "NOT APPLICABLE"
51
52
 
52
53
 
53
54
  class ErrorBlame(Enum):
@@ -19,7 +19,7 @@ class BatchEngineConfig:
19
19
  batch_timeout_seconds: int = PF_BATCH_TIMEOUT_SEC_DEFAULT
20
20
  """The maximum amount of time to wait for all evaluations in the batch to complete."""
21
21
 
22
- run_timeout_seconds: int = 600
22
+ line_timeout_seconds: int = 600
23
23
  """The maximum amount of time to wait for an evaluation to run against a single entry
24
24
  in the data input to complete."""
25
25
 
@@ -32,13 +32,16 @@ class BatchEngineConfig:
32
32
  default_num_results: int = 100
33
33
  """The default number of results to return if you don't ask for all results."""
34
34
 
35
+ raise_on_error: bool = True
36
+ """Whether to raise an error if an evaluation fails."""
37
+
35
38
  def __post_init__(self):
36
39
  if self.logger is None:
37
40
  raise ValueError("logger cannot be None")
38
41
  if self.batch_timeout_seconds <= 0:
39
42
  raise ValueError("batch_timeout_seconds must be greater than 0")
40
- if self.run_timeout_seconds <= 0:
41
- raise ValueError("run_timeout_seconds must be greater than 0")
43
+ if self.line_timeout_seconds <= 0:
44
+ raise ValueError("line_timeout_seconds must be greater than 0")
42
45
  if self.max_concurrency <= 0:
43
46
  raise ValueError("max_concurrency must be greater than 0")
44
47
  if self.default_num_results <= 0:
@@ -20,15 +20,31 @@ from concurrent.futures import Executor
20
20
  from functools import partial
21
21
  from contextlib import contextmanager
22
22
  from datetime import datetime, timezone
23
- from typing import Any, Callable, Dict, Final, Generator, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, cast
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Dict,
27
+ Final,
28
+ Generator,
29
+ List,
30
+ Mapping,
31
+ MutableMapping,
32
+ Optional,
33
+ Sequence,
34
+ Set,
35
+ Tuple,
36
+ cast,
37
+ Literal,
38
+ )
24
39
  from uuid import uuid4
25
40
 
41
+ from ._config import BatchEngineConfig
26
42
  from ._utils import DEFAULTS_KEY, get_int_env_var, get_value_from_path, is_async_callable
27
43
  from ._status import BatchStatus
28
44
  from ._result import BatchResult, BatchRunDetails, BatchRunError, TokenMetrics
29
45
  from ._run_storage import AbstractRunStorage, NoOpRunStorage
30
- from .._common._logging import log_progress, NodeLogManager
31
- from ..._exceptions import ErrorBlame
46
+ from .._common._logging import log_progress, logger, NodeLogManager
47
+ from ..._exceptions import ErrorBlame, EvaluationException
32
48
  from ._exceptions import (
33
49
  BatchEngineCanceledError,
34
50
  BatchEngineError,
@@ -54,30 +70,25 @@ class BatchEngine:
54
70
  self,
55
71
  func: Callable,
56
72
  *,
73
+ config: BatchEngineConfig,
57
74
  storage: Optional[AbstractRunStorage] = None,
58
- batch_timeout_sec: Optional[int] = None,
59
- line_timeout_sec: Optional[int] = None,
60
- max_worker_count: Optional[int] = None,
61
75
  executor: Optional[Executor] = None,
62
76
  ):
63
77
  """Create a new batch engine instance
64
78
 
65
79
  :param Callable func: The function to run the flow
80
+ :param BatchEngineConfig config: The configuration for the batch engine
66
81
  :param Optional[AbstractRunStorage] storage: The storage to store execution results
67
- :param Optional[int] batch_timeout_sec: The timeout of batch run in seconds
68
- :param Optional[int] line_timeout_sec: The timeout of each line in seconds
69
- :param Optional[int] max_worker_count: The concurrency limit of batch run
70
82
  :param Optional[Executor] executor: The executor to run the flow (if needed)
71
83
  """
72
84
 
73
85
  self._func: Callable = func
86
+ self._config: BatchEngineConfig = config
74
87
  self._storage: AbstractRunStorage = storage or NoOpRunStorage()
75
88
 
76
- # TODO ralphe: Consume these from the batch context/config instead of from
77
- # kwargs or (even worse) environment variables
78
- self._batch_timeout_sec = batch_timeout_sec or get_int_env_var("PF_BATCH_TIMEOUT_SEC")
79
- self._line_timeout_sec = line_timeout_sec or get_int_env_var("PF_LINE_TIMEOUT_SEC", 600)
80
- self._max_worker_count = max_worker_count or get_int_env_var("PF_WORKER_COUNT") or MAX_WORKER_COUNT
89
+ self._batch_timeout_sec = self._config.batch_timeout_seconds
90
+ self._line_timeout_sec = self._config.line_timeout_seconds
91
+ self._max_worker_count = self._config.max_concurrency
81
92
 
82
93
  self._executor: Optional[Executor] = executor
83
94
  self._is_canceled: bool = False
@@ -85,15 +96,13 @@ class BatchEngine:
85
96
  async def run(
86
97
  self,
87
98
  data: Sequence[Mapping[str, Any]],
88
- column_mapping: Mapping[str, str],
99
+ column_mapping: Optional[Mapping[str, str]],
89
100
  *,
90
101
  id: Optional[str] = None,
91
102
  max_lines: Optional[int] = None,
92
103
  ) -> BatchResult:
93
104
  if not data:
94
105
  raise BatchEngineValidationError("Please provide a non-empty data mapping.")
95
- if not column_mapping:
96
- raise BatchEngineValidationError("The column mapping is required.")
97
106
 
98
107
  start_time = datetime.now(timezone.utc)
99
108
 
@@ -105,6 +114,8 @@ class BatchEngine:
105
114
  id = id or str(uuid4())
106
115
  result: BatchResult = await self._exec_in_task(id, batch_inputs, start_time)
107
116
  return result
117
+ except EvaluationException:
118
+ raise
108
119
  except Exception as ex:
109
120
  raise BatchEngineError(
110
121
  "Unexpected error while running the batch run.", blame=ErrorBlame.SYSTEM_ERROR
@@ -114,20 +125,58 @@ class BatchEngine:
114
125
  # TODO ralphe: Make sure this works
115
126
  self._is_canceled = True
116
127
 
117
- @staticmethod
118
128
  def _apply_column_mapping(
129
+ self,
119
130
  data: Sequence[Mapping[str, Any]],
120
- column_mapping: Mapping[str, str],
131
+ column_mapping: Optional[Mapping[str, str]],
121
132
  max_lines: Optional[int],
122
133
  ) -> Sequence[Mapping[str, str]]:
134
+
135
+ resolved_column_mapping: Mapping[str, str] = self._resolve_column_mapping(column_mapping)
136
+ resolved_column_mapping.update(self._generate_defaults_for_column_mapping())
137
+ return self._apply_column_mapping_to_lines(data, resolved_column_mapping, max_lines)
138
+
139
+ def _resolve_column_mapping(
140
+ self,
141
+ column_mapping: Optional[Mapping[str, str]],
142
+ ) -> Mapping[str, str]:
143
+ parameters = inspect.signature(self._func).parameters
144
+ default_column_mapping: Dict[str, str] = {
145
+ name: f"${{data.{name}}}"
146
+ for name, value in parameters.items()
147
+ if name not in ["self", "cls", "args", "kwargs"]
148
+ }
149
+ resolved_mapping: Dict[str, str] = default_column_mapping.copy()
150
+
151
+ for name, value in parameters.items():
152
+ if value and value.default is not inspect.Parameter.empty:
153
+ resolved_mapping.pop(name)
154
+
155
+ resolved_mapping.update(column_mapping or {})
156
+ return resolved_mapping
157
+
158
+ def _generate_defaults_for_column_mapping(self) -> Mapping[Literal["$defaults$"], Any]:
159
+
160
+ return {
161
+ DEFAULTS_KEY: {
162
+ name: value.default
163
+ for name, value in inspect.signature(self._func).parameters.items()
164
+ if value.default is not inspect.Parameter.empty
165
+ }
166
+ }
167
+
168
+ @staticmethod
169
+ def _apply_column_mapping_to_lines(
170
+ data: Sequence[Mapping[str, Any]],
171
+ column_mapping: Mapping[str, str],
172
+ max_lines: Optional[int],
173
+ ) -> Sequence[Mapping[str, Any]]:
123
174
  data = data[:max_lines] if max_lines else data
124
175
 
125
176
  inputs: Sequence[Mapping[str, Any]] = []
126
- line: int = 0
127
177
  defaults = cast(Mapping[str, Any], column_mapping.get(DEFAULTS_KEY, {}))
128
178
 
129
- for input in data:
130
- line += 1
179
+ for line_number, input in enumerate(data, start=1):
131
180
  mapped: Dict[str, Any] = {}
132
181
  missing_inputs: Set[str] = set()
133
182
 
@@ -148,18 +197,18 @@ class BatchEngine:
148
197
  continue
149
198
 
150
199
  dict_path = match.group(1)
151
- found, value = get_value_from_path(dict_path, input)
200
+ found, mapped_value = get_value_from_path(dict_path, input)
152
201
  if not found: # try default value
153
- found, value = get_value_from_path(dict_path, defaults)
202
+ found, mapped_value = get_value_from_path(dict_path, defaults)
154
203
 
155
204
  if found:
156
- mapped[key] = value
205
+ mapped[key] = mapped_value
157
206
  else:
158
207
  missing_inputs.add(dict_path)
159
208
 
160
209
  if missing_inputs:
161
210
  missing = ", ".join(missing_inputs)
162
- raise BatchEngineValidationError(f"Missing inputs for line {line}: '{missing}'")
211
+ raise BatchEngineValidationError(f"Missing inputs for line {line_number}: '{missing}'")
163
212
 
164
213
  inputs.append(mapped)
165
214
 
@@ -212,10 +261,12 @@ class BatchEngine:
212
261
  end_time=None,
213
262
  tokens=TokenMetrics(0, 0, 0),
214
263
  error=BatchRunError("The line run is not completed.", None),
264
+ index=i,
215
265
  )
216
266
  )
217
267
  for i in range(len(batch_inputs))
218
268
  ]
269
+ self.handle_line_failures(result_details)
219
270
 
220
271
  for line_result in result_details:
221
272
  # Indicate the worst status of the batch run. This works because
@@ -229,9 +280,15 @@ class BatchEngine:
229
280
  metrics.total_tokens += line_result.tokens.total_tokens
230
281
 
231
282
  if failed_lines and not error:
232
- error = BatchEngineRunFailedError(
233
- str(floor(failed_lines / len(batch_inputs) * 100)) + f"% of the batch run failed."
283
+ error_message = f"{floor(failed_lines / len(batch_inputs) * 100)}% of the batch run failed."
284
+ first_exception: Optional[Exception] = next(
285
+ (result.error.exception for result in result_details if result.error and result.error.exception),
286
+ None,
234
287
  )
288
+ if first_exception is not None:
289
+ error_message += f" {first_exception}"
290
+
291
+ error = BatchEngineRunFailedError(error_message)
235
292
 
236
293
  return BatchResult(
237
294
  status=status,
@@ -283,6 +340,13 @@ class BatchEngine:
283
340
  # TODO ralphe: set logger to use here
284
341
  )
285
342
 
343
+ def __preprocess_inputs(self, inputs: Mapping[str, Any]) -> Mapping[str, Any]:
344
+
345
+ func_params = inspect.signature(self._func).parameters
346
+
347
+ filtered_params = {key: value for key, value in inputs.items() if key in func_params}
348
+ return filtered_params
349
+
286
350
  async def _exec_line_async(
287
351
  self,
288
352
  run_id: str,
@@ -298,6 +362,7 @@ class BatchEngine:
298
362
  end_time=None,
299
363
  tokens=TokenMetrics(0, 0, 0),
300
364
  error=None,
365
+ index=index,
301
366
  )
302
367
 
303
368
  try:
@@ -313,13 +378,15 @@ class BatchEngine:
313
378
  # For now we will just run the function in the current process, but in the future we may
314
379
  # want to consider running the function in a separate process for isolation reasons.
315
380
  output: Any
381
+
382
+ processed_inputs = self.__preprocess_inputs(inputs)
316
383
  if is_async_callable(self._func):
317
- output = await self._func(**inputs)
384
+ output = await self._func(**processed_inputs)
318
385
  else:
319
386
  # to maximize the parallelism, we run the synchronous function in a separate thread
320
387
  # and await its result
321
388
  output = await asyncio.get_event_loop().run_in_executor(
322
- self._executor, partial(self._func, **inputs)
389
+ self._executor, partial(self._func, **processed_inputs)
323
390
  )
324
391
 
325
392
  # This should in theory never happen but as an extra precaution, let's check if the output
@@ -340,6 +407,24 @@ class BatchEngine:
340
407
 
341
408
  return index, details
342
409
 
410
+ @staticmethod
411
+ def handle_line_failures(run_infos: List[BatchRunDetails], raise_on_line_failure: bool = False):
412
+ """Handle line failures in batch run"""
413
+ failed_run_infos: List[BatchRunDetails] = [r for r in run_infos if r.status == BatchStatus.Failed]
414
+ failed_msg: Optional[str] = None
415
+ if len(failed_run_infos) > 0:
416
+ failed_indexes = ",".join([str(r.index) for r in failed_run_infos])
417
+ first_fail_exception: str = failed_run_infos[0].error.details
418
+ if raise_on_line_failure:
419
+ failed_msg = "Flow run failed due to the error: " + first_fail_exception
420
+ raise Exception(failed_msg)
421
+
422
+ failed_msg = (
423
+ f"{len(failed_run_infos)}/{len(run_infos)} flow run failed, indexes: [{failed_indexes}],"
424
+ f" exception of index {failed_run_infos[0].index}: {first_fail_exception}"
425
+ )
426
+ logger.error(failed_msg)
427
+
343
428
  def _persist_run_info(self, line_results: Sequence[BatchRunDetails]):
344
429
  # TODO ralphe: implement?
345
430
  pass
@@ -55,6 +55,8 @@ class BatchRunDetails:
55
55
  """The token metrics of the line run."""
56
56
  error: Optional[BatchRunError]
57
57
  """The error of the line run. This will only be set if the status is Failed."""
58
+ index: int
59
+ """The line run index."""
58
60
 
59
61
  @property
60
62
  def duration(self) -> timedelta:
@@ -58,7 +58,7 @@ class Run:
58
58
  dynamic_callable: Callable,
59
59
  name_prefix: Optional[str],
60
60
  inputs: Sequence[Mapping[str, Any]],
61
- column_mapping: Mapping[str, str],
61
+ column_mapping: Optional[Mapping[str, str]] = None,
62
62
  created_on: Optional[datetime] = None,
63
63
  run: Optional["Run"] = None,
64
64
  ):
@@ -70,7 +70,7 @@ class Run:
70
70
  self.dynamic_callable = dynamic_callable
71
71
  self.name = self._generate_run_name(name_prefix, self._created_on)
72
72
  self.inputs = inputs
73
- self.column_mapping = column_mapping
73
+ self.column_mapping: Optional[Mapping[str, str]] = column_mapping
74
74
  self.result: Optional[BatchResult] = None
75
75
  self.metrics: Mapping[str, Any] = {}
76
76
  self._run = run