nvidia-nat 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +256 -0
- nat/agent/dual_node.py +67 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +363 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +149 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +225 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +415 -0
- nat/agent/rewoo_agent/prompt.py +110 -0
- nat/agent/rewoo_agent/register.py +157 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +119 -0
- nat/agent/tool_calling_agent/register.py +106 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +93 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +21 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +285 -0
- nat/builder/component_utils.py +316 -0
- nat/builder/context.py +270 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +161 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +24 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +344 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +627 -0
- nat/builder/intermediate_step_manager.py +174 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +148 -0
- nat/builder/workflow_builder.py +1117 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +37 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +129 -0
- nat/cli/commands/info/list_mcp.py +304 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +155 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +246 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +16 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- nat/cli/commands/workflow/templates/register.py.j2 +5 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +317 -0
- nat/cli/entrypoint.py +135 -0
- nat/cli/main.py +57 -0
- nat/cli/register_workflow.py +488 -0
- nat/cli/type_registry.py +1000 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/api_server.py +716 -0
- nat/data_models/authentication.py +231 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +58 -0
- nat/data_models/component_ref.py +168 -0
- nat/data_models/config.py +410 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +127 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +30 -0
- nat/data_models/function_dependencies.py +72 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +190 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +43 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +60 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +367 -0
- nat/eval/evaluate.py +510 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +45 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +23 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +245 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +184 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +134 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +66 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +36 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +233 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +96 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +241 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1087 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +183 -0
- nat/front_ends/fastapi/main.py +72 -0
- nat/front_ends/fastapi/message_handler.py +320 -0
- nat/front_ends/fastapi/message_validator.py +352 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/mcp_front_end_config.py +36 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +81 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +143 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +241 -0
- nat/front_ends/register.py +22 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +57 -0
- nat/llm/nim_llm.py +46 -0
- nat/llm/openai_llm.py +46 -0
- nat/llm/register.py +23 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +94 -0
- nat/llm/utils/error.py +17 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +20 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +322 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +288 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/type_introspection_mixin.py +183 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +310 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +71 -0
- nat/observability/register.py +96 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +627 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +290 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +131 -0
- nat/profiler/decorators/function_tracking.py +254 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +40 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/profile_runner.py +473 -0
- nat/profiler/utils.py +184 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +571 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +251 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +21 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +237 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +22 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +195 -0
- nat/runtime/session.py +162 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +318 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +74 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +42 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools/__init__.py +0 -0
- nat/tool/github_tools/create_github_commit.py +133 -0
- nat/tool/github_tools/create_github_issue.py +87 -0
- nat/tool/github_tools/create_github_pr.py +106 -0
- nat/tool/github_tools/get_github_file.py +106 -0
- nat/tool/github_tools/get_github_issue.py +166 -0
- nat/tool/github_tools/get_github_pr.py +256 -0
- nat/tool/github_tools/update_github_issue.py +100 -0
- nat/tool/mcp/__init__.py +14 -0
- nat/tool/mcp/exceptions.py +142 -0
- nat/tool/mcp/mcp_client.py +255 -0
- nat/tool/mcp/mcp_tool.py +96 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +67 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +38 -0
- nat/tool/retriever.py +94 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +289 -0
- nat/utils/exception_handlers/mcp.py +211 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +197 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +290 -0
- nat/utils/type_utils.py +484 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0.dist-info/METADATA +365 -0
- nvidia_nat-1.2.0.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0.dist-info/entry_points.txt +21 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.2.0.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.builder.builder import Builder
|
|
17
|
+
from nat.builder.function_info import FunctionInfo
|
|
18
|
+
from nat.cli.register_workflow import register_function
|
|
19
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CurrentTimeToolConfig(FunctionBaseConfig, name="current_datetime"):
|
|
23
|
+
"""
|
|
24
|
+
Simple tool which returns the current date and time in human readable format.
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register_function(config_type=CurrentTimeToolConfig)
|
|
30
|
+
async def current_datetime(config: CurrentTimeToolConfig, builder: Builder):
|
|
31
|
+
|
|
32
|
+
import datetime
|
|
33
|
+
|
|
34
|
+
async def _get_current_time(unused: str) -> str:
|
|
35
|
+
|
|
36
|
+
now = datetime.datetime.now() # Get current time
|
|
37
|
+
now_human_readable = now.strftime(("%Y-%m-%d %H:%M:%S"))
|
|
38
|
+
|
|
39
|
+
return f"The current time of day is {now_human_readable}" # Format time in H:MM AM/PM format
|
|
40
|
+
|
|
41
|
+
yield FunctionInfo.from_fn(_get_current_time,
|
|
42
|
+
description="Returns the current date and time in human readable format.")
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from nat.builder.function_info import FunctionInfo
|
|
24
|
+
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.component_ref import LLMRef
|
|
26
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MilvusDocumentSearchToolConfig(FunctionBaseConfig, name="milvus_document_search"):
|
|
32
|
+
"""
|
|
33
|
+
This tool retrieves relevant documents for a given user query. The input query is mapped to the most appropriate
|
|
34
|
+
Milvus collection database. This will return relevant documents from the selected collection.
|
|
35
|
+
"""
|
|
36
|
+
base_url: str = Field(description="The base url used to connect to the milvus database.")
|
|
37
|
+
top_k: int = Field(default=4, description="The number of results to return from the milvus database.")
|
|
38
|
+
timeout: int = Field(default=60, description="The timeout configuration to use when sending requests.")
|
|
39
|
+
llm_name: LLMRef = Field(description=("The name of the llm client to instantiate to determine most appropriate "
|
|
40
|
+
"milvus collection."))
|
|
41
|
+
collection_names: list = Field(default=["nvidia_api_catalog"],
|
|
42
|
+
description="The list of available collection names.")
|
|
43
|
+
collection_descriptions: list = Field(default=["Documents about NVIDIA's product catalog"],
|
|
44
|
+
description=("Collection descriptions that map to collection names by "
|
|
45
|
+
"index position."))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register_function(config_type=MilvusDocumentSearchToolConfig)
|
|
49
|
+
async def document_search(config: MilvusDocumentSearchToolConfig, builder: Builder):
|
|
50
|
+
from typing import Literal
|
|
51
|
+
|
|
52
|
+
import httpx
|
|
53
|
+
from langchain_core.messages import HumanMessage
|
|
54
|
+
from langchain_core.messages import SystemMessage
|
|
55
|
+
from langchain_core.pydantic_v1 import BaseModel
|
|
56
|
+
from langchain_core.pydantic_v1 import Field # pylint: disable=redefined-outer-name, reimported
|
|
57
|
+
|
|
58
|
+
# define collection store
|
|
59
|
+
# create a list of tuples using enumerate()
|
|
60
|
+
tuples = [(key, value)
|
|
61
|
+
for i, (key, value) in enumerate(zip(config.collection_names, config.collection_descriptions))]
|
|
62
|
+
|
|
63
|
+
# convert list of tuples to dictionary using dict()
|
|
64
|
+
collection_store = dict(tuples)
|
|
65
|
+
|
|
66
|
+
# define collection class and force it to accept only valid collection names
|
|
67
|
+
class CollectionName(BaseModel):
|
|
68
|
+
collection_name: Literal[tuple(
|
|
69
|
+
config.collection_names)] = Field(description="The appropriate milvus collection name for the question.")
|
|
70
|
+
|
|
71
|
+
class DocumentSearchOutput(BaseModel):
|
|
72
|
+
collection_name: str
|
|
73
|
+
documents: str
|
|
74
|
+
|
|
75
|
+
# define prompt template
|
|
76
|
+
prompt_template = f"""You are an agent that helps users find the right Milvus collection based on the question.
|
|
77
|
+
Here are the available list of collections (formatted as collection_name: collection_description): \n
|
|
78
|
+
({collection_store})
|
|
79
|
+
\nFirst, analyze the available collections and their descriptions.
|
|
80
|
+
Then, select the most appropriate collection for the user's query.
|
|
81
|
+
Return only the name of the predicted collection."""
|
|
82
|
+
|
|
83
|
+
async with httpx.AsyncClient(headers={
|
|
84
|
+
"accept": "application/json", "Content-Type": "application/json"
|
|
85
|
+
},
|
|
86
|
+
timeout=config.timeout) as client:
|
|
87
|
+
|
|
88
|
+
async def _document_search(query: str) -> DocumentSearchOutput:
|
|
89
|
+
"""
|
|
90
|
+
This tool retrieve relevant context for the given question
|
|
91
|
+
Args:
|
|
92
|
+
query (str): The question for which we need to search milvus collections.
|
|
93
|
+
"""
|
|
94
|
+
# log query
|
|
95
|
+
logger.debug("Q: %s", query)
|
|
96
|
+
|
|
97
|
+
# Set Template
|
|
98
|
+
sys_message = SystemMessage(content=prompt_template)
|
|
99
|
+
|
|
100
|
+
# define LLM and generate response
|
|
101
|
+
llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
structured_llm = llm.with_structured_output(CollectionName)
|
|
103
|
+
query_string = f"Get relevant chunks for this query: {query}"
|
|
104
|
+
llm_pred = await structured_llm.ainvoke([sys_message] + [HumanMessage(content=query_string)])
|
|
105
|
+
|
|
106
|
+
logger.info("Predicted LLM Collection: %s", llm_pred)
|
|
107
|
+
|
|
108
|
+
# configure params for RAG endpoint and doc search
|
|
109
|
+
url = f"{config.base_url}/search"
|
|
110
|
+
payload = {"query": query, "top_k": config.top_k, "collection_name": llm_pred.collection_name}
|
|
111
|
+
|
|
112
|
+
# send configured payload to running chain server
|
|
113
|
+
logger.debug("Sending request to the RAG endpoint %s", url)
|
|
114
|
+
response = await client.post(url, content=json.dumps(payload))
|
|
115
|
+
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
results = response.json()
|
|
118
|
+
|
|
119
|
+
if len(results["chunks"]) == 0:
|
|
120
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents="")
|
|
121
|
+
|
|
122
|
+
# parse docs from Langchain Document object to string
|
|
123
|
+
parsed_docs = []
|
|
124
|
+
|
|
125
|
+
# iterate over results and store parsed content
|
|
126
|
+
for doc in results["chunks"]:
|
|
127
|
+
source = doc["filename"]
|
|
128
|
+
page = doc.get("page", "")
|
|
129
|
+
page_content = doc["content"]
|
|
130
|
+
parsed_document = f'<Document source="{source}" page="{page}"/>\n{page_content}\n</Document>'
|
|
131
|
+
parsed_docs.append(parsed_document)
|
|
132
|
+
|
|
133
|
+
# combine parsed documents into a single string
|
|
134
|
+
internal_search_docs = "\n\n---\n\n".join(parsed_docs)
|
|
135
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents=internal_search_docs)
|
|
136
|
+
|
|
137
|
+
yield FunctionInfo.from_fn(
|
|
138
|
+
_document_search,
|
|
139
|
+
description=("This tool retrieves relevant documents for a given user query."
|
|
140
|
+
"The input query is mapped to the most appropriate Milvus collection database"
|
|
141
|
+
"This will return relevant documents from the selected collection."))
|
|
File without changes
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.function_info import FunctionInfo
|
|
21
|
+
from nat.cli.register_workflow import register_function
|
|
22
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GithubCommitCodeModel(BaseModel):
|
|
26
|
+
branch: str = Field(description="The branch of the remote repo to which the code will be committed")
|
|
27
|
+
commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
|
|
28
|
+
local_path: str = Field(description="Local filepath of the file that has been updated and "
|
|
29
|
+
"needs to be committed to the remote repo")
|
|
30
|
+
remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
|
|
31
|
+
"root of current repository")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GithubCommitCodeModelList(BaseModel):
|
|
35
|
+
updated_files: list[GithubCommitCodeModel] = Field(description=("A list of local filepaths and commit messages"))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GithubCommitCodeConfig(FunctionBaseConfig, name="github_commit_code_tool"):
|
|
39
|
+
"""
|
|
40
|
+
Tool that commits and pushes modified code to a remote GitHub repository asynchronously.
|
|
41
|
+
"""
|
|
42
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
43
|
+
local_repo_dir: str = Field(description="Absolute path to the root of the repo, cloned locally")
|
|
44
|
+
timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_function(config_type=GithubCommitCodeConfig)
|
|
48
|
+
async def commit_code_async(config: GithubCommitCodeConfig, builder: Builder):
|
|
49
|
+
"""
|
|
50
|
+
Commits and pushes modified code to a remote GitHub repository asynchronously.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
import json
|
|
54
|
+
import os
|
|
55
|
+
|
|
56
|
+
import httpx
|
|
57
|
+
|
|
58
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
59
|
+
if not github_pat:
|
|
60
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
61
|
+
|
|
62
|
+
# define the headers for the payload request
|
|
63
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
64
|
+
|
|
65
|
+
async def _github_commit_code(updated_files) -> list:
|
|
66
|
+
results = []
|
|
67
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
68
|
+
for file_ in updated_files:
|
|
69
|
+
branch = file_.branch
|
|
70
|
+
commit_msg = file_.commit_msg
|
|
71
|
+
local_path = file_.local_path
|
|
72
|
+
remote_path = file_.remote_path
|
|
73
|
+
|
|
74
|
+
# Read content from the local file
|
|
75
|
+
local_path = os.path.join(config.local_repo_dir, local_path)
|
|
76
|
+
with open(local_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
77
|
+
content = f.read()
|
|
78
|
+
|
|
79
|
+
# Step 1. Create a blob with the updated contents of the file
|
|
80
|
+
blob_url = f'https://api.github.com/repos/{config.repo_name}/git/blobs'
|
|
81
|
+
blob_data = {'content': content, 'encoding': 'utf-8'}
|
|
82
|
+
blob_response = await client.request("POST", blob_url, json=blob_data, headers=headers)
|
|
83
|
+
blob_response.raise_for_status()
|
|
84
|
+
blob_sha = blob_response.json()['sha']
|
|
85
|
+
|
|
86
|
+
# Step 2: Get the base tree SHA. The commit will be pushed to this ref node in the Git graph
|
|
87
|
+
ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
|
|
88
|
+
ref_response = await client.request("GET", ref_url, headers=headers)
|
|
89
|
+
ref_response.raise_for_status()
|
|
90
|
+
base_tree_sha = ref_response.json()['object']['sha']
|
|
91
|
+
|
|
92
|
+
# Step 3. Create an updated tree (Git graph) with the new blob
|
|
93
|
+
tree_url = f'https://api.github.com/repos/{config.repo_name}/git/trees'
|
|
94
|
+
tree_data = {
|
|
95
|
+
'base_tree': base_tree_sha,
|
|
96
|
+
'tree': [{
|
|
97
|
+
'path': remote_path, 'mode': '100644', 'type': 'blob', 'sha': blob_sha
|
|
98
|
+
}]
|
|
99
|
+
}
|
|
100
|
+
tree_response = await client.request("POST", tree_url, json=tree_data, headers=headers)
|
|
101
|
+
tree_response.raise_for_status()
|
|
102
|
+
tree_sha = tree_response.json()['sha']
|
|
103
|
+
|
|
104
|
+
# Step 4: Create a commit
|
|
105
|
+
commit_url = f'https://api.github.com/repos/{config.repo_name}/git/commits'
|
|
106
|
+
commit_data = {'message': commit_msg, 'tree': tree_sha, 'parents': [base_tree_sha]}
|
|
107
|
+
commit_response = await client.request("POST", commit_url, json=commit_data, headers=headers)
|
|
108
|
+
commit_response.raise_for_status()
|
|
109
|
+
commit_sha = commit_response.json()['sha']
|
|
110
|
+
|
|
111
|
+
# Step 5: Update the reference in the Git graph
|
|
112
|
+
update_ref_url = f'https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}'
|
|
113
|
+
update_ref_data = {'sha': commit_sha}
|
|
114
|
+
update_ref_response = await client.request("PATCH",
|
|
115
|
+
update_ref_url,
|
|
116
|
+
json=update_ref_data,
|
|
117
|
+
headers=headers)
|
|
118
|
+
update_ref_response.raise_for_status()
|
|
119
|
+
|
|
120
|
+
payload_responses = {
|
|
121
|
+
'blob_resp': blob_response.json(),
|
|
122
|
+
'original_tree_ref': tree_response.json(),
|
|
123
|
+
'commit_resp': commit_response.json(),
|
|
124
|
+
'updated_tree_ref_resp': update_ref_response.json()
|
|
125
|
+
}
|
|
126
|
+
results.append(payload_responses)
|
|
127
|
+
|
|
128
|
+
return json.dumps(results)
|
|
129
|
+
|
|
130
|
+
yield FunctionInfo.from_fn(_github_commit_code,
|
|
131
|
+
description=(f"Commits and pushes modified code to a "
|
|
132
|
+
f"GitHub repository in the repo named {config.repo_name}"),
|
|
133
|
+
input_schema=GithubCommitCodeModelList)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.function_info import FunctionInfo
|
|
21
|
+
from nat.cli.register_workflow import register_function
|
|
22
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GithubCreateIssueModel(BaseModel):
|
|
26
|
+
title: str = Field(description="The title of the GitHub Issue")
|
|
27
|
+
body: str = Field(description="The body of the GitHub Issue")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GithubCreateIssueModelList(BaseModel):
|
|
31
|
+
issues: list[GithubCreateIssueModel] = Field(description=("A list of GitHub issues, "
|
|
32
|
+
"each with a title and a body"))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GithubCreateIssueToolConfig(FunctionBaseConfig, name="github_create_issue_tool"):
|
|
36
|
+
"""
|
|
37
|
+
Tool that creates an issue in a GitHub repository asynchronously.
|
|
38
|
+
"""
|
|
39
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
40
|
+
timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_function(config_type=GithubCreateIssueToolConfig)
|
|
44
|
+
async def create_github_issue_async(config: GithubCreateIssueToolConfig, builder: Builder):
|
|
45
|
+
"""
|
|
46
|
+
Creates an issue in a GitHub repository asynchronously.
|
|
47
|
+
"""
|
|
48
|
+
import json
|
|
49
|
+
import os
|
|
50
|
+
|
|
51
|
+
import httpx
|
|
52
|
+
|
|
53
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
54
|
+
if not github_pat:
|
|
55
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
56
|
+
|
|
57
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
58
|
+
|
|
59
|
+
# define the headers for the payload request
|
|
60
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
61
|
+
|
|
62
|
+
async def _github_post_issue(issues) -> list:
|
|
63
|
+
results = []
|
|
64
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
65
|
+
for issue in issues:
|
|
66
|
+
# define the payload body
|
|
67
|
+
payload = issue.dict(exclude_unset=True)
|
|
68
|
+
|
|
69
|
+
response = await client.request("POST", url, json=payload, headers=headers)
|
|
70
|
+
|
|
71
|
+
# Raise an exception for HTTP errors
|
|
72
|
+
response.raise_for_status()
|
|
73
|
+
|
|
74
|
+
# Parse and return the response JSON
|
|
75
|
+
try:
|
|
76
|
+
result = response.json()
|
|
77
|
+
results.append(result)
|
|
78
|
+
|
|
79
|
+
except ValueError as e:
|
|
80
|
+
raise ValueError("The API response is not valid JSON.") from e
|
|
81
|
+
|
|
82
|
+
return json.dumps(results)
|
|
83
|
+
|
|
84
|
+
yield FunctionInfo.from_fn(_github_post_issue,
|
|
85
|
+
description=(f"Creates a GitHub issue in the "
|
|
86
|
+
f"repo named {config.repo_name}"),
|
|
87
|
+
input_schema=GithubCreateIssueModelList)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
from nat.builder.builder import Builder
|
|
20
|
+
from nat.builder.function_info import FunctionInfo
|
|
21
|
+
from nat.cli.register_workflow import register_function
|
|
22
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GithubCreatePullModel(BaseModel):
|
|
26
|
+
title: str = Field(description="Title of the pull request")
|
|
27
|
+
body: str = Field(description="Description of the pull request")
|
|
28
|
+
source_branch: str = Field(description="The name of the branch containing your changes")
|
|
29
|
+
target_branch: str = Field(description="The name of the branch you want to merge into")
|
|
30
|
+
assignees: list[str] | None = Field([],
|
|
31
|
+
description="List of GitHub usernames to assign to the PR. "
|
|
32
|
+
"Always the current user")
|
|
33
|
+
reviewers: list[str] | None = Field([], description="List of GitHub usernames to request review from")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class GithubCreatePullList(BaseModel):
|
|
37
|
+
pull_details: GithubCreatePullModel = Field(description=("A list of params used for creating the PR in GitHub"))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GithubCreatePullConfig(FunctionBaseConfig, name="github_create_pull_tool"):
|
|
41
|
+
"""
|
|
42
|
+
Tool that creates a pull request in a GitHub repository asynchronously with assignees and reviewers.
|
|
43
|
+
"""
|
|
44
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
45
|
+
timeout: int = Field(default=300, description="The timeout configuration to use when sending requests.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register_function(config_type=GithubCreatePullConfig)
|
|
49
|
+
async def create_pull_request_async(config: GithubCreatePullConfig, builder: Builder):
|
|
50
|
+
"""
|
|
51
|
+
Creates a pull request in a GitHub repository asynchronously with assignees and reviewers.
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
import json
|
|
55
|
+
import os
|
|
56
|
+
|
|
57
|
+
import httpx
|
|
58
|
+
|
|
59
|
+
github_pat = os.getenv("GITHUB_PAT")
|
|
60
|
+
if not github_pat:
|
|
61
|
+
raise ValueError("GITHUB_PAT environment variable must be set")
|
|
62
|
+
|
|
63
|
+
headers = {"Authorization": f"Bearer {github_pat}", "Accept": "application/vnd.github+json"}
|
|
64
|
+
|
|
65
|
+
async def _github_create_pull(pull_details: GithubCreatePullList) -> str:
|
|
66
|
+
results = []
|
|
67
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
68
|
+
# Create pull request
|
|
69
|
+
pr_url = f'https://api.github.com/repos/{config.repo_name}/pulls'
|
|
70
|
+
pr_data = {
|
|
71
|
+
'title': pull_details.title,
|
|
72
|
+
'body': pull_details.body,
|
|
73
|
+
'head': pull_details.source_branch,
|
|
74
|
+
'base': pull_details.target_branch
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
pr_response = await client.request("POST", pr_url, json=pr_data, headers=headers)
|
|
78
|
+
pr_response.raise_for_status()
|
|
79
|
+
pr_number = pr_response.json()['number']
|
|
80
|
+
|
|
81
|
+
# Add assignees if provided
|
|
82
|
+
if pull_details.assignees:
|
|
83
|
+
assignees_url = f'https://api.github.com/repos/{config.repo_name}/issues/{pr_number}/assignees'
|
|
84
|
+
assignees_data = {'assignees': pull_details.assignees}
|
|
85
|
+
assignees_response = await client.request("POST", assignees_url, json=assignees_data, headers=headers)
|
|
86
|
+
assignees_response.raise_for_status()
|
|
87
|
+
|
|
88
|
+
# Request reviewers if provided
|
|
89
|
+
if pull_details.reviewers:
|
|
90
|
+
reviewers_url = f'https://api.github.com/repos/{config.repo_name}/pulls/{pr_number}/requested_reviewers'
|
|
91
|
+
reviewers_data = {'reviewers': pull_details.reviewers}
|
|
92
|
+
reviewers_response = await client.request("POST", reviewers_url, json=reviewers_data, headers=headers)
|
|
93
|
+
reviewers_response.raise_for_status()
|
|
94
|
+
|
|
95
|
+
results.append({
|
|
96
|
+
'pull_request': pr_response.json(),
|
|
97
|
+
'assignees': assignees_response.json() if pull_details.assignees else None,
|
|
98
|
+
'reviewers': reviewers_response.json() if pull_details.reviewers else None
|
|
99
|
+
})
|
|
100
|
+
|
|
101
|
+
return json.dumps(results)
|
|
102
|
+
|
|
103
|
+
yield FunctionInfo.from_fn(_github_create_pull,
|
|
104
|
+
description=(f"Creates a pull request with assignees and reviewers in the "
|
|
105
|
+
f"GitHub repository named {config.repo_name}"),
|
|
106
|
+
input_schema=GithubCreatePullList)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.builder.builder import Builder
|
|
17
|
+
from nat.builder.function_info import FunctionInfo
|
|
18
|
+
from nat.cli.register_workflow import register_function
|
|
19
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GithubGetFileToolConfig(FunctionBaseConfig, name="github_getfile"):
|
|
23
|
+
"""
|
|
24
|
+
Tool that returns the text of a github file using a github url starting with https://github.com and ending
|
|
25
|
+
with a specific file.
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_function(config_type=GithubGetFileToolConfig)
|
|
31
|
+
async def github_text_from_url(config: GithubGetFileToolConfig, builder: Builder):
|
|
32
|
+
|
|
33
|
+
import re
|
|
34
|
+
|
|
35
|
+
import requests
|
|
36
|
+
|
|
37
|
+
async def _github_text_from_url(url_text: str) -> str:
|
|
38
|
+
|
|
39
|
+
# Extract sections of the base github path
|
|
40
|
+
pattern = r"https://github.com/(.*)/blob/(.*)"
|
|
41
|
+
matches = re.findall(pattern, url_text)
|
|
42
|
+
|
|
43
|
+
if (len(matches) == 0):
|
|
44
|
+
return ("Invalid github url. Please provide a valid github url. "
|
|
45
|
+
"Example: 'https://github.com/my_repository/blob/main/file.txt'")
|
|
46
|
+
|
|
47
|
+
# Construct raw content path
|
|
48
|
+
raw_url = f"https://raw.githubusercontent.com/{matches[0][0]}/refs/heads/{matches[0][1]}"
|
|
49
|
+
# Grab raw text from github
|
|
50
|
+
try:
|
|
51
|
+
response = requests.get(raw_url, timeout=60)
|
|
52
|
+
except requests.exceptions.Timeout:
|
|
53
|
+
return f"Timeout encountered when retrieving resource: {raw_url}"
|
|
54
|
+
|
|
55
|
+
return f"```python\n{response.text}\n```"
|
|
56
|
+
|
|
57
|
+
yield FunctionInfo.from_fn(_github_text_from_url,
|
|
58
|
+
description=("Returns the text of a github file using a github url starting with"
|
|
59
|
+
"https://github.com and ending with a specific file."))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GithubGetFileLinesToolConfig(FunctionBaseConfig, name="github_getfilelines"):
|
|
63
|
+
"""
|
|
64
|
+
Tool that returns the text lines of a github file using a github url starting with
|
|
65
|
+
https://github.com and ending with a specific file line references. Examples of line references are
|
|
66
|
+
#L409-L417 and #L166-L171.
|
|
67
|
+
"""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@register_function(config_type=GithubGetFileLinesToolConfig)
|
|
72
|
+
async def github_text_lines_from_url(config: GithubGetFileLinesToolConfig, builder: Builder):
|
|
73
|
+
|
|
74
|
+
import re
|
|
75
|
+
|
|
76
|
+
async def _github_text_lines_from_url(url_text: str) -> str:
|
|
77
|
+
|
|
78
|
+
import requests
|
|
79
|
+
|
|
80
|
+
# Extract sections of the base github path
|
|
81
|
+
pattern = r"https://github.com/(.*)/blob/(.*)(#L(\d+)-L(\d+))"
|
|
82
|
+
matches = re.findall(pattern, url_text)
|
|
83
|
+
|
|
84
|
+
if (len(matches) == 0):
|
|
85
|
+
return ("Invalid github url. Please provide a valid github url with line information. "
|
|
86
|
+
"Example: 'https://github.com/my_repository/blob/main/file.txt#L409-L417'")
|
|
87
|
+
|
|
88
|
+
start_line, end_line = int(matches[0][3]), int(matches[0][4])
|
|
89
|
+
# Construct raw content path
|
|
90
|
+
raw_url = f"https://raw.githubusercontent.com/{matches[0][0]}/refs/heads/{matches[0][1]}"
|
|
91
|
+
# Grab raw text from github
|
|
92
|
+
try:
|
|
93
|
+
response = requests.get(raw_url, timeout=60)
|
|
94
|
+
except requests.exceptions.Timeout:
|
|
95
|
+
return f"Timeout encountered when retrieving resource: {raw_url}"
|
|
96
|
+
# Extract the specified lines
|
|
97
|
+
file_lines = response.text.splitlines()
|
|
98
|
+
selected_lines = file_lines[start_line:end_line]
|
|
99
|
+
joined_selected_lines = "\n".join(selected_lines)
|
|
100
|
+
|
|
101
|
+
return f"```python\n{joined_selected_lines}\n```"
|
|
102
|
+
|
|
103
|
+
yield FunctionInfo.from_fn(_github_text_lines_from_url,
|
|
104
|
+
description=("Returns the text lines of a github file using a github url starting with"
|
|
105
|
+
"https://github.com and ending with a specific file line references. "
|
|
106
|
+
"Examples of line references are #L409-L417 and #L166-L171."))
|