nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +15 -5
  6. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  7. nat/agent/register.py +2 -0
  8. nat/agent/rewoo_agent/agent.py +4 -2
  9. nat/agent/rewoo_agent/register.py +8 -3
  10. nat/agent/router_agent/__init__.py +0 -0
  11. nat/agent/router_agent/agent.py +329 -0
  12. nat/agent/router_agent/prompt.py +48 -0
  13. nat/agent/router_agent/register.py +97 -0
  14. nat/agent/tool_calling_agent/agent.py +69 -7
  15. nat/agent/tool_calling_agent/register.py +11 -3
  16. nat/builder/builder.py +27 -4
  17. nat/builder/component_utils.py +7 -3
  18. nat/builder/function.py +167 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +213 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +2 -0
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/data_models/component.py +2 -0
  29. nat/data_models/component_ref.py +11 -0
  30. nat/data_models/config.py +40 -16
  31. nat/data_models/function.py +34 -0
  32. nat/data_models/function_dependencies.py +8 -0
  33. nat/data_models/optimizable.py +119 -0
  34. nat/data_models/optimizer.py +149 -0
  35. nat/data_models/temperature_mixin.py +4 -3
  36. nat/data_models/top_p_mixin.py +4 -3
  37. nat/embedder/nim_embedder.py +1 -1
  38. nat/embedder/openai_embedder.py +1 -1
  39. nat/eval/config.py +1 -1
  40. nat/eval/evaluate.py +5 -1
  41. nat/eval/register.py +4 -0
  42. nat/eval/runtime_evaluator/__init__.py +14 -0
  43. nat/eval/runtime_evaluator/evaluate.py +123 -0
  44. nat/eval/runtime_evaluator/register.py +100 -0
  45. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  46. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  47. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  48. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  49. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  50. nat/front_ends/fastapi/job_store.py +518 -99
  51. nat/front_ends/fastapi/main.py +11 -19
  52. nat/front_ends/fastapi/utils.py +57 -0
  53. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  54. nat/llm/aws_bedrock_llm.py +14 -3
  55. nat/llm/nim_llm.py +14 -3
  56. nat/llm/openai_llm.py +8 -1
  57. nat/observability/exporter/processing_exporter.py +29 -55
  58. nat/observability/mixin/redaction_config_mixin.py +5 -4
  59. nat/observability/mixin/tagging_config_mixin.py +26 -14
  60. nat/observability/mixin/type_introspection_mixin.py +401 -107
  61. nat/observability/processor/processor.py +3 -0
  62. nat/observability/processor/redaction/__init__.py +24 -0
  63. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  64. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  65. nat/observability/processor/redaction/redaction_processor.py +177 -0
  66. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  67. nat/observability/processor/span_tagging_processor.py +21 -14
  68. nat/profiler/decorators/framework_wrapper.py +9 -6
  69. nat/profiler/parameter_optimization/__init__.py +0 -0
  70. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  71. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  72. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  73. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  74. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  75. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  76. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  77. nat/profiler/utils.py +3 -1
  78. nat/tool/chat_completion.py +4 -1
  79. nat/tool/github_tools.py +450 -0
  80. nat/tool/register.py +2 -7
  81. nat/utils/callable_utils.py +70 -0
  82. nat/utils/exception_handlers/automatic_retries.py +103 -48
  83. nat/utils/type_utils.py +4 -0
  84. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  85. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
  86. nat/observability/processor/header_redaction_processor.py +0 -123
  87. nat/observability/processor/redaction_processor.py +0 -77
  88. nat/tool/github_tools/create_github_commit.py +0 -133
  89. nat/tool/github_tools/create_github_issue.py +0 -87
  90. nat/tool/github_tools/create_github_pr.py +0 -106
  91. nat/tool/github_tools/get_github_file.py +0 -106
  92. nat/tool/github_tools/get_github_issue.py +0 -166
  93. nat/tool/github_tools/get_github_pr.py +0 -256
  94. nat/tool/github_tools/update_github_issue.py +0 -100
  95. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  96. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  97. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -1,123 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
- from collections.abc import Callable
18
- from functools import lru_cache
19
-
20
- from starlette.datastructures import Headers
21
-
22
- from nat.builder.context import Context
23
- from nat.data_models.span import Span
24
- from nat.observability.processor.redaction_processor import SpanRedactionProcessor
25
- from nat.utils.type_utils import override
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
-
30
- def default_callback(_auth_key: str) -> bool:
31
- """Default callback that always returns False."""
32
- return False
33
-
34
-
35
- class HeaderRedactionProcessor(SpanRedactionProcessor):
36
- """Processor that redacts the span based on auth key, span attributes, and callback.
37
-
38
- Uses an LRU cache to avoid redundant callback executions for the same auth keys,
39
- providing bounded memory usage and automatic eviction of least recently used entries.
40
-
41
- Args:
42
- attributes: List of span attribute keys to redact.
43
- header: The header key to check for authentication.
44
- callback: Function to determine if the auth key should trigger redaction.
45
- enabled: Whether the processor is enabled (default: True).
46
- force_redact: If True, always redact regardless of header checks (default: False).
47
- redaction_value: The value to replace redacted attributes with (default: "[REDACTED]").
48
- """
49
-
50
- def __init__(self,
51
- attributes: list[str] | None = None,
52
- header: str | None = None,
53
- callback: Callable[[str], bool] | None = None,
54
- enabled: bool = True,
55
- force_redact: bool = False,
56
- redaction_value: str = "[REDACTED]"):
57
- self.attributes = attributes or []
58
- self.header = header
59
- self.callback = callback or default_callback
60
- self.enabled = enabled
61
- self.force_redact = force_redact
62
- self.redaction_value = redaction_value
63
-
64
- @override
65
- def should_redact(self, item: Span, context: Context) -> bool:
66
- """Determine if this span should be redacted based on header auth.
67
-
68
- Args:
69
- item (Span): The span to check.
70
- context (Context): The current context.
71
-
72
- Returns:
73
- bool: True if the span should be redacted, False otherwise.
74
- """
75
- # If force_redact is enabled, always redact regardless of other conditions
76
- if self.force_redact:
77
- return True
78
-
79
- if not self.enabled:
80
- return False
81
-
82
- headers: Headers | None = context.metadata.headers
83
-
84
- if headers is None or self.header is None:
85
- return False
86
-
87
- auth_key = headers.get(self.header, None)
88
-
89
- if not auth_key:
90
- return False
91
-
92
- # Use LRU cached method to determine if redaction is needed
93
- return self._should_redact_impl(auth_key)
94
-
95
- @lru_cache(maxsize=128)
96
- def _should_redact_impl(self, auth_key: str) -> bool:
97
- """Implementation method for checking if redaction should occur.
98
-
99
- This method uses lru_cache to avoid redundant callback executions.
100
-
101
- Args:
102
- auth_key (str): The authentication key to check.
103
-
104
- Returns:
105
- bool: True if the span should be redacted, False otherwise.
106
- """
107
- return self.callback(auth_key)
108
-
109
- @override
110
- def redact_item(self, item: Span) -> Span:
111
- """Redact the span.
112
-
113
- Args:
114
- item (Span): The span to redact.
115
-
116
- Returns:
117
- Span: The redacted span.
118
- """
119
- for key in self.attributes:
120
- if key in item.attributes:
121
- item.attributes[key] = self.redaction_value
122
-
123
- return item
@@ -1,77 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
- from abc import ABC
18
- from abc import abstractmethod
19
- from typing import TypeVar
20
-
21
- from nat.builder.context import Context
22
- from nat.data_models.span import Span
23
- from nat.observability.processor.processor import Processor
24
- from nat.utils.type_utils import override
25
-
26
- RedactionItemT = TypeVar('RedactionItemT')
27
-
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- class RedactionProcessor(Processor[RedactionItemT, RedactionItemT], ABC):
32
- """Abstract base class for redaction processors."""
33
-
34
- @abstractmethod
35
- def should_redact(self, item: RedactionItemT, context: Context) -> bool:
36
- """Determine if this item should be redacted.
37
-
38
- Args:
39
- item (RedactionItemT): The item to check.
40
- context (Context): The current context.
41
-
42
- Returns:
43
- bool: True if the item should be redacted, False otherwise.
44
- """
45
- pass
46
-
47
- @abstractmethod
48
- def redact_item(self, item: RedactionItemT) -> RedactionItemT:
49
- """Redact the item.
50
-
51
- Args:
52
- item (RedactionItemT): The item to redact.
53
-
54
- Returns:
55
- RedactionItemT: The redacted item.
56
- """
57
- pass
58
-
59
- @override
60
- async def process(self, item: RedactionItemT) -> RedactionItemT:
61
- """Perform redaction on the item if it should be redacted.
62
-
63
- Args:
64
- item (RedactionItemT): The item to process.
65
-
66
- Returns:
67
- RedactionItemT: The processed item.
68
- """
69
- context = Context.get()
70
- if self.should_redact(item, context):
71
- return self.redact_item(item)
72
- return item
73
-
74
-
75
- class SpanRedactionProcessor(RedactionProcessor[Span]):
76
- """Abstract base class for span redaction processors."""
77
- pass
@@ -1,133 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from pydantic import BaseModel
17
- from pydantic import Field
18
-
19
- from nat.builder.builder import Builder
20
- from nat.builder.function_info import FunctionInfo
21
- from nat.cli.register_workflow import register_function
22
- from nat.data_models.function import FunctionBaseConfig
23
-
24
-
25
- class GithubCommitCodeModel(BaseModel):
26
- branch: str = Field(description="The branch of the remote repo to which the code will be committed")
27
- commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
28
- local_path: str = Field(description="Local filepath of the file that has been updated and "
29
- "needs to be committed to the remote repo")
30
- remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
31
- "root of current repository")
32
-
33
-
34
- class GithubCommitCodeModelList(BaseModel):
35
- updated_files: list[GithubCommitCodeModel] = Field(description=("A list of local filepaths and commit messages"))
36
-
37
-
38
- class GithubCommitCodeConfig(FunctionBaseConfig, name="github_commit_code_tool"):
39
- """
40
- Tool that commits and pushes modified code to a remote GitHub repository asynchronously.
41
- """
42
- repo_name: str = Field(description="The repository name in the format 'owner/repo'")
43
- local_repo_dir: str = Field(description="Absolute path to the root of the repo, cloned locally")
44
- timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
45
-
46
-
47
- @register_function(config_type=GithubCommitCodeConfig)
48
- async def commit_code_async(config: GithubCommitCodeConfig, builder: Builder):
49
- """
50
- Commits and pushes modified code to a remote GitHub repository asynchronously.
51
-
52
- """
53
- import json
54
- import os
55
-
56
- import httpx
57
-
58
- github_pat = os.getenv("GITHUB_PAT")
59
- if not github_pat:
60
- raise ValueError("GITHUB_PAT environment variable must be set")
61
-
62
- # define the headers for the payload request
63
- headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
64
-
65
- async def _github_commit_code(updated_files) -> list:
66
- results = []
67
- async with httpx.AsyncClient(timeout=config.timeout) as client:
68
- for file_ in updated_files:
69
- branch = file_.branch
70
- commit_msg = file_.commit_msg
71
- local_path = file_.local_path
72
- remote_path = file_.remote_path
73
-
74
- # Read content from the local file
75
- local_path = os.path.join(config.local_repo_dir, local_path)
76
- with open(local_path, 'r', encoding='utf-8', errors='ignore') as f:
77
- content = f.read()
78
-
79
- # Step 1. Create a blob with the updated contents of the file
80
- blob_url = f'https://api.github.com/repos/{config.repo_name}/git/blobs'
81
- blob_data = {'content': content, 'encoding': 'utf-8'}
82
- blob_response = await client.request("POST", blob_url, json=blob_data, headers=headers)
83
- blob_response.raise_for_status()
84
- blob_sha = blob_response.json()['sha']
85
-
86
- # Step 2: Get the base tree SHA. The commit will be pushed to this ref node in the Git graph
87
- ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
88
- ref_response = await client.request("GET", ref_url, headers=headers)
89
- ref_response.raise_for_status()
90
- base_tree_sha = ref_response.json()['object']['sha']
91
-
92
- # Step 3. Create an updated tree (Git graph) with the new blob
93
- tree_url = f'https://api.github.com/repos/{config.repo_name}/git/trees'
94
- tree_data = {
95
- 'base_tree': base_tree_sha,
96
- 'tree': [{
97
- 'path': remote_path, 'mode': '100644', 'type': 'blob', 'sha': blob_sha
98
- }]
99
- }
100
- tree_response = await client.request("POST", tree_url, json=tree_data, headers=headers)
101
- tree_response.raise_for_status()
102
- tree_sha = tree_response.json()['sha']
103
-
104
- # Step 4: Create a commit
105
- commit_url = f'https://api.github.com/repos/{config.repo_name}/git/commits'
106
- commit_data = {'message': commit_msg, 'tree': tree_sha, 'parents': [base_tree_sha]}
107
- commit_response = await client.request("POST", commit_url, json=commit_data, headers=headers)
108
- commit_response.raise_for_status()
109
- commit_sha = commit_response.json()['sha']
110
-
111
- # Step 5: Update the reference in the Git graph
112
- update_ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
113
- update_ref_data = {'sha': commit_sha}
114
- update_ref_response = await client.request("PATCH",
115
- update_ref_url,
116
- json=update_ref_data,
117
- headers=headers)
118
- update_ref_response.raise_for_status()
119
-
120
- payload_responses = {
121
- 'blob_resp': blob_response.json(),
122
- 'original_tree_ref': tree_response.json(),
123
- 'commit_resp': commit_response.json(),
124
- 'updated_tree_ref_resp': update_ref_response.json()
125
- }
126
- results.append(payload_responses)
127
-
128
- return json.dumps(results)
129
-
130
- yield FunctionInfo.from_fn(_github_commit_code,
131
- description=(f"Commits and pushes modified code to a "
132
- f"GitHub repository in the repo named {config.repo_name}"),
133
- input_schema=GithubCommitCodeModelList)
@@ -1,87 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from pydantic import BaseModel
17
- from pydantic import Field
18
-
19
- from nat.builder.builder import Builder
20
- from nat.builder.function_info import FunctionInfo
21
- from nat.cli.register_workflow import register_function
22
- from nat.data_models.function import FunctionBaseConfig
23
-
24
-
25
- class GithubCreateIssueModel(BaseModel):
26
- title: str = Field(description="The title of the GitHub Issue")
27
- body: str = Field(description="The body of the GitHub Issue")
28
-
29
-
30
- class GithubCreateIssueModelList(BaseModel):
31
- issues: list[GithubCreateIssueModel] = Field(description=("A list of GitHub issues, "
32
- "each with a title and a body"))
33
-
34
-
35
- class GithubCreateIssueToolConfig(FunctionBaseConfig, name="github_create_issue_tool"):
36
- """
37
- Tool that creates an issue in a GitHub repository asynchronously.
38
- """
39
- repo_name: str = Field(description="The repository name in the format 'owner/repo'")
40
- timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
41
-
42
-
43
- @register_function(config_type=GithubCreateIssueToolConfig)
44
- async def create_github_issue_async(config: GithubCreateIssueToolConfig, builder: Builder):
45
- """
46
- Creates an issue in a GitHub repository asynchronously.
47
- """
48
- import json
49
- import os
50
-
51
- import httpx
52
-
53
- github_pat = os.getenv("GITHUB_PAT")
54
- if not github_pat:
55
- raise ValueError("GITHUB_PAT environment variable must be set")
56
-
57
- url = f"https://api.github.com/repos/{config.repo_name}/issues"
58
-
59
- # define the headers for the payload request
60
- headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
61
-
62
- async def _github_post_issue(issues) -> list:
63
- results = []
64
- async with httpx.AsyncClient(timeout=config.timeout) as client:
65
- for issue in issues:
66
- # define the payload body
67
- payload = issue.dict(exclude_unset=True)
68
-
69
- response = await client.request("POST", url, json=payload, headers=headers)
70
-
71
- # Raise an exception for HTTP errors
72
- response.raise_for_status()
73
-
74
- # Parse and return the response JSON
75
- try:
76
- result = response.json()
77
- results.append(result)
78
-
79
- except ValueError as e:
80
- raise ValueError("The API response is not valid JSON.") from e
81
-
82
- return json.dumps(results)
83
-
84
- yield FunctionInfo.from_fn(_github_post_issue,
85
- description=(f"Creates a GitHub issue in the "
86
- f"repo named {config.repo_name}"),
87
- input_schema=GithubCreateIssueModelList)
@@ -1,106 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from pydantic import BaseModel
17
- from pydantic import Field
18
-
19
- from nat.builder.builder import Builder
20
- from nat.builder.function_info import FunctionInfo
21
- from nat.cli.register_workflow import register_function
22
- from nat.data_models.function import FunctionBaseConfig
23
-
24
-
25
- class GithubCreatePullModel(BaseModel):
26
- title: str = Field(description="Title of the pull request")
27
- body: str = Field(description="Description of the pull request")
28
- source_branch: str = Field(description="The name of the branch containing your changes")
29
- target_branch: str = Field(description="The name of the branch you want to merge into")
30
- assignees: list[str] | None = Field([],
31
- description="List of GitHub usernames to assign to the PR. "
32
- "Always the current user")
33
- reviewers: list[str] | None = Field([], description="List of GitHub usernames to request review from")
34
-
35
-
36
- class GithubCreatePullList(BaseModel):
37
- pull_details: GithubCreatePullModel = Field(description=("A list of params used for creating the PR in GitHub"))
38
-
39
-
40
- class GithubCreatePullConfig(FunctionBaseConfig, name="github_create_pull_tool"):
41
- """
42
- Tool that creates a pull request in a GitHub repository asynchronously with assignees and reviewers.
43
- """
44
- repo_name: str = Field(description="The repository name in the format 'owner/repo'")
45
- timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
46
-
47
-
48
- @register_function(config_type=GithubCreatePullConfig)
49
- async def create_pull_request_async(config: GithubCreatePullConfig, builder: Builder):
50
- """
51
- Creates a pull request in a GitHub repository asynchronously with assignees and reviewers.
52
-
53
- """
54
- import json
55
- import os
56
-
57
- import httpx
58
-
59
- github_pat = os.getenv("GITHUB_PAT")
60
- if not github_pat:
61
- raise ValueError("GITHUB_PAT environment variable must be set")
62
-
63
- headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
64
-
65
- async def _github_create_pull(pull_details: GithubCreatePullList) -> str:
66
- results = []
67
- async with httpx.AsyncClient(timeout=config.timeout) as client:
68
- # Create pull request
69
- pr_url = f'https://api.github.com/repos/{config.repo_name}/pulls'
70
- pr_data = {
71
- 'title': pull_details.title,
72
- 'body': pull_details.body,
73
- 'head': pull_details.source_branch,
74
- 'base': pull_details.target_branch
75
- }
76
-
77
- pr_response = await client.request("POST", pr_url, json=pr_data, headers=headers)
78
- pr_response.raise_for_status()
79
- pr_number = pr_response.json()['number']
80
-
81
- # Add assignees if provided
82
- if pull_details.assignees:
83
- assignees_url = f'https://api.github.com/repos/{config.repo_name}/issues/{pr_number}/assignees'
84
- assignees_data = {'assignees': pull_details.assignees}
85
- assignees_response = await client.request("POST", assignees_url, json=assignees_data, headers=headers)
86
- assignees_response.raise_for_status()
87
-
88
- # Request reviewers if provided
89
- if pull_details.reviewers:
90
- reviewers_url = f'https://api.github.com/repos/{config.repo_name}/pulls/{pr_number}/requested_reviewers'
91
- reviewers_data = {'reviewers': pull_details.reviewers}
92
- reviewers_response = await client.request("POST", reviewers_url, json=reviewers_data, headers=headers)
93
- reviewers_response.raise_for_status()
94
-
95
- results.append({
96
- 'pull_request': pr_response.json(),
97
- 'assignees': assignees_response.json() if pull_details.assignees else None,
98
- 'reviewers': reviewers_response.json() if pull_details.reviewers else None
99
- })
100
-
101
- return json.dumps(results)
102
-
103
- yield FunctionInfo.from_fn(_github_create_pull,
104
- description=(f"Creates a pull request with assignees and reviewers in the "
105
- f"GitHub repository named {config.repo_name}"),
106
- input_schema=GithubCreatePullList)
@@ -1,106 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from nat.builder.builder import Builder
17
- from nat.builder.function_info import FunctionInfo
18
- from nat.cli.register_workflow import register_function
19
- from nat.data_models.function import FunctionBaseConfig
20
-
21
-
22
- class GithubGetFileToolConfig(FunctionBaseConfig, name="github_getfile"):
23
- """
24
- Tool that returns the text of a github file using a github url starting with https://github.com and ending
25
- with a specific file.
26
- """
27
- pass
28
-
29
-
30
- @register_function(config_type=GithubGetFileToolConfig)
31
- async def github_text_from_url(config: GithubGetFileToolConfig, builder: Builder):
32
-
33
- import re
34
-
35
- import requests
36
-
37
- async def _github_text_from_url(url_text: str) -> str:
38
-
39
- # Extract sections of the base github path
40
- pattern = r"https://github.com/(.*)/blob/(.*)"
41
- matches = re.findall(pattern, url_text)
42
-
43
- if (len(matches) == 0):
44
- return ("Invalid github url. Please provide a valid github url. "
45
- "Example: 'https://github.com/my_repository/blob/main/file.txt'")
46
-
47
- # Construct raw content path
48
- raw_url = f"https://raw.githubusercontent.com/{matches[0][0]}/refs/heads/{matches[0][1]}"
49
- # Grab raw text from github
50
- try:
51
- response = requests.get(raw_url, timeout=60)
52
- except requests.exceptions.Timeout:
53
- return f"Timeout encountered when retrieving resource: {raw_url}"
54
-
55
- return f"```python\n{response.text}\n```"
56
-
57
- yield FunctionInfo.from_fn(_github_text_from_url,
58
- description=("Returns the text of a github file using a github url starting with"
59
- "https://github.com and ending with a specific file."))
60
-
61
-
62
- class GithubGetFileLinesToolConfig(FunctionBaseConfig, name="github_getfilelines"):
63
- """
64
- Tool that returns the text lines of a github file using a github url starting with
65
- https://github.com and ending with a specific file line references. Examples of line references are
66
- #L409-L417 and #L166-L171.
67
- """
68
- pass
69
-
70
-
71
- @register_function(config_type=GithubGetFileLinesToolConfig)
72
- async def github_text_lines_from_url(config: GithubGetFileLinesToolConfig, builder: Builder):
73
-
74
- import re
75
-
76
- async def _github_text_lines_from_url(url_text: str) -> str:
77
-
78
- import requests
79
-
80
- # Extract sections of the base github path
81
- pattern = r"https://github.com/(.*)/blob/(.*)(#L(\d+)-L(\d+))"
82
- matches = re.findall(pattern, url_text)
83
-
84
- if (len(matches) == 0):
85
- return ("Invalid github url. Please provide a valid github url with line information. "
86
- "Example: 'https://github.com/my_repository/blob/main/file.txt#L409-L417'")
87
-
88
- start_line, end_line = int(matches[0][3]), int(matches[0][4])
89
- # Construct raw content path
90
- raw_url = f"https://raw.githubusercontent.com/{matches[0][0]}/refs/heads/{matches[0][1]}"
91
- # Grab raw text from github
92
- try:
93
- response = requests.get(raw_url, timeout=60)
94
- except requests.exceptions.Timeout:
95
- return f"Timeout encountered when retrieving resource: {raw_url}"
96
- # Extract the specified lines
97
- file_lines = response.text.splitlines()
98
- selected_lines = file_lines[start_line:end_line]
99
- joined_selected_lines = "\n".join(selected_lines)
100
-
101
- return f"```python\n{joined_selected_lines}\n```"
102
-
103
- yield FunctionInfo.from_fn(_github_text_lines_from_url,
104
- description=("Returns the text lines of a github file using a github url starting with"
105
- "https://github.com and ending with a specific file line references. "
106
- "Examples of line references are #L409-L417 and #L166-L171."))