google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. google/adk/agents/active_streaming_tool.py +1 -0
  2. google/adk/agents/base_agent.py +91 -47
  3. google/adk/agents/base_agent.py.orig +330 -0
  4. google/adk/agents/callback_context.py +4 -9
  5. google/adk/agents/invocation_context.py +1 -0
  6. google/adk/agents/langgraph_agent.py +1 -0
  7. google/adk/agents/live_request_queue.py +1 -0
  8. google/adk/agents/llm_agent.py +172 -35
  9. google/adk/agents/loop_agent.py +1 -1
  10. google/adk/agents/parallel_agent.py +7 -0
  11. google/adk/agents/readonly_context.py +7 -1
  12. google/adk/agents/run_config.py +5 -1
  13. google/adk/agents/sequential_agent.py +31 -0
  14. google/adk/agents/transcription_entry.py +5 -2
  15. google/adk/artifacts/base_artifact_service.py +5 -10
  16. google/adk/artifacts/gcs_artifact_service.py +9 -9
  17. google/adk/artifacts/in_memory_artifact_service.py +6 -6
  18. google/adk/auth/auth_credential.py +9 -5
  19. google/adk/auth/auth_preprocessor.py +7 -1
  20. google/adk/auth/auth_tool.py +3 -4
  21. google/adk/cli/agent_graph.py +5 -5
  22. google/adk/cli/browser/index.html +2 -2
  23. google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
  24. google/adk/cli/cli.py +7 -7
  25. google/adk/cli/cli_deploy.py +7 -2
  26. google/adk/cli/cli_eval.py +181 -106
  27. google/adk/cli/cli_tools_click.py +147 -62
  28. google/adk/cli/fast_api.py +340 -158
  29. google/adk/cli/fast_api.py.orig +822 -0
  30. google/adk/cli/utils/common.py +23 -0
  31. google/adk/cli/utils/evals.py +83 -1
  32. google/adk/cli/utils/logs.py +13 -5
  33. google/adk/code_executors/__init__.py +3 -1
  34. google/adk/code_executors/built_in_code_executor.py +52 -0
  35. google/adk/evaluation/__init__.py +1 -1
  36. google/adk/evaluation/agent_evaluator.py +168 -128
  37. google/adk/evaluation/eval_case.py +102 -0
  38. google/adk/evaluation/eval_set.py +37 -0
  39. google/adk/evaluation/eval_sets_manager.py +42 -0
  40. google/adk/evaluation/evaluation_constants.py +1 -0
  41. google/adk/evaluation/evaluation_generator.py +89 -114
  42. google/adk/evaluation/evaluator.py +56 -0
  43. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  44. google/adk/evaluation/response_evaluator.py +107 -3
  45. google/adk/evaluation/trajectory_evaluator.py +83 -2
  46. google/adk/events/event.py +7 -1
  47. google/adk/events/event_actions.py +7 -1
  48. google/adk/examples/example.py +1 -0
  49. google/adk/examples/example_util.py +3 -2
  50. google/adk/flows/__init__.py +0 -1
  51. google/adk/flows/llm_flows/_code_execution.py +19 -11
  52. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  53. google/adk/flows/llm_flows/base_llm_flow.py +86 -22
  54. google/adk/flows/llm_flows/basic.py +3 -0
  55. google/adk/flows/llm_flows/functions.py +10 -9
  56. google/adk/flows/llm_flows/instructions.py +28 -9
  57. google/adk/flows/llm_flows/single_flow.py +1 -1
  58. google/adk/memory/__init__.py +1 -1
  59. google/adk/memory/_utils.py +23 -0
  60. google/adk/memory/base_memory_service.py +25 -21
  61. google/adk/memory/base_memory_service.py.orig +76 -0
  62. google/adk/memory/in_memory_memory_service.py +59 -27
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
  65. google/adk/models/anthropic_llm.py +36 -11
  66. google/adk/models/base_llm.py +45 -4
  67. google/adk/models/gemini_llm_connection.py +15 -2
  68. google/adk/models/google_llm.py +9 -44
  69. google/adk/models/google_llm.py.orig +305 -0
  70. google/adk/models/lite_llm.py +94 -38
  71. google/adk/models/llm_request.py +1 -1
  72. google/adk/models/llm_response.py +15 -3
  73. google/adk/models/registry.py +1 -1
  74. google/adk/runners.py +68 -44
  75. google/adk/sessions/__init__.py +1 -1
  76. google/adk/sessions/_session_util.py +14 -0
  77. google/adk/sessions/base_session_service.py +8 -32
  78. google/adk/sessions/database_session_service.py +58 -61
  79. google/adk/sessions/in_memory_session_service.py +108 -26
  80. google/adk/sessions/session.py +4 -0
  81. google/adk/sessions/vertex_ai_session_service.py +23 -45
  82. google/adk/telemetry.py +3 -0
  83. google/adk/tools/__init__.py +4 -7
  84. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  85. google/adk/tools/_memory_entry_utils.py +30 -0
  86. google/adk/tools/agent_tool.py +16 -13
  87. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +58 -0
  93. google/adk/tools/enterprise_search_tool.py +65 -0
  94. google/adk/tools/function_parameter_parse_util.py +2 -2
  95. google/adk/tools/google_api_tool/__init__.py +18 -70
  96. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  97. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  98. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  99. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  100. google/adk/tools/langchain_tool.py +96 -49
  101. google/adk/tools/load_artifacts_tool.py +4 -4
  102. google/adk/tools/load_memory_tool.py +16 -5
  103. google/adk/tools/mcp_tool/__init__.py +3 -2
  104. google/adk/tools/mcp_tool/conversion_utils.py +1 -1
  105. google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
  106. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  107. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  108. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  109. google/adk/tools/openapi_tool/common/common.py +2 -5
  110. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  111. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
  112. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  113. google/adk/tools/preload_memory_tool.py +27 -18
  114. google/adk/tools/retrieval/__init__.py +1 -1
  115. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  116. google/adk/tools/tool_context.py +4 -4
  117. google/adk/tools/toolbox_toolset.py +79 -0
  118. google/adk/tools/transfer_to_agent_tool.py +0 -1
  119. google/adk/version.py +1 -1
  120. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  121. google_adk-1.0.0.dist-info/RECORD +195 -0
  122. google/adk/agents/remote_agent.py +0 -50
  123. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  124. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  125. google/adk/tools/toolbox_tool.py +0 -46
  126. google_adk-0.4.0.dist-info/RECORD +0 -179
  127. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  128. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  129. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/cli.py CHANGED
@@ -55,7 +55,7 @@ async def run_input_file(
55
55
  input_file = InputFile.model_validate_json(f.read())
56
56
  input_file.state['_time'] = datetime.now()
57
57
 
58
- session = session_service.create_session(
58
+ session = await session_service.create_session(
59
59
  app_name=app_name, user_id=user_id, state=input_file.state
60
60
  )
61
61
  for query in input_file.queries:
@@ -105,6 +105,7 @@ async def run_cli(
105
105
  input_file: Optional[str] = None,
106
106
  saved_session_file: Optional[str] = None,
107
107
  save_session: bool,
108
+ session_id: Optional[str] = None,
108
109
  ) -> None:
109
110
  """Runs an interactive CLI for a certain agent.
110
111
 
@@ -118,6 +119,7 @@ async def run_cli(
118
119
  saved_session_file: Optional[str], the absolute path to the json file that
119
120
  contains a previously saved session, exclusive with input_file.
120
121
  save_session: bool, whether to save the session on exit.
122
+ session_id: Optional[str], the session ID to save the session to on exit.
121
123
  """
122
124
  if agent_parent_dir not in sys.path:
123
125
  sys.path.append(agent_parent_dir)
@@ -128,7 +130,7 @@ async def run_cli(
128
130
  agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
129
131
  agent_module = importlib.import_module(agent_folder_name)
130
132
  user_id = 'test_user'
131
- session = session_service.create_session(
133
+ session = await session_service.create_session(
132
134
  app_name=agent_folder_name, user_id=user_id
133
135
  )
134
136
  root_agent = agent_module.agent.root_agent
@@ -143,14 +145,12 @@ async def run_cli(
143
145
  input_path=input_file,
144
146
  )
145
147
  elif saved_session_file:
146
-
147
- loaded_session = None
148
148
  with open(saved_session_file, 'r') as f:
149
149
  loaded_session = Session.model_validate_json(f.read())
150
150
 
151
151
  if loaded_session:
152
152
  for event in loaded_session.events:
153
- session_service.append_event(session, event)
153
+ await session_service.append_event(session, event)
154
154
  content = event.content
155
155
  if not content or not content.parts or not content.parts[0].text:
156
156
  continue
@@ -175,11 +175,11 @@ async def run_cli(
175
175
  )
176
176
 
177
177
  if save_session:
178
- session_id = input('Session ID to save: ')
178
+ session_id = session_id or input('Session ID to save: ')
179
179
  session_path = f'{agent_module_path}/{session_id}.session.json'
180
180
 
181
181
  # Fetch the session again to get all the details.
182
- session = session_service.get_session(
182
+ session = await session_service.get_session(
183
183
  app_name=session.app_name,
184
184
  user_id=session.user_id,
185
185
  session_id=session.id,
@@ -42,7 +42,7 @@ ENV GOOGLE_CLOUD_LOCATION={gcp_region}
42
42
  # Set up environment variables - End
43
43
 
44
44
  # Install ADK - Start
45
- RUN pip install google-adk
45
+ RUN pip install google-adk=={adk_version}
46
46
  # Install ADK - End
47
47
 
48
48
  # Copy agent - Start
@@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
54
54
 
55
55
  EXPOSE {port}
56
56
 
57
- CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
57
+ CMD adk {command} --port={port} {host_option} {session_db_option} {trace_to_cloud_option} "/app/agents"
58
58
  """
59
59
 
60
60
 
@@ -86,6 +86,7 @@ def to_cloud_run(
86
86
  with_ui: bool,
87
87
  verbosity: str,
88
88
  session_db_url: str,
89
+ adk_version: str,
89
90
  ):
90
91
  """Deploys an agent to Google Cloud Run.
91
92
 
@@ -114,6 +115,7 @@ def to_cloud_run(
114
115
  with_ui: Whether to deploy with UI.
115
116
  verbosity: The verbosity level of the CLI.
116
117
  session_db_url: The database URL to connect the session.
118
+ adk_version: The ADK version to use in Cloud Run.
117
119
  """
118
120
  app_name = app_name or os.path.basename(agent_folder)
119
121
 
@@ -139,6 +141,7 @@ def to_cloud_run(
139
141
 
140
142
  # create Dockerfile
141
143
  click.echo('Creating Dockerfile...')
144
+ host_option = '--host=0.0.0.0' if adk_version > '0.5.0' else ''
142
145
  dockerfile_content = _DOCKERFILE_TEMPLATE.format(
143
146
  gcp_project_id=project,
144
147
  gcp_region=region,
@@ -150,6 +153,8 @@ def to_cloud_run(
150
153
  if session_db_url
151
154
  else '',
152
155
  trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
156
+ adk_version=adk_version,
157
+ host_option=host_option,
153
158
  )
154
159
  dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
155
160
  os.makedirs(temp_folder, exist_ok=True)
@@ -12,47 +12,107 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from enum import Enum
16
15
  import importlib.util
17
16
  import json
18
17
  import logging
19
18
  import os
20
19
  import sys
21
- import traceback
22
20
  from typing import Any
23
- from typing import Generator
21
+ from typing import AsyncGenerator
24
22
  from typing import Optional
25
23
  import uuid
26
24
 
27
25
  from pydantic import BaseModel
26
+ from pydantic import Field
28
27
 
29
28
  from ..agents import Agent
29
+ from ..artifacts.base_artifact_service import BaseArtifactService
30
+ from ..evaluation.eval_case import EvalCase
31
+ from ..evaluation.eval_case import Invocation
32
+ from ..evaluation.evaluator import EvalStatus
33
+ from ..evaluation.evaluator import Evaluator
34
+ from ..sessions.base_session_service import BaseSessionService
35
+ from ..sessions.session import Session
36
+ from .utils import common
30
37
 
31
- logger = logging.getLogger(__name__)
38
+ logger = logging.getLogger("google_adk." + __name__)
32
39
 
33
40
 
34
- class EvalStatus(Enum):
35
- PASSED = 1
36
- FAILED = 2
37
- NOT_EVALUATED = 3
41
+ class EvalMetric(common.BaseModel):
42
+ """A metric used to evaluate a particular aspect of an eval case."""
38
43
 
39
-
40
- class EvalMetric(BaseModel):
41
44
  metric_name: str
45
+ """The name of the metric."""
46
+
42
47
  threshold: float
48
+ """A threshold value. Each metric decides how to interpret this threshold."""
49
+
43
50
 
51
+ class EvalMetricResult(EvalMetric):
52
+ """The actual computed score/value of a particular EvalMetric."""
44
53
 
45
- class EvalMetricResult(BaseModel):
46
- score: Optional[float]
54
+ score: Optional[float] = None
47
55
  eval_status: EvalStatus
48
56
 
49
57
 
50
- class EvalResult(BaseModel):
51
- eval_set_file: str
52
- eval_id: str
58
+ class EvalMetricResultPerInvocation(common.BaseModel):
59
+ """Eval metric results per invocation."""
60
+
61
+ actual_invocation: Invocation
62
+ """The actual invocation, usually obtained by inferencing the agent."""
63
+
64
+ expected_invocation: Invocation
65
+ """The expected invocation, usually the reference or golden invocation."""
66
+
67
+ eval_metric_results: list[EvalMetricResult] = []
68
+ """Eval resutls for each applicable metric."""
69
+
70
+
71
+ class EvalCaseResult(common.BaseModel):
72
+ """Case-level evaluation results."""
73
+
74
+ eval_set_file: str = Field(
75
+ deprecated=True,
76
+ description="This field is deprecated, use eval_set_id instead.",
77
+ )
78
+ eval_set_id: str = ""
79
+ """The eval set id."""
80
+
81
+ eval_id: str = ""
82
+ """The eval case id."""
83
+
53
84
  final_eval_status: EvalStatus
54
- eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
85
+ """Final evalu status for this eval case."""
86
+
87
+ eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
88
+ deprecated=True,
89
+ description=(
90
+ "This field is deprecated, use overall_eval_metric_results instead."
91
+ ),
92
+ )
93
+
94
+ overall_eval_metric_results: list[EvalMetricResult]
95
+ """Overall result for each metric for the entire eval case."""
96
+
97
+ eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
98
+ """Result for each metric on a per invocation basis."""
99
+
55
100
  session_id: str
101
+ """Session id of the session generated as result of inferencing/scraping stage of the eval."""
102
+
103
+ session_details: Optional[Session] = None
104
+ """Session generated as result of inferencing/scraping stage of the eval."""
105
+
106
+ user_id: Optional[str] = None
107
+ """User id used during inferencing/scraping stage of the eval."""
108
+
109
+
110
+ class EvalSetResult(common.BaseModel):
111
+ eval_set_result_id: str
112
+ eval_set_result_name: str
113
+ eval_set_id: str
114
+ eval_case_results: list[EvalCaseResult] = Field(default_factory=list)
115
+ creation_timestamp: float = 0.0
56
116
 
57
117
 
58
118
  MISSING_EVAL_DEPENDENCIES_MESSAGE = (
@@ -146,15 +206,26 @@ def parse_and_get_evals_to_run(
146
206
  return eval_set_to_evals
147
207
 
148
208
 
149
- def run_evals(
150
- eval_set_to_evals: dict[str, list[str]],
209
+ async def run_evals(
210
+ eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
151
211
  root_agent: Agent,
152
212
  reset_func: Optional[Any],
153
213
  eval_metrics: list[EvalMetric],
154
- session_service=None,
155
- artifact_service=None,
156
- print_detailed_results=False,
157
- ) -> Generator[EvalResult, None, None]:
214
+ session_service: Optional[BaseSessionService] = None,
215
+ artifact_service: Optional[BaseArtifactService] = None,
216
+ ) -> AsyncGenerator[EvalCaseResult, None]:
217
+ """Returns a stream of EvalCaseResult for each eval case that was evaluated.
218
+
219
+ Args:
220
+ eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
221
+ they belong.
222
+ root_agent: Agent to use for inferencing.
223
+ reset_func: If present, this will be called before invoking the agent before
224
+ every inferencing step.
225
+ eval_metrics: A list of metrics that should be used during evaluation.
226
+ session_service: The session service to use during inferencing.
227
+ artifact_service: The artifact service to use during inferencing.
228
+ """
158
229
  try:
159
230
  from ..evaluation.agent_evaluator import EvaluationGenerator
160
231
  from ..evaluation.response_evaluator import ResponseEvaluator
@@ -162,97 +233,96 @@ def run_evals(
162
233
  except ModuleNotFoundError as e:
163
234
  raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
164
235
 
165
- """Returns a summary of eval runs."""
166
- for eval_set_file, evals_to_run in eval_set_to_evals.items():
167
- with open(eval_set_file, "r", encoding="utf-8") as file:
168
- eval_items = json.load(file) # Load JSON into a list
169
-
170
- assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
171
-
172
- for eval_item in eval_items:
173
- eval_name = eval_item["name"]
174
- eval_data = eval_item["data"]
175
- initial_session = eval_item.get("initial_session", {})
176
-
177
- if evals_to_run and eval_name not in evals_to_run:
178
- continue
236
+ for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
237
+ for eval_case in eval_cases:
238
+ eval_name = eval_case.eval_id
239
+ initial_session = eval_case.session_input
240
+ user_id = initial_session.user_id if initial_session else "test_user_id"
179
241
 
180
242
  try:
181
- print(f"Running Eval: {eval_set_file}:{eval_name}")
243
+ print(f"Running Eval: {eval_set_id}:{eval_name}")
182
244
  session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
183
245
 
184
- scrape_result = EvaluationGenerator._process_query_with_root_agent(
185
- data=eval_data,
186
- root_agent=root_agent,
187
- reset_func=reset_func,
188
- initial_session=initial_session,
189
- session_id=session_id,
190
- session_service=session_service,
191
- artifact_service=artifact_service,
246
+ inference_result = (
247
+ await EvaluationGenerator._generate_inferences_from_root_agent(
248
+ invocations=eval_case.conversation,
249
+ root_agent=root_agent,
250
+ reset_func=reset_func,
251
+ initial_session=initial_session,
252
+ session_id=session_id,
253
+ session_service=session_service,
254
+ artifact_service=artifact_service,
255
+ )
192
256
  )
193
257
 
194
- eval_metric_results = []
258
+ # Initialize the per-invocation metric results to an empty list.
259
+ # We will fill this as we evaluate each metric.
260
+ eval_metric_result_per_invocation = []
261
+ for actual, expected in zip(inference_result, eval_case.conversation):
262
+ eval_metric_result_per_invocation.append(
263
+ EvalMetricResultPerInvocation(
264
+ actual_invocation=actual,
265
+ expected_invocation=expected,
266
+ eval_metric_results=[],
267
+ )
268
+ )
269
+
270
+ overall_eval_metric_results = []
271
+
195
272
  for eval_metric in eval_metrics:
196
- eval_metric_result = None
197
- if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
198
- score = TrajectoryEvaluator.evaluate(
199
- [scrape_result], print_detailed_results=print_detailed_results
200
- )
201
- eval_metric_result = _get_eval_metric_result(eval_metric, score)
202
- elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
203
- score = ResponseEvaluator.evaluate(
204
- [scrape_result],
205
- [RESPONSE_MATCH_SCORE_KEY],
206
- print_detailed_results=print_detailed_results,
207
- )
208
- eval_metric_result = _get_eval_metric_result(
209
- eval_metric, score["rouge_1/mean"].item()
210
- )
211
- elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
212
- score = ResponseEvaluator.evaluate(
213
- [scrape_result],
214
- [RESPONSE_EVALUATION_SCORE_KEY],
215
- print_detailed_results=print_detailed_results,
216
- )
217
- eval_metric_result = _get_eval_metric_result(
218
- eval_metric, score["coherence/mean"].item()
273
+ metric_evaluator = _get_evaluator(eval_metric)
274
+
275
+ evaluation_result = metric_evaluator.evaluate_invocations(
276
+ actual_invocations=inference_result,
277
+ expected_invocations=eval_case.conversation,
278
+ )
279
+
280
+ overall_eval_metric_results.append(
281
+ EvalMetricResult(
282
+ metric_name=eval_metric.metric_name,
283
+ threshold=eval_metric.threshold,
284
+ score=evaluation_result.overall_score,
285
+ eval_status=evaluation_result.overall_eval_status,
286
+ )
287
+ )
288
+ for index, per_invocation_result in enumerate(
289
+ evaluation_result.per_invocation_results
290
+ ):
291
+ eval_metric_result_per_invocation[index].eval_metric_results.append(
292
+ EvalMetricResult(
293
+ metric_name=eval_metric.metric_name,
294
+ threshold=eval_metric.threshold,
295
+ score=per_invocation_result.score,
296
+ eval_status=per_invocation_result.eval_status,
297
+ )
219
298
  )
220
- else:
221
- logger.warning("`%s` is not supported.", eval_metric.metric_name)
222
- eval_metric_results.append((
223
- eval_metric,
224
- EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
225
- ))
226
-
227
- eval_metric_results.append((
228
- eval_metric,
229
- eval_metric_result,
230
- ))
231
- _print_eval_metric_result(eval_metric, eval_metric_result)
232
299
 
233
300
  final_eval_status = EvalStatus.NOT_EVALUATED
234
-
235
301
  # Go over the all the eval statuses and mark the final eval status as
236
302
  # passed if all of them pass, otherwise mark the final eval status to
237
303
  # failed.
238
- for eval_metric_result in eval_metric_results:
239
- eval_status = eval_metric_result[1].eval_status
240
- if eval_status == EvalStatus.PASSED:
304
+ for overall_eval_metric_result in overall_eval_metric_results:
305
+ overall_eval_status = overall_eval_metric_result.eval_status
306
+ if overall_eval_status == EvalStatus.PASSED:
241
307
  final_eval_status = EvalStatus.PASSED
242
- elif eval_status == EvalStatus.NOT_EVALUATED:
308
+ elif overall_eval_status == EvalStatus.NOT_EVALUATED:
243
309
  continue
244
- elif eval_status == EvalStatus.FAILED:
310
+ elif overall_eval_status == EvalStatus.FAILED:
245
311
  final_eval_status = EvalStatus.FAILED
246
312
  break
247
313
  else:
248
314
  raise ValueError("Unknown eval status.")
249
315
 
250
- yield EvalResult(
251
- eval_set_file=eval_set_file,
316
+ yield EvalCaseResult(
317
+ eval_set_file=eval_set_id,
318
+ eval_set_id=eval_set_id,
252
319
  eval_id=eval_name,
253
320
  final_eval_status=final_eval_status,
254
- eval_metric_results=eval_metric_results,
321
+ eval_metric_results=[],
322
+ overall_eval_metric_results=overall_eval_metric_results,
323
+ eval_metric_result_per_invocation=eval_metric_result_per_invocation,
255
324
  session_id=session_id,
325
+ user_id=user_id,
256
326
  )
257
327
 
258
328
  if final_eval_status == EvalStatus.PASSED:
@@ -262,21 +332,26 @@ def run_evals(
262
332
 
263
333
  print(f"Result: {result}\n")
264
334
 
265
- except Exception as e:
266
- print(f"Error: {e}")
267
- logger.info("Error: %s", str(traceback.format_exc()))
268
-
269
-
270
- def _get_eval_metric_result(eval_metric, score):
271
- eval_status = (
272
- EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
273
- )
274
- return EvalMetricResult(score=score, eval_status=eval_status)
335
+ except Exception:
336
+ # Catching the general exception, so that we don't block other eval
337
+ # cases.
338
+ logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`")
275
339
 
276
340
 
277
- def _print_eval_metric_result(eval_metric, eval_metric_result):
278
- print(
279
- f"Metric: {eval_metric.metric_name}\tStatus:"
280
- f" {eval_metric_result.eval_status}\tScore:"
281
- f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
282
- )
341
+ def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
342
+ try:
343
+ from ..evaluation.response_evaluator import ResponseEvaluator
344
+ from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
345
+ except ModuleNotFoundError as e:
346
+ raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
347
+ if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
348
+ return TrajectoryEvaluator(threshold=eval_metric.threshold)
349
+ elif (
350
+ eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
351
+ or eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY
352
+ ):
353
+ return ResponseEvaluator(
354
+ threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
355
+ )
356
+
357
+ raise ValueError(f"Unsupported eval metric: {eval_metric}")