nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,450 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datetime import datetime
17
+ from typing import Literal
18
+
19
+ from pydantic import BaseModel
20
+ from pydantic import Field
21
+ from pydantic import PositiveInt
22
+ from pydantic import computed_field
23
+ from pydantic import field_validator
24
+
25
+ from nat.builder.builder import Builder
26
+ from nat.builder.function import FunctionGroup
27
+ from nat.builder.function_info import FunctionInfo
28
+ from nat.cli.register_workflow import register_function
29
+ from nat.cli.register_workflow import register_function_group
30
+ from nat.data_models.function import FunctionBaseConfig
31
+ from nat.data_models.function import FunctionGroupBaseConfig
32
+
33
+
34
+ class GithubCreateIssueModel(BaseModel):
35
+ title: str = Field(description="The title of the GitHub Issue")
36
+ body: str = Field(description="The body of the GitHub Issue")
37
+
38
+
39
+ class GithubCreateIssueModelList(BaseModel):
40
+ issues: list[GithubCreateIssueModel] = Field(default_factory=list,
41
+ description=("A list of GitHub issues, "
42
+ "each with a title and a body"))
43
+
44
+
45
+ class GithubGetIssueModel(BaseModel):
46
+ state: Literal["open", "closed", "all"] | None = Field(default="open",
47
+ description="Issue state used in issue query filter")
48
+ assignee: str | None = Field(default=None, description="Assignee name used in issue query filter")
49
+ creator: str | None = Field(default=None, description="Creator name used in issue query filter")
50
+ mentioned: str | None = Field(default=None, description="Name of person mentioned in issue")
51
+ labels: list[str] | None = Field(default=None, description="A list of labels that are assigned to the issue")
52
+ since: str | None = Field(default=None,
53
+ description="Only show results that were last updated after the given time.")
54
+
55
+ @classmethod
56
+ @field_validator('since', mode='before')
57
+ def validate_since(cls, v):
58
+ if v is None:
59
+ return v
60
+ try:
61
+ # Parse the string to a datetime object
62
+ parsed_date = datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ")
63
+ # Return the formatted string
64
+ return parsed_date.isoformat() + 'Z'
65
+ except ValueError as e:
66
+ raise ValueError("since must be in ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ") from e
67
+
68
+
69
+ class GithubGetIssueModelList(BaseModel):
70
+ filter_parameters: list[GithubGetIssueModel] = Field(default_factory=list,
71
+ description=("A list of query params when fetching issues "
72
+ "each of type GithubGetIssueModel"))
73
+
74
+
75
+ class GithubUpdateIssueModel(BaseModel):
76
+ issue_number: str = Field(description="The issue number that will be updated")
77
+ title: str | None = Field(default=None, description="The title of the GitHub Issue")
78
+ body: str | None = Field(default=None, description="The body of the GitHub Issue")
79
+ state: Literal["open", "closed"] | None = Field(default=None, description="The new state of the issue")
80
+
81
+ state_reason: Literal["completed", "not_planned", "reopened"] | None = Field(
82
+ default=None, description="The reason for changing the state of the issue")
83
+
84
+ labels: list[str] | None = Field(default=None, description="A list of labels to assign to the issue")
85
+ assignees: list[str] | None = Field(default=None, description="A list of assignees to assign to the issue")
86
+
87
+
88
+ class GithubUpdateIssueModelList(BaseModel):
89
+ issues: list[GithubUpdateIssueModel] = Field(default_factory=list,
90
+ description=("A list of GitHub issues each "
91
+ "of type GithubUpdateIssueModel"))
92
+
93
+
94
+ class GithubCreatePullModel(BaseModel):
95
+ title: str = Field(description="Title of the pull request")
96
+ body: str = Field(description="Description of the pull request")
97
+ source_branch: str = Field(description="The name of the branch containing your changes", serialization_alias="head")
98
+ target_branch: str = Field(description="The name of the branch you want to merge into", serialization_alias="base")
99
+ assignees: list[str] | None = Field(default=None,
100
+ description="List of GitHub usernames to assign to the PR. "
101
+ "Always the current user")
102
+ reviewers: list[str] | None = Field(default=None, description="List of GitHub usernames to request review from")
103
+
104
+
105
+ class GithubCreatePullList(BaseModel):
106
+ pull_details: list[GithubCreatePullModel] = Field(
107
+ default_factory=list, description=("A list of params used for creating the PR in GitHub"))
108
+
109
+
110
+ class GithubGetPullsModel(BaseModel):
111
+ state: Literal["open", "closed", "all"] | None = Field(default="open",
112
+ description="Issue state used in issue query filter")
113
+ head: str | None = Field(default=None,
114
+ description="Filters pulls by head user or head organization and branch name")
115
+ base: str | None = Field(default=None, description="Filters pull by branch name")
116
+
117
+
118
+ class GithubGetPullsModelList(BaseModel):
119
+ filter_parameters: list[GithubGetPullsModel] = Field(
120
+ default_factory=list,
121
+ description=("A list of query params when fetching pull requests "
122
+ "each of type GithubGetPullsModel"))
123
+
124
+
125
+ class GithubCommitCodeModel(BaseModel):
126
+ branch: str = Field(description="The branch of the remote repo to which the code will be committed")
127
+ commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
128
+ local_path: str = Field(description="Local filepath of the file that has been updated and "
129
+ "needs to be committed to the remote repo")
130
+ remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
131
+ "root of current repository")
132
+
133
+
134
+ class GithubCommitCodeModelList(BaseModel):
135
+ updated_files: list[GithubCommitCodeModel] = Field(default_factory=list,
136
+ description=("A list of local filepaths and commit messages"))
137
+
138
+
139
+ class GithubGroupConfig(FunctionGroupBaseConfig, name="github"):
140
+ """Function group for GitHub repository operations.
141
+
142
+ Exposes issue, pull request, and commit operations with shared configuration.
143
+ """
144
+ repo_name: str = Field(description="The repository name in the format 'owner/repo'")
145
+ timeout: int = Field(default=300, description="Timeout in seconds for GitHub API requests")
146
+ # Required for commit function
147
+ local_repo_dir: str | None = Field(default=None,
148
+ description="Absolute path to the local clone. Required for 'commit' function")
149
+
150
+
151
+ @register_function_group(config_type=GithubGroupConfig)
152
+ async def github_tool(config: GithubGroupConfig, _builder: Builder):
153
+ """Register the `github` function group with shared configuration.
154
+
155
+ Implements:
156
+ - create_issue, get_issue, update_issue
157
+ - create_pull, get_pull
158
+ - commit
159
+ """
160
+ import base64
161
+ import json
162
+ import os
163
+
164
+ import httpx
165
+
166
+ token: str | None = None
167
+ for env_var in ["GITHUB_TOKEN", "GITHUB_PAT", "GH_TOKEN"]:
168
+ token = os.getenv(env_var)
169
+ if token:
170
+ break
171
+
172
+ if not token:
173
+ raise ValueError("No GitHub token found in environment variables. Please set one of the following"
174
+ "environment variables: GITHUB_TOKEN, GITHUB_PAT, GH_TOKEN")
175
+
176
+ headers = {
177
+ "Authorization": f"Bearer {token}",
178
+ "Accept": "application/vnd.github+json",
179
+ "User-Agent": "NeMo-Agent-Toolkit",
180
+ }
181
+
182
+ async with httpx.AsyncClient(timeout=config.timeout, headers=headers) as client:
183
+
184
+ # Issues
185
+ async def create_issue(issues_list: GithubCreateIssueModelList) -> str:
186
+ url = f"https://api.github.com/repos/{config.repo_name}/issues"
187
+ results = []
188
+ for issue in issues_list.issues:
189
+ payload = issue.model_dump(exclude_unset=True)
190
+ response = await client.post(url, json=payload)
191
+ response.raise_for_status()
192
+ results.append(response.json())
193
+ return json.dumps(results)
194
+
195
+ async def get_issue(issues_list: GithubGetIssueModelList) -> str:
196
+ url = f"https://api.github.com/repos/{config.repo_name}/issues"
197
+ results = []
198
+ for issue in issues_list.filter_parameters:
199
+ params = issue.model_dump(exclude_unset=True, exclude_none=True)
200
+ response = await client.get(url, params=params)
201
+ response.raise_for_status()
202
+ results.append(response.json())
203
+ return json.dumps(results)
204
+
205
+ async def update_issue(issues_list: GithubUpdateIssueModelList) -> str:
206
+ url = f"https://api.github.com/repos/{config.repo_name}/issues"
207
+ results = []
208
+ for issue in issues_list.issues:
209
+ payload = issue.model_dump(exclude_unset=True, exclude_none=True)
210
+ issue_number = payload.pop("issue_number")
211
+ issue_url = f"{url}/{issue_number}"
212
+ response = await client.patch(issue_url, json=payload)
213
+ response.raise_for_status()
214
+ results.append(response.json())
215
+ return json.dumps(results)
216
+
217
+ # Pull requests
218
+ async def create_pull(pull_list: GithubCreatePullList) -> str:
219
+ results = []
220
+ pr_url = f"https://api.github.com/repos/{config.repo_name}/pulls"
221
+
222
+ for pull_detail in pull_list.pull_details:
223
+
224
+ pr_data = pull_detail.model_dump(
225
+ include={"title", "body", "source_branch", "target_branch"},
226
+ by_alias=True,
227
+ )
228
+ pr_response = await client.post(pr_url, json=pr_data)
229
+ pr_response.raise_for_status()
230
+ pr_number = pr_response.json()["number"]
231
+
232
+ result = {"pull_request": pr_response.json()}
233
+
234
+ if pull_detail.assignees:
235
+ assignees_url = f"https://api.github.com/repos/{config.repo_name}/issues/{pr_number}/assignees"
236
+ assignees_data = {"assignees": pull_detail.assignees}
237
+ assignees_response = await client.post(assignees_url, json=assignees_data)
238
+ assignees_response.raise_for_status()
239
+ result["assignees"] = assignees_response.json()
240
+
241
+ if pull_detail.reviewers:
242
+ reviewers_url = f"https://api.github.com/repos/{config.repo_name}/pulls/{pr_number}/requested_reviewers"
243
+ reviewers_data = {"reviewers": pull_detail.reviewers}
244
+ reviewers_response = await client.post(reviewers_url, json=reviewers_data)
245
+ reviewers_response.raise_for_status()
246
+ result["reviewers"] = reviewers_response.json()
247
+
248
+ results.append(result)
249
+
250
+ return json.dumps(results)
251
+
252
+ async def get_pull(pull_list: GithubGetPullsModelList) -> str:
253
+ url = f"https://api.github.com/repos/{config.repo_name}/pulls"
254
+ results = []
255
+ for pull_params in pull_list.filter_parameters:
256
+ params = pull_params.model_dump(exclude_unset=True, exclude_none=True)
257
+ response = await client.get(url, params=params)
258
+ response.raise_for_status()
259
+ results.append(response.json())
260
+
261
+ return json.dumps(results)
262
+
263
+ # Commits (commit updated files)
264
+ async def commit(updated_file_list: GithubCommitCodeModelList) -> str:
265
+ if not config.local_repo_dir:
266
+ raise ValueError("'local_repo_dir' must be set in the github function group config to use 'commit'")
267
+
268
+ results = []
269
+ for updated_file in updated_file_list.updated_files:
270
+ branch = updated_file.branch
271
+ commit_msg = updated_file.commit_msg
272
+ local_path = updated_file.local_path
273
+ remote_path = updated_file.remote_path
274
+
275
+ # Read content from the local file (secure + binary-safe)
276
+ safe_root = os.path.realpath(config.local_repo_dir)
277
+ candidate = os.path.realpath(os.path.join(config.local_repo_dir, local_path))
278
+ if not candidate.startswith(safe_root + os.sep):
279
+ raise ValueError(f"local_path '{local_path}' resolves outside local_repo_dir")
280
+ if not os.path.isfile(candidate):
281
+ raise FileNotFoundError(f"File not found: {candidate}")
282
+ with open(candidate, "rb") as f:
283
+ content_bytes = f.read()
284
+ content_b64 = base64.b64encode(content_bytes).decode("ascii")
285
+
286
+ # 1) Create blob
287
+ blob_url = f"https://api.github.com/repos/{config.repo_name}/git/blobs"
288
+ blob_data = {"content": content_b64, "encoding": "base64"}
289
+ blob_response = await client.post(blob_url, json=blob_data)
290
+ blob_response.raise_for_status()
291
+ blob_sha = blob_response.json()["sha"]
292
+
293
+ # 2) Get current ref (parent commit SHA)
294
+ ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
295
+ ref_response = await client.get(ref_url)
296
+ ref_response.raise_for_status()
297
+ parent_commit_sha = ref_response.json()["object"]["sha"]
298
+
299
+ # 3) Get parent commit to retrieve its tree SHA
300
+ parent_commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits/{parent_commit_sha}"
301
+ parent_commit_resp = await client.get(parent_commit_url)
302
+ parent_commit_resp.raise_for_status()
303
+ base_tree_sha = parent_commit_resp.json()["tree"]["sha"]
304
+
305
+ # 4) Create tree
306
+ tree_url = f"https://api.github.com/repos/{config.repo_name}/git/trees"
307
+ tree_data = {
308
+ "base_tree": base_tree_sha,
309
+ "tree": [{
310
+ "path": remote_path, "mode": "100644", "type": "blob", "sha": blob_sha
311
+ }],
312
+ }
313
+ tree_response = await client.post(tree_url, json=tree_data)
314
+ tree_response.raise_for_status()
315
+ tree_sha = tree_response.json()["sha"]
316
+
317
+ # 5) Create commit
318
+ commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits"
319
+ commit_data = {"message": commit_msg, "tree": tree_sha, "parents": [parent_commit_sha]}
320
+ commit_response = await client.post(commit_url, json=commit_data)
321
+ commit_response.raise_for_status()
322
+ commit_sha = commit_response.json()["sha"]
323
+
324
+ # 6) Update ref
325
+ update_ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
326
+ update_ref_data = {"sha": commit_sha, "force": False}
327
+ update_ref_response = await client.patch(update_ref_url, json=update_ref_data)
328
+ update_ref_response.raise_for_status()
329
+
330
+ results.append({
331
+ "blob_resp": blob_response.json(),
332
+ "parent_commit": parent_commit_resp.json(),
333
+ "new_tree": tree_response.json(),
334
+ "commit_resp": commit_response.json(),
335
+ "update_ref_resp": update_ref_response.json(),
336
+ })
337
+
338
+ return json.dumps(results)
339
+
340
+ group = FunctionGroup(config=config)
341
+
342
+ group.add_function("create_issue",
343
+ create_issue,
344
+ description=f"Creates a GitHub issue in the repo named {config.repo_name}",
345
+ input_schema=GithubCreateIssueModelList)
346
+ group.add_function("get_issue",
347
+ get_issue,
348
+ description=f"Fetches a particular GitHub issue in the repo named {config.repo_name}",
349
+ input_schema=GithubGetIssueModelList)
350
+ group.add_function("update_issue",
351
+ update_issue,
352
+ description=f"Updates a GitHub issue in the repo named {config.repo_name}",
353
+ input_schema=GithubUpdateIssueModelList)
354
+ group.add_function("create_pull",
355
+ create_pull,
356
+ description="Creates a pull request with assignees and reviewers in"
357
+ f"the GitHub repository named {config.repo_name}",
358
+ input_schema=GithubCreatePullList)
359
+ group.add_function("get_pull",
360
+ get_pull,
361
+ description="Fetches the files for a particular GitHub pull request"
362
+ f"in the repo named {config.repo_name}",
363
+ input_schema=GithubGetPullsModelList)
364
+ group.add_function("commit",
365
+ commit,
366
+ description="Commits and pushes modified code to a GitHub repository"
367
+ f"in the repo named {config.repo_name}",
368
+ input_schema=GithubCommitCodeModelList)
369
+
370
+ yield group
371
+
372
+
373
+ class GithubFilesGroupConfig(FunctionBaseConfig, name="github_files_tool"):
374
+ timeout: int = Field(default=5, description="Timeout in seconds for HTTP requests")
375
+
376
+
377
+ @register_function(config_type=GithubFilesGroupConfig)
378
+ async def github_files_tool(config: GithubFilesGroupConfig, _builder: Builder):
379
+
380
+ import re
381
+
382
+ import httpx
383
+
384
+ class FileMetadata(BaseModel):
385
+ repo_path: str
386
+ file_path: str
387
+ start: str | None = Field(default=None)
388
+ end: str | None = Field(default=None)
389
+
390
+ @computed_field
391
+ @property
392
+ def start_line(self) -> PositiveInt | None:
393
+ return int(self.start) if self.start else None
394
+
395
+ @computed_field
396
+ @property
397
+ def end_line(self) -> PositiveInt | None:
398
+ return int(self.end) if self.end else None
399
+
400
+ async with httpx.AsyncClient(timeout=config.timeout) as client:
401
+
402
+ async def get(url_text: str) -> str:
403
+ """
404
+ Returns the text of a github file using a github url starting with https://github.com and ending
405
+ with a specific file. If a line reference is provided (#L409), the text of the line is returned.
406
+ If a range of lines is provided (#L409-L417), the text of the lines is returned.
407
+
408
+ Examples:
409
+ - https://github.com/org/repo/blob/main/README.md -> Returns full text of the README.md file
410
+ - https://github.com/org/repo/blob/main/README.md#L409 -> Returns the 409th line of the README.md file
411
+ - https://github.com/org/repo/blob/main/README.md#L409-L417 -> Returns lines 409-417 of the README.md file
412
+ """
413
+
414
+ pattern = r"https://github\.com/(?P<repo_path>[^/]*/[^/]*)/blob/(?P<file_path>[^?#]*)(?:#L(?P<start>\d+)(?:-L(?P<end>\d+))?)?"
415
+ match = re.match(pattern, url_text)
416
+ if not match:
417
+ return ("Invalid github url. Please provide a valid github url. "
418
+ "Example: 'https://github.com/org/repo/blob/main/README.md' "
419
+ "or 'https://github.com/org/repo/blob/main/README.md#L409' "
420
+ "or 'https://github.com/org/repo/blob/main/README.md#L409-L417'")
421
+
422
+ file_metadata = FileMetadata(**match.groupdict())
423
+
424
+ # The following URL is the raw URL of the file. refs/heads/ always points to the top commit of the branch
425
+ raw_url = f"https://raw.githubusercontent.com/{file_metadata.repo_path}/refs/heads/{file_metadata.file_path}"
426
+ try:
427
+ response = await client.get(raw_url)
428
+ response.raise_for_status()
429
+ except httpx.TimeoutException:
430
+ return f"Timeout encountered when retrieving resource: {raw_url}"
431
+
432
+ if file_metadata.start_line is None:
433
+ return f"```{response.text}\n```"
434
+
435
+ lines = response.text.splitlines()
436
+
437
+ if file_metadata.start_line > len(lines):
438
+ return f"Error: Line {file_metadata.start_line} is out of range for the file {file_metadata.file_path}"
439
+
440
+ if file_metadata.end_line is None:
441
+ return f"```{lines[file_metadata.start_line - 1]}\n```"
442
+
443
+ if file_metadata.end_line > len(lines):
444
+ return f"Error: Line {file_metadata.end_line} is out of range for the file {file_metadata.file_path}"
445
+
446
+ selected_lines = lines[file_metadata.start_line - 1:file_metadata.end_line]
447
+ response_text = "\n".join(selected_lines)
448
+ return f"```{response_text}\n```"
449
+
450
+ yield FunctionInfo.from_fn(get, description=get.__doc__)
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class AddToolConfig(FunctionBaseConfig, name="add_memory"):
31
31
  """Function to add memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to add memory about a user's interactions to a system "
33
+ description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
34
34
  "for retrieval later."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -46,7 +46,7 @@ async def add_memory_tool(config: AddToolConfig, builder: Builder):
46
46
  from langchain_core.tools import ToolException
47
47
 
48
48
  # First, retrieve the memory client
49
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
50
50
 
51
51
  async def _arun(item: MemoryItem) -> str:
52
52
  """
@@ -30,10 +30,9 @@ logger = logging.getLogger(__name__)
30
30
  class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
31
31
  """Function to delete memory from a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
34
- "interactions to help answer questions in a personalized way."),
33
+ description: str = Field(default="Tool to delete a memory from a hosted memory platform.",
35
34
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
35
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
36
  description=("Instance name of the memory client instance from the workflow "
38
37
  "configuration object."))
39
38
 
@@ -47,7 +46,7 @@ async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
47
46
  from langchain_core.tools import ToolException
48
47
 
49
48
  # First, retrieve the memory client
50
- memory_editor = builder.get_memory_client(config.memory)
49
+ memory_editor = await builder.get_memory_client(config.memory)
51
50
 
52
51
  async def _arun(user_id: str) -> str:
53
52
  """
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
30
30
  class GetToolConfig(FunctionBaseConfig, name="get_memory"):
31
31
  """Function to get memory to a hosted memory platform."""
32
32
 
33
- description: str = Field(default=("Tool to retrieve memory about a user's "
33
+ description: str = Field(default=("Tool to retrieve a memory about a user's "
34
34
  "interactions to help answer questions in a personalized way."),
35
35
  description="The description of this function's use for tool calling agents.")
36
- memory: MemoryRef = Field(default="saas_memory",
36
+ memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
37
37
  description=("Instance name of the memory client instance from the workflow "
38
38
  "configuration object."))
39
39
 
@@ -49,7 +49,7 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
49
49
  from langchain_core.tools import ToolException
50
50
 
51
51
  # First, retrieve the memory client
52
- memory_editor = builder.get_memory_client(config.memory)
52
+ memory_editor = await builder.get_memory_client(config.memory)
53
53
 
54
54
  async def _arun(search_input: SearchMemoryInput) -> str:
55
55
  """
@@ -67,6 +67,6 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
67
67
 
68
68
  except Exception as e:
69
69
 
70
- raise ToolException(f"Error retreiving memory: {e}") from e
70
+ raise ToolException(f"Error retrieving memory: {e}") from e
71
71
 
72
72
  yield FunctionInfo.from_fn(_arun, description=config.description)
nat/tool/register.py CHANGED
@@ -24,13 +24,8 @@ from . import nvidia_rag
24
24
  from . import retriever
25
25
  from . import server_tools
26
26
  from .code_execution import register
27
- from .github_tools import create_github_commit
28
- from .github_tools import create_github_issue
29
- from .github_tools import create_github_pr
30
- from .github_tools import get_github_file
31
- from .github_tools import get_github_issue
32
- from .github_tools import get_github_pr
33
- from .github_tools import update_github_issue
27
+ from .github_tools import github_tool
28
+ from .github_tools import github_files_tool
34
29
  from .memory_tools import add_memory_tool
35
30
  from .memory_tools import delete_memory_tool
36
31
  from .memory_tools import get_memory_tool
nat/tool/server_tools.py CHANGED
@@ -32,14 +32,23 @@ class RequestAttributesTool(FunctionBaseConfig, name="current_request_attributes
32
32
  @register_function(config_type=RequestAttributesTool)
33
33
  async def current_request_attributes(config: RequestAttributesTool, builder: Builder):
34
34
 
35
+ from pydantic import RootModel
36
+ from pydantic.types import JsonValue
35
37
  from starlette.datastructures import Headers
36
38
  from starlette.datastructures import QueryParams
37
39
 
38
- async def _get_request_attributes(unused: str) -> str:
40
+ class RequestBody(RootModel[JsonValue]):
41
+ """
42
+ Data model that accepts a request body of any valid JSON type.
43
+ """
44
+ root: JsonValue
45
+
46
+ async def _get_request_attributes(request_body: RequestBody) -> str:
39
47
 
40
48
  from nat.builder.context import Context
41
49
  nat_context = Context.get()
42
50
 
51
+ # Access request attributes from context
43
52
  method: str | None = nat_context.metadata.method
44
53
  url_path: str | None = nat_context.metadata.url_path
45
54
  url_scheme: str | None = nat_context.metadata.url_scheme
@@ -51,6 +60,9 @@ async def current_request_attributes(config: RequestAttributesTool, builder: Bui
51
60
  cookies: dict[str, str] | None = nat_context.metadata.cookies
52
61
  conversation_id: str | None = nat_context.conversation_id
53
62
 
63
+ # Access the request body data - can be any valid JSON type
64
+ request_body_data: JsonValue = request_body.root
65
+
54
66
  return (f"Method: {method}, "
55
67
  f"URL Path: {url_path}, "
56
68
  f"URL Scheme: {url_scheme}, "
@@ -60,7 +72,8 @@ async def current_request_attributes(config: RequestAttributesTool, builder: Bui
60
72
  f"Client Host: {client_host}, "
61
73
  f"Client Port: {client_port}, "
62
74
  f"Cookies: {cookies}, "
63
- f"Conversation Id: {conversation_id}")
75
+ f"Conversation Id: {conversation_id}, "
76
+ f"Request Body: {request_body_data}")
64
77
 
65
78
  yield FunctionInfo.from_fn(_get_request_attributes,
66
79
  description="Returns the acquired user defined request attributes.")
nat/utils/__init__.py CHANGED
@@ -0,0 +1,76 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+ from pathlib import Path
18
+
19
+ if typing.TYPE_CHECKING:
20
+
21
+ from nat.data_models.config import Config
22
+
23
+ from .type_utils import StrPath
24
+
25
+ _T = typing.TypeVar("_T")
26
+
27
+
28
+ async def run_workflow(*,
29
+ config: "Config | None" = None,
30
+ config_file: "StrPath | None" = None,
31
+ prompt: str,
32
+ to_type: type[_T] = str,
33
+ session_kwargs: dict[str, typing.Any] | None = None) -> _T:
34
+ """
35
+ Wrapper to run a workflow given either a config or a config file path and a prompt, returning the result in the
36
+ type specified by the `to_type`.
37
+
38
+ Parameters
39
+ ----------
40
+ config : Config | None
41
+ The configuration object to use for the workflow. If None, config_file must be provided.
42
+ config_file : StrPath | None
43
+ The path to the configuration file. If None, config must be provided. Can be either a str or a Path object.
44
+ prompt : str
45
+ The prompt to run the workflow with.
46
+ to_type : type[_T]
47
+ The type to convert the result to. Default is str.
48
+
49
+ Returns
50
+ -------
51
+ _T
52
+ The result of the workflow converted to the specified type.
53
+ """
54
+ from nat.builder.workflow_builder import WorkflowBuilder
55
+ from nat.runtime.loader import load_config
56
+ from nat.runtime.session import SessionManager
57
+
58
+ if config is not None and config_file is not None:
59
+ raise ValueError("Only one of config or config_file should be provided")
60
+
61
+ if config is None:
62
+ if config_file is None:
63
+ raise ValueError("Either config_file or config must be provided")
64
+
65
+ if not Path(config_file).exists():
66
+ raise ValueError(f"Config file {config_file} does not exist")
67
+
68
+ config = load_config(config_file)
69
+
70
+ session_kwargs = session_kwargs or {}
71
+
72
+ async with WorkflowBuilder.from_config(config=config) as workflow_builder:
73
+ session_manager = SessionManager(await workflow_builder.build())
74
+ async with session_manager.session(**session_kwargs) as session:
75
+ async with session.run(prompt) as runner:
76
+ return await runner.result(to_type=to_type)