aiqtoolkit 1.1.0a20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +198 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +116 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +372 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +125 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +134 -0
- aiq/builder/workflow_builder.py +733 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +233 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +34 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +37 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +36 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +307 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +869 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +550 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +269 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +101 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +164 -0
- aiq/eval/evaluate.py +322 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +22 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +181 -0
- aiq/front_ends/fastapi/step_adaptor.py +315 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +102 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +59 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +270 -0
- aiq/observability/register.py +97 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +81 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +116 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +75 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +67 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +36 -0
- aiq/tool/retriever.py +89 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +50 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
- aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
- aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GithubListPullsModel(BaseModel):
|
|
28
|
+
state: Literal["open", "closed", "all"] | None = Field('open', description="Issue state used in issue query filter")
|
|
29
|
+
head: str | None = Field(None, description="Filters pulls by head user or head organization and branch name")
|
|
30
|
+
base: str | None = Field(None, description="Filters pull by branch name")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GithubListPullsModelList(BaseModel):
|
|
34
|
+
filter_params: GithubListPullsModel = Field(description=("A list of query params when fetching pull requests "
|
|
35
|
+
"each of type GithubListPRModel"))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GithubListPullsToolConfig(FunctionBaseConfig, name="github_list_pulls_tool"):
|
|
39
|
+
"""
|
|
40
|
+
Tool that lists GitHub Pull Requests based on various filter parameters
|
|
41
|
+
"""
|
|
42
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
43
|
+
timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@register_function(config_type=GithubListPullsToolConfig)
|
|
47
|
+
async def list_github_pulls_async(config: GithubListPullsToolConfig, builder: Builder):
|
|
48
|
+
"""
|
|
49
|
+
Lists GitHub Pull Requests based on various filter parameters
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
import json
|
|
53
|
+
import os
|
|
54
|
+
|
|
55
|
+
import httpx
|
|
56
|
+
|
|
57
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
58
|
+
if not github_pat:
|
|
59
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
60
|
+
|
|
61
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
62
|
+
|
|
63
|
+
# define the headers for the payload request
|
|
64
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
65
|
+
|
|
66
|
+
async def _github_list_pulls(filter_params) -> dict:
|
|
67
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
68
|
+
|
|
69
|
+
filter_params = filter_params.dict(exclude_unset=True)
|
|
70
|
+
|
|
71
|
+
# filter out None values that are explictly set in the request body.
|
|
72
|
+
filter_params = {k: v for k, v in filter_params.items() if v is not None}
|
|
73
|
+
|
|
74
|
+
response = await client.request("GET", url, params=filter_params, headers=headers)
|
|
75
|
+
|
|
76
|
+
# Raise an exception for HTTP errors
|
|
77
|
+
response.raise_for_status()
|
|
78
|
+
|
|
79
|
+
# Parse and return the response JSON
|
|
80
|
+
try:
|
|
81
|
+
result = response.json()
|
|
82
|
+
|
|
83
|
+
except ValueError as e:
|
|
84
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
85
|
+
|
|
86
|
+
return json.dumps(result)
|
|
87
|
+
|
|
88
|
+
yield FunctionInfo.from_fn(_github_list_pulls,
|
|
89
|
+
description=(f"Lists GitHub PRs based on filter params "
|
|
90
|
+
f"in the repo named {config.repo_name}"),
|
|
91
|
+
input_schema=GithubListPullsModelList)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class GithubGetPullModel(BaseModel):
|
|
95
|
+
pull_number: str = Field(description="The number of the pull request that needs to be fetched")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class GithubGetPullToolConfig(FunctionBaseConfig, name="github_get_pull_tool"):
|
|
99
|
+
"""
|
|
100
|
+
Tool that fetches a particular pull request in a GitHub repository asynchronously.
|
|
101
|
+
"""
|
|
102
|
+
repo_name: str = "The repository name in the format 'owner/repo'"
|
|
103
|
+
timeout: int = 300
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@register_function(config_type=GithubGetPullToolConfig)
|
|
107
|
+
async def get_github_pull_async(config: GithubGetPullToolConfig, builder: Builder):
|
|
108
|
+
"""
|
|
109
|
+
Fetches a particular pull request in a GitHub repository asynchronously.
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
import json
|
|
113
|
+
import os
|
|
114
|
+
|
|
115
|
+
import httpx
|
|
116
|
+
|
|
117
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
118
|
+
if not github_pat:
|
|
119
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
120
|
+
|
|
121
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
122
|
+
|
|
123
|
+
# define the headers for the payload request
|
|
124
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
125
|
+
|
|
126
|
+
async def _github_get_pull(pull_number) -> list:
|
|
127
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
128
|
+
# update the url with the pull number that needs to be updated
|
|
129
|
+
pull_url = os.path.join(url, pull_number)
|
|
130
|
+
|
|
131
|
+
response = await client.request("GET", pull_url, headers=headers)
|
|
132
|
+
|
|
133
|
+
# Raise an exception for HTTP errors
|
|
134
|
+
response.raise_for_status()
|
|
135
|
+
|
|
136
|
+
# Parse and return the response JSON
|
|
137
|
+
try:
|
|
138
|
+
result = response.json()
|
|
139
|
+
|
|
140
|
+
except ValueError as e:
|
|
141
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
142
|
+
|
|
143
|
+
return json.dumps(result)
|
|
144
|
+
|
|
145
|
+
yield FunctionInfo.from_fn(_github_get_pull,
|
|
146
|
+
description=(f"Fetches a particular GitHub pull request "
|
|
147
|
+
f"in the repo named {config.repo_name}"),
|
|
148
|
+
input_schema=GithubGetPullModel)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class GithubGetPullCommitsToolConfig(FunctionBaseConfig, name="github_get_pull_commits_tool"):
|
|
152
|
+
"""
|
|
153
|
+
Configuration for the GitHub Get Pull Commits Tool.
|
|
154
|
+
"""
|
|
155
|
+
repo_name: str = "The repository name in the format 'owner/repo'"
|
|
156
|
+
timeout: int = 300
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@register_function(config_type=GithubGetPullCommitsToolConfig)
|
|
160
|
+
async def get_github_pull_commits_async(config: GithubGetPullCommitsToolConfig, builder: Builder):
|
|
161
|
+
"""
|
|
162
|
+
Fetches the commits associated with a particular pull request in a GitHub repository asynchronously.
|
|
163
|
+
|
|
164
|
+
"""
|
|
165
|
+
import json
|
|
166
|
+
import os
|
|
167
|
+
|
|
168
|
+
import httpx
|
|
169
|
+
|
|
170
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
171
|
+
if not github_pat:
|
|
172
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
173
|
+
|
|
174
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
175
|
+
|
|
176
|
+
# define the headers for the payload request
|
|
177
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
178
|
+
|
|
179
|
+
async def _github_get_pull(pull_number) -> list:
|
|
180
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
181
|
+
# update the url with the pull number that needs to be updated
|
|
182
|
+
pull_url = os.path.join(url, pull_number)
|
|
183
|
+
pull_commits_url = os.path.join(pull_url, "commits")
|
|
184
|
+
|
|
185
|
+
response = await client.request("GET", pull_commits_url, headers=headers)
|
|
186
|
+
|
|
187
|
+
# Raise an exception for HTTP errors
|
|
188
|
+
response.raise_for_status()
|
|
189
|
+
|
|
190
|
+
# Parse and return the response JSON
|
|
191
|
+
try:
|
|
192
|
+
result = response.json()
|
|
193
|
+
|
|
194
|
+
except ValueError as e:
|
|
195
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
196
|
+
|
|
197
|
+
return json.dumps(result)
|
|
198
|
+
|
|
199
|
+
yield FunctionInfo.from_fn(_github_get_pull,
|
|
200
|
+
description=("Fetches the commits for a particular GitHub pull request "
|
|
201
|
+
f" in the repo named {config.repo_name}"),
|
|
202
|
+
input_schema=GithubGetPullModel)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class GithubGetPullFilesToolConfig(FunctionBaseConfig, name="github_get_pull_files_tool"):
|
|
206
|
+
"""
|
|
207
|
+
Configuration for the GitHub Get Pull Files Tool.
|
|
208
|
+
"""
|
|
209
|
+
repo_name: str = "The repository name in the format 'owner/repo'"
|
|
210
|
+
timeout: int = 300
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@register_function(config_type=GithubGetPullFilesToolConfig)
|
|
214
|
+
async def get_github_pull_files_async(config: GithubGetPullFilesToolConfig, builder: Builder):
|
|
215
|
+
"""
|
|
216
|
+
Fetches the files associated with a particular pull request in a GitHub repository asynchronously.
|
|
217
|
+
|
|
218
|
+
"""
|
|
219
|
+
import json
|
|
220
|
+
import os
|
|
221
|
+
|
|
222
|
+
import httpx
|
|
223
|
+
|
|
224
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
225
|
+
if not github_pat:
|
|
226
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
227
|
+
|
|
228
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
229
|
+
|
|
230
|
+
# define the headers for the payload request
|
|
231
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
232
|
+
|
|
233
|
+
async def _github_get_pull(pull_number) -> list:
|
|
234
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
235
|
+
# update the url with the pull number that needs to be updated
|
|
236
|
+
pull_url = os.path.join(url, pull_number)
|
|
237
|
+
pull_files_url = os.path.join(pull_url, "files")
|
|
238
|
+
|
|
239
|
+
response = await client.request("GET", pull_files_url, headers=headers)
|
|
240
|
+
|
|
241
|
+
# Raise an exception for HTTP errors
|
|
242
|
+
response.raise_for_status()
|
|
243
|
+
|
|
244
|
+
# Parse and return the response JSON
|
|
245
|
+
try:
|
|
246
|
+
result = response.json()
|
|
247
|
+
|
|
248
|
+
except ValueError as e:
|
|
249
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
250
|
+
|
|
251
|
+
return json.dumps(result)
|
|
252
|
+
|
|
253
|
+
yield FunctionInfo.from_fn(_github_get_pull,
|
|
254
|
+
description=("Fetches the files for a particular GitHub pull request "
|
|
255
|
+
f" in the repo named {config.repo_name}"),
|
|
256
|
+
input_schema=GithubGetPullModel)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GithubUpdateIssueModel(BaseModel):
|
|
28
|
+
issue_number: str = Field(description="The issue number that will be updated")
|
|
29
|
+
title: str | None = Field(None, description="The title of the GitHub Issue")
|
|
30
|
+
body: str | None = Field(None, description="The body of the GitHub Issue")
|
|
31
|
+
state: Literal["open", "closed"] | None = Field(None, description="The new state of the issue")
|
|
32
|
+
|
|
33
|
+
state_reason: Literal["completed", "not_planned", "reopened", None] | None = Field(
|
|
34
|
+
None, description="The reason for changing the state of the issue")
|
|
35
|
+
|
|
36
|
+
labels: list[str] | None = Field(None, description="A list of labels to assign to the issue")
|
|
37
|
+
assignees: list[str] | None = Field(None, description="A list of assignees to assign to the issue")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GithubUpdateIssueModelList(BaseModel):
|
|
41
|
+
issues: list[GithubUpdateIssueModel] = Field(description=("A list of GitHub issues each "
|
|
42
|
+
"of type GithubUpdateIssueModel"))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GithubUpdateIssueToolConfig(FunctionBaseConfig, name="github_update_issue_tool"):
|
|
46
|
+
"""
|
|
47
|
+
Tool that updates an issue in a GitHub repository asynchronously.
|
|
48
|
+
"""
|
|
49
|
+
repo_name: str = "The repository name in the format 'owner/repo'"
|
|
50
|
+
timeout: int = 300
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_function(config_type=GithubUpdateIssueToolConfig)
|
|
54
|
+
async def update_github_issue_async(config: GithubUpdateIssueToolConfig, builder: Builder):
|
|
55
|
+
"""
|
|
56
|
+
Updates an issue in a GitHub repository asynchronously.
|
|
57
|
+
"""
|
|
58
|
+
import json
|
|
59
|
+
import os
|
|
60
|
+
|
|
61
|
+
import httpx
|
|
62
|
+
|
|
63
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
64
|
+
if not github_pat:
|
|
65
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
66
|
+
|
|
67
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
68
|
+
|
|
69
|
+
# define the headers for the payload request
|
|
70
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
71
|
+
|
|
72
|
+
async def _github_update_issue(issues) -> list:
|
|
73
|
+
results = []
|
|
74
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
75
|
+
for issue in issues:
|
|
76
|
+
payload = issue.dict(exclude_unset=True)
|
|
77
|
+
|
|
78
|
+
# update the url with the issue number that needs to be updated
|
|
79
|
+
issue_number = payload.pop("issue_number")
|
|
80
|
+
issue_url = os.path.join(url, issue_number)
|
|
81
|
+
|
|
82
|
+
response = await client.request("PATCH", issue_url, json=payload, headers=headers)
|
|
83
|
+
|
|
84
|
+
# Raise an exception for HTTP errors
|
|
85
|
+
response.raise_for_status()
|
|
86
|
+
|
|
87
|
+
# Parse and return the response JSON
|
|
88
|
+
try:
|
|
89
|
+
result = response.json()
|
|
90
|
+
results.append(result)
|
|
91
|
+
|
|
92
|
+
except ValueError as e:
|
|
93
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
94
|
+
|
|
95
|
+
return json.dumps(results)
|
|
96
|
+
|
|
97
|
+
yield FunctionInfo.from_fn(_github_update_issue,
|
|
98
|
+
description=(f"Updates a GitHub issue in the "
|
|
99
|
+
f"repo named {config.repo_name}"),
|
|
100
|
+
input_schema=GithubUpdateIssueModelList)
|
aiq/tool/mcp/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from contextlib import asynccontextmanager
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from mcp import ClientSession
|
|
24
|
+
from mcp.client.sse import sse_client
|
|
25
|
+
from mcp.types import TextContent
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic import Field
|
|
28
|
+
from pydantic import create_model
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
|
|
34
|
+
"""
|
|
35
|
+
Create a pydantic model from the input schema of the MCP tool
|
|
36
|
+
"""
|
|
37
|
+
_type_map = {
|
|
38
|
+
"string": str,
|
|
39
|
+
"number": float,
|
|
40
|
+
"integer": int,
|
|
41
|
+
"boolean": bool,
|
|
42
|
+
"array": list,
|
|
43
|
+
"null": None,
|
|
44
|
+
"object": dict,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
properties = mcp_input_schema.get("properties", {})
|
|
48
|
+
schema_dict = {}
|
|
49
|
+
|
|
50
|
+
def _generate_valid_classname(class_name: str):
|
|
51
|
+
return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
|
|
52
|
+
|
|
53
|
+
def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
|
|
54
|
+
json_type = field_properties.get("type", "string")
|
|
55
|
+
enum_vals = field_properties.get("enum")
|
|
56
|
+
|
|
57
|
+
if enum_vals:
|
|
58
|
+
enum_name = f"{field_name.capitalize()}Enum"
|
|
59
|
+
field_type = Enum(enum_name, {item: item for item in enum_vals})
|
|
60
|
+
|
|
61
|
+
elif json_type == "object" and "properties" in field_properties:
|
|
62
|
+
field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
|
|
63
|
+
elif json_type == "array" and "items" in field_properties:
|
|
64
|
+
item_properties = field_properties.get("items", {})
|
|
65
|
+
if item_properties.get("type") == "object":
|
|
66
|
+
item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
|
|
67
|
+
else:
|
|
68
|
+
item_type = _type_map.get(json_type, Any)
|
|
69
|
+
field_type = list[item_type]
|
|
70
|
+
else:
|
|
71
|
+
field_type = _type_map.get(json_type, Any)
|
|
72
|
+
|
|
73
|
+
default_value = field_properties.get("default", ...)
|
|
74
|
+
nullable = field_properties.get("nullable", False)
|
|
75
|
+
description = field_properties.get("description", "")
|
|
76
|
+
|
|
77
|
+
field_type = field_type | None if nullable else field_type
|
|
78
|
+
|
|
79
|
+
return field_type, Field(default=default_value, description=description)
|
|
80
|
+
|
|
81
|
+
for field_name, field_props in properties.items():
|
|
82
|
+
schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
|
|
83
|
+
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class MCPSSEClient:
|
|
87
|
+
"""
|
|
88
|
+
Client for creating a session and connecting to an MCP server using SSE
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
url (str): The url of the MCP server
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, url: str):
|
|
95
|
+
self.url = url
|
|
96
|
+
|
|
97
|
+
@asynccontextmanager
|
|
98
|
+
async def connect_to_sse_server(self):
|
|
99
|
+
"""
|
|
100
|
+
Establish a session with an MCP SSE server within an aync context
|
|
101
|
+
"""
|
|
102
|
+
async with sse_client(url=self.url) as (read, write):
|
|
103
|
+
async with ClientSession(read, write) as session:
|
|
104
|
+
await session.initialize()
|
|
105
|
+
yield session
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class MCPBuilder(MCPSSEClient):
|
|
109
|
+
"""
|
|
110
|
+
Builder class used to connect to an MCP Server and generate ToolClients
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
url (str): The url of the MCP server
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, url):
|
|
117
|
+
super().__init__(url)
|
|
118
|
+
self._tools = None
|
|
119
|
+
|
|
120
|
+
async def get_tools(self):
|
|
121
|
+
"""
|
|
122
|
+
Retrieve a dictionary of all tools served by the MCP server.
|
|
123
|
+
"""
|
|
124
|
+
async with self.connect_to_sse_server() as session:
|
|
125
|
+
response = await session.list_tools()
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
|
|
129
|
+
for tool in response.tools
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
async def get_tool(self, tool_name: str) -> MCPToolClient:
|
|
133
|
+
"""
|
|
134
|
+
Get an MCP Tool by name.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool_name (str): Name of the tool to load.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
MCPToolClient for the configured tool.
|
|
141
|
+
|
|
142
|
+
Raise:
|
|
143
|
+
ValueError if no tool is available with that name.
|
|
144
|
+
"""
|
|
145
|
+
if not self._tools:
|
|
146
|
+
self._tools = await self.get_tools()
|
|
147
|
+
|
|
148
|
+
tool = self._tools.get(tool_name)
|
|
149
|
+
if not tool:
|
|
150
|
+
raise ValueError(f"Tool {tool_name} not available at {self.url}")
|
|
151
|
+
return tool
|
|
152
|
+
|
|
153
|
+
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
154
|
+
async with self.connect_to_sse_server() as session:
|
|
155
|
+
result = await session.call_tool(tool_name, tool_args)
|
|
156
|
+
return result
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class MCPToolClient(MCPSSEClient):
|
|
160
|
+
"""
|
|
161
|
+
Client wrapper used to call an MCP tool.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
url (str): The url of the MCP server
|
|
165
|
+
tool_name (str): The name of the tool to wrap
|
|
166
|
+
tool_description (str): The description of the tool provided by the MCP server.
|
|
167
|
+
tool_input_schema (dict): The input schema for the tool.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
|
|
171
|
+
super().__init__(url)
|
|
172
|
+
self._tool_name = tool_name
|
|
173
|
+
self._tool_description = tool_description
|
|
174
|
+
self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def name(self):
|
|
178
|
+
"""Returns the name of the tool."""
|
|
179
|
+
return self._tool_name
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def description(self):
|
|
183
|
+
"""
|
|
184
|
+
Returns the tool's description. If none was provided. Provides a simple description using the tool's name
|
|
185
|
+
"""
|
|
186
|
+
if not self._tool_description:
|
|
187
|
+
return f"MCP Tool {self._tool_name}"
|
|
188
|
+
return self._tool_description
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def input_schema(self):
|
|
192
|
+
"""
|
|
193
|
+
Returns the tool's input_schema.
|
|
194
|
+
"""
|
|
195
|
+
return self._input_schema
|
|
196
|
+
|
|
197
|
+
def set_description(self, description: str):
|
|
198
|
+
"""
|
|
199
|
+
Manually define the tool's description using the provided string.
|
|
200
|
+
"""
|
|
201
|
+
self._tool_description = description
|
|
202
|
+
|
|
203
|
+
async def acall(self, tool_args: dict) -> str:
|
|
204
|
+
"""
|
|
205
|
+
Call the MCP tool with the provided arguments.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
209
|
+
"""
|
|
210
|
+
async with self.connect_to_sse_server() as session:
|
|
211
|
+
result = await session.call_tool(self._tool_name, tool_args)
|
|
212
|
+
|
|
213
|
+
output = []
|
|
214
|
+
for res in result.content:
|
|
215
|
+
if isinstance(res, TextContent):
|
|
216
|
+
output.append(res.text)
|
|
217
|
+
else:
|
|
218
|
+
# Log non-text content for now
|
|
219
|
+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
|
220
|
+
return "\n".join(output)
|
aiq/tool/mcp/mcp_tool.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import HttpUrl
|
|
21
|
+
|
|
22
|
+
from aiq.builder.builder import Builder
|
|
23
|
+
from aiq.builder.function_info import FunctionInfo
|
|
24
|
+
from aiq.cli.register_workflow import register_function
|
|
25
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
|
|
31
|
+
"""
|
|
32
|
+
Function which connects to a Model Context Protocol (MCP) server and wraps the selected tool as an AgentIQ function.
|
|
33
|
+
"""
|
|
34
|
+
# Add your custom configuration parameters here
|
|
35
|
+
url: HttpUrl = Field(description="The URL of the MCP server")
|
|
36
|
+
mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
|
|
37
|
+
description: str | None = Field(default=None,
|
|
38
|
+
description="""
|
|
39
|
+
Description for the tool that will override the description provided by the MCP server. Should only be used if
|
|
40
|
+
the description provided by the server is poor or nonexistent
|
|
41
|
+
""")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register_function(config_type=MCPToolConfig)
|
|
45
|
+
async def mcp_tool(config: MCPToolConfig, builder: Builder):
|
|
46
|
+
"""
|
|
47
|
+
Generate an AgentIQ Function that wraps a tool provided by the MCP server.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
from aiq.tool.mcp.mcp_client import MCPBuilder
|
|
51
|
+
from aiq.tool.mcp.mcp_client import MCPToolClient
|
|
52
|
+
|
|
53
|
+
client = MCPBuilder(url=str(config.url))
|
|
54
|
+
|
|
55
|
+
tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
|
|
56
|
+
if config.description:
|
|
57
|
+
tool.set_description(description=config.description)
|
|
58
|
+
|
|
59
|
+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, str(config.url))
|
|
60
|
+
|
|
61
|
+
def _convert_from_str(input_str: str) -> tool.input_schema:
|
|
62
|
+
return tool.input_schema.model_validate_json(input_str)
|
|
63
|
+
|
|
64
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
|
65
|
+
if tool_input:
|
|
66
|
+
args = tool_input.model_dump()
|
|
67
|
+
return await tool.acall(args)
|
|
68
|
+
|
|
69
|
+
_ = tool.input_schema.model_validate(kwargs)
|
|
70
|
+
return await tool.acall(kwargs)
|
|
71
|
+
|
|
72
|
+
yield FunctionInfo.create(single_fn=_response_fn,
|
|
73
|
+
description=tool.description,
|
|
74
|
+
input_schema=tool.input_schema,
|
|
75
|
+
converters=[_convert_from_str])
|
|
File without changes
|