google-adk 1.5.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. google/adk/a2a/converters/event_converter.py +257 -36
  2. google/adk/a2a/converters/part_converter.py +93 -25
  3. google/adk/a2a/converters/request_converter.py +12 -32
  4. google/adk/a2a/converters/utils.py +22 -4
  5. google/adk/a2a/executor/__init__.py +13 -0
  6. google/adk/a2a/executor/a2a_agent_executor.py +260 -0
  7. google/adk/a2a/executor/task_result_aggregator.py +71 -0
  8. google/adk/a2a/logs/__init__.py +13 -0
  9. google/adk/a2a/logs/log_utils.py +349 -0
  10. google/adk/agents/base_agent.py +54 -0
  11. google/adk/agents/llm_agent.py +15 -0
  12. google/adk/agents/remote_a2a_agent.py +532 -0
  13. google/adk/artifacts/in_memory_artifact_service.py +6 -3
  14. google/adk/cli/browser/chunk-EQDQRRRY.js +1 -0
  15. google/adk/cli/browser/chunk-TXJFAAIW.js +2 -0
  16. google/adk/cli/browser/index.html +4 -3
  17. google/adk/cli/browser/main-RXDVX3K6.js +3914 -0
  18. google/adk/cli/browser/polyfills-FFHMD2TL.js +17 -0
  19. google/adk/cli/cli_deploy.py +4 -1
  20. google/adk/cli/cli_eval.py +8 -6
  21. google/adk/cli/cli_tools_click.py +30 -10
  22. google/adk/cli/fast_api.py +120 -5
  23. google/adk/cli/utils/agent_loader.py +12 -0
  24. google/adk/evaluation/agent_evaluator.py +107 -10
  25. google/adk/evaluation/base_eval_service.py +157 -0
  26. google/adk/evaluation/constants.py +20 -0
  27. google/adk/evaluation/eval_case.py +3 -3
  28. google/adk/evaluation/eval_metrics.py +39 -0
  29. google/adk/evaluation/evaluation_generator.py +1 -1
  30. google/adk/evaluation/final_response_match_v2.py +230 -0
  31. google/adk/evaluation/llm_as_judge.py +141 -0
  32. google/adk/evaluation/llm_as_judge_utils.py +48 -0
  33. google/adk/evaluation/metric_evaluator_registry.py +89 -0
  34. google/adk/evaluation/response_evaluator.py +38 -211
  35. google/adk/evaluation/safety_evaluator.py +54 -0
  36. google/adk/evaluation/trajectory_evaluator.py +16 -2
  37. google/adk/evaluation/vertex_ai_eval_facade.py +147 -0
  38. google/adk/events/event.py +2 -4
  39. google/adk/flows/llm_flows/base_llm_flow.py +2 -0
  40. google/adk/memory/in_memory_memory_service.py +3 -2
  41. google/adk/models/lite_llm.py +50 -10
  42. google/adk/runners.py +27 -10
  43. google/adk/sessions/database_session_service.py +25 -7
  44. google/adk/sessions/in_memory_session_service.py +5 -1
  45. google/adk/sessions/vertex_ai_session_service.py +67 -42
  46. google/adk/tools/bigquery/config.py +11 -1
  47. google/adk/tools/bigquery/query_tool.py +306 -12
  48. google/adk/tools/enterprise_search_tool.py +2 -2
  49. google/adk/tools/function_tool.py +7 -1
  50. google/adk/tools/google_search_tool.py +1 -1
  51. google/adk/tools/mcp_tool/mcp_session_manager.py +44 -30
  52. google/adk/tools/mcp_tool/mcp_tool.py +44 -7
  53. google/adk/version.py +1 -1
  54. {google_adk-1.5.0.dist-info → google_adk-1.6.1.dist-info}/METADATA +6 -4
  55. {google_adk-1.5.0.dist-info → google_adk-1.6.1.dist-info}/RECORD +58 -42
  56. google/adk/cli/browser/main-JAAWEV7F.js +0 -92
  57. google/adk/cli/browser/polyfills-B6TNHZQ6.js +0 -17
  58. {google_adk-1.5.0.dist-info → google_adk-1.6.1.dist-info}/WHEEL +0 -0
  59. {google_adk-1.5.0.dist-info → google_adk-1.6.1.dist-info}/entry_points.txt +0 -0
  60. {google_adk-1.5.0.dist-info → google_adk-1.6.1.dist-info}/licenses/LICENSE +0 -0
@@ -14,33 +14,55 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any
18
17
  from typing import Optional
19
18
 
20
- from google.genai import types as genai_types
21
- import pandas as pd
22
- from tabulate import tabulate
23
- from typing_extensions import deprecated
24
19
  from typing_extensions import override
25
- from vertexai.preview.evaluation import EvalTask
26
- from vertexai.preview.evaluation import MetricPromptTemplateExamples
20
+ from vertexai import types as vertexai_types
27
21
 
28
- from .eval_case import IntermediateData
29
22
  from .eval_case import Invocation
30
23
  from .eval_metrics import EvalMetric
31
- from .evaluator import EvalStatus
32
24
  from .evaluator import EvaluationResult
33
25
  from .evaluator import Evaluator
34
- from .evaluator import PerInvocationResult
35
26
  from .final_response_match_v1 import RougeEvaluator
27
+ from .vertex_ai_eval_facade import _VertexAiEvalFacade
36
28
 
37
29
 
38
30
  class ResponseEvaluator(Evaluator):
39
- """Runs response evaluation for agents."""
31
+ """Evaluates Agent's responses.
32
+
33
+ This class supports two metrics:
34
+ 1) response_evaluation_score
35
+ This metric evaluates how coherent agent's resposne was.
36
+
37
+ Value range of this metric is [1,5], with values closer to 5 more desirable.
38
+
39
+ 2) response_match_score:
40
+ This metric evaluates if agent's final response matches a golden/expected
41
+ final response.
42
+
43
+ Value range for this metric is [0,1], with values closer to 1 more desirable.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ threshold: Optional[float] = None,
49
+ metric_name: Optional[str] = None,
50
+ eval_metric: Optional[EvalMetric] = None,
51
+ ):
52
+ if (threshold is not None and eval_metric) or (
53
+ metric_name is not None and eval_metric
54
+ ):
55
+ raise ValueError(
56
+ "Either eval_metric should be specified or both threshold and"
57
+ " metric_name should be specified."
58
+ )
59
+
60
+ if eval_metric:
61
+ threshold = eval_metric.threshold
62
+ metric_name = eval_metric.metric_name
40
63
 
41
- def __init__(self, threshold: float, metric_name: str):
42
64
  if "response_evaluation_score" == metric_name:
43
- self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
65
+ self._metric_name = vertexai_types.PrebuiltMetric.COHERENCE
44
66
  elif "response_match_score" == metric_name:
45
67
  self._metric_name = "response_match_score"
46
68
  else:
@@ -63,201 +85,6 @@ class ResponseEvaluator(Evaluator):
63
85
  actual_invocations, expected_invocations
64
86
  )
65
87
 
66
- total_score = 0.0
67
- num_invocations = 0
68
- per_invocation_results = []
69
- for actual, expected in zip(actual_invocations, expected_invocations):
70
- prompt = self._get_text(expected.user_content)
71
- reference = self._get_text(expected.final_response)
72
- response = self._get_text(actual.final_response)
73
- actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
74
- reference_trajectory = self._get_tool_use_trajectory(
75
- expected.intermediate_data
76
- )
77
-
78
- eval_case = {
79
- "prompt": prompt,
80
- "reference": reference,
81
- "response": response,
82
- "actual_tool_user": actual_tool_use,
83
- "reference_trajectory": reference_trajectory,
84
- }
85
-
86
- eval_case_result = ResponseEvaluator._perform_eval(
87
- pd.DataFrame([eval_case]), [self._metric_name]
88
- )
89
- score = self._get_score(eval_case_result)
90
- per_invocation_results.append(
91
- PerInvocationResult(
92
- actual_invocation=actual,
93
- expected_invocation=expected,
94
- score=score,
95
- eval_status=self._get_eval_status(score),
96
- )
97
- )
98
- total_score += score
99
- num_invocations += 1
100
-
101
- if per_invocation_results:
102
- overall_score = total_score / num_invocations
103
- return EvaluationResult(
104
- overall_score=overall_score,
105
- overall_eval_status=self._get_eval_status(overall_score),
106
- per_invocation_results=per_invocation_results,
107
- )
108
-
109
- return EvaluationResult()
110
-
111
- def _get_text(self, content: Optional[genai_types.Content]) -> str:
112
- if content and content.parts:
113
- return "\n".join([p.text for p in content.parts if p.text])
114
-
115
- return ""
116
-
117
- def _get_tool_use_trajectory(
118
- self, intermediate_data: Optional[IntermediateData]
119
- ) -> list[dict[str, Any]]:
120
- tool_use_trajectory = []
121
- if not intermediate_data:
122
- return tool_use_trajectory
123
-
124
- for function_call in intermediate_data.tool_uses:
125
- tool_use_trajectory.append({
126
- "tool_name": function_call.name,
127
- "tool_input": function_call.args or {},
128
- })
129
-
130
- return tool_use_trajectory
131
-
132
- def _get_score(self, eval_result) -> float:
133
- return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
134
-
135
- def _get_eval_status(self, score: float):
136
- return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
137
-
138
- @staticmethod
139
- @deprecated(
140
- "This method has been deprecated and will be removed soon. Please use"
141
- " evaluate_invocations instead."
142
- )
143
- def evaluate(
144
- raw_eval_dataset: list[list[dict[str, Any]]],
145
- evaluation_criteria: list[str],
146
- *,
147
- print_detailed_results: bool = False,
148
- ):
149
- r"""Returns the value of requested evaluation metrics.
150
-
151
- Args:
152
- raw_eval_dataset: The dataset that will be evaluated.
153
- evaluation_criteria: The evaluation criteria to be used. This method
154
- support two criteria, `response_evaluation_score` and
155
- `response_match_score`.
156
- print_detailed_results: Prints detailed results on the console. This is
157
- usually helpful during debugging.
158
-
159
- A note on evaluation_criteria:
160
- `response_match_score`: This metric compares the agents final natural
161
- language response with the expected final response, stored in the
162
- "reference" field in test/eval files. We use Rouge metric to compare the
163
- two responses.
164
-
165
- Value Range: [0, 1]. A score closer to 0 means poor similarity between
166
- response and reference. A score closer to 1 means strong similarity
167
- between response and reference.
168
-
169
- `response_evaluation_score`: Uses LLM to evalaute coherence of the
170
- response, including tool use. This is pointwise metric.
171
-
172
- Value range: [0, 5], where 0 means that the agent's response is not
173
- coherent, while 5 means it is . High values are good.
174
- A note on raw_eval_dataset:
175
- The dataset should be a list session, where each session is represented
176
- as a list of interaction that need evaluation. Each evaluation is
177
- represented as a dictionary that is expected to have values for the
178
- following keys:
179
-
180
- 1) query
181
- 2) response
182
- 3) acutal_tool_use
183
- 4) expected_tool_use
184
- 5) reference
185
-
186
- Here is a sample eval_dataset value with one entry:
187
- [
188
- [
189
- {
190
- "query": "roll a die for me",
191
- "response": "I rolled a 16 sided die and got 13.\n",
192
- "expected_tool_use": [
193
- {
194
- "tool_name": "roll_die",
195
- "tool_input": {
196
- "sides": 16
197
- }
198
- }
199
- ],
200
- "acutal_tool_use": [
201
- {
202
- "tool_name": "roll_die",
203
- "tool_input": {
204
- "sides": 16
205
- }
206
- }
207
- ],
208
- "reference": "I rolled a 16 sided die and got 13.\n"
209
- }
210
- ]
211
- ]
212
- """
213
- if not raw_eval_dataset:
214
- raise ValueError("The evaluation dataset is empty.")
215
-
216
- metrics = ResponseEvaluator._get_metrics(
217
- raw_eval_dataset, evaluation_criteria
218
- )
219
- flattened_queries = [
220
- item for sublist in raw_eval_dataset for item in sublist
221
- ]
222
- eval_dataset = pd.DataFrame(flattened_queries).rename(
223
- columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
224
- )
225
-
226
- eval_result = ResponseEvaluator._perform_eval(
227
- dataset=eval_dataset, metrics=metrics
228
- )
229
-
230
- if print_detailed_results:
231
- ResponseEvaluator._print_results(eval_result)
232
- return eval_result.summary_metrics
233
-
234
- @staticmethod
235
- def _get_metrics(raw_eval_dataset, criteria):
236
- metrics = []
237
- if (
238
- "response_evaluation_score" in criteria
239
- and "query" in raw_eval_dataset[0][0]
240
- and "expected_tool_use" in raw_eval_dataset[0][0]
241
- ):
242
- metrics.append(MetricPromptTemplateExamples.Pointwise.COHERENCE)
243
- if (
244
- "response_match_score" in criteria
245
- and "reference" in raw_eval_dataset[0][0]
246
- ):
247
- metrics.append("rouge_1")
248
- return metrics
249
-
250
- @staticmethod
251
- def _perform_eval(dataset, metrics):
252
- """This method hides away the call to external service.
253
-
254
- Primarily helps with unit testing.
255
- """
256
- eval_task = EvalTask(dataset=dataset, metrics=metrics)
257
-
258
- return eval_task.evaluate()
259
-
260
- @staticmethod
261
- def _print_results(eval_result):
262
- print("Evaluation Summary Metrics:", eval_result.summary_metrics)
263
- print(tabulate(eval_result.metrics_table, headers="keys", tablefmt="grid"))
88
+ return _VertexAiEvalFacade(
89
+ threshold=self._threshold, metric_name=self._metric_name
90
+ ).evaluate_invocations(actual_invocations, expected_invocations)
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing_extensions import override
18
+ from vertexai import types as vertexai_types
19
+
20
+ from .eval_case import Invocation
21
+ from .eval_metrics import EvalMetric
22
+ from .evaluator import EvaluationResult
23
+ from .evaluator import Evaluator
24
+ from .vertex_ai_eval_facade import _VertexAiEvalFacade
25
+
26
+
27
+ class SafetyEvaluatorV1(Evaluator):
28
+ """Evaluates safety (harmlessness) of an Agent's Response.
29
+
30
+ The class delegates the responsibility to Vertex Gen AI Eval SDK. The V1
31
+ suffix in the class name is added to convey that there could be other versions
32
+ of the safety metric as well, and those metrics could use a different strategy
33
+ to evaluate safety.
34
+
35
+ Using this class requires a GCP project. Please set GOOGLE_CLOUD_PROJECT and
36
+ GOOGLE_CLOUD_LOCATION in your .env file.
37
+
38
+ Value range of the metric is [0, 1], with values closer to 1 to be more
39
+ desirable (safe).
40
+ """
41
+
42
+ def __init__(self, eval_metric: EvalMetric):
43
+ self._eval_metric = eval_metric
44
+
45
+ @override
46
+ def evaluate_invocations(
47
+ self,
48
+ actual_invocations: list[Invocation],
49
+ expected_invocations: list[Invocation],
50
+ ) -> EvaluationResult:
51
+ return _VertexAiEvalFacade(
52
+ threshold=self._eval_metric.threshold,
53
+ metric_name=vertexai_types.PrebuiltMetric.SAFETY,
54
+ ).evaluate_invocations(actual_invocations, expected_invocations)
@@ -15,7 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  from typing import Any
18
- from typing import cast
18
+ from typing import Optional
19
19
 
20
20
  from google.genai import types as genai_types
21
21
  import pandas as pd
@@ -24,6 +24,7 @@ from typing_extensions import deprecated
24
24
  from typing_extensions import override
25
25
 
26
26
  from .eval_case import Invocation
27
+ from .eval_metrics import EvalMetric
27
28
  from .evaluation_constants import EvalConstants
28
29
  from .evaluator import EvalStatus
29
30
  from .evaluator import EvaluationResult
@@ -34,7 +35,20 @@ from .evaluator import PerInvocationResult
34
35
  class TrajectoryEvaluator(Evaluator):
35
36
  """Evaluates tool use trajectories for accuracy."""
36
37
 
37
- def __init__(self, threshold: float):
38
+ def __init__(
39
+ self,
40
+ threshold: Optional[float] = None,
41
+ eval_metric: Optional[EvalMetric] = None,
42
+ ):
43
+ if threshold is not None and eval_metric:
44
+ raise ValueError(
45
+ "Either eval_metric should be specified or threshold should be"
46
+ " specified."
47
+ )
48
+
49
+ if eval_metric:
50
+ threshold = eval_metric.threshold
51
+
38
52
  self._threshold = threshold
39
53
 
40
54
  @override
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from typing import Optional
19
+
20
+ from google.genai import types as genai_types
21
+ import pandas as pd
22
+ from typing_extensions import override
23
+ from vertexai import Client as VertexAiClient
24
+ from vertexai import types as vertexai_types
25
+
26
+ from .eval_case import Invocation
27
+ from .evaluator import EvalStatus
28
+ from .evaluator import EvaluationResult
29
+ from .evaluator import Evaluator
30
+ from .evaluator import PerInvocationResult
31
+
32
+ _ERROR_MESSAGE_SUFFIX = """
33
+ You should specify both project id and location. This metric uses Vertex Gen AI
34
+ Eval SDK, and it requires google cloud credentials.
35
+
36
+ If using an .env file add the values there, or explicitly set in the code using
37
+ the template below:
38
+
39
+ os.environ['GOOGLE_CLOUD_LOCATION'] = <LOCATION>
40
+ os.environ['GOOGLE_CLOUD_PROJECT'] = <PROJECT ID>
41
+ """
42
+
43
+
44
+ class _VertexAiEvalFacade(Evaluator):
45
+ """Simple facade for Vertex Gen AI Eval SDK.
46
+
47
+ Vertex Gen AI Eval SDK exposes quite a few metrics that are valuable for
48
+ agentic evals. This class helps us to access those metrics.
49
+
50
+ Using this class requires a GCP project. Please set GOOGLE_CLOUD_PROJECT and
51
+ GOOGLE_CLOUD_LOCATION in your .env file.
52
+ """
53
+
54
+ def __init__(
55
+ self, threshold: float, metric_name: vertexai_types.PrebuiltMetric
56
+ ):
57
+ self._threshold = threshold
58
+ self._metric_name = metric_name
59
+
60
+ @override
61
+ def evaluate_invocations(
62
+ self,
63
+ actual_invocations: list[Invocation],
64
+ expected_invocations: list[Invocation],
65
+ ) -> EvaluationResult:
66
+ total_score = 0.0
67
+ num_invocations = 0
68
+ per_invocation_results = []
69
+ for actual, expected in zip(actual_invocations, expected_invocations):
70
+ prompt = self._get_text(expected.user_content)
71
+ reference = self._get_text(expected.final_response)
72
+ response = self._get_text(actual.final_response)
73
+ eval_case = {
74
+ "prompt": prompt,
75
+ "reference": reference,
76
+ "response": response,
77
+ }
78
+
79
+ eval_case_result = _VertexAiEvalFacade._perform_eval(
80
+ dataset=pd.DataFrame([eval_case]), metrics=[self._metric_name]
81
+ )
82
+ score = self._get_score(eval_case_result)
83
+ per_invocation_results.append(
84
+ PerInvocationResult(
85
+ actual_invocation=actual,
86
+ expected_invocation=expected,
87
+ score=score,
88
+ eval_status=self._get_eval_status(score),
89
+ )
90
+ )
91
+
92
+ if score:
93
+ total_score += score
94
+ num_invocations += 1
95
+
96
+ if per_invocation_results:
97
+ overall_score = (
98
+ total_score / num_invocations if num_invocations > 0 else None
99
+ )
100
+ return EvaluationResult(
101
+ overall_score=overall_score,
102
+ overall_eval_status=self._get_eval_status(overall_score),
103
+ per_invocation_results=per_invocation_results,
104
+ )
105
+
106
+ return EvaluationResult()
107
+
108
+ def _get_text(self, content: Optional[genai_types.Content]) -> str:
109
+ if content and content.parts:
110
+ return "\n".join([p.text for p in content.parts if p.text])
111
+
112
+ return ""
113
+
114
+ def _get_score(self, eval_result) -> Optional[float]:
115
+ if eval_result and eval_result.summary_metrics:
116
+ return eval_result.summary_metrics[0].mean_score
117
+
118
+ return None
119
+
120
+ def _get_eval_status(self, score: Optional[float]):
121
+ if score:
122
+ return (
123
+ EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
124
+ )
125
+
126
+ return EvalStatus.NOT_EVALUATED
127
+
128
+ @staticmethod
129
+ def _perform_eval(dataset, metrics):
130
+ """This method hides away the call to external service.
131
+
132
+ Primarily helps with unit testing.
133
+ """
134
+ project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", None)
135
+ location = os.environ.get("GOOGLE_CLOUD_LOCATION", None)
136
+
137
+ if not project_id:
138
+ raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
139
+ if not location:
140
+ raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
141
+
142
+ client = VertexAiClient(project=project_id, location=location)
143
+
144
+ return client.evals.evaluate(
145
+ dataset=vertexai_types.EvaluationDataset(eval_dataset_df=dataset),
146
+ metrics=metrics,
147
+ )
@@ -14,9 +14,8 @@
14
14
  from __future__ import annotations
15
15
 
16
16
  from datetime import datetime
17
- import random
18
- import string
19
17
  from typing import Optional
18
+ import uuid
20
19
 
21
20
  from google.genai import types
22
21
  from pydantic import alias_generators
@@ -132,5 +131,4 @@ class Event(LlmResponse):
132
131
 
133
132
  @staticmethod
134
133
  def new_id():
135
- characters = string.ascii_letters + string.digits
136
- return ''.join(random.choice(characters) for _ in range(8))
134
+ return str(uuid.uuid4())
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from abc import ABC
18
18
  import asyncio
19
+ import datetime
19
20
  import inspect
20
21
  import logging
21
22
  from typing import AsyncGenerator
@@ -320,6 +321,7 @@ class BaseLlmFlow(ABC):
320
321
  ):
321
322
  # Update the mutable event id to avoid conflict
322
323
  model_response_event.id = Event.new_id()
324
+ model_response_event.timestamp = datetime.datetime.now().timestamp()
323
325
  yield event
324
326
 
325
327
  async def _preprocess_async(
@@ -11,8 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
-
16
14
  from __future__ import annotations
17
15
 
18
16
  import re
@@ -43,6 +41,9 @@ class InMemoryMemoryService(BaseMemoryService):
43
41
  """An in-memory memory service for prototyping purpose only.
44
42
 
45
43
  Uses keyword matching instead of semantic search.
44
+
45
+ It is not suitable for multi-threaded production environments. Use it for
46
+ testing and development only.
46
47
  """
47
48
 
48
49
  def __init__(self):
@@ -23,6 +23,7 @@ from typing import cast
23
23
  from typing import Dict
24
24
  from typing import Generator
25
25
  from typing import Iterable
26
+ from typing import List
26
27
  from typing import Literal
27
28
  from typing import Optional
28
29
  from typing import Tuple
@@ -485,16 +486,22 @@ def _message_to_generate_content_response(
485
486
 
486
487
  def _get_completion_inputs(
487
488
  llm_request: LlmRequest,
488
- ) -> tuple[Iterable[Message], Iterable[dict]]:
489
- """Converts an LlmRequest to litellm inputs.
489
+ ) -> Tuple[
490
+ List[Message],
491
+ Optional[List[Dict]],
492
+ Optional[types.SchemaUnion],
493
+ Optional[Dict],
494
+ ]:
495
+ """Converts an LlmRequest to litellm inputs and extracts generation params.
490
496
 
491
497
  Args:
492
498
  llm_request: The LlmRequest to convert.
493
499
 
494
500
  Returns:
495
- The litellm inputs (message list, tool dictionary and response format).
501
+ The litellm inputs (message list, tool dictionary, response format and generation params).
496
502
  """
497
- messages = []
503
+ # 1. Construct messages
504
+ messages: List[Message] = []
498
505
  for content in llm_request.contents or []:
499
506
  message_param_or_list = _content_to_message_param(content)
500
507
  if isinstance(message_param_or_list, list):
@@ -511,7 +518,8 @@ def _get_completion_inputs(
511
518
  ),
512
519
  )
513
520
 
514
- tools = None
521
+ # 2. Convert tool declarations
522
+ tools: Optional[List[Dict]] = None
515
523
  if (
516
524
  llm_request.config
517
525
  and llm_request.config.tools
@@ -522,12 +530,39 @@ def _get_completion_inputs(
522
530
  for tool in llm_request.config.tools[0].function_declarations
523
531
  ]
524
532
 
525
- response_format = None
526
-
527
- if llm_request.config.response_schema:
533
+ # 3. Handle response format
534
+ response_format: Optional[types.SchemaUnion] = None
535
+ if llm_request.config and llm_request.config.response_schema:
528
536
  response_format = llm_request.config.response_schema
529
537
 
530
- return messages, tools, response_format
538
+ # 4. Extract generation parameters
539
+ generation_params: Optional[Dict] = None
540
+ if llm_request.config:
541
+ config_dict = llm_request.config.model_dump(exclude_none=True)
542
+ # Generate LiteLlm parameters here,
543
+ # Following https://docs.litellm.ai/docs/completion/input.
544
+ generation_params = {}
545
+ param_mapping = {
546
+ "max_output_tokens": "max_completion_tokens",
547
+ "stop_sequences": "stop",
548
+ }
549
+ for key in (
550
+ "temperature",
551
+ "max_output_tokens",
552
+ "top_p",
553
+ "top_k",
554
+ "stop_sequences",
555
+ "presence_penalty",
556
+ "frequency_penalty",
557
+ ):
558
+ if key in config_dict:
559
+ mapped_key = param_mapping.get(key, key)
560
+ generation_params[mapped_key] = config_dict[key]
561
+
562
+ if not generation_params:
563
+ generation_params = None
564
+
565
+ return messages, tools, response_format, generation_params
531
566
 
532
567
 
533
568
  def _build_function_declaration_log(
@@ -664,7 +699,9 @@ class LiteLlm(BaseLlm):
664
699
  self._maybe_append_user_content(llm_request)
665
700
  logger.debug(_build_request_log(llm_request))
666
701
 
667
- messages, tools, response_format = _get_completion_inputs(llm_request)
702
+ messages, tools, response_format, generation_params = (
703
+ _get_completion_inputs(llm_request)
704
+ )
668
705
 
669
706
  if "functions" in self._additional_args:
670
707
  # LiteLLM does not support both tools and functions together.
@@ -678,6 +715,9 @@ class LiteLlm(BaseLlm):
678
715
  }
679
716
  completion_args.update(self._additional_args)
680
717
 
718
+ if generation_params:
719
+ completion_args.update(generation_params)
720
+
681
721
  if stream:
682
722
  text = ""
683
723
  # Track function calls by index