google-adk 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. google/adk/agents/base_agent.py +0 -2
  2. google/adk/agents/invocation_context.py +3 -3
  3. google/adk/agents/parallel_agent.py +17 -7
  4. google/adk/agents/sequential_agent.py +8 -8
  5. google/adk/auth/auth_preprocessor.py +18 -17
  6. google/adk/cli/agent_graph.py +165 -23
  7. google/adk/cli/browser/assets/ADK-512-color.svg +9 -0
  8. google/adk/cli/browser/index.html +2 -2
  9. google/adk/cli/browser/{main-PKDNKWJE.js → main-CS5OLUMF.js} +59 -59
  10. google/adk/cli/browser/polyfills-FFHMD2TL.js +17 -0
  11. google/adk/cli/cli.py +9 -9
  12. google/adk/cli/cli_deploy.py +157 -0
  13. google/adk/cli/cli_tools_click.py +228 -99
  14. google/adk/cli/fast_api.py +119 -34
  15. google/adk/cli/utils/agent_loader.py +60 -44
  16. google/adk/cli/utils/envs.py +1 -1
  17. google/adk/code_executors/unsafe_local_code_executor.py +11 -0
  18. google/adk/errors/__init__.py +13 -0
  19. google/adk/errors/not_found_error.py +28 -0
  20. google/adk/evaluation/agent_evaluator.py +1 -1
  21. google/adk/evaluation/eval_sets_manager.py +36 -6
  22. google/adk/evaluation/evaluation_generator.py +5 -4
  23. google/adk/evaluation/local_eval_sets_manager.py +101 -6
  24. google/adk/flows/llm_flows/agent_transfer.py +2 -2
  25. google/adk/flows/llm_flows/base_llm_flow.py +19 -0
  26. google/adk/flows/llm_flows/contents.py +4 -4
  27. google/adk/flows/llm_flows/functions.py +140 -127
  28. google/adk/memory/vertex_ai_rag_memory_service.py +2 -2
  29. google/adk/models/anthropic_llm.py +7 -10
  30. google/adk/models/google_llm.py +46 -18
  31. google/adk/models/lite_llm.py +63 -26
  32. google/adk/py.typed +0 -0
  33. google/adk/sessions/_session_util.py +10 -16
  34. google/adk/sessions/database_session_service.py +81 -66
  35. google/adk/sessions/vertex_ai_session_service.py +32 -6
  36. google/adk/telemetry.py +91 -24
  37. google/adk/tools/_automatic_function_calling_util.py +31 -25
  38. google/adk/tools/{function_parameter_parse_util.py → _function_parameter_parse_util.py} +9 -3
  39. google/adk/tools/_gemini_schema_util.py +158 -0
  40. google/adk/tools/apihub_tool/apihub_toolset.py +3 -2
  41. google/adk/tools/application_integration_tool/clients/connections_client.py +7 -0
  42. google/adk/tools/application_integration_tool/integration_connector_tool.py +5 -7
  43. google/adk/tools/base_tool.py +4 -8
  44. google/adk/tools/bigquery/bigquery_credentials.py +7 -3
  45. google/adk/tools/function_tool.py +4 -4
  46. google/adk/tools/langchain_tool.py +20 -13
  47. google/adk/tools/load_memory_tool.py +1 -0
  48. google/adk/tools/mcp_tool/conversion_utils.py +4 -2
  49. google/adk/tools/mcp_tool/mcp_session_manager.py +63 -5
  50. google/adk/tools/mcp_tool/mcp_tool.py +3 -2
  51. google/adk/tools/mcp_tool/mcp_toolset.py +15 -8
  52. google/adk/tools/openapi_tool/common/common.py +4 -43
  53. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +0 -2
  54. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +4 -2
  55. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +4 -2
  56. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +7 -127
  57. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +2 -7
  58. google/adk/tools/transfer_to_agent_tool.py +8 -1
  59. google/adk/tools/vertex_ai_search_tool.py +8 -1
  60. google/adk/utils/variant_utils.py +51 -0
  61. google/adk/version.py +1 -1
  62. {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/METADATA +7 -7
  63. {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/RECORD +66 -60
  64. google/adk/cli/browser/polyfills-B6TNHZQ6.js +0 -17
  65. {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/WHEEL +0 -0
  66. {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/entry_points.txt +0 -0
  67. {google_adk-1.1.1.dist-info → google_adk-1.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,10 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  import asyncio
16
18
  import collections
17
19
  from contextlib import asynccontextmanager
18
20
  from datetime import datetime
21
+ import functools
19
22
  import logging
20
23
  import os
21
24
  import tempfile
@@ -58,6 +61,19 @@ class HelpfulCommand(click.Command):
58
61
  def __init__(self, *args, **kwargs):
59
62
  super().__init__(*args, **kwargs)
60
63
 
64
+ @staticmethod
65
+ def _format_missing_arg_error(click_exception):
66
+ """Format the missing argument error with uppercase parameter name.
67
+
68
+ Args:
69
+ click_exception: The MissingParameter exception from Click.
70
+
71
+ Returns:
72
+ str: Formatted error message with uppercase parameter name.
73
+ """
74
+ name = click_exception.param.name
75
+ return f"Missing required argument: {name.upper()}"
76
+
61
77
  def parse_args(self, ctx, args):
62
78
  """Override the parse_args method to show help text on error.
63
79
 
@@ -75,8 +91,10 @@ class HelpfulCommand(click.Command):
75
91
  try:
76
92
  return super().parse_args(ctx, args)
77
93
  except click.MissingParameter as exc:
94
+ error_message = self._format_missing_arg_error(exc)
95
+
78
96
  click.echo(ctx.get_help())
79
- click.secho(f"\nError: {str(exc)}", fg="red", err=True)
97
+ click.secho(f"\nError: {error_message}", fg="red", err=True)
80
98
  ctx.exit(2)
81
99
 
82
100
 
@@ -84,6 +102,7 @@ logger = logging.getLogger("google_adk." + __name__)
84
102
 
85
103
 
86
104
  @click.group(context_settings={"max_content_width": 240})
105
+ @click.version_option(version.__version__)
87
106
  def main():
88
107
  """Agent Development Kit CLI tools."""
89
108
  pass
@@ -398,57 +417,78 @@ def cli_eval(
398
417
  print(eval_result.model_dump_json(indent=2))
399
418
 
400
419
 
401
- @main.command("web")
402
- @click.option(
403
- "--session_db_url",
404
- help=(
405
- """Optional. The database URL to store the session.
420
+ def fast_api_common_options():
421
+ """Decorator to add common fast api options to click commands."""
406
422
 
407
- - Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
423
+ def decorator(func):
424
+ @click.option(
425
+ "--session_db_url",
426
+ help=(
427
+ """Optional. The database URL to store the session.
428
+ - Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
429
+ - Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
430
+ - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
431
+ ),
432
+ )
433
+ @click.option(
434
+ "--artifact_storage_uri",
435
+ type=str,
436
+ help=(
437
+ "Optional. The artifact storage URI to store the artifacts,"
438
+ " supported URIs: gs://<bucket name> for GCS artifact service."
439
+ ),
440
+ default=None,
441
+ )
442
+ @click.option(
443
+ "--host",
444
+ type=str,
445
+ help="Optional. The binding host of the server",
446
+ default="127.0.0.1",
447
+ show_default=True,
448
+ )
449
+ @click.option(
450
+ "--port",
451
+ type=int,
452
+ help="Optional. The port of the server",
453
+ default=8000,
454
+ )
455
+ @click.option(
456
+ "--allow_origins",
457
+ help="Optional. Any additional origins to allow for CORS.",
458
+ multiple=True,
459
+ )
460
+ @click.option(
461
+ "--log_level",
462
+ type=click.Choice(
463
+ ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
464
+ case_sensitive=False,
465
+ ),
466
+ default="INFO",
467
+ help="Optional. Set the logging level",
468
+ )
469
+ @click.option(
470
+ "--trace_to_cloud",
471
+ is_flag=True,
472
+ show_default=True,
473
+ default=False,
474
+ help="Optional. Whether to enable cloud trace for telemetry.",
475
+ )
476
+ @click.option(
477
+ "--reload/--no-reload",
478
+ default=True,
479
+ help="Optional. Whether to enable auto reload for server.",
480
+ )
481
+ @functools.wraps(func)
482
+ def wrapper(*args, **kwargs):
483
+ return func(*args, **kwargs)
408
484
 
409
- - Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
485
+ return wrapper
410
486
 
411
- - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
412
- ),
413
- )
414
- @click.option(
415
- "--host",
416
- type=str,
417
- help="Optional. The binding host of the server",
418
- default="127.0.0.1",
419
- show_default=True,
420
- )
421
- @click.option(
422
- "--port",
423
- type=int,
424
- help="Optional. The port of the server",
425
- default=8000,
426
- )
427
- @click.option(
428
- "--allow_origins",
429
- help="Optional. Any additional origins to allow for CORS.",
430
- multiple=True,
431
- )
432
- @click.option(
433
- "--log_level",
434
- type=click.Choice(
435
- ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
436
- ),
437
- default="INFO",
438
- help="Optional. Set the logging level",
439
- )
440
- @click.option(
441
- "--trace_to_cloud",
442
- is_flag=True,
443
- show_default=True,
444
- default=False,
445
- help="Optional. Whether to enable cloud trace for telemetry.",
446
- )
447
- @click.option(
448
- "--reload/--no-reload",
449
- default=True,
450
- help="Optional. Whether to enable auto reload for server.",
451
- )
487
+ return decorator
488
+
489
+
490
+ @main.command("web")
491
+ @fast_api_common_options()
452
492
  @click.argument(
453
493
  "agents_dir",
454
494
  type=click.Path(
@@ -459,6 +499,7 @@ def cli_eval(
459
499
  def cli_web(
460
500
  agents_dir: str,
461
501
  session_db_url: str = "",
502
+ artifact_storage_uri: Optional[str] = None,
462
503
  log_level: str = "INFO",
463
504
  allow_origins: Optional[list[str]] = None,
464
505
  host: str = "127.0.0.1",
@@ -502,6 +543,7 @@ def cli_web(
502
543
  app = get_fast_api_app(
503
544
  agents_dir=agents_dir,
504
545
  session_db_url=session_db_url,
546
+ artifact_storage_uri=artifact_storage_uri,
505
547
  allow_origins=allow_origins,
506
548
  web=True,
507
549
  trace_to_cloud=trace_to_cloud,
@@ -519,56 +561,6 @@ def cli_web(
519
561
 
520
562
 
521
563
  @main.command("api_server")
522
- @click.option(
523
- "--session_db_url",
524
- help=(
525
- """Optional. The database URL to store the session.
526
-
527
- - Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
528
-
529
- - Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
530
-
531
- - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
532
- ),
533
- )
534
- @click.option(
535
- "--host",
536
- type=str,
537
- help="Optional. The binding host of the server",
538
- default="127.0.0.1",
539
- show_default=True,
540
- )
541
- @click.option(
542
- "--port",
543
- type=int,
544
- help="Optional. The port of the server",
545
- default=8000,
546
- )
547
- @click.option(
548
- "--allow_origins",
549
- help="Optional. Any additional origins to allow for CORS.",
550
- multiple=True,
551
- )
552
- @click.option(
553
- "--log_level",
554
- type=click.Choice(
555
- ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
556
- ),
557
- default="INFO",
558
- help="Optional. Set the logging level",
559
- )
560
- @click.option(
561
- "--trace_to_cloud",
562
- is_flag=True,
563
- show_default=True,
564
- default=False,
565
- help="Optional. Whether to enable cloud trace for telemetry.",
566
- )
567
- @click.option(
568
- "--reload/--no-reload",
569
- default=True,
570
- help="Optional. Whether to enable auto reload for server.",
571
- )
572
564
  # The directory of agents, where each sub-directory is a single agent.
573
565
  # By default, it is the current working directory
574
566
  @click.argument(
@@ -578,9 +570,11 @@ def cli_web(
578
570
  ),
579
571
  default=os.getcwd(),
580
572
  )
573
+ @fast_api_common_options()
581
574
  def cli_api_server(
582
575
  agents_dir: str,
583
576
  session_db_url: str = "",
577
+ artifact_storage_uri: Optional[str] = None,
584
578
  log_level: str = "INFO",
585
579
  allow_origins: Optional[list[str]] = None,
586
580
  host: str = "127.0.0.1",
@@ -603,6 +597,7 @@ def cli_api_server(
603
597
  get_fast_api_app(
604
598
  agents_dir=agents_dir,
605
599
  session_db_url=session_db_url,
600
+ artifact_storage_uri=artifact_storage_uri,
606
601
  allow_origins=allow_origins,
607
602
  web=False,
608
603
  trace_to_cloud=trace_to_cloud,
@@ -706,6 +701,15 @@ def cli_api_server(
706
701
  - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
707
702
  ),
708
703
  )
704
+ @click.option(
705
+ "--artifact_storage_uri",
706
+ type=str,
707
+ help=(
708
+ "Optional. The artifact storage URI to store the artifacts, supported"
709
+ " URIs: gs://<bucket name> for GCS artifact service."
710
+ ),
711
+ default=None,
712
+ )
709
713
  @click.argument(
710
714
  "agent",
711
715
  type=click.Path(
@@ -734,6 +738,7 @@ def cli_deploy_cloud_run(
734
738
  with_ui: bool,
735
739
  verbosity: str,
736
740
  session_db_url: str,
741
+ artifact_storage_uri: Optional[str],
737
742
  adk_version: str,
738
743
  ):
739
744
  """Deploys an agent to Cloud Run.
@@ -757,7 +762,131 @@ def cli_deploy_cloud_run(
757
762
  with_ui=with_ui,
758
763
  verbosity=verbosity,
759
764
  session_db_url=session_db_url,
765
+ artifact_storage_uri=artifact_storage_uri,
760
766
  adk_version=adk_version,
761
767
  )
762
768
  except Exception as e:
763
769
  click.secho(f"Deploy failed: {e}", fg="red", err=True)
770
+
771
+
772
+ @deploy.command("agent_engine")
773
+ @click.option(
774
+ "--project",
775
+ type=str,
776
+ help="Required. Google Cloud project to deploy the agent.",
777
+ )
778
+ @click.option(
779
+ "--region",
780
+ type=str,
781
+ help="Required. Google Cloud region to deploy the agent.",
782
+ )
783
+ @click.option(
784
+ "--staging_bucket",
785
+ type=str,
786
+ help="Required. GCS bucket for staging the deployment artifacts.",
787
+ )
788
+ @click.option(
789
+ "--trace_to_cloud",
790
+ type=bool,
791
+ is_flag=True,
792
+ show_default=True,
793
+ default=False,
794
+ help="Optional. Whether to enable Cloud Trace for Agent Engine.",
795
+ )
796
+ @click.option(
797
+ "--adk_app",
798
+ type=str,
799
+ default="agent_engine_app",
800
+ help=(
801
+ "Optional. Python file for defining the ADK application"
802
+ " (default: a file named agent_engine_app.py)"
803
+ ),
804
+ )
805
+ @click.option(
806
+ "--temp_folder",
807
+ type=str,
808
+ default=os.path.join(
809
+ tempfile.gettempdir(),
810
+ "agent_engine_deploy_src",
811
+ datetime.now().strftime("%Y%m%d_%H%M%S"),
812
+ ),
813
+ help=(
814
+ "Optional. Temp folder for the generated Agent Engine source files."
815
+ " If the folder already exists, its contents will be removed."
816
+ " (default: a timestamped folder in the system temp directory)."
817
+ ),
818
+ )
819
+ @click.option(
820
+ "--env_file",
821
+ type=str,
822
+ default="",
823
+ help=(
824
+ "Optional. The filepath to the `.env` file for environment variables."
825
+ " (default: the `.env` file in the `agent` directory, if any.)"
826
+ ),
827
+ )
828
+ @click.option(
829
+ "--requirements_file",
830
+ type=str,
831
+ default="",
832
+ help=(
833
+ "Optional. The filepath to the `requirements.txt` file to use."
834
+ " (default: the `requirements.txt` file in the `agent` directory, if"
835
+ " any.)"
836
+ ),
837
+ )
838
+ @click.argument(
839
+ "agent",
840
+ type=click.Path(
841
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
842
+ ),
843
+ )
844
+ def cli_deploy_agent_engine(
845
+ agent: str,
846
+ project: str,
847
+ region: str,
848
+ staging_bucket: str,
849
+ trace_to_cloud: bool,
850
+ adk_app: str,
851
+ temp_folder: str,
852
+ env_file: str,
853
+ requirements_file: str,
854
+ ):
855
+ """Deploys an agent to Agent Engine.
856
+
857
+ Args:
858
+ agent (str): Required. The path to the agent to be deloyed.
859
+ project (str): Required. Google Cloud project to deploy the agent.
860
+ region (str): Required. Google Cloud region to deploy the agent.
861
+ staging_bucket (str): Required. GCS bucket for staging the deployment
862
+ artifacts.
863
+ trace_to_cloud (bool): Required. Whether to enable Cloud Trace.
864
+ adk_app (str): Required. Python file for defining the ADK application.
865
+ temp_folder (str): Required. The folder for the generated Agent Engine
866
+ files. If the folder already exists, its contents will be replaced.
867
+ env_file (str): Required. The filepath to the `.env` file for environment
868
+ variables. If it is an empty string, the `.env` file in the `agent`
869
+ directory will be used if it exists.
870
+ requirements_file (str): Required. The filepath to the `requirements.txt`
871
+ file to use. If it is an empty string, the `requirements.txt` file in the
872
+ `agent` directory will be used if exists.
873
+
874
+ Example:
875
+
876
+ adk deploy agent_engine --project=[project] --region=[region]
877
+ --staging_bucket=[staging_bucket] path/to/my_agent
878
+ """
879
+ try:
880
+ cli_deploy.to_agent_engine(
881
+ agent_folder=agent,
882
+ project=project,
883
+ region=region,
884
+ staging_bucket=staging_bucket,
885
+ trace_to_cloud=trace_to_cloud,
886
+ adk_app=adk_app,
887
+ temp_folder=temp_folder,
888
+ env_file=env_file,
889
+ requirements_file=requirements_file,
890
+ )
891
+ except Exception as e:
892
+ click.secho(f"Deploy failed: {e}", fg="red", err=True)
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from __future__ import annotations
17
16
 
18
17
  import asyncio
@@ -56,7 +55,9 @@ from ..agents.live_request_queue import LiveRequest
56
55
  from ..agents.live_request_queue import LiveRequestQueue
57
56
  from ..agents.llm_agent import Agent
58
57
  from ..agents.run_config import StreamingMode
58
+ from ..artifacts.gcs_artifact_service import GcsArtifactService
59
59
  from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
60
+ from ..errors.not_found_error import NotFoundError
60
61
  from ..evaluation.eval_case import EvalCase
61
62
  from ..evaluation.eval_case import SessionInput
62
63
  from ..evaluation.eval_metrics import EvalMetric
@@ -98,7 +99,7 @@ class ApiServerSpanExporter(export.SpanExporter):
98
99
  if (
99
100
  span.name == "call_llm"
100
101
  or span.name == "send_data"
101
- or span.name.startswith("tool_response")
102
+ or span.name.startswith("execute_tool")
102
103
  ):
103
104
  attributes = dict(span.attributes)
104
105
  attributes["trace_id"] = span.get_span_context().trace_id
@@ -193,6 +194,7 @@ def get_fast_api_app(
193
194
  *,
194
195
  agents_dir: str,
195
196
  session_db_url: str = "",
197
+ artifact_storage_uri: Optional[str] = None,
196
198
  allow_origins: Optional[list[str]] = None,
197
199
  web: bool,
198
200
  trace_to_cloud: bool = False,
@@ -251,13 +253,12 @@ def get_fast_api_app(
251
253
 
252
254
  runner_dict = {}
253
255
 
254
- # Build the Artifact service
255
- artifact_service = InMemoryArtifactService()
256
- memory_service = InMemoryMemoryService()
257
-
258
256
  eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
259
257
  eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
260
258
 
259
+ # Build the Memory service
260
+ memory_service = InMemoryMemoryService()
261
+
261
262
  # Build the Session service
262
263
  agent_engine_id = ""
263
264
  if session_db_url:
@@ -276,6 +277,18 @@ def get_fast_api_app(
276
277
  else:
277
278
  session_service = InMemorySessionService()
278
279
 
280
+ # Build the Artifact service
281
+ if artifact_storage_uri:
282
+ if artifact_storage_uri.startswith("gs://"):
283
+ gcs_bucket = artifact_storage_uri.split("://")[1]
284
+ artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
285
+ else:
286
+ raise click.ClickException(
287
+ "Unsupported artifact storage URI: %s" % artifact_storage_uri
288
+ )
289
+ else:
290
+ artifact_service = InMemoryArtifactService()
291
+
279
292
  # initialize Agent Loader
280
293
  agent_loader = AgentLoader(agents_dir)
281
294
 
@@ -475,8 +488,66 @@ def get_fast_api_app(
475
488
  """Lists all evals in an eval set."""
476
489
  eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
477
490
 
491
+ if not eval_set_data:
492
+ raise HTTPException(
493
+ status_code=400, detail=f"Eval set `{eval_set_id}` not found."
494
+ )
495
+
478
496
  return sorted([x.eval_id for x in eval_set_data.eval_cases])
479
497
 
498
+ @app.get(
499
+ "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
500
+ response_model_exclude_none=True,
501
+ )
502
+ def get_eval(app_name: str, eval_set_id: str, eval_case_id: str) -> EvalCase:
503
+ """Gets an eval case in an eval set."""
504
+ eval_case_to_find = eval_sets_manager.get_eval_case(
505
+ app_name, eval_set_id, eval_case_id
506
+ )
507
+
508
+ if eval_case_to_find:
509
+ return eval_case_to_find
510
+
511
+ raise HTTPException(
512
+ status_code=404,
513
+ detail=f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found.",
514
+ )
515
+
516
+ @app.put(
517
+ "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}",
518
+ response_model_exclude_none=True,
519
+ )
520
+ def update_eval(
521
+ app_name: str,
522
+ eval_set_id: str,
523
+ eval_case_id: str,
524
+ updated_eval_case: EvalCase,
525
+ ):
526
+ if updated_eval_case.eval_id and updated_eval_case.eval_id != eval_case_id:
527
+ raise HTTPException(
528
+ status_code=400,
529
+ detail=(
530
+ "Eval id in EvalCase should match the eval id in the API route."
531
+ ),
532
+ )
533
+
534
+ # Overwrite the value. We are either overwriting the same value or an empty
535
+ # field.
536
+ updated_eval_case.eval_id = eval_case_id
537
+ try:
538
+ eval_sets_manager.update_eval_case(
539
+ app_name, eval_set_id, updated_eval_case
540
+ )
541
+ except NotFoundError as nfe:
542
+ raise HTTPException(status_code=404, detail=str(nfe)) from nfe
543
+
544
+ @app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}")
545
+ def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str):
546
+ try:
547
+ eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id)
548
+ except NotFoundError as nfe:
549
+ raise HTTPException(status_code=404, detail=str(nfe)) from nfe
550
+
480
551
  @app.post(
481
552
  "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
482
553
  response_model_exclude_none=True,
@@ -491,6 +562,11 @@ def get_fast_api_app(
491
562
  # run.
492
563
  eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
493
564
 
565
+ if not eval_set:
566
+ raise HTTPException(
567
+ status_code=400, detail=f"Eval set `{eval_set_id}` not found."
568
+ )
569
+
494
570
  if req.eval_ids:
495
571
  eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
496
572
  eval_set_to_evals = {eval_set_id: eval_cases}
@@ -501,34 +577,38 @@ def get_fast_api_app(
501
577
  root_agent = agent_loader.load_agent(app_name)
502
578
  run_eval_results = []
503
579
  eval_case_results = []
504
- async for eval_case_result in run_evals(
505
- eval_set_to_evals,
506
- root_agent,
507
- getattr(root_agent, "reset_data", None),
508
- req.eval_metrics,
509
- session_service=session_service,
510
- artifact_service=artifact_service,
511
- ):
512
- run_eval_results.append(
513
- RunEvalResult(
514
- app_name=app_name,
515
- eval_set_file=eval_case_result.eval_set_file,
516
- eval_set_id=eval_set_id,
517
- eval_id=eval_case_result.eval_id,
518
- final_eval_status=eval_case_result.final_eval_status,
519
- eval_metric_results=eval_case_result.eval_metric_results,
520
- overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
521
- eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
522
- user_id=eval_case_result.user_id,
523
- session_id=eval_case_result.session_id,
524
- )
525
- )
526
- eval_case_result.session_details = await session_service.get_session(
527
- app_name=app_name,
528
- user_id=eval_case_result.user_id,
529
- session_id=eval_case_result.session_id,
530
- )
531
- eval_case_results.append(eval_case_result)
580
+ try:
581
+ async for eval_case_result in run_evals(
582
+ eval_set_to_evals,
583
+ root_agent,
584
+ getattr(root_agent, "reset_data", None),
585
+ req.eval_metrics,
586
+ session_service=session_service,
587
+ artifact_service=artifact_service,
588
+ ):
589
+ run_eval_results.append(
590
+ RunEvalResult(
591
+ app_name=app_name,
592
+ eval_set_file=eval_case_result.eval_set_file,
593
+ eval_set_id=eval_set_id,
594
+ eval_id=eval_case_result.eval_id,
595
+ final_eval_status=eval_case_result.final_eval_status,
596
+ eval_metric_results=eval_case_result.eval_metric_results,
597
+ overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
598
+ eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
599
+ user_id=eval_case_result.user_id,
600
+ session_id=eval_case_result.session_id,
601
+ )
602
+ )
603
+ eval_case_result.session_details = await session_service.get_session(
604
+ app_name=app_name,
605
+ user_id=eval_case_result.user_id,
606
+ session_id=eval_case_result.session_id,
607
+ )
608
+ eval_case_results.append(eval_case_result)
609
+ except ModuleNotFoundError as e:
610
+ logger.exception("%s", e)
611
+ raise HTTPException(status_code=400, detail=str(e)) from e
532
612
 
533
613
  eval_set_results_manager.save_eval_set_result(
534
614
  app_name, eval_set_id, eval_case_results
@@ -854,6 +934,11 @@ def get_fast_api_app(
854
934
  return runner
855
935
 
856
936
  if web:
937
+ import mimetypes
938
+
939
+ mimetypes.add_type("application/javascript", ".js", True)
940
+ mimetypes.add_type("text/javascript", ".js", True)
941
+
857
942
  BASE_DIR = Path(__file__).parent.resolve()
858
943
  ANGULAR_DIST_PATH = BASE_DIR / "browser"
859
944