aiqtoolkit 1.1.0a20250429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiqtoolkit might be problematic. Click here for more details.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +76 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +322 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +46 -0
- aiq/agent/react_agent/register.py +148 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +224 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +410 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +123 -0
- aiq/agent/tool_calling_agent/register.py +105 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +223 -0
- aiq/builder/component_utils.py +303 -0
- aiq/builder/context.py +198 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +116 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +297 -0
- aiq/builder/function_base.py +372 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +125 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +71 -0
- aiq/builder/workflow.py +134 -0
- aiq/builder/workflow_builder.py +733 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +233 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +34 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +37 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +36 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/start.py +250 -0
- aiq/cli/commands/uninstall.py +83 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +307 -0
- aiq/cli/entrypoint.py +133 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +408 -0
- aiq/cli/type_registry.py +869 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +550 -0
- aiq/data_models/common.py +143 -0
- aiq/data_models/component.py +46 -0
- aiq/data_models/component_ref.py +135 -0
- aiq/data_models/config.py +349 -0
- aiq/data_models/dataset_handler.py +122 -0
- aiq/data_models/discovery_metadata.py +269 -0
- aiq/data_models/embedder.py +26 -0
- aiq/data_models/evaluate.py +101 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +64 -0
- aiq/data_models/interactive.py +237 -0
- aiq/data_models/intermediate_step.py +269 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +26 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +26 -0
- aiq/data_models/profiler.py +53 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +58 -0
- aiq/embedder/openai_embedder.py +42 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +42 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +164 -0
- aiq/eval/evaluate.py +322 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/evaluator_model.py +44 -0
- aiq/eval/intermediate_step_adapter.py +93 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +138 -0
- aiq/eval/rag_evaluator/register.py +138 -0
- aiq/eval/register.py +22 -0
- aiq/eval/remote_workflow.py +128 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +118 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +131 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +107 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +150 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +103 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +574 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +161 -0
- aiq/front_ends/fastapi/main.py +70 -0
- aiq/front_ends/fastapi/message_handler.py +279 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +181 -0
- aiq/front_ends/fastapi/step_adaptor.py +315 -0
- aiq/front_ends/fastapi/websocket.py +148 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +52 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/nim_llm.py +45 -0
- aiq/llm/openai_llm.py +45 -0
- aiq/llm/register.py +22 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +102 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +59 -0
- aiq/observability/__init__.py +0 -0
- aiq/observability/async_otel_listener.py +270 -0
- aiq/observability/register.py +97 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +278 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +25 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +452 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +433 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +198 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +81 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +188 -0
- aiq/runtime/runner.py +176 -0
- aiq/runtime/session.py +116 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +188 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +79 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +4 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +25 -0
- aiq/tool/code_execution/register.py +70 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/mcp_client.py +220 -0
- aiq/tool/mcp/mcp_tool.py +75 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +67 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +36 -0
- aiq/tool/retriever.py +89 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/yaml_tools.py +50 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/type_converter.py +232 -0
- aiq/utils/type_utils.py +397 -0
- aiq/utils/url_utils.py +27 -0
- aiqtoolkit-1.1.0a20250429.dist-info/METADATA +326 -0
- aiqtoolkit-1.1.0a20250429.dist-info/RECORD +309 -0
- aiqtoolkit-1.1.0a20250429.dist-info/WHEEL +5 -0
- aiqtoolkit-1.1.0a20250429.dist-info/entry_points.txt +17 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- aiqtoolkit-1.1.0a20250429.dist-info/licenses/LICENSE.md +201 -0
- aiqtoolkit-1.1.0a20250429.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import multiprocessing
|
|
17
|
+
import resource
|
|
18
|
+
import sys
|
|
19
|
+
from io import StringIO
|
|
20
|
+
|
|
21
|
+
from flask import Flask
|
|
22
|
+
from flask import request
|
|
23
|
+
|
|
24
|
+
app = Flask(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@app.after_request
|
|
28
|
+
def add_hsts_header(response):
|
|
29
|
+
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
|
|
30
|
+
return response
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def execute_python(generated_code, timeout):
|
|
34
|
+
# running in a separate process to ensure any kind of crashes are properly handled
|
|
35
|
+
queue = multiprocessing.Queue()
|
|
36
|
+
process = multiprocessing.Process(target=execute_code_subprocess, args=(generated_code, queue))
|
|
37
|
+
process.start()
|
|
38
|
+
process.join(timeout=timeout)
|
|
39
|
+
|
|
40
|
+
if process.is_alive(): # didn't finish successfully
|
|
41
|
+
process.kill()
|
|
42
|
+
return {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"}
|
|
43
|
+
|
|
44
|
+
return queue.get()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# need to memory-limit to avoid common errors of allocating too much
|
|
48
|
+
# but this has to be done in a subprocess to not crush server itself
|
|
49
|
+
def execute_code_subprocess(generated_code, queue):
|
|
50
|
+
limit = 1024 * 1024 * 1024 * 10 # 10gb - somehow with a smaller limit the server dies when numpy is used
|
|
51
|
+
resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
|
|
52
|
+
resource.setrlimit(resource.RLIMIT_DATA, (limit, limit))
|
|
53
|
+
|
|
54
|
+
# this can be overriden inside generated code, so it's not a guaranteed protection
|
|
55
|
+
sys.stdout = StringIO()
|
|
56
|
+
try:
|
|
57
|
+
exec(generated_code, {}) # pylint: disable=W0122
|
|
58
|
+
queue.put(sys.stdout.getvalue())
|
|
59
|
+
except Exception as e:
|
|
60
|
+
print(f"Error: {str(e)}")
|
|
61
|
+
queue.put({"process_status": "error", "stdout": "", "stderr": str(e) + "\n"})
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Main Flask endpoint to handle execution requests
|
|
65
|
+
@app.route("/execute", methods=["POST"])
|
|
66
|
+
def execute():
|
|
67
|
+
generated_code = request.json['generated_code']
|
|
68
|
+
timeout = request.json['timeout']
|
|
69
|
+
language = request.json.get('language', 'python')
|
|
70
|
+
|
|
71
|
+
if language == 'python':
|
|
72
|
+
return execute_python(generated_code, timeout)
|
|
73
|
+
return {"process_status": "error", "stdout": "", "stderr": "Only python execution is supported"}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == '__main__':
|
|
77
|
+
log = logging.getLogger('werkzeug')
|
|
78
|
+
log.setLevel(logging.WARNING)
|
|
79
|
+
app.run(port=6000)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
# NOTE: needs to run from the root of the repo!
|
|
18
|
+
|
|
19
|
+
SANDBOX_NAME=${1:-'local-sandbox'}
|
|
20
|
+
NUM_THREADS=10
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
docker build --tag=${SANDBOX_NAME} --build-arg="UWSGI_PROCESSES=$((${NUM_THREADS} * 10))" --build-arg="UWSGI_CHEAPER=${NUM_THREADS}" -f Dockerfile.sandbox .
|
|
24
|
+
|
|
25
|
+
docker run --network=host --rm --name=local-sandbox ${SANDBOX_NAME}
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
from pydantic import HttpUrl
|
|
22
|
+
|
|
23
|
+
from aiq.builder.builder import Builder
|
|
24
|
+
from aiq.builder.function_info import FunctionInfo
|
|
25
|
+
from aiq.cli.register_workflow import register_function
|
|
26
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CodeExecutionToolConfig(FunctionBaseConfig, name="code_execution"):
|
|
32
|
+
"""
|
|
33
|
+
Tool for executing python code in a remotely hosted sandbox environment.
|
|
34
|
+
"""
|
|
35
|
+
uri: HttpUrl = Field(default=HttpUrl("http://127.0.0.1:6000"),
|
|
36
|
+
description="URI for the code execution sandbox server")
|
|
37
|
+
sandbox_type: Literal["local", "piston"] = Field(default="local", description="The type of code execution sandbox")
|
|
38
|
+
timeout: float = Field(default=10.0, description="Number of seconds to wait for a code execution request")
|
|
39
|
+
max_output_characters: int = Field(default=1000, description="Maximum number of characters that can be returned")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_function(config_type=CodeExecutionToolConfig)
|
|
43
|
+
async def code_execution_tool(config: CodeExecutionToolConfig, builder: Builder):
|
|
44
|
+
from aiq.tool.code_execution.code_sandbox import get_sandbox
|
|
45
|
+
|
|
46
|
+
class CodeExecutionInputSchema(BaseModel):
|
|
47
|
+
generated_code: str = Field(description="String containing the code to be executed")
|
|
48
|
+
|
|
49
|
+
sandbox = get_sandbox(sandbox_type=config.sandbox_type, uri=config.uri)
|
|
50
|
+
|
|
51
|
+
async def _execute_code(generated_code: str) -> dict:
|
|
52
|
+
logger.info("Executing code in the sandbox at %s", config.uri)
|
|
53
|
+
try:
|
|
54
|
+
output = await sandbox.execute_code(
|
|
55
|
+
generated_code=generated_code,
|
|
56
|
+
language="python",
|
|
57
|
+
timeout=config.timeout,
|
|
58
|
+
max_output_characters=config.max_output_characters,
|
|
59
|
+
)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.exception("Error when executing code in the sandbox, %s", e)
|
|
62
|
+
return {"process_status": "error", "stdout": "", "stderr": e}
|
|
63
|
+
return output
|
|
64
|
+
|
|
65
|
+
yield FunctionInfo.from_fn(
|
|
66
|
+
fn=_execute_code,
|
|
67
|
+
input_schema=CodeExecutionInputSchema,
|
|
68
|
+
description="""Executes the provied 'generated_code' in a python sandbox environment and returns
|
|
69
|
+
a dictionary containing stdout, stderr, and the execution status, as well as a session_id. The
|
|
70
|
+
session_id can be used to append to code that was previously executed.""")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import glob
|
|
16
|
+
import logging
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
LOG = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def format_code_output(execution_dict: dict[str, str],
|
|
23
|
+
code_output_begin: str,
|
|
24
|
+
code_output_end: str,
|
|
25
|
+
code_output_format: str = 'llama'):
|
|
26
|
+
"""Formatting code output to be displayed as an llm expects it."""
|
|
27
|
+
if code_output_format == 'llama':
|
|
28
|
+
output = execution_dict["process_status"]
|
|
29
|
+
if execution_dict['stdout']:
|
|
30
|
+
output += f"\n[stdout]\n{execution_dict['stdout']}[/stdout]"
|
|
31
|
+
if execution_dict['stderr']:
|
|
32
|
+
output += f"\n[stderr]\n{execution_dict['stderr']}[/stderr]"
|
|
33
|
+
output = f"{code_output_begin}\n\n{output}{code_output_end}\n\n"
|
|
34
|
+
elif code_output_format == 'qwen':
|
|
35
|
+
output = ""
|
|
36
|
+
if execution_dict['stdout']:
|
|
37
|
+
output += f"{execution_dict['stdout']}"
|
|
38
|
+
if execution_dict['stderr']:
|
|
39
|
+
output += f"{execution_dict['stderr']}"
|
|
40
|
+
if execution_dict['stderr'] and execution_dict['stdout']:
|
|
41
|
+
LOG.warning("Both stdout and stderr are not empty. This shouldn't normally happen! %s", execution_dict)
|
|
42
|
+
output = f"{code_output_begin}{output}{code_output_end}"
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unknown code_output_format: {code_output_format}")
|
|
45
|
+
|
|
46
|
+
# wrapping with code output separators
|
|
47
|
+
return output
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _extract_between_separators(generation: str, separators: tuple[str, str], extract_all: bool = False):
|
|
51
|
+
"""Extracting all text between last occurrence of separators[0] and [1].
|
|
52
|
+
|
|
53
|
+
If extract_all is True, returning a list with all occurrences of text between separators.
|
|
54
|
+
"""
|
|
55
|
+
if extract_all:
|
|
56
|
+
separators = [re.escape(sp) for sp in separators]
|
|
57
|
+
pattern = f'{separators[0]}(.*?){separators[1]}'
|
|
58
|
+
return re.findall(pattern, generation, re.DOTALL)
|
|
59
|
+
return generation.split(separators[0])[-1].split(separators[1])[0]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def extract_code_to_execute(generation: str, code_begin: str, code_end: str, extract_all: bool = False):
|
|
63
|
+
return _extract_between_separators(generation, [code_begin, code_end], extract_all)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def extract_code_output(generation: str, code_output_begin: str, code_output_end: str, extract_all: bool = False):
|
|
67
|
+
return _extract_between_separators(generation, [code_output_begin, code_output_end], extract_all)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def unroll_files(input_files):
|
|
71
|
+
if len(input_files) == 0:
|
|
72
|
+
raise ValueError("No files found with the given pattern.")
|
|
73
|
+
total_files = 0
|
|
74
|
+
for file_pattern in input_files:
|
|
75
|
+
for file in sorted(glob.glob(file_pattern, recursive=True)):
|
|
76
|
+
total_files += 1
|
|
77
|
+
yield file
|
|
78
|
+
if total_files == 0:
|
|
79
|
+
raise ValueError("No files found with the given pattern.")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def python_doc_to_cmd_help(doc_class, docs_prefix="", arg_prefix=""):
|
|
83
|
+
"""Converts python doc to cmd help format.
|
|
84
|
+
|
|
85
|
+
Will color the args and change the format to match what we use in cmd help.
|
|
86
|
+
"""
|
|
87
|
+
all_args = docs_prefix
|
|
88
|
+
all_args += doc_class.__doc__.split("Args:")[1].rstrip()
|
|
89
|
+
# \033[92m ... \033[0m - green in terminal
|
|
90
|
+
colored_args = ""
|
|
91
|
+
for line in all_args.split("\n"):
|
|
92
|
+
if " " in line and " - " in line:
|
|
93
|
+
# add colors
|
|
94
|
+
line = line.replace(" ", " \033[92m").replace(" - ", "\033[0m - ")
|
|
95
|
+
# fixing arg format
|
|
96
|
+
line = line.replace(' \033[92m', f' \033[92m{arg_prefix}')
|
|
97
|
+
# fixing indent
|
|
98
|
+
line = line.replace(" ", " ").replace(" ", " ")
|
|
99
|
+
colored_args += line + '\n'
|
|
100
|
+
return colored_args[:-1]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from aiq.builder.builder import Builder
|
|
17
|
+
from aiq.builder.function_info import FunctionInfo
|
|
18
|
+
from aiq.cli.register_workflow import register_function
|
|
19
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CurrentTimeToolConfig(FunctionBaseConfig, name="current_datetime"):
|
|
23
|
+
"""
|
|
24
|
+
Simple tool which returns the current date and time in human readable format.
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register_function(config_type=CurrentTimeToolConfig)
|
|
30
|
+
async def current_datetime(config: CurrentTimeToolConfig, builder: Builder):
|
|
31
|
+
|
|
32
|
+
import datetime
|
|
33
|
+
|
|
34
|
+
async def _get_current_time(unused: str) -> str:
|
|
35
|
+
|
|
36
|
+
now = datetime.datetime.now() # Get current time
|
|
37
|
+
now_human_readable = now.strftime(("%Y-%m-%d %H:%M:%S"))
|
|
38
|
+
|
|
39
|
+
return f"The current time of day is {now_human_readable}" # Format time in H:MM AM/PM format
|
|
40
|
+
|
|
41
|
+
yield FunctionInfo.from_fn(_get_current_time,
|
|
42
|
+
description="Returns the current date and time in human readable format.")
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from aiq.builder.function_info import FunctionInfo
|
|
24
|
+
from aiq.cli.register_workflow import register_function
|
|
25
|
+
from aiq.data_models.component_ref import LLMRef
|
|
26
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MilvusDocumentSearchToolConfig(FunctionBaseConfig, name="milvus_document_search"):
|
|
32
|
+
"""
|
|
33
|
+
This tool retrieves relevant documents for a given user query. The input query is mapped to the most appropriate
|
|
34
|
+
Milvus collection database. This will return relevant documents from the selected collection.
|
|
35
|
+
"""
|
|
36
|
+
base_url: str = Field(description="The base url used to connect to the milvus database.")
|
|
37
|
+
top_k: int = Field(default=4, description="The number of results to return from the milvus database.")
|
|
38
|
+
timeout: int = Field(default=60, description="The timeout configuration to use when sending requests.")
|
|
39
|
+
llm_name: LLMRef = Field(description=("The name of the llm client to instantiate to determine most appropriate "
|
|
40
|
+
"milvus collection."))
|
|
41
|
+
collection_names: list = Field(default=["nvidia_api_catalog"],
|
|
42
|
+
description="The list of available collection names.")
|
|
43
|
+
collection_descriptions: list = Field(default=["Documents about NVIDIA's product catalog"],
|
|
44
|
+
description=("Collection descriptions that map to collection names by "
|
|
45
|
+
"index position."))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register_function(config_type=MilvusDocumentSearchToolConfig)
|
|
49
|
+
async def document_search(config: MilvusDocumentSearchToolConfig, builder: Builder):
|
|
50
|
+
from typing import Literal
|
|
51
|
+
|
|
52
|
+
import httpx
|
|
53
|
+
from langchain_core.messages import HumanMessage
|
|
54
|
+
from langchain_core.messages import SystemMessage
|
|
55
|
+
from langchain_core.pydantic_v1 import BaseModel
|
|
56
|
+
from langchain_core.pydantic_v1 import Field # pylint: disable=redefined-outer-name, reimported
|
|
57
|
+
|
|
58
|
+
# define collection store
|
|
59
|
+
# create a list of tuples using enumerate()
|
|
60
|
+
tuples = [(key, value)
|
|
61
|
+
for i, (key, value) in enumerate(zip(config.collection_names, config.collection_descriptions))]
|
|
62
|
+
|
|
63
|
+
# convert list of tuples to dictionary using dict()
|
|
64
|
+
collection_store = dict(tuples)
|
|
65
|
+
|
|
66
|
+
# define collection class and force it to accept only valid collection names
|
|
67
|
+
class CollectionName(BaseModel):
|
|
68
|
+
collection_name: Literal[tuple(
|
|
69
|
+
config.collection_names)] = Field(description="The appropriate milvus collection name for the question.")
|
|
70
|
+
|
|
71
|
+
class DocumentSearchOutput(BaseModel):
|
|
72
|
+
collection_name: str
|
|
73
|
+
documents: str
|
|
74
|
+
|
|
75
|
+
# define prompt template
|
|
76
|
+
prompt_template = f"""You are an agent that helps users find the right Milvus collection based on the question.
|
|
77
|
+
Here are the available list of collections (formatted as collection_name: collection_description): \n
|
|
78
|
+
({collection_store})
|
|
79
|
+
\nFirst, analyze the available collections and their descriptions.
|
|
80
|
+
Then, select the most appropriate collection for the user's query.
|
|
81
|
+
Return only the name of the predicted collection."""
|
|
82
|
+
|
|
83
|
+
async with httpx.AsyncClient(headers={
|
|
84
|
+
"accept": "application/json", "Content-Type": "application/json"
|
|
85
|
+
},
|
|
86
|
+
timeout=config.timeout) as client:
|
|
87
|
+
|
|
88
|
+
async def _document_search(query: str) -> DocumentSearchOutput:
|
|
89
|
+
"""
|
|
90
|
+
This tool retrieve relevant context for the given question
|
|
91
|
+
Args:
|
|
92
|
+
query (str): The question for which we need to search milvus collections.
|
|
93
|
+
"""
|
|
94
|
+
# log query
|
|
95
|
+
logger.debug("Q: %s", query)
|
|
96
|
+
|
|
97
|
+
# Set Template
|
|
98
|
+
sys_message = SystemMessage(content=prompt_template)
|
|
99
|
+
|
|
100
|
+
# define LLM and generate response
|
|
101
|
+
llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
structured_llm = llm.with_structured_output(CollectionName)
|
|
103
|
+
query_string = f"Get relevant chunks for this query: {query}"
|
|
104
|
+
llm_pred = await structured_llm.ainvoke([sys_message] + [HumanMessage(content=query_string)])
|
|
105
|
+
|
|
106
|
+
logger.info("Predicted LLM Collection: %s", llm_pred)
|
|
107
|
+
|
|
108
|
+
# configure params for RAG endpoint and doc search
|
|
109
|
+
url = f"{config.base_url}/search"
|
|
110
|
+
payload = {"query": query, "top_k": config.top_k, "collection_name": llm_pred.collection_name}
|
|
111
|
+
|
|
112
|
+
# send configured payload to running chain server
|
|
113
|
+
logger.debug("Sending request to the RAG endpoint %s", url)
|
|
114
|
+
response = await client.post(url, content=json.dumps(payload))
|
|
115
|
+
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
results = response.json()
|
|
118
|
+
|
|
119
|
+
if len(results["chunks"]) == 0:
|
|
120
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents="")
|
|
121
|
+
|
|
122
|
+
# parse docs from Langchain Document object to string
|
|
123
|
+
parsed_docs = []
|
|
124
|
+
|
|
125
|
+
# iterate over results and store parsed content
|
|
126
|
+
for doc in results["chunks"]:
|
|
127
|
+
source = doc["filename"]
|
|
128
|
+
page = doc.get("page", "")
|
|
129
|
+
page_content = doc["content"]
|
|
130
|
+
parsed_document = f'<Document source="{source}" page="{page}"/>\n{page_content}\n</Document>'
|
|
131
|
+
parsed_docs.append(parsed_document)
|
|
132
|
+
|
|
133
|
+
# combine parsed documents into a single string
|
|
134
|
+
internal_search_docs = "\n\n---\n\n".join(parsed_docs)
|
|
135
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents=internal_search_docs)
|
|
136
|
+
|
|
137
|
+
yield FunctionInfo.from_fn(
|
|
138
|
+
_document_search,
|
|
139
|
+
description=("This tool retrieves relevant documents for a given user query."
|
|
140
|
+
"The input query is mapped to the most appropriate Milvus collection database"
|
|
141
|
+
"This will return relevant documents from the selected collection."))
|
|
File without changes
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
from aiq.builder.builder import Builder
|
|
20
|
+
from aiq.builder.function_info import FunctionInfo
|
|
21
|
+
from aiq.cli.register_workflow import register_function
|
|
22
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GithubCommitCodeModel(BaseModel):
|
|
26
|
+
branch: str = Field(description="The branch of the remote repo to which the code will be committed")
|
|
27
|
+
commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
|
|
28
|
+
local_path: str = Field(description="Local filepath of the file that has been updated and "
|
|
29
|
+
"needs to be committed to the remote repo")
|
|
30
|
+
remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
|
|
31
|
+
"root of current repository")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GithubCommitCodeModelList(BaseModel):
|
|
35
|
+
updated_files: list[GithubCommitCodeModel] = Field(description=("A list of local filepaths and commit messages"))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GithubCommitCodeConfig(FunctionBaseConfig, name="github_commit_code_tool"):
|
|
39
|
+
"""
|
|
40
|
+
Tool that commits and pushes modified code to a remote GitHub repository asynchronously.
|
|
41
|
+
"""
|
|
42
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
43
|
+
local_repo_dir: str = Field(description="Absolute path to the root of the repo, cloned locally")
|
|
44
|
+
timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_function(config_type=GithubCommitCodeConfig)
|
|
48
|
+
async def commit_code_async(config: GithubCommitCodeConfig, builder: Builder):
|
|
49
|
+
"""
|
|
50
|
+
Commits and pushes modified code to a remote GitHub repository asynchronously.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
import json
|
|
54
|
+
import os
|
|
55
|
+
|
|
56
|
+
import httpx
|
|
57
|
+
|
|
58
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
59
|
+
if not github_pat:
|
|
60
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
61
|
+
|
|
62
|
+
# define the headers for the payload request
|
|
63
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
64
|
+
|
|
65
|
+
async def _github_commit_code(updated_files) -> list:
|
|
66
|
+
results = []
|
|
67
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
68
|
+
for file_ in updated_files:
|
|
69
|
+
branch = file_.branch
|
|
70
|
+
commit_msg = file_.commit_msg
|
|
71
|
+
local_path = file_.local_path
|
|
72
|
+
remote_path = file_.remote_path
|
|
73
|
+
|
|
74
|
+
# Read content from the local file
|
|
75
|
+
local_path = os.path.join(config.local_repo_dir, local_path)
|
|
76
|
+
with open(local_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
77
|
+
content = f.read()
|
|
78
|
+
|
|
79
|
+
# Step 1. Create a blob with the updated contents of the file
|
|
80
|
+
blob_url = f'https://api.github.com/repos/{config.repo_name}/git/blobs'
|
|
81
|
+
blob_data = {'content': content, 'encoding': 'utf-8'}
|
|
82
|
+
blob_response = await client.request("POST", blob_url, json=blob_data, headers=headers)
|
|
83
|
+
blob_response.raise_for_status()
|
|
84
|
+
blob_sha = blob_response.json()['sha']
|
|
85
|
+
|
|
86
|
+
# Step 2: Get the base tree SHA. The commit will be pushed to this ref node in the Git graph
|
|
87
|
+
ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
|
|
88
|
+
ref_response = await client.request("GET", ref_url, headers=headers)
|
|
89
|
+
ref_response.raise_for_status()
|
|
90
|
+
base_tree_sha = ref_response.json()['object']['sha']
|
|
91
|
+
|
|
92
|
+
# Step 3. Create an updated tree (Git graph) with the new blob
|
|
93
|
+
tree_url = f'https://api.github.com/repos/{config.repo_name}/git/trees'
|
|
94
|
+
tree_data = {
|
|
95
|
+
'base_tree': base_tree_sha,
|
|
96
|
+
'tree': [{
|
|
97
|
+
'path': remote_path, 'mode': '100644', 'type': 'blob', 'sha': blob_sha
|
|
98
|
+
}]
|
|
99
|
+
}
|
|
100
|
+
tree_response = await client.request("POST", tree_url, json=tree_data, headers=headers)
|
|
101
|
+
tree_response.raise_for_status()
|
|
102
|
+
tree_sha = tree_response.json()['sha']
|
|
103
|
+
|
|
104
|
+
# Step 4: Create a commit
|
|
105
|
+
commit_url = f'https://api.github.com/repos/{config.repo_name}/git/commits'
|
|
106
|
+
commit_data = {'message': commit_msg, 'tree': tree_sha, 'parents': [base_tree_sha]}
|
|
107
|
+
commit_response = await client.request("POST", commit_url, json=commit_data, headers=headers)
|
|
108
|
+
commit_response.raise_for_status()
|
|
109
|
+
commit_sha = commit_response.json()['sha']
|
|
110
|
+
|
|
111
|
+
# Step 5: Update the reference in the Git graph
|
|
112
|
+
update_ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
|
|
113
|
+
update_ref_data = {'sha': commit_sha}
|
|
114
|
+
update_ref_response = await client.request("PATCH",
|
|
115
|
+
update_ref_url,
|
|
116
|
+
json=update_ref_data,
|
|
117
|
+
headers=headers)
|
|
118
|
+
update_ref_response.raise_for_status()
|
|
119
|
+
|
|
120
|
+
payload_responses = {
|
|
121
|
+
'blob_resp': blob_response.json(),
|
|
122
|
+
'original_tree_ref': tree_response.json(),
|
|
123
|
+
'commit_resp': commit_response.json(),
|
|
124
|
+
'updated_tree_ref_resp': update_ref_response.json()
|
|
125
|
+
}
|
|
126
|
+
results.append(payload_responses)
|
|
127
|
+
|
|
128
|
+
return json.dumps(results)
|
|
129
|
+
|
|
130
|
+
yield FunctionInfo.from_fn(_github_commit_code,
|
|
131
|
+
description=(f"Commits and pushes modified code to a "
|
|
132
|
+
f"GitHub repository in the repo named {config.repo_name}"),
|
|
133
|
+
input_schema=GithubCommitCodeModelList)
|