nvidia-nat 1.1.0a20251020__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/__init__.py +0 -0
- nat/agent/base.py +265 -0
- nat/agent/dual_node.py +72 -0
- nat/agent/prompt_optimizer/__init__.py +0 -0
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/__init__.py +0 -0
- nat/agent/react_agent/agent.py +394 -0
- nat/agent/react_agent/output_parser.py +104 -0
- nat/agent/react_agent/prompt.py +44 -0
- nat/agent/react_agent/register.py +168 -0
- nat/agent/reasoning_agent/__init__.py +0 -0
- nat/agent/reasoning_agent/reasoning_agent.py +227 -0
- nat/agent/register.py +23 -0
- nat/agent/rewoo_agent/__init__.py +0 -0
- nat/agent/rewoo_agent/agent.py +593 -0
- nat/agent/rewoo_agent/prompt.py +107 -0
- nat/agent/rewoo_agent/register.py +175 -0
- nat/agent/tool_calling_agent/__init__.py +0 -0
- nat/agent/tool_calling_agent/agent.py +246 -0
- nat/agent/tool_calling_agent/register.py +129 -0
- nat/authentication/__init__.py +14 -0
- nat/authentication/api_key/__init__.py +14 -0
- nat/authentication/api_key/api_key_auth_provider.py +96 -0
- nat/authentication/api_key/api_key_auth_provider_config.py +124 -0
- nat/authentication/api_key/register.py +26 -0
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/exceptions/__init__.py +14 -0
- nat/authentication/exceptions/api_key_exceptions.py +38 -0
- nat/authentication/http_basic_auth/__init__.py +0 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- nat/authentication/http_basic_auth/register.py +30 -0
- nat/authentication/interfaces.py +96 -0
- nat/authentication/oauth2/__init__.py +14 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +140 -0
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/oauth2/register.py +25 -0
- nat/authentication/register.py +20 -0
- nat/builder/__init__.py +0 -0
- nat/builder/builder.py +317 -0
- nat/builder/component_utils.py +320 -0
- nat/builder/context.py +321 -0
- nat/builder/embedder.py +24 -0
- nat/builder/eval_builder.py +166 -0
- nat/builder/evaluator.py +29 -0
- nat/builder/framework_enum.py +25 -0
- nat/builder/front_end.py +73 -0
- nat/builder/function.py +714 -0
- nat/builder/function_base.py +380 -0
- nat/builder/function_info.py +625 -0
- nat/builder/intermediate_step_manager.py +206 -0
- nat/builder/llm.py +25 -0
- nat/builder/retriever.py +25 -0
- nat/builder/user_interaction_manager.py +78 -0
- nat/builder/workflow.py +160 -0
- nat/builder/workflow_builder.py +1365 -0
- nat/cli/__init__.py +14 -0
- nat/cli/cli_utils/__init__.py +0 -0
- nat/cli/cli_utils/config_override.py +231 -0
- nat/cli/cli_utils/validation.py +37 -0
- nat/cli/commands/__init__.py +0 -0
- nat/cli/commands/configure/__init__.py +0 -0
- nat/cli/commands/configure/channel/__init__.py +0 -0
- nat/cli/commands/configure/channel/add.py +28 -0
- nat/cli/commands/configure/channel/channel.py +34 -0
- nat/cli/commands/configure/channel/remove.py +30 -0
- nat/cli/commands/configure/channel/update.py +30 -0
- nat/cli/commands/configure/configure.py +33 -0
- nat/cli/commands/evaluate.py +139 -0
- nat/cli/commands/info/__init__.py +14 -0
- nat/cli/commands/info/info.py +47 -0
- nat/cli/commands/info/list_channels.py +32 -0
- nat/cli/commands/info/list_components.py +128 -0
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/__init__.py +14 -0
- nat/cli/commands/registry/publish.py +88 -0
- nat/cli/commands/registry/pull.py +118 -0
- nat/cli/commands/registry/registry.py +36 -0
- nat/cli/commands/registry/remove.py +108 -0
- nat/cli/commands/registry/search.py +153 -0
- nat/cli/commands/sizing/__init__.py +14 -0
- nat/cli/commands/sizing/calc.py +297 -0
- nat/cli/commands/sizing/sizing.py +27 -0
- nat/cli/commands/start.py +257 -0
- nat/cli/commands/uninstall.py +81 -0
- nat/cli/commands/validate.py +47 -0
- nat/cli/commands/workflow/__init__.py +14 -0
- nat/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +17 -0
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +25 -0
- nat/cli/commands/workflow/templates/register.py.j2 +4 -0
- nat/cli/commands/workflow/templates/workflow.py.j2 +50 -0
- nat/cli/commands/workflow/workflow.py +37 -0
- nat/cli/commands/workflow/workflow_commands.py +403 -0
- nat/cli/entrypoint.py +141 -0
- nat/cli/main.py +60 -0
- nat/cli/register_workflow.py +522 -0
- nat/cli/type_registry.py +1069 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/__init__.py +14 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +843 -0
- nat/data_models/authentication.py +245 -0
- nat/data_models/common.py +171 -0
- nat/data_models/component.py +60 -0
- nat/data_models/component_ref.py +179 -0
- nat/data_models/config.py +434 -0
- nat/data_models/dataset_handler.py +169 -0
- nat/data_models/discovery_metadata.py +305 -0
- nat/data_models/embedder.py +27 -0
- nat/data_models/evaluate.py +130 -0
- nat/data_models/evaluator.py +26 -0
- nat/data_models/front_end.py +26 -0
- nat/data_models/function.py +64 -0
- nat/data_models/function_dependencies.py +80 -0
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/interactive.py +246 -0
- nat/data_models/intermediate_step.py +302 -0
- nat/data_models/invocation_node.py +38 -0
- nat/data_models/llm.py +27 -0
- nat/data_models/logging.py +26 -0
- nat/data_models/memory.py +27 -0
- nat/data_models/object_store.py +44 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/profiler.py +54 -0
- nat/data_models/registry_handler.py +26 -0
- nat/data_models/retriever.py +30 -0
- nat/data_models/retry_mixin.py +35 -0
- nat/data_models/span.py +228 -0
- nat/data_models/step_adaptor.py +64 -0
- nat/data_models/streaming.py +33 -0
- nat/data_models/swe_bench_model.py +54 -0
- nat/data_models/telemetry_exporter.py +26 -0
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/data_models/ttc_strategy.py +30 -0
- nat/embedder/__init__.py +0 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +59 -0
- nat/embedder/openai_embedder.py +42 -0
- nat/embedder/register.py +22 -0
- nat/eval/__init__.py +14 -0
- nat/eval/config.py +62 -0
- nat/eval/dataset_handler/__init__.py +0 -0
- nat/eval/dataset_handler/dataset_downloader.py +106 -0
- nat/eval/dataset_handler/dataset_filter.py +52 -0
- nat/eval/dataset_handler/dataset_handler.py +431 -0
- nat/eval/evaluate.py +565 -0
- nat/eval/evaluator/__init__.py +14 -0
- nat/eval/evaluator/base_evaluator.py +77 -0
- nat/eval/evaluator/evaluator_model.py +58 -0
- nat/eval/intermediate_step_adapter.py +99 -0
- nat/eval/rag_evaluator/__init__.py +0 -0
- nat/eval/rag_evaluator/evaluate.py +178 -0
- nat/eval/rag_evaluator/register.py +143 -0
- nat/eval/register.py +26 -0
- nat/eval/remote_workflow.py +133 -0
- nat/eval/runners/__init__.py +14 -0
- nat/eval/runners/config.py +39 -0
- nat/eval/runners/multi_eval_runner.py +54 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/runtime_event_subscriber.py +52 -0
- nat/eval/swe_bench_evaluator/__init__.py +0 -0
- nat/eval/swe_bench_evaluator/evaluate.py +215 -0
- nat/eval/swe_bench_evaluator/register.py +36 -0
- nat/eval/trajectory_evaluator/__init__.py +0 -0
- nat/eval/trajectory_evaluator/evaluate.py +75 -0
- nat/eval/trajectory_evaluator/register.py +40 -0
- nat/eval/tunable_rag_evaluator/__init__.py +0 -0
- nat/eval/tunable_rag_evaluator/evaluate.py +242 -0
- nat/eval/tunable_rag_evaluator/register.py +52 -0
- nat/eval/usage_stats.py +41 -0
- nat/eval/utils/__init__.py +0 -0
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/output_uploader.py +140 -0
- nat/eval/utils/tqdm_position_registry.py +40 -0
- nat/eval/utils/weave_eval.py +193 -0
- nat/experimental/__init__.py +0 -0
- nat/experimental/decorators/__init__.py +0 -0
- nat/experimental/decorators/experimental_warning_decorator.py +154 -0
- nat/experimental/test_time_compute/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/__init__.py +0 -0
- nat/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- nat/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- nat/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- nat/experimental/test_time_compute/functions/__init__.py +0 -0
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +228 -0
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +205 -0
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +146 -0
- nat/experimental/test_time_compute/models/__init__.py +0 -0
- nat/experimental/test_time_compute/models/editor_config.py +132 -0
- nat/experimental/test_time_compute/models/scoring_config.py +112 -0
- nat/experimental/test_time_compute/models/search_config.py +120 -0
- nat/experimental/test_time_compute/models/selection_config.py +154 -0
- nat/experimental/test_time_compute/models/stage_enums.py +43 -0
- nat/experimental/test_time_compute/models/strategy_base.py +67 -0
- nat/experimental/test_time_compute/models/tool_use_config.py +41 -0
- nat/experimental/test_time_compute/models/ttc_item.py +48 -0
- nat/experimental/test_time_compute/register.py +35 -0
- nat/experimental/test_time_compute/scoring/__init__.py +0 -0
- nat/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- nat/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- nat/experimental/test_time_compute/search/__init__.py +0 -0
- nat/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- nat/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- nat/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- nat/experimental/test_time_compute/selection/__init__.py +0 -0
- nat/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- nat/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +157 -0
- nat/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- nat/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- nat/front_ends/__init__.py +14 -0
- nat/front_ends/console/__init__.py +14 -0
- nat/front_ends/console/authentication_flow_handler.py +285 -0
- nat/front_ends/console/console_front_end_config.py +32 -0
- nat/front_ends/console/console_front_end_plugin.py +108 -0
- nat/front_ends/console/register.py +25 -0
- nat/front_ends/cron/__init__.py +14 -0
- nat/front_ends/fastapi/__init__.py +14 -0
- nat/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +142 -0
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +272 -0
- nat/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +247 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1257 -0
- nat/front_ends/fastapi/html_snippets/__init__.py +14 -0
- nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- nat/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- nat/front_ends/fastapi/job_store.py +602 -0
- nat/front_ends/fastapi/main.py +64 -0
- nat/front_ends/fastapi/message_handler.py +344 -0
- nat/front_ends/fastapi/message_validator.py +351 -0
- nat/front_ends/fastapi/register.py +25 -0
- nat/front_ends/fastapi/response_helpers.py +195 -0
- nat/front_ends/fastapi/step_adaptor.py +319 -0
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/__init__.py +14 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +90 -0
- nat/front_ends/mcp/mcp_front_end_plugin.py +113 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +268 -0
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/register.py +27 -0
- nat/front_ends/mcp/tool_converter.py +290 -0
- nat/front_ends/register.py +21 -0
- nat/front_ends/simple_base/__init__.py +14 -0
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +56 -0
- nat/llm/__init__.py +0 -0
- nat/llm/aws_bedrock_llm.py +69 -0
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +58 -0
- nat/llm/openai_llm.py +54 -0
- nat/llm/register.py +27 -0
- nat/llm/utils/__init__.py +14 -0
- nat/llm/utils/env_config_value.py +93 -0
- nat/llm/utils/error.py +17 -0
- nat/llm/utils/thinking.py +215 -0
- nat/memory/__init__.py +20 -0
- nat/memory/interfaces.py +183 -0
- nat/memory/models.py +112 -0
- nat/meta/pypi.md +58 -0
- nat/object_store/__init__.py +20 -0
- nat/object_store/in_memory_object_store.py +76 -0
- nat/object_store/interfaces.py +84 -0
- nat/object_store/models.py +38 -0
- nat/object_store/register.py +19 -0
- nat/observability/__init__.py +14 -0
- nat/observability/exporter/__init__.py +14 -0
- nat/observability/exporter/base_exporter.py +449 -0
- nat/observability/exporter/exporter.py +78 -0
- nat/observability/exporter/file_exporter.py +33 -0
- nat/observability/exporter/processing_exporter.py +550 -0
- nat/observability/exporter/raw_exporter.py +52 -0
- nat/observability/exporter/span_exporter.py +308 -0
- nat/observability/exporter_manager.py +335 -0
- nat/observability/mixin/__init__.py +14 -0
- nat/observability/mixin/batch_config_mixin.py +26 -0
- nat/observability/mixin/collector_config_mixin.py +23 -0
- nat/observability/mixin/file_mixin.py +288 -0
- nat/observability/mixin/file_mode.py +23 -0
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/resource_conflict_mixin.py +134 -0
- nat/observability/mixin/serialize_mixin.py +61 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +496 -0
- nat/observability/processor/__init__.py +14 -0
- nat/observability/processor/batching_processor.py +308 -0
- nat/observability/processor/callback_processor.py +42 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/intermediate_step_serializer.py +28 -0
- nat/observability/processor/processor.py +74 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +114 -0
- nat/observability/utils/__init__.py +14 -0
- nat/observability/utils/dict_utils.py +236 -0
- nat/observability/utils/time_utils.py +31 -0
- nat/plugins/.namespace +1 -0
- nat/profiler/__init__.py +0 -0
- nat/profiler/calc/__init__.py +14 -0
- nat/profiler/calc/calc_runner.py +626 -0
- nat/profiler/calc/calculations.py +288 -0
- nat/profiler/calc/data_models.py +188 -0
- nat/profiler/calc/plot.py +345 -0
- nat/profiler/callbacks/__init__.py +0 -0
- nat/profiler/callbacks/agno_callback_handler.py +295 -0
- nat/profiler/callbacks/base_callback_class.py +20 -0
- nat/profiler/callbacks/langchain_callback_handler.py +297 -0
- nat/profiler/callbacks/llama_index_callback_handler.py +205 -0
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- nat/profiler/callbacks/token_usage_base_model.py +27 -0
- nat/profiler/data_frame_row.py +51 -0
- nat/profiler/data_models.py +24 -0
- nat/profiler/decorators/__init__.py +0 -0
- nat/profiler/decorators/framework_wrapper.py +180 -0
- nat/profiler/decorators/function_tracking.py +411 -0
- nat/profiler/forecasting/__init__.py +0 -0
- nat/profiler/forecasting/config.py +18 -0
- nat/profiler/forecasting/model_trainer.py +75 -0
- nat/profiler/forecasting/models/__init__.py +22 -0
- nat/profiler/forecasting/models/forecasting_base_model.py +42 -0
- nat/profiler/forecasting/models/linear_model.py +197 -0
- nat/profiler/forecasting/models/random_forest_regressor.py +269 -0
- nat/profiler/inference_metrics_model.py +28 -0
- nat/profiler/inference_optimization/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- nat/profiler/inference_optimization/data_models.py +386 -0
- nat/profiler/inference_optimization/experimental/__init__.py +0 -0
- nat/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +404 -0
- nat/profiler/inference_optimization/llm_metrics.py +212 -0
- nat/profiler/inference_optimization/prompt_caching.py +163 -0
- nat/profiler/inference_optimization/token_uniqueness.py +107 -0
- nat/profiler/inference_optimization/workflow_runtimes.py +72 -0
- nat/profiler/intermediate_property_adapter.py +102 -0
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +478 -0
- nat/profiler/utils.py +186 -0
- nat/registry_handlers/__init__.py +0 -0
- nat/registry_handlers/local/__init__.py +0 -0
- nat/registry_handlers/local/local_handler.py +176 -0
- nat/registry_handlers/local/register_local.py +37 -0
- nat/registry_handlers/metadata_factory.py +60 -0
- nat/registry_handlers/package_utils.py +570 -0
- nat/registry_handlers/pypi/__init__.py +0 -0
- nat/registry_handlers/pypi/pypi_handler.py +248 -0
- nat/registry_handlers/pypi/register_pypi.py +40 -0
- nat/registry_handlers/register.py +20 -0
- nat/registry_handlers/registry_handler_base.py +157 -0
- nat/registry_handlers/rest/__init__.py +0 -0
- nat/registry_handlers/rest/register_rest.py +56 -0
- nat/registry_handlers/rest/rest_handler.py +236 -0
- nat/registry_handlers/schemas/__init__.py +0 -0
- nat/registry_handlers/schemas/headers.py +42 -0
- nat/registry_handlers/schemas/package.py +68 -0
- nat/registry_handlers/schemas/publish.py +68 -0
- nat/registry_handlers/schemas/pull.py +82 -0
- nat/registry_handlers/schemas/remove.py +36 -0
- nat/registry_handlers/schemas/search.py +91 -0
- nat/registry_handlers/schemas/status.py +47 -0
- nat/retriever/__init__.py +0 -0
- nat/retriever/interface.py +41 -0
- nat/retriever/milvus/__init__.py +14 -0
- nat/retriever/milvus/register.py +81 -0
- nat/retriever/milvus/retriever.py +228 -0
- nat/retriever/models.py +77 -0
- nat/retriever/nemo_retriever/__init__.py +14 -0
- nat/retriever/nemo_retriever/register.py +60 -0
- nat/retriever/nemo_retriever/retriever.py +190 -0
- nat/retriever/register.py +21 -0
- nat/runtime/__init__.py +14 -0
- nat/runtime/loader.py +220 -0
- nat/runtime/runner.py +292 -0
- nat/runtime/session.py +223 -0
- nat/runtime/user_metadata.py +130 -0
- nat/settings/__init__.py +0 -0
- nat/settings/global_settings.py +329 -0
- nat/test/.namespace +1 -0
- nat/tool/__init__.py +0 -0
- nat/tool/chat_completion.py +77 -0
- nat/tool/code_execution/README.md +151 -0
- nat/tool/code_execution/__init__.py +0 -0
- nat/tool/code_execution/code_sandbox.py +267 -0
- nat/tool/code_execution/local_sandbox/.gitignore +1 -0
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- nat/tool/code_execution/local_sandbox/__init__.py +13 -0
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- nat/tool/code_execution/register.py +74 -0
- nat/tool/code_execution/test_code_execution_sandbox.py +414 -0
- nat/tool/code_execution/utils.py +100 -0
- nat/tool/datetime_tools.py +82 -0
- nat/tool/document_search.py +141 -0
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/__init__.py +0 -0
- nat/tool/memory_tools/add_memory_tool.py +79 -0
- nat/tool/memory_tools/delete_memory_tool.py +66 -0
- nat/tool/memory_tools/get_memory_tool.py +72 -0
- nat/tool/nvidia_rag.py +95 -0
- nat/tool/register.py +31 -0
- nat/tool/retriever.py +95 -0
- nat/tool/server_tools.py +66 -0
- nat/utils/__init__.py +0 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/__init__.py +0 -0
- nat/utils/data_models/schema_validator.py +58 -0
- nat/utils/debugging_utils.py +43 -0
- nat/utils/decorators.py +210 -0
- nat/utils/dump_distro_mapping.py +32 -0
- nat/utils/exception_handlers/__init__.py +0 -0
- nat/utils/exception_handlers/automatic_retries.py +342 -0
- nat/utils/exception_handlers/schemas.py +114 -0
- nat/utils/io/__init__.py +0 -0
- nat/utils/io/model_processing.py +28 -0
- nat/utils/io/yaml_tools.py +119 -0
- nat/utils/log_levels.py +25 -0
- nat/utils/log_utils.py +37 -0
- nat/utils/metadata_utils.py +74 -0
- nat/utils/optional_imports.py +142 -0
- nat/utils/producer_consumer_queue.py +178 -0
- nat/utils/reactive/__init__.py +0 -0
- nat/utils/reactive/base/__init__.py +0 -0
- nat/utils/reactive/base/observable_base.py +65 -0
- nat/utils/reactive/base/observer_base.py +55 -0
- nat/utils/reactive/base/subject_base.py +79 -0
- nat/utils/reactive/observable.py +59 -0
- nat/utils/reactive/observer.py +76 -0
- nat/utils/reactive/subject.py +131 -0
- nat/utils/reactive/subscription.py +49 -0
- nat/utils/settings/__init__.py +0 -0
- nat/utils/settings/global_settings.py +195 -0
- nat/utils/string_utils.py +38 -0
- nat/utils/type_converter.py +299 -0
- nat/utils/type_utils.py +488 -0
- nat/utils/url_utils.py +27 -0
- nvidia_nat-1.1.0a20251020.dist-info/METADATA +195 -0
- nvidia_nat-1.1.0a20251020.dist-info/RECORD +480 -0
- nvidia_nat-1.1.0a20251020.dist-info/WHEEL +5 -0
- nvidia_nat-1.1.0a20251020.dist-info/entry_points.txt +22 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat-1.1.0a20251020.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.1.0a20251020.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.builder.builder import Builder
|
|
22
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
|
+
from nat.builder.function_info import FunctionInfo
|
|
24
|
+
from nat.cli.register_workflow import register_function
|
|
25
|
+
from nat.data_models.component_ref import LLMRef
|
|
26
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MilvusDocumentSearchToolConfig(FunctionBaseConfig, name="milvus_document_search"):
|
|
32
|
+
"""
|
|
33
|
+
This tool retrieves relevant documents for a given user query. The input query is mapped to the most appropriate
|
|
34
|
+
Milvus collection database. This will return relevant documents from the selected collection.
|
|
35
|
+
"""
|
|
36
|
+
base_url: str = Field(description="The base url used to connect to the milvus database.")
|
|
37
|
+
top_k: int = Field(default=4, description="The number of results to return from the milvus database.")
|
|
38
|
+
timeout: int = Field(default=60, description="The timeout configuration to use when sending requests.")
|
|
39
|
+
llm_name: LLMRef = Field(description=("The name of the llm client to instantiate to determine most appropriate "
|
|
40
|
+
"milvus collection."))
|
|
41
|
+
collection_names: list = Field(default=["nvidia_api_catalog"],
|
|
42
|
+
description="The list of available collection names.")
|
|
43
|
+
collection_descriptions: list = Field(default=["Documents about NVIDIA's product catalog"],
|
|
44
|
+
description=("Collection descriptions that map to collection names by "
|
|
45
|
+
"index position."))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register_function(config_type=MilvusDocumentSearchToolConfig)
|
|
49
|
+
async def document_search(config: MilvusDocumentSearchToolConfig, builder: Builder):
|
|
50
|
+
from typing import Literal
|
|
51
|
+
|
|
52
|
+
import httpx
|
|
53
|
+
from langchain_core.messages import HumanMessage
|
|
54
|
+
from langchain_core.messages import SystemMessage
|
|
55
|
+
from langchain_core.pydantic_v1 import BaseModel
|
|
56
|
+
from langchain_core.pydantic_v1 import Field
|
|
57
|
+
|
|
58
|
+
# define collection store
|
|
59
|
+
# create a list of tuples using enumerate()
|
|
60
|
+
tuples = [(key, value)
|
|
61
|
+
for i, (key, value) in enumerate(zip(config.collection_names, config.collection_descriptions))]
|
|
62
|
+
|
|
63
|
+
# convert list of tuples to dictionary using dict()
|
|
64
|
+
collection_store = dict(tuples)
|
|
65
|
+
|
|
66
|
+
# define collection class and force it to accept only valid collection names
|
|
67
|
+
class CollectionName(BaseModel):
|
|
68
|
+
collection_name: Literal[tuple(
|
|
69
|
+
config.collection_names)] = Field(description="The appropriate milvus collection name for the question.")
|
|
70
|
+
|
|
71
|
+
class DocumentSearchOutput(BaseModel):
|
|
72
|
+
collection_name: str
|
|
73
|
+
documents: str
|
|
74
|
+
|
|
75
|
+
# define prompt template
|
|
76
|
+
prompt_template = f"""You are an agent that helps users find the right Milvus collection based on the question.
|
|
77
|
+
Here are the available list of collections (formatted as collection_name: collection_description): \n
|
|
78
|
+
({collection_store})
|
|
79
|
+
\nFirst, analyze the available collections and their descriptions.
|
|
80
|
+
Then, select the most appropriate collection for the user's query.
|
|
81
|
+
Return only the name of the predicted collection."""
|
|
82
|
+
|
|
83
|
+
async with httpx.AsyncClient(headers={
|
|
84
|
+
"accept": "application/json", "Content-Type": "application/json"
|
|
85
|
+
},
|
|
86
|
+
timeout=config.timeout) as client:
|
|
87
|
+
|
|
88
|
+
async def _document_search(query: str) -> DocumentSearchOutput:
|
|
89
|
+
"""
|
|
90
|
+
This tool retrieve relevant context for the given question
|
|
91
|
+
Args:
|
|
92
|
+
query (str): The question for which we need to search milvus collections.
|
|
93
|
+
"""
|
|
94
|
+
# log query
|
|
95
|
+
logger.debug("Q: %s", query)
|
|
96
|
+
|
|
97
|
+
# Set Template
|
|
98
|
+
sys_message = SystemMessage(content=prompt_template)
|
|
99
|
+
|
|
100
|
+
# define LLM and generate response
|
|
101
|
+
llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
102
|
+
structured_llm = llm.with_structured_output(CollectionName)
|
|
103
|
+
query_string = f"Get relevant chunks for this query: {query}"
|
|
104
|
+
llm_pred = await structured_llm.ainvoke([sys_message] + [HumanMessage(content=query_string)])
|
|
105
|
+
|
|
106
|
+
logger.info("Predicted LLM Collection: %s", llm_pred)
|
|
107
|
+
|
|
108
|
+
# configure params for RAG endpoint and doc search
|
|
109
|
+
url = f"{config.base_url}/search"
|
|
110
|
+
payload = {"query": query, "top_k": config.top_k, "collection_name": llm_pred.collection_name}
|
|
111
|
+
|
|
112
|
+
# send configured payload to running chain server
|
|
113
|
+
logger.debug("Sending request to the RAG endpoint %s", url)
|
|
114
|
+
response = await client.post(url, content=json.dumps(payload))
|
|
115
|
+
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
results = response.json()
|
|
118
|
+
|
|
119
|
+
if len(results["chunks"]) == 0:
|
|
120
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents="")
|
|
121
|
+
|
|
122
|
+
# parse docs from LangChain/LangGraph Document object to string
|
|
123
|
+
parsed_docs = []
|
|
124
|
+
|
|
125
|
+
# iterate over results and store parsed content
|
|
126
|
+
for doc in results["chunks"]:
|
|
127
|
+
source = doc["filename"]
|
|
128
|
+
page = doc.get("page", "")
|
|
129
|
+
page_content = doc["content"]
|
|
130
|
+
parsed_document = f'<Document source="{source}" page="{page}"/>\n{page_content}\n</Document>'
|
|
131
|
+
parsed_docs.append(parsed_document)
|
|
132
|
+
|
|
133
|
+
# combine parsed documents into a single string
|
|
134
|
+
internal_search_docs = "\n\n---\n\n".join(parsed_docs)
|
|
135
|
+
return DocumentSearchOutput(collection_name=llm_pred.collection_name, documents=internal_search_docs)
|
|
136
|
+
|
|
137
|
+
yield FunctionInfo.from_fn(
|
|
138
|
+
_document_search,
|
|
139
|
+
description=("This tool retrieves relevant documents for a given user query."
|
|
140
|
+
"The input query is mapped to the most appropriate Milvus collection database"
|
|
141
|
+
"This will return relevant documents from the selected collection."))
|
nat/tool/github_tools.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
from pydantic import PositiveInt
|
|
22
|
+
from pydantic import computed_field
|
|
23
|
+
from pydantic import field_validator
|
|
24
|
+
|
|
25
|
+
from nat.builder.builder import Builder
|
|
26
|
+
from nat.builder.function import FunctionGroup
|
|
27
|
+
from nat.builder.function_info import FunctionInfo
|
|
28
|
+
from nat.cli.register_workflow import register_function
|
|
29
|
+
from nat.cli.register_workflow import register_function_group
|
|
30
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
31
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GithubCreateIssueModel(BaseModel):
|
|
35
|
+
title: str = Field(description="The title of the GitHub Issue")
|
|
36
|
+
body: str = Field(description="The body of the GitHub Issue")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GithubCreateIssueModelList(BaseModel):
|
|
40
|
+
issues: list[GithubCreateIssueModel] = Field(default_factory=list,
|
|
41
|
+
description=("A list of GitHub issues, "
|
|
42
|
+
"each with a title and a body"))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GithubGetIssueModel(BaseModel):
|
|
46
|
+
state: Literal["open", "closed", "all"] | None = Field(default="open",
|
|
47
|
+
description="Issue state used in issue query filter")
|
|
48
|
+
assignee: str | None = Field(default=None, description="Assignee name used in issue query filter")
|
|
49
|
+
creator: str | None = Field(default=None, description="Creator name used in issue query filter")
|
|
50
|
+
mentioned: str | None = Field(default=None, description="Name of person mentioned in issue")
|
|
51
|
+
labels: list[str] | None = Field(default=None, description="A list of labels that are assigned to the issue")
|
|
52
|
+
since: str | None = Field(default=None,
|
|
53
|
+
description="Only show results that were last updated after the given time.")
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
@field_validator('since', mode='before')
|
|
57
|
+
def validate_since(cls, v):
|
|
58
|
+
if v is None:
|
|
59
|
+
return v
|
|
60
|
+
try:
|
|
61
|
+
# Parse the string to a datetime object
|
|
62
|
+
parsed_date = datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ")
|
|
63
|
+
# Return the formatted string
|
|
64
|
+
return parsed_date.isoformat() + 'Z'
|
|
65
|
+
except ValueError as e:
|
|
66
|
+
raise ValueError("since must be in ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ") from e
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GithubGetIssueModelList(BaseModel):
|
|
70
|
+
filter_parameters: list[GithubGetIssueModel] = Field(default_factory=list,
|
|
71
|
+
description=("A list of query params when fetching issues "
|
|
72
|
+
"each of type GithubGetIssueModel"))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class GithubUpdateIssueModel(BaseModel):
|
|
76
|
+
issue_number: str = Field(description="The issue number that will be updated")
|
|
77
|
+
title: str | None = Field(default=None, description="The title of the GitHub Issue")
|
|
78
|
+
body: str | None = Field(default=None, description="The body of the GitHub Issue")
|
|
79
|
+
state: Literal["open", "closed"] | None = Field(default=None, description="The new state of the issue")
|
|
80
|
+
|
|
81
|
+
state_reason: Literal["completed", "not_planned", "reopened"] | None = Field(
|
|
82
|
+
default=None, description="The reason for changing the state of the issue")
|
|
83
|
+
|
|
84
|
+
labels: list[str] | None = Field(default=None, description="A list of labels to assign to the issue")
|
|
85
|
+
assignees: list[str] | None = Field(default=None, description="A list of assignees to assign to the issue")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class GithubUpdateIssueModelList(BaseModel):
|
|
89
|
+
issues: list[GithubUpdateIssueModel] = Field(default_factory=list,
|
|
90
|
+
description=("A list of GitHub issues each "
|
|
91
|
+
"of type GithubUpdateIssueModel"))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class GithubCreatePullModel(BaseModel):
|
|
95
|
+
title: str = Field(description="Title of the pull request")
|
|
96
|
+
body: str = Field(description="Description of the pull request")
|
|
97
|
+
source_branch: str = Field(description="The name of the branch containing your changes", serialization_alias="head")
|
|
98
|
+
target_branch: str = Field(description="The name of the branch you want to merge into", serialization_alias="base")
|
|
99
|
+
assignees: list[str] | None = Field(default=None,
|
|
100
|
+
description="List of GitHub usernames to assign to the PR. "
|
|
101
|
+
"Always the current user")
|
|
102
|
+
reviewers: list[str] | None = Field(default=None, description="List of GitHub usernames to request review from")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class GithubCreatePullList(BaseModel):
|
|
106
|
+
pull_details: list[GithubCreatePullModel] = Field(
|
|
107
|
+
default_factory=list, description=("A list of params used for creating the PR in GitHub"))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class GithubGetPullsModel(BaseModel):
|
|
111
|
+
state: Literal["open", "closed", "all"] | None = Field(default="open",
|
|
112
|
+
description="Issue state used in issue query filter")
|
|
113
|
+
head: str | None = Field(default=None,
|
|
114
|
+
description="Filters pulls by head user or head organization and branch name")
|
|
115
|
+
base: str | None = Field(default=None, description="Filters pull by branch name")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class GithubGetPullsModelList(BaseModel):
|
|
119
|
+
filter_parameters: list[GithubGetPullsModel] = Field(
|
|
120
|
+
default_factory=list,
|
|
121
|
+
description=("A list of query params when fetching pull requests "
|
|
122
|
+
"each of type GithubGetPullsModel"))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class GithubCommitCodeModel(BaseModel):
|
|
126
|
+
branch: str = Field(description="The branch of the remote repo to which the code will be committed")
|
|
127
|
+
commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
|
|
128
|
+
local_path: str = Field(description="Local filepath of the file that has been updated and "
|
|
129
|
+
"needs to be committed to the remote repo")
|
|
130
|
+
remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
|
|
131
|
+
"root of current repository")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class GithubCommitCodeModelList(BaseModel):
|
|
135
|
+
updated_files: list[GithubCommitCodeModel] = Field(default_factory=list,
|
|
136
|
+
description=("A list of local filepaths and commit messages"))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class GithubGroupConfig(FunctionGroupBaseConfig, name="github"):
|
|
140
|
+
"""Function group for GitHub repository operations.
|
|
141
|
+
|
|
142
|
+
Exposes issue, pull request, and commit operations with shared configuration.
|
|
143
|
+
"""
|
|
144
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
145
|
+
timeout: int = Field(default=300, description="Timeout in seconds for GitHub API requests")
|
|
146
|
+
# Required for commit function
|
|
147
|
+
local_repo_dir: str | None = Field(default=None,
|
|
148
|
+
description="Absolute path to the local clone. Required for 'commit' function")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@register_function_group(config_type=GithubGroupConfig)
|
|
152
|
+
async def github_tool(config: GithubGroupConfig, _builder: Builder):
|
|
153
|
+
"""Register the `github` function group with shared configuration.
|
|
154
|
+
|
|
155
|
+
Implements:
|
|
156
|
+
- create_issue, get_issue, update_issue
|
|
157
|
+
- create_pull, get_pull
|
|
158
|
+
- commit
|
|
159
|
+
"""
|
|
160
|
+
import base64
|
|
161
|
+
import json
|
|
162
|
+
import os
|
|
163
|
+
|
|
164
|
+
import httpx
|
|
165
|
+
|
|
166
|
+
token: str | None = None
|
|
167
|
+
for env_var in ["GITHUB_TOKEN", "GITHUB_PAT", "GH_TOKEN"]:
|
|
168
|
+
token = os.getenv(env_var)
|
|
169
|
+
if token:
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
if not token:
|
|
173
|
+
raise ValueError("No GitHub token found in environment variables. Please set one of the following"
|
|
174
|
+
"environment variables: GITHUB_TOKEN, GITHUB_PAT, GH_TOKEN")
|
|
175
|
+
|
|
176
|
+
headers = {
|
|
177
|
+
"Authorization": f"Bearer {token}",
|
|
178
|
+
"Accept": "application/vnd.github+json",
|
|
179
|
+
"User-Agent": "NeMo-Agent-Toolkit",
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
async with httpx.AsyncClient(timeout=config.timeout, headers=headers) as client:
|
|
183
|
+
|
|
184
|
+
# Issues
|
|
185
|
+
async def create_issue(issues_list: GithubCreateIssueModelList) -> str:
|
|
186
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
187
|
+
results = []
|
|
188
|
+
for issue in issues_list.issues:
|
|
189
|
+
payload = issue.model_dump(exclude_unset=True)
|
|
190
|
+
response = await client.post(url, json=payload)
|
|
191
|
+
response.raise_for_status()
|
|
192
|
+
results.append(response.json())
|
|
193
|
+
return json.dumps(results)
|
|
194
|
+
|
|
195
|
+
async def get_issue(issues_list: GithubGetIssueModelList) -> str:
|
|
196
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
197
|
+
results = []
|
|
198
|
+
for issue in issues_list.filter_parameters:
|
|
199
|
+
params = issue.model_dump(exclude_unset=True, exclude_none=True)
|
|
200
|
+
response = await client.get(url, params=params)
|
|
201
|
+
response.raise_for_status()
|
|
202
|
+
results.append(response.json())
|
|
203
|
+
return json.dumps(results)
|
|
204
|
+
|
|
205
|
+
async def update_issue(issues_list: GithubUpdateIssueModelList) -> str:
|
|
206
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
207
|
+
results = []
|
|
208
|
+
for issue in issues_list.issues:
|
|
209
|
+
payload = issue.model_dump(exclude_unset=True, exclude_none=True)
|
|
210
|
+
issue_number = payload.pop("issue_number")
|
|
211
|
+
issue_url = f"{url}/{issue_number}"
|
|
212
|
+
response = await client.patch(issue_url, json=payload)
|
|
213
|
+
response.raise_for_status()
|
|
214
|
+
results.append(response.json())
|
|
215
|
+
return json.dumps(results)
|
|
216
|
+
|
|
217
|
+
# Pull requests
|
|
218
|
+
async def create_pull(pull_list: GithubCreatePullList) -> str:
|
|
219
|
+
results = []
|
|
220
|
+
pr_url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
221
|
+
|
|
222
|
+
for pull_detail in pull_list.pull_details:
|
|
223
|
+
|
|
224
|
+
pr_data = pull_detail.model_dump(
|
|
225
|
+
include={"title", "body", "source_branch", "target_branch"},
|
|
226
|
+
by_alias=True,
|
|
227
|
+
)
|
|
228
|
+
pr_response = await client.post(pr_url, json=pr_data)
|
|
229
|
+
pr_response.raise_for_status()
|
|
230
|
+
pr_number = pr_response.json()["number"]
|
|
231
|
+
|
|
232
|
+
result = {"pull_request": pr_response.json()}
|
|
233
|
+
|
|
234
|
+
if pull_detail.assignees:
|
|
235
|
+
assignees_url = f"https://api.github.com/repos/{config.repo_name}/issues/{pr_number}/assignees"
|
|
236
|
+
assignees_data = {"assignees": pull_detail.assignees}
|
|
237
|
+
assignees_response = await client.post(assignees_url, json=assignees_data)
|
|
238
|
+
assignees_response.raise_for_status()
|
|
239
|
+
result["assignees"] = assignees_response.json()
|
|
240
|
+
|
|
241
|
+
if pull_detail.reviewers:
|
|
242
|
+
reviewers_url = f"https://api.github.com/repos/{config.repo_name}/pulls/{pr_number}/requested_reviewers"
|
|
243
|
+
reviewers_data = {"reviewers": pull_detail.reviewers}
|
|
244
|
+
reviewers_response = await client.post(reviewers_url, json=reviewers_data)
|
|
245
|
+
reviewers_response.raise_for_status()
|
|
246
|
+
result["reviewers"] = reviewers_response.json()
|
|
247
|
+
|
|
248
|
+
results.append(result)
|
|
249
|
+
|
|
250
|
+
return json.dumps(results)
|
|
251
|
+
|
|
252
|
+
async def get_pull(pull_list: GithubGetPullsModelList) -> str:
|
|
253
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
254
|
+
results = []
|
|
255
|
+
for pull_params in pull_list.filter_parameters:
|
|
256
|
+
params = pull_params.model_dump(exclude_unset=True, exclude_none=True)
|
|
257
|
+
response = await client.get(url, params=params)
|
|
258
|
+
response.raise_for_status()
|
|
259
|
+
results.append(response.json())
|
|
260
|
+
|
|
261
|
+
return json.dumps(results)
|
|
262
|
+
|
|
263
|
+
# Commits (commit updated files)
|
|
264
|
+
async def commit(updated_file_list: GithubCommitCodeModelList) -> str:
|
|
265
|
+
if not config.local_repo_dir:
|
|
266
|
+
raise ValueError("'local_repo_dir' must be set in the github function group config to use 'commit'")
|
|
267
|
+
|
|
268
|
+
results = []
|
|
269
|
+
for updated_file in updated_file_list.updated_files:
|
|
270
|
+
branch = updated_file.branch
|
|
271
|
+
commit_msg = updated_file.commit_msg
|
|
272
|
+
local_path = updated_file.local_path
|
|
273
|
+
remote_path = updated_file.remote_path
|
|
274
|
+
|
|
275
|
+
# Read content from the local file (secure + binary-safe)
|
|
276
|
+
safe_root = os.path.realpath(config.local_repo_dir)
|
|
277
|
+
candidate = os.path.realpath(os.path.join(config.local_repo_dir, local_path))
|
|
278
|
+
if not candidate.startswith(safe_root + os.sep):
|
|
279
|
+
raise ValueError(f"local_path '{local_path}' resolves outside local_repo_dir")
|
|
280
|
+
if not os.path.isfile(candidate):
|
|
281
|
+
raise FileNotFoundError(f"File not found: {candidate}")
|
|
282
|
+
with open(candidate, "rb") as f:
|
|
283
|
+
content_bytes = f.read()
|
|
284
|
+
content_b64 = base64.b64encode(content_bytes).decode("ascii")
|
|
285
|
+
|
|
286
|
+
# 1) Create blob
|
|
287
|
+
blob_url = f"https://api.github.com/repos/{config.repo_name}/git/blobs"
|
|
288
|
+
blob_data = {"content": content_b64, "encoding": "base64"}
|
|
289
|
+
blob_response = await client.post(blob_url, json=blob_data)
|
|
290
|
+
blob_response.raise_for_status()
|
|
291
|
+
blob_sha = blob_response.json()["sha"]
|
|
292
|
+
|
|
293
|
+
# 2) Get current ref (parent commit SHA)
|
|
294
|
+
ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
|
|
295
|
+
ref_response = await client.get(ref_url)
|
|
296
|
+
ref_response.raise_for_status()
|
|
297
|
+
parent_commit_sha = ref_response.json()["object"]["sha"]
|
|
298
|
+
|
|
299
|
+
# 3) Get parent commit to retrieve its tree SHA
|
|
300
|
+
parent_commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits/{parent_commit_sha}"
|
|
301
|
+
parent_commit_resp = await client.get(parent_commit_url)
|
|
302
|
+
parent_commit_resp.raise_for_status()
|
|
303
|
+
base_tree_sha = parent_commit_resp.json()["tree"]["sha"]
|
|
304
|
+
|
|
305
|
+
# 4) Create tree
|
|
306
|
+
tree_url = f"https://api.github.com/repos/{config.repo_name}/git/trees"
|
|
307
|
+
tree_data = {
|
|
308
|
+
"base_tree": base_tree_sha,
|
|
309
|
+
"tree": [{
|
|
310
|
+
"path": remote_path, "mode": "100644", "type": "blob", "sha": blob_sha
|
|
311
|
+
}],
|
|
312
|
+
}
|
|
313
|
+
tree_response = await client.post(tree_url, json=tree_data)
|
|
314
|
+
tree_response.raise_for_status()
|
|
315
|
+
tree_sha = tree_response.json()["sha"]
|
|
316
|
+
|
|
317
|
+
# 5) Create commit
|
|
318
|
+
commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits"
|
|
319
|
+
commit_data = {"message": commit_msg, "tree": tree_sha, "parents": [parent_commit_sha]}
|
|
320
|
+
commit_response = await client.post(commit_url, json=commit_data)
|
|
321
|
+
commit_response.raise_for_status()
|
|
322
|
+
commit_sha = commit_response.json()["sha"]
|
|
323
|
+
|
|
324
|
+
# 6) Update ref
|
|
325
|
+
update_ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
|
|
326
|
+
update_ref_data = {"sha": commit_sha, "force": False}
|
|
327
|
+
update_ref_response = await client.patch(update_ref_url, json=update_ref_data)
|
|
328
|
+
update_ref_response.raise_for_status()
|
|
329
|
+
|
|
330
|
+
results.append({
|
|
331
|
+
"blob_resp": blob_response.json(),
|
|
332
|
+
"parent_commit": parent_commit_resp.json(),
|
|
333
|
+
"new_tree": tree_response.json(),
|
|
334
|
+
"commit_resp": commit_response.json(),
|
|
335
|
+
"update_ref_resp": update_ref_response.json(),
|
|
336
|
+
})
|
|
337
|
+
|
|
338
|
+
return json.dumps(results)
|
|
339
|
+
|
|
340
|
+
group = FunctionGroup(config=config)
|
|
341
|
+
|
|
342
|
+
group.add_function("create_issue",
|
|
343
|
+
create_issue,
|
|
344
|
+
description=f"Creates a GitHub issue in the repo named {config.repo_name}",
|
|
345
|
+
input_schema=GithubCreateIssueModelList)
|
|
346
|
+
group.add_function("get_issue",
|
|
347
|
+
get_issue,
|
|
348
|
+
description=f"Fetches a particular GitHub issue in the repo named {config.repo_name}",
|
|
349
|
+
input_schema=GithubGetIssueModelList)
|
|
350
|
+
group.add_function("update_issue",
|
|
351
|
+
update_issue,
|
|
352
|
+
description=f"Updates a GitHub issue in the repo named {config.repo_name}",
|
|
353
|
+
input_schema=GithubUpdateIssueModelList)
|
|
354
|
+
group.add_function("create_pull",
|
|
355
|
+
create_pull,
|
|
356
|
+
description="Creates a pull request with assignees and reviewers in"
|
|
357
|
+
f"the GitHub repository named {config.repo_name}",
|
|
358
|
+
input_schema=GithubCreatePullList)
|
|
359
|
+
group.add_function("get_pull",
|
|
360
|
+
get_pull,
|
|
361
|
+
description="Fetches the files for a particular GitHub pull request"
|
|
362
|
+
f"in the repo named {config.repo_name}",
|
|
363
|
+
input_schema=GithubGetPullsModelList)
|
|
364
|
+
group.add_function("commit",
|
|
365
|
+
commit,
|
|
366
|
+
description="Commits and pushes modified code to a GitHub repository"
|
|
367
|
+
f"in the repo named {config.repo_name}",
|
|
368
|
+
input_schema=GithubCommitCodeModelList)
|
|
369
|
+
|
|
370
|
+
yield group
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class GithubFilesGroupConfig(FunctionBaseConfig, name="github_files_tool"):
|
|
374
|
+
timeout: int = Field(default=5, description="Timeout in seconds for HTTP requests")
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@register_function(config_type=GithubFilesGroupConfig)
|
|
378
|
+
async def github_files_tool(config: GithubFilesGroupConfig, _builder: Builder):
|
|
379
|
+
|
|
380
|
+
import re
|
|
381
|
+
|
|
382
|
+
import httpx
|
|
383
|
+
|
|
384
|
+
class FileMetadata(BaseModel):
|
|
385
|
+
repo_path: str
|
|
386
|
+
file_path: str
|
|
387
|
+
start: str | None = Field(default=None)
|
|
388
|
+
end: str | None = Field(default=None)
|
|
389
|
+
|
|
390
|
+
@computed_field
|
|
391
|
+
@property
|
|
392
|
+
def start_line(self) -> PositiveInt | None:
|
|
393
|
+
return int(self.start) if self.start else None
|
|
394
|
+
|
|
395
|
+
@computed_field
|
|
396
|
+
@property
|
|
397
|
+
def end_line(self) -> PositiveInt | None:
|
|
398
|
+
return int(self.end) if self.end else None
|
|
399
|
+
|
|
400
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
401
|
+
|
|
402
|
+
async def get(url_text: str) -> str:
|
|
403
|
+
"""
|
|
404
|
+
Returns the text of a github file using a github url starting with https://github.com and ending
|
|
405
|
+
with a specific file. If a line reference is provided (#L409), the text of the line is returned.
|
|
406
|
+
If a range of lines is provided (#L409-L417), the text of the lines is returned.
|
|
407
|
+
|
|
408
|
+
Examples:
|
|
409
|
+
- https://github.com/org/repo/blob/main/README.md -> Returns full text of the README.md file
|
|
410
|
+
- https://github.com/org/repo/blob/main/README.md#L409 -> Returns the 409th line of the README.md file
|
|
411
|
+
- https://github.com/org/repo/blob/main/README.md#L409-L417 -> Returns lines 409-417 of the README.md file
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
pattern = r"https://github\.com/(?P<repo_path>[^/]*/[^/]*)/blob/(?P<file_path>[^?#]*)(?:#L(?P<start>\d+)(?:-L(?P<end>\d+))?)?"
|
|
415
|
+
match = re.match(pattern, url_text)
|
|
416
|
+
if not match:
|
|
417
|
+
return ("Invalid github url. Please provide a valid github url. "
|
|
418
|
+
"Example: 'https://github.com/org/repo/blob/main/README.md' "
|
|
419
|
+
"or 'https://github.com/org/repo/blob/main/README.md#L409' "
|
|
420
|
+
"or 'https://github.com/org/repo/blob/main/README.md#L409-L417'")
|
|
421
|
+
|
|
422
|
+
file_metadata = FileMetadata(**match.groupdict())
|
|
423
|
+
|
|
424
|
+
# The following URL is the raw URL of the file. refs/heads/ always points to the top commit of the branch
|
|
425
|
+
raw_url = f"https://raw.githubusercontent.com/{file_metadata.repo_path}/refs/heads/{file_metadata.file_path}"
|
|
426
|
+
try:
|
|
427
|
+
response = await client.get(raw_url)
|
|
428
|
+
response.raise_for_status()
|
|
429
|
+
except httpx.TimeoutException:
|
|
430
|
+
return f"Timeout encountered when retrieving resource: {raw_url}"
|
|
431
|
+
|
|
432
|
+
if file_metadata.start_line is None:
|
|
433
|
+
return f"```{response.text}\n```"
|
|
434
|
+
|
|
435
|
+
lines = response.text.splitlines()
|
|
436
|
+
|
|
437
|
+
if file_metadata.start_line > len(lines):
|
|
438
|
+
return f"Error: Line {file_metadata.start_line} is out of range for the file {file_metadata.file_path}"
|
|
439
|
+
|
|
440
|
+
if file_metadata.end_line is None:
|
|
441
|
+
return f"```{lines[file_metadata.start_line - 1]}\n```"
|
|
442
|
+
|
|
443
|
+
if file_metadata.end_line > len(lines):
|
|
444
|
+
return f"Error: Line {file_metadata.end_line} is out of range for the file {file_metadata.file_path}"
|
|
445
|
+
|
|
446
|
+
selected_lines = lines[file_metadata.start_line - 1:file_metadata.end_line]
|
|
447
|
+
response_text = "\n".join(selected_lines)
|
|
448
|
+
return f"```{response_text}\n```"
|
|
449
|
+
|
|
450
|
+
yield FunctionInfo.from_fn(get, description=get.__doc__)
|
|
File without changes
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.function_info import FunctionInfo
|
|
22
|
+
from nat.cli.register_workflow import register_function
|
|
23
|
+
from nat.data_models.component_ref import MemoryRef
|
|
24
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
25
|
+
from nat.memory.models import MemoryItem
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
|
+
"""Function to add memory to a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to add a memory about a user's interactions to a system "
|
|
34
|
+
"for retrieval later."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default=MemoryRef("saas_memory"),
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=AddToolConfig)
|
|
42
|
+
async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to add memory to a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
from langchain_core.tools import ToolException
|
|
47
|
+
|
|
48
|
+
# First, retrieve the memory client
|
|
49
|
+
memory_editor = await builder.get_memory_client(config.memory)
|
|
50
|
+
|
|
51
|
+
async def _arun(item: MemoryItem) -> str:
|
|
52
|
+
"""
|
|
53
|
+
Asynchronous execution of addition of memories.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
item (MemoryItem): The memory item to add. Must include:
|
|
57
|
+
- conversation: List of dicts with "role" and "content" keys
|
|
58
|
+
- user_id: String identifier for the user
|
|
59
|
+
- metadata: Dict of metadata (can be empty)
|
|
60
|
+
- tags: Optional list of tags
|
|
61
|
+
- memory: Optional memory string
|
|
62
|
+
|
|
63
|
+
Note: If conversation is not provided, it will be created from the memory field
|
|
64
|
+
if available, otherwise an error will be raised.
|
|
65
|
+
"""
|
|
66
|
+
try:
|
|
67
|
+
# If conversation is not provided but memory is, create a conversation
|
|
68
|
+
if not item.conversation and item.memory:
|
|
69
|
+
item.conversation = [{"role": "user", "content": item.memory}]
|
|
70
|
+
elif not item.conversation:
|
|
71
|
+
raise ToolException("Either conversation or memory must be provided")
|
|
72
|
+
|
|
73
|
+
await memory_editor.add_items([item])
|
|
74
|
+
return "Memory added successfully. You can continue. Please respond to the user."
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise ToolException(f"Error adding memory: {e}") from e
|
|
78
|
+
|
|
79
|
+
yield FunctionInfo.from_fn(_arun, description=config.description)
|