nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -1,77 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import logging
17
- from abc import ABC
18
- from abc import abstractmethod
19
- from typing import TypeVar
20
-
21
- from nat.builder.context import Context
22
- from nat.data_models.span import Span
23
- from nat.observability.processor.processor import Processor
24
- from nat.utils.type_utils import override
25
-
26
- RedactionItemT = TypeVar('RedactionItemT')
27
-
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- class RedactionProcessor(Processor[RedactionItemT, RedactionItemT], ABC):
32
- """Abstract base class for redaction processors."""
33
-
34
- @abstractmethod
35
- def should_redact(self, item: RedactionItemT, context: Context) -> bool:
36
- """Determine if this item should be redacted.
37
-
38
- Args:
39
- item (RedactionItemT): The item to check.
40
- context (Context): The current context.
41
-
42
- Returns:
43
- bool: True if the item should be redacted, False otherwise.
44
- """
45
- pass
46
-
47
- @abstractmethod
48
- def redact_item(self, item: RedactionItemT) -> RedactionItemT:
49
- """Redact the item.
50
-
51
- Args:
52
- item (RedactionItemT): The item to redact.
53
-
54
- Returns:
55
- RedactionItemT: The redacted item.
56
- """
57
- pass
58
-
59
- @override
60
- async def process(self, item: RedactionItemT) -> RedactionItemT:
61
- """Perform redaction on the item if it should be redacted.
62
-
63
- Args:
64
- item (RedactionItemT): The item to process.
65
-
66
- Returns:
67
- RedactionItemT: The processed item.
68
- """
69
- context = Context.get()
70
- if self.should_redact(item, context):
71
- return self.redact_item(item)
72
- return item
73
-
74
-
75
- class SpanRedactionProcessor(RedactionProcessor[Span]):
76
- """Abstract base class for span redaction processors."""
77
- pass
@@ -1,414 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Test suite for Code Execution Sandbox using pytest.
17
-
18
- This module provides comprehensive testing for the code execution sandbox service,
19
- replacing the original bash script with a more maintainable Python implementation.
20
- """
21
-
22
- import os
23
- from typing import Any
24
-
25
- import pytest
26
- import requests
27
- from requests.exceptions import ConnectionError
28
- from requests.exceptions import RequestException
29
- from requests.exceptions import Timeout
30
-
31
-
32
- class TestCodeExecutionSandbox:
33
- """Test suite for the Code Execution Sandbox service."""
34
-
35
- @pytest.fixture(scope="class")
36
- def sandbox_config(self):
37
- """Configuration for sandbox testing."""
38
- return {
39
- "url": os.environ.get("SANDBOX_URL", "http://127.0.0.1:6000/execute"),
40
- "timeout": int(os.environ.get("SANDBOX_TIMEOUT", "30")),
41
- "connection_timeout": 5
42
- }
43
-
44
- @pytest.fixture(scope="class", autouse=True)
45
- def check_sandbox_running(self, sandbox_config):
46
- """Check if sandbox server is running before running tests."""
47
- try:
48
- _ = requests.get(sandbox_config["url"], timeout=sandbox_config["connection_timeout"])
49
- print(f"✓ Sandbox server is running at {sandbox_config['url']}")
50
- except (ConnectionError, Timeout, RequestException):
51
- pytest.skip(
52
- f"Sandbox server is not running at {sandbox_config['url']}. "
53
- "Please start it with: cd src/nat/tool/code_execution/local_sandbox && ./start_local_sandbox.sh")
54
-
55
- def execute_code(self, sandbox_config: dict[str, Any], code: str, language: str = "python") -> dict[str, Any]:
56
- """
57
- Execute code in the sandbox and return the response.
58
-
59
- Args:
60
- sandbox_config: Configuration dictionary
61
- code: Code to execute
62
- language: Programming language (default: python)
63
-
64
- Returns:
65
- dictionary containing the response from the sandbox
66
- """
67
- payload = {"generated_code": code, "timeout": sandbox_config["timeout"], "language": language}
68
-
69
- response = requests.post(
70
- sandbox_config["url"],
71
- json=payload,
72
- timeout=sandbox_config["timeout"] + 5 # Add buffer to request timeout
73
- )
74
-
75
- # Ensure we got a response
76
- response.raise_for_status()
77
- return response.json()
78
-
79
- def test_simple_print(self, sandbox_config):
80
- """Test simple print statement execution."""
81
- code = "print('Hello, World!')"
82
- result = self.execute_code(sandbox_config, code)
83
-
84
- assert result["process_status"] == "completed"
85
- assert "Hello, World!" in result["stdout"]
86
- assert result["stderr"] == ""
87
-
88
- def test_basic_arithmetic(self, sandbox_config):
89
- """Test basic arithmetic operations."""
90
- code = """
91
- result = 2 + 3
92
- print(f'Result: {result}')
93
- """
94
- result = self.execute_code(sandbox_config, code)
95
-
96
- assert result["process_status"] == "completed"
97
- assert "Result: 5" in result["stdout"]
98
- assert result["stderr"] == ""
99
-
100
- def test_numpy_operations(self, sandbox_config):
101
- """Test numpy dependency availability and operations."""
102
- code = """
103
- import numpy as np
104
- arr = np.array([1, 2, 3, 4, 5])
105
- print(f'Array: {arr}')
106
- print(f'Mean: {np.mean(arr)}')
107
- """
108
- result = self.execute_code(sandbox_config, code)
109
-
110
- assert result["process_status"] == "completed"
111
- assert "Array: [1 2 3 4 5]" in result["stdout"]
112
- assert "Mean: 3.0" in result["stdout"]
113
- assert result["stderr"] == ""
114
-
115
- def test_pandas_operations(self, sandbox_config):
116
- """Test pandas dependency availability and operations."""
117
- code = """
118
- import pandas as pd
119
- df = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
120
- print(df)
121
- print(f'Sum of column A: {df["A"].sum()}')
122
- """
123
- result = self.execute_code(sandbox_config, code)
124
-
125
- assert result["process_status"] == "completed"
126
- assert "Sum of column A: 6" in result["stdout"]
127
- assert result["stderr"] == ""
128
-
129
- def test_plotly_import(self, sandbox_config):
130
- """Test plotly dependency availability."""
131
- code = """
132
- import plotly.graph_objects as go
133
- print('Plotly imported successfully')
134
- fig = go.Figure()
135
- fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))
136
- print('Plot created successfully')
137
- """
138
- result = self.execute_code(sandbox_config, code)
139
-
140
- assert result["process_status"] == "completed"
141
- assert "Plotly imported successfully" in result["stdout"]
142
- assert "Plot created successfully" in result["stdout"]
143
- assert result["stderr"] == ""
144
-
145
- def test_syntax_error_handling(self, sandbox_config):
146
- """Test handling of syntax errors."""
147
- code = """
148
- print('Hello World'
149
- # Missing closing parenthesis
150
- """
151
- result = self.execute_code(sandbox_config, code)
152
-
153
- assert result["process_status"] == "error"
154
- assert "SyntaxError" in result["stderr"] or "SyntaxError" in result["stdout"]
155
-
156
- def test_runtime_error_handling(self, sandbox_config):
157
- """Test handling of runtime errors."""
158
- code = """
159
- x = 1 / 0
160
- print('This should not print')
161
- """
162
- result = self.execute_code(sandbox_config, code)
163
-
164
- assert result["process_status"] == "error"
165
- assert "ZeroDivisionError" in result["stderr"] or "ZeroDivisionError" in result["stdout"]
166
-
167
- def test_import_error_handling(self, sandbox_config):
168
- """Test handling of import errors."""
169
- code = """
170
- import nonexistent_module
171
- print('This should not print')
172
- """
173
- result = self.execute_code(sandbox_config, code)
174
-
175
- assert result["process_status"] == "error"
176
- assert "ModuleNotFoundError" in result["stderr"] or "ImportError" in result["stderr"]
177
-
178
- def test_mixed_output(self, sandbox_config):
179
- """Test code that produces both stdout and stderr output."""
180
- code = """
181
- import sys
182
- print('This goes to stdout')
183
- print('This goes to stderr', file=sys.stderr)
184
- print('Back to stdout')
185
- """
186
- result = self.execute_code(sandbox_config, code)
187
-
188
- assert result["process_status"] == "completed"
189
- assert "This goes to stdout" in result["stdout"]
190
- assert "Back to stdout" in result["stdout"]
191
- assert "This goes to stderr" in result["stderr"]
192
-
193
- def test_long_running_code(self, sandbox_config):
194
- """Test code that takes some time to execute but completes within timeout."""
195
- code = """
196
- import time
197
- for i in range(3):
198
- print(f'Iteration {i}')
199
- time.sleep(0.5)
200
- print('Completed')
201
- """
202
- result = self.execute_code(sandbox_config, code)
203
-
204
- assert result["process_status"] == "completed"
205
- assert "Iteration 0" in result["stdout"]
206
- assert "Iteration 1" in result["stdout"]
207
- assert "Iteration 2" in result["stdout"]
208
- assert "Completed" in result["stdout"]
209
- assert result["stderr"] == ""
210
-
211
- def test_file_operations(self, sandbox_config):
212
- """Test basic file operations in the sandbox."""
213
- code = """
214
- import os
215
- print(f'Current directory: {os.getcwd()}')
216
- with open('test_file.txt', 'w') as f:
217
- f.write('Hello, World!')
218
- with open('test_file.txt', 'r') as f:
219
- content = f.read()
220
- print(f'File content: {content}')
221
- os.remove('test_file.txt')
222
- print('File operations completed')
223
- """
224
- result = self.execute_code(sandbox_config, code)
225
-
226
- assert result["process_status"] == "completed"
227
- assert "File content: Hello, World!" in result["stdout"]
228
- assert "File operations completed" in result["stdout"]
229
- assert result["stderr"] == ""
230
-
231
- def test_file_persistence_create(self, sandbox_config):
232
- """Test file persistence - create various file types."""
233
- code = """
234
- import os
235
- import pandas as pd
236
- import numpy as np
237
- print('Current directory:', os.getcwd())
238
- print('Directory contents:', os.listdir('.'))
239
-
240
- # Create a test file
241
- with open('persistence_test.txt', 'w') as f:
242
- f.write('Hello from sandbox persistence test!')
243
-
244
- # Create a CSV file
245
- df = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
246
- df.to_csv('persistence_test.csv', index=False)
247
-
248
- # Create a numpy array file
249
- arr = np.array([1, 2, 3, 4, 5])
250
- np.save('persistence_test.npy', arr)
251
-
252
- print('Files created:')
253
- for file in os.listdir('.'):
254
- if 'persistence_test' in file:
255
- print(' -', file)
256
- """
257
- result = self.execute_code(sandbox_config, code)
258
-
259
- assert result["process_status"] == "completed"
260
- assert "persistence_test.txt" in result["stdout"]
261
- assert "persistence_test.csv" in result["stdout"]
262
- assert "persistence_test.npy" in result["stdout"]
263
- assert result["stderr"] == ""
264
-
265
- def test_file_persistence_read(self, sandbox_config):
266
- """Test file persistence - read back created files."""
267
- code = """
268
- import pandas as pd
269
- import numpy as np
270
-
271
- # Read back the files we created
272
- print('=== Reading persistence_test.txt ===')
273
- with open('persistence_test.txt', 'r') as f:
274
- content = f.read()
275
- print(f'Content: {content}')
276
-
277
- print('\\n=== Reading persistence_test.csv ===')
278
- df = pd.read_csv('persistence_test.csv')
279
- print(df)
280
- print(f'DataFrame shape: {df.shape}')
281
-
282
- print('\\n=== Reading persistence_test.npy ===')
283
- arr = np.load('persistence_test.npy')
284
- print(f'Array: {arr}')
285
- print(f'Array sum: {np.sum(arr)}')
286
-
287
- print('\\n=== File persistence test PASSED! ===')
288
- """
289
- result = self.execute_code(sandbox_config, code)
290
-
291
- assert result["process_status"] == "completed"
292
- assert "Content: Hello from sandbox persistence test!" in result["stdout"]
293
- assert "DataFrame shape: (3, 2)" in result["stdout"]
294
- assert "Array: [1 2 3 4 5]" in result["stdout"]
295
- assert "Array sum: 15" in result["stdout"]
296
- assert "File persistence test PASSED!" in result["stdout"]
297
- assert result["stderr"] == ""
298
-
299
- def test_json_operations(self, sandbox_config):
300
- """Test JSON file operations for persistence."""
301
- code = """
302
- import json
303
- import os
304
-
305
- # Create a complex JSON file
306
- data = {
307
- 'test_name': 'sandbox_persistence',
308
- 'timestamp': '2024-07-03',
309
- 'results': {
310
- 'numpy_test': True,
311
- 'pandas_test': True,
312
- 'file_operations': True
313
- },
314
- 'metrics': [1.5, 2.3, 3.7, 4.1],
315
- 'metadata': {
316
- 'working_dir': os.getcwd(),
317
- 'python_version': '3.x'
318
- }
319
- }
320
-
321
- # Save JSON file
322
- with open('persistence_test.json', 'w') as f:
323
- json.dump(data, f, indent=2)
324
-
325
- # Read it back
326
- with open('persistence_test.json', 'r') as f:
327
- loaded_data = json.load(f)
328
-
329
- print('JSON file created and loaded successfully')
330
- print(f'Test name: {loaded_data["test_name"]}')
331
- print(f'Results count: {len(loaded_data["results"])}')
332
- print(f'Metrics: {loaded_data["metrics"]}')
333
- print('JSON persistence test completed!')
334
- """
335
- result = self.execute_code(sandbox_config, code)
336
-
337
- assert result["process_status"] == "completed"
338
- assert "JSON file created and loaded successfully" in result["stdout"]
339
- assert "Test name: sandbox_persistence" in result["stdout"]
340
- assert "Results count: 3" in result["stdout"]
341
- assert "JSON persistence test completed!" in result["stdout"]
342
- assert result["stderr"] == ""
343
-
344
- def test_missing_generated_code_field(self, sandbox_config):
345
- """Test request missing the generated_code field."""
346
- payload = {"timeout": 10, "language": "python"}
347
-
348
- response = requests.post(sandbox_config["url"], json=payload)
349
-
350
- # Should return an error status code or error in response
351
- assert response.status_code != 200 or "error" in response.json()
352
-
353
- def test_missing_timeout_field(self, sandbox_config):
354
- """Test request missing the timeout field."""
355
- payload = {"generated_code": "print('test')", "language": "python"}
356
-
357
- response = requests.post(sandbox_config["url"], json=payload)
358
-
359
- # Should return error for missing timeout field
360
- result = response.json()
361
- assert response.status_code == 400 and result["process_status"] == "error"
362
-
363
- def test_invalid_json(self, sandbox_config):
364
- """Test request with invalid JSON."""
365
- invalid_json = '{"generated_code": "print("test")", "timeout": 10}'
366
-
367
- response = requests.post(sandbox_config["url"], data=invalid_json, headers={"Content-Type": "application/json"})
368
-
369
- # Should return error for invalid JSON
370
- assert response.status_code != 200
371
-
372
- def test_non_json_request(self, sandbox_config):
373
- """Test request with non-JSON content."""
374
- response = requests.post(sandbox_config["url"], data="This is not JSON", headers={"Content-Type": "text/plain"})
375
-
376
- # Should return error for non-JSON content
377
- assert response.status_code != 200
378
-
379
- def test_timeout_too_low(self, sandbox_config):
380
- """Test request with timeout too low."""
381
- code = """
382
- import time
383
- time.sleep(2.0)
384
- """
385
- payload = {"generated_code": code, "timeout": 1, "language": "python"}
386
- response = requests.post(sandbox_config["url"], json=payload)
387
- assert response.json()["process_status"] == "timeout"
388
- assert response.status_code == 200
389
-
390
-
391
- # Pytest configuration and fixtures for command-line options
392
- def pytest_addoption(parser):
393
- """Add custom command-line options for pytest."""
394
- parser.addoption("--sandbox-url",
395
- action="store",
396
- default="http://127.0.0.1:6000/execute",
397
- help="Sandbox URL for testing")
398
- parser.addoption("--sandbox-timeout",
399
- action="store",
400
- type=int,
401
- default=30,
402
- help="Timeout in seconds for sandbox operations")
403
-
404
-
405
- @pytest.fixture(scope="session", autouse=True)
406
- def setup_environment(request):
407
- """Setup environment variables from command-line options."""
408
- os.environ["SANDBOX_URL"] = request.config.getoption("--sandbox-url", "http://127.0.0.1:6000/execute")
409
- os.environ["SANDBOX_TIMEOUT"] = str(request.config.getoption("--sandbox-timeout", 30))
410
-
411
-
412
- if __name__ == "__main__":
413
- # Allow running as a script
414
- pytest.main([__file__, "-v"])
@@ -1,133 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from pydantic import BaseModel
17
- from pydantic import Field
18
-
19
- from nat.builder.builder import Builder
20
- from nat.builder.function_info import FunctionInfo
21
- from nat.cli.register_workflow import register_function
22
- from nat.data_models.function import FunctionBaseConfig
23
-
24
-
25
- class GithubCommitCodeModel(BaseModel):
26
- branch: str = Field(description="The branch of the remote repo to which the code will be committed")
27
- commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
28
- local_path: str = Field(description="Local filepath of the file that has been updated and "
29
- "needs to be committed to the remote repo")
30
- remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
31
- "root of current repository")
32
-
33
-
34
- class GithubCommitCodeModelList(BaseModel):
35
- updated_files: list[GithubCommitCodeModel] = Field(description=("A list of local filepaths and commit messages"))
36
-
37
-
38
- class GithubCommitCodeConfig(FunctionBaseConfig, name="github_commit_code_tool"):
39
- """
40
- Tool that commits and pushes modified code to a remote GitHub repository asynchronously.
41
- """
42
- repo_name: str = Field(description="The repository name in the format 'owner/repo'")
43
- local_repo_dir: str = Field(description="Absolute path to the root of the repo, cloned locally")
44
- timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
45
-
46
-
47
- @register_function(config_type=GithubCommitCodeConfig)
48
- async def commit_code_async(config: GithubCommitCodeConfig, builder: Builder):
49
- """
50
- Commits and pushes modified code to a remote GitHub repository asynchronously.
51
-
52
- """
53
- import json
54
- import os
55
-
56
- import httpx
57
-
58
- github_pat = os.getenv("GITHUB_PAT")
59
- if not github_pat:
60
- raise ValueError("GITHUB_PAT environment variable must be set")
61
-
62
- # define the headers for the payload request
63
- headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
64
-
65
- async def _github_commit_code(updated_files) -> list:
66
- results = []
67
- async with httpx.AsyncClient(timeout=config.timeout) as client:
68
- for file_ in updated_files:
69
- branch = file_.branch
70
- commit_msg = file_.commit_msg
71
- local_path = file_.local_path
72
- remote_path = file_.remote_path
73
-
74
- # Read content from the local file
75
- local_path = os.path.join(config.local_repo_dir, local_path)
76
- with open(local_path, 'r', encoding='utf-8', errors='ignore') as f:
77
- content = f.read()
78
-
79
- # Step 1. Create a blob with the updated contents of the file
80
- blob_url = f'https://api.github.com/repos/{config.repo_name}/git/blobs'
81
- blob_data = {'content': content, 'encoding': 'utf-8'}
82
- blob_response = await client.request("POST", blob_url, json=blob_data, headers=headers)
83
- blob_response.raise_for_status()
84
- blob_sha = blob_response.json()['sha']
85
-
86
- # Step 2: Get the base tree SHA. The commit will be pushed to this ref node in the Git graph
87
- ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
88
- ref_response = await client.request("GET", ref_url, headers=headers)
89
- ref_response.raise_for_status()
90
- base_tree_sha = ref_response.json()['object']['sha']
91
-
92
- # Step 3. Create an updated tree (Git graph) with the new blob
93
- tree_url = f'https://api.github.com/repos/{config.repo_name}/git/trees'
94
- tree_data = {
95
- 'base_tree': base_tree_sha,
96
- 'tree': [{
97
- 'path': remote_path, 'mode': '100644', 'type': 'blob', 'sha': blob_sha
98
- }]
99
- }
100
- tree_response = await client.request("POST", tree_url, json=tree_data, headers=headers)
101
- tree_response.raise_for_status()
102
- tree_sha = tree_response.json()['sha']
103
-
104
- # Step 4: Create a commit
105
- commit_url = f'https://api.github.com/repos/{config.repo_name}/git/commits'
106
- commit_data = {'message': commit_msg, 'tree': tree_sha, 'parents': [base_tree_sha]}
107
- commit_response = await client.request("POST", commit_url, json=commit_data, headers=headers)
108
- commit_response.raise_for_status()
109
- commit_sha = commit_response.json()['sha']
110
-
111
- # Step 5: Update the reference in the Git graph
112
- update_ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
113
- update_ref_data = {'sha': commit_sha}
114
- update_ref_response = await client.request("PATCH",
115
- update_ref_url,
116
- json=update_ref_data,
117
- headers=headers)
118
- update_ref_response.raise_for_status()
119
-
120
- payload_responses = {
121
- 'blob_resp': blob_response.json(),
122
- 'original_tree_ref': tree_response.json(),
123
- 'commit_resp': commit_response.json(),
124
- 'updated_tree_ref_resp': update_ref_response.json()
125
- }
126
- results.append(payload_responses)
127
-
128
- return json.dumps(results)
129
-
130
- yield FunctionInfo.from_fn(_github_commit_code,
131
- description=(f"Commits and pushes modified code to a "
132
- f"GitHub repository in the repo named {config.repo_name}"),
133
- input_schema=GithubCommitCodeModelList)