nvidia-nat 1.2.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/agent/__init__.py +0 -0
- aiq/agent/base.py +239 -0
- aiq/agent/dual_node.py +67 -0
- aiq/agent/react_agent/__init__.py +0 -0
- aiq/agent/react_agent/agent.py +355 -0
- aiq/agent/react_agent/output_parser.py +104 -0
- aiq/agent/react_agent/prompt.py +41 -0
- aiq/agent/react_agent/register.py +149 -0
- aiq/agent/reasoning_agent/__init__.py +0 -0
- aiq/agent/reasoning_agent/reasoning_agent.py +225 -0
- aiq/agent/register.py +23 -0
- aiq/agent/rewoo_agent/__init__.py +0 -0
- aiq/agent/rewoo_agent/agent.py +411 -0
- aiq/agent/rewoo_agent/prompt.py +108 -0
- aiq/agent/rewoo_agent/register.py +158 -0
- aiq/agent/tool_calling_agent/__init__.py +0 -0
- aiq/agent/tool_calling_agent/agent.py +119 -0
- aiq/agent/tool_calling_agent/register.py +106 -0
- aiq/authentication/__init__.py +14 -0
- aiq/authentication/api_key/__init__.py +14 -0
- aiq/authentication/api_key/api_key_auth_provider.py +96 -0
- aiq/authentication/api_key/api_key_auth_provider_config.py +124 -0
- aiq/authentication/api_key/register.py +26 -0
- aiq/authentication/exceptions/__init__.py +14 -0
- aiq/authentication/exceptions/api_key_exceptions.py +38 -0
- aiq/authentication/http_basic_auth/__init__.py +0 -0
- aiq/authentication/http_basic_auth/http_basic_auth_provider.py +81 -0
- aiq/authentication/http_basic_auth/register.py +30 -0
- aiq/authentication/interfaces.py +93 -0
- aiq/authentication/oauth2/__init__.py +14 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider.py +107 -0
- aiq/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +39 -0
- aiq/authentication/oauth2/register.py +25 -0
- aiq/authentication/register.py +21 -0
- aiq/builder/__init__.py +0 -0
- aiq/builder/builder.py +285 -0
- aiq/builder/component_utils.py +316 -0
- aiq/builder/context.py +264 -0
- aiq/builder/embedder.py +24 -0
- aiq/builder/eval_builder.py +161 -0
- aiq/builder/evaluator.py +29 -0
- aiq/builder/framework_enum.py +24 -0
- aiq/builder/front_end.py +73 -0
- aiq/builder/function.py +344 -0
- aiq/builder/function_base.py +380 -0
- aiq/builder/function_info.py +627 -0
- aiq/builder/intermediate_step_manager.py +174 -0
- aiq/builder/llm.py +25 -0
- aiq/builder/retriever.py +25 -0
- aiq/builder/user_interaction_manager.py +74 -0
- aiq/builder/workflow.py +148 -0
- aiq/builder/workflow_builder.py +1117 -0
- aiq/cli/__init__.py +14 -0
- aiq/cli/cli_utils/__init__.py +0 -0
- aiq/cli/cli_utils/config_override.py +231 -0
- aiq/cli/cli_utils/validation.py +37 -0
- aiq/cli/commands/__init__.py +0 -0
- aiq/cli/commands/configure/__init__.py +0 -0
- aiq/cli/commands/configure/channel/__init__.py +0 -0
- aiq/cli/commands/configure/channel/add.py +28 -0
- aiq/cli/commands/configure/channel/channel.py +36 -0
- aiq/cli/commands/configure/channel/remove.py +30 -0
- aiq/cli/commands/configure/channel/update.py +30 -0
- aiq/cli/commands/configure/configure.py +33 -0
- aiq/cli/commands/evaluate.py +139 -0
- aiq/cli/commands/info/__init__.py +14 -0
- aiq/cli/commands/info/info.py +39 -0
- aiq/cli/commands/info/list_channels.py +32 -0
- aiq/cli/commands/info/list_components.py +129 -0
- aiq/cli/commands/info/list_mcp.py +213 -0
- aiq/cli/commands/registry/__init__.py +14 -0
- aiq/cli/commands/registry/publish.py +88 -0
- aiq/cli/commands/registry/pull.py +118 -0
- aiq/cli/commands/registry/registry.py +38 -0
- aiq/cli/commands/registry/remove.py +108 -0
- aiq/cli/commands/registry/search.py +155 -0
- aiq/cli/commands/sizing/__init__.py +14 -0
- aiq/cli/commands/sizing/calc.py +297 -0
- aiq/cli/commands/sizing/sizing.py +27 -0
- aiq/cli/commands/start.py +246 -0
- aiq/cli/commands/uninstall.py +81 -0
- aiq/cli/commands/validate.py +47 -0
- aiq/cli/commands/workflow/__init__.py +14 -0
- aiq/cli/commands/workflow/templates/__init__.py.j2 +0 -0
- aiq/cli/commands/workflow/templates/config.yml.j2 +16 -0
- aiq/cli/commands/workflow/templates/pyproject.toml.j2 +22 -0
- aiq/cli/commands/workflow/templates/register.py.j2 +5 -0
- aiq/cli/commands/workflow/templates/workflow.py.j2 +36 -0
- aiq/cli/commands/workflow/workflow.py +37 -0
- aiq/cli/commands/workflow/workflow_commands.py +313 -0
- aiq/cli/entrypoint.py +135 -0
- aiq/cli/main.py +44 -0
- aiq/cli/register_workflow.py +488 -0
- aiq/cli/type_registry.py +1000 -0
- aiq/data_models/__init__.py +14 -0
- aiq/data_models/api_server.py +694 -0
- aiq/data_models/authentication.py +231 -0
- aiq/data_models/common.py +171 -0
- aiq/data_models/component.py +54 -0
- aiq/data_models/component_ref.py +168 -0
- aiq/data_models/config.py +406 -0
- aiq/data_models/dataset_handler.py +123 -0
- aiq/data_models/discovery_metadata.py +335 -0
- aiq/data_models/embedder.py +27 -0
- aiq/data_models/evaluate.py +127 -0
- aiq/data_models/evaluator.py +26 -0
- aiq/data_models/front_end.py +26 -0
- aiq/data_models/function.py +30 -0
- aiq/data_models/function_dependencies.py +72 -0
- aiq/data_models/interactive.py +246 -0
- aiq/data_models/intermediate_step.py +302 -0
- aiq/data_models/invocation_node.py +38 -0
- aiq/data_models/llm.py +27 -0
- aiq/data_models/logging.py +26 -0
- aiq/data_models/memory.py +27 -0
- aiq/data_models/object_store.py +44 -0
- aiq/data_models/profiler.py +54 -0
- aiq/data_models/registry_handler.py +26 -0
- aiq/data_models/retriever.py +30 -0
- aiq/data_models/retry_mixin.py +35 -0
- aiq/data_models/span.py +187 -0
- aiq/data_models/step_adaptor.py +64 -0
- aiq/data_models/streaming.py +33 -0
- aiq/data_models/swe_bench_model.py +54 -0
- aiq/data_models/telemetry_exporter.py +26 -0
- aiq/data_models/ttc_strategy.py +30 -0
- aiq/embedder/__init__.py +0 -0
- aiq/embedder/langchain_client.py +41 -0
- aiq/embedder/nim_embedder.py +59 -0
- aiq/embedder/openai_embedder.py +43 -0
- aiq/embedder/register.py +24 -0
- aiq/eval/__init__.py +14 -0
- aiq/eval/config.py +60 -0
- aiq/eval/dataset_handler/__init__.py +0 -0
- aiq/eval/dataset_handler/dataset_downloader.py +106 -0
- aiq/eval/dataset_handler/dataset_filter.py +52 -0
- aiq/eval/dataset_handler/dataset_handler.py +254 -0
- aiq/eval/evaluate.py +506 -0
- aiq/eval/evaluator/__init__.py +14 -0
- aiq/eval/evaluator/base_evaluator.py +73 -0
- aiq/eval/evaluator/evaluator_model.py +45 -0
- aiq/eval/intermediate_step_adapter.py +99 -0
- aiq/eval/rag_evaluator/__init__.py +0 -0
- aiq/eval/rag_evaluator/evaluate.py +178 -0
- aiq/eval/rag_evaluator/register.py +143 -0
- aiq/eval/register.py +23 -0
- aiq/eval/remote_workflow.py +133 -0
- aiq/eval/runners/__init__.py +14 -0
- aiq/eval/runners/config.py +39 -0
- aiq/eval/runners/multi_eval_runner.py +54 -0
- aiq/eval/runtime_event_subscriber.py +52 -0
- aiq/eval/swe_bench_evaluator/__init__.py +0 -0
- aiq/eval/swe_bench_evaluator/evaluate.py +215 -0
- aiq/eval/swe_bench_evaluator/register.py +36 -0
- aiq/eval/trajectory_evaluator/__init__.py +0 -0
- aiq/eval/trajectory_evaluator/evaluate.py +75 -0
- aiq/eval/trajectory_evaluator/register.py +40 -0
- aiq/eval/tunable_rag_evaluator/__init__.py +0 -0
- aiq/eval/tunable_rag_evaluator/evaluate.py +245 -0
- aiq/eval/tunable_rag_evaluator/register.py +52 -0
- aiq/eval/usage_stats.py +41 -0
- aiq/eval/utils/__init__.py +0 -0
- aiq/eval/utils/output_uploader.py +140 -0
- aiq/eval/utils/tqdm_position_registry.py +40 -0
- aiq/eval/utils/weave_eval.py +184 -0
- aiq/experimental/__init__.py +0 -0
- aiq/experimental/decorators/__init__.py +0 -0
- aiq/experimental/decorators/experimental_warning_decorator.py +130 -0
- aiq/experimental/test_time_compute/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/__init__.py +0 -0
- aiq/experimental/test_time_compute/editing/iterative_plan_refinement_editor.py +147 -0
- aiq/experimental/test_time_compute/editing/llm_as_a_judge_editor.py +204 -0
- aiq/experimental/test_time_compute/editing/motivation_aware_summarization.py +107 -0
- aiq/experimental/test_time_compute/functions/__init__.py +0 -0
- aiq/experimental/test_time_compute/functions/execute_score_select_function.py +105 -0
- aiq/experimental/test_time_compute/functions/its_tool_orchestration_function.py +205 -0
- aiq/experimental/test_time_compute/functions/its_tool_wrapper_function.py +146 -0
- aiq/experimental/test_time_compute/functions/plan_select_execute_function.py +224 -0
- aiq/experimental/test_time_compute/models/__init__.py +0 -0
- aiq/experimental/test_time_compute/models/editor_config.py +132 -0
- aiq/experimental/test_time_compute/models/scoring_config.py +112 -0
- aiq/experimental/test_time_compute/models/search_config.py +120 -0
- aiq/experimental/test_time_compute/models/selection_config.py +154 -0
- aiq/experimental/test_time_compute/models/stage_enums.py +43 -0
- aiq/experimental/test_time_compute/models/strategy_base.py +66 -0
- aiq/experimental/test_time_compute/models/tool_use_config.py +41 -0
- aiq/experimental/test_time_compute/models/ttc_item.py +48 -0
- aiq/experimental/test_time_compute/register.py +36 -0
- aiq/experimental/test_time_compute/scoring/__init__.py +0 -0
- aiq/experimental/test_time_compute/scoring/llm_based_agent_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/llm_based_plan_scorer.py +168 -0
- aiq/experimental/test_time_compute/scoring/motivation_aware_scorer.py +111 -0
- aiq/experimental/test_time_compute/search/__init__.py +0 -0
- aiq/experimental/test_time_compute/search/multi_llm_planner.py +128 -0
- aiq/experimental/test_time_compute/search/multi_query_retrieval_search.py +122 -0
- aiq/experimental/test_time_compute/search/single_shot_multi_plan_planner.py +128 -0
- aiq/experimental/test_time_compute/selection/__init__.py +0 -0
- aiq/experimental/test_time_compute/selection/best_of_n_selector.py +63 -0
- aiq/experimental/test_time_compute/selection/llm_based_agent_output_selector.py +131 -0
- aiq/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +159 -0
- aiq/experimental/test_time_compute/selection/llm_based_plan_selector.py +128 -0
- aiq/experimental/test_time_compute/selection/threshold_selector.py +58 -0
- aiq/front_ends/__init__.py +14 -0
- aiq/front_ends/console/__init__.py +14 -0
- aiq/front_ends/console/authentication_flow_handler.py +233 -0
- aiq/front_ends/console/console_front_end_config.py +32 -0
- aiq/front_ends/console/console_front_end_plugin.py +96 -0
- aiq/front_ends/console/register.py +25 -0
- aiq/front_ends/cron/__init__.py +14 -0
- aiq/front_ends/fastapi/__init__.py +14 -0
- aiq/front_ends/fastapi/auth_flow_handlers/__init__.py +0 -0
- aiq/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +27 -0
- aiq/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +107 -0
- aiq/front_ends/fastapi/fastapi_front_end_config.py +234 -0
- aiq/front_ends/fastapi/fastapi_front_end_controller.py +68 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin.py +116 -0
- aiq/front_ends/fastapi/fastapi_front_end_plugin_worker.py +1092 -0
- aiq/front_ends/fastapi/html_snippets/__init__.py +14 -0
- aiq/front_ends/fastapi/html_snippets/auth_code_grant_success.py +35 -0
- aiq/front_ends/fastapi/intermediate_steps_subscriber.py +80 -0
- aiq/front_ends/fastapi/job_store.py +183 -0
- aiq/front_ends/fastapi/main.py +72 -0
- aiq/front_ends/fastapi/message_handler.py +298 -0
- aiq/front_ends/fastapi/message_validator.py +345 -0
- aiq/front_ends/fastapi/register.py +25 -0
- aiq/front_ends/fastapi/response_helpers.py +195 -0
- aiq/front_ends/fastapi/step_adaptor.py +321 -0
- aiq/front_ends/mcp/__init__.py +14 -0
- aiq/front_ends/mcp/mcp_front_end_config.py +32 -0
- aiq/front_ends/mcp/mcp_front_end_plugin.py +93 -0
- aiq/front_ends/mcp/register.py +27 -0
- aiq/front_ends/mcp/tool_converter.py +242 -0
- aiq/front_ends/register.py +22 -0
- aiq/front_ends/simple_base/__init__.py +14 -0
- aiq/front_ends/simple_base/simple_front_end_plugin_base.py +54 -0
- aiq/llm/__init__.py +0 -0
- aiq/llm/aws_bedrock_llm.py +57 -0
- aiq/llm/nim_llm.py +46 -0
- aiq/llm/openai_llm.py +46 -0
- aiq/llm/register.py +23 -0
- aiq/llm/utils/__init__.py +14 -0
- aiq/llm/utils/env_config_value.py +94 -0
- aiq/llm/utils/error.py +17 -0
- aiq/memory/__init__.py +20 -0
- aiq/memory/interfaces.py +183 -0
- aiq/memory/models.py +112 -0
- aiq/meta/module_to_distro.json +3 -0
- aiq/meta/pypi.md +58 -0
- aiq/object_store/__init__.py +20 -0
- aiq/object_store/in_memory_object_store.py +76 -0
- aiq/object_store/interfaces.py +84 -0
- aiq/object_store/models.py +36 -0
- aiq/object_store/register.py +20 -0
- aiq/observability/__init__.py +14 -0
- aiq/observability/exporter/__init__.py +14 -0
- aiq/observability/exporter/base_exporter.py +449 -0
- aiq/observability/exporter/exporter.py +78 -0
- aiq/observability/exporter/file_exporter.py +33 -0
- aiq/observability/exporter/processing_exporter.py +322 -0
- aiq/observability/exporter/raw_exporter.py +52 -0
- aiq/observability/exporter/span_exporter.py +265 -0
- aiq/observability/exporter_manager.py +335 -0
- aiq/observability/mixin/__init__.py +14 -0
- aiq/observability/mixin/batch_config_mixin.py +26 -0
- aiq/observability/mixin/collector_config_mixin.py +23 -0
- aiq/observability/mixin/file_mixin.py +288 -0
- aiq/observability/mixin/file_mode.py +23 -0
- aiq/observability/mixin/resource_conflict_mixin.py +134 -0
- aiq/observability/mixin/serialize_mixin.py +61 -0
- aiq/observability/mixin/type_introspection_mixin.py +183 -0
- aiq/observability/processor/__init__.py +14 -0
- aiq/observability/processor/batching_processor.py +310 -0
- aiq/observability/processor/callback_processor.py +42 -0
- aiq/observability/processor/intermediate_step_serializer.py +28 -0
- aiq/observability/processor/processor.py +71 -0
- aiq/observability/register.py +96 -0
- aiq/observability/utils/__init__.py +14 -0
- aiq/observability/utils/dict_utils.py +236 -0
- aiq/observability/utils/time_utils.py +31 -0
- aiq/plugins/.namespace +1 -0
- aiq/profiler/__init__.py +0 -0
- aiq/profiler/calc/__init__.py +14 -0
- aiq/profiler/calc/calc_runner.py +627 -0
- aiq/profiler/calc/calculations.py +288 -0
- aiq/profiler/calc/data_models.py +188 -0
- aiq/profiler/calc/plot.py +345 -0
- aiq/profiler/callbacks/__init__.py +0 -0
- aiq/profiler/callbacks/agno_callback_handler.py +295 -0
- aiq/profiler/callbacks/base_callback_class.py +20 -0
- aiq/profiler/callbacks/langchain_callback_handler.py +290 -0
- aiq/profiler/callbacks/llama_index_callback_handler.py +205 -0
- aiq/profiler/callbacks/semantic_kernel_callback_handler.py +238 -0
- aiq/profiler/callbacks/token_usage_base_model.py +27 -0
- aiq/profiler/data_frame_row.py +51 -0
- aiq/profiler/data_models.py +24 -0
- aiq/profiler/decorators/__init__.py +0 -0
- aiq/profiler/decorators/framework_wrapper.py +131 -0
- aiq/profiler/decorators/function_tracking.py +254 -0
- aiq/profiler/forecasting/__init__.py +0 -0
- aiq/profiler/forecasting/config.py +18 -0
- aiq/profiler/forecasting/model_trainer.py +75 -0
- aiq/profiler/forecasting/models/__init__.py +22 -0
- aiq/profiler/forecasting/models/forecasting_base_model.py +40 -0
- aiq/profiler/forecasting/models/linear_model.py +196 -0
- aiq/profiler/forecasting/models/random_forest_regressor.py +268 -0
- aiq/profiler/inference_metrics_model.py +28 -0
- aiq/profiler/inference_optimization/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/__init__.py +0 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +460 -0
- aiq/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +258 -0
- aiq/profiler/inference_optimization/data_models.py +386 -0
- aiq/profiler/inference_optimization/experimental/__init__.py +0 -0
- aiq/profiler/inference_optimization/experimental/concurrency_spike_analysis.py +468 -0
- aiq/profiler/inference_optimization/experimental/prefix_span_analysis.py +405 -0
- aiq/profiler/inference_optimization/llm_metrics.py +212 -0
- aiq/profiler/inference_optimization/prompt_caching.py +163 -0
- aiq/profiler/inference_optimization/token_uniqueness.py +107 -0
- aiq/profiler/inference_optimization/workflow_runtimes.py +72 -0
- aiq/profiler/intermediate_property_adapter.py +102 -0
- aiq/profiler/profile_runner.py +473 -0
- aiq/profiler/utils.py +184 -0
- aiq/registry_handlers/__init__.py +0 -0
- aiq/registry_handlers/local/__init__.py +0 -0
- aiq/registry_handlers/local/local_handler.py +176 -0
- aiq/registry_handlers/local/register_local.py +37 -0
- aiq/registry_handlers/metadata_factory.py +60 -0
- aiq/registry_handlers/package_utils.py +567 -0
- aiq/registry_handlers/pypi/__init__.py +0 -0
- aiq/registry_handlers/pypi/pypi_handler.py +251 -0
- aiq/registry_handlers/pypi/register_pypi.py +40 -0
- aiq/registry_handlers/register.py +21 -0
- aiq/registry_handlers/registry_handler_base.py +157 -0
- aiq/registry_handlers/rest/__init__.py +0 -0
- aiq/registry_handlers/rest/register_rest.py +56 -0
- aiq/registry_handlers/rest/rest_handler.py +237 -0
- aiq/registry_handlers/schemas/__init__.py +0 -0
- aiq/registry_handlers/schemas/headers.py +42 -0
- aiq/registry_handlers/schemas/package.py +68 -0
- aiq/registry_handlers/schemas/publish.py +63 -0
- aiq/registry_handlers/schemas/pull.py +82 -0
- aiq/registry_handlers/schemas/remove.py +36 -0
- aiq/registry_handlers/schemas/search.py +91 -0
- aiq/registry_handlers/schemas/status.py +47 -0
- aiq/retriever/__init__.py +0 -0
- aiq/retriever/interface.py +37 -0
- aiq/retriever/milvus/__init__.py +14 -0
- aiq/retriever/milvus/register.py +81 -0
- aiq/retriever/milvus/retriever.py +228 -0
- aiq/retriever/models.py +74 -0
- aiq/retriever/nemo_retriever/__init__.py +14 -0
- aiq/retriever/nemo_retriever/register.py +60 -0
- aiq/retriever/nemo_retriever/retriever.py +190 -0
- aiq/retriever/register.py +22 -0
- aiq/runtime/__init__.py +14 -0
- aiq/runtime/loader.py +215 -0
- aiq/runtime/runner.py +190 -0
- aiq/runtime/session.py +158 -0
- aiq/runtime/user_metadata.py +130 -0
- aiq/settings/__init__.py +0 -0
- aiq/settings/global_settings.py +318 -0
- aiq/test/.namespace +1 -0
- aiq/tool/__init__.py +0 -0
- aiq/tool/chat_completion.py +74 -0
- aiq/tool/code_execution/README.md +151 -0
- aiq/tool/code_execution/__init__.py +0 -0
- aiq/tool/code_execution/code_sandbox.py +267 -0
- aiq/tool/code_execution/local_sandbox/.gitignore +1 -0
- aiq/tool/code_execution/local_sandbox/Dockerfile.sandbox +60 -0
- aiq/tool/code_execution/local_sandbox/__init__.py +13 -0
- aiq/tool/code_execution/local_sandbox/local_sandbox_server.py +198 -0
- aiq/tool/code_execution/local_sandbox/sandbox.requirements.txt +6 -0
- aiq/tool/code_execution/local_sandbox/start_local_sandbox.sh +50 -0
- aiq/tool/code_execution/register.py +74 -0
- aiq/tool/code_execution/test_code_execution_sandbox.py +414 -0
- aiq/tool/code_execution/utils.py +100 -0
- aiq/tool/datetime_tools.py +42 -0
- aiq/tool/document_search.py +141 -0
- aiq/tool/github_tools/__init__.py +0 -0
- aiq/tool/github_tools/create_github_commit.py +133 -0
- aiq/tool/github_tools/create_github_issue.py +87 -0
- aiq/tool/github_tools/create_github_pr.py +106 -0
- aiq/tool/github_tools/get_github_file.py +106 -0
- aiq/tool/github_tools/get_github_issue.py +166 -0
- aiq/tool/github_tools/get_github_pr.py +256 -0
- aiq/tool/github_tools/update_github_issue.py +100 -0
- aiq/tool/mcp/__init__.py +14 -0
- aiq/tool/mcp/exceptions.py +142 -0
- aiq/tool/mcp/mcp_client.py +255 -0
- aiq/tool/mcp/mcp_tool.py +96 -0
- aiq/tool/memory_tools/__init__.py +0 -0
- aiq/tool/memory_tools/add_memory_tool.py +79 -0
- aiq/tool/memory_tools/delete_memory_tool.py +67 -0
- aiq/tool/memory_tools/get_memory_tool.py +72 -0
- aiq/tool/nvidia_rag.py +95 -0
- aiq/tool/register.py +38 -0
- aiq/tool/retriever.py +89 -0
- aiq/tool/server_tools.py +66 -0
- aiq/utils/__init__.py +0 -0
- aiq/utils/data_models/__init__.py +0 -0
- aiq/utils/data_models/schema_validator.py +58 -0
- aiq/utils/debugging_utils.py +43 -0
- aiq/utils/dump_distro_mapping.py +32 -0
- aiq/utils/exception_handlers/__init__.py +0 -0
- aiq/utils/exception_handlers/automatic_retries.py +289 -0
- aiq/utils/exception_handlers/mcp.py +211 -0
- aiq/utils/exception_handlers/schemas.py +114 -0
- aiq/utils/io/__init__.py +0 -0
- aiq/utils/io/model_processing.py +28 -0
- aiq/utils/io/yaml_tools.py +119 -0
- aiq/utils/log_utils.py +37 -0
- aiq/utils/metadata_utils.py +74 -0
- aiq/utils/optional_imports.py +142 -0
- aiq/utils/producer_consumer_queue.py +178 -0
- aiq/utils/reactive/__init__.py +0 -0
- aiq/utils/reactive/base/__init__.py +0 -0
- aiq/utils/reactive/base/observable_base.py +65 -0
- aiq/utils/reactive/base/observer_base.py +55 -0
- aiq/utils/reactive/base/subject_base.py +79 -0
- aiq/utils/reactive/observable.py +59 -0
- aiq/utils/reactive/observer.py +76 -0
- aiq/utils/reactive/subject.py +131 -0
- aiq/utils/reactive/subscription.py +49 -0
- aiq/utils/settings/__init__.py +0 -0
- aiq/utils/settings/global_settings.py +197 -0
- aiq/utils/string_utils.py +38 -0
- aiq/utils/type_converter.py +290 -0
- aiq/utils/type_utils.py +484 -0
- aiq/utils/url_utils.py +27 -0
- nvidia_nat-1.2.0rc5.dist-info/METADATA +363 -0
- nvidia_nat-1.2.0rc5.dist-info/RECORD +435 -0
- nvidia_nat-1.2.0rc5.dist-info/WHEEL +5 -0
- nvidia_nat-1.2.0rc5.dist-info/entry_points.txt +20 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE-3rd-party.txt +3686 -0
- nvidia_nat-1.2.0rc5.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat-1.2.0rc5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from contextlib import asynccontextmanager
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from mcp import ClientSession
|
|
24
|
+
from mcp.client.sse import sse_client
|
|
25
|
+
from mcp.types import TextContent
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic import Field
|
|
28
|
+
from pydantic import create_model
|
|
29
|
+
|
|
30
|
+
from aiq.tool.mcp.exceptions import MCPToolNotFoundError
|
|
31
|
+
from aiq.utils.exception_handlers.mcp import mcp_exception_handler
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def model_from_mcp_schema(name: str, mcp_input_schema: dict) -> type[BaseModel]:
|
|
37
|
+
"""
|
|
38
|
+
Create a pydantic model from the input schema of the MCP tool
|
|
39
|
+
"""
|
|
40
|
+
_type_map = {
|
|
41
|
+
"string": str,
|
|
42
|
+
"number": float,
|
|
43
|
+
"integer": int,
|
|
44
|
+
"boolean": bool,
|
|
45
|
+
"array": list,
|
|
46
|
+
"null": None,
|
|
47
|
+
"object": dict,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
properties = mcp_input_schema.get("properties", {})
|
|
51
|
+
required_fields = set(mcp_input_schema.get("required", []))
|
|
52
|
+
schema_dict = {}
|
|
53
|
+
|
|
54
|
+
def _generate_valid_classname(class_name: str):
|
|
55
|
+
return class_name.replace('_', ' ').replace('-', ' ').title().replace(' ', '')
|
|
56
|
+
|
|
57
|
+
def _generate_field(field_name: str, field_properties: dict[str, Any]) -> tuple:
|
|
58
|
+
json_type = field_properties.get("type", "string")
|
|
59
|
+
enum_vals = field_properties.get("enum")
|
|
60
|
+
|
|
61
|
+
if enum_vals:
|
|
62
|
+
enum_name = f"{field_name.capitalize()}Enum"
|
|
63
|
+
field_type = Enum(enum_name, {item: item for item in enum_vals})
|
|
64
|
+
|
|
65
|
+
elif json_type == "object" and "properties" in field_properties:
|
|
66
|
+
field_type = model_from_mcp_schema(name=field_name, mcp_input_schema=field_properties)
|
|
67
|
+
elif json_type == "array" and "items" in field_properties:
|
|
68
|
+
item_properties = field_properties.get("items", {})
|
|
69
|
+
if item_properties.get("type") == "object":
|
|
70
|
+
item_type = model_from_mcp_schema(name=field_name, mcp_input_schema=item_properties)
|
|
71
|
+
else:
|
|
72
|
+
item_type = _type_map.get(item_properties.get("type", "string"), Any)
|
|
73
|
+
field_type = list[item_type]
|
|
74
|
+
elif isinstance(json_type, list):
|
|
75
|
+
field_type = None
|
|
76
|
+
for t in json_type:
|
|
77
|
+
mapped = _type_map.get(t, Any)
|
|
78
|
+
field_type = mapped if field_type is None else field_type | mapped
|
|
79
|
+
|
|
80
|
+
return field_type, Field(
|
|
81
|
+
default=field_properties.get("default", None if "null" in json_type else ...),
|
|
82
|
+
description=field_properties.get("description", "")
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
field_type = _type_map.get(json_type, Any)
|
|
86
|
+
|
|
87
|
+
# Determine the default value based on whether the field is required
|
|
88
|
+
if field_name in required_fields:
|
|
89
|
+
# Field is required - use explicit default if provided, otherwise make it required
|
|
90
|
+
default_value = field_properties.get("default", ...)
|
|
91
|
+
else:
|
|
92
|
+
# Field is optional - use explicit default if provided, otherwise None
|
|
93
|
+
default_value = field_properties.get("default", None)
|
|
94
|
+
# Make the type optional if no default was provided
|
|
95
|
+
if "default" not in field_properties:
|
|
96
|
+
field_type = field_type | None
|
|
97
|
+
|
|
98
|
+
nullable = field_properties.get("nullable", False)
|
|
99
|
+
description = field_properties.get("description", "")
|
|
100
|
+
|
|
101
|
+
field_type = field_type | None if nullable else field_type
|
|
102
|
+
|
|
103
|
+
return field_type, Field(default=default_value, description=description)
|
|
104
|
+
|
|
105
|
+
for field_name, field_props in properties.items():
|
|
106
|
+
schema_dict[field_name] = _generate_field(field_name=field_name, field_properties=field_props)
|
|
107
|
+
return create_model(f"{_generate_valid_classname(name)}InputSchema", **schema_dict)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class MCPSSEClient:
|
|
111
|
+
"""
|
|
112
|
+
Client for creating a session and connecting to an MCP server using SSE
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
url (str): The url of the MCP server
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(self, url: str):
|
|
119
|
+
self.url = url
|
|
120
|
+
|
|
121
|
+
@asynccontextmanager
|
|
122
|
+
async def connect_to_sse_server(self):
|
|
123
|
+
"""
|
|
124
|
+
Establish a session with an MCP SSE server within an aync context
|
|
125
|
+
"""
|
|
126
|
+
async with sse_client(url=self.url) as (read, write):
|
|
127
|
+
async with ClientSession(read, write) as session:
|
|
128
|
+
await session.initialize()
|
|
129
|
+
yield session
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class MCPBuilder(MCPSSEClient):
|
|
133
|
+
"""
|
|
134
|
+
Builder class used to connect to an MCP Server and generate ToolClients
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
url (str): The url of the MCP server
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(self, url):
|
|
141
|
+
super().__init__(url)
|
|
142
|
+
self._tools = None
|
|
143
|
+
|
|
144
|
+
@mcp_exception_handler
|
|
145
|
+
async def get_tools(self):
|
|
146
|
+
"""
|
|
147
|
+
Retrieve a dictionary of all tools served by the MCP server.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Dict of tool name to MCPToolClient
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
MCPError: If connection or tool retrieval fails
|
|
154
|
+
"""
|
|
155
|
+
async with self.connect_to_sse_server() as session:
|
|
156
|
+
response = await session.list_tools()
|
|
157
|
+
|
|
158
|
+
return {
|
|
159
|
+
tool.name: MCPToolClient(self.url, tool.name, tool.description, tool_input_schema=tool.inputSchema)
|
|
160
|
+
for tool in response.tools
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
@mcp_exception_handler
|
|
164
|
+
async def get_tool(self, tool_name: str) -> MCPToolClient:
|
|
165
|
+
"""
|
|
166
|
+
Get an MCP Tool by name.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
tool_name (str): Name of the tool to load.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
MCPToolClient for the configured tool.
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
MCPToolNotFoundError: If no tool is available with that name
|
|
176
|
+
MCPError: If connection fails
|
|
177
|
+
"""
|
|
178
|
+
if not self._tools:
|
|
179
|
+
self._tools = await self.get_tools()
|
|
180
|
+
|
|
181
|
+
tool = self._tools.get(tool_name)
|
|
182
|
+
if not tool:
|
|
183
|
+
raise MCPToolNotFoundError(tool_name, self.url)
|
|
184
|
+
return tool
|
|
185
|
+
|
|
186
|
+
@mcp_exception_handler
|
|
187
|
+
async def call_tool(self, tool_name: str, tool_args: dict | None):
|
|
188
|
+
async with self.connect_to_sse_server() as session:
|
|
189
|
+
result = await session.call_tool(tool_name, tool_args)
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class MCPToolClient(MCPSSEClient):
|
|
194
|
+
"""
|
|
195
|
+
Client wrapper used to call an MCP tool.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
url (str): The url of the MCP server
|
|
199
|
+
tool_name (str): The name of the tool to wrap
|
|
200
|
+
tool_description (str): The description of the tool provided by the MCP server.
|
|
201
|
+
tool_input_schema (dict): The input schema for the tool.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, url: str, tool_name: str, tool_description: str | None, tool_input_schema: dict | None = None):
|
|
205
|
+
super().__init__(url)
|
|
206
|
+
self._tool_name = tool_name
|
|
207
|
+
self._tool_description = tool_description
|
|
208
|
+
self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def name(self):
|
|
212
|
+
"""Returns the name of the tool."""
|
|
213
|
+
return self._tool_name
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def description(self):
|
|
217
|
+
"""
|
|
218
|
+
Returns the tool's description. If none was provided. Provides a simple description using the tool's name
|
|
219
|
+
"""
|
|
220
|
+
if not self._tool_description:
|
|
221
|
+
return f"MCP Tool {self._tool_name}"
|
|
222
|
+
return self._tool_description
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def input_schema(self):
|
|
226
|
+
"""
|
|
227
|
+
Returns the tool's input_schema.
|
|
228
|
+
"""
|
|
229
|
+
return self._input_schema
|
|
230
|
+
|
|
231
|
+
def set_description(self, description: str):
|
|
232
|
+
"""
|
|
233
|
+
Manually define the tool's description using the provided string.
|
|
234
|
+
"""
|
|
235
|
+
self._tool_description = description
|
|
236
|
+
|
|
237
|
+
@mcp_exception_handler
|
|
238
|
+
async def acall(self, tool_args: dict) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Call the MCP tool with the provided arguments.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
|
|
244
|
+
"""
|
|
245
|
+
async with self.connect_to_sse_server() as session:
|
|
246
|
+
result = await session.call_tool(self._tool_name, tool_args)
|
|
247
|
+
|
|
248
|
+
output = []
|
|
249
|
+
for res in result.content:
|
|
250
|
+
if isinstance(res, TextContent):
|
|
251
|
+
output.append(res.text)
|
|
252
|
+
else:
|
|
253
|
+
# Log non-text content for now
|
|
254
|
+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
|
|
255
|
+
return "\n".join(output)
|
aiq/tool/mcp/mcp_tool.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
from pydantic import HttpUrl
|
|
21
|
+
|
|
22
|
+
from aiq.builder.builder import Builder
|
|
23
|
+
from aiq.builder.function_info import FunctionInfo
|
|
24
|
+
from aiq.cli.register_workflow import register_function
|
|
25
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
|
|
31
|
+
"""
|
|
32
|
+
Function which connects to a Model Context Protocol (MCP) server and wraps the selected tool as a NeMo Agent toolkit
|
|
33
|
+
function.
|
|
34
|
+
"""
|
|
35
|
+
# Add your custom configuration parameters here
|
|
36
|
+
url: HttpUrl = Field(description="The URL of the MCP server")
|
|
37
|
+
mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
|
|
38
|
+
description: str | None = Field(default=None,
|
|
39
|
+
description="""
|
|
40
|
+
Description for the tool that will override the description provided by the MCP server. Should only be used if
|
|
41
|
+
the description provided by the server is poor or nonexistent
|
|
42
|
+
""")
|
|
43
|
+
return_exception: bool = Field(default=True,
|
|
44
|
+
description="""
|
|
45
|
+
If true, the tool will return the exception message if the tool call fails.
|
|
46
|
+
If false, raise the exception.
|
|
47
|
+
""")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_function(config_type=MCPToolConfig)
|
|
51
|
+
async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=unused-argument
|
|
52
|
+
"""
|
|
53
|
+
Generate an AIQ Toolkit Function that wraps a tool provided by the MCP server.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
from aiq.tool.mcp.mcp_client import MCPBuilder
|
|
57
|
+
from aiq.tool.mcp.mcp_client import MCPToolClient
|
|
58
|
+
|
|
59
|
+
client = MCPBuilder(url=str(config.url))
|
|
60
|
+
|
|
61
|
+
tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
|
|
62
|
+
if config.description:
|
|
63
|
+
tool.set_description(description=config.description)
|
|
64
|
+
|
|
65
|
+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, str(config.url))
|
|
66
|
+
|
|
67
|
+
def _convert_from_str(input_str: str) -> tool.input_schema:
|
|
68
|
+
return tool.input_schema.model_validate_json(input_str)
|
|
69
|
+
|
|
70
|
+
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
|
|
71
|
+
# Run the tool, catching any errors and sending to agent for correction
|
|
72
|
+
try:
|
|
73
|
+
if tool_input:
|
|
74
|
+
args = tool_input.model_dump()
|
|
75
|
+
return await tool.acall(args)
|
|
76
|
+
|
|
77
|
+
_ = tool.input_schema.model_validate(kwargs)
|
|
78
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
79
|
+
return await tool.acall(filtered_kwargs)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
if config.return_exception:
|
|
82
|
+
if tool_input:
|
|
83
|
+
logger.warning("Error calling tool %s with serialized input: %s",
|
|
84
|
+
tool.name,
|
|
85
|
+
tool_input.model_dump(),
|
|
86
|
+
exc_info=True)
|
|
87
|
+
else:
|
|
88
|
+
logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True)
|
|
89
|
+
return str(e)
|
|
90
|
+
# If the tool call fails, raise the exception.
|
|
91
|
+
raise
|
|
92
|
+
|
|
93
|
+
yield FunctionInfo.create(single_fn=_response_fn,
|
|
94
|
+
description=tool.description,
|
|
95
|
+
input_schema=tool.input_schema,
|
|
96
|
+
converters=[_convert_from_str])
|
|
File without changes
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import MemoryItem
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AddToolConfig(FunctionBaseConfig, name="add_memory"):
|
|
31
|
+
"""Function to add memory to a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to add memory about a user's interactions to a system "
|
|
34
|
+
"for retrieval later."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=AddToolConfig)
|
|
42
|
+
async def add_memory_tool(config: AddToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to add memory to a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
from langchain_core.tools import ToolException
|
|
47
|
+
|
|
48
|
+
# First, retrieve the memory client
|
|
49
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
50
|
+
|
|
51
|
+
async def _arun(item: MemoryItem) -> str:
|
|
52
|
+
"""
|
|
53
|
+
Asynchronous execution of addition of memories.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
item (MemoryItem): The memory item to add. Must include:
|
|
57
|
+
- conversation: List of dicts with "role" and "content" keys
|
|
58
|
+
- user_id: String identifier for the user
|
|
59
|
+
- metadata: Dict of metadata (can be empty)
|
|
60
|
+
- tags: Optional list of tags
|
|
61
|
+
- memory: Optional memory string
|
|
62
|
+
|
|
63
|
+
Note: If conversation is not provided, it will be created from the memory field
|
|
64
|
+
if available, otherwise an error will be raised.
|
|
65
|
+
"""
|
|
66
|
+
try:
|
|
67
|
+
# If conversation is not provided but memory is, create a conversation
|
|
68
|
+
if not item.conversation and item.memory:
|
|
69
|
+
item.conversation = [{"role": "user", "content": item.memory}]
|
|
70
|
+
elif not item.conversation:
|
|
71
|
+
raise ToolException("Either conversation or memory must be provided")
|
|
72
|
+
|
|
73
|
+
await memory_editor.add_items([item])
|
|
74
|
+
return "Memory added successfully. You can continue. Please respond to the user."
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise ToolException(f"Error adding memory: {e}") from e
|
|
78
|
+
|
|
79
|
+
yield FunctionInfo.from_fn(_arun, description=config.description)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import DeleteMemoryInput
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DeleteToolConfig(FunctionBaseConfig, name="delete_memory"):
|
|
31
|
+
"""Function to delete memory from a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
34
|
+
"interactions to help answer questions in a personalized way."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=DeleteToolConfig)
|
|
42
|
+
async def delete_memory_tool(config: DeleteToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to delete memory from a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from langchain_core.tools import ToolException
|
|
48
|
+
|
|
49
|
+
# First, retrieve the memory client
|
|
50
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
51
|
+
|
|
52
|
+
async def _arun(user_id: str) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Asynchronous execution of deletion of memories.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
|
|
59
|
+
await memory_editor.remove_items(user_id=user_id, )
|
|
60
|
+
|
|
61
|
+
return "Memories deleted!"
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
|
|
65
|
+
raise ToolException(f"Error deleting memory: {e}") from e
|
|
66
|
+
|
|
67
|
+
yield FunctionInfo.from_fn(_arun, description=config.description, input_schema=DeleteMemoryInput)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from aiq.builder.builder import Builder
|
|
21
|
+
from aiq.builder.function_info import FunctionInfo
|
|
22
|
+
from aiq.cli.register_workflow import register_function
|
|
23
|
+
from aiq.data_models.component_ref import MemoryRef
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
from aiq.memory.models import SearchMemoryInput
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GetToolConfig(FunctionBaseConfig, name="get_memory"):
|
|
31
|
+
"""Function to get memory to a hosted memory platform."""
|
|
32
|
+
|
|
33
|
+
description: str = Field(default=("Tool to retrieve memory about a user's "
|
|
34
|
+
"interactions to help answer questions in a personalized way."),
|
|
35
|
+
description="The description of this function's use for tool calling agents.")
|
|
36
|
+
memory: MemoryRef = Field(default="saas_memory",
|
|
37
|
+
description=("Instance name of the memory client instance from the workflow "
|
|
38
|
+
"configuration object."))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@register_function(config_type=GetToolConfig)
|
|
42
|
+
async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
43
|
+
"""
|
|
44
|
+
Function to get memory to a hosted memory platform.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
import json
|
|
48
|
+
|
|
49
|
+
from langchain_core.tools import ToolException
|
|
50
|
+
|
|
51
|
+
# First, retrieve the memory client
|
|
52
|
+
memory_editor = builder.get_memory_client(config.memory)
|
|
53
|
+
|
|
54
|
+
async def _arun(search_input: SearchMemoryInput) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Asynchronous execution of collection of memories.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
memories = await memory_editor.search(
|
|
60
|
+
query=search_input.query,
|
|
61
|
+
top_k=search_input.top_k,
|
|
62
|
+
user_id=search_input.user_id,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
memory_str = f"Memories as a JSON: \n{json.dumps([mem.model_dump(mode='json') for mem in memories])}"
|
|
66
|
+
return memory_str
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
|
|
70
|
+
raise ToolException(f"Error retreiving memory: {e}") from e
|
|
71
|
+
|
|
72
|
+
yield FunctionInfo.from_fn(_arun, description=config.description)
|
aiq/tool/nvidia_rag.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from aiq.builder.builder import Builder
|
|
22
|
+
from aiq.builder.function_info import FunctionInfo
|
|
23
|
+
from aiq.cli.register_workflow import register_function
|
|
24
|
+
from aiq.data_models.function import FunctionBaseConfig
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NVIDIARAGToolConfig(FunctionBaseConfig, name="nvidia_rag"):
|
|
30
|
+
"""
|
|
31
|
+
Tool used to search the NVIDIA Developer database for documents across a variety of NVIDIA asset types.
|
|
32
|
+
"""
|
|
33
|
+
base_url: str = Field(description="The base url to the RAG service.")
|
|
34
|
+
timeout: int = Field(default=60, description="The timeout configuration to use when sending requests.")
|
|
35
|
+
document_separator: str = Field(default="\n\n", description="The delimiter to use between retrieved documents.")
|
|
36
|
+
document_prompt: str = Field(default=("-------\n\n" + "Title: {document_title}\n"
|
|
37
|
+
"Text: {page_content}\nSource URL: {document_url}"),
|
|
38
|
+
description="The prompt to use to retrieve documents from the RAG service")
|
|
39
|
+
top_k: int = Field(default=4, description="The number of results to return from the RAG service.")
|
|
40
|
+
collection_name: str = Field(default="nvidia_api_catalog",
|
|
41
|
+
description=("The name of the collection to use when retrieving documents."))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register_function(config_type=NVIDIARAGToolConfig)
|
|
45
|
+
async def nvidia_rag_tool(config: NVIDIARAGToolConfig, builder: Builder):
|
|
46
|
+
import httpx
|
|
47
|
+
from langchain.prompts import PromptTemplate
|
|
48
|
+
from langchain_core.documents import Document
|
|
49
|
+
from langchain_core.prompts import aformat_document
|
|
50
|
+
|
|
51
|
+
document_prompt = PromptTemplate.from_template(config.document_prompt)
|
|
52
|
+
|
|
53
|
+
async with httpx.AsyncClient(headers={
|
|
54
|
+
"accept": "application/json", "Content-Type": "application/json"
|
|
55
|
+
},
|
|
56
|
+
timeout=config.timeout) as client:
|
|
57
|
+
|
|
58
|
+
async def runnable(query: str) -> str:
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
url = f"{config.base_url}/search"
|
|
62
|
+
|
|
63
|
+
payload = {"query": query, "top_k": config.top_k, "collection_name": config.collection_name}
|
|
64
|
+
|
|
65
|
+
logger.debug("Sending request to the RAG endpoint %s.", url)
|
|
66
|
+
response = await client.post(url, content=json.dumps(payload))
|
|
67
|
+
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
|
|
70
|
+
output = response.json()
|
|
71
|
+
|
|
72
|
+
docs = [
|
|
73
|
+
Document(
|
|
74
|
+
page_content=ret["content"],
|
|
75
|
+
metadata={
|
|
76
|
+
"document_title": ret["filename"],
|
|
77
|
+
"document_url": "nemo_framework",
|
|
78
|
+
"document_full_text": ret["content"],
|
|
79
|
+
"score_rerank": ret["score"]
|
|
80
|
+
},
|
|
81
|
+
type="Document",
|
|
82
|
+
) for ret in output["chunks"]
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
parsed_output = config.document_separator.join(
|
|
86
|
+
[await aformat_document(doc, document_prompt) for doc in docs])
|
|
87
|
+
return parsed_output
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.exception("Error while running the tool", exc_info=True)
|
|
90
|
+
return f"Error while running the tool: {e}"
|
|
91
|
+
|
|
92
|
+
yield FunctionInfo.from_fn(
|
|
93
|
+
runnable,
|
|
94
|
+
description=("Search the NVIDIA Developer database for documents across a variety of "
|
|
95
|
+
"NVIDIA asset types"))
|