google-adk 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +0 -2
- google/adk/agents/invocation_context.py +3 -3
- google/adk/agents/parallel_agent.py +17 -7
- google/adk/agents/sequential_agent.py +8 -8
- google/adk/auth/auth_preprocessor.py +18 -17
- google/adk/cli/agent_graph.py +165 -23
- google/adk/cli/browser/assets/ADK-512-color.svg +9 -0
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-PKDNKWJE.js → main-CS5OLUMF.js} +59 -59
- google/adk/cli/browser/polyfills-FFHMD2TL.js +17 -0
- google/adk/cli/cli.py +9 -9
- google/adk/cli/cli_deploy.py +157 -0
- google/adk/cli/cli_tools_click.py +228 -99
- google/adk/cli/fast_api.py +119 -34
- google/adk/cli/utils/agent_loader.py +60 -44
- google/adk/cli/utils/envs.py +1 -1
- google/adk/code_executors/unsafe_local_code_executor.py +11 -0
- google/adk/errors/__init__.py +13 -0
- google/adk/errors/not_found_error.py +28 -0
- google/adk/evaluation/agent_evaluator.py +1 -1
- google/adk/evaluation/eval_sets_manager.py +36 -6
- google/adk/evaluation/evaluation_generator.py +5 -4
- google/adk/evaluation/local_eval_sets_manager.py +101 -6
- google/adk/flows/llm_flows/agent_transfer.py +2 -2
- google/adk/flows/llm_flows/base_llm_flow.py +19 -0
- google/adk/flows/llm_flows/contents.py +4 -4
- google/adk/flows/llm_flows/functions.py +140 -127
- google/adk/memory/vertex_ai_rag_memory_service.py +2 -2
- google/adk/models/anthropic_llm.py +7 -10
- google/adk/models/google_llm.py +46 -18
- google/adk/models/lite_llm.py +63 -26
- google/adk/py.typed +0 -0
- google/adk/sessions/_session_util.py +10 -16
- google/adk/sessions/database_session_service.py +81 -66
- google/adk/sessions/vertex_ai_session_service.py +32 -6
- google/adk/telemetry.py +91 -24
- google/adk/tools/_automatic_function_calling_util.py +31 -25
- google/adk/tools/{function_parameter_parse_util.py → _function_parameter_parse_util.py} +9 -3
- google/adk/tools/_gemini_schema_util.py +158 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +3 -2
- google/adk/tools/application_integration_tool/clients/connections_client.py +7 -0
- google/adk/tools/application_integration_tool/integration_connector_tool.py +5 -7
- google/adk/tools/base_tool.py +4 -8
- google/adk/tools/bigquery/bigquery_credentials.py +7 -3
- google/adk/tools/function_tool.py +4 -4
- google/adk/tools/langchain_tool.py +20 -13
- google/adk/tools/load_memory_tool.py +1 -0
- google/adk/tools/mcp_tool/conversion_utils.py +4 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +63 -5
- google/adk/tools/mcp_tool/mcp_tool.py +3 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +15 -8
- google/adk/tools/openapi_tool/common/common.py +4 -43
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +0 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +4 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +4 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +7 -127
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +2 -7
- google/adk/tools/transfer_to_agent_tool.py +8 -1
- google/adk/tools/vertex_ai_search_tool.py +8 -1
- google/adk/utils/variant_utils.py +51 -0
- google/adk/version.py +1 -1
- {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/METADATA +7 -7
- {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/RECORD +66 -60
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +0 -17
- {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/WHEEL +0 -0
- {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,10 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from __future__ import annotations
|
16
|
+
|
15
17
|
import asyncio
|
16
18
|
import collections
|
17
19
|
from contextlib import asynccontextmanager
|
18
20
|
from datetime import datetime
|
21
|
+
import functools
|
19
22
|
import logging
|
20
23
|
import os
|
21
24
|
import tempfile
|
@@ -58,6 +61,19 @@ class HelpfulCommand(click.Command):
|
|
58
61
|
def __init__(self, *args, **kwargs):
|
59
62
|
super().__init__(*args, **kwargs)
|
60
63
|
|
64
|
+
@staticmethod
|
65
|
+
def _format_missing_arg_error(click_exception):
|
66
|
+
"""Format the missing argument error with uppercase parameter name.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
click_exception: The MissingParameter exception from Click.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
str: Formatted error message with uppercase parameter name.
|
73
|
+
"""
|
74
|
+
name = click_exception.param.name
|
75
|
+
return f"Missing required argument: {name.upper()}"
|
76
|
+
|
61
77
|
def parse_args(self, ctx, args):
|
62
78
|
"""Override the parse_args method to show help text on error.
|
63
79
|
|
@@ -75,8 +91,10 @@ class HelpfulCommand(click.Command):
|
|
75
91
|
try:
|
76
92
|
return super().parse_args(ctx, args)
|
77
93
|
except click.MissingParameter as exc:
|
94
|
+
error_message = self._format_missing_arg_error(exc)
|
95
|
+
|
78
96
|
click.echo(ctx.get_help())
|
79
|
-
click.secho(f"\nError: {
|
97
|
+
click.secho(f"\nError: {error_message}", fg="red", err=True)
|
80
98
|
ctx.exit(2)
|
81
99
|
|
82
100
|
|
@@ -84,6 +102,7 @@ logger = logging.getLogger("google_adk." + __name__)
|
|
84
102
|
|
85
103
|
|
86
104
|
@click.group(context_settings={"max_content_width": 240})
|
105
|
+
@click.version_option(version.__version__)
|
87
106
|
def main():
|
88
107
|
"""Agent Development Kit CLI tools."""
|
89
108
|
pass
|
@@ -398,57 +417,78 @@ def cli_eval(
|
|
398
417
|
print(eval_result.model_dump_json(indent=2))
|
399
418
|
|
400
419
|
|
401
|
-
|
402
|
-
|
403
|
-
"--session_db_url",
|
404
|
-
help=(
|
405
|
-
"""Optional. The database URL to store the session.
|
420
|
+
def fast_api_common_options():
|
421
|
+
"""Decorator to add common fast api options to click commands."""
|
406
422
|
|
407
|
-
|
423
|
+
def decorator(func):
|
424
|
+
@click.option(
|
425
|
+
"--session_db_url",
|
426
|
+
help=(
|
427
|
+
"""Optional. The database URL to store the session.
|
428
|
+
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
429
|
+
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
430
|
+
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
431
|
+
),
|
432
|
+
)
|
433
|
+
@click.option(
|
434
|
+
"--artifact_storage_uri",
|
435
|
+
type=str,
|
436
|
+
help=(
|
437
|
+
"Optional. The artifact storage URI to store the artifacts,"
|
438
|
+
" supported URIs: gs://<bucket name> for GCS artifact service."
|
439
|
+
),
|
440
|
+
default=None,
|
441
|
+
)
|
442
|
+
@click.option(
|
443
|
+
"--host",
|
444
|
+
type=str,
|
445
|
+
help="Optional. The binding host of the server",
|
446
|
+
default="127.0.0.1",
|
447
|
+
show_default=True,
|
448
|
+
)
|
449
|
+
@click.option(
|
450
|
+
"--port",
|
451
|
+
type=int,
|
452
|
+
help="Optional. The port of the server",
|
453
|
+
default=8000,
|
454
|
+
)
|
455
|
+
@click.option(
|
456
|
+
"--allow_origins",
|
457
|
+
help="Optional. Any additional origins to allow for CORS.",
|
458
|
+
multiple=True,
|
459
|
+
)
|
460
|
+
@click.option(
|
461
|
+
"--log_level",
|
462
|
+
type=click.Choice(
|
463
|
+
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
464
|
+
case_sensitive=False,
|
465
|
+
),
|
466
|
+
default="INFO",
|
467
|
+
help="Optional. Set the logging level",
|
468
|
+
)
|
469
|
+
@click.option(
|
470
|
+
"--trace_to_cloud",
|
471
|
+
is_flag=True,
|
472
|
+
show_default=True,
|
473
|
+
default=False,
|
474
|
+
help="Optional. Whether to enable cloud trace for telemetry.",
|
475
|
+
)
|
476
|
+
@click.option(
|
477
|
+
"--reload/--no-reload",
|
478
|
+
default=True,
|
479
|
+
help="Optional. Whether to enable auto reload for server.",
|
480
|
+
)
|
481
|
+
@functools.wraps(func)
|
482
|
+
def wrapper(*args, **kwargs):
|
483
|
+
return func(*args, **kwargs)
|
408
484
|
|
409
|
-
|
485
|
+
return wrapper
|
410
486
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
@
|
415
|
-
|
416
|
-
type=str,
|
417
|
-
help="Optional. The binding host of the server",
|
418
|
-
default="127.0.0.1",
|
419
|
-
show_default=True,
|
420
|
-
)
|
421
|
-
@click.option(
|
422
|
-
"--port",
|
423
|
-
type=int,
|
424
|
-
help="Optional. The port of the server",
|
425
|
-
default=8000,
|
426
|
-
)
|
427
|
-
@click.option(
|
428
|
-
"--allow_origins",
|
429
|
-
help="Optional. Any additional origins to allow for CORS.",
|
430
|
-
multiple=True,
|
431
|
-
)
|
432
|
-
@click.option(
|
433
|
-
"--log_level",
|
434
|
-
type=click.Choice(
|
435
|
-
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
436
|
-
),
|
437
|
-
default="INFO",
|
438
|
-
help="Optional. Set the logging level",
|
439
|
-
)
|
440
|
-
@click.option(
|
441
|
-
"--trace_to_cloud",
|
442
|
-
is_flag=True,
|
443
|
-
show_default=True,
|
444
|
-
default=False,
|
445
|
-
help="Optional. Whether to enable cloud trace for telemetry.",
|
446
|
-
)
|
447
|
-
@click.option(
|
448
|
-
"--reload/--no-reload",
|
449
|
-
default=True,
|
450
|
-
help="Optional. Whether to enable auto reload for server.",
|
451
|
-
)
|
487
|
+
return decorator
|
488
|
+
|
489
|
+
|
490
|
+
@main.command("web")
|
491
|
+
@fast_api_common_options()
|
452
492
|
@click.argument(
|
453
493
|
"agents_dir",
|
454
494
|
type=click.Path(
|
@@ -459,6 +499,7 @@ def cli_eval(
|
|
459
499
|
def cli_web(
|
460
500
|
agents_dir: str,
|
461
501
|
session_db_url: str = "",
|
502
|
+
artifact_storage_uri: Optional[str] = None,
|
462
503
|
log_level: str = "INFO",
|
463
504
|
allow_origins: Optional[list[str]] = None,
|
464
505
|
host: str = "127.0.0.1",
|
@@ -502,6 +543,7 @@ def cli_web(
|
|
502
543
|
app = get_fast_api_app(
|
503
544
|
agents_dir=agents_dir,
|
504
545
|
session_db_url=session_db_url,
|
546
|
+
artifact_storage_uri=artifact_storage_uri,
|
505
547
|
allow_origins=allow_origins,
|
506
548
|
web=True,
|
507
549
|
trace_to_cloud=trace_to_cloud,
|
@@ -519,56 +561,6 @@ def cli_web(
|
|
519
561
|
|
520
562
|
|
521
563
|
@main.command("api_server")
|
522
|
-
@click.option(
|
523
|
-
"--session_db_url",
|
524
|
-
help=(
|
525
|
-
"""Optional. The database URL to store the session.
|
526
|
-
|
527
|
-
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
528
|
-
|
529
|
-
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
530
|
-
|
531
|
-
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
532
|
-
),
|
533
|
-
)
|
534
|
-
@click.option(
|
535
|
-
"--host",
|
536
|
-
type=str,
|
537
|
-
help="Optional. The binding host of the server",
|
538
|
-
default="127.0.0.1",
|
539
|
-
show_default=True,
|
540
|
-
)
|
541
|
-
@click.option(
|
542
|
-
"--port",
|
543
|
-
type=int,
|
544
|
-
help="Optional. The port of the server",
|
545
|
-
default=8000,
|
546
|
-
)
|
547
|
-
@click.option(
|
548
|
-
"--allow_origins",
|
549
|
-
help="Optional. Any additional origins to allow for CORS.",
|
550
|
-
multiple=True,
|
551
|
-
)
|
552
|
-
@click.option(
|
553
|
-
"--log_level",
|
554
|
-
type=click.Choice(
|
555
|
-
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
556
|
-
),
|
557
|
-
default="INFO",
|
558
|
-
help="Optional. Set the logging level",
|
559
|
-
)
|
560
|
-
@click.option(
|
561
|
-
"--trace_to_cloud",
|
562
|
-
is_flag=True,
|
563
|
-
show_default=True,
|
564
|
-
default=False,
|
565
|
-
help="Optional. Whether to enable cloud trace for telemetry.",
|
566
|
-
)
|
567
|
-
@click.option(
|
568
|
-
"--reload/--no-reload",
|
569
|
-
default=True,
|
570
|
-
help="Optional. Whether to enable auto reload for server.",
|
571
|
-
)
|
572
564
|
# The directory of agents, where each sub-directory is a single agent.
|
573
565
|
# By default, it is the current working directory
|
574
566
|
@click.argument(
|
@@ -578,9 +570,11 @@ def cli_web(
|
|
578
570
|
),
|
579
571
|
default=os.getcwd(),
|
580
572
|
)
|
573
|
+
@fast_api_common_options()
|
581
574
|
def cli_api_server(
|
582
575
|
agents_dir: str,
|
583
576
|
session_db_url: str = "",
|
577
|
+
artifact_storage_uri: Optional[str] = None,
|
584
578
|
log_level: str = "INFO",
|
585
579
|
allow_origins: Optional[list[str]] = None,
|
586
580
|
host: str = "127.0.0.1",
|
@@ -603,6 +597,7 @@ def cli_api_server(
|
|
603
597
|
get_fast_api_app(
|
604
598
|
agents_dir=agents_dir,
|
605
599
|
session_db_url=session_db_url,
|
600
|
+
artifact_storage_uri=artifact_storage_uri,
|
606
601
|
allow_origins=allow_origins,
|
607
602
|
web=False,
|
608
603
|
trace_to_cloud=trace_to_cloud,
|
@@ -706,6 +701,15 @@ def cli_api_server(
|
|
706
701
|
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
707
702
|
),
|
708
703
|
)
|
704
|
+
@click.option(
|
705
|
+
"--artifact_storage_uri",
|
706
|
+
type=str,
|
707
|
+
help=(
|
708
|
+
"Optional. The artifact storage URI to store the artifacts, supported"
|
709
|
+
" URIs: gs://<bucket name> for GCS artifact service."
|
710
|
+
),
|
711
|
+
default=None,
|
712
|
+
)
|
709
713
|
@click.argument(
|
710
714
|
"agent",
|
711
715
|
type=click.Path(
|
@@ -734,6 +738,7 @@ def cli_deploy_cloud_run(
|
|
734
738
|
with_ui: bool,
|
735
739
|
verbosity: str,
|
736
740
|
session_db_url: str,
|
741
|
+
artifact_storage_uri: Optional[str],
|
737
742
|
adk_version: str,
|
738
743
|
):
|
739
744
|
"""Deploys an agent to Cloud Run.
|
@@ -757,7 +762,131 @@ def cli_deploy_cloud_run(
|
|
757
762
|
with_ui=with_ui,
|
758
763
|
verbosity=verbosity,
|
759
764
|
session_db_url=session_db_url,
|
765
|
+
artifact_storage_uri=artifact_storage_uri,
|
760
766
|
adk_version=adk_version,
|
761
767
|
)
|
762
768
|
except Exception as e:
|
763
769
|
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
770
|
+
|
771
|
+
|
772
|
+
@deploy.command("agent_engine")
|
773
|
+
@click.option(
|
774
|
+
"--project",
|
775
|
+
type=str,
|
776
|
+
help="Required. Google Cloud project to deploy the agent.",
|
777
|
+
)
|
778
|
+
@click.option(
|
779
|
+
"--region",
|
780
|
+
type=str,
|
781
|
+
help="Required. Google Cloud region to deploy the agent.",
|
782
|
+
)
|
783
|
+
@click.option(
|
784
|
+
"--staging_bucket",
|
785
|
+
type=str,
|
786
|
+
help="Required. GCS bucket for staging the deployment artifacts.",
|
787
|
+
)
|
788
|
+
@click.option(
|
789
|
+
"--trace_to_cloud",
|
790
|
+
type=bool,
|
791
|
+
is_flag=True,
|
792
|
+
show_default=True,
|
793
|
+
default=False,
|
794
|
+
help="Optional. Whether to enable Cloud Trace for Agent Engine.",
|
795
|
+
)
|
796
|
+
@click.option(
|
797
|
+
"--adk_app",
|
798
|
+
type=str,
|
799
|
+
default="agent_engine_app",
|
800
|
+
help=(
|
801
|
+
"Optional. Python file for defining the ADK application"
|
802
|
+
" (default: a file named agent_engine_app.py)"
|
803
|
+
),
|
804
|
+
)
|
805
|
+
@click.option(
|
806
|
+
"--temp_folder",
|
807
|
+
type=str,
|
808
|
+
default=os.path.join(
|
809
|
+
tempfile.gettempdir(),
|
810
|
+
"agent_engine_deploy_src",
|
811
|
+
datetime.now().strftime("%Y%m%d_%H%M%S"),
|
812
|
+
),
|
813
|
+
help=(
|
814
|
+
"Optional. Temp folder for the generated Agent Engine source files."
|
815
|
+
" If the folder already exists, its contents will be removed."
|
816
|
+
" (default: a timestamped folder in the system temp directory)."
|
817
|
+
),
|
818
|
+
)
|
819
|
+
@click.option(
|
820
|
+
"--env_file",
|
821
|
+
type=str,
|
822
|
+
default="",
|
823
|
+
help=(
|
824
|
+
"Optional. The filepath to the `.env` file for environment variables."
|
825
|
+
" (default: the `.env` file in the `agent` directory, if any.)"
|
826
|
+
),
|
827
|
+
)
|
828
|
+
@click.option(
|
829
|
+
"--requirements_file",
|
830
|
+
type=str,
|
831
|
+
default="",
|
832
|
+
help=(
|
833
|
+
"Optional. The filepath to the `requirements.txt` file to use."
|
834
|
+
" (default: the `requirements.txt` file in the `agent` directory, if"
|
835
|
+
" any.)"
|
836
|
+
),
|
837
|
+
)
|
838
|
+
@click.argument(
|
839
|
+
"agent",
|
840
|
+
type=click.Path(
|
841
|
+
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
842
|
+
),
|
843
|
+
)
|
844
|
+
def cli_deploy_agent_engine(
|
845
|
+
agent: str,
|
846
|
+
project: str,
|
847
|
+
region: str,
|
848
|
+
staging_bucket: str,
|
849
|
+
trace_to_cloud: bool,
|
850
|
+
adk_app: str,
|
851
|
+
temp_folder: str,
|
852
|
+
env_file: str,
|
853
|
+
requirements_file: str,
|
854
|
+
):
|
855
|
+
"""Deploys an agent to Agent Engine.
|
856
|
+
|
857
|
+
Args:
|
858
|
+
agent (str): Required. The path to the agent to be deloyed.
|
859
|
+
project (str): Required. Google Cloud project to deploy the agent.
|
860
|
+
region (str): Required. Google Cloud region to deploy the agent.
|
861
|
+
staging_bucket (str): Required. GCS bucket for staging the deployment
|
862
|
+
artifacts.
|
863
|
+
trace_to_cloud (bool): Required. Whether to enable Cloud Trace.
|
864
|
+
adk_app (str): Required. Python file for defining the ADK application.
|
865
|
+
temp_folder (str): Required. The folder for the generated Agent Engine
|
866
|
+
files. If the folder already exists, its contents will be replaced.
|
867
|
+
env_file (str): Required. The filepath to the `.env` file for environment
|
868
|
+
variables. If it is an empty string, the `.env` file in the `agent`
|
869
|
+
directory will be used if it exists.
|
870
|
+
requirements_file (str): Required. The filepath to the `requirements.txt`
|
871
|
+
file to use. If it is an empty string, the `requirements.txt` file in the
|
872
|
+
`agent` directory will be used if exists.
|
873
|
+
|
874
|
+
Example:
|
875
|
+
|
876
|
+
adk deploy agent_engine --project=[project] --region=[region]
|
877
|
+
--staging_bucket=[staging_bucket] path/to/my_agent
|
878
|
+
"""
|
879
|
+
try:
|
880
|
+
cli_deploy.to_agent_engine(
|
881
|
+
agent_folder=agent,
|
882
|
+
project=project,
|
883
|
+
region=region,
|
884
|
+
staging_bucket=staging_bucket,
|
885
|
+
trace_to_cloud=trace_to_cloud,
|
886
|
+
adk_app=adk_app,
|
887
|
+
temp_folder=temp_folder,
|
888
|
+
env_file=env_file,
|
889
|
+
requirements_file=requirements_file,
|
890
|
+
)
|
891
|
+
except Exception as e:
|
892
|
+
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
google/adk/cli/fast_api.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
from __future__ import annotations
|
17
16
|
|
18
17
|
import asyncio
|
@@ -56,7 +55,9 @@ from ..agents.live_request_queue import LiveRequest
|
|
56
55
|
from ..agents.live_request_queue import LiveRequestQueue
|
57
56
|
from ..agents.llm_agent import Agent
|
58
57
|
from ..agents.run_config import StreamingMode
|
58
|
+
from ..artifacts.gcs_artifact_service import GcsArtifactService
|
59
59
|
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
60
|
+
from ..errors.not_found_error import NotFoundError
|
60
61
|
from ..evaluation.eval_case import EvalCase
|
61
62
|
from ..evaluation.eval_case import SessionInput
|
62
63
|
from ..evaluation.eval_metrics import EvalMetric
|
@@ -98,7 +99,7 @@ class ApiServerSpanExporter(export.SpanExporter):
|
|
98
99
|
if (
|
99
100
|
span.name == "call_llm"
|
100
101
|
or span.name == "send_data"
|
101
|
-
or span.name.startswith("
|
102
|
+
or span.name.startswith("execute_tool")
|
102
103
|
):
|
103
104
|
attributes = dict(span.attributes)
|
104
105
|
attributes["trace_id"] = span.get_span_context().trace_id
|
@@ -193,6 +194,7 @@ def get_fast_api_app(
|
|
193
194
|
*,
|
194
195
|
agents_dir: str,
|
195
196
|
session_db_url: str = "",
|
197
|
+
artifact_storage_uri: Optional[str] = None,
|
196
198
|
allow_origins: Optional[list[str]] = None,
|
197
199
|
web: bool,
|
198
200
|
trace_to_cloud: bool = False,
|
@@ -251,13 +253,12 @@ def get_fast_api_app(
|
|
251
253
|
|
252
254
|
runner_dict = {}
|
253
255
|
|
254
|
-
# Build the Artifact service
|
255
|
-
artifact_service = InMemoryArtifactService()
|
256
|
-
memory_service = InMemoryMemoryService()
|
257
|
-
|
258
256
|
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
259
257
|
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
260
258
|
|
259
|
+
# Build the Memory service
|
260
|
+
memory_service = InMemoryMemoryService()
|
261
|
+
|
261
262
|
# Build the Session service
|
262
263
|
agent_engine_id = ""
|
263
264
|
if session_db_url:
|
@@ -276,6 +277,18 @@ def get_fast_api_app(
|
|
276
277
|
else:
|
277
278
|
session_service = InMemorySessionService()
|
278
279
|
|
280
|
+
# Build the Artifact service
|
281
|
+
if artifact_storage_uri:
|
282
|
+
if artifact_storage_uri.startswith("gs://"):
|
283
|
+
gcs_bucket = artifact_storage_uri.split("://")[1]
|
284
|
+
artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
|
285
|
+
else:
|
286
|
+
raise click.ClickException(
|
287
|
+
"Unsupported artifact storage URI: %s" % artifact_storage_uri
|
288
|
+
)
|
289
|
+
else:
|
290
|
+
artifact_service = InMemoryArtifactService()
|
291
|
+
|
279
292
|
# initialize Agent Loader
|
280
293
|
agent_loader = AgentLoader(agents_dir)
|
281
294
|
|
@@ -475,8 +488,66 @@ def get_fast_api_app(
|
|
475
488
|
"""Lists all evals in an eval set."""
|
476
489
|
eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
477
490
|
|
491
|
+
if not eval_set_data:
|
492
|
+
raise HTTPException(
|
493
|
+
status_code=400, detail=f"Eval set `{eval_set_id}` not found."
|
494
|
+
)
|
495
|
+
|
478
496
|
return sorted([x.eval_id for x in eval_set_data.eval_cases])
|
479
497
|
|
498
|
+
@app.get(
|
499
|
+
"/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
|
500
|
+
response_model_exclude_none=True,
|
501
|
+
)
|
502
|
+
def get_eval(app_name: str, eval_set_id: str, eval_case_id: str) -> EvalCase:
|
503
|
+
"""Gets an eval case in an eval set."""
|
504
|
+
eval_case_to_find = eval_sets_manager.get_eval_case(
|
505
|
+
app_name, eval_set_id, eval_case_id
|
506
|
+
)
|
507
|
+
|
508
|
+
if eval_case_to_find:
|
509
|
+
return eval_case_to_find
|
510
|
+
|
511
|
+
raise HTTPException(
|
512
|
+
status_code=404,
|
513
|
+
detail=f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found.",
|
514
|
+
)
|
515
|
+
|
516
|
+
@app.put(
|
517
|
+
"/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
|
518
|
+
response_model_exclude_none=True,
|
519
|
+
)
|
520
|
+
def update_eval(
|
521
|
+
app_name: str,
|
522
|
+
eval_set_id: str,
|
523
|
+
eval_case_id: str,
|
524
|
+
updated_eval_case: EvalCase,
|
525
|
+
):
|
526
|
+
if updated_eval_case.eval_id and updated_eval_case.eval_id != eval_case_id:
|
527
|
+
raise HTTPException(
|
528
|
+
status_code=400,
|
529
|
+
detail=(
|
530
|
+
"Eval id in EvalCase should match the eval id in the API route."
|
531
|
+
),
|
532
|
+
)
|
533
|
+
|
534
|
+
# Overwrite the value. We are either overwriting the same value or an empty
|
535
|
+
# field.
|
536
|
+
updated_eval_case.eval_id = eval_case_id
|
537
|
+
try:
|
538
|
+
eval_sets_manager.update_eval_case(
|
539
|
+
app_name, eval_set_id, updated_eval_case
|
540
|
+
)
|
541
|
+
except NotFoundError as nfe:
|
542
|
+
raise HTTPException(status_code=404, detail=str(nfe)) from nfe
|
543
|
+
|
544
|
+
@app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}")
|
545
|
+
def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str):
|
546
|
+
try:
|
547
|
+
eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id)
|
548
|
+
except NotFoundError as nfe:
|
549
|
+
raise HTTPException(status_code=404, detail=str(nfe)) from nfe
|
550
|
+
|
480
551
|
@app.post(
|
481
552
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
482
553
|
response_model_exclude_none=True,
|
@@ -491,6 +562,11 @@ def get_fast_api_app(
|
|
491
562
|
# run.
|
492
563
|
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
493
564
|
|
565
|
+
if not eval_set:
|
566
|
+
raise HTTPException(
|
567
|
+
status_code=400, detail=f"Eval set `{eval_set_id}` not found."
|
568
|
+
)
|
569
|
+
|
494
570
|
if req.eval_ids:
|
495
571
|
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
|
496
572
|
eval_set_to_evals = {eval_set_id: eval_cases}
|
@@ -501,34 +577,38 @@ def get_fast_api_app(
|
|
501
577
|
root_agent = agent_loader.load_agent(app_name)
|
502
578
|
run_eval_results = []
|
503
579
|
eval_case_results = []
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
580
|
+
try:
|
581
|
+
async for eval_case_result in run_evals(
|
582
|
+
eval_set_to_evals,
|
583
|
+
root_agent,
|
584
|
+
getattr(root_agent, "reset_data", None),
|
585
|
+
req.eval_metrics,
|
586
|
+
session_service=session_service,
|
587
|
+
artifact_service=artifact_service,
|
588
|
+
):
|
589
|
+
run_eval_results.append(
|
590
|
+
RunEvalResult(
|
591
|
+
app_name=app_name,
|
592
|
+
eval_set_file=eval_case_result.eval_set_file,
|
593
|
+
eval_set_id=eval_set_id,
|
594
|
+
eval_id=eval_case_result.eval_id,
|
595
|
+
final_eval_status=eval_case_result.final_eval_status,
|
596
|
+
eval_metric_results=eval_case_result.eval_metric_results,
|
597
|
+
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
|
598
|
+
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
|
599
|
+
user_id=eval_case_result.user_id,
|
600
|
+
session_id=eval_case_result.session_id,
|
601
|
+
)
|
602
|
+
)
|
603
|
+
eval_case_result.session_details = await session_service.get_session(
|
604
|
+
app_name=app_name,
|
605
|
+
user_id=eval_case_result.user_id,
|
606
|
+
session_id=eval_case_result.session_id,
|
607
|
+
)
|
608
|
+
eval_case_results.append(eval_case_result)
|
609
|
+
except ModuleNotFoundError as e:
|
610
|
+
logger.exception("%s", e)
|
611
|
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
532
612
|
|
533
613
|
eval_set_results_manager.save_eval_set_result(
|
534
614
|
app_name, eval_set_id, eval_case_results
|
@@ -854,6 +934,11 @@ def get_fast_api_app(
|
|
854
934
|
return runner
|
855
935
|
|
856
936
|
if web:
|
937
|
+
import mimetypes
|
938
|
+
|
939
|
+
mimetypes.add_type("application/javascript", ".js", True)
|
940
|
+
mimetypes.add_type("text/javascript", ".js", True)
|
941
|
+
|
857
942
|
BASE_DIR = Path(__file__).parent.resolve()
|
858
943
|
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
859
944
|
|