aiqtoolkit 1.1.0a20250429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiqtoolkit might be problematic. Click here for more details.

Files changed (309) hide show
  1. aiq/agent/__init__.py +0 -0
  2. aiq/agent/base.py +76 -0
  3. aiq/agent/dual_node.py +67 -0
  4. aiq/agent/react_agent/__init__.py +0 -0
  5. aiq/agent/react_agent/agent.py +322 -0
  6. aiq/agent/react_agent/output_parser.py +104 -0
  7. aiq/agent/react_agent/prompt.py +46 -0
  8. aiq/agent/react_agent/register.py +148 -0
  9. aiq/agent/reasoning_agent/__init__.py +0 -0
  10. aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
  11. aiq/agent/register.py +23 -0
  12. aiq/agent/rewoo_agent/__init__.py +0 -0
  13. aiq/agent/rewoo_agent/agent.py +410 -0
  14. aiq/agent/rewoo_agent/prompt.py +108 -0
  15. aiq/agent/rewoo_agent/register.py +158 -0
  16. aiq/agent/tool_calling_agent/__init__.py +0 -0
  17. aiq/agent/tool_calling_agent/agent.py +123 -0
  18. aiq/agent/tool_calling_agent/register.py +105 -0
  19. aiq/builder/__init__.py +0 -0
  20. aiq/builder/builder.py +223 -0
  21. aiq/builder/component_utils.py +303 -0
  22. aiq/builder/context.py +198 -0
  23. aiq/builder/embedder.py +24 -0
  24. aiq/builder/eval_builder.py +116 -0
  25. aiq/builder/evaluator.py +29 -0
  26. aiq/builder/framework_enum.py +24 -0
  27. aiq/builder/front_end.py +73 -0
  28. aiq/builder/function.py +297 -0
  29. aiq/builder/function_base.py +372 -0
  30. aiq/builder/function_info.py +627 -0
  31. aiq/builder/intermediate_step_manager.py +125 -0
  32. aiq/builder/llm.py +25 -0
  33. aiq/builder/retriever.py +25 -0
  34. aiq/builder/user_interaction_manager.py +71 -0
  35. aiq/builder/workflow.py +134 -0
  36. aiq/builder/workflow_builder.py +733 -0
  37. aiq/cli/__init__.py +14 -0
  38. aiq/cli/cli_utils/__init__.py +0 -0
  39. aiq/cli/cli_utils/config_override.py +233 -0
  40. aiq/cli/cli_utils/validation.py +37 -0
  41. aiq/cli/commands/__init__.py +0 -0
  42. aiq/cli/commands/configure/__init__.py +0 -0
  43. aiq/cli/commands/configure/channel/__init__.py +0 -0
  44. aiq/cli/commands/configure/channel/add.py +28 -0
  45. aiq/cli/commands/configure/channel/channel.py +34 -0
  46. aiq/cli/commands/configure/channel/remove.py +30 -0
  47. aiq/cli/commands/configure/channel/update.py +30 -0
  48. aiq/cli/commands/configure/configure.py +33 -0
  49. aiq/cli/commands/evaluate.py +139 -0
  50. aiq/cli/commands/info/__init__.py +14 -0
  51. aiq/cli/commands/info/info.py +37 -0
  52. aiq/cli/commands/info/list_channels.py +32 -0
  53. aiq/cli/commands/info/list_components.py +129 -0
  54. aiq/cli/commands/registry/__init__.py +14 -0
  55. aiq/cli/commands/registry/publish.py +88 -0
  56. aiq/cli/commands/registry/pull.py +118 -0
  57. aiq/cli/commands/registry/registry.py +36 -0
  58. aiq/cli/commands/registry/remove.py +108 -0
  59. aiq/cli/commands/registry/search.py +155 -0
  60. aiq/cli/commands/start.py +250 -0
  61. aiq/cli/commands/uninstall.py +83 -0
  62. aiq/cli/commands/validate.py +47 -0
  63. aiq/cli/commands/workflow/__init__.py +14 -0
  64. aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
  65. aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
  66. aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
  67. aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
  68. aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
  69. aiq/cli/commands/workflow/workflow.py +37 -0
  70. aiq/cli/commands/workflow/workflow_commands.py +307 -0
  71. aiq/cli/entrypoint.py +133 -0
  72. aiq/cli/main.py +44 -0
  73. aiq/cli/register_workflow.py +408 -0
  74. aiq/cli/type_registry.py +869 -0
  75. aiq/data_models/__init__.py +14 -0
  76. aiq/data_models/api_server.py +550 -0
  77. aiq/data_models/common.py +143 -0
  78. aiq/data_models/component.py +46 -0
  79. aiq/data_models/component_ref.py +135 -0
  80. aiq/data_models/config.py +349 -0
  81. aiq/data_models/dataset_handler.py +122 -0
  82. aiq/data_models/discovery_metadata.py +269 -0
  83. aiq/data_models/embedder.py +26 -0
  84. aiq/data_models/evaluate.py +101 -0
  85. aiq/data_models/evaluator.py +26 -0
  86. aiq/data_models/front_end.py +26 -0
  87. aiq/data_models/function.py +30 -0
  88. aiq/data_models/function_dependencies.py +64 -0
  89. aiq/data_models/interactive.py +237 -0
  90. aiq/data_models/intermediate_step.py +269 -0
  91. aiq/data_models/invocation_node.py +38 -0
  92. aiq/data_models/llm.py +26 -0
  93. aiq/data_models/logging.py +26 -0
  94. aiq/data_models/memory.py +26 -0
  95. aiq/data_models/profiler.py +53 -0
  96. aiq/data_models/registry_handler.py +26 -0
  97. aiq/data_models/retriever.py +30 -0
  98. aiq/data_models/step_adaptor.py +64 -0
  99. aiq/data_models/streaming.py +33 -0
  100. aiq/data_models/swe_bench_model.py +54 -0
  101. aiq/data_models/telemetry_exporter.py +26 -0
  102. aiq/embedder/__init__.py +0 -0
  103. aiq/embedder/langchain_client.py +41 -0
  104. aiq/embedder/nim_embedder.py +58 -0
  105. aiq/embedder/openai_embedder.py +42 -0
  106. aiq/embedder/register.py +24 -0
  107. aiq/eval/__init__.py +14 -0
  108. aiq/eval/config.py +42 -0
  109. aiq/eval/dataset_handler/__init__.py +0 -0
  110. aiq/eval/dataset_handler/dataset_downloader.py +106 -0
  111. aiq/eval/dataset_handler/dataset_filter.py +52 -0
  112. aiq/eval/dataset_handler/dataset_handler.py +164 -0
  113. aiq/eval/evaluate.py +322 -0
  114. aiq/eval/evaluator/__init__.py +14 -0
  115. aiq/eval/evaluator/evaluator_model.py +44 -0
  116. aiq/eval/intermediate_step_adapter.py +93 -0
  117. aiq/eval/rag_evaluator/__init__.py +0 -0
  118. aiq/eval/rag_evaluator/evaluate.py +138 -0
  119. aiq/eval/rag_evaluator/register.py +138 -0
  120. aiq/eval/register.py +22 -0
  121. aiq/eval/remote_workflow.py +128 -0
  122. aiq/eval/runtime_event_subscriber.py +52 -0
  123. aiq/eval/swe_bench_evaluator/__init__.py +0 -0
  124. aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
  125. aiq/eval/swe_bench_evaluator/register.py +36 -0
  126. aiq/eval/trajectory_evaluator/__init__.py +0 -0
  127. aiq/eval/trajectory_evaluator/evaluate.py +118 -0
  128. aiq/eval/trajectory_evaluator/register.py +40 -0
  129. aiq/eval/utils/__init__.py +0 -0
  130. aiq/eval/utils/output_uploader.py +131 -0
  131. aiq/eval/utils/tqdm_position_registry.py +40 -0
  132. aiq/front_ends/__init__.py +14 -0
  133. aiq/front_ends/console/__init__.py +14 -0
  134. aiq/front_ends/console/console_front_end_config.py +32 -0
  135. aiq/front_ends/console/console_front_end_plugin.py +107 -0
  136. aiq/front_ends/console/register.py +25 -0
  137. aiq/front_ends/cron/__init__.py +14 -0
  138. aiq/front_ends/fastapi/__init__.py +14 -0
  139. aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
  140. aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
  141. aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
  142. aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
  143. aiq/front_ends/fastapi/job_store.py +161 -0
  144. aiq/front_ends/fastapi/main.py +70 -0
  145. aiq/front_ends/fastapi/message_handler.py +279 -0
  146. aiq/front_ends/fastapi/message_validator.py +345 -0
  147. aiq/front_ends/fastapi/register.py +25 -0
  148. aiq/front_ends/fastapi/response_helpers.py +181 -0
  149. aiq/front_ends/fastapi/step_adaptor.py +315 -0
  150. aiq/front_ends/fastapi/websocket.py +148 -0
  151. aiq/front_ends/mcp/__init__.py +14 -0
  152. aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
  153. aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
  154. aiq/front_ends/mcp/register.py +27 -0
  155. aiq/front_ends/mcp/tool_converter.py +242 -0
  156. aiq/front_ends/register.py +22 -0
  157. aiq/front_ends/simple_base/__init__.py +14 -0
  158. aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
  159. aiq/llm/__init__.py +0 -0
  160. aiq/llm/nim_llm.py +45 -0
  161. aiq/llm/openai_llm.py +45 -0
  162. aiq/llm/register.py +22 -0
  163. aiq/llm/utils/__init__.py +14 -0
  164. aiq/llm/utils/env_config_value.py +94 -0
  165. aiq/llm/utils/error.py +17 -0
  166. aiq/memory/__init__.py +20 -0
  167. aiq/memory/interfaces.py +183 -0
  168. aiq/memory/models.py +102 -0
  169. aiq/meta/module_to_distro.json +3 -0
  170. aiq/meta/pypi.md +59 -0
  171. aiq/observability/__init__.py +0 -0
  172. aiq/observability/async_otel_listener.py +270 -0
  173. aiq/observability/register.py +97 -0
  174. aiq/plugins/.namespace +1 -0
  175. aiq/profiler/__init__.py +0 -0
  176. aiq/profiler/callbacks/__init__.py +0 -0
  177. aiq/profiler/callbacks/agno_callback_handler.py +295 -0
  178. aiq/profiler/callbacks/base_callback_class.py +20 -0
  179. aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
  180. aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
  181. aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
  182. aiq/profiler/callbacks/token_usage_base_model.py +27 -0
  183. aiq/profiler/data_frame_row.py +51 -0
  184. aiq/profiler/decorators/__init__.py +0 -0
  185. aiq/profiler/decorators/framework_wrapper.py +131 -0
  186. aiq/profiler/decorators/function_tracking.py +254 -0
  187. aiq/profiler/forecasting/__init__.py +0 -0
  188. aiq/profiler/forecasting/config.py +18 -0
  189. aiq/profiler/forecasting/model_trainer.py +75 -0
  190. aiq/profiler/forecasting/models/__init__.py +22 -0
  191. aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
  192. aiq/profiler/forecasting/models/linear_model.py +196 -0
  193. aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
  194. aiq/profiler/inference_metrics_model.py +25 -0
  195. aiq/profiler/inference_optimization/__init__.py +0 -0
  196. aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
  197. aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
  198. aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
  199. aiq/profiler/inference_optimization/data_models.py +386 -0
  200. aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
  201. aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
  202. aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
  203. aiq/profiler/inference_optimization/llm_metrics.py +212 -0
  204. aiq/profiler/inference_optimization/prompt_caching.py +163 -0
  205. aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
  206. aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
  207. aiq/profiler/intermediate_property_adapter.py +102 -0
  208. aiq/profiler/profile_runner.py +433 -0
  209. aiq/profiler/utils.py +184 -0
  210. aiq/registry_handlers/__init__.py +0 -0
  211. aiq/registry_handlers/local/__init__.py +0 -0
  212. aiq/registry_handlers/local/local_handler.py +176 -0
  213. aiq/registry_handlers/local/register_local.py +37 -0
  214. aiq/registry_handlers/metadata_factory.py +60 -0
  215. aiq/registry_handlers/package_utils.py +198 -0
  216. aiq/registry_handlers/pypi/__init__.py +0 -0
  217. aiq/registry_handlers/pypi/pypi_handler.py +251 -0
  218. aiq/registry_handlers/pypi/register_pypi.py +40 -0
  219. aiq/registry_handlers/register.py +21 -0
  220. aiq/registry_handlers/registry_handler_base.py +157 -0
  221. aiq/registry_handlers/rest/__init__.py +0 -0
  222. aiq/registry_handlers/rest/register_rest.py +56 -0
  223. aiq/registry_handlers/rest/rest_handler.py +237 -0
  224. aiq/registry_handlers/schemas/__init__.py +0 -0
  225. aiq/registry_handlers/schemas/headers.py +42 -0
  226. aiq/registry_handlers/schemas/package.py +68 -0
  227. aiq/registry_handlers/schemas/publish.py +63 -0
  228. aiq/registry_handlers/schemas/pull.py +81 -0
  229. aiq/registry_handlers/schemas/remove.py +36 -0
  230. aiq/registry_handlers/schemas/search.py +91 -0
  231. aiq/registry_handlers/schemas/status.py +47 -0
  232. aiq/retriever/__init__.py +0 -0
  233. aiq/retriever/interface.py +37 -0
  234. aiq/retriever/milvus/__init__.py +14 -0
  235. aiq/retriever/milvus/register.py +81 -0
  236. aiq/retriever/milvus/retriever.py +228 -0
  237. aiq/retriever/models.py +74 -0
  238. aiq/retriever/nemo_retriever/__init__.py +14 -0
  239. aiq/retriever/nemo_retriever/register.py +60 -0
  240. aiq/retriever/nemo_retriever/retriever.py +190 -0
  241. aiq/retriever/register.py +22 -0
  242. aiq/runtime/__init__.py +14 -0
  243. aiq/runtime/loader.py +188 -0
  244. aiq/runtime/runner.py +176 -0
  245. aiq/runtime/session.py +116 -0
  246. aiq/settings/__init__.py +0 -0
  247. aiq/settings/global_settings.py +318 -0
  248. aiq/test/.namespace +1 -0
  249. aiq/tool/__init__.py +0 -0
  250. aiq/tool/code_execution/__init__.py +0 -0
  251. aiq/tool/code_execution/code_sandbox.py +188 -0
  252. aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
  253. aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
  254. aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
  255. aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
  256. aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
  257. aiq/tool/code_execution/register.py +70 -0
  258. aiq/tool/code_execution/utils.py +100 -0
  259. aiq/tool/datetime_tools.py +42 -0
  260. aiq/tool/document_search.py +141 -0
  261. aiq/tool/github_tools/__init__.py +0 -0
  262. aiq/tool/github_tools/create_github_commit.py +133 -0
  263. aiq/tool/github_tools/create_github_issue.py +87 -0
  264. aiq/tool/github_tools/create_github_pr.py +106 -0
  265. aiq/tool/github_tools/get_github_file.py +106 -0
  266. aiq/tool/github_tools/get_github_issue.py +166 -0
  267. aiq/tool/github_tools/get_github_pr.py +256 -0
  268. aiq/tool/github_tools/update_github_issue.py +100 -0
  269. aiq/tool/mcp/__init__.py +14 -0
  270. aiq/tool/mcp/mcp_client.py +220 -0
  271. aiq/tool/mcp/mcp_tool.py +75 -0
  272. aiq/tool/memory_tools/__init__.py +0 -0
  273. aiq/tool/memory_tools/add_memory_tool.py +67 -0
  274. aiq/tool/memory_tools/delete_memory_tool.py +67 -0
  275. aiq/tool/memory_tools/get_memory_tool.py +72 -0
  276. aiq/tool/nvidia_rag.py +95 -0
  277. aiq/tool/register.py +36 -0
  278. aiq/tool/retriever.py +89 -0
  279. aiq/utils/__init__.py +0 -0
  280. aiq/utils/data_models/__init__.py +0 -0
  281. aiq/utils/data_models/schema_validator.py +58 -0
  282. aiq/utils/debugging_utils.py +43 -0
  283. aiq/utils/exception_handlers/__init__.py +0 -0
  284. aiq/utils/exception_handlers/schemas.py +114 -0
  285. aiq/utils/io/__init__.py +0 -0
  286. aiq/utils/io/yaml_tools.py +50 -0
  287. aiq/utils/metadata_utils.py +74 -0
  288. aiq/utils/producer_consumer_queue.py +178 -0
  289. aiq/utils/reactive/__init__.py +0 -0
  290. aiq/utils/reactive/base/__init__.py +0 -0
  291. aiq/utils/reactive/base/observable_base.py +65 -0
  292. aiq/utils/reactive/base/observer_base.py +55 -0
  293. aiq/utils/reactive/base/subject_base.py +79 -0
  294. aiq/utils/reactive/observable.py +59 -0
  295. aiq/utils/reactive/observer.py +76 -0
  296. aiq/utils/reactive/subject.py +131 -0
  297. aiq/utils/reactive/subscription.py +49 -0
  298. aiq/utils/settings/__init__.py +0 -0
  299. aiq/utils/settings/global_settings.py +197 -0
  300. aiq/utils/type_converter.py +232 -0
  301. aiq/utils/type_utils.py +397 -0
  302. aiq/utils/url_utils.py +27 -0
  303. aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
  304. aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
  305. aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
  306. aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
  307. aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
  308. aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
  309. aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Literal
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from aiq.builder.builder import Builder
22
+ from aiq.builder.function_info import FunctionInfo
23
+ from aiq.cli.register_workflow import register_function
24
+ from aiq.data_models.function import FunctionBaseConfig
25
+
26
+
27
+ class GithubListPullsModel(BaseModel):
28
+ state: Literal["open", "closed", "all"] | None = Field('open', description="Issue state used in issue query filter")
29
+ head: str | None = Field(None, description="Filters pulls by head user or head organization and branch name")
30
+ base: str | None = Field(None, description="Filters pull by branch name")
31
+
32
+
33
+ class GithubListPullsModelList(BaseModel):
34
+ filter_params: GithubListPullsModel = Field(description=("A list of query params when fetching pull requests "
35
+ "each of type GithubListPRModel"))
36
+
37
+
38
+ class GithubListPullsToolConfig(FunctionBaseConfig, name="github_list_pulls_tool"):
39
+ """
40
+ Tool that lists GitHub Pull Requests based on various filter parameters
41
+ """
42
+ repo_name: str = Field(description="The repository name in the format 'owner/repo'")
43
+ timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
44
+
45
+
46
+ @register_function(config_type=GithubListPullsToolConfig)
47
+ async def list_github_pulls_async(config: GithubListPullsToolConfig, builder: Builder):
48
+ """
49
+ Lists GitHub Pull Requests based on various filter parameters
50
+
51
+ """
52
+ import json
53
+ import os
54
+
55
+ import httpx
56
+
57
+ github_pat = os.getenv("GITHUB_PAT")
58
+ if not github_pat:
59
+ raise ValueError("GITHUB_PAT environment variable must be set")
60
+
61
+ url = f"https://api.github.com/repos/{config.repo_name}/pulls"
62
+
63
+ # define the headers for the payload request
64
+ headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
65
+
66
+ async def _github_list_pulls(filter_params) -> dict:
67
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
68
+
69
+ filter_params = filter_params.dict(exclude_unset=True)
70
+
71
+ # filter out None values that are explictly set in the request body.
72
+ filter_params = {k: v for k, v in filter_params.items() if v is not None}
73
+
74
+ response = await client.request("GET", url, params=filter_params, headers=headers)
75
+
76
+ # Raise an exception for HTTP errors
77
+ response.raise_for_status()
78
+
79
+ # Parse and return the response JSON
80
+ try:
81
+ result = response.json()
82
+
83
+ except ValueError as e:
84
+ raise ValueError("The API response is not valid JSON.") from e
85
+
86
+ return json.dumps(result)
87
+
88
+ yield FunctionInfo.from_fn(_github_list_pulls,
89
+ description=(f"Lists GitHub PRs based on filter params "
90
+ f"in the repo named {config.repo_name}"),
91
+ input_schema=GithubListPullsModelList)
92
+
93
+
94
+ class GithubGetPullModel(BaseModel):
95
+ pull_number: str = Field(description="The number of the pull request that needs to be fetched")
96
+
97
+
98
+ class GithubGetPullToolConfig(FunctionBaseConfig, name="github_get_pull_tool"):
99
+ """
100
+ Tool that fetches a particular pull request in a GitHub repository asynchronously.
101
+ """
102
+ repo_name: str = "The repository name in the format 'owner/repo'"
103
+ timeout: int = 300
104
+
105
+
106
+ @register_function(config_type=GithubGetPullToolConfig)
107
+ async def get_github_pull_async(config: GithubGetPullToolConfig, builder: Builder):
108
+ """
109
+ Fetches a particular pull request in a GitHub repository asynchronously.
110
+
111
+ """
112
+ import json
113
+ import os
114
+
115
+ import httpx
116
+
117
+ github_pat = os.getenv("GITHUB_PAT")
118
+ if not github_pat:
119
+ raise ValueError("GITHUB_PAT environment variable must be set")
120
+
121
+ url = f"https://api.github.com/repos/{config.repo_name}/pulls"
122
+
123
+ # define the headers for the payload request
124
+ headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
125
+
126
+ async def _github_get_pull(pull_number) -> list:
127
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
128
+ # update the url with the pull number that needs to be updated
129
+ pull_url = os.path.join(url, pull_number)
130
+
131
+ response = await client.request("GET", pull_url, headers=headers)
132
+
133
+ # Raise an exception for HTTP errors
134
+ response.raise_for_status()
135
+
136
+ # Parse and return the response JSON
137
+ try:
138
+ result = response.json()
139
+
140
+ except ValueError as e:
141
+ raise ValueError("The API response is not valid JSON.") from e
142
+
143
+ return json.dumps(result)
144
+
145
+ yield FunctionInfo.from_fn(_github_get_pull,
146
+ description=(f"Fetches a particular GitHub pull request "
147
+ f"in the repo named {config.repo_name}"),
148
+ input_schema=GithubGetPullModel)
149
+
150
+
151
+ class GithubGetPullCommitsToolConfig(FunctionBaseConfig, name="github_get_pull_commits_tool"):
152
+ """
153
+ Configuration for the GitHub Get Pull Commits Tool.
154
+ """
155
+ repo_name: str = "The repository name in the format 'owner/repo'"
156
+ timeout: int = 300
157
+
158
+
159
+ @register_function(config_type=GithubGetPullCommitsToolConfig)
160
+ async def get_github_pull_commits_async(config: GithubGetPullCommitsToolConfig, builder: Builder):
161
+ """
162
+ Fetches the commits associated with a particular pull request in a GitHub repository asynchronously.
163
+
164
+ """
165
+ import json
166
+ import os
167
+
168
+ import httpx
169
+
170
+ github_pat = os.getenv("GITHUB_PAT")
171
+ if not github_pat:
172
+ raise ValueError("GITHUB_PAT environment variable must be set")
173
+
174
+ url = f"https://api.github.com/repos/{config.repo_name}/pulls"
175
+
176
+ # define the headers for the payload request
177
+ headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
178
+
179
+ async def _github_get_pull(pull_number) -> list:
180
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
181
+ # update the url with the pull number that needs to be updated
182
+ pull_url = os.path.join(url, pull_number)
183
+ pull_commits_url = os.path.join(pull_url, "commits")
184
+
185
+ response = await client.request("GET", pull_commits_url, headers=headers)
186
+
187
+ # Raise an exception for HTTP errors
188
+ response.raise_for_status()
189
+
190
+ # Parse and return the response JSON
191
+ try:
192
+ result = response.json()
193
+
194
+ except ValueError as e:
195
+ raise ValueError("The API response is not valid JSON.") from e
196
+
197
+ return json.dumps(result)
198
+
199
+ yield FunctionInfo.from_fn(_github_get_pull,
200
+ description=("Fetches the commits for a particular GitHub pull request "
201
+ f" in the repo named {config.repo_name}"),
202
+ input_schema=GithubGetPullModel)
203
+
204
+
205
+ class GithubGetPullFilesToolConfig(FunctionBaseConfig, name="github_get_pull_files_tool"):
206
+ """
207
+ Configuration for the GitHub Get Pull Files Tool.
208
+ """
209
+ repo_name: str = "The repository name in the format 'owner/repo'"
210
+ timeout: int = 300
211
+
212
+
213
+ @register_function(config_type=GithubGetPullFilesToolConfig)
214
+ async def get_github_pull_files_async(config: GithubGetPullFilesToolConfig, builder: Builder):
215
+ """
216
+ Fetches the files associated with a particular pull request in a GitHub repository asynchronously.
217
+
218
+ """
219
+ import json
220
+ import os
221
+
222
+ import httpx
223
+
224
+ github_pat = os.getenv("GITHUB_PAT")
225
+ if not github_pat:
226
+ raise ValueError("GITHUB_PAT environment variable must be set")
227
+
228
+ url = f"https://api.github.com/repos/{config.repo_name}/pulls"
229
+
230
+ # define the headers for the payload request
231
+ headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
232
+
233
+ async def _github_get_pull(pull_number) -> list:
234
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
235
+ # update the url with the pull number that needs to be updated
236
+ pull_url = os.path.join(url, pull_number)
237
+ pull_files_url = os.path.join(pull_url, "files")
238
+
239
+ response = await client.request("GET", pull_files_url, headers=headers)
240
+
241
+ # Raise an exception for HTTP errors
242
+ response.raise_for_status()
243
+
244
+ # Parse and return the response JSON
245
+ try:
246
+ result = response.json()
247
+
248
+ except ValueError as e:
249
+ raise ValueError("The API response is not valid JSON.") from e
250
+
251
+ return json.dumps(result)
252
+
253
+ yield FunctionInfo.from_fn(_github_get_pull,
254
+ description=("Fetches the files for a particular GitHub pull request "
255
+ f" in the repo named {config.repo_name}"),
256
+ input_schema=GithubGetPullModel)
@@ -0,0 +1,100 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Literal
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from aiq.builder.builder import Builder
22
+ from aiq.builder.function_info import FunctionInfo
23
+ from aiq.cli.register_workflow import register_function
24
+ from aiq.data_models.function import FunctionBaseConfig
25
+
26
+
27
+ class GithubUpdateIssueModel(BaseModel):
28
+ issue_number: str = Field(description="The issue number that will be updated")
29
+ title: str | None = Field(None, description="The title of the GitHub Issue")
30
+ body: str | None = Field(None, description="The body of the GitHub Issue")
31
+ state: Literal["open", "closed"] | None = Field(None, description="The new state of the issue")
32
+
33
+ state_reason: Literal["completed", "not_planned", "reopened", None] | None = Field(
34
+ None, description="The reason for changing the state of the issue")
35
+
36
+ labels: list[str] | None = Field(None, description="A list of labels to assign to the issue")
37
+ assignees: list[str] | None = Field(None, description="A list of assignees to assign to the issue")
38
+
39
+
40
+ class GithubUpdateIssueModelList(BaseModel):
41
+ issues: list[GithubUpdateIssueModel] = Field(description=("A list of GitHub issues each "
42
+ "of type GithubUpdateIssueModel"))
43
+
44
+
45
+ class GithubUpdateIssueToolConfig(FunctionBaseConfig, name="github_update_issue_tool"):
46
+ """
47
+ Tool that updates an issue in a GitHub repository asynchronously.
48
+ """
49
+ repo_name: str = "The repository name in the format 'owner/repo'"
50
+ timeout: int = 300
51
+
52
+
53
+ @register_function(config_type=GithubUpdateIssueToolConfig)
54
+ async def update_github_issue_async(config: GithubUpdateIssueToolConfig, builder: Builder):
55
+ """
56
+ Updates an issue in a GitHub repository asynchronously.
57
+ """
58
+ import json
59
+ import os
60
+
61
+ import httpx
62
+
63
+ github_pat = os.getenv("GITHUB_PAT")
64
+ if not github_pat:
65
+ raise ValueError("GITHUB_PAT environment variable must be set")
66
+
67
+ url = f"https://api.github.com/repos/{config.repo_name}/issues"
68
+
69
+ # define the headers for the payload request
70
+ headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
71
+
72
+ async def _github_update_issue(issues) -> list:
73
+ results = []
74
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
75
+ for issue in issues:
76
+ payload = issue.dict(exclude_unset=True)
77
+
78
+ # update the url with the issue number that needs to be updated
79
+ issue_number = payload.pop("issue_number")
80
+ issue_url = os.path.join(url, issue_number)
81
+
82
+ response = await client.request("PATCH", issue_url, json=payload, headers=headers)
83
+
84
+ # Raise an exception for HTTP errors
85
+ response.raise_for_status()
86
+
87
+ # Parse and return the response JSON
88
+ try:
89
+ result = response.json()
90
+ results.append(result)
91
+
92
+ except ValueError as e:
93
+ raise ValueError("The API response is not valid JSON.") from e
94
+
95
+ return json.dumps(results)
96
+
97
+ yield FunctionInfo.from_fn(_github_update_issue,
98
+ description=(f"Updates a GitHub issue in the "
99
+ f"repo named {config.repo_name}"),
100
+ input_schema=GithubUpdateIssueModelList)
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,220 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from contextlib import asynccontextmanager
20
+ from enum import Enum
21
+ from typing import Any
22
+
23
+ from mcp import ClientSession
24
+ from mcp.client.sse import sse_client
25
+ from mcp.types import TextContent
26
+ from pydantic import BaseModel
27
+ from pydantic import Field
28
+ from pydantic import create_model
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
34
+ """
35
+ Create a pydantic model from the input schema of the MCP tool
36
+ """
37
+ _type_map = {
38
+ "string": str,
39
+ "number": float,
40
+ "integer": int,
41
+ "boolean": bool,
42
+ "array": list,
43
+ "null": None,
44
+ "object": dict,
45
+ }
46
+
47
+ properties = mcp_input_schema.get("properties", {})
48
+ schema_dict = {}
49
+
50
+ def _generate_valid_classname(class_name: str):
51
+ return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
52
+
53
+ def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
54
+ json_type = field_properties.get("type", "string")
55
+ enum_vals = field_properties.get("enum")
56
+
57
+ if enum_vals:
58
+ enum_name = f"{field_name.capitalize()}Enum"
59
+ field_type = Enum(enum_name, {item: item for item in enum_vals})
60
+
61
+ elif json_type == "object" and "properties" in field_properties:
62
+ field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
63
+ elif json_type == "array" and "items" in field_properties:
64
+ item_properties = field_properties.get("items", {})
65
+ if item_properties.get("type") == "object":
66
+ item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
67
+ else:
68
+ item_type = _type_map.get(json_type, Any)
69
+ field_type = list[item_type]
70
+ else:
71
+ field_type = _type_map.get(json_type, Any)
72
+
73
+ default_value = field_properties.get("default", ...)
74
+ nullable = field_properties.get("nullable", False)
75
+ description = field_properties.get("description", "")
76
+
77
+ field_type = field_type | None if nullable else field_type
78
+
79
+ return field_type, Field(default=default_value, description=description)
80
+
81
+ for field_name, field_props in properties.items():
82
+ schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
83
+ return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
84
+
85
+
86
+ class MCPSSEClient:
87
+ """
88
+ Client for creating a session and connecting to an MCP server using SSE
89
+
90
+ Args:
91
+ url (str): The url of the MCP server
92
+ """
93
+
94
+ def __init__(self, url: str):
95
+ self.url = url
96
+
97
+ @asynccontextmanager
98
+ async def connect_to_sse_server(self):
99
+ """
100
+ Establish a session with an MCP SSE server within an aync context
101
+ """
102
+ async with sse_client(url=self.url) as (read, write):
103
+ async with ClientSession(read, write) as session:
104
+ await session.initialize()
105
+ yield session
106
+
107
+
108
+ class MCPBuilder(MCPSSEClient):
109
+ """
110
+ Builder class used to connect to an MCP Server and generate ToolClients
111
+
112
+ Args:
113
+ url (str): The url of the MCP server
114
+ """
115
+
116
+ def __init__(self, url):
117
+ super().__init__(url)
118
+ self._tools = None
119
+
120
+ async def get_tools(self):
121
+ """
122
+ Retrieve a dictionary of all tools served by the MCP server.
123
+ """
124
+ async with self.connect_to_sse_server() as session:
125
+ response = await session.list_tools()
126
+
127
+ return {
128
+ tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
129
+ for tool in response.tools
130
+ }
131
+
132
+ async def get_tool(self, tool_name: str) -> MCPToolClient:
133
+ """
134
+ Get an MCP Tool by name.
135
+
136
+ Args:
137
+ tool_name (str): Name of the tool to load.
138
+
139
+ Returns:
140
+ MCPToolClient for the configured tool.
141
+
142
+ Raise:
143
+ ValueError if no tool is available with that name.
144
+ """
145
+ if not self._tools:
146
+ self._tools = await self.get_tools()
147
+
148
+ tool = self._tools.get(tool_name)
149
+ if not tool:
150
+ raise ValueError(f"Tool {tool_name} not available at {self.url}")
151
+ return tool
152
+
153
+ async def call_tool(self, tool_name: str, tool_args: dict | None):
154
+ async with self.connect_to_sse_server() as session:
155
+ result = await session.call_tool(tool_name, tool_args)
156
+ return result
157
+
158
+
159
+ class MCPToolClient(MCPSSEClient):
160
+ """
161
+ Client wrapper used to call an MCP tool.
162
+
163
+ Args:
164
+ url (str): The url of the MCP server
165
+ tool_name (str): The name of the tool to wrap
166
+ tool_description (str): The description of the tool provided by the MCP server.
167
+ tool_input_schema (dict): The input schema for the tool.
168
+ """
169
+
170
+ def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
171
+ super().__init__(url)
172
+ self._tool_name = tool_name
173
+ self._tool_description = tool_description
174
+ self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
175
+
176
+ @property
177
+ def name(self):
178
+ """Returns the name of the tool."""
179
+ return self._tool_name
180
+
181
+ @property
182
+ def description(self):
183
+ """
184
+ Returns the tool's description. If none was provided. Provides a simple description using the tool's name
185
+ """
186
+ if not self._tool_description:
187
+ return f"MCP Tool {self._tool_name}"
188
+ return self._tool_description
189
+
190
+ @property
191
+ def input_schema(self):
192
+ """
193
+ Returns the tool's input_schema.
194
+ """
195
+ return self._input_schema
196
+
197
+ def set_description(self, description: str):
198
+ """
199
+ Manually define the tool's description using the provided string.
200
+ """
201
+ self._tool_description = description
202
+
203
+ async def acall(self, tool_args: dict) -> str:
204
+ """
205
+ Call the MCP tool with the provided arguments.
206
+
207
+ Args:
208
+ tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
209
+ """
210
+ async with self.connect_to_sse_server() as session:
211
+ result = await session.call_tool(self._tool_name, tool_args)
212
+
213
+ output = []
214
+ for res in result.content:
215
+ if isinstance(res, TextContent):
216
+ output.append(res.text)
217
+ else:
218
+ # Log non-text content for now
219
+ logger.warning("Got not-text output from %s of type %s", self.name, type(res))
220
+ return "\n".join(output)
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+ from pydantic import HttpUrl
21
+
22
+ from aiq.builder.builder import Builder
23
+ from aiq.builder.function_info import FunctionInfo
24
+ from aiq.cli.register_workflow import register_function
25
+ from aiq.data_models.function import FunctionBaseConfig
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
31
+ """
32
+ Function which connects to a Model Context Protocol (MCP) server and wraps the selected tool as an AgentIQ function.
33
+ """
34
+ # Add your custom configuration parameters here
35
+ url: HttpUrl = Field(description="The URL of the MCP server")
36
+ mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
37
+ description: str | None = Field(default=None,
38
+ description="""
39
+ Description for the tool that will override the description provided by the MCP server. Should only be used if
40
+ the description provided by the server is poor or nonexistent
41
+ """)
42
+
43
+
44
+ @register_function(config_type=MCPToolConfig)
45
+ async def mcp_tool(config: MCPToolConfig, builder: Builder):
46
+ """
47
+ Generate an AgentIQ Function that wraps a tool provided by the MCP server.
48
+ """
49
+
50
+ from aiq.tool.mcp.mcp_client import MCPBuilder
51
+ from aiq.tool.mcp.mcp_client import MCPToolClient
52
+
53
+ client = MCPBuilder(url=str(config.url))
54
+
55
+ tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
56
+ if config.description:
57
+ tool.set_description(description=config.description)
58
+
59
+ logger.info("Configured to use tool: %s from MCP server at %s", tool.name, str(config.url))
60
+
61
+ def _convert_from_str(input_str: str) -> tool.input_schema:
62
+ return tool.input_schema.model_validate_json(input_str)
63
+
64
+ async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
65
+ if tool_input:
66
+ args = tool_input.model_dump()
67
+ return await tool.acall(args)
68
+
69
+ _ = tool.input_schema.model_validate(kwargs)
70
+ return await tool.acall(kwargs)
71
+
72
+ yield FunctionInfo.create(single_fn=_response_fn,
73
+ description=tool.description,
74
+ input_schema=tool.input_schema,
75
+ converters=[_convert_from_str])
File without changes